Source code for allensdk.brain_observatory.ecephys.stimulus_analysis.flashes

import numpy as np
import pandas as pd
from six import string_types
import scipy.ndimage as ndi
import scipy.stats as st
from scipy.optimize import curve_fit
import logging

import matplotlib.pyplot as plt

from .stimulus_analysis import StimulusAnalysis, get_fr

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)


logger = logging.getLogger(__name__)


[docs]class Flashes(StimulusAnalysis): """ A class for computing single-unit metrics from the full-field flash stimulus of an ecephys session NWB file. To use, pass in a EcephysSession object:: session = EcephysSession.from_nwb_path('/path/to/my.nwb') fl_analysis = Flashes(session) or, alternatively, pass in the file path:: fl_analysis = Flashes('/path/to/my.nwb') You can also pass in a unit filter dictionary which will only select units with certain properties. For example to get only those units which are on probe C and found in the VISp area:: fl_analysis = Flashes(session, filter={'location': 'probeC', 'ecephys_structure_acronym': 'VISp'}) To get a table of the individual unit metrics ranked by unit ID:: metrics_table_df = fl_analysis.metrics() """ def __init__(self, ecephys_session, col_color='color', trial_duration=0.25, **kwargs): super(Flashes, self).__init__(ecephys_session, trial_duration=trial_duration, **kwargs) self._metrics = None self._colors = None self._col_color = col_color if self._params is not None: self._params = self._params.get('flashes', {}) self._stimulus_key = self._params.get('stimulus_key', None) # Overwrites parent value with argvars else: self._params = {} @property def name(self): return 'Flashes' @property def colors(self): """ Array of 'color' conditions (black vs. white flash) """ if self._colors is None: self._get_stim_table_stats() return self._colors @property def number_colors(self): """ Number of 'color' conditions (black vs. white flash) """ if self._colors is None: self._get_stim_table_stats() return len(self._colors) @property def null_condition(self): """ Stimulus condition ID for null stimulus (not used, so set to -1) """ # TODO: If null_condition is not used remove it, parent should have it set to 1 return -1 @property def METRICS_COLUMNS(self): return [('on_off_ratio_fl', np.float64), ('sustained_idx_fl', np.float64), ('firing_rate_fl', np.float64), ('time_to_peak_fl', np.float64), ('fano_fl', np.float64), ('lifetime_sparseness_fl', np.float64), ('run_pval_fl', np.float64), ('run_mod_fl', np.float64)] @property def metrics(self): if self._metrics is None: logger.info('Calculating metrics for ' + self.name) unit_ids = self.unit_ids metrics_df = self.empty_metrics_table() if len(self. stim_table) > 0: metrics_df['on_off_ratio_fl'] = [self._get_on_off_ratio(unit) for unit in unit_ids] metrics_df['sustained_idx_fl'] = [self._get_sustained_index(unit, self._get_preferred_condition(unit)) for unit in unit_ids] metrics_df['firing_rate_fl'] = [self._get_overall_firing_rate(unit) for unit in unit_ids] metrics_df['time_to_peak_fl'] = [self._get_time_to_peak(unit, self._get_preferred_condition(unit)) for unit in unit_ids] metrics_df['fano_fl'] = [self._get_fano_factor(unit, self._get_preferred_condition(unit)) for unit in unit_ids] metrics_df['lifetime_sparseness_fl'] = [self._get_lifetime_sparseness(unit) for unit in unit_ids] metrics_df.loc[:, ['run_pval_fl', 'run_mod_fl']] = [ self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids] self._metrics = metrics_df return self._metrics def _find_stimulus_key(self, stim_table): """Tries to guess the correct stimulus_key based on the data. :param stim_table: :return: """ known_keys_lc = [k.lower() for k in self.__class__.known_stimulus_keys()] for table_key in stim_table['stimulus_name'].unique(): table_key_lc = table_key.lower() for known_key in known_keys_lc: if table_key_lc.startswith(known_key): return table_key else: return None
[docs] @classmethod def known_stimulus_keys(cls): return ['flash', 'flashes']
def _get_stim_table_stats(self): """ Extract colors from the stimulus table """ self._colors = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_color] != 'null'][self._col_color].unique()) def _get_sustained_index(self, unit_id, condition_id): """ Calculate the sustained index for a given unit, a measure of the transience of the flash response. Parameters ---------- unit_id : int unique ID for the unit of interest Returns ------- sustained_index : ratio of the mean PSTH and the maximum of the PSTH A cell that fires very transiently will have a sustained index close to 0 A cell that first continuously throughout the flash will have a sustained index closer to 1 """ psth = self.conditionwise_psth.sel(unit_id=unit_id, stimulus_condition_id=condition_id).data return np.mean(psth)/np.amax(psth) def _get_on_off_ratio(self, unit_id): """Gets the ratio of mean spikes for on-stimuli vs off stimuli. Parameters ---------- unit_id : int unique ID for the unit of interest Returns ------- on_off_ratio : float """ on_condition_id = self.stimulus_conditions[self.stimulus_conditions[self._col_color] == 1.0].index.values off_condition_id = self.stimulus_conditions[self.stimulus_conditions[self._col_color] == -1.0].index.values on_mean_spikes = self.conditionwise_statistics.loc[unit_id].loc[on_condition_id]['spike_mean'].values off_mean_spikes = self.conditionwise_statistics.loc[unit_id].loc[off_condition_id]['spike_mean'].values if len(on_mean_spikes) == 0 or len(off_mean_spikes) == 0: return np.nan if off_mean_spikes[0] > 0: return on_mean_spikes[0] / off_mean_spikes[0] else: return np.nan ## VISUALIZATION ##
[docs] def plot_raster(self, stimulus_condition_id, unit_id): """ Plot raster for one condition and one unit """ idx_color = np.where(self.colors == self.stimulus_conditions.loc[stimulus_condition_id][self._col_color])[0] if len(idx_color) == 1: presentation_ids = self.presentationwise_statistics.xs(unit_id, level=1)[ self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id'] == stimulus_condition_id].index.values df = self.presentationwise_spike_times[ (self.presentationwise_spike_times['stimulus_presentation_id'].isin(presentation_ids)) & (self.presentationwise_spike_times['unit_id'] == unit_id) ] x = df.index.values - self.stim_table.loc[df.stimulus_presentation_id].start_time _, y = np.unique(df.stimulus_presentation_id, return_inverse=True) plt.subplot(self.number_colors, 1, idx_color + 1) plt.scatter(x, y, c='k', s=1, alpha=0.25) plt.axis('off')
[docs] def plot_response(self, unit_id): """ Plot a histogram for the two conditions """ plot_colors = ('darkslateblue', 'grey') for idx, color in enumerate(self.colors): condition_id = self.stimulus_conditions[self.stimulus_conditions['color'] == color].index.values[0] psth = self.conditionwise_psth.sel(unit_id=unit_id, stimulus_condition_id=condition_id).values plt.bar(np.arange(len(psth))-0.5, psth, color=plot_colors[idx], alpha=0.5, width=1.0) plt.step(np.arange(len(psth)), psth, color=plot_colors[idx]) plt.axis('off')