import numpy as np
from scipy import stats
from scipy.optimize import curve_fit, fmin
from allensdk.internal.model.glif.find_spikes import align_and_cut_spikes, ALIGN_CUT_WINDOW
import logging
import matplotlib.pyplot as plt
[docs]def calc_spike_cut_and_v_reset_via_expvar_residuals(all_current_list,
all_voltage_list, dt, El_reference, deltaV,
max_spike_cut_time=False,
MAKE_PLOT=False, SHOW_PLOT=False, PUBLICATION_PLOT=False, BLOCK=False):
'''This function calculates where the spike should be cut based on explained variance.
The goal is to find a model where the voltage after a spike maximally explains the
voltage before a spike. This will also specify the voltage reset rule
inputs:
spike_determination_method: string specifing the method used to find threshold
all_current_list: list of current (list of current traces injected into neuron)
all_voltage_list: list of voltages (list of voltage trace)
The change is that if the slope is greater than one or intercept is greater than zero it forces it.
Regardless of required force the residuals are used.
'''
#--find the region of the spike needed for calculation of explained variance
(temp_v_spike_shape_list, all_i_spike_shape_list, all_thresholdInd, waveIndOfFirstSpikes, spikeFromWhichSweep) \
= align_and_cut_spikes(all_voltage_list, all_current_list, dt)
#--At this point it is unclear how this calculation should be done.
#--the slope should be fine no matter what, but the intercept dependency
#--will depend on the El, and deltaV
#--change reference
all_v_spike_shape_list=[shape-El_reference-deltaV for shape in temp_v_spike_shape_list]
# --setting limits to find explained variance
if max_spike_cut_time and max_spike_cut_time < .010:
expVarIndRangeAfterSpike = range(int(.001 / dt), int(max_spike_cut_time / dt)) #NOTE: THIS IS USED IN REFERENCE TO SPIKE TIME
else:
expVarIndRangeAfterSpike = range(int(.001 / dt), int(.010 / dt)) #NOTE: THIS IS USED IN REFERENCE TO SPIKE TIME
vectorIndex_of_max_explained_var = expVarIndRangeAfterSpike[0] # this is just here for the title of the plot
list_of_endPointArrays = [] # this should end up a list of numpy arrays where each numpy array contains the indices of the v_spike_shape_list that are a certain time after the threshold
for ii in expVarIndRangeAfterSpike:
list_of_endPointArrays.append(np.array(all_thresholdInd) + ii)
def line_force_slope_to_1(x,c):
return x+c
def line_force_int_to_0(x, m): #TODO: CHANGE THIS TO REST TOD DISCONNECT EVERYTHING.
return m*x
# HERE YOU GET THE SLOPE AND INTERCEPT AT EACH POINT
linRegress_error_4_each_time_end = []
slope_at_each_time_end=[]
intercept_at_each_time_end=[]
varData_4_each_time_end = []
varModel_4_each_time_end = []
chi2 = []
sum_residuals_4_each_time_end=[]
xdata = np.array([v[all_thresholdInd[ii]] for ii, v in enumerate(all_v_spike_shape_list)])
var_of_Vdata_beforeSpike = np.var(xdata)
for jj, vectorOfIndAcrossWaves in enumerate(list_of_endPointArrays): # these indices should be in terms of the spike waveforms
# print('jj', jj)
# TODO: Teeter get rid of the nonblipness
v_at_specificEndPoint = [all_v_spike_shape_list[ii][index] for ii, index in enumerate(vectorOfIndAcrossWaves)] # this is calculating variance at certain time points
# --currently the model of voltage reset is a linear regression between voltage before the spike and the voltage after the spike but it could be more complicated (for example as a function of current)
ydata = np.array(v_at_specificEndPoint) # this is the voltage at the specified end point
slope, intercept, r_value, p_value, std_err = stats.linregress(xdata, ydata)
# print(slope, intercept, r_value, p_value, std_err)
# if slope>1.0:
# logging.warning('linear regression slope is bigger than one: forcing slope to 1 and refitting intercept.')
# slope=1.0
# (intercept, nothing)=curve_fit(line_force_slope_to_1, xdata, ydata)
# #print("NEW INTERCEPT:", intercept)
# if intercept>0.0:
# #warnings.warn('/t ... and intercept is bigger than zero: forcing intercept to 0')
# intercept=0.0
#
# if intercept>0.0:
# logging.warning('Intercept is bigger than zero: forcing intercept to 0 and refitting slope.')
# intercept=0.0
# (slope, nothing)=curve_fit(line_force_int_to_0,xdata, ydata)
# #print("NEW SLOPE: ", slope)
# if slope>1.0:
# logging.warning('/t ... and linear regression slope is bigger than one: forcing slope to 1.')
# slope=1.0
slope_at_each_time_end.append(slope)
intercept_at_each_time_end.append(intercept)
ymodel = slope * xdata + intercept
residuals = ydata - ymodel
sum_residuals=sum(abs(residuals))
sum_residuals_4_each_time_end.append(sum_residuals)
chi2.append(np.var(residuals)) # how well the model describes the data
linRegress_error_4_each_time_end.append(std_err)
varData_4_each_time_end.append(np.var(v_at_specificEndPoint))
varModel_4_each_time_end.append(np.var(ymodel))
# --these will line up with how many arrays there are in the list
vectorIndex_of_min_sum_residuals = sum_residuals_4_each_time_end.index(min(sum_residuals_4_each_time_end))
#----NOTE THIS ISNT ACTUALLY CALCULATING EXPLAINED VARIANCE!!!!!!!!!!!!!!!!!!
vectorIndex_of_max_explained_var=vectorIndex_of_min_sum_residuals
all_v_spike_init_list = [v[all_thresholdInd[ii]] for ii, v in enumerate(all_v_spike_shape_list)]
# USE THIS WHEN MUTIPLE VECTORS all_v_at_min_expVar_list=[v[list_of_endPointArrays[vectorIndex_of_max_explained_var][ii]] for ii, v in enumerate(all_v_spike_shape_list)]
all_v_at_min_expVar_list = [v[list_of_endPointArrays[vectorIndex_of_max_explained_var][ii]] for ii, v in enumerate(all_v_spike_shape_list)]
time_at_minExpVar=list_of_endPointArrays[vectorIndex_of_max_explained_var]*dt
if MAKE_PLOT:
truncatedTime = np.arange(0, len(all_v_spike_shape_list[0])) * dt
plt.figure(figsize=(20, 10))
for ii in range(0, len(all_v_spike_shape_list)):
plt.subplot(2,1,1)
plt.plot(truncatedTime, temp_v_spike_shape_list[ii])
# plt.plot(truncatedTime[aligned_peakInd[ii]],spikewave[aligned_peakInd[ii]], '.k'
plt.plot(truncatedTime[all_thresholdInd[ii]], temp_v_spike_shape_list[ii][all_thresholdInd[ii]], '*k')
plt.title('Non adusted spikes')
plt.subplot(2,1,2)
plt.plot(truncatedTime, all_v_spike_shape_list[ii])
plt.plot(time_at_minExpVar, all_v_at_min_expVar_list, '*k')
plt.xlabel('time (s)', fontsize=20)
plt.ylabel('voltage (mV)', fontsize=20)
plt.title("Adjusted spikes (RP=%.3g, deltaV=%.3g)" % (El_reference,deltaV))
if PUBLICATION_PLOT:
truncatedTime = np.arange(0, len(all_v_spike_shape_list[0])) * dt
plt.figure(figsize=(20, 5))
for ii in range(0, len(all_v_spike_shape_list)):
plt.plot(truncatedTime*1000, temp_v_spike_shape_list[ii]*1e3, lw=2)
# plt.plot(truncatedTime[aligned_peakInd[ii]],spikewave[aligned_peakInd[ii]], '.k'
plt.plot(truncatedTime[all_thresholdInd[ii]]*1000, temp_v_spike_shape_list[ii][all_thresholdInd[ii]]*1e3, '.k', ms=10)
# plt.title('Spike Cutting', fontsize=20)
# plt.subplot(2,1,2)
# plt.plot(truncatedTime, all_v_spike_shape_list[ii])
plt.plot(time_at_minExpVar*1000, (np.array(all_v_at_min_expVar_list)+El_reference+deltaV)*1.e3, '.k', ms=10)
plt.xlabel('Time (ms)', fontsize=16)
plt.ylabel('Voltage (mV)', fontsize=16)
plt.xlim([0,12])
plt.tight_layout()
# plt.title("Adjusted spikes (RP=%.3g, deltaV=%.3g)" % (El_reference,deltaV))
if SHOW_PLOT:
plt.show(block=BLOCK)
# indNotExcluded_In_regress=list(np.setdiff1d(np.array([theInd for theInd in spikeIndDict['nonblip']]), np.array(waveIndOfFirstSpikes)))
# something is wrong with all_v_at_min_expVar_list--look at the difference between starting at .003 and .005 after thresh
if MAKE_PLOT:
plt.figure(figsize=(20, 10))
plt.plot(all_v_spike_init_list, all_v_at_min_expVar_list, 'b.', ms=16, label='noise') # list of voltage traces for blip
plt.xlabel('voltage at spike initiation (V)', fontsize=20)
plt.ylabel('voltage after spike (V)', fontsize=20)
# plt.title(cellTitle, fontsize=20)
slope_at_min_expVar_list, intercept_at_min_expVar_list, r_value_at_min_expVar_list, p_value_at_min_expVar_list, std_err_at_min_expVar_list = \
stats.linregress(np.array(all_v_spike_init_list), np.array(all_v_at_min_expVar_list))
print('mean of voltage before spike', np.mean(all_v_spike_init_list))
print('mean of voltage after spike', np.mean(all_v_at_min_expVar_list))
spike_cut_length= (list_of_endPointArrays[vectorIndex_of_max_explained_var][0])-int(ALIGN_CUT_WINDOW[0]/dt) #note this is dangerous if they arent' all at the same ind
if MAKE_PLOT:
xlim = np.array([min(all_v_spike_init_list), max(all_v_spike_init_list)])
plotLineRegress1(slope_at_min_expVar_list, intercept_at_min_expVar_list, r_value_at_min_expVar_list, xlim)
plt.legend(loc=2, fontsize=20)
if MAKE_PLOT:
xlim = np.array([min(all_v_spike_init_list), max(all_v_spike_init_list)])
plotLineRegressRed(slope_at_each_time_end[vectorIndex_of_max_explained_var], intercept_at_each_time_end[vectorIndex_of_max_explained_var], np.NAN, xlim)
plt.legend(loc=2, fontsize=20)
if SHOW_PLOT:
plt.show(block=BLOCK)
if PUBLICATION_PLOT:
plt.figure(figsize=(7, 5))
plt.plot(np.array(all_v_spike_init_list)*1e3, np.array(all_v_at_min_expVar_list)*1e3, 'b.', ms=16) # list of voltage traces for blip
plt.xlabel('Voltage at spike initiation (mV)', fontsize=16)
plt.ylabel('Voltage after spike (mV)', fontsize=16)
# plt.title('Voltage reset rules', fontsize=20)
xlim = np.array([min(all_v_spike_init_list), max(all_v_spike_init_list)])
def plot_hack(slope, intercept, r,xlim):
y=slope*xlim+intercept
plt.plot(xlim, y, '-k', lw=4)# label='slope='+"%.2f"%slope+', intercept='+"%.3f"%intercept)
plot_hack(slope_at_min_expVar_list, intercept_at_min_expVar_list*1e3, r_value_at_min_expVar_list, xlim*1e3)
plt.legend(loc=2, fontsize=16)
plt.tight_layout()
plt.show(block=BLOCK)
#TODO: Corinne look to see if these were calculated with zeroed out El if not does is matter?
if isinstance(slope_at_min_expVar_list, np.ndarray):
slope_at_min_expVar_list=float(slope_at_min_expVar_list[0])
if isinstance(intercept_at_min_expVar_list, np.ndarray):
intercept_at_min_expVar_list=float(intercept_at_min_expVar_list[0])
if type(intercept_at_min_expVar_list)==list or type(intercept_at_min_expVar_list)==np.ndarray:
intercept_at_min_expVar_list=intercept_at_min_expVar_list[0]
return spike_cut_length, slope_at_min_expVar_list, intercept_at_min_expVar_list
[docs]def plotLineRegress1(slope, intercept, r,xlim):
y=slope*xlim+intercept
print('slope=', slope, 'intercept=', intercept, 'xlim', xlim)
plt.plot(xlim, y, '-k', lw=4, label='slope='+"%.2f"%slope+', intercept='+"%.3f"%intercept+', r='+"%.2f"%r)
[docs]def plotLineRegressRed(slope, intercept, r,xlim):
y=slope*xlim+intercept
print('slope=', slope, 'intercept=', intercept, 'xlim', xlim)
plt.plot(xlim, y, '-r', lw=4, label='slope='+"%.2f"%slope+', intercept='+"%.3f"%intercept+', r='+"%.2f"%r)