import warnings
from collections.abc import Collection
from collections import defaultdict
import xarray as xr
import numpy as np
import pandas as pd
import scipy.stats
from allensdk.core.lazy_property import LazyPropertyMixin
from allensdk.brain_observatory.ecephys.ecephys_session_api import EcephysSessionApi, EcephysNwbSessionApi, EcephysNwb1Api
from allensdk.brain_observatory.ecephys.stimulus_table import naming_utilities
from allensdk.brain_observatory.ecephys.stimulus_table._schemas import default_stimulus_renames, default_column_renames
NON_STIMULUS_PARAMETERS = tuple([
'start_time',
'stop_time',
'duration',
'stimulus_block',
"stimulus_condition_id"
]) # stimulus_presentation column names not describing a parameter of a stimulus
[docs]class EcephysSession(LazyPropertyMixin):
''' Represents data from a single EcephysSession
Attributes
----------
units : pd.Dataframe
A table whose rows are sorted units (putative neurons) and whose columns are characteristics
of those units.
Index is:
unit_id : int
Unique integer identifier for this unit.
Columns are:
firing_rate : float
This unit's firing rate (spikes / s) calculated over the window of that unit's activity
(the time from its first detected spike to its last).
isi_violations : float
Estamate of this unit's contamination rate (larger means that more of the spikes assigned
to this unit probably originated from other neurons). Calculated as a ratio of the firing
rate of the unit over periods where spikes would be isi-violating vs the total firing
rate of the unit.
peak_channel_id : int
Unique integer identifier for this unit's peak channel. A unit's peak channel is the channel on
which its peak-to-trough amplitude difference is maximized. This is assessed using the kilosort 2
templates rather than the mean waveforms for a unit.
snr : float
Signal to noise ratio for this unit.
probe_horizontal_position : numeric
The horizontal (short-axis) position of this unit's peak channel in microns.
probe_vertical_position : numeric
The vertical (long-axis, lower values are closer to the probe base) position of
this unit's peak channel in microns.
probe_id : int
Unique integer identifier for this unit's probe.
probe_description : str
Human-readable description carrying miscellaneous information about this unit's probe.
location : str
Gross-scale location of this unit's probe.
spike_times : dict
Maps integer unit ids to arrays of spike times (float) for those units.
running_speed : RunningSpeed
NamedTuple with two fields
timestamps : numpy.ndarray
Timestamps of running speed data samples
values : np.ndarray
Running speed of the experimental subject (in cm / s).
mean_waveforms : dict
Maps integer unit ids to xarray.DataArrays containing mean spike waveforms for that unit.
stimulus_presentations : pd.DataFrame
Table whose rows are stimulus presentations and whose columns are presentation characteristics.
A stimulus presentation is the smallest unit of distinct stimulus presentation and lasts for
(usually) 1 60hz frame. Since not all parameters are relevant to all stimuli, this table
contains many 'null' values.
Index is
stimulus_presentation_id : int
Unique identifier for this stimulus presentation
Columns are
start_time : float
Time (s) at which this presentation began
stop_time : float
Time (s) at which this presentation ended
duration : float
stop_time - start_time (s). Included for convenience.
stimulus_name : str
Identifies the stimulus family (e.g. "drifting_gratings" or "natural_movie_3") used
for this presentation. The stimulus family, along with relevant parameter values, provides the
information required to reconstruct the stimulus presented during this presentation. The empty
string indicates a blank period.
stimulus_block : numeric
A stimulus block is made by sequentially presenting presentations from the same stimulus family.
This value is the index of the block which contains this presentation. During a blank period,
this is 'null'.
TF : float
Temporal frequency, or 'null' when not appropriate.
SF : float
Spatial frequency, or 'null' when not appropriate
Ori : float
Orientation (in degrees) or 'null' when not appropriate
Contrast : float
Pos_x : float
Pos_y : float
Color : numeric
Image : numeric
Phase : float
stimulus_condition_id : integer
identifies the session-unique stimulus condition (permutation of parameters) to which this presentation
belongs
stimulus_conditions : pd.DataFrame
Each row is a unique permutation (within this session) of stimulus parameters presented during this experiment.
Columns are as stimulus presentations, sans start_time, end_time, stimulus_block, and duration.
inter_presentation_intervals : pd.DataFrame
The elapsed time between each immediately sequential pair of stimulus presentations. This is a dataframe with a
two-level multiindex (levels are 'from_presentation_id' and 'to_presentation_id'). It has a single column,
'interval', which reports the elapsed time between the two presentations in seconds on the experiment's master
clock.
'''
DETAILED_STIMULUS_PARAMETERS = (
"colorSpace",
"flipHoriz",
"flipVert",
"depth",
"interpolate",
"mask",
"opacity",
"rgbPedestal",
"tex",
"texRes",
"units",
"rgb",
"signalDots",
"noiseDots",
"fieldSize",
"fieldShape",
"fieldPos",
"nDots",
"dotSize",
"dotLife",
"color_triplet"
)
@property
def num_units(self):
return self._units.shape[0]
@property
def num_probes(self):
return self.probes.shape[0]
@property
def num_channels(self):
return self.channels.shape[0]
@property
def num_stimulus_presentations(self):
return self.stimulus_presentations.shape[0]
@property
def stimulus_names(self):
return self.stimulus_presentations['stimulus_name'].unique().tolist()
@property
def stimulus_conditions(self):
self.stimulus_presentations
return self._stimulus_conditions
@property
def rig_geometry_data(self):
if self._rig_metadata:
return self._rig_metadata["rig_geometry_data"]
else:
return None
@property
def rig_equipment_name(self):
if self._rig_metadata:
return self._rig_metadata["rig_equipment"]
else:
return None
@property
def specimen_name(self):
return self._metadata["specimen_name"]
@property
def age_in_days(self):
return self._metadata["age_in_days"]
@property
def sex(self):
return self._metadata["sex"]
@property
def full_genotype(self):
return self._metadata["full_genotype"]
@property
def session_type(self):
return self._metadata["stimulus_name"]
@property
def units(self):
return self._units.drop(columns=['width_rf', 'height_rf',
'on_screen_rf', 'time_to_peak_fl',
'time_to_peak_rf', 'time_to_peak_sg',
'sustained_idx_fl', 'time_to_peak_dg'],
errors='ignore')
@property
def structure_acronyms(self):
return self.channels["ecephys_structure_acronym"].unique().tolist()
@property
def structurewise_unit_counts(self):
return self.units["ecephys_structure_acronym"].value_counts()
@property
def metadata(self):
return {
"specimen_name": self.specimen_name,
"session_type": self.session_type,
"full_genotype": self.full_genotype,
"sex": self.sex,
"age_in_days": self.age_in_days,
"rig_equipment_name": self.rig_equipment_name,
"num_units": self.num_units,
"num_channels": self.num_channels,
"num_probes": self.num_probes,
"num_stimulus_presentations": self.num_stimulus_presentations,
"session_start_time": self.session_start_time,
"ecephys_session_id": self.ecephys_session_id,
"structure_acronyms": self.structure_acronyms,
"stimulus_names": self.stimulus_names
}
@property
def stimulus_presentations(self):
return self.__class__._remove_detailed_stimulus_parameters(self._stimulus_presentations)
@property
def spike_times(self):
if not hasattr(self, "_accessed_spike_times"):
self._accessed_spike_times = True
self._warn_invalid_spike_intervals()
return self._spike_times
def __init__(
self,
api: EcephysSessionApi,
test: bool = False,
**kwargs
):
""" Construct an EcephysSession object, which provides access to
detailed data for a single extracellular electrophysiology
(neuropixels) session.
Parameters
----------
api :
Used to access data, which is then cached on this object. Must
expose the EcephysSessionApi interface. Standard options include
instances of:
EcephysSessionNwbApi :: reads data from a neurodata without
borders 2.0 file.
test :
If true, check during construction that this session's api is
valid.
"""
self.api: EcephysSessionApi = api
self.ecephys_session_id = self.LazyProperty(self.api.get_ecephys_session_id)
self.session_start_time = self.LazyProperty(self.api.get_session_start_time)
self.running_speed = self.LazyProperty(self.api.get_running_speed)
self.mean_waveforms = self.LazyProperty(self.api.get_mean_waveforms, wrappers=[self._build_mean_waveforms])
self._spike_times = self.LazyProperty(self.api.get_spike_times, wrappers=[self._build_spike_times])
self.optogenetic_stimulation_epochs = self.LazyProperty(self.api.get_optogenetic_stimulation)
self.spike_amplitudes = self.LazyProperty(self.api.get_spike_amplitudes)
self.probes = self.LazyProperty(self.api.get_probes)
self.channels = self.LazyProperty(self.api.get_channels)
self._stimulus_presentations = self.LazyProperty(self.api.get_stimulus_presentations,
wrappers=[self._build_stimulus_presentations, self._mask_invalid_stimulus_presentations])
self.inter_presentation_intervals = self.LazyProperty(self._build_inter_presentation_intervals)
self.invalid_times = self.LazyProperty(self.api.get_invalid_times)
self._units = self.LazyProperty(self.api.get_units, wrappers=[self._build_units_table])
self._rig_metadata = self.LazyProperty(self.api.get_rig_metadata)
self._metadata = self.LazyProperty(self.api.get_metadata)
if test:
self.api.test()
[docs] def get_current_source_density(self, probe_id):
""" Obtain current source density (CSD) of trial-averaged response to a flash stimuli for this probe.
See allensdk.brain_observatory.ecephys.current_source_density for details of CSD calculation.
CSD is computed with a 1D method (second spatial derivative) without prior spatial smoothing
User should apply spatial smoothing of their choice (e.g., Gaussian filter) to the computed CSD
Parameters
----------
probe_id : int
identify the probe whose CSD data ought to be loaded
Returns
-------
xr.DataArray :
dimensions are channel (id) and time (seconds, relative to stimulus onset). Values are current source
density assessed on that channel at that time (V/m^2)
"""
return self.api.get_current_source_density(probe_id)
[docs] def get_lfp(self, probe_id, mask_invalid_intervals=True):
''' Load an xarray DataArray with LFP data from channels on a single probe
Parameters
----------
probe_id : int
identify the probe whose LFP data ought to be loaded
mask_invalid_intervals : bool
if True (default) will mask data in the invalid intervals with np.nan
Returns
-------
xr.DataArray :
dimensions are channel (id) and time (seconds). Values are sampled LFP data.
Notes
-----
Unlike many other data access methods on this class. This one does not cache the loaded data in memory due to
the large size of the LFP data.
'''
if mask_invalid_intervals:
probe_name = self.probes.loc[probe_id]["description"]
fail_tags = ["all_probes", probe_name]
invalid_time_intervals = self._filter_invalid_times_by_tags(fail_tags)
lfp = self.api.get_lfp(probe_id)
time_points = lfp.time
valid_time_points = self._get_valid_time_points(time_points, invalid_time_intervals)
return lfp.where(cond=valid_time_points)
else:
return self.api.get_lfp(probe_id)
def _get_valid_time_points(self, time_points, invalid_time_intevals):
all_time_points = xr.DataArray(
name="time_points",
data=[True] * len(time_points),
dims=['time'],
coords=[time_points]
)
valid_time_points = all_time_points
for ix, invalid_time_interval in invalid_time_intevals.iterrows():
invalid_time_points = (time_points >= invalid_time_interval['start_time']) & (time_points <= invalid_time_interval['stop_time'])
valid_time_points = np.logical_and(valid_time_points, np.logical_not(invalid_time_points))
return valid_time_points
def _filter_invalid_times_by_tags(self, tags):
"""
Parameters
----------
invalid_times: pd.DataFrame
of invalid times
tags: list
of tags
Returns
-------
pd.DataFrame of invalid times having tags
"""
invalid_times = self.invalid_times.copy()
if not invalid_times.empty:
mask = invalid_times['tags'].apply(lambda x: any([t in x for t in tags]))
invalid_times = invalid_times[mask]
return invalid_times
[docs] def get_inter_presentation_intervals_for_stimulus(self, stimulus_names):
''' Get a subset of this session's inter-presentation intervals, filtered by stimulus name.
Parameters
----------
stimulus_names : array-like of str
The names of stimuli to include in the output.
Returns
-------
pd.DataFrame :
inter-presentation intervals, filtered to the requested stimulus names.
'''
stimulus_names = coerce_scalar(stimulus_names, f'expected stimulus_names to be a collection (list-like), but found {type(stimulus_names)}: {stimulus_names}')
filtered_presentations = self.stimulus_presentations[self.stimulus_presentations['stimulus_name'].isin(stimulus_names)]
filtered_ids = set(filtered_presentations.index.values)
return self.inter_presentation_intervals[
(self.inter_presentation_intervals.index.isin(filtered_ids, level='from_presentation_id'))
& (self.inter_presentation_intervals.index.isin(filtered_ids, level='to_presentation_id'))
]
[docs] def get_stimulus_table(self, stimulus_names=None, include_detailed_parameters=False, include_unused_parameters=False):
'''Get a subset of stimulus presentations by name, with irrelevant parameters filtered off
Parameters
----------
stimulus_names : array-like of str
The names of stimuli to include in the output.
Returns
-------
pd.DataFrame :
Rows are filtered presentations, columns are the relevant subset of stimulus parameters
'''
if stimulus_names is None:
stimulus_names = self.stimulus_names
stimulus_names = coerce_scalar(stimulus_names, f'expected stimulus_names to be a collection (list-like), but found {type(stimulus_names)}: {stimulus_names}')
presentations = self._stimulus_presentations[self._stimulus_presentations['stimulus_name'].isin(stimulus_names)]
if not include_detailed_parameters:
presentations = self.__class__._remove_detailed_stimulus_parameters(presentations)
if not include_unused_parameters:
presentations = removed_unused_stimulus_presentation_columns(presentations)
return presentations
[docs] def get_stimulus_epochs(self, duration_thresholds=None):
""" Reports continuous periods of time during which a single kind of stimulus was presented
flipVert
Parameters
---------
duration_thresholds : dict, optional
keys are stimulus names, values are floating point durations in seconds. All epochs with
- a given stimulus name
- a duration shorter than the associated threshold
will be removed from the results
"""
if duration_thresholds is None:
duration_thresholds = {"spontaneous_activity": 90.0}
presentations = self.stimulus_presentations.copy()
diff_indices = nan_intervals(presentations["stimulus_block"].values)
epochs = []
for left, right in zip(diff_indices[:-1], diff_indices[1:]):
epochs.append({
"start_time": presentations.iloc[left]["start_time"],
"stop_time": presentations.iloc[right - 1]["stop_time"],
"stimulus_name": presentations.iloc[left]["stimulus_name"],
"stimulus_block": presentations.iloc[left]["stimulus_block"]
})
epochs = pd.DataFrame(epochs)
epochs["duration"] = epochs["stop_time"] - epochs["start_time"]
for key, threshold in duration_thresholds.items():
epochs = epochs[
(epochs["stimulus_name"] != key)
| (epochs["duration"] >= threshold)
]
return epochs.loc[:, ["start_time", "stop_time", "duration", "stimulus_name", "stimulus_block"]]
[docs] def get_invalid_times(self):
""" Report invalid time intervals with tags describing the scope of invalid data
The tags format: [scope,scope_id,label]
scope:
'EcephysSession': data is invalid across session
'EcephysProbe': data is invalid for a single probe
label:
'all_probes': gain fluctuations on the Neuropixels probe result in missed spikes and LFP saturation events
'stimulus' : very long frames (>3x the normal frame length) make any stimulus-locked analysis invalid
'probe#': probe # stopped sending data during this interval (spikes and LFP samples will be missing)
'optotagging': missing optotagging data
Returns
-------
pd.DataFrame :
Rows are invalid intervals, columns are 'start_time' (s), 'stop_time' (s), 'tags'
"""
return self.invalid_times
[docs] def get_pupil_data(self, suppress_pupil_data: bool = True) -> pd.DataFrame:
"""Return a dataframe with eye tracking data
Parameters
----------
suppress_pupil_data : bool, optional
Whether or not to suppress eye gaze mapping data in output
dataframe, by default True.
Returns
-------
pd.DataFrame
Contains columns for eye, pupil and cr ellipse fits:
*_center_x
*_center_y
*_height
*_width
*_phi
May also contain raw/filtered columns for gaze mapping if
suppress_pupil_data is set to False:
*_eye_area
*_pupil_area
*_screen_coordinates_x_cm
*_screen_coordinates_y_cm
*_screen_coordinates_spherical_x_deg
*_screen_coorindates_spherical_y_deg
"""
return self.api.get_pupil_data(suppress_pupil_data=suppress_pupil_data)
def _mask_invalid_stimulus_presentations(self, stimulus_presentations):
"""Mask invalid stimulus presentations
Find stimulus presentations overlapping with invalid times
Mask stimulus names with "invalid_presentation", keep "start_time" and "stop_time", mask remaining data with np.nan
Parameters
----------
stimulus_presentations : pd.DataFrame
table including all stimulus presentations
Returns
-------
pd.DataFrame :
table with masked invalid presentations
"""
fail_tags = ["stimulus"]
invalid_times = self._filter_invalid_times_by_tags(fail_tags)
for ix_sp, sp in stimulus_presentations.iterrows():
stim_epoch = sp['start_time'], sp['stop_time']
for ix_it, it in invalid_times.iterrows():
invalid_interval = it['start_time'], it['stop_time']
if _overlap(stim_epoch, invalid_interval):
stimulus_presentations.iloc[ix_sp, :] = np.nan
stimulus_presentations.at[ix_sp, "stimulus_name"] = "invalid_presentation"
stimulus_presentations.at[ix_sp, "start_time"] = stim_epoch[0]
stimulus_presentations.at[ix_sp, "stop_time"] = stim_epoch[1]
return stimulus_presentations
[docs] def presentationwise_spike_counts(
self,
bin_edges,
stimulus_presentation_ids,
unit_ids,
binarize=False,
dtype=None,
large_bin_size_threshold=0.001,
time_domain_callback=None
):
''' Build an array of spike counts surrounding stimulus onset per unit and stimulus frame.
Parameters
---------
bin_edges : numpy.ndarray
Spikes will be counted into the bins defined by these edges. Values are in seconds, relative
to stimulus onset.
stimulus_presentation_ids : array-like
Filter to these stimulus presentations
unit_ids : array-like
Filter to these units
binarize : bool, optional
If true, all counts greater than 0 will be treated as 1. This results in lower storage overhead,
but is only reasonable if bin sizes are fine (<= 1 millisecond).
large_bin_size_threshold : float, optional
If binarize is True and the largest bin width is greater than this value, a warning will be emitted.
time_domain_callback : callable, optional
The time domain is a numpy array whose values are trial-aligned bin
edges (each row is aligned to a different trial). This optional function will be
applied to the time domain before counting spikes.
Returns
-------
xarray.DataArray :
Data array whose dimensions are stimulus presentation, unit,
and time bin and whose values are spike counts.
'''
stimulus_presentations = self._filter_owned_df('stimulus_presentations', ids=stimulus_presentation_ids)
units = self._filter_owned_df('units', ids=unit_ids)
largest_bin_size = np.amax(np.diff(bin_edges))
if binarize and largest_bin_size > large_bin_size_threshold:
warnings.warn(
f'You\'ve elected to binarize spike counts, but your maximum bin width is {largest_bin_size:2.5f} seconds. '
'Binarizing spike counts with such a large bin width can cause significant loss of accuracy! '
f'Please consider only binarizing spike counts when your bins are <= {large_bin_size_threshold} seconds wide.'
)
bin_edges = np.array(bin_edges)
domain = build_time_window_domain(bin_edges, stimulus_presentations['start_time'].values, callback=time_domain_callback)
out_of_order = np.where(np.diff(domain, axis=1) < 0)
if len(out_of_order[0]) > 0:
out_of_order_time_bins = [(row, col) for row, col in zip(out_of_order)]
raise ValueError(f"The time domain specified contains out-of-order bin edges at indices: {out_of_order_time_bins}")
ends = domain[:, -1]
starts = domain[:, 0]
time_diffs = starts[1:] - ends[:-1]
overlapping = np.where(time_diffs < 0)[0]
if len(overlapping) > 0:
# Ignoring intervals that overlaps multiple time bins because trying to figure that out would take O(n)
overlapping = [(s, s + 1) for s in overlapping]
warnings.warn(f"You've specified some overlapping time intervals between neighboring rows: {overlapping}, "
f"with a maximum overlap of {np.abs(np.min(time_diffs))} seconds.")
tiled_data = build_spike_histogram(
domain, self.spike_times, units.index.values, dtype=dtype, binarize=binarize
)
tiled_data = xr.DataArray(
name='spike_counts',
data=tiled_data,
coords={
'stimulus_presentation_id': stimulus_presentations.index.values,
'time_relative_to_stimulus_onset': bin_edges[:-1] + np.diff(bin_edges) / 2,
'unit_id': units.index.values
},
dims=['stimulus_presentation_id', 'time_relative_to_stimulus_onset', 'unit_id']
)
return tiled_data
[docs] def presentationwise_spike_times(self, stimulus_presentation_ids=None, unit_ids=None):
''' Produce a table associating spike times with units and stimulus presentations
Parameters
----------
stimulus_presentation_ids : array-like
Filter to these stimulus presentations
unit_ids : array-like
Filter to these units
Returns
-------
pandas.DataFrame :
Index is
spike_time : float
On the session's master clock.
Columns are
stimulus_presentation_id : int
The stimulus presentation on which this spike occurred.
unit_id : int
The unit that emitted this spike.
'''
stimulus_presentations = self._filter_owned_df('stimulus_presentations', ids=stimulus_presentation_ids)
units = self._filter_owned_df('units', ids=unit_ids)
presentation_times = np.zeros([stimulus_presentations.shape[0] * 2])
presentation_times[::2] = np.array(stimulus_presentations['start_time'])
presentation_times[1::2] = np.array(stimulus_presentations['stop_time'])
all_presentation_ids = np.array(stimulus_presentations.index.values)
presentation_ids = []
unit_ids = []
spike_times = []
for ii, unit_id in enumerate(units.index.values):
data = self.spike_times[unit_id]
indices = np.searchsorted(presentation_times, data) - 1
index_valid = indices % 2 == 0
presentations = all_presentation_ids[np.floor(indices / 2).astype(int)]
sorder = np.argsort(presentations)
presentations = presentations[sorder]
index_valid = index_valid[sorder]
data = data[sorder]
changes = np.where(np.ediff1d(presentations, to_begin=1, to_end=1))[0]
for ii, jj in zip(changes[:-1], changes[1:]):
values = data[ii:jj][index_valid[ii:jj]]
if values.size == 0:
continue
unit_ids.append(np.zeros([values.size]) + unit_id)
presentation_ids.append(np.zeros([values.size]) + presentations[ii])
spike_times.append(values)
if not spike_times:
# If there are no units firing during the given stimulus return an empty dataframe
return pd.DataFrame(columns=['spike_times', 'stimulus_presentation',
'unit_id', 'time_since_stimulus_presentation_onset'])
spike_df = pd.DataFrame({
'stimulus_presentation_id': np.concatenate(presentation_ids).astype(int),
'unit_id': np.concatenate(unit_ids).astype(int)
}, index=pd.Index(np.concatenate(spike_times), name='spike_time'))
# Add time since stimulus presentation onset
onset_times = self._filter_owned_df(
"stimulus_presentations", ids=all_presentation_ids)["start_time"]
spikes_with_onset = spike_df.join(onset_times,
on=["stimulus_presentation_id"])
spikes_with_onset["time_since_stimulus_presentation_onset"] = (
spikes_with_onset.index - spikes_with_onset["start_time"]
)
spikes_with_onset.sort_values('spike_time', axis=0, inplace=True)
spikes_with_onset.drop(columns=["start_time"], inplace=True)
return spikes_with_onset
[docs] def conditionwise_spike_statistics(self, stimulus_presentation_ids=None, unit_ids=None, use_rates=False):
""" Produce summary statistics for each distinct stimulus condition
Parameters
----------
stimulus_presentation_ids : array-like
identifies stimulus presentations from which spikes will be considered
unit_ids : array-like
identifies units whose spikes will be considered
use_rates : bool, optional
If True, use firing rates. If False, use spike counts.
Returns
-------
pd.DataFrame :
Rows are indexed by unit id and stimulus condition id. Values are summary statistics describing spikes
emitted by a specific unit across presentations within a specific condition.
"""
# TODO: Need to return an empty df if no matching unit-ids or presentation-ids are found
# TODO: To use filter_owned_df() make sure to convert the results from a Series to a Dataframe
stimulus_presentation_ids = (stimulus_presentation_ids if stimulus_presentation_ids is not None
else self.stimulus_presentations.index.values) # In case
presentations = self.stimulus_presentations.loc[stimulus_presentation_ids, ["stimulus_condition_id", "duration"]]
spikes = self.presentationwise_spike_times(
stimulus_presentation_ids=stimulus_presentation_ids, unit_ids=unit_ids
)
if spikes.empty:
# In the case there are no spikes
spike_counts = pd.DataFrame({'spike_count': 0},
index=pd.MultiIndex.from_product([stimulus_presentation_ids, unit_ids],
names=['stimulus_presentation_id', 'unit_id']))
else:
spike_counts = spikes.copy()
spike_counts["spike_count"] = np.zeros(spike_counts.shape[0])
spike_counts = spike_counts.groupby(["stimulus_presentation_id", "unit_id"]).count()
unit_ids = unit_ids if unit_ids is not None else spikes['unit_id'].unique() # If not explicity stated get unit ids from spikes table.
spike_counts = spike_counts.reindex(pd.MultiIndex.from_product([stimulus_presentation_ids,
unit_ids],
names=['stimulus_presentation_id',
'unit_id']), fill_value=0)
sp = pd.merge(spike_counts, presentations, left_on="stimulus_presentation_id", right_index=True, how="left")
sp.reset_index(inplace=True)
if use_rates:
sp["spike_rate"] = sp["spike_count"] / sp["duration"]
sp.drop(columns=["spike_count"], inplace=True)
extractor = _extract_summary_rate_statistics
else:
sp.drop(columns=["duration"])
extractor = _extract_summary_count_statistics
summary = []
for ind, gr in sp.groupby(["stimulus_condition_id", "unit_id"]):
summary.append(extractor(ind, gr))
return pd.DataFrame(summary).set_index(keys=["unit_id", "stimulus_condition_id"])
[docs] def get_parameter_values_for_stimulus(self, stimulus_name, drop_nulls=True):
""" For each stimulus parameter, report the unique values taken on by that
parameter while a named stimulus was presented.
Parameters
----------
stimulus_name : str
filter to presentations of this stimulus
Returns
-------
dict :
maps parameters (column names) to their unique values.
"""
presentation_ids = self.get_stimulus_table([stimulus_name]).index.values
return self.get_stimulus_parameter_values(presentation_ids, drop_nulls=drop_nulls)
[docs] def get_stimulus_parameter_values(self, stimulus_presentation_ids=None, drop_nulls=True):
''' For each stimulus parameter, report the unique values taken on by that
parameter throughout the course of the session.
Parameters
----------
stimulus_presentation_ids : array-like, optional
If provided, only parameter values from these stimulus presentations will be considered.
Returns
-------
dict :
maps parameters (column names) to their unique values.
'''
stimulus_presentations = self._filter_owned_df('stimulus_presentations', ids=stimulus_presentation_ids)
stimulus_presentations = stimulus_presentations.drop(columns=list(NON_STIMULUS_PARAMETERS) + ['stimulus_name'])
stimulus_presentations = removed_unused_stimulus_presentation_columns(stimulus_presentations)
parameters = {}
for colname in stimulus_presentations.columns:
uniques = stimulus_presentations[colname].unique()
non_null = np.array(uniques[uniques != "null"])
non_null = non_null
non_null = np.sort(non_null)
if not drop_nulls and "null" in uniques:
non_null = np.concatenate([non_null, ["null"]])
parameters[colname] = non_null
return parameters
[docs] def channel_structure_intervals(self, channel_ids):
""" find on a list of channels the intervals of channels inserted into particular structures
Parameters
----------
channel_ids : list
A list of channel ids
structure_id_key : str
use this column for numerically identifying structures
structure_label_key : str
use this column for human-readable structure identification
Returns
-------
labels : np.ndarray
for each detected interval, the label associated with that interval
intervals : np.ndarray
one element longer than labels. Start and end indices for intervals.
"""
structure_id_key = "ecephys_structure_id"
structure_label_key = "ecephys_structure_acronym"
np.array(channel_ids).sort()
table = self.channels.loc[channel_ids]
unique_probes = table["probe_id"].unique()
if len(unique_probes) > 1:
warnings.warn("Calculating structure boundaries across channels from multiple probes.")
intervals = nan_intervals(table[structure_id_key].values)
labels = table[structure_label_key].iloc[intervals[:-1]].values
return labels, intervals
def _build_spike_times(self, spike_times):
retained_units = set(self._units.index.values)
output_spike_times = {}
for unit_id in list(spike_times.keys()):
data = spike_times.pop(unit_id)
if unit_id not in retained_units:
continue
output_spike_times[unit_id] = data
return output_spike_times
def _build_stimulus_presentations(self, stimulus_presentations, nonapplicable="null"):
stimulus_presentations.index.name = 'stimulus_presentation_id'
stimulus_presentations = stimulus_presentations.drop(columns=['stimulus_index'])
# TODO: putting these here for now; after SWDB 2019, will rerun stimulus table module for all sessions
# and can remove these
stimulus_presentations = naming_utilities.collapse_columns(stimulus_presentations)
stimulus_presentations = naming_utilities.standardize_movie_numbers(stimulus_presentations)
stimulus_presentations = naming_utilities.add_number_to_shuffled_movie(stimulus_presentations)
stimulus_presentations = naming_utilities.map_stimulus_names(
stimulus_presentations, default_stimulus_renames
)
stimulus_presentations = naming_utilities.map_column_names(stimulus_presentations, default_column_renames)
# pandas groupby ops ignore nans, so we need a new "nonapplicable" value that pandas does not recognize as null ...
stimulus_presentations.replace("", nonapplicable, inplace=True)
stimulus_presentations.fillna(nonapplicable, inplace=True)
stimulus_presentations['duration'] = stimulus_presentations['stop_time'] - stimulus_presentations['start_time']
# TODO: database these
stimulus_conditions = {}
presentation_conditions = []
cid_counter = -1
# TODO: Can we have parameters on what columns to omit? If stimulus_block or duration is left in it can affect
# how conditionwise_spike_statistics counts spikes
params_only = stimulus_presentations.drop(columns=["start_time", "stop_time", "duration", "stimulus_block"])
for row in params_only.itertuples(index=False):
if row in stimulus_conditions:
cid = stimulus_conditions[row]
else:
cid_counter += 1
stimulus_conditions[row] = cid_counter
cid = cid_counter
presentation_conditions.append(cid)
cond_ids = []
cond_vals = []
for cv, ci in stimulus_conditions.items():
cond_ids.append(ci)
cond_vals.append(cv)
self._stimulus_conditions = pd.DataFrame(cond_vals, index=pd.Index(data=cond_ids, name="stimulus_condition_id"))
stimulus_presentations["stimulus_condition_id"] = presentation_conditions
return stimulus_presentations
def _build_units_table(self, units_table):
channels = self.channels.copy()
probes = self.probes.copy()
self._unmerged_units = units_table.copy()
table = pd.merge(units_table, channels, left_on='peak_channel_id', right_index=True, suffixes=['_unit', '_channel'])
table = pd.merge(table, probes, left_on='probe_id', right_index=True, suffixes=['_unit', '_probe'])
table.index.name = 'unit_id'
table = table.rename(columns={
'description': 'probe_description',
'local_index_channel': 'channel_local_index',
'PT_ratio': 'waveform_PT_ratio',
'amplitude': 'waveform_amplitude',
'duration': 'waveform_duration',
'halfwidth': 'waveform_halfwidth',
'recovery_slope': 'waveform_recovery_slope',
'repolarization_slope': 'waveform_repolarization_slope',
'spread': 'waveform_spread',
'velocity_above': 'waveform_velocity_above',
'velocity_below': 'waveform_velocity_below',
'sampling_rate': 'probe_sampling_rate',
'lfp_sampling_rate': 'probe_lfp_sampling_rate',
'has_lfp_data': 'probe_has_lfp_data',
'l_ratio': 'L_ratio',
'pref_images_multi_ns': 'pref_image_multi_ns',
})
return table.sort_values(by=['probe_description', 'probe_vertical_position', 'probe_horizontal_position'])
def _build_nwb1_waveforms(self, mean_waveforms):
# _build_mean_waveforms() assumes every unit has the same number of waveforms and that a unit-waveform exists
# for all channels. This is not true for NWB 1 files where each unit has ONE waveform on ONE channel
units_df = self._units
output_waveforms = {}
sampling_rate_lu = {uid: self.probes.loc[r['probe_id']]['sampling_rate'] for uid, r in units_df.iterrows()}
for uid in list(mean_waveforms.keys()):
data = mean_waveforms.pop(uid)
output_waveforms[uid] = xr.DataArray(
data=data,
dims=['channel_id', 'time'],
coords={
'channel_id': [units_df.loc[uid]['peak_channel_id']],
'time': np.arange(data.shape[1]) / sampling_rate_lu[uid]
}
)
return output_waveforms
def _build_mean_waveforms(self, mean_waveforms):
if isinstance(self.api, EcephysNwb1Api):
return self._build_nwb1_waveforms(mean_waveforms)
channel_id_lut = defaultdict(lambda: -1)
for cid, row in self.channels.iterrows():
channel_id_lut[(row["local_index"], row["probe_id"])] = cid
probe_id_lut = {uid: row['probe_id'] for uid, row in self._units.iterrows()}
output_waveforms = {}
for uid in list(mean_waveforms.keys()):
data = mean_waveforms.pop(uid)
if uid not in probe_id_lut: # It's been filtered out during unit table generation!
continue
probe_id = probe_id_lut[uid]
output_waveforms[uid] = xr.DataArray(
data=data,
dims=['channel_id', 'time'],
coords={
'channel_id': [channel_id_lut[(ii, probe_id)] for ii in range(data.shape[0])],
'time': np.arange(data.shape[1]) / self.probes.loc[probe_id]['sampling_rate']
}
)
output_waveforms[uid] = output_waveforms[uid][output_waveforms[uid]["channel_id"] != -1]
return output_waveforms
def _build_inter_presentation_intervals(self):
intervals = pd.DataFrame({
'from_presentation_id': self.stimulus_presentations.index.values[:-1],
'to_presentation_id': self.stimulus_presentations.index.values[1:],
'interval': self.stimulus_presentations['start_time'].values[1:] - self.stimulus_presentations['stop_time'].values[:-1]
})
return intervals.set_index(['from_presentation_id', 'to_presentation_id'], inplace=False)
def _filter_owned_df(self, key, ids=None, copy=True):
df = getattr(self, key)
if copy:
df = df.copy()
if ids is None:
return df
ids = coerce_scalar(ids, f'a scalar ({ids}) was provided as ids, filtering to a single row of {key}.')
df = df.loc[ids]
if df.shape[0] == 0:
warnings.warn(f'filtering to an empty set of {key}!')
return df
@classmethod
def _remove_detailed_stimulus_parameters(cls, presentations):
columns = list(cls.DETAILED_STIMULUS_PARAMETERS)
return presentations.drop(columns=columns, errors="ignore")
[docs] @classmethod
def from_nwb_path(cls, path, nwb_version=2, api_kwargs=None, **kwargs):
api_kwargs = {} if api_kwargs is None else api_kwargs
# TODO: Is there a way for pynwb to check the file before actually loading it with io read? If so we could
# automatically check what NWB version is being inputed
nwb_version = int(nwb_version) # only use major version
if nwb_version >= 2:
NWBAdaptorCls = EcephysNwbSessionApi
elif nwb_version == 1:
NWBAdaptorCls = EcephysNwb1Api
else:
raise Exception(f'specified NWB version {nwb_version} not supported. Supported versions are: 2.X, 1.X')
return cls(api=NWBAdaptorCls.from_path(path=path, **api_kwargs), **kwargs)
def _warn_invalid_spike_intervals(self):
fail_tags = list(self.probes["description"])
fail_tags.append("all_probes")
invalid_time_intervals = self._filter_invalid_times_by_tags(fail_tags)
if not invalid_time_intervals.empty:
warnings.warn("Session includes invalid time intervals that could be accessed with the attribute 'invalid_times',"
"Spikes within these intervals are invalid and may need to be excluded from the analysis.")
[docs]def build_spike_histogram(time_domain, spike_times, unit_ids, dtype=None, binarize=False):
time_domain = np.array(time_domain)
unit_ids = np.array(unit_ids)
tiled_data = np.zeros(
(time_domain.shape[0], time_domain.shape[1] - 1, unit_ids.size),
dtype=(np.uint8 if binarize else np.uint16) if dtype is None else dtype
)
starts = time_domain[:, :-1]
ends = time_domain[:, 1:]
for ii, unit_id in enumerate(unit_ids):
data = np.array(spike_times[unit_id])
start_positions = np.searchsorted(data, starts.flat)
end_positions = np.searchsorted(data, ends.flat, side="right")
counts = (end_positions - start_positions)
tiled_data[:, :, ii].flat = counts > 0 if binarize else counts
return tiled_data
[docs]def build_time_window_domain(bin_edges, offsets, callback=None):
callback = (lambda x: x) if callback is None else callback
domain = np.tile(bin_edges[None, :], (len(offsets), 1))
domain += offsets[:, None]
return callback(domain)
[docs]def removed_unused_stimulus_presentation_columns(stimulus_presentations):
to_drop = []
for cn in stimulus_presentations.columns:
if np.all(stimulus_presentations[cn].isna()):
to_drop.append(cn)
elif np.all(stimulus_presentations[cn].astype(str).values == ''):
to_drop.append(cn)
elif np.all(stimulus_presentations[cn].astype(str).values == 'null'):
to_drop.append(cn)
return stimulus_presentations.drop(columns=to_drop)
[docs]def nan_intervals(array, nan_like=["null"]):
""" find interval bounds (bounding consecutive identical values) in an array, which may contain nans
Parameters
-----------
array : np.ndarray
Returns
-------
np.ndarray :
start and end indices of detected intervals (one longer than the number of intervals)
"""
intervals = [0]
current = array[0]
for ii, item in enumerate(array[1:]):
if is_distinct_from(item, current):
intervals.append(ii + 1)
current = item
intervals.append(len(array))
return np.unique(intervals)
[docs]def is_distinct_from(left, right):
if type(left) != type(right):
return True
if pd.isna(left) and pd.isna(right):
return False
if left is None and right is None:
return False
return left != right
[docs]def array_intervals(array):
""" find interval bounds (bounding consecutive identical values) in an array
Parameters
-----------
array : np.ndarray
Returns
-------
np.ndarray :
start and end indices of detected intervals (one longer than the number of intervals)
"""
changes = np.flatnonzero(np.diff(array)) + 1
return np.concatenate([[0], changes, [len(array)]])
[docs]def coerce_scalar(value, message, warn=False):
if not isinstance(value, Collection) or isinstance(value, str):
if warn:
warnings.warn(message)
return [value]
return value
def _extract_summary_count_statistics(index, group):
return {
"stimulus_condition_id": index[0],
"unit_id": index[1],
"spike_count": group["spike_count"].sum(),
"stimulus_presentation_count": group.shape[0],
"spike_mean": np.mean(group["spike_count"].values),
"spike_std": np.std(group["spike_count"].values, ddof=1),
"spike_sem": scipy.stats.sem(group["spike_count"].values)
}
def _extract_summary_rate_statistics(index, group):
return {
"stimulus_condition_id": index[0],
"unit_id": index[1],
"stimulus_presentation_count": group.shape[0],
"spike_mean": np.mean(group["spike_rate"].values),
"spike_std": np.std(group["spike_rate"].values, ddof=1),
"spike_sem": scipy.stats.sem(group["spike_rate"].values)
}
def _overlap(a, b):
"""Check if the two intervals overlap
Parameters
----------
a : tuple
start, stop times
b : tuple
start, stop times
Returns
-------
bool : True if overlap, otherwise False
"""
return max(a[0], b[0]) <= min(a[1], b[1])