Source code for allensdk.brain_observatory.ecephys.stimulus_analysis.static_gratings

from six import string_types
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit
from functools import partial
import logging

import matplotlib.pyplot as plt

from .stimulus_analysis import StimulusAnalysis
from .stimulus_analysis import osi, deg2rad
from ...circle_plots import FanPlotter

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)


logger = logging.getLogger(__name__)


[docs]class StaticGratings(StimulusAnalysis): """ A class for computing single-unit metrics from the static gratings stimulus of an ecephys session NWB file. To use, pass in a EcephysSession object:: session = EcephysSession.from_nwb_path('/path/to/my.nwb') sg_analysis = StaticGratings(session) or, alternatively, pass in the file path:: sg_analysis = StaticGratings('/path/to/my.nwb') You can also pass in a unit filter dictionary which will only select units with certain properties. For example to get only those units which are on probe C and found in the VISp area:: sg_analysis = StaticGratings(session, filter={'location': 'probeC', 'ecephys_structure_acronym': 'VISp'}) To get a table of the individual unit metrics ranked by unit ID:: metrics_table_df = sg_analysis.metrics() """ def __init__(self, ecephys_session, col_ori='orientation', col_sf='spatial_frequency', col_phase='phase', trial_duration=0.25, **kwargs): super(StaticGratings, self).__init__(ecephys_session, trial_duration=trial_duration, **kwargs) self._orivals = None self._number_ori = None self._sfvals = None self._number_sf = None self._phasevals = None self._number_phase = None # self._response_events = None # self._response_trials = None self._metrics = None self._col_ori = col_ori self._col_sf = col_sf self._col_phase = col_phase self._trial_duration = trial_duration # self._module_name = 'Static Gratings' # TODO: module_name should be a static class variable if self._params is not None: self._params = self._params.get('static_gratings', {}) self._stimulus_key = self._params.get('stimulus_key', None) # Overwrites parent value with argvars else: self._params = {} @property def name(self): return 'Static Gratings' @property def orivals(self): """ Array of grating orientation conditions """ if self._orivals is None: self._get_stim_table_stats() return self._orivals @property def number_ori(self): """ Number of grating orientation conditions """ if self._number_ori is None: self._get_stim_table_stats() return self._number_ori @property def sfvals(self): """ Array of grating spatial frequency conditions """ if self._sfvals is None: self._get_stim_table_stats() return self._sfvals @property def number_sf(self): """ Number of grating orientation conditions """ if self._number_sf is None: self._get_stim_table_stats() return self._number_sf @property def phasevals(self): """ Array of grating phase conditions """ if self._phasevals is None: self._get_stim_table_stats() return self._phasevals @property def number_phase(self): """ Number of grating phase conditions """ if self._number_phase is None: self._get_stim_table_stats() return self._number_phase @property def null_condition(self): """ Stimulus condition ID for null (blank) stimulus """ return self.stimulus_conditions[self.stimulus_conditions[self._col_sf] == 'null'].index @property def METRICS_COLUMNS(self): return [('pref_sf_sg', np.float64), ('pref_sf_multi_sg', bool), ('pref_ori_sg', np.float64), ('pref_ori_multi_sg', bool), ('pref_phase_sg', np.float64), ('pref_phase_multi_sg', bool), ('g_osi_sg', np.float64), ('time_to_peak_sg', np.float64), ('firing_rate_sg', np.float64), ('fano_sg', np.float64), ('lifetime_sparseness_sg', np.float64), ('run_pval_sg', np.float64), ('run_mod_sg', np.float64)] @property def metrics(self): if self._metrics is None: logger.info('Calculating metrics for ' + self.name) unit_ids = self.unit_ids metrics_df = self.empty_metrics_table() if len(self.stim_table) > 0: metrics_df['pref_sf_sg'] = [self._get_pref_sf(unit) for unit in unit_ids] metrics_df['pref_sf_multi_sg'] = [ self._check_multiple_pref_conditions(unit_id, self._col_sf, self.sfvals) for unit_id in unit_ids ] metrics_df['pref_ori_sg'] = [self._get_pref_ori(unit) for unit in unit_ids] metrics_df['pref_ori_multi_sg'] = [ self._check_multiple_pref_conditions(unit_id, self._col_ori, self.orivals) for unit_id in unit_ids ] metrics_df['pref_phase_sg'] = [self._get_pref_phase(unit) for unit in unit_ids] metrics_df['pref_phase_multi_sg'] = [ self._check_multiple_pref_conditions(unit_id, self._col_phase, self.phasevals) for unit_id in unit_ids ] metrics_df['g_osi_sg'] = [self._get_osi(unit, metrics_df.loc[unit]['pref_sf_sg'], metrics_df.loc[unit]['pref_phase_sg']) for unit in unit_ids] metrics_df['time_to_peak_sg'] = [self._get_time_to_peak(unit, self._get_preferred_condition(unit)) for unit in unit_ids] metrics_df['firing_rate_sg'] = [self._get_overall_firing_rate(unit) for unit in unit_ids] metrics_df['fano_sg'] = [self._get_fano_factor(unit, self._get_preferred_condition(unit)) for unit in unit_ids] metrics_df['lifetime_sparseness_sg'] = [self._get_lifetime_sparseness(unit) for unit in unit_ids] metrics_df.loc[:, ['run_pval_sg', 'run_mod_sg']] = \ [self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids] self._metrics = metrics_df return self._metrics
[docs] @classmethod def known_stimulus_keys(cls): return ['static_gratings']
def _get_stim_table_stats(self): """ Extract orientations, spatial frequencies, and phases from the stimulus table """ self._orivals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_ori] != 'null'][self._col_ori].unique()) self._number_ori = len(self._orivals) self._sfvals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_sf] != 'null'][self._col_sf].unique()) self._number_sf = len(self._sfvals) self._phasevals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_phase] != 'null'][self._col_phase].unique()) self._number_phase = len(self._phasevals) def _get_pref_sf(self, unit_id): """Calculate the preferred spatial frequency condition for a given unit. Parameters ---------- unit_id : int unique ID for the unit of interest Returns ------- pref_sf : float spatial frequency driving the maximal response """ # TODO: Most of the _get_pref_*() methods can be combined into one method and shared among the classes # Combine the stimulus_condition_id values that have the save spatial-frequency similar_conditions_ids = [self.stimulus_conditions.index[self.stimulus_conditions[self._col_sf] == sf].tolist() for sf in self.sfvals] # For each spatial frequency average up conditionwise_statistics 'spike_mean' column using the indicies above. # return the sf with the largest spike_mean. df = pd.DataFrame( index=self.sfvals, data={'spike_mean': [self.conditionwise_statistics.loc[unit_id].loc[condition_inds]['spike_mean'].mean() for condition_inds in similar_conditions_ids]} ).rename_axis(self._col_sf) return df.idxmax().iloc[0] def _get_pref_ori(self, unit_id): """ Calculate the preferred orientation condition for a given unit Parameters ---------- unit_id : int unique ID for the unit of interest Returns ------- pref_ori :float stimulus orientation driving the maximal response """ # Combine the stimulus_condition_id values that have the save orientations similar_conditions = [self.stimulus_conditions.index[self.stimulus_conditions[self._col_ori] == ori].tolist() for ori in self.orivals] # For each orientations average up conditionwise_statistics 'spike_mean' column using the indicies above. # Return the oris with the largest spike_mean. df = pd.DataFrame( index=self.orivals, data={'spike_mean': [self.conditionwise_statistics.loc[unit_id].loc[condition_inds]['spike_mean'].mean() for condition_inds in similar_conditions]} ).rename_axis(self._col_ori) return df.idxmax().iloc[0] def _get_pref_phase(self, unit_id): """Calculate the preferred phase condition for a given unit Parameters ---------- unit_id : int unique ID for the unit of interest Returns ------- pref_phase : float stimulus phase driving the maximal response """ combined_cond_ids = [self.stimulus_conditions.index[self.stimulus_conditions[self._col_phase] == phase].tolist() for phase in self.phasevals] df = pd.DataFrame( index=self.phasevals, data = {'spike_mean': [self.conditionwise_statistics.loc[unit_id].loc[condition_inds]['spike_mean'].mean() for condition_inds in combined_cond_ids]} ).rename_axis(self._col_phase) return df.idxmax().iloc[0] def _get_osi(self, unit_id, pref_sf, pref_phase): """ Calculate the orientation selectivity for a given unit Parameters ---------- unit_id : int unique ID for the unit of interest pref_sf : float preferred spatial frequency for this unit pref_phase : float preferred phase for this unit Returns ------- osi : float orientation selectivity value """ orivals_rad = deg2rad(self.orivals).astype('complex128') # TODO: can we use numpy deg2rad? condition_inds = self.stimulus_conditions[ (self.stimulus_conditions[self._col_sf] == pref_sf) & (self.stimulus_conditions[self._col_phase] == pref_phase) ].index.values df = self.conditionwise_statistics.loc[unit_id].loc[condition_inds] df = df.assign(ori=self.stimulus_conditions.loc[df.index.values][self._col_ori]) df = df.sort_values(by=['ori']) tuning = np.array(df['spike_mean'].values) return osi(orivals_rad, tuning) ## VISUALIZATION ##
[docs] def plot_raster(self, stimulus_condition_id, unit_id): """ Plot raster for one condition and one unit """ idx_sf = np.where(self.sfvals == self.stimulus_conditions.loc[stimulus_condition_id][self._col_sf])[0] idx_ori = np.where(self.orivals == self.stimulus_conditions.loc[stimulus_condition_id][self._col_ori])[0] if len(idx_sf) == len(idx_ori) == 1: presentation_ids = \ self.presentationwise_statistics.xs(unit_id, level=1)\ [self.presentationwise_statistics.xs(unit_id, level=1)\ ['stimulus_condition_id'] == stimulus_condition_id].index.values df = self.presentationwise_spike_times[ \ (self.presentationwise_spike_times['stimulus_presentation_id'].isin(presentation_ids)) & \ (self.presentationwise_spike_times['unit_id'] == unit_id) ] x = df.index.values - self.stim_table.loc[df.stimulus_presentation_id].start_time _, y = np.unique(df.stimulus_presentation_id, return_inverse=True) plt.subplot(self.number_sf, self.number_ori, idx_sf*self.number_ori + idx_ori + 1) plt.scatter(x, y, c='k', s=1, alpha=0.25) plt.axis('off')
[docs] def plot_response_summary(self, unit_id, bar_thickness=0.25): """ Plot the spike counts across conditions """ df = self.stimulus_conditions.drop(index=self.null_condition) df['sf_index'] = np.searchsorted(self.sfvals, df[self._col_sf].values) df['ori_index'] = np.searchsorted(self.orivals, df[self._col_ori].values) df['phase_index'] = np.searchsorted(self.phasevals, df[self._col_phase].values) cond_values = self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id'] x = df.loc[cond_values.values]['sf_index'] + np.random.rand(cond_values.size) * bar_thickness - bar_thickness/2 y = self.presentationwise_statistics.xs(unit_id, level=1)['spike_counts'] c = df.loc[cond_values.values]['phase_index'] plt.subplot(2,1,1) plt.scatter(y,x,c=c,alpha=0.5,cmap='Blues',vmin=-5) locs, labels = plt.yticks(ticks=np.arange(self.number_sf), labels=self.sfvals) plt.ylabel('Spatial frequency') plt.xlabel('Spikes per trial') plt.ylim([self.number_sf,-1]) x = df.loc[cond_values.values]['ori_index'] + np.random.rand(cond_values.size) * bar_thickness - bar_thickness/2 y = self.presentationwise_statistics.xs(unit_id, level=1)['spike_counts'] c = df.loc[cond_values.values]['phase_index'] plt.subplot(2,1,2) plt.scatter(x,y,c=c,alpha=0.5,cmap='Spectral') locs, labels = plt.xticks(ticks=np.arange(self.number_ori), labels=self.orivals) plt.xlabel('Orientation') plt.ylabel('Spikes per trial')
[docs] def make_fan_plot(self, unit_id): """ Make a 2P-style Fan Plot based on presentationwise spike counts""" angle_data = self.stimulus_conditions.loc[self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id']][self._col_ori].values r_data = self.stimulus_conditions.loc[self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id']][self._col_sf].values group_data = self.stimulus_conditions.loc[self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id']][self._col_phase].values data = self.presentationwise_statistics.xs(unit_id, level=1)['spike_counts'].values null_trials = np.where(angle_data == 'null')[0] angle_data = np.delete(angle_data, null_trials) r_data = np.delete(r_data, null_trials) group_data = np.delete(group_data, null_trials) data = np.delete(data, null_trials) cmin = np.min(data) cmax = np.max(data) fp = FanPlotter.for_static_gratings() fp.plot(r_data = r_data, angle_data = angle_data, group_data = group_data, data =data, clim=[cmin, cmax]) fp.show_axes(closed=False) plt.axis('off')
[docs]def fit_sf_tuning(sf_tuning_responses, sf_values, pref_sf_index): """Performs gaussian or exponential fit on the spatial frequency tuning curve at preferred orientation/phase for a given cell. :param sf_tuning_responses: An array of len N, with each value the (averaged) response of a cell at a given spatial freq. stimulus. :param sf_values: An array of len N, with each value the spatial freq. of the stimulus (corresponding to sf_tuning_response). :param pref_sf_index: The pre-determined prefered spatial frequency (sf_values index) of the cell. :return: index for the preferred sf from the curve fit, prefered sf from the curve fit, low cutoff sf from the curve fit, high cutoff sf from the curve fit """ fit_sf_ind = np.NaN fit_sf = np.NaN sf_low_cutoff = np.NaN sf_high_cutoff = np.NaN if pref_sf_index in range(1, len(sf_values)-1): # If the prefered spatial freq is an interior case try to fit the tunning curve with a gaussian. try: popt, pcov = curve_fit(gauss_function, np.arange(len(sf_values)), sf_tuning_responses, p0=[np.amax(sf_tuning_responses), pref_sf_index, 1.], maxfev=2000) sf_prediction = gauss_function(np.arange(0., 4.1, 0.1), *popt) fit_sf_ind = popt[1] fit_sf = 0.02*np.power(2, popt[1]) low_cut_ind = np.abs(sf_prediction-(sf_prediction.max()/2.))[:sf_prediction.argmax()].argmin() high_cut_ind = np.abs(sf_prediction-(sf_prediction.max()/2.))[sf_prediction.argmax():].argmin() + sf_prediction.argmax() if low_cut_ind > 0: low_cutoff = np.arange(0, 4.1, 0.1)[low_cut_ind] sf_low_cutoff = 0.02*np.power(2, low_cutoff) elif high_cut_ind < 4: high_cutoff = np.arange(0, 4.1, 0.1)[high_cut_ind] sf_high_cutoff = 0.02*np.power(2, high_cutoff) except Exception as e: pass else: # If the prefered spatial freq is a boundary value try to fit the tunning curve with an exponential fit_sf_ind = pref_sf_index fit_sf = sf_values[pref_sf_index] try: popt, pcov = curve_fit(exp_function, np.arange(len(sf_values)), sf_tuning_responses, p0=[np.amax(sf_tuning_responses), 2., np.amin(sf_tuning_responses)], maxfev=2000) sf_prediction = exp_function(np.arange(0., 4.1, 0.1), *popt) if pref_sf_index == 0: high_cut_ind = np.abs(sf_prediction-(sf_prediction.max()/2.))[sf_prediction.argmax():].argmin()+sf_prediction.argmax() high_cutoff = np.arange(0, 4.1, 0.1)[high_cut_ind] sf_high_cutoff = 0.02*np.power(2, high_cutoff) else: low_cut_ind = np.abs(sf_prediction-(sf_prediction.max()/2.))[:sf_prediction.argmax()].argmin() low_cutoff = np.arange(0, 4.1, 0.1)[low_cut_ind] sf_low_cutoff = 0.02*np.power(2, low_cutoff) except Exception as e: pass return fit_sf_ind, fit_sf, sf_low_cutoff, sf_high_cutoff
[docs]def get_sfdi(sf_tuning_responses, mean_sweeps_trials, bias=5): """Computes spatial frequency discrimination index for cell :param sf_tuning_responses: sf_tuning_responses: An array of len N, with each value the (averaged) response of a cell at a given spatial freq. stimulus. :param mean_sweeps_trials: The set of events (spikes) across all trials of varying :param bias: :return: The sfdi value (float) """ trial_mean = mean_sweeps_trials.mean() sse_part = np.sqrt(np.sum((mean_sweeps_trials - trial_mean)**2) / (len(mean_sweeps_trials) - bias)) return (np.ptp(sf_tuning_responses)) / (np.ptp(sf_tuning_responses) + 2 * sse_part)
[docs]def gauss_function(x, a, x0, sigma): return a*np.exp(-(x-x0)**2/(2*sigma**2))
[docs]def exp_function(x, a, b, c): return a*np.exp(-b*x)+c