import logging
import numpy as np
import h5py
#import matplotlib
#matplotlib.use('agg')
import matplotlib.pyplot as plt
import os
[docs]def background_trace(trace, save_dir, data_set=None):
fig,ax = plt.subplots(1)
ax.plot(trace)
if data_set is not None:
_add_stim_epochs(trace, ax, data_set)
save_file = os.path.join(save_dir, 'background_trace.pdf')
fig.savefig(save_file)
logging.info("Background Trace saved to %s", save_file)
plt.close(fig)
[docs]def correlation_report(dm, save_dir, without_masks=True):
'''
parameters:
dm: [DeMix object]
without_masks: boolean
'''
logging.info("Generating Correlation Report")
if without_masks:
cor, cor_demix = compute_correlations_without_masks(dm)
no_diagonal_mask = (1.0 - np.eye(cor.shape[0])).astype('bool')
fig, ax = plt.subplots(1)
ax.plot(dm.mask_overlap[:-1,:-1][no_diagonal_mask],cor[no_diagonal_mask]-cor_demix[no_diagonal_mask],'o')
ax.set_xlim([-1,np.max(dm.mask_overlap[:-1,:-1][no_diagonal_mask])])
ax.set_title('Delta Correlation vs. mask overlap')
save_file = os.path.join(save_dir,'cor_vs_overlap.pdf')
fig.savefig(save_file)
logging.info("\tCorrelation overlap saved to %s", save_file)
plt.close(fig)
fig, ax = plt.subplots(1,3)
delta_cor = cor- cor_demix
ax[0].hist(delta_cor[dm.mask_overlap[:-1,:-1]==0],bins=100)
ax[0].set_title('Delta Correlation')
ax[1].hist(cor[no_diagonal_mask],bins=100)
ax[1].set_title('Pre-demix Correlation')
ax[2].hist(cor_demix[no_diagonal_mask],bins=100)
ax[2].set_title('Post-demix Correlation')
save_file = os.path.join(save_dir,'cor_hist.pdf')
fig.savefig(save_file)
logging.info("\tCorrelation histograms saved to %s", save_file)
plt.close(fig)
else:
raise Exception('without_masks=False not yet implemented')
[docs]def plot_masks(dm, save_dir, movie_file, movie_dataset, window=150, add_background=True):
logging.info("Plotting masks")
overlap_pairs = [(x,y) for (x,y) in zip(*np.where(dm.mask_overlap >0)) if x>y and x!=dm.mask_overlap.shape[0]-1]
movie_data = h5py.File(movie_file,'r')
bg_traces = dm.get_traces_with_background()
for i,pair in enumerate(overlap_pairs):
fig_pair, ax_pair = plt.subplots(2,2)
mask = np.zeros(dm.masks[0].shape)
rgb_shape = (dm.masks[0].shape[0],dm.masks[0].shape[1],3)
rgb_mask = np.zeros(rgb_shape,dtype=np.uint8)
#for p in pair:
rgb_mask[:,:,2] = 255*dm.masks[pair[0]]
rgb_mask[:,:,1] = 255*dm.masks[pair[1]]
for p in pair:
mask += dm.masks[p]
non_zeros = np.where(mask)
non_zeros_mask = np.zeros(mask.shape)
ylower = np.min(non_zeros[0])
yupper = np.max(non_zeros[0])
xlower = np.min(non_zeros[1])
xupper = np.max(non_zeros[1])
#ax_pair[0,0].imshow(mask[ylower:yupper,xlower:xupper])
ax_pair[0,0].imshow(rgb_mask[ylower:yupper,xlower:xupper])
trace1 = dm.traces[pair[0]]
trace2 = dm.traces[pair[1]]
if add_background:
trace1_demix = bg_traces[pair[0]]
trace2_demix = bg_traces[pair[1]]
else:
trace1_demix = dm.traces_demix[pair[0]]
trace2_demix = dm.traces_demix[pair[1]]
center = np.where(trace1==np.max(trace1))[0][0]
ax_pair[1,0].plot(trace1[(center-window):(center+window)],label=str(pair[0]))
ax_pair[1,0].plot(trace2[(center-window):(center+window)],label=str(pair[1]))
ax_pair[1,0].legend()
ax_pair[1,1].plot(trace1_demix[center-window:center+window],label=str(pair[0]))
ax_pair[1,1].plot(trace2_demix[center-window:center+window],label=str(pair[1]))
ax_pair[1,1].legend()
ax_pair[0,1].imshow(movie_data[movie_dataset][center,ylower:yupper,xlower:xupper])
save_file = os.path.join(save_dir,'masks_'+str(pair[0])+'_'+str(pair[1])+'.pdf')
fig_pair.savefig(save_file)
plt.close(fig_pair)
logging.info("\tMask saved to %s", save_file)
#print(overlap_pairs)
#print(np.unique(dm.mask_overlap[no_diagonal_mask][dm.mask_overlap[no_diagonal_mask]>0]))
def _get_epoch_windows(stim_table):
start = np.array(stim_table.start)
end = np.array(stim_table.end)
windows = zip(start,end)
window_list = [[start[0]]]
for i,w in enumerate(windows[1:]):
if start[i+1] - end[i]>1:
window_list[-1].append(end[i])
window_list.append([start[i+1]])
#window_list += [end[i],start[i+1]]
window_list[-1].append(end[-1])
#window_list = [start[0]]
#window_list += [ start[x+1] for x in list(np.where(np.abs(start[1:] - end[:-1]) > 1)[0])]
#window_list += [end[-1]]
#print(window_list)
return window_list
def _add_stim_epochs(trace,ax,data_set):
stim_colors_dict = {'locally_sparse_noise':'green','drifting_gratings':'yellow','natural_movie_one':'magenta','natural_movie_two':'magenta','natural_movie_three':'red','natural_scenes':'orange','spontaneous':'grey','static_gratings':'blue'}
from allensdk.brain_observatory.stimulus_info import stimuli_in_session
stim_types = stimuli_in_session(data_set.get_metadata()['session_type'])
for stim in stim_types:
#print(stim)
stim_table = data_set.get_stimulus_table(stim)
window_list = _get_epoch_windows(stim_table)
for w in window_list:
#ax.fill_betweenx(np.arange(trace.shape[0]),w[0],w[1],facecolor=stim_colors_dict[stim],alpha=0.2)
#ax.axvspan(w[0],w[1],np.min(trace),np.max(trace),facecolor=stim_colors_dict[stim],alpha=0.2)
ax.axvspan(w[0],w[1],facecolor=stim_colors_dict[stim],alpha=0.2)
ax.set_ylim([np.min(trace),np.max(trace)])
[docs]def compute_non_overlap_masks(dm):
no_masks = np.zeros(dm.masks.shape).astype(int)
overlap_val = np.zeros(dm.masks.shape[0])
for i, m in enumerate(dm.masks):
overlap_1 = np.sum(dm.masks[:i],axis=0)
overlap_2 = np.sum(dm.masks[i+1:],axis=0)
overlap = overlap_1 + overlap_2
overlap_val[i] = np.sum(overlap)
no1 = overlap == 0
#no2 = overlap_2 == 0
no_masks[i] = np.logical_and(no1, m)
dm.no_masks = no_masks
dm.overlap = overlap
return dm.no_masks
[docs]def compute_non_overlap_traces(dm, movie_path, movie_dataset):
no_traces_shape = (dm.traces.shape[0],dm.traces.shape[1])
no_traces = np.zeros(no_traces_shape)
N, T = no_traces.shape
chunk_size = 1000
num_chunks = int(np.ceil(T/float(chunk_size)))
normalized_flat_masks = dm.no_masks.reshape(N,-1).T # shape (pixels, N)
normalized_flat_masks /= np.sum(normalized_flat_masks,axis=0) # shape(pixels, N)
movie_f = h5py.File(movie_path)
movie = movie_f[movie_dataset]
logging.debug("Getting traces for %d chunks", num_chunks)
for n in range(num_chunks):
print("Chunk = ", n)
data = movie[n*chunk_size:(n+1)*chunk_size]
data = data.reshape(chunk_size,-1) # This line throws an error
no_traces[:,n*chunk_size:(n+1)*chunk_size] = np.dot(data,normalized_flat_masks).T
movie_f.close()
logging.debug("Done")
dm.no_traces = no_traces
return dm.no_traces
[docs]def compute_correlations(dm, movie_path, movie_dataset):
compute_non_overlap_masks(dm)
compute_non_overlap_traces(dm, movie_path, movie_dataset)
no_mean = np.mean(dm.no_traces)
dm_mean = np.mean(dm.traces_demix)
t_mean = np.mean(dm.traces)
C_no_dm = np.mean( (dm.no_traces-no_mean)*(dm.traces_demix-dm_mean), axis=1)
C_t_dm = np.mean( (dm.traces-t_mean)*(dm.traces_demix-dm_mean), axis=1)
return C_no_dm, C_t_dm
[docs]def compute_correlations_without_masks(dm):
N, T = dm.traces.shape
N=N -1
traces = (dm.traces.T - np.mean(dm.traces.T,axis=0)) # shape (T,N)
traces_demix = (dm.traces_demix.T - np.mean(dm.traces_demix.T,axis=0)) # shape (T,N)
traces /= np.std(traces,axis=0)
traces_demix /= np.std(traces_demix,axis=0)
cor = np.dot(traces.T,traces)/T
cor_demix = np.dot(traces_demix.T,traces_demix)/T
return cor[:N,:N], cor_demix[:N,:N]