import numpy as np
import pandas as pd
from six import string_types
import scipy.ndimage as ndi
import scipy.stats as st
from scipy.signal import welch
from scipy.optimize import curve_fit
from scipy.fftpack import fft
from scipy import signal
import logging
import matplotlib.pyplot as plt
from .stimulus_analysis import StimulusAnalysis, osi, dsi, deg2rad
from ...circle_plots import FanPlotter
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
logger = logging.getLogger(__name__)
[docs]class DriftingGratings(StimulusAnalysis):
"""
A class for computing single-unit metrics from the drifting gratings stimulus of an ecephys session NWB file.
To use, pass in a EcephysSession object::
session = EcephysSession.from_nwb_path('/path/to/my.nwb')
dg_analysis = DriftingGratings(session)
or, alternatively, pass in the file path::
dg_analysis = DriftingGratings('/path/to/my.nwb')
You can also pass in a unit filter dictionary which will only select units with certain properties. For example
to get only those units which are on probe C and found in the VISp area::
dg_analysis = DriftingGratings(session, filter={'location': 'probeC', 'structure_acronym': 'VISp'})
To get a table of the individual unit metrics ranked by unit ID::
metrics_table_df = dg_analysis.metrics()
"""
def __init__(self, ecephys_session, col_ori='orientation', col_tf='temporal_frequency', col_contrast='contrast',
trial_duration=2.0, **kwargs):
super(DriftingGratings, self).__init__(ecephys_session, trial_duration=trial_duration, **kwargs)
self._metrics = None
self._orivals = None
self._number_ori = None
self._tfvals = None
self._number_tf = None
self._contrastvals = None
self._number_constrast = None
self._col_ori = col_ori
self._col_tf = col_tf
self._col_contrast = col_contrast
if self._params is not None:
# TODO: Need to make sure
self._params = self._params.get('drifting_gratings', {})
self._stimulus_key = self._params.get('stimulus_key', None) # Overwrites parent value with argvars
else:
self._params = {}
self._stim_table_contrast = None
#stim_table = self.stim_table
#self._stim_table_contrast = stim_table[stim_table['stimulus_name'] == 'drifting_gratings_contrast']
#self._stim_table = stim_table[stim_table['stimulus_name'] != 'drifting_gratings_contrast']
self._conditionwise_statistics_contrast = None
self._stimulus_conditions_contrast = None
@property
def stim_table_contrast(self):
if self._stim_table_contrast is None:
stim_table = self.ecephys_session.stimulus_presentations
if 'drifting_gratings_contrast' in stim_table['stimulus_name'].unique():
self._stim_table_contrast = stim_table[stim_table['stimulus_name'] == 'drifting_gratings_contrast']
else:
self._stim_table_contrast = pd.DataFrame()
return self._stim_table_contrast
@property
def name(self):
return 'Drifting Gratings'
@property
def orivals(self):
""" Array of grating orientation conditions """
if self._orivals is None:
self._get_stim_table_stats()
return self._orivals
@property
def number_ori(self):
""" Number of grating orientation conditions """
if self._number_ori is None:
self._get_stim_table_stats()
return self._number_ori
@property
def tfvals(self):
""" Array of grating temporal frequency conditions """
if self._tfvals is None:
self._get_stim_table_stats()
return self._tfvals
@property
def number_tf(self):
""" Number of grating temporal frequency conditions """
if self._tfvals is None:
self._get_stim_table_stats()
return self._number_tf
@property
def contrastvals(self):
""" Array of grating temporal frequency conditions """
if self._contrastvals is None:
self._get_stim_table_stats()
return self._contrastvals
@property
def number_contrast(self):
""" Number of grating temporal frequency conditions """
if self._number_contrast is None:
self._get_stim_table_stats()
return self._number_contrast
@property
def null_condition(self):
""" Stimulus condition ID for null (blank) stimulus """
return self.stimulus_conditions[self.stimulus_conditions[self._col_tf] == 'null'].index
@property
def stimulus_conditions_contrast(self):
""" Stimulus conditions for contrast stimulus """
if self._stimulus_conditions_contrast is None:
# TODO: look into efficiency of using a table intersect instead.
contrast_condition_list = self.stim_table_contrast.stimulus_condition_id.unique()
self._stimulus_conditions_contrast = self.ecephys_session.stimulus_conditions[
self.ecephys_session.stimulus_conditions.index.isin(contrast_condition_list)
]
return self._stimulus_conditions_contrast
@property
def conditionwise_statistics_contrast(self):
""" Conditionwise statistics for contrast stimulus """
if self._conditionwise_statistics_contrast is None:
self._conditionwise_statistics_contrast = self.ecephys_session.conditionwise_spike_statistics(
self.stim_table_contrast.index.values,
self.unit_ids
)
return self._conditionwise_statistics_contrast
@property
def METRICS_COLUMNS(self):
return [('pref_ori_dg', np.float64),
('pref_ori_multi_dg', bool),
('pref_tf_dg', np.float64),
('pref_tf_multi_dg', bool),
('c50_dg', np.float64),
('f1_f0_dg', np.float64),
('mod_idx_dg', np.float64),
('g_osi_dg', np.float64),
('g_dsi_dg', np.float64),
('firing_rate_dg', np.float64),
('fano_dg', np.float64),
('lifetime_sparseness_dg', np.float64),
('run_pval_dg', np.float64),
('run_mod_dg', np.float64)]
@property
def metrics(self):
if self._metrics is None:
logger.info('Calculating metrics for ' + self.name)
unit_ids = self.unit_ids
metrics_df = self.empty_metrics_table()
if len(self.stim_table) > 0:
metrics_df['pref_ori_dg'] = [self._get_pref_ori(unit) for unit in unit_ids]
metrics_df['pref_ori_multi_dg'] = [
self._check_multiple_pref_conditions(unit_id, self._col_ori, self.orivals) for unit_id in unit_ids
]
metrics_df['pref_tf_dg'] = [self._get_pref_tf(unit) for unit in unit_ids]
metrics_df['pref_tf_multi_dg'] = [
self._check_multiple_pref_conditions(unit_id, self._col_tf, self.tfvals) for unit_id in unit_ids
]
metrics_df['f1_f0_dg'] = [self._get_f1_f0(unit, self._get_preferred_condition(unit))
for unit in unit_ids]
metrics_df['mod_idx_dg'] = [self._get_modulation_index(unit, self._get_preferred_condition(unit))
for unit in unit_ids]
metrics_df['g_osi_dg'] = [self._get_selectivity(unit, metrics_df.loc[unit]['pref_tf_dg'], 'osi')
for unit in unit_ids]
metrics_df['g_dsi_dg'] = [self._get_selectivity(unit, metrics_df.loc[unit]['pref_tf_dg'], 'dsi')
for unit in unit_ids]
metrics_df['firing_rate_dg'] = [self._get_overall_firing_rate(unit) for unit in unit_ids]
metrics_df['fano_dg'] = [self._get_fano_factor(unit, self._get_preferred_condition(unit))
for unit in unit_ids]
metrics_df['lifetime_sparseness_dg'] = [self._get_lifetime_sparseness(unit) for unit in unit_ids]
metrics_df.loc[:, ['run_pval_dg', 'run_mod_dg']] = [
self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids]
if len(self.stim_table_contrast) > 0:
metrics_df['c50_dg'] = [self._get_c50(unit) for unit in unit_ids]
self._metrics = metrics_df
return self._metrics
[docs] @classmethod
def known_stimulus_keys(cls):
return ['drifting_gratings', 'drifting_gratings_75_repeats']
def _get_stim_table_stats(self):
""" Extract orientations and temporal frequencies from the stimulus table """
self._orivals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_ori]
!= 'null'][self._col_ori].unique())
self._number_ori = len(self._orivals)
self._tfvals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_tf]
!= 'null'][self._col_tf].unique())
self._number_tf = len(self._tfvals)
self._contrastvals = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_contrast]
!= 'null'][self._col_contrast].unique())
self._number_contrast = len(self._contrastvals)
def _get_pref_ori(self, unit_id):
""" Calculate the preferred orientation condition for a given unit
Parameters
----------
unit_id : int
unique ID for the unit of interest
Returns
-------
pref_ori : float
stimulus orientation driving the maximal response
"""
# TODO: Most of the _get_pref_*() methods can be combined into one method and shared among the classes
similar_conditions = [self.stimulus_conditions.index[self.stimulus_conditions[self._col_ori] == ori].tolist()
for ori in self.orivals]
df = pd.DataFrame(
index=self.orivals,
data={'spike_mean': [self.conditionwise_statistics.loc[unit_id].loc[condition_inds]['spike_mean'].mean()
for condition_inds in similar_conditions]}
).rename_axis(self._col_ori)
return df.idxmax().iloc[0]
def _get_pref_tf(self, unit_id):
""" Calculate the preferred temporal frequency condition for a given unit
Params:
-------
unit_id : int
unique ID for the unit of interest
Returns
-------
pref_tf : float
stimulus temporal frequency driving the maximal response
"""
similar_conditions = [self.stimulus_conditions.index[self.stimulus_conditions[self._col_tf] == tf].tolist()
for tf in self.tfvals]
df = pd.DataFrame(
index=self.tfvals,
data={'spike_mean': [self.conditionwise_statistics.loc[unit_id].loc[condition_inds]['spike_mean'].mean()
for condition_inds in similar_conditions]}
).rename_axis(self._col_tf)
return df.idxmax().iloc[0]
def _get_selectivity(self, unit_id, pref_tf, selectivity_type='osi'):
""" Calculate the orientation or direction selectivity for a given unit
Params:
-------
unit_id - unique ID for the unit of interest
pref_tf - preferred temporal frequency for this unit
selectivity_type - 'osi' or 'dsi'
Returns:
-------
selectivity - orientation or direction selectivity value
"""
orivals_rad = deg2rad(self.orivals).astype('complex128')
condition_inds = self.stimulus_conditions[self.stimulus_conditions[self._col_tf] == pref_tf].index.values
df = self.conditionwise_statistics.loc[unit_id].loc[condition_inds]
df = df.assign(ori=self.stimulus_conditions.loc[df.index.values][self._col_ori])
df = df.sort_values(by=['ori']) # do not replace with self._col_ori unless we modify the line above
tuning = np.array(df['spike_mean'].values)
if selectivity_type == 'osi':
return osi(orivals_rad, tuning)
elif selectivity_type == 'dsi':
return dsi(orivals_rad, tuning)
else:
warnings.warn(f'unkown selectivity function {selectivity_type}.')
return np.nan
def _get_f1_f0(self, unit_id, condition_id):
""" Calculate F1/F0 for a given unit
A measure of how tightly locked a unit's firing rate is to the cycles of a drifting grating
Parameters
----------
unit_id - unique ID for the unit of interest
condition_id - ID for the condition of interest (usually the preferred condition)
Returns
-------
f1_f0 - metric
"""
presentation_ids = self.stim_table[self.stim_table['stimulus_condition_id'] == condition_id].index.values
tf = self.stim_table.loc[presentation_ids[0]][self._col_tf]
dataset = self.ecephys_session.presentationwise_spike_counts(
bin_edges=np.arange(0, self.trial_duration, 0.001),
stimulus_presentation_ids=presentation_ids,
unit_ids=[unit_id]
).drop('unit_id')
arr = np.squeeze(dataset.values)
trial_duration = dataset.time_relative_to_stimulus_onset.max() #TODO: If there a reason not to use self.trial_duration?
return f1_f0(arr, tf, trial_duration)
def _get_modulation_index(self, unit_id, condition_id):
""" Calculate modulation index for a given unit.
Parameters
----------
unit_id : int
unique ID for the unit of interest
condition_id :
ID for the condition of interest (usually the preferred condition)
Returns
-------
modulation_index : metric
"""
tf = self.stimulus_conditions.loc[condition_id][self._col_tf]
data = self.conditionwise_psth.sel(unit_id=unit_id, stimulus_condition_id=condition_id).data
sample_rate = 1 / np.mean(np.diff(self.conditionwise_psth.time_relative_to_stimulus_onset))
return modulation_index(data, tf, sample_rate)
def _get_c50(self, unit_id):
""" Calculate C50 for a given unit. Only valid if the contrast tuning stimulus is present. Otherwise,
return NaN value
Parameters
----------
unit_id : int
unique ID for the unit of interest
Returns:
-------
c50 : float
metric
"""
contrast_conditions = self.stim_table_contrast[
(self.stim_table_contrast[self._col_ori] == self._get_pref_ori(unit_id))]['stimulus_condition_id'].unique()
# contrasts = self.stimulus_conditions_contrast.loc[contrast_conditions]['contrast'].values.astype('float')
contrasts = self.stimulus_conditions_contrast.loc[contrast_conditions][self._col_contrast].values.astype('float')
mean_responses = self.conditionwise_statistics_contrast.loc[unit_id].loc[contrast_conditions]['spike_mean'].values.astype('float')
return c50(contrasts, mean_responses)
# Methods need to either be removed or updated to work with latest adaptor. Talked with Jsh and decision still
# pending.
'''
def _get_tfdi(self, unit_id, pref_ori):
""" Calculate temporal frequency discrimination index for a given unit
Only valid if the contrast tuning stimulus is present
Otherwise, return NaN value
Params:
-------
unit_id - unique ID for the unit of interest
pref_ori - preferred orientation for that cell
Returns:
-------
tfdi - metric
"""
### NEEDS TO BE UPDATED FOR NEW ADAPTER
v = list(self.spikes.keys())[nc]
tf_tuning = self.response_events[pref_ori, 1:, nc, 0]
trials = self.mean_sweep_events[(self.stim_table['Ori'] == self.orivals[pref_ori])][v].values
sse_part = np.sqrt(np.sum((trials-trials.mean())**2)/(len(trials)-5))
return (np.ptp(tf_tuning))/(np.ptp(tf_tuning) + 2*sse_part)
'''
'''
def _get_suppressed_contrast(self, unit_id, pref_ori, pref_tf):
""" Calculate two metrics used to determine if a unit is suppressed by contrast
Params:
-------
unit_id - unique ID for the unit of interest
pref_ori - preferred orientation for that cell
pref_tf - preferred temporal frequency for that cell
Returns:
-------
peak_blank - metric
all_blank - metric
"""
### NEEDS TO BE UPDATED FOR NEW ADAPTER
blank = self.response_events[0, 0, nc, 0]
peak = self.response_events[pref_ori, pref_tf+1, nc, 0]
all_resp = self.response_events[:, 1:, nc, 0].mean()
peak_blank = peak - blank
all_blank = all_resp - blank
return peak_blank, all_blank
'''
'''
def _fit_tf_tuning(self, unit_id, pref_ori, pref_tf):
""" Performs Gaussian or exponential fit on the temporal frequency tuning curve at the preferred orientation.
Params:
-------
unit_id - unique ID for the unit of interest
pref_ori - preferred orientation for that cell
pref_tf - preferred temporal frequency for that cell
Returns:
-------
fit_tf_ind - metric
fit_tf - metric
tf_low_cutoff - metric
tf_high_cutoff - metric
"""
### NEEDS TO BE UPDATED FOR NEW ADAPTER
tf_tuning = self.response_events[pref_ori, 1:, nc, 0]
fit_tf_ind = np.NaN
fit_tf = np.NaN
tf_low_cutoff = np.NaN
tf_high_cutoff = np.NaN
if pref_tf in range(1, 4):
try:
popt, pcov = curve_fit(gauss_function, range(5), tf_tuning, p0=[np.amax(tf_tuning), pref_tf, 1.],
maxfev=2000)
tf_prediction = gauss_function(np.arange(0., 4.1, 0.1), *popt)
fit_tf_ind = popt[1]
fit_tf = np.power(2, popt[1])
low_cut_ind = np.abs(tf_prediction - (tf_prediction.max() / 2.))[:tf_prediction.argmax()].argmin()
high_cut_ind = np.abs(tf_prediction - (tf_prediction.max() / 2.))[
tf_prediction.argmax():].argmin() + tf_prediction.argmax()
if low_cut_ind > 0:
low_cutoff = np.arange(0, 4.1, 0.1)[low_cut_ind]
tf_low_cutoff = np.power(2, low_cutoff)
elif high_cut_ind < 49:
high_cutoff = np.arange(0, 4.1, 0.1)[high_cut_ind]
tf_high_cutoff = np.power(2, high_cutoff)
except Exception:
pass
else:
fit_tf_ind = pref_tf
fit_tf = self.tfvals[pref_tf]
try:
popt, pcov = curve_fit(exp_function, range(5), tf_tuning,
p0=[np.amax(tf_tuning), 2., np.amin(tf_tuning)], maxfev=2000)
tf_prediction = exp_function(np.arange(0., 4.1, 0.1), *popt)
if pref_tf == 0:
high_cut_ind = np.abs(tf_prediction - (tf_prediction.max() / 2.))[
tf_prediction.argmax():].argmin() + tf_prediction.argmax()
high_cutoff = np.arange(0, 4.1, 0.1)[high_cut_ind]
tf_high_cutoff = np.power(2, high_cutoff)
else:
low_cut_ind = np.abs(tf_prediction - (tf_prediction.max() / 2.))[
:tf_prediction.argmax()].argmin()
low_cutoff = np.arange(0, 4.1, 0.1)[low_cut_ind]
tf_low_cutoff = np.power(2, low_cutoff)
except Exception:
pass
return fit_tf_ind, fit_tf, tf_low_cutoff, tf_high_cutoff
'''
## VISUALIZATION ##
[docs] def plot_raster(self, stimulus_condition_id, unit_id):
""" Plot raster for one condition and one unit """
idx_tf = np.where(self.tfvals == self.stimulus_conditions.loc[stimulus_condition_id][self._col_tf])[0]
idx_ori = np.where(self.orivals == self.stimulus_conditions.loc[stimulus_condition_id][self._col_ori])[0]
if len(idx_tf) == len(idx_ori) == 1:
presentation_ids = self.presentationwise_statistics.xs(unit_id, level=1)[
self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id'] == stimulus_condition_id
].index.values
df = self.presentationwise_spike_times[
(self.presentationwise_spike_times['stimulus_presentation_id'].isin(presentation_ids)) &
(self.presentationwise_spike_times['unit_id'] == unit_id)]
x = df.index.values - self.stim_table.loc[df.stimulus_presentation_id].start_time
_, y = np.unique(df.stimulus_presentation_id, return_inverse=True)
plt.subplot(self.number_tf, self.number_ori, idx_tf*self.number_ori + idx_ori + 1)
plt.scatter(x, y, c='k', s=1, alpha=0.25)
plt.axis('off')
[docs] def plot_response_summary(self, unit_id, bar_thickness=0.25):
""" Plot the spike counts across conditions """
df = self.stimulus_conditions.drop(index=self.null_condition)
df['tf_index'] = np.searchsorted(self.tfvals, df[self._col_tf].values)
df['ori_index'] = np.searchsorted(self.orivals, df[self._col_ori].values)
cond_values = self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id']
x = df.loc[cond_values.values]['tf_index'] + np.random.rand(cond_values.size) * bar_thickness - bar_thickness/2
y = self.presentationwise_statistics.xs(unit_id, level=1)['spike_counts']
c = df.loc[cond_values.values]['tf_index']
plt.subplot(2, 1, 1)
plt.scatter(y, x, c=c, alpha=0.5, cmap='Purples', vmin=-5)
locs, labels = plt.yticks(ticks=np.arange(self.number_tf), labels=self.tfvals)
plt.ylabel('Temporal frequency')
plt.xlabel('Spikes per trial')
plt.ylim([self.number_tf, -1])
x = df.loc[cond_values.values]['ori_index'] + np.random.rand(cond_values.size) * bar_thickness - bar_thickness/2
y = self.presentationwise_statistics.xs(unit_id, level=1)['spike_counts']
c = df.loc[cond_values.values]['ori_index']
plt.subplot(2, 1, 2)
plt.scatter(x, y, c=c, alpha=0.5, cmap='Spectral')
locs, labels = plt.xticks(ticks=np.arange(self.number_ori), labels=self.orivals)
plt.xlabel('Orientation')
plt.ylabel('Spikes per trial')
[docs] def make_star_plot(self, unit_id):
""" Make a 2P-style Star Plot based on presentationwise spike counts"""
angle_data = self.stimulus_conditions.loc[self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id']][self._col_ori].values
r_data = self.stimulus_conditions.loc[self.presentationwise_statistics.xs(unit_id, level=1)['stimulus_condition_id']][self._col_tf].values
data = self.presentationwise_statistics.xs(unit_id, level=1)['spike_counts'].values
null_trials = np.where(angle_data == 'null')[0]
angle_data = np.delete(angle_data, null_trials)
r_data = np.delete(r_data, null_trials)
data = np.delete(data, null_trials)
cmin = np.min(data)
cmax = np.max(data)
fp = FanPlotter.for_drifting_gratings()
fp.plot(r_data=r_data, angle_data=angle_data, data=data, clim=[cmin, cmax])
fp.show_axes(closed=False)
plt.ylim([-5, 5])
plt.axis('equal')
plt.axis('off')
### General functions ###
def _gauss_function(x, a, x0, sigma):
"""
fit gaussian function at log scale
good for fitting band pass, not good at low pass or high pass
"""
return a*np.exp(-(x-x0)**2/(2*sigma**2))
def _exp_function(x, a, b, c):
return a*np.exp(-b*x)+c
def _contrast_curve(x, b, c, d, e):
"""Difference of gaussian.
fit sigmoid function at log scale
not good for fitting band pass
- b: hill slope
- c: min response
- d: max response
- e: EC50
"""
return c+(d-c)/(1+np.exp(b*(np.log(x)-np.log(e))))
[docs]def c50(contrasts, responses):
"""Computes C50, the halfway point between the maximum and minimum values in a curved fitted against a difference
of gaussian for the contrast values and their responese (mean spike rates)
Parameters
----------
contrasts : array of floats
list of different contrast stimuli
responses : array of floats
array of responses (spike rates)
Returns
-------
c50 : float
"""
if contrasts.size == 0 or contrasts.size != responses.size:
warnings.warn('the contrasts and responses arrays must be of the same length')
return np.nan
try:
# find the paraemters that best fit the contrast curve give x = contrast-vals and y = responses
fitCoefs, _ = curve_fit(_contrast_curve, contrasts, responses, maxfev=100000)
except RuntimeError as e:
warnings.warn(str(e))
return np.nan
# Create the constrast curve using the optimized parameters, get the halfway range point on the curve
# resids = responses - contrast_curve(contrasts.astype('float'), *fitCoefs)
X = np.linspace(min(contrasts)*0.9, max(contrasts)*1.1, 256) #
y_fit = _contrast_curve(X, *fitCoefs)
y_middle = (np.max(y_fit) - np.min(y_fit)) / 2 + np.min(y_fit)
try:
# y_fit is unlikely to be sorted, so to get the optimial value we should sort by y_fit and X before calling
# numpy's searchsorted()
sorted_indicies = np.argsort(y_fit)
X_sorted = X[sorted_indicies]
y_fit_sorted = y_fit[sorted_indicies]
c50 = X_sorted[np.searchsorted(y_fit_sorted, y_middle)]
except IndexError as e:
warnings.warn(str(e))
return np.nan
return c50
[docs]def f1_f0(arr, tf, trial_duration):
"""Computes F1/F0 of a drifting grating response
Parameters
----------
arr :
DataArray with trials x bin-times
tf :
temporal frequency of the stimulus
Returns
-------
f1_f0 : float
metric
"""
if arr.size == 0:
return np.nan
if arr.ndim == 1:
arr = arr.reshape(1, arr.size)
# For each trial group the bins into blocks that will go to the length of the temporal frequency
num_bins = arr.shape[1]
num_trials = arr.shape[0]
cycles_per_trial = int(tf * trial_duration)
bins_per_cycle = int(num_bins / cycles_per_trial)
if bins_per_cycle == 0:
# can occur if temp-freq x trial duration is greater than the total trial duration
return np.nan
arr = arr[:, :cycles_per_trial*bins_per_cycle].reshape((num_trials, cycles_per_trial, bins_per_cycle))
avg_rate = np.mean(arr, 1)
AMP = 2*np.abs(fft(avg_rate, bins_per_cycle)) / bins_per_cycle
f0 = 0.5*AMP[:, 0]
f1 = AMP[:, 1]
selection = f0 > 0.0
if not np.any(selection):
# No spikes found
return np.nan
return np.nanmean(f1[selection]/f0[selection])
[docs]def modulation_index(response_psth, tf, sample_rate):
"""Depth of modulation by each cycle of a drifting grating; similar to F1/F0
ref: Matteucci et al. (2019) Nonlinear processing of shape information
in rat lateral extrastriate cortex. J Neurosci 39: 1649-1670
Parameters
----------
response_psth : array of floats
the binned responses of a unit for a given stimuli
tf : float
the temporal frequency
sample_rate : float
the sampling rate of response_psth
Returns
-------
modulation_index : float
the mi value
"""
if response_psth.size == 0:
warnings.warn('response_psth is empty')
return np.nan
f, psd = signal.welch(response_psth, fs=sample_rate, nperseg=1024) # get freqs. and power spectral density
mean_psd = np.mean(psd)
if mean_psd == 0.0:
# TODO: Check with josh, should it be 0 or nan?
return 0.0
tf_index = np.searchsorted(f, tf)
if not 0 <= tf_index < psd.size:
warnings.warn('specified temporal frequency is not within the singals sampling range. Please adjust tf and/or'
'sample_rate parameters.')
return np.nan
return abs((psd[tf_index] - np.mean(psd))/np.sqrt(np.mean(psd**2)- mean_psd**2))