Source code for allensdk.internal.model.glif.ASGLM

import numpy as np
import itertools
import allensdk.internal.model.GLM as GLM
import logging
import statsmodels.api as sm
import matplotlib.pyplot as plt

[docs]def ASGLM_pairwise(ks_int, I_stim, voltage, spike_ind, cinit, tauinit, SCL, dt, resting_potential, SHORT_RUN=False, MAKE_PLOT=False, SHOW_PLOT=False, BLOCK=False): '''Calculate the resistance and amplitude of the afterspike currents for Parameters ---------- ks_int: list initial possible k's (k=1/tau, where tau is the time constant of the exponential decay) I_stim: list of arrays input stimulus traces of sweeps voltage: list of arrays voltage of cell as a result of I_stim spike_ind: list of arrays each array contains the index of the spikes cinit: float membrane capacitance tauinit: float time constant of membrane SCL: float number of indicies that should be cut after a spike dt: float size of time step of injected current Returns ''' #Initialize post-spike filter parameters (MOST OF THESE ARE HARD-CODED currently) nkt = 8000 # arbitrary length of filter WANT FILTER TO COVER A LENGTH OF TIME, WANT FILTER TO BE LONGER THAN LONGEST LENGTH ASC DTsim = dt #DTsim = dt means filter is nkt*dt = 100 ms long in this case THIS SHOULD INDEED BE THE SAMPLE WIDTH neye = 0 #no of identity basis vectors DELTA f = 1e-3/dt #pre-factor for getting time units correct for bases taus_int=[1000./kk for kk in ks_int] #converting to ms taus_filter = [f*j for j in taus_int] #convert ms time to filter time units 1/(dt*ks) ks_list = [(1.0/(i)) for i in taus_filter] #use peak positions of rcos bumps as time-scales for exponential bases ncos = 2 #no of bases!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! flag_exp = 1 #flag_exp = 1 means use exponential bases, else use raised-cosine bumps vL = resting_potential # GLM fit with post-spike currents tst = 0 #190000#355000 #time-step to start npcut = int(SCL) # no of points to cut after each spike-initiation # Collect spikes between tst and tend t0_list = [] for mm in range(len(I_stim)): tend = len(I_stim[mm]) t0=[] for jj in range(len(spike_ind[mm])): if spike_ind[mm][jj]>tst and spike_ind[mm][jj]<tend: t0.append(spike_ind[mm][jj]) t0_list.append(t0) #Create a list of pairs of ks ks_pairs = list(itertools.combinations(ks_list,ncos)) ks_pairs_in_SI_units= list(itertools.combinations(ks_int, ncos)) if len(ks_pairs)!=10: raise Exception('figure subplots will need to be changed as there is a different number than 10 ks_pairs.') #Initialize list to hold charge dump values and amp-vectors fitprs_list = [] llf_list = [] R_list = [] #Iterate over all pairs if SHORT_RUN: logging.warning("You are not doing all the ks pairs in ASGLM_pairwise") ks_pairs=[ks_pairs[0]] R_for_all_ks_pairs=[] asc_amp_for_all_ks_pairs=[] El_for_all_ks_pairs=[] C_for_all_ks_pairs=[] llh_for_all_ks_pairs=[] for ks_ind, (ks_fit_units, ks_SI_units) in enumerate(zip(ks_pairs, ks_pairs_in_SI_units)): print('ks_fit_units', ks_fit_units) #Create basis IPSPs if MAKE_PLOT: plotting_colors=['r', 'b', 'g', 'm', 'c'] plt.figure(78, figsize=(20,10)) basis_IPSP_list = [] for rr in range(len(I_stim)): #loop over repeats #find the basis of the entire trace basis_IPSP, gg0 = GLM.create_basis_IPSP(neye,ncos,taus_filter,ks_fit_units,DTsim,t0_list[rr],I_stim[rr],nkt,flag_exp,npcut) basis_IPSP_list.append(basis_IPSP) #--Plot basis IPSPs between si and se si = t0_list[0][0]-10 #plot start_ind se = si+nkt+10 #plot end_ind tvec = dt*np.arange(si-tst,se-tst) #convert time-steps to real time (in sec) if MAKE_PLOT: plt.figure(78) plt.subplot(5,2, ks_ind+1) plt.plot(1e3*tvec,basis_IPSP[si:se,:], lw=2, label=str(rr)) #1e3 plots time on x-axis in ms plt.xlabel('time (ms)') plt.title("k's "+str(ks_SI_units)) if MAKE_PLOT: plt.annotate('ASGLM (fit asc and R): AScurrent basis', xy=(.4, .985), xycoords='figure fraction', horizontalalignment='left', verticalalignment='top', fontsize=20) plt.legend() plt.tight_layout() if SHOW_PLOT: plt.show(block=BLOCK) # cut spikes out of dv, v, i, and b_ipsp and put the different sweeps in lists i_all_swps_list = [] b_ipsp_all_swps_list = [] v_all_swps_list = [] dv_all_swps_list = [] for ss in range(len(I_stim)): #loop over repeats tend = len(I_stim[ss]) i = I_stim[ss][tst:tend-1] b_ipsp = basis_IPSP_list[ss][tst:tend-1] v = voltage[ss][tst:tend-1] vs = voltage[ss][tst+1:tend] dv = (vs-v)/dt #derivative of voltage #delete npcut points after spike from each qty delpts = [] for kk in range(len(t0_list[ss])): delpts.append(range(int(t0_list[ss][kk])-tst,int(t0_list[ss][kk])-tst+npcut)) dv = np.delete(dv,delpts,0) v = np.delete(v,delpts,0) i = np.delete(i,delpts,0) b_ipsp = np.delete(b_ipsp,delpts,0) v_all_swps_list.append(v) dv_all_swps_list.append(dv) i_all_swps_list.append(i) b_ipsp_all_swps_list.append(b_ipsp) tvec = dt*np.arange(len(v)) # Compute amplitude each basis AS current using a GLM if MAKE_PLOT: plt.figure(79, figsize=(20, 12)) plt.figure(80, figsize=(20, 12)) R_for_each_sweep=[] asc_amp_for_each_sweep=[] llh_for_each_sweep=[] for kkk, (v_spike_deleted, dv_spike_deleted, i_spike_deleted, b_ipsp_spikes_deleted) in \ enumerate(zip(v_all_swps_list, dv_all_swps_list, i_all_swps_list, b_ipsp_all_swps_list)): #--fitting afterspike current amplitudes and resistance inp = np.zeros((len(i_spike_deleted),ncos+1)) inp[:,range(0,ncos)] = (1/cinit)*b_ipsp_spikes_deleted[:,range(ncos)] inp[:,ncos] = -(v_spike_deleted-vL)/tauinit out = dv_spike_deleted -i_spike_deleted/cinit#+ (v-vL)/tauinit - i/cinit try: glm_fit = sm.GLM(out,inp,family=sm.families.Gaussian(sm.families.links.identity)) res = glm_fit.fit() fitprs = res.params #fitprs has [AMP OF ASC, TAU] llh=res.llf fit_R=tauinit/(fitprs[ncos]*cinit) fit_asc_amp=fitprs[:ncos] #Compute and plot post-spike current (essentially multiply basis functions with correct amplitudes from GLM fit) ipsc = np.sum(b_ipsp_spikes_deleted[:,0:ncos]*fit_asc_amp,1) #THIS IS TOTAL POSTSPIKE CURRENT except Exception as e: logging.warning("fit didn't work: " + str(e)) llh=np.nan fit_R=np.NAN fit_asc_amp=np.ones(ncos)*np.NAN ipsc=np.ones(len(b_ipsp_spikes_deleted[:,0]))*np.NAN R_for_each_sweep.append(fit_R) asc_amp_for_each_sweep.append(fit_asc_amp) llh_for_each_sweep.append(llh) #Plot a single instance of AS current as function of time (in ms) if MAKE_PLOT: plt.figure(79) plt.subplot(5,2, ks_ind+1) plot_inds = np.arange(int(t0_list[0][0])-tst,int(t0_list[0][0])-tst+nkt) #plot just after first spike tvec = dt*plot_inds plt.plot(tvec,ipsc[plot_inds], lw=2, label='llh='+str(llh)) plt.xlabel('time (s)') plt.ylabel('current (A)') plt.title("k's "+str(ks_SI_units)) plt.figure(80) plt.subplot(5,2,ks_ind+1) for ASC, KS in zip(fit_asc_amp, np.array(ks_SI_units)): #TODO: MAKE SURE I SHOULD BE USING SI UNITS HERE {I think this is correct as it is what I am returning from the function] #TODO: MAKE SURE THE SIGNS OF K ARE OK t_plot=np.arange(10000)*dt single_asc_trace=ASC*np.exp(-KS*t_plot) plt.plot(t_plot, single_asc_trace, lw=2, label='sweep '+str(kkk)) plt.xlabel('time (s)') plt.ylabel('current (A)') plt.title("k's "+str(ks_SI_units)) if MAKE_PLOT: plt.figure(79) plt.tight_layout() plt.annotate('ASGLM (fit asc and R): Sum fit after spike currents', xy=(.3, .985), xycoords='figure fraction', horizontalalignment='left', verticalalignment='top', fontsize=20) plt.legend() plt.figure(80) plt.tight_layout() plt.annotate('ASGLM (fit asc and R): Individual Currents', xy=(.3, .985), xycoords='figure fraction', horizontalalignment='left', verticalalignment='top', fontsize=20) plt.legend() if SHOW_PLOT: plt.show(block=False) # #BELOW IS OVER EVERY PAIR R_for_all_ks_pairs.append(R_for_each_sweep) asc_amp_for_all_ks_pairs.append(asc_amp_for_each_sweep) llh_for_all_ks_pairs.append(llh_for_each_sweep) #!!!!!!!!!!!!!we multiplied ks by dt for SI!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! ave_llh_for_each_pair=np.mean(llh_for_all_ks_pairs, axis=1) best_ks_pair_ind = np.where(np.max(ave_llh_for_each_pair)==ave_llh_for_each_pair)[0][0] best_k_pair=np.array(ks_pairs[best_ks_pair_ind])/dt best_asc_amp=np.array(asc_amp_for_all_ks_pairs[best_ks_pair_ind]) best_R=np.array(R_for_all_ks_pairs[best_ks_pair_ind]) best_llh=np.array(llh_for_all_ks_pairs[best_ks_pair_ind]) print('**********from ASGLM_pairwise******************************************') print('best_ks_pair_ind', best_ks_pair_ind) print('best_asc_amp', best_asc_amp) print('best_k_pair', best_k_pair) print('best_R', best_R) print('best_llh', best_llh) print('**********done with ASGLM_pairwise***********************************') return best_k_pair, best_asc_amp, best_R, best_llh