from six import string_types
import numpy as np
import pandas as pd
from scipy.optimize import curve_fit
from functools import partial
import logging
import matplotlib.pyplot as plt
from .stimulus_analysis import StimulusAnalysis
from .stimulus_analysis import osi, deg2rad
from ...circle_plots import FanPlotter
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
logger = logging.getLogger(__name__)
[docs]class StaticGratings(StimulusAnalysis):
"""
A class for computing single-unit metrics from the static gratings stimulus of an ecephys session NWB file.
To use, pass in a EcephysSession object::
session = EcephysSession.from_nwb_path('/path/to/my.nwb')
sg_analysis = StaticGratings(session)
or, alternatively, pass in the file path::
sg_analysis = StaticGratings('/path/to/my.nwb')
You can also pass in a unit filter dictionary which will only select units with certain properties. For example
to get only those units which are on probe C and found in the VISp area::
sg_analysis = StaticGratings(session, filter={'location': 'probeC', 'ecephys_structure_acronym': 'VISp'})
To get a table of the individual unit metrics ranked by unit ID::
metrics_table_df = sg_analysis.metrics()
"""
def __init__(self, ecephys_session, col_ori='orientation', col_sf='spatial_frequency', col_phase='phase',
trial_duration=0.25, **kwargs):
super(StaticGratings, self).__init__(ecephys_session, trial_duration=trial_duration, **kwargs)
self._orivals = None
self._number_ori = None
self._sfvals = None
self._number_sf = None
self._phasevals = None
self._number_phase = None
# self._response_events = None
# self._response_trials = None
self._metrics = None
self._col_ori = col_ori
self._col_sf = col_sf
self._col_phase = col_phase
self._trial_duration = trial_duration
# self._module_name = 'Static Gratings' # TODO: module_name should be a static class variable
if self._params is not None:
self._params = self._params.get('static_gratings', {})
self._stimulus_key = self._params.get('stimulus_key', None) # Overwrites parent value with argvars
else:
self._params = {}
@property
def name(self):
return 'Static Gratings'
@property
def orivals(self):
""" Array of grating orientation conditions """
if self._orivals is None:
self._get_stim_table_stats()
return self._orivals
@property
def number_ori(self):
""" Number of grating orientation conditions """
if self._number_ori is None:
self._get_stim_table_stats()
return self._number_ori
@property
def sfvals(self):
""" Array of grating spatial frequency conditions """
if self._sfvals is None:
self._get_stim_table_stats()
return self._sfvals
@property
def number_sf(self):
""" Number of grating orientation conditions """
if self._number_sf is None:
self._get_stim_table_stats()
return self._number_sf
@property
def phasevals(self):
""" Array of grating phase conditions """
if self._phasevals is None:
self._get_stim_table_stats()
return self._phasevals
@property
def number_phase(self):
""" Number of grating phase conditions """
if self._number_phase is None:
self._get_stim_table_stats()
return self._number_phase
@property
def null_condition(self):
""" Stimulus condition ID for null (blank) stimulus """
return self.stimulus_conditions[self.stimulus_conditions[self._col_sf] == 'null'].index
@property
def METRICS_COLUMNS(self):
return [('pref_sf_sg', np.float64),
('pref_sf_multi_sg', bool),
('pref_ori_sg', np.float64),
('pref_ori_multi_sg', bool),
('pref_phase_sg', np.float64),
('pref_phase_multi_sg', bool),
('g_osi_sg', np.float64),
('time_to_peak_sg', np.float64),
('firing_rate_sg', np.float64),
('fano_sg', np.float64),
('lifetime_sparseness_sg', np.float64),
('run_pval_sg', np.float64),
('run_mod_sg', np.float64)]
@property
def metrics(self):
if self._metrics is None:
logger.info('Calculating metrics for ' + self.name)
unit_ids = self.unit_ids
metrics_df = self.empty_metrics_table()
if len(self.stim_table) > 0:
metrics_df['pref_sf_sg'] = [self._get_pref_sf(unit) for unit in unit_ids]
metrics_df['pref_sf_multi_sg'] = [
self._check_multiple_pref_conditions(unit_id, self._col_sf, self.sfvals) for unit_id in unit_ids
]
metrics_df['pref_ori_sg'] = [self._get_pref_ori(unit) for unit in unit_ids]
metrics_df['pref_ori_multi_sg'] = [
self._check_multiple_pref_conditions(unit_id, self._col_ori, self.orivals) for unit_id in unit_ids
]
metrics_df['pref_phase_sg'] = [self._get_pref_phase(unit) for unit in unit_ids]
metrics_df['pref_phase_multi_sg'] = [
self._check_multiple_pref_conditions(unit_id, self._col_phase, self.phasevals) for unit_id in unit_ids
]
metrics_df['g_osi_sg'] = [self._get_osi(unit, metrics_df.loc[unit]['pref_sf_sg'], metrics_df.loc[unit]['pref_phase_sg']) for unit in unit_ids]
metrics_df['time_to_peak_sg'] = [self._get_time_to_peak(unit, self._get_preferred_condition(unit)) for unit in unit_ids]
metrics_df['firing_rate_sg'] = [self._get_overall_firing_rate(unit) for unit in unit_ids]
metrics_df['fano_sg'] = [self._get_fano_factor(unit, self._get_preferred_condition(unit)) for unit in unit_ids]
metrics_df['lifetime_sparseness_sg'] = [self._get_lifetime_sparseness(unit) for unit in unit_ids]
metrics_df.loc[:, ['run_pval_sg', 'run_mod_sg']] = \
[self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids]
self._metrics = metrics_df
return self._metrics
[docs] @classmethod
def known_stimulus_keys(cls):
return ['static_gratings']
def _get_stim_table_stats(self):
""" Extract orientations, spatial frequencies, and phases from the stimulus table """
self._orivals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_ori] != 'null'][self._col_ori].unique())
self._number_ori = len(self._orivals)
self._sfvals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_sf] != 'null'][self._col_sf].unique())
self._number_sf = len(self._sfvals)
self._phasevals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_phase] != 'null'][self._col_phase].unique())
self._number_phase = len(self._phasevals)
def _get_pref_sf(self, unit_id):
"""Calculate the preferred spatial frequency condition for a given unit.
Parameters
----------
unit_id : int
unique ID for the unit of interest
Returns
-------
pref_sf : float
spatial frequency driving the maximal response
"""
# TODO: Most of the _get_pref_*() methods can be combined into one method and shared among the classes
# Combine the stimulus_condition_id values that have the save spatial-frequency
similar_conditions_ids = [self.stimulus_conditions.index[self.stimulus_conditions[self._col_sf] == sf].tolist()
for sf in self.sfvals]
# For each spatial frequency average up conditionwise_statistics 'spike_mean' column using the indicies above.
# return the sf with the largest spike_mean.
df = pd.DataFrame(
index=self.sfvals,
data={'spike_mean': [self.conditionwise_statistics.loc[unit_id].loc[condition_inds]['spike_mean'].mean()
for condition_inds in similar_conditions_ids]}
).rename_axis(self._col_sf)
return df.idxmax().iloc[0]
def _get_pref_ori(self, unit_id):
""" Calculate the preferred orientation condition for a given unit
Parameters
----------
unit_id : int
unique ID for the unit of interest
Returns
-------
pref_ori :float
stimulus orientation driving the maximal response
"""
# Combine the stimulus_condition_id values that have the save orientations
similar_conditions = [self.stimulus_conditions.index[self.stimulus_conditions[self._col_ori] == ori].tolist()
for ori in self.orivals]
# For each orientations average up conditionwise_statistics 'spike_mean' column using the indicies above.
# Return the oris with the largest spike_mean.
df = pd.DataFrame(
index=self.orivals,
data={'spike_mean': [self.conditionwise_statistics.loc[unit_id].loc[condition_inds]['spike_mean'].mean()
for condition_inds in similar_conditions]}
).rename_axis(self._col_ori)
return df.idxmax().iloc[0]
def _get_pref_phase(self, unit_id):
"""Calculate the preferred phase condition for a given unit
Parameters
----------
unit_id : int
unique ID for the unit of interest
Returns
-------
pref_phase : float
stimulus phase driving the maximal response
"""
combined_cond_ids = [self.stimulus_conditions.index[self.stimulus_conditions[self._col_phase] == phase].tolist()
for phase in self.phasevals]
df = pd.DataFrame(
index=self.phasevals,
data = {'spike_mean': [self.conditionwise_statistics.loc[unit_id].loc[condition_inds]['spike_mean'].mean()
for condition_inds in combined_cond_ids]}
).rename_axis(self._col_phase)
return df.idxmax().iloc[0]
def _get_osi(self, unit_id, pref_sf, pref_phase):
""" Calculate the orientation selectivity for a given unit
Parameters
----------
unit_id : int
unique ID for the unit of interest
pref_sf : float
preferred spatial frequency for this unit
pref_phase : float
preferred phase for this unit
Returns
-------
osi : float
orientation selectivity value
"""
orivals_rad = deg2rad(self.orivals).astype('complex128') # TODO: can we use numpy deg2rad?
condition_inds = self.stimulus_conditions[
(self.stimulus_conditions[self._col_sf] == pref_sf) &
(self.stimulus_conditions[self._col_phase] == pref_phase)
].index.values
df = self.conditionwise_statistics.loc[unit_id].loc[condition_inds]
df = df.assign(ori=self.stimulus_conditions.loc[df.index.values][self._col_ori])
df = df.sort_values(by=['ori'])
tuning = np.array(df['spike_mean'].values)
return osi(orivals_rad, tuning)
## VISUALIZATION ##
[docs] def plot_raster(self, stimulus_condition_id, unit_id):
""" Plot raster for one condition and one unit """
idx_sf = np.where(self.sfvals == self.stimulus_conditions.loc[stimulus_condition_id][self._col_sf])[0]
idx_ori = np.where(self.orivals == self.stimulus_conditions.loc[stimulus_condition_id][self._col_ori])[0]
if len(idx_sf) == len(idx_ori) == 1:
presentation_ids = \
self.presentationwise_statistics.xs(unit_id, level=1)\
[self.presentationwise_statistics.xs(unit_id, level=1)\
['stimulus_condition_id'] == stimulus_condition_id].index.values
df = self.presentationwise_spike_times[ \
(self.presentationwise_spike_times['stimulus_presentation_id'].isin(presentation_ids)) & \
(self.presentationwise_spike_times['unit_id'] == unit_id) ]
x = df.index.values - self.stim_table.loc[df.stimulus_presentation_id].start_time
_, y = np.unique(df.stimulus_presentation_id, return_inverse=True)
plt.subplot(self.number_sf, self.number_ori, idx_sf*self.number_ori + idx_ori + 1)
plt.scatter(x, y, c='k', s=1, alpha=0.25)
plt.axis('off')
[docs] def plot_response_summary(self, unit_id, bar_thickness=0.25):
""" Plot the spike counts across conditions """
df = self.stimulus_conditions.drop(index=self.null_condition)
df['sf_index'] = np.searchsorted(self.sfvals, df[self._col_sf].values)
df['ori_index'] = np.searchsorted(self.orivals, df[self._col_ori].values)
df['phase_index'] = np.searchsorted(self.phasevals, df[self._col_phase].values)
cond_values = self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id']
x = df.loc[cond_values.values]['sf_index'] + np.random.rand(cond_values.size) * bar_thickness - bar_thickness/2
y = self.presentationwise_statistics.xs(unit_id, level=1)['spike_counts']
c = df.loc[cond_values.values]['phase_index']
plt.subplot(2,1,1)
plt.scatter(y,x,c=c,alpha=0.5,cmap='Blues',vmin=-5)
locs, labels = plt.yticks(ticks=np.arange(self.number_sf), labels=self.sfvals)
plt.ylabel('Spatial frequency')
plt.xlabel('Spikes per trial')
plt.ylim([self.number_sf,-1])
x = df.loc[cond_values.values]['ori_index'] + np.random.rand(cond_values.size) * bar_thickness - bar_thickness/2
y = self.presentationwise_statistics.xs(unit_id, level=1)['spike_counts']
c = df.loc[cond_values.values]['phase_index']
plt.subplot(2,1,2)
plt.scatter(x,y,c=c,alpha=0.5,cmap='Spectral')
locs, labels = plt.xticks(ticks=np.arange(self.number_ori), labels=self.orivals)
plt.xlabel('Orientation')
plt.ylabel('Spikes per trial')
[docs] def make_fan_plot(self, unit_id):
""" Make a 2P-style Fan Plot based on presentationwise spike counts"""
angle_data = self.stimulus_conditions.loc[self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id']][self._col_ori].values
r_data = self.stimulus_conditions.loc[self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id']][self._col_sf].values
group_data = self.stimulus_conditions.loc[self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id']][self._col_phase].values
data = self.presentationwise_statistics.xs(unit_id, level=1)['spike_counts'].values
null_trials = np.where(angle_data == 'null')[0]
angle_data = np.delete(angle_data, null_trials)
r_data = np.delete(r_data, null_trials)
group_data = np.delete(group_data, null_trials)
data = np.delete(data, null_trials)
cmin = np.min(data)
cmax = np.max(data)
fp = FanPlotter.for_static_gratings()
fp.plot(r_data = r_data, angle_data = angle_data, group_data = group_data, data =data, clim=[cmin, cmax])
fp.show_axes(closed=False)
plt.axis('off')
[docs]def fit_sf_tuning(sf_tuning_responses, sf_values, pref_sf_index):
"""Performs gaussian or exponential fit on the spatial frequency tuning curve at preferred orientation/phase for
a given cell.
:param sf_tuning_responses: An array of len N, with each value the (averaged) response of a cell at a given spatial
freq. stimulus.
:param sf_values: An array of len N, with each value the spatial freq. of the stimulus (corresponding to
sf_tuning_response).
:param pref_sf_index: The pre-determined prefered spatial frequency (sf_values index) of the cell.
:return: index for the preferred sf from the curve fit, prefered sf from the curve fit, low cutoff sf from the
curve fit, high cutoff sf from the curve fit
"""
fit_sf_ind = np.NaN
fit_sf = np.NaN
sf_low_cutoff = np.NaN
sf_high_cutoff = np.NaN
if pref_sf_index in range(1, len(sf_values)-1):
# If the prefered spatial freq is an interior case try to fit the tunning curve with a gaussian.
try:
popt, pcov = curve_fit(gauss_function, np.arange(len(sf_values)), sf_tuning_responses, p0=[np.amax(sf_tuning_responses),
pref_sf_index, 1.], maxfev=2000)
sf_prediction = gauss_function(np.arange(0., 4.1, 0.1), *popt)
fit_sf_ind = popt[1]
fit_sf = 0.02*np.power(2, popt[1])
low_cut_ind = np.abs(sf_prediction-(sf_prediction.max()/2.))[:sf_prediction.argmax()].argmin()
high_cut_ind = np.abs(sf_prediction-(sf_prediction.max()/2.))[sf_prediction.argmax():].argmin() + sf_prediction.argmax()
if low_cut_ind > 0:
low_cutoff = np.arange(0, 4.1, 0.1)[low_cut_ind]
sf_low_cutoff = 0.02*np.power(2, low_cutoff)
elif high_cut_ind < 4:
high_cutoff = np.arange(0, 4.1, 0.1)[high_cut_ind]
sf_high_cutoff = 0.02*np.power(2, high_cutoff)
except Exception as e:
pass
else:
# If the prefered spatial freq is a boundary value try to fit the tunning curve with an exponential
fit_sf_ind = pref_sf_index
fit_sf = sf_values[pref_sf_index]
try:
popt, pcov = curve_fit(exp_function, np.arange(len(sf_values)), sf_tuning_responses,
p0=[np.amax(sf_tuning_responses), 2., np.amin(sf_tuning_responses)], maxfev=2000)
sf_prediction = exp_function(np.arange(0., 4.1, 0.1), *popt)
if pref_sf_index == 0:
high_cut_ind = np.abs(sf_prediction-(sf_prediction.max()/2.))[sf_prediction.argmax():].argmin()+sf_prediction.argmax()
high_cutoff = np.arange(0, 4.1, 0.1)[high_cut_ind]
sf_high_cutoff = 0.02*np.power(2, high_cutoff)
else:
low_cut_ind = np.abs(sf_prediction-(sf_prediction.max()/2.))[:sf_prediction.argmax()].argmin()
low_cutoff = np.arange(0, 4.1, 0.1)[low_cut_ind]
sf_low_cutoff = 0.02*np.power(2, low_cutoff)
except Exception as e:
pass
return fit_sf_ind, fit_sf, sf_low_cutoff, sf_high_cutoff
[docs]def get_sfdi(sf_tuning_responses, mean_sweeps_trials, bias=5):
"""Computes spatial frequency discrimination index for cell
:param sf_tuning_responses: sf_tuning_responses: An array of len N, with each value the (averaged) response of a
cell at a given spatial freq. stimulus.
:param mean_sweeps_trials: The set of events (spikes) across all trials of varying
:param bias:
:return: The sfdi value (float)
"""
trial_mean = mean_sweeps_trials.mean()
sse_part = np.sqrt(np.sum((mean_sweeps_trials - trial_mean)**2) / (len(mean_sweeps_trials) - bias))
return (np.ptp(sf_tuning_responses)) / (np.ptp(sf_tuning_responses) + 2 * sse_part)
[docs]def gauss_function(x, a, x0, sigma):
return a*np.exp(-(x-x0)**2/(2*sigma**2))
[docs]def exp_function(x, a, b, c):
return a*np.exp(-b*x)+c