Source code for allensdk.brain_observatory.ecephys.ecephys_session

import warnings
from collections.abc import Collection
from collections import defaultdict

import xarray as xr
import numpy as np
import pandas as pd
import scipy.stats

from allensdk.core.lazy_property import LazyPropertyMixin
from allensdk.brain_observatory.ecephys.ecephys_session_api import EcephysSessionApi, EcephysNwbSessionApi, EcephysNwb1Api
from allensdk.brain_observatory.ecephys.stimulus_table import naming_utilities
from allensdk.brain_observatory.ecephys.stimulus_table._schemas import default_stimulus_renames, default_column_renames


NON_STIMULUS_PARAMETERS = tuple([
    'start_time',
    'stop_time',
    'duration',
    'stimulus_block',
    "stimulus_condition_id"
])  # stimulus_presentation column names not describing a parameter of a stimulus


[docs]class EcephysSession(LazyPropertyMixin): ''' Represents data from a single EcephysSession Attributes ---------- units : pd.Dataframe A table whose rows are sorted units (putative neurons) and whose columns are characteristics of those units. Index is: unit_id : int Unique integer identifier for this unit. Columns are: firing_rate : float This unit's firing rate (spikes / s) calculated over the window of that unit's activity (the time from its first detected spike to its last). isi_violations : float Estamate of this unit's contamination rate (larger means that more of the spikes assigned to this unit probably originated from other neurons). Calculated as a ratio of the firing rate of the unit over periods where spikes would be isi-violating vs the total firing rate of the unit. peak_channel_id : int Unique integer identifier for this unit's peak channel. A unit's peak channel is the channel on which its peak-to-trough amplitude difference is maximized. This is assessed using the kilosort 2 templates rather than the mean waveforms for a unit. snr : float Signal to noise ratio for this unit. probe_horizontal_position : numeric The horizontal (short-axis) position of this unit's peak channel in microns. probe_vertical_position : numeric The vertical (long-axis, lower values are closer to the probe base) position of this unit's peak channel in microns. probe_id : int Unique integer identifier for this unit's probe. probe_description : str Human-readable description carrying miscellaneous information about this unit's probe. location : str Gross-scale location of this unit's probe. spike_times : dict Maps integer unit ids to arrays of spike times (float) for those units. running_speed : RunningSpeed NamedTuple with two fields timestamps : numpy.ndarray Timestamps of running speed data samples values : np.ndarray Running speed of the experimental subject (in cm / s). mean_waveforms : dict Maps integer unit ids to xarray.DataArrays containing mean spike waveforms for that unit. stimulus_presentations : pd.DataFrame Table whose rows are stimulus presentations and whose columns are presentation characteristics. A stimulus presentation is the smallest unit of distinct stimulus presentation and lasts for (usually) 1 60hz frame. Since not all parameters are relevant to all stimuli, this table contains many 'null' values. Index is stimulus_presentation_id : int Unique identifier for this stimulus presentation Columns are start_time : float Time (s) at which this presentation began stop_time : float Time (s) at which this presentation ended duration : float stop_time - start_time (s). Included for convenience. stimulus_name : str Identifies the stimulus family (e.g. "drifting_gratings" or "natural_movie_3") used for this presentation. The stimulus family, along with relevant parameter values, provides the information required to reconstruct the stimulus presented during this presentation. The empty string indicates a blank period. stimulus_block : numeric A stimulus block is made by sequentially presenting presentations from the same stimulus family. This value is the index of the block which contains this presentation. During a blank period, this is 'null'. TF : float Temporal frequency, or 'null' when not appropriate. SF : float Spatial frequency, or 'null' when not appropriate Ori : float Orientation (in degrees) or 'null' when not appropriate Contrast : float Pos_x : float Pos_y : float Color : numeric Image : numeric Phase : float stimulus_condition_id : integer identifies the session-unique stimulus condition (permutation of parameters) to which this presentation belongs stimulus_conditions : pd.DataFrame Each row is a unique permutation (within this session) of stimulus parameters presented during this experiment. Columns are as stimulus presentations, sans start_time, end_time, stimulus_block, and duration. inter_presentation_intervals : pd.DataFrame The elapsed time between each immediately sequential pair of stimulus presentations. This is a dataframe with a two-level multiindex (levels are 'from_presentation_id' and 'to_presentation_id'). It has a single column, 'interval', which reports the elapsed time between the two presentations in seconds on the experiment's master clock. ''' DETAILED_STIMULUS_PARAMETERS = ( "colorSpace", "flipHoriz", "flipVert", "depth", "interpolate", "mask", "opacity", "rgbPedestal", "tex", "texRes", "units", "rgb", "signalDots", "noiseDots", "fieldSize", "fieldShape", "fieldPos", "nDots", "dotSize", "dotLife", "color_triplet" ) @property def num_units(self): return self._units.shape[0] @property def num_probes(self): return self.probes.shape[0] @property def num_channels(self): return self.channels.shape[0] @property def num_stimulus_presentations(self): return self.stimulus_presentations.shape[0] @property def stimulus_names(self): return self.stimulus_presentations['stimulus_name'].unique().tolist() @property def stimulus_conditions(self): self.stimulus_presentations return self._stimulus_conditions @property def rig_geometry_data(self): if self._rig_metadata: return self._rig_metadata["rig_geometry_data"] else: return None @property def rig_equipment_name(self): if self._rig_metadata: return self._rig_metadata["rig_equipment"] else: return None @property def specimen_name(self): return self._metadata["specimen_name"] @property def age_in_days(self): return self._metadata["age_in_days"] @property def sex(self): return self._metadata["sex"] @property def full_genotype(self): return self._metadata["full_genotype"] @property def session_type(self): return self._metadata["stimulus_name"] @property def units(self): return self._units.drop(columns=['width_rf', 'height_rf', 'on_screen_rf', 'time_to_peak_fl', 'time_to_peak_rf', 'time_to_peak_sg', 'sustained_idx_fl', 'time_to_peak_dg'], errors='ignore') @property def structure_acronyms(self): return self.channels["ecephys_structure_acronym"].unique().tolist() @property def structurewise_unit_counts(self): return self.units["ecephys_structure_acronym"].value_counts() @property def metadata(self): return { "specimen_name": self.specimen_name, "session_type": self.session_type, "full_genotype": self.full_genotype, "sex": self.sex, "age_in_days": self.age_in_days, "rig_equipment_name": self.rig_equipment_name, "num_units": self.num_units, "num_channels": self.num_channels, "num_probes": self.num_probes, "num_stimulus_presentations": self.num_stimulus_presentations, "session_start_time": self.session_start_time, "ecephys_session_id": self.ecephys_session_id, "structure_acronyms": self.structure_acronyms, "stimulus_names": self.stimulus_names } @property def stimulus_presentations(self): return self.__class__._remove_detailed_stimulus_parameters(self._stimulus_presentations) @property def spike_times(self): if not hasattr(self, "_accessed_spike_times"): self._accessed_spike_times = True self._warn_invalid_spike_intervals() return self._spike_times def __init__( self, api: EcephysSessionApi, test: bool = False, **kwargs ): """ Construct an EcephysSession object, which provides access to detailed data for a single extracellular electrophysiology (neuropixels) session. Parameters ---------- api : Used to access data, which is then cached on this object. Must expose the EcephysSessionApi interface. Standard options include instances of: EcephysSessionNwbApi :: reads data from a neurodata without borders 2.0 file. test : If true, check during construction that this session's api is valid. """ self.api: EcephysSessionApi = api self.ecephys_session_id = self.LazyProperty(self.api.get_ecephys_session_id) self.session_start_time = self.LazyProperty(self.api.get_session_start_time) self.running_speed = self.LazyProperty(self.api.get_running_speed) self.mean_waveforms = self.LazyProperty(self.api.get_mean_waveforms, wrappers=[self._build_mean_waveforms]) self._spike_times = self.LazyProperty(self.api.get_spike_times, wrappers=[self._build_spike_times]) self.optogenetic_stimulation_epochs = self.LazyProperty(self.api.get_optogenetic_stimulation) self.spike_amplitudes = self.LazyProperty(self.api.get_spike_amplitudes) self.probes = self.LazyProperty(self.api.get_probes) self.channels = self.LazyProperty(self.api.get_channels) self._stimulus_presentations = self.LazyProperty(self.api.get_stimulus_presentations, wrappers=[self._build_stimulus_presentations, self._mask_invalid_stimulus_presentations]) self.inter_presentation_intervals = self.LazyProperty(self._build_inter_presentation_intervals) self.invalid_times = self.LazyProperty(self.api.get_invalid_times) self._units = self.LazyProperty(self.api.get_units, wrappers=[self._build_units_table]) self._rig_metadata = self.LazyProperty(self.api.get_rig_metadata) self._metadata = self.LazyProperty(self.api.get_metadata) if test: self.api.test()
[docs] def get_current_source_density(self, probe_id): """ Obtain current source density (CSD) of trial-averaged response to a flash stimuli for this probe. See allensdk.brain_observatory.ecephys.current_source_density for details of CSD calculation. CSD is computed with a 1D method (second spatial derivative) without prior spatial smoothing User should apply spatial smoothing of their choice (e.g., Gaussian filter) to the computed CSD Parameters ---------- probe_id : int identify the probe whose CSD data ought to be loaded Returns ------- xr.DataArray : dimensions are channel (id) and time (seconds, relative to stimulus onset). Values are current source density assessed on that channel at that time (V/m^2) """ return self.api.get_current_source_density(probe_id)
[docs] def get_lfp(self, probe_id, mask_invalid_intervals=True): ''' Load an xarray DataArray with LFP data from channels on a single probe Parameters ---------- probe_id : int identify the probe whose LFP data ought to be loaded mask_invalid_intervals : bool if True (default) will mask data in the invalid intervals with np.nan Returns ------- xr.DataArray : dimensions are channel (id) and time (seconds). Values are sampled LFP data. Notes ----- Unlike many other data access methods on this class. This one does not cache the loaded data in memory due to the large size of the LFP data. ''' if mask_invalid_intervals: probe_name = self.probes.loc[probe_id]["description"] fail_tags = ["all_probes", probe_name] invalid_time_intervals = self._filter_invalid_times_by_tags(fail_tags) lfp = self.api.get_lfp(probe_id) time_points = lfp.time valid_time_points = self._get_valid_time_points(time_points, invalid_time_intervals) return lfp.where(cond=valid_time_points) else: return self.api.get_lfp(probe_id)
def _get_valid_time_points(self, time_points, invalid_time_intevals): all_time_points = xr.DataArray( name="time_points", data=[True] * len(time_points), dims=['time'], coords=[time_points] ) valid_time_points = all_time_points for ix, invalid_time_interval in invalid_time_intevals.iterrows(): invalid_time_points = (time_points >= invalid_time_interval['start_time']) & (time_points <= invalid_time_interval['stop_time']) valid_time_points = np.logical_and(valid_time_points, np.logical_not(invalid_time_points)) return valid_time_points def _filter_invalid_times_by_tags(self, tags): """ Parameters ---------- invalid_times: pd.DataFrame of invalid times tags: list of tags Returns ------- pd.DataFrame of invalid times having tags """ invalid_times = self.invalid_times.copy() if not invalid_times.empty: mask = invalid_times['tags'].apply(lambda x: any([t in x for t in tags])) invalid_times = invalid_times[mask] return invalid_times
[docs] def get_inter_presentation_intervals_for_stimulus(self, stimulus_names): ''' Get a subset of this session's inter-presentation intervals, filtered by stimulus name. Parameters ---------- stimulus_names : array-like of str The names of stimuli to include in the output. Returns ------- pd.DataFrame : inter-presentation intervals, filtered to the requested stimulus names. ''' stimulus_names = coerce_scalar(stimulus_names, f'expected stimulus_names to be a collection (list-like), but found {type(stimulus_names)}: {stimulus_names}') filtered_presentations = self.stimulus_presentations[self.stimulus_presentations['stimulus_name'].isin(stimulus_names)] filtered_ids = set(filtered_presentations.index.values) return self.inter_presentation_intervals[ (self.inter_presentation_intervals.index.isin(filtered_ids, level='from_presentation_id')) & (self.inter_presentation_intervals.index.isin(filtered_ids, level='to_presentation_id')) ]
[docs] def get_stimulus_table(self, stimulus_names=None, include_detailed_parameters=False, include_unused_parameters=False): '''Get a subset of stimulus presentations by name, with irrelevant parameters filtered off Parameters ---------- stimulus_names : array-like of str The names of stimuli to include in the output. Returns ------- pd.DataFrame : Rows are filtered presentations, columns are the relevant subset of stimulus parameters ''' if stimulus_names is None: stimulus_names = self.stimulus_names stimulus_names = coerce_scalar(stimulus_names, f'expected stimulus_names to be a collection (list-like), but found {type(stimulus_names)}: {stimulus_names}') presentations = self._stimulus_presentations[self._stimulus_presentations['stimulus_name'].isin(stimulus_names)] if not include_detailed_parameters: presentations = self.__class__._remove_detailed_stimulus_parameters(presentations) if not include_unused_parameters: presentations = removed_unused_stimulus_presentation_columns(presentations) return presentations
[docs] def get_stimulus_epochs(self, duration_thresholds=None): """ Reports continuous periods of time during which a single kind of stimulus was presented flipVert Parameters --------- duration_thresholds : dict, optional keys are stimulus names, values are floating point durations in seconds. All epochs with - a given stimulus name - a duration shorter than the associated threshold will be removed from the results """ if duration_thresholds is None: duration_thresholds = {"spontaneous_activity": 90.0} presentations = self.stimulus_presentations.copy() diff_indices = nan_intervals(presentations["stimulus_block"].values) epochs = [] for left, right in zip(diff_indices[:-1], diff_indices[1:]): epochs.append({ "start_time": presentations.iloc[left]["start_time"], "stop_time": presentations.iloc[right - 1]["stop_time"], "stimulus_name": presentations.iloc[left]["stimulus_name"], "stimulus_block": presentations.iloc[left]["stimulus_block"] }) epochs = pd.DataFrame(epochs) epochs["duration"] = epochs["stop_time"] - epochs["start_time"] for key, threshold in duration_thresholds.items(): epochs = epochs[ (epochs["stimulus_name"] != key) | (epochs["duration"] >= threshold) ] return epochs.loc[:, ["start_time", "stop_time", "duration", "stimulus_name", "stimulus_block"]]
[docs] def get_invalid_times(self): """ Report invalid time intervals with tags describing the scope of invalid data The tags format: [scope,scope_id,label] scope: 'EcephysSession': data is invalid across session 'EcephysProbe': data is invalid for a single probe label: 'all_probes': gain fluctuations on the Neuropixels probe result in missed spikes and LFP saturation events 'stimulus' : very long frames (>3x the normal frame length) make any stimulus-locked analysis invalid 'probe#': probe # stopped sending data during this interval (spikes and LFP samples will be missing) 'optotagging': missing optotagging data Returns ------- pd.DataFrame : Rows are invalid intervals, columns are 'start_time' (s), 'stop_time' (s), 'tags' """ return self.invalid_times
[docs] def get_pupil_data(self, suppress_pupil_data: bool = True) -> pd.DataFrame: """Return a dataframe with eye tracking data Parameters ---------- suppress_pupil_data : bool, optional Whether or not to suppress eye gaze mapping data in output dataframe, by default True. Returns ------- pd.DataFrame Contains columns for eye, pupil and cr ellipse fits: *_center_x *_center_y *_height *_width *_phi May also contain raw/filtered columns for gaze mapping if suppress_pupil_data is set to False: *_eye_area *_pupil_area *_screen_coordinates_x_cm *_screen_coordinates_y_cm *_screen_coordinates_spherical_x_deg *_screen_coorindates_spherical_y_deg """ return self.api.get_pupil_data(suppress_pupil_data=suppress_pupil_data)
def _mask_invalid_stimulus_presentations(self, stimulus_presentations): """Mask invalid stimulus presentations Find stimulus presentations overlapping with invalid times Mask stimulus names with "invalid_presentation", keep "start_time" and "stop_time", mask remaining data with np.nan Parameters ---------- stimulus_presentations : pd.DataFrame table including all stimulus presentations Returns ------- pd.DataFrame : table with masked invalid presentations """ fail_tags = ["stimulus"] invalid_times = self._filter_invalid_times_by_tags(fail_tags) for ix_sp, sp in stimulus_presentations.iterrows(): stim_epoch = sp['start_time'], sp['stop_time'] for ix_it, it in invalid_times.iterrows(): invalid_interval = it['start_time'], it['stop_time'] if _overlap(stim_epoch, invalid_interval): stimulus_presentations.iloc[ix_sp, :] = np.nan stimulus_presentations.at[ix_sp, "stimulus_name"] = "invalid_presentation" stimulus_presentations.at[ix_sp, "start_time"] = stim_epoch[0] stimulus_presentations.at[ix_sp, "stop_time"] = stim_epoch[1] return stimulus_presentations
[docs] def presentationwise_spike_counts( self, bin_edges, stimulus_presentation_ids, unit_ids, binarize=False, dtype=None, large_bin_size_threshold=0.001, time_domain_callback=None ): ''' Build an array of spike counts surrounding stimulus onset per unit and stimulus frame. Parameters --------- bin_edges : numpy.ndarray Spikes will be counted into the bins defined by these edges. Values are in seconds, relative to stimulus onset. stimulus_presentation_ids : array-like Filter to these stimulus presentations unit_ids : array-like Filter to these units binarize : bool, optional If true, all counts greater than 0 will be treated as 1. This results in lower storage overhead, but is only reasonable if bin sizes are fine (<= 1 millisecond). large_bin_size_threshold : float, optional If binarize is True and the largest bin width is greater than this value, a warning will be emitted. time_domain_callback : callable, optional The time domain is a numpy array whose values are trial-aligned bin edges (each row is aligned to a different trial). This optional function will be applied to the time domain before counting spikes. Returns ------- xarray.DataArray : Data array whose dimensions are stimulus presentation, unit, and time bin and whose values are spike counts. ''' stimulus_presentations = self._filter_owned_df('stimulus_presentations', ids=stimulus_presentation_ids) units = self._filter_owned_df('units', ids=unit_ids) largest_bin_size = np.amax(np.diff(bin_edges)) if binarize and largest_bin_size > large_bin_size_threshold: warnings.warn( f'You\'ve elected to binarize spike counts, but your maximum bin width is {largest_bin_size:2.5f} seconds. ' 'Binarizing spike counts with such a large bin width can cause significant loss of accuracy! ' f'Please consider only binarizing spike counts when your bins are <= {large_bin_size_threshold} seconds wide.' ) bin_edges = np.array(bin_edges) domain = build_time_window_domain(bin_edges, stimulus_presentations['start_time'].values, callback=time_domain_callback) out_of_order = np.where(np.diff(domain, axis=1) < 0) if len(out_of_order[0]) > 0: out_of_order_time_bins = [(row, col) for row, col in zip(out_of_order)] raise ValueError(f"The time domain specified contains out-of-order bin edges at indices: {out_of_order_time_bins}") ends = domain[:, -1] starts = domain[:, 0] time_diffs = starts[1:] - ends[:-1] overlapping = np.where(time_diffs < 0)[0] if len(overlapping) > 0: # Ignoring intervals that overlaps multiple time bins because trying to figure that out would take O(n) overlapping = [(s, s + 1) for s in overlapping] warnings.warn(f"You've specified some overlapping time intervals between neighboring rows: {overlapping}, " f"with a maximum overlap of {np.abs(np.min(time_diffs))} seconds.") tiled_data = build_spike_histogram( domain, self.spike_times, units.index.values, dtype=dtype, binarize=binarize ) tiled_data = xr.DataArray( name='spike_counts', data=tiled_data, coords={ 'stimulus_presentation_id': stimulus_presentations.index.values, 'time_relative_to_stimulus_onset': bin_edges[:-1] + np.diff(bin_edges) / 2, 'unit_id': units.index.values }, dims=['stimulus_presentation_id', 'time_relative_to_stimulus_onset', 'unit_id'] ) return tiled_data
[docs] def presentationwise_spike_times(self, stimulus_presentation_ids=None, unit_ids=None): ''' Produce a table associating spike times with units and stimulus presentations Parameters ---------- stimulus_presentation_ids : array-like Filter to these stimulus presentations unit_ids : array-like Filter to these units Returns ------- pandas.DataFrame : Index is spike_time : float On the session's master clock. Columns are stimulus_presentation_id : int The stimulus presentation on which this spike occurred. unit_id : int The unit that emitted this spike. ''' stimulus_presentations = self._filter_owned_df('stimulus_presentations', ids=stimulus_presentation_ids) units = self._filter_owned_df('units', ids=unit_ids) presentation_times = np.zeros([stimulus_presentations.shape[0] * 2]) presentation_times[::2] = np.array(stimulus_presentations['start_time']) presentation_times[1::2] = np.array(stimulus_presentations['stop_time']) all_presentation_ids = np.array(stimulus_presentations.index.values) presentation_ids = [] unit_ids = [] spike_times = [] for ii, unit_id in enumerate(units.index.values): data = self.spike_times[unit_id] indices = np.searchsorted(presentation_times, data) - 1 index_valid = indices % 2 == 0 presentations = all_presentation_ids[np.floor(indices / 2).astype(int)] sorder = np.argsort(presentations) presentations = presentations[sorder] index_valid = index_valid[sorder] data = data[sorder] changes = np.where(np.ediff1d(presentations, to_begin=1, to_end=1))[0] for ii, jj in zip(changes[:-1], changes[1:]): values = data[ii:jj][index_valid[ii:jj]] if values.size == 0: continue unit_ids.append(np.zeros([values.size]) + unit_id) presentation_ids.append(np.zeros([values.size]) + presentations[ii]) spike_times.append(values) if not spike_times: # If there are no units firing during the given stimulus return an empty dataframe return pd.DataFrame(columns=['spike_times', 'stimulus_presentation', 'unit_id', 'time_since_stimulus_presentation_onset']) spike_df = pd.DataFrame({ 'stimulus_presentation_id': np.concatenate(presentation_ids).astype(int), 'unit_id': np.concatenate(unit_ids).astype(int) }, index=pd.Index(np.concatenate(spike_times), name='spike_time')) # Add time since stimulus presentation onset onset_times = self._filter_owned_df( "stimulus_presentations", ids=all_presentation_ids)["start_time"] spikes_with_onset = spike_df.join(onset_times, on=["stimulus_presentation_id"]) spikes_with_onset["time_since_stimulus_presentation_onset"] = ( spikes_with_onset.index - spikes_with_onset["start_time"] ) spikes_with_onset.sort_values('spike_time', axis=0, inplace=True) spikes_with_onset.drop(columns=["start_time"], inplace=True) return spikes_with_onset
[docs] def conditionwise_spike_statistics(self, stimulus_presentation_ids=None, unit_ids=None, use_rates=False): """ Produce summary statistics for each distinct stimulus condition Parameters ---------- stimulus_presentation_ids : array-like identifies stimulus presentations from which spikes will be considered unit_ids : array-like identifies units whose spikes will be considered use_rates : bool, optional If True, use firing rates. If False, use spike counts. Returns ------- pd.DataFrame : Rows are indexed by unit id and stimulus condition id. Values are summary statistics describing spikes emitted by a specific unit across presentations within a specific condition. """ # TODO: Need to return an empty df if no matching unit-ids or presentation-ids are found # TODO: To use filter_owned_df() make sure to convert the results from a Series to a Dataframe stimulus_presentation_ids = (stimulus_presentation_ids if stimulus_presentation_ids is not None else self.stimulus_presentations.index.values) # In case presentations = self.stimulus_presentations.loc[stimulus_presentation_ids, ["stimulus_condition_id", "duration"]] spikes = self.presentationwise_spike_times( stimulus_presentation_ids=stimulus_presentation_ids, unit_ids=unit_ids ) if spikes.empty: # In the case there are no spikes spike_counts = pd.DataFrame({'spike_count': 0}, index=pd.MultiIndex.from_product([stimulus_presentation_ids, unit_ids], names=['stimulus_presentation_id', 'unit_id'])) else: spike_counts = spikes.copy() spike_counts["spike_count"] = np.zeros(spike_counts.shape[0]) spike_counts = spike_counts.groupby(["stimulus_presentation_id", "unit_id"]).count() unit_ids = unit_ids if unit_ids is not None else spikes['unit_id'].unique() # If not explicity stated get unit ids from spikes table. spike_counts = spike_counts.reindex(pd.MultiIndex.from_product([stimulus_presentation_ids, unit_ids], names=['stimulus_presentation_id', 'unit_id']), fill_value=0) sp = pd.merge(spike_counts, presentations, left_on="stimulus_presentation_id", right_index=True, how="left") sp.reset_index(inplace=True) if use_rates: sp["spike_rate"] = sp["spike_count"] / sp["duration"] sp.drop(columns=["spike_count"], inplace=True) extractor = _extract_summary_rate_statistics else: sp.drop(columns=["duration"]) extractor = _extract_summary_count_statistics summary = [] for ind, gr in sp.groupby(["stimulus_condition_id", "unit_id"]): summary.append(extractor(ind, gr)) return pd.DataFrame(summary).set_index(keys=["unit_id", "stimulus_condition_id"])
[docs] def get_parameter_values_for_stimulus(self, stimulus_name, drop_nulls=True): """ For each stimulus parameter, report the unique values taken on by that parameter while a named stimulus was presented. Parameters ---------- stimulus_name : str filter to presentations of this stimulus Returns ------- dict : maps parameters (column names) to their unique values. """ presentation_ids = self.get_stimulus_table([stimulus_name]).index.values return self.get_stimulus_parameter_values(presentation_ids, drop_nulls=drop_nulls)
[docs] def get_stimulus_parameter_values(self, stimulus_presentation_ids=None, drop_nulls=True): ''' For each stimulus parameter, report the unique values taken on by that parameter throughout the course of the session. Parameters ---------- stimulus_presentation_ids : array-like, optional If provided, only parameter values from these stimulus presentations will be considered. Returns ------- dict : maps parameters (column names) to their unique values. ''' stimulus_presentations = self._filter_owned_df('stimulus_presentations', ids=stimulus_presentation_ids) stimulus_presentations = stimulus_presentations.drop(columns=list(NON_STIMULUS_PARAMETERS) + ['stimulus_name']) stimulus_presentations = removed_unused_stimulus_presentation_columns(stimulus_presentations) parameters = {} for colname in stimulus_presentations.columns: uniques = stimulus_presentations[colname].unique() non_null = np.array(uniques[uniques != "null"]) non_null = non_null non_null = np.sort(non_null) if not drop_nulls and "null" in uniques: non_null = np.concatenate([non_null, ["null"]]) parameters[colname] = non_null return parameters
[docs] def channel_structure_intervals(self, channel_ids): """ find on a list of channels the intervals of channels inserted into particular structures Parameters ---------- channel_ids : list A list of channel ids structure_id_key : str use this column for numerically identifying structures structure_label_key : str use this column for human-readable structure identification Returns ------- labels : np.ndarray for each detected interval, the label associated with that interval intervals : np.ndarray one element longer than labels. Start and end indices for intervals. """ structure_id_key = "ecephys_structure_id" structure_label_key = "ecephys_structure_acronym" np.array(channel_ids).sort() table = self.channels.loc[channel_ids] unique_probes = table["probe_id"].unique() if len(unique_probes) > 1: warnings.warn("Calculating structure boundaries across channels from multiple probes.") intervals = nan_intervals(table[structure_id_key].values) labels = table[structure_label_key].iloc[intervals[:-1]].values return labels, intervals
def _build_spike_times(self, spike_times): retained_units = set(self._units.index.values) output_spike_times = {} for unit_id in list(spike_times.keys()): data = spike_times.pop(unit_id) if unit_id not in retained_units: continue output_spike_times[unit_id] = data return output_spike_times def _build_stimulus_presentations(self, stimulus_presentations, nonapplicable="null"): stimulus_presentations.index.name = 'stimulus_presentation_id' stimulus_presentations = stimulus_presentations.drop(columns=['stimulus_index']) # TODO: putting these here for now; after SWDB 2019, will rerun stimulus table module for all sessions # and can remove these stimulus_presentations = naming_utilities.collapse_columns(stimulus_presentations) stimulus_presentations = naming_utilities.standardize_movie_numbers(stimulus_presentations) stimulus_presentations = naming_utilities.add_number_to_shuffled_movie(stimulus_presentations) stimulus_presentations = naming_utilities.map_stimulus_names( stimulus_presentations, default_stimulus_renames ) stimulus_presentations = naming_utilities.map_column_names(stimulus_presentations, default_column_renames) # pandas groupby ops ignore nans, so we need a new "nonapplicable" value that pandas does not recognize as null ... stimulus_presentations.replace("", nonapplicable, inplace=True) stimulus_presentations.fillna(nonapplicable, inplace=True) stimulus_presentations['duration'] = stimulus_presentations['stop_time'] - stimulus_presentations['start_time'] # TODO: database these stimulus_conditions = {} presentation_conditions = [] cid_counter = -1 # TODO: Can we have parameters on what columns to omit? If stimulus_block or duration is left in it can affect # how conditionwise_spike_statistics counts spikes params_only = stimulus_presentations.drop(columns=["start_time", "stop_time", "duration", "stimulus_block"]) for row in params_only.itertuples(index=False): if row in stimulus_conditions: cid = stimulus_conditions[row] else: cid_counter += 1 stimulus_conditions[row] = cid_counter cid = cid_counter presentation_conditions.append(cid) cond_ids = [] cond_vals = [] for cv, ci in stimulus_conditions.items(): cond_ids.append(ci) cond_vals.append(cv) self._stimulus_conditions = pd.DataFrame(cond_vals, index=pd.Index(data=cond_ids, name="stimulus_condition_id")) stimulus_presentations["stimulus_condition_id"] = presentation_conditions return stimulus_presentations def _build_units_table(self, units_table): channels = self.channels.copy() probes = self.probes.copy() self._unmerged_units = units_table.copy() table = pd.merge(units_table, channels, left_on='peak_channel_id', right_index=True, suffixes=['_unit', '_channel']) table = pd.merge(table, probes, left_on='probe_id', right_index=True, suffixes=['_unit', '_probe']) table.index.name = 'unit_id' table = table.rename(columns={ 'description': 'probe_description', 'local_index_channel': 'channel_local_index', 'PT_ratio': 'waveform_PT_ratio', 'amplitude': 'waveform_amplitude', 'duration': 'waveform_duration', 'halfwidth': 'waveform_halfwidth', 'recovery_slope': 'waveform_recovery_slope', 'repolarization_slope': 'waveform_repolarization_slope', 'spread': 'waveform_spread', 'velocity_above': 'waveform_velocity_above', 'velocity_below': 'waveform_velocity_below', 'sampling_rate': 'probe_sampling_rate', 'lfp_sampling_rate': 'probe_lfp_sampling_rate', 'has_lfp_data': 'probe_has_lfp_data', 'l_ratio': 'L_ratio', 'pref_images_multi_ns': 'pref_image_multi_ns', }) return table.sort_values(by=['probe_description', 'probe_vertical_position', 'probe_horizontal_position']) def _build_nwb1_waveforms(self, mean_waveforms): # _build_mean_waveforms() assumes every unit has the same number of waveforms and that a unit-waveform exists # for all channels. This is not true for NWB 1 files where each unit has ONE waveform on ONE channel units_df = self._units output_waveforms = {} sampling_rate_lu = {uid: self.probes.loc[r['probe_id']]['sampling_rate'] for uid, r in units_df.iterrows()} for uid in list(mean_waveforms.keys()): data = mean_waveforms.pop(uid) output_waveforms[uid] = xr.DataArray( data=data, dims=['channel_id', 'time'], coords={ 'channel_id': [units_df.loc[uid]['peak_channel_id']], 'time': np.arange(data.shape[1]) / sampling_rate_lu[uid] } ) return output_waveforms def _build_mean_waveforms(self, mean_waveforms): if isinstance(self.api, EcephysNwb1Api): return self._build_nwb1_waveforms(mean_waveforms) channel_id_lut = defaultdict(lambda: -1) for cid, row in self.channels.iterrows(): channel_id_lut[(row["local_index"], row["probe_id"])] = cid probe_id_lut = {uid: row['probe_id'] for uid, row in self._units.iterrows()} output_waveforms = {} for uid in list(mean_waveforms.keys()): data = mean_waveforms.pop(uid) if uid not in probe_id_lut: # It's been filtered out during unit table generation! continue probe_id = probe_id_lut[uid] output_waveforms[uid] = xr.DataArray( data=data, dims=['channel_id', 'time'], coords={ 'channel_id': [channel_id_lut[(ii, probe_id)] for ii in range(data.shape[0])], 'time': np.arange(data.shape[1]) / self.probes.loc[probe_id]['sampling_rate'] } ) output_waveforms[uid] = output_waveforms[uid][output_waveforms[uid]["channel_id"] != -1] return output_waveforms def _build_inter_presentation_intervals(self): intervals = pd.DataFrame({ 'from_presentation_id': self.stimulus_presentations.index.values[:-1], 'to_presentation_id': self.stimulus_presentations.index.values[1:], 'interval': self.stimulus_presentations['start_time'].values[1:] - self.stimulus_presentations['stop_time'].values[:-1] }) return intervals.set_index(['from_presentation_id', 'to_presentation_id'], inplace=False) def _filter_owned_df(self, key, ids=None, copy=True): df = getattr(self, key) if copy: df = df.copy() if ids is None: return df ids = coerce_scalar(ids, f'a scalar ({ids}) was provided as ids, filtering to a single row of {key}.') df = df.loc[ids] if df.shape[0] == 0: warnings.warn(f'filtering to an empty set of {key}!') return df @classmethod def _remove_detailed_stimulus_parameters(cls, presentations): columns = list(cls.DETAILED_STIMULUS_PARAMETERS) return presentations.drop(columns=columns, errors="ignore")
[docs] @classmethod def from_nwb_path(cls, path, nwb_version=2, api_kwargs=None, **kwargs): api_kwargs = {} if api_kwargs is None else api_kwargs # TODO: Is there a way for pynwb to check the file before actually loading it with io read? If so we could # automatically check what NWB version is being inputed nwb_version = int(nwb_version) # only use major version if nwb_version >= 2: NWBAdaptorCls = EcephysNwbSessionApi elif nwb_version == 1: NWBAdaptorCls = EcephysNwb1Api else: raise Exception(f'specified NWB version {nwb_version} not supported. Supported versions are: 2.X, 1.X') return cls(api=NWBAdaptorCls.from_path(path=path, **api_kwargs), **kwargs)
def _warn_invalid_spike_intervals(self): fail_tags = list(self.probes["description"]) fail_tags.append("all_probes") invalid_time_intervals = self._filter_invalid_times_by_tags(fail_tags) if not invalid_time_intervals.empty: warnings.warn("Session includes invalid time intervals that could be accessed with the attribute 'invalid_times'," "Spikes within these intervals are invalid and may need to be excluded from the analysis.")
[docs]def build_spike_histogram(time_domain, spike_times, unit_ids, dtype=None, binarize=False): time_domain = np.array(time_domain) unit_ids = np.array(unit_ids) tiled_data = np.zeros( (time_domain.shape[0], time_domain.shape[1] - 1, unit_ids.size), dtype=(np.uint8 if binarize else np.uint16) if dtype is None else dtype ) starts = time_domain[:, :-1] ends = time_domain[:, 1:] for ii, unit_id in enumerate(unit_ids): data = np.array(spike_times[unit_id]) start_positions = np.searchsorted(data, starts.flat) end_positions = np.searchsorted(data, ends.flat, side="right") counts = (end_positions - start_positions) tiled_data[:, :, ii].flat = counts > 0 if binarize else counts return tiled_data
[docs]def build_time_window_domain(bin_edges, offsets, callback=None): callback = (lambda x: x) if callback is None else callback domain = np.tile(bin_edges[None, :], (len(offsets), 1)) domain += offsets[:, None] return callback(domain)
[docs]def removed_unused_stimulus_presentation_columns(stimulus_presentations): to_drop = [] for cn in stimulus_presentations.columns: if np.all(stimulus_presentations[cn].isna()): to_drop.append(cn) elif np.all(stimulus_presentations[cn].astype(str).values == ''): to_drop.append(cn) elif np.all(stimulus_presentations[cn].astype(str).values == 'null'): to_drop.append(cn) return stimulus_presentations.drop(columns=to_drop)
[docs]def nan_intervals(array, nan_like=["null"]): """ find interval bounds (bounding consecutive identical values) in an array, which may contain nans Parameters ----------- array : np.ndarray Returns ------- np.ndarray : start and end indices of detected intervals (one longer than the number of intervals) """ intervals = [0] current = array[0] for ii, item in enumerate(array[1:]): if is_distinct_from(item, current): intervals.append(ii + 1) current = item intervals.append(len(array)) return np.unique(intervals)
[docs]def is_distinct_from(left, right): if type(left) != type(right): return True if pd.isna(left) and pd.isna(right): return False if left is None and right is None: return False return left != right
[docs]def array_intervals(array): """ find interval bounds (bounding consecutive identical values) in an array Parameters ----------- array : np.ndarray Returns ------- np.ndarray : start and end indices of detected intervals (one longer than the number of intervals) """ changes = np.flatnonzero(np.diff(array)) + 1 return np.concatenate([[0], changes, [len(array)]])
[docs]def coerce_scalar(value, message, warn=False): if not isinstance(value, Collection) or isinstance(value, str): if warn: warnings.warn(message) return [value] return value
def _extract_summary_count_statistics(index, group): return { "stimulus_condition_id": index[0], "unit_id": index[1], "spike_count": group["spike_count"].sum(), "stimulus_presentation_count": group.shape[0], "spike_mean": np.mean(group["spike_count"].values), "spike_std": np.std(group["spike_count"].values, ddof=1), "spike_sem": scipy.stats.sem(group["spike_count"].values) } def _extract_summary_rate_statistics(index, group): return { "stimulus_condition_id": index[0], "unit_id": index[1], "stimulus_presentation_count": group.shape[0], "spike_mean": np.mean(group["spike_rate"].values), "spike_std": np.std(group["spike_rate"].values, ddof=1), "spike_sem": scipy.stats.sem(group["spike_rate"].values) } def _overlap(a, b): """Check if the two intervals overlap Parameters ---------- a : tuple start, stop times b : tuple start, stop times Returns ------- bool : True if overlap, otherwise False """ return max(a[0], b[0]) <= min(a[1], b[1])