from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.pyplot as plt
import numpy as np
[docs]def plot_spike_counts(
data_array,
time_coords,
cbar_label,
title,
xlabel='time relative to stimulus onset (s)',
ylabel='unit',
xtick_step=20
): # pragma: no cover
'''Utility for making a simple spike counts plot.
Parameters
----------
data_array : xarray.DataArray
2D data array unitwise values per time bin. See EcephysSession.sweepwise_spike_counts
'''
fig, ax = plt.subplots(figsize=(12, 12))
div = make_axes_locatable(ax)
cbar_axis = div.append_axes("right", 0.2, pad=0.05)
img = ax.imshow(
data_array.T,
interpolation='none'
)
plt.colorbar(img, cax=cbar_axis)
cbar_axis.set_ylabel(cbar_label, fontsize=16)
ax.yaxis.set_major_locator(plt.NullLocator())
ax.set_ylabel(ylabel, fontsize=16)
reltime = np.array(time_coords)
ax.set_xticks(np.arange(0, len(reltime), xtick_step))
ax.set_xticklabels([f'{mp:1.3f}' for mp in reltime[::xtick_step]], rotation=45)
ax.set_xlabel(xlabel, fontsize=16)
ax.set_title(title, fontsize=20)
return fig
class _VlPlotter:
def __init__(self, ax, num_objects, cmap=plt.cm.tab20, cycle_colors=False):
self.ii = 0
self.ax = ax
self.num_objects = num_objects
self.cmap = cmap
self.cycle_colors = cycle_colors
def __call__(self, gb):
low = self.ii / self.num_objects
high = (self.ii + 1) / self.num_objects
cindex = self.ii % self.cmap.N if self.cycle_colors else np.random.randint(self.cmap.N)
color = self.cmap(cindex)
self.ax.vlines(gb.index.values, low, high, colors=color)
self.ii += 1
[docs]def raster_plot(spike_times, figsize=(8,8), cmap=plt.cm.tab20, title='spike raster', cycle_colors=False):
fig, ax = plt.subplots(figsize=figsize)
plotter = _VlPlotter(ax, num_objects=len(spike_times['unit_id'].unique()), cmap=cmap, cycle_colors=cycle_colors)
# aggregate is called on each column, so pass only one (eg the stimulus_presentation_id)
# to plot each unit once
spike_times[['stimulus_presentation_id', 'unit_id']].groupby('unit_id').agg(plotter)
ax.set_xlabel('time (s)', fontsize=16)
ax.set_ylabel('unit', fontsize=16)
ax.set_title(title, fontsize=20)
plt.yticks([])
plt.axis('tight')
return fig