import argschema.fields
import marshmallow as mm
import numpy as np
from allensdk.brain_observatory.argschema_utilities import (
RaisingSchema,
check_read_access,
check_write_access,
)
from argschema import ArgSchema
from argschema.fields import (
Boolean,
DateTime,
Dict,
Float,
Int,
List,
LogLevel,
Nested,
String,
)
[docs]class BaseBehaviorSessionDataSchema(RaisingSchema):
behavior_session_id = Int(
required=True,
description=(
"Unique identifier for the "
"behavior session to write into "
"NWB format"
),
)
driver_line = List(
String,
required=True,
cli_as_single_argument=True,
description="Genetic driver line(s) of subject",
)
reporter_line = List(
String,
required=True,
cli_as_single_argument=True,
description="Genetic reporter line(s) of subject",
)
full_genotype = String(
required=True, description="Full genotype of subject"
)
rig_name = String(
required=True,
description=(
"Name of experimental rig used for " "the behavior session"
),
)
date_of_acquisition = String(
required=True,
description=(
"Date of acquisition of " "behavior session, in string " "format"
),
)
external_specimen_name = Int(
required=True, description="LabTracks ID of the subject"
)
behavior_stimulus_file = argschema.fields.InputFile(
required=True,
description=("Path of behavior_stimulus " "camstim *.pkl file"),
)
date_of_birth = String(required=True, description="Subject date of birth")
sex = String(required=True, description="Subject sex")
age = String(required=True, description="Subject age")
[docs]class Channel(RaisingSchema):
[docs] @mm.pre_load
def set_field_defaults(self, data, **kwargs):
if data.get("filtering") is None:
data["filtering"] = (
"AP band: 500 Hz high-pass; " "LFP band: 1000 Hz low-pass"
)
if data.get("structure_acronym") is None:
data["structure_acronym"] = ""
return data
id = Int(required=True)
probe_id = Int(required=True)
valid_data = Boolean(required=True)
probe_channel_number = Int(required=True)
probe_vertical_position = Int(required=True)
probe_horizontal_position = Int(required=True)
structure_id = Int(required=True, allow_none=True)
structure_acronym = String(required=True)
anterior_posterior_ccf_coordinate = Float(allow_none=True)
dorsal_ventral_ccf_coordinate = Float(allow_none=True)
left_right_ccf_coordinate = Float(allow_none=True)
impedence = Float(required=False, allow_none=True, default=None)
filtering = String(required=False)
[docs] @mm.post_load
def set_impedence_default(self, data, **kwargs):
# This must be a post_load operation as np.nan is not a valid
# JSON format 'float' type for the Marshmallow `Float` field
# (so validation fails if this is set at pre_load)
if data.get("impedence") is None:
data["impedence"] = np.nan
return data
[docs]class Unit(RaisingSchema):
id = Int(required=True)
peak_channel_id = Int(required=True)
local_index = Int(
required=True,
help="within-probe index of this unit.",
)
cluster_id = Int(
required=True,
help="within-probe identifier of this unit",
)
quality = String(required=True)
firing_rate = Float(required=True)
snr = Float(required=True, allow_none=True)
isi_violations = Float(required=True)
presence_ratio = Float(required=True)
amplitude_cutoff = Float(required=True)
isolation_distance = Float(required=True, allow_none=True)
l_ratio = Float(required=True, allow_none=True)
d_prime = Float(required=True, allow_none=True)
nn_hit_rate = Float(required=True, allow_none=True)
nn_miss_rate = Float(required=True, allow_none=True)
max_drift = Float(required=True, allow_none=True)
cumulative_drift = Float(required=True, allow_none=True)
silhouette_score = Float(required=True, allow_none=True)
waveform_duration = Float(required=True, allow_none=True)
waveform_halfwidth = Float(required=True, allow_none=True)
PT_ratio = Float(required=True, allow_none=True)
repolarization_slope = Float(required=True, allow_none=True)
recovery_slope = Float(required=True, allow_none=True)
amplitude = Float(required=True, allow_none=True)
spread = Float(required=True, allow_none=True)
velocity_above = Float(required=True, allow_none=True)
velocity_below = Float(required=True, allow_none=True)
[docs]class Lfp(RaisingSchema):
input_data_path = String(required=True, validate=check_read_access)
input_timestamps_path = String(required=True, validate=check_read_access)
input_channels_path = String(required=True, validate=check_read_access)
output_path = String(required=False)
[docs]class Probe(RaisingSchema):
id = Int(required=True)
name = String(required=True)
spike_times_path = String(required=True, validate=check_read_access)
spike_clusters_file = String(required=True, validate=check_read_access)
mean_waveforms_path = String(required=True, validate=check_read_access)
channels = Nested(Channel, many=True, required=True)
units = Nested(Unit, many=True, required=True)
lfp = Nested(Lfp, many=False, required=False, allow_none=True)
csd_path = String(
required=False,
validate=check_read_access,
allow_none=True,
help="""path to h5 file containing calculated current
source density""",
)
sampling_rate = Float(
default=30000.0,
help="""sampling rate (Hz, master clock) at which raw data were
acquired on this probe""",
)
lfp_sampling_rate = Float(
default=2500.0,
allow_none=True,
help="""sampling rate of LFP data on this probe""",
)
temporal_subsampling_factor = Float(
default=2.0,
allow_none=True,
help="""subsampling factor applied to lfp data for
this probe (across time)""",
)
spike_amplitudes_path = String(
validate=check_read_access,
help="""path to npy file containing scale factor applied to the
kilosort template used to extract each spike""",
)
spike_templates_path = String(
validate=check_read_access,
help="""path to file associating each spike with a kilosort template""", # noqa: E501
)
templates_path = String(
validate=check_read_access,
help="""path to file containing an (nTemplates)x(nSamples)x(nUnits)
array of kilosort templates""",
)
inverse_whitening_matrix_path = String(
validate=check_read_access,
help="""Kilosort templates are whitened. In order to use them for
scaling spike amplitudes to volts, we need to remove
the whitening""",
)
amplitude_scale_factor = Float(
default=0.195e-6,
allow_none=True,
help="""amplitude scale factor converting raw amplitudes to Volts.
Default converts from bits -> uV -> V""",
)
scale_mean_waveform_and_csd = Float(
default=1,
allow_none=True,
help="""Amount to scale the mean waveform and CSD by. (data / scale).
This is a fix for a set of data documented in the change log.
The values for unit amplitudes were changed in the input_json
file and do not use this scale.
If the data in LIMS for these sessions is updated, this scaling
is not needed. Default is 1"""
)
[docs]class InvalidEpoch(RaisingSchema):
id = Int(required=True)
type = String(required=True)
label = String(required=True)
start_time = Float(required=True)
end_time = Float(required=True)
[docs]class BaseNeuropixelsSchema(ArgSchema):
"""Base schema for writing NWB files for projects with
behavior + ecephys"""
probes = Nested(
Probe,
many=True,
required=True,
help="records of the individual probes used for this experiment",
)
optotagging_table_path = argschema.fields.InputFile(
required=False,
help="""file at this path contains information about the optogenetic
stimulation applied during this experiment""",
)
running_speed_path = String(
required=False,
help="""data collected about the running behavior of the experiment's
subject""",
)
eye_tracking_rig_geometry = Dict(
required=False,
help="""Mapping containing information about session rig geometry used
for eye gaze mapping.""",
)
[docs]class ProbeOutputs(RaisingSchema):
nwb_path = String(required=True)
id = Int(required=True)
[docs]class OutputSchema(RaisingSchema):
nwb_path = String(required=True, description="path to output file")
probe_outputs = Nested(ProbeOutputs, required=True, many=True)