Source code for allensdk.brain_observatory.ecephys.write_nwb.schemas

import argschema.fields
import marshmallow as mm
import numpy as np
from allensdk.brain_observatory.argschema_utilities import (
    RaisingSchema,
    check_read_access,
    check_write_access,
)
from argschema import ArgSchema
from argschema.fields import (
    Boolean,
    DateTime,
    Dict,
    Float,
    Int,
    List,
    LogLevel,
    Nested,
    String,
)


[docs]class BaseBehaviorSessionDataSchema(RaisingSchema): behavior_session_id = Int( required=True, description=( "Unique identifier for the " "behavior session to write into " "NWB format" ), ) driver_line = List( String, required=True, cli_as_single_argument=True, description="Genetic driver line(s) of subject", ) reporter_line = List( String, required=True, cli_as_single_argument=True, description="Genetic reporter line(s) of subject", ) full_genotype = String( required=True, description="Full genotype of subject" ) rig_name = String( required=True, description=( "Name of experimental rig used for " "the behavior session" ), ) date_of_acquisition = String( required=True, description=( "Date of acquisition of " "behavior session, in string " "format" ), ) external_specimen_name = Int( required=True, description="LabTracks ID of the subject" ) behavior_stimulus_file = argschema.fields.InputFile( required=True, description=("Path of behavior_stimulus " "camstim *.pkl file"), ) date_of_birth = String(required=True, description="Subject date of birth") sex = String(required=True, description="Subject sex") age = String(required=True, description="Subject age")
[docs]class Channel(RaisingSchema):
[docs] @mm.pre_load def set_field_defaults(self, data, **kwargs): if data.get("filtering") is None: data["filtering"] = ( "AP band: 500 Hz high-pass; " "LFP band: 1000 Hz low-pass" ) if data.get("structure_acronym") is None: data["structure_acronym"] = "" return data
id = Int(required=True) probe_id = Int(required=True) valid_data = Boolean(required=True) probe_channel_number = Int(required=True) probe_vertical_position = Int(required=True) probe_horizontal_position = Int(required=True) structure_id = Int(required=True, allow_none=True) structure_acronym = String(required=True) anterior_posterior_ccf_coordinate = Float(allow_none=True) dorsal_ventral_ccf_coordinate = Float(allow_none=True) left_right_ccf_coordinate = Float(allow_none=True) impedence = Float(required=False, allow_none=True, default=None) filtering = String(required=False)
[docs] @mm.post_load def set_impedence_default(self, data, **kwargs): # This must be a post_load operation as np.nan is not a valid # JSON format 'float' type for the Marshmallow `Float` field # (so validation fails if this is set at pre_load) if data.get("impedence") is None: data["impedence"] = np.nan return data
[docs]class Unit(RaisingSchema): id = Int(required=True) peak_channel_id = Int(required=True) local_index = Int( required=True, help="within-probe index of this unit.", ) cluster_id = Int( required=True, help="within-probe identifier of this unit", ) quality = String(required=True) firing_rate = Float(required=True) snr = Float(required=True, allow_none=True) isi_violations = Float(required=True) presence_ratio = Float(required=True) amplitude_cutoff = Float(required=True) isolation_distance = Float(required=True, allow_none=True) l_ratio = Float(required=True, allow_none=True) d_prime = Float(required=True, allow_none=True) nn_hit_rate = Float(required=True, allow_none=True) nn_miss_rate = Float(required=True, allow_none=True) max_drift = Float(required=True, allow_none=True) cumulative_drift = Float(required=True, allow_none=True) silhouette_score = Float(required=True, allow_none=True) waveform_duration = Float(required=True, allow_none=True) waveform_halfwidth = Float(required=True, allow_none=True) PT_ratio = Float(required=True, allow_none=True) repolarization_slope = Float(required=True, allow_none=True) recovery_slope = Float(required=True, allow_none=True) amplitude = Float(required=True, allow_none=True) spread = Float(required=True, allow_none=True) velocity_above = Float(required=True, allow_none=True) velocity_below = Float(required=True, allow_none=True)
[docs]class Lfp(RaisingSchema): input_data_path = String(required=True, validate=check_read_access) input_timestamps_path = String(required=True, validate=check_read_access) input_channels_path = String(required=True, validate=check_read_access) output_path = String(required=False)
[docs]class Probe(RaisingSchema): id = Int(required=True) name = String(required=True) spike_times_path = String(required=True, validate=check_read_access) spike_clusters_file = String(required=True, validate=check_read_access) mean_waveforms_path = String(required=True, validate=check_read_access) channels = Nested(Channel, many=True, required=True) units = Nested(Unit, many=True, required=True) lfp = Nested(Lfp, many=False, required=False, allow_none=True) csd_path = String( required=False, validate=check_read_access, allow_none=True, help="""path to h5 file containing calculated current source density""", ) sampling_rate = Float( default=30000.0, help="""sampling rate (Hz, master clock) at which raw data were acquired on this probe""", ) lfp_sampling_rate = Float( default=2500.0, allow_none=True, help="""sampling rate of LFP data on this probe""", ) temporal_subsampling_factor = Float( default=2.0, allow_none=True, help="""subsampling factor applied to lfp data for this probe (across time)""", ) spike_amplitudes_path = String( validate=check_read_access, help="""path to npy file containing scale factor applied to the kilosort template used to extract each spike""", ) spike_templates_path = String( validate=check_read_access, help="""path to file associating each spike with a kilosort template""", # noqa: E501 ) templates_path = String( validate=check_read_access, help="""path to file containing an (nTemplates)x(nSamples)x(nUnits) array of kilosort templates""", ) inverse_whitening_matrix_path = String( validate=check_read_access, help="""Kilosort templates are whitened. In order to use them for scaling spike amplitudes to volts, we need to remove the whitening""", ) amplitude_scale_factor = Float( default=0.195e-6, allow_none=True, help="""amplitude scale factor converting raw amplitudes to Volts. Default converts from bits -> uV -> V""", ) scale_mean_waveform_and_csd = Float( default=1, allow_none=True, help="""Amount to scale the mean waveform and CSD by. (data / scale). This is a fix for a set of data documented in the change log. The values for unit amplitudes were changed in the input_json file and do not use this scale. If the data in LIMS for these sessions is updated, this scaling is not needed. Default is 1""" )
[docs]class InvalidEpoch(RaisingSchema): id = Int(required=True) type = String(required=True) label = String(required=True) start_time = Float(required=True) end_time = Float(required=True)
[docs]class SessionMetadata(RaisingSchema): specimen_name = String(required=True) age_in_days = Float(required=True) full_genotype = String(required=True) strain = String(required=True) sex = String(required=True) stimulus_name = String(required=True) species = String(required=True) donor_id = Int(required=True)
[docs]class BaseNeuropixelsSchema(ArgSchema): """Base schema for writing NWB files for projects with behavior + ecephys""" probes = Nested( Probe, many=True, required=True, help="records of the individual probes used for this experiment", ) optotagging_table_path = argschema.fields.InputFile( required=False, help="""file at this path contains information about the optogenetic stimulation applied during this experiment""", ) running_speed_path = String( required=False, help="""data collected about the running behavior of the experiment's subject""", ) eye_tracking_rig_geometry = Dict( required=False, help="""Mapping containing information about session rig geometry used for eye gaze mapping.""", )
[docs]class VCNInputSchema(BaseNeuropixelsSchema): """Input schema for visual coding neuropixels project"""
[docs] class Meta: unknown = mm.RAISE
log_level = LogLevel( default="INFO", help="set the logging level of the module" ) output_path = String( required=True, validate=check_write_access, help="write outputs to here", ) session_id = Int( required=True, help="unique identifier for this ecephys session" ) session_start_time = DateTime( required=True, help="the date and time (iso8601) at which the session started", ) stimulus_table_path = String( required=True, validate=check_read_access, help="path to stimulus table file", ) invalid_epochs = Nested( InvalidEpoch, many=True, required=True, help="epochs with invalid data" ) session_sync_path = String( required=True, validate=check_read_access, help="""Path to an h5 experiment session sync file (*.sync). This file relates events from different acquisition modalities to one another in time.""", ) pool_size = Int( default=3, help="number of child processes used to write probewise lfp files", ) eye_dlc_ellipses_path = String( required=False, validate=check_read_access, help="""h5 filepath containing raw ellipse fits produced by Deep Lab Cuts of subject eye, pupil, and corneal reflections during experiment""", ) eye_gaze_mapping_path = String( required=False, allow_none=True, help="""h5 filepath containing eye gaze behavior of the experiment's subject""", ) session_metadata = Nested( SessionMetadata, allow_none=True, required=False, help="miscellaneous information describing this session" "", ) running_speed_path = String( required=True, help="""data collected about the running behavior of the experiment's subject""", )
[docs]class ProbeOutputs(RaisingSchema): nwb_path = String(required=True) id = Int(required=True)
[docs]class OutputSchema(RaisingSchema): nwb_path = String(required=True, description="path to output file") probe_outputs = Nested(ProbeOutputs, required=True, many=True)