Source code for allensdk.internal.model.glif.spike_cutting

import numpy as np
from scipy import stats
from scipy.optimize import curve_fit, fmin
from allensdk.internal.model.glif.find_spikes import align_and_cut_spikes, ALIGN_CUT_WINDOW
import logging

import matplotlib.pyplot as plt 

[docs]def calc_spike_cut_and_v_reset_via_expvar_residuals(all_current_list, all_voltage_list, dt, El_reference, deltaV, max_spike_cut_time=False, MAKE_PLOT=False, SHOW_PLOT=False, PUBLICATION_PLOT=False, BLOCK=False): '''This function calculates where the spike should be cut based on explained variance. The goal is to find a model where the voltage after a spike maximally explains the voltage before a spike. This will also specify the voltage reset rule inputs: spike_determination_method: string specifing the method used to find threshold all_current_list: list of current (list of current traces injected into neuron) all_voltage_list: list of voltages (list of voltage trace) The change is that if the slope is greater than one or intercept is greater than zero it forces it. Regardless of required force the residuals are used. ''' #--find the region of the spike needed for calculation of explained variance (temp_v_spike_shape_list, all_i_spike_shape_list, all_thresholdInd, waveIndOfFirstSpikes, spikeFromWhichSweep) \ = align_and_cut_spikes(all_voltage_list, all_current_list, dt) #--At this point it is unclear how this calculation should be done. #--the slope should be fine no matter what, but the intercept dependency #--will depend on the El, and deltaV #--change reference all_v_spike_shape_list=[shape-El_reference-deltaV for shape in temp_v_spike_shape_list] # --setting limits to find explained variance if max_spike_cut_time and max_spike_cut_time < .010: expVarIndRangeAfterSpike = range(int(.001 / dt), int(max_spike_cut_time / dt)) #NOTE: THIS IS USED IN REFERENCE TO SPIKE TIME else: expVarIndRangeAfterSpike = range(int(.001 / dt), int(.010 / dt)) #NOTE: THIS IS USED IN REFERENCE TO SPIKE TIME vectorIndex_of_max_explained_var = expVarIndRangeAfterSpike[0] # this is just here for the title of the plot list_of_endPointArrays = [] # this should end up a list of numpy arrays where each numpy array contains the indices of the v_spike_shape_list that are a certain time after the threshold for ii in expVarIndRangeAfterSpike: list_of_endPointArrays.append(np.array(all_thresholdInd) + ii) def line_force_slope_to_1(x,c): return x+c def line_force_int_to_0(x, m): #TODO: CHANGE THIS TO REST TOD DISCONNECT EVERYTHING. return m*x # HERE YOU GET THE SLOPE AND INTERCEPT AT EACH POINT linRegress_error_4_each_time_end = [] slope_at_each_time_end=[] intercept_at_each_time_end=[] varData_4_each_time_end = [] varModel_4_each_time_end = [] chi2 = [] sum_residuals_4_each_time_end=[] xdata = np.array([v[all_thresholdInd[ii]] for ii, v in enumerate(all_v_spike_shape_list)]) var_of_Vdata_beforeSpike = np.var(xdata) for jj, vectorOfIndAcrossWaves in enumerate(list_of_endPointArrays): # these indices should be in terms of the spike waveforms # print('jj', jj) # TODO: Teeter get rid of the nonblipness v_at_specificEndPoint = [all_v_spike_shape_list[ii][index] for ii, index in enumerate(vectorOfIndAcrossWaves)] # this is calculating variance at certain time points # --currently the model of voltage reset is a linear regression between voltage before the spike and the voltage after the spike but it could be more complicated (for example as a function of current) ydata = np.array(v_at_specificEndPoint) # this is the voltage at the specified end point slope, intercept, r_value, p_value, std_err = stats.linregress(xdata, ydata) # print(slope, intercept, r_value, p_value, std_err) # if slope>1.0: # logging.warning('linear regression slope is bigger than one: forcing slope to 1 and refitting intercept.') # slope=1.0 # (intercept, nothing)=curve_fit(line_force_slope_to_1, xdata, ydata) # #print("NEW INTERCEPT:", intercept) # if intercept>0.0: # #warnings.warn('/t ... and intercept is bigger than zero: forcing intercept to 0') # intercept=0.0 # # if intercept>0.0: # logging.warning('Intercept is bigger than zero: forcing intercept to 0 and refitting slope.') # intercept=0.0 # (slope, nothing)=curve_fit(line_force_int_to_0,xdata, ydata) # #print("NEW SLOPE: ", slope) # if slope>1.0: # logging.warning('/t ... and linear regression slope is bigger than one: forcing slope to 1.') # slope=1.0 slope_at_each_time_end.append(slope) intercept_at_each_time_end.append(intercept) ymodel = slope * xdata + intercept residuals = ydata - ymodel sum_residuals=sum(abs(residuals)) sum_residuals_4_each_time_end.append(sum_residuals) chi2.append(np.var(residuals)) # how well the model describes the data linRegress_error_4_each_time_end.append(std_err) varData_4_each_time_end.append(np.var(v_at_specificEndPoint)) varModel_4_each_time_end.append(np.var(ymodel)) # --these will line up with how many arrays there are in the list vectorIndex_of_min_sum_residuals = sum_residuals_4_each_time_end.index(min(sum_residuals_4_each_time_end)) #----NOTE THIS ISNT ACTUALLY CALCULATING EXPLAINED VARIANCE!!!!!!!!!!!!!!!!!! vectorIndex_of_max_explained_var=vectorIndex_of_min_sum_residuals all_v_spike_init_list = [v[all_thresholdInd[ii]] for ii, v in enumerate(all_v_spike_shape_list)] # USE THIS WHEN MUTIPLE VECTORS all_v_at_min_expVar_list=[v[list_of_endPointArrays[vectorIndex_of_max_explained_var][ii]] for ii, v in enumerate(all_v_spike_shape_list)] all_v_at_min_expVar_list = [v[list_of_endPointArrays[vectorIndex_of_max_explained_var][ii]] for ii, v in enumerate(all_v_spike_shape_list)] time_at_minExpVar=list_of_endPointArrays[vectorIndex_of_max_explained_var]*dt if MAKE_PLOT: truncatedTime = np.arange(0, len(all_v_spike_shape_list[0])) * dt plt.figure(figsize=(20, 10)) for ii in range(0, len(all_v_spike_shape_list)): plt.subplot(2,1,1) plt.plot(truncatedTime, temp_v_spike_shape_list[ii]) # plt.plot(truncatedTime[aligned_peakInd[ii]],spikewave[aligned_peakInd[ii]], '.k' plt.plot(truncatedTime[all_thresholdInd[ii]], temp_v_spike_shape_list[ii][all_thresholdInd[ii]], '*k') plt.title('Non adusted spikes') plt.subplot(2,1,2) plt.plot(truncatedTime, all_v_spike_shape_list[ii]) plt.plot(time_at_minExpVar, all_v_at_min_expVar_list, '*k') plt.xlabel('time (s)', fontsize=20) plt.ylabel('voltage (mV)', fontsize=20) plt.title("Adjusted spikes (RP=%.3g, deltaV=%.3g)" % (El_reference,deltaV)) if PUBLICATION_PLOT: truncatedTime = np.arange(0, len(all_v_spike_shape_list[0])) * dt plt.figure(figsize=(20, 5)) for ii in range(0, len(all_v_spike_shape_list)): plt.plot(truncatedTime*1000, temp_v_spike_shape_list[ii]*1e3, lw=2) # plt.plot(truncatedTime[aligned_peakInd[ii]],spikewave[aligned_peakInd[ii]], '.k' plt.plot(truncatedTime[all_thresholdInd[ii]]*1000, temp_v_spike_shape_list[ii][all_thresholdInd[ii]]*1e3, '.k', ms=10) # plt.title('Spike Cutting', fontsize=20) # plt.subplot(2,1,2) # plt.plot(truncatedTime, all_v_spike_shape_list[ii]) plt.plot(time_at_minExpVar*1000, (np.array(all_v_at_min_expVar_list)+El_reference+deltaV)*1.e3, '.k', ms=10) plt.xlabel('Time (ms)', fontsize=16) plt.ylabel('Voltage (mV)', fontsize=16) plt.xlim([0,12]) plt.tight_layout() # plt.title("Adjusted spikes (RP=%.3g, deltaV=%.3g)" % (El_reference,deltaV)) if SHOW_PLOT: plt.show(block=BLOCK) # indNotExcluded_In_regress=list(np.setdiff1d(np.array([theInd for theInd in spikeIndDict['nonblip']]), np.array(waveIndOfFirstSpikes))) # something is wrong with all_v_at_min_expVar_list--look at the difference between starting at .003 and .005 after thresh if MAKE_PLOT: plt.figure(figsize=(20, 10)) plt.plot(all_v_spike_init_list, all_v_at_min_expVar_list, 'b.', ms=16, label='noise') # list of voltage traces for blip plt.xlabel('voltage at spike initiation (V)', fontsize=20) plt.ylabel('voltage after spike (V)', fontsize=20) # plt.title(cellTitle, fontsize=20) slope_at_min_expVar_list, intercept_at_min_expVar_list, r_value_at_min_expVar_list, p_value_at_min_expVar_list, std_err_at_min_expVar_list = \ stats.linregress(np.array(all_v_spike_init_list), np.array(all_v_at_min_expVar_list)) print('mean of voltage before spike', np.mean(all_v_spike_init_list)) print('mean of voltage after spike', np.mean(all_v_at_min_expVar_list)) spike_cut_length= (list_of_endPointArrays[vectorIndex_of_max_explained_var][0])-int(ALIGN_CUT_WINDOW[0]/dt) #note this is dangerous if they arent' all at the same ind if MAKE_PLOT: xlim = np.array([min(all_v_spike_init_list), max(all_v_spike_init_list)]) plotLineRegress1(slope_at_min_expVar_list, intercept_at_min_expVar_list, r_value_at_min_expVar_list, xlim) plt.legend(loc=2, fontsize=20) if MAKE_PLOT: xlim = np.array([min(all_v_spike_init_list), max(all_v_spike_init_list)]) plotLineRegressRed(slope_at_each_time_end[vectorIndex_of_max_explained_var], intercept_at_each_time_end[vectorIndex_of_max_explained_var], np.NAN, xlim) plt.legend(loc=2, fontsize=20) if SHOW_PLOT: plt.show(block=BLOCK) if PUBLICATION_PLOT: plt.figure(figsize=(7, 5)) plt.plot(np.array(all_v_spike_init_list)*1e3, np.array(all_v_at_min_expVar_list)*1e3, 'b.', ms=16) # list of voltage traces for blip plt.xlabel('Voltage at spike initiation (mV)', fontsize=16) plt.ylabel('Voltage after spike (mV)', fontsize=16) # plt.title('Voltage reset rules', fontsize=20) xlim = np.array([min(all_v_spike_init_list), max(all_v_spike_init_list)]) def plot_hack(slope, intercept, r,xlim): y=slope*xlim+intercept plt.plot(xlim, y, '-k', lw=4)# label='slope='+"%.2f"%slope+', intercept='+"%.3f"%intercept) plot_hack(slope_at_min_expVar_list, intercept_at_min_expVar_list*1e3, r_value_at_min_expVar_list, xlim*1e3) plt.legend(loc=2, fontsize=16) plt.tight_layout() plt.show(block=BLOCK) #TODO: Corinne look to see if these were calculated with zeroed out El if not does is matter? if isinstance(slope_at_min_expVar_list, np.ndarray): slope_at_min_expVar_list=float(slope_at_min_expVar_list[0]) if isinstance(intercept_at_min_expVar_list, np.ndarray): intercept_at_min_expVar_list=float(intercept_at_min_expVar_list[0]) if type(intercept_at_min_expVar_list)==list or type(intercept_at_min_expVar_list)==np.ndarray: intercept_at_min_expVar_list=intercept_at_min_expVar_list[0] return spike_cut_length, slope_at_min_expVar_list, intercept_at_min_expVar_list
[docs]def plotLineRegress1(slope, intercept, r,xlim): y=slope*xlim+intercept print('slope=', slope, 'intercept=', intercept, 'xlim', xlim) plt.plot(xlim, y, '-k', lw=4, label='slope='+"%.2f"%slope+', intercept='+"%.3f"%intercept+', r='+"%.2f"%r)
[docs]def plotLineRegressRed(slope, intercept, r,xlim): y=slope*xlim+intercept print('slope=', slope, 'intercept=', intercept, 'xlim', xlim) plt.plot(xlim, y, '-r', lw=4, label='slope='+"%.2f"%slope+', intercept='+"%.3f"%intercept+', r='+"%.2f"%r)