import numpy as np
import scipy.ndimage as ndi
from scipy.optimize import curve_fit, leastsq
import logging
import matplotlib.pyplot as plt
from ...chisquare_categorical import chisq_from_stim_table
from .stimulus_analysis import StimulusAnalysis
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
logger = logging.getLogger(__name__)
[docs]class ReceptiveFieldMapping(StimulusAnalysis):
"""
A class for computing single-unit metrics from the receptive field mapping stimulus of an ecephys session NWB file.
To use, pass in a EcephysSession object::
session = EcephysSession.from_nwb_path('/path/to/my.nwb')
rf_analysis = ReceptiveFieldMapping(session)
or, alternatively, pass in the file path::
rf_analysis = ReceptiveFieldMapping('/path/to/my.nwb')
You can also pass in a unit filter dictionary which will only select units with certain properties. For example
to get only those units which are on probe C and found in the VISp area::
rf_analysis = ReceptiveFieldMapping(session, filter={'location': 'probeC', 'ecephys_structure_acronym': 'VISp'})
To get a table of the individual unit metrics ranked by unit ID::
metrics_table_df = rf_analysis.metrics()
"""
def __init__(self, ecephys_session, col_pos_x='x_position', col_pos_y='y_position', trial_duration=0.25,
minimum_spike_count=10.0, mask_threshold=0.5, **kwargs):
super(ReceptiveFieldMapping, self).__init__(ecephys_session, trial_duration=trial_duration, **kwargs)
self._pos_x = None
self._pos_y = None
self._rf_matrix = None
self._col_pos_x = col_pos_x
self._col_pos_y = col_pos_y
self._minimum_spike_count = minimum_spike_count
self._mask_threshold = mask_threshold
#if self._params is not None:
# self._params = self._params['receptive_field_mapping']
# self._stimulus_key = self._params['stimulus_key']
# self._minimum_spike_count = self._params.get('minimum_spike_count', minimum_spike_count)
# self._mask_threshold = self._params.get('mask_threshold', mask_threshold)
@property
def name(self):
return 'Receptive Field Mapping'
@property
def elevations(self):
""" Array of stimulus elevations """
if self._pos_y is None:
self._get_stim_table_stats()
return self._pos_y
@property
def azimuths(self):
""" Array of stimulus azimuths """
if self._pos_x is None:
self._get_stim_table_stats()
return self._pos_x
@property
def number_elevations(self):
""" Number of stimulus elevations """
if self._pos_y is None:
self._get_stim_table_stats()
return len(self._pos_y)
@property
def number_azimuths(self):
""" Number of stimulus azimuths """
if self._pos_x is None:
self._get_stim_table_stats()
return len(self._pos_y) # TODO: Save this instead of calculating every time.
@property
def null_condition(self):
""" Stimulus condition ID for null stimulus (not used, so set to -1) """
# TODO: Remove
return -1
@property
def receptive_fields(self):
""" Spatial receptive fields for N units (9 x 9 x N matrix of responses) """
if self._rf_matrix is None:
bin_edges = np.linspace(0, 0.249, 3)
self.stim_table.loc[:, self._col_pos_y] = 40.0 - self.stim_table[self._col_pos_y]
presentationwise_response_matrix = self.ecephys_session.presentationwise_spike_counts(
bin_edges=bin_edges,
stimulus_presentation_ids=self.stim_table.index.values,
unit_ids=self.unit_ids,
)
self._rf_matrix = self._response_by_stimulus_position(presentationwise_response_matrix, self.stim_table)
return self._rf_matrix
@property
def METRICS_COLUMNS(self):
return [('azimuth_rf', np.float64),
('elevation_rf', np.float64),
('width_rf', np.float64),
('height_rf', np.float64),
('area_rf', np.float64),
('p_value_rf', np.float64),
('on_screen_rf', bool),
('firing_rate_rf', np.float64),
('fano_rf', np.float64),
('time_to_peak_rf', np.float64),
('lifetime_sparseness_rf', np.float64),
('run_mod_rf', np.float64),
('run_pval_rf', np.float64)
]
@property
def metrics(self):
if self._metrics is None:
logger.info('Calculating metrics for ' + self.name)
unit_ids = self.unit_ids
metrics_df = self.empty_metrics_table()
if len(self.stim_table) > 0:
metrics_df.loc[:, ['azimuth_rf',
'elevation_rf',
'width_rf',
'height_rf',
'area_rf',
'p_value_rf',
'on_screen_rf',
]] = [self._get_rf_stats(unit) for unit in unit_ids]
metrics_df['firing_rate_rf'] = [self._get_overall_firing_rate(unit) for unit in unit_ids]
metrics_df['fano_rf'] = [self._get_fano_factor(unit, self._get_preferred_condition(unit))
for unit in unit_ids]
metrics_df['time_to_peak_rf'] = [self._get_time_to_peak(unit, self._get_preferred_condition(unit))
for unit in unit_ids]
metrics_df['lifetime_sparseness_rf'] = [self._get_lifetime_sparseness(unit) for unit in unit_ids]
metrics_df.loc[:, ['run_pval_rf', 'run_mod_rf']] = \
[self._get_running_modulation(unit, self._get_preferred_condition(unit)) for unit in unit_ids]
self._metrics = metrics_df
return self._metrics
[docs] @classmethod
def known_stimulus_keys(cls):
return ['receptive_field_mapping', 'gabor', "gabors"]
def _find_stimulus_key(self, stim_table):
known_keys_lc = [k.lower() for k in self.__class__.known_stimulus_keys()]
for table_key in stim_table['stimulus_name'].unique():
table_key_lc = table_key.lower()
for known_key in known_keys_lc:
if table_key_lc.startswith(known_key):
return table_key
else:
return None
def _get_stim_table_stats(self):
""" Extract azimuths and elevations from stimulus table."""
self._pos_y = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_pos_y]
!= 'null'][self._col_pos_y].unique())
self._pos_x = np.sort(self.stimulus_conditions.loc[self.stimulus_conditions[self._col_pos_x]
!= 'null'][self._col_pos_x].unique())
[docs] def get_receptive_field(self, unit_id):
""" Alias for _get_rf
"""
return self._get_rf(unit_id)
def _get_rf(self, unit_id):
""" Extract the receptive field for one unit
Parameters
----------
unit_id : int
unique ID for the unit of interest
Returns
-------
receptive_field : 9 x 9 numpy array
"""
return self.receptive_fields['spike_counts'].sel(unit_id=unit_id).data
def _response_by_stimulus_position(self, dataset, presentations, row_key=None, column_key=None, unit_key='unit_id',
time_key='time_relative_to_stimulus_onset', spike_count_key='spike_count'):
""" Calculate the unit's response to different locations
of the Gabor patch
Returns
-------
dataset : xarray
dataset of receptive fields
"""
if row_key is None:
row_key = self._col_pos_y
if column_key is None:
column_key = self._col_pos_x
dataset = dataset.copy()
dataset[spike_count_key] = dataset.sum(dim=time_key)
dataset = dataset.drop(time_key)
dataset[row_key] = presentations.loc[:, row_key]
dataset[column_key] = presentations.loc[:, column_key]
dataset = dataset.to_dataframe()
dataset = dataset.reset_index(unit_key).groupby([row_key, column_key, unit_key]).sum()
return dataset.to_xarray()
def _get_rf_stats(self, unit_id):
""" Calculate a variety of metrics for one unit's receptive field
Parameters
----------
unit_id : int
unique ID for the unit of interest
Returns
-------
azimuth :
preferred azimuth in degrees, based on center of mass of thresholded RF
elevation :
preferred elevation in degrees, based on center of mass of thresholded RF
width :
receptive field width in degrees, based on Gaussian fit
height :
receptive field height in degrees, based on Gaussian fit
area :
receptive field area in degrees^2, based on thresholded RF area
p_value :
probability that a significant receptive field is present, based on categorical chi-square test
on_screen :
True if the receptive field is away from the screen edge, based on Gaussian fit
"""
rf = self._get_rf(unit_id)
spikes_per_trial = self.presentationwise_statistics.xs(unit_id, level=1)['spike_counts'].values
if np.sum(spikes_per_trial) < self._minimum_spike_count:
return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, False
p_value = chisq_from_stim_table(self.stim_table, [self._col_pos_x, self._col_pos_y],
np.expand_dims(spikes_per_trial,1))
#print(self._params)
#exit()
rf_thresh, azimuth, elevation, area = threshold_rf(rf, self._mask_threshold)
if is_rf_inverted(rf_thresh):
rf = invert_rf(rf)
(peak_height, center_y, center_x, width_y, width_x), success = fit_2d_gaussian(rf)
on_screen = rf_on_screen(rf, center_y, center_x)
height_deg = convert_pixels_to_degrees(width_y)
width_deg = convert_pixels_to_degrees(width_x)
azimuth_deg = convert_azimuth_to_degrees(azimuth)
elevation_deg = convert_elevation_to_degrees(elevation)
area_deg = convert_pixel_area_to_degrees(area)
return azimuth_deg, elevation_deg, width_deg, height_deg, area_deg, p_value[0], on_screen
## VISUALIZATION ##
[docs] def plot_raster(self, stimulus_condition_id, unit_id):
""" Plot raster for one condition and one unit """
idx_elev = np.where(self.elevations == self.stimulus_conditions.loc[stimulus_condition_id][self._col_pos_y])[0]
idx_azi = np.where(self.azimuths == self.stimulus_conditions.loc[stimulus_condition_id][self._col_pos_x])[0]
if len(idx_elev) == len(idx_azi) == 1:
presentation_ids = self.presentationwise_statistics.xs(unit_id, level=1)[
self.presentationwise_statistics.xs(unit_id, level=1)[
'stimulus_condition_id'] == stimulus_condition_id].index.values
df = self.presentationwise_spike_times[ \
(self.presentationwise_spike_times['stimulus_presentation_id'].isin(presentation_ids)) & \
(self.presentationwise_spike_times['unit_id'] == unit_id) ]
x = df.index.values - self.stim_table.loc[df.stimulus_presentation_id].start_time
_, y = np.unique(df.stimulus_presentation_id, return_inverse=True)
idx_elev = self.number_elevations - idx_elev - 1 # reverse the elevation index so it matches the RF
plt.subplot(self.number_elevations, self.number_azimuths, idx_elev*self.number_azimuths + idx_azi + 1)
plt.scatter(x, y, c='k', s=1, alpha=0.25)
plt.axis('off')
[docs] def plot_rf(self, unit_id):
""" Plot the spike counts across conditions """
plt.imshow(self._get_rf(unit_id), cmap='Greys')
plt.axis('off')
#### HELPER FUNCTIONS ####
def _gaussian_function_2d(peak_height, center_y, center_x, width_y, width_x):
"""Returns a 2D Gaussian function
Parameters
----------
peak_height :
peak of distribution
center_y :
y-coordinate of distribution center
center_x :
x-coordinate of distribution center
width_y :
width of distribution along x-axis
width_x :
width of distribution along y-axis
Returns
-------
f(x,y) : function
Returns the value of the distribution at a particular x,y coordinate
"""
return lambda x,y: peak_height \
* np.exp( \
-( \
((center_y - y) / width_y)**2 \
+ ((center_x - x) / width_x)**2 \
) \
/ 2 \
)
[docs]def gaussian_moments_2d(data):
"""Finds the moments of a 2D Gaussian distribution, given an input matrix
Parameters
----------
data : numpy.ndarray
2D matrix
Returns
-------
peak_height :
peak of distribution
center_y :
y-coordinate of distribution center
center_x :
x-coordinate of distribution center
width_y :
width of distribution along x-axis
width_x :
width of distribution along y-axis
"""
total = data.sum()
height = data.max()
Y, X = np.indices(data.shape)
center_y = (Y*data).sum()/total
center_x = (X*data).sum()/total
if np.isnan(center_y) or np.isinf(center_y) or np.isnan(center_x) or np.isinf(center_x):
return None
col = data[:, int(center_x)]
row = data[int(center_y), :]
width_y = np.sqrt(np.abs((np.arange(row.size)-center_y)**2*row).sum()/row.sum())
width_x = np.sqrt(np.abs((np.arange(col.size)-center_x)**2*col).sum()/col.sum())
return height, center_y, center_x, width_y, width_x
[docs]def fit_2d_gaussian(matrix):
"""Fits a receptive field with a 2-dimensional Gaussian distribution
Parameters
----------
matrix : numpy.ndarray
2D matrix of spike counts
Returns
-------
parameters - tuple
peak_height : peak of distribution
center_y : y-coordinate of distribution center
center_x : x-coordinate of distribution center
width_y : width of distribution along x-axis
width_x : width of distribution along y-axis
success - bool
True if a fit was found, False otherwise
"""
params = gaussian_moments_2d(matrix)
if params is None:
return (np.nan, np.nan, np.nan, np.nan, np.nan), False
errorfunction = lambda p: np.ravel(_gaussian_function_2d(*p)(*np.indices(matrix.shape)) - matrix)
fit_params, ier = leastsq(errorfunction, params)
success = True if ier < 5 else False
return fit_params, success
[docs]def is_rf_inverted(rf_thresh):
"""Checks if the receptive field mapping timulus is suppressing or exciting the cell
Parameters
----------
rf_thresh : matrix
matrix of spike counts at each stimulus position
Returns
-------
if_rf_inverted : bool
True if the receptive field is inverted
"""
edge_mask = np.zeros(rf_thresh.shape)
edge_mask[:,0] = 1
edge_mask[:,-1] = 1
edge_mask[0,:] = 1
edge_mask[-1,:] = 1
num_edge_pixels = np.sum(rf_thresh * edge_mask)
return num_edge_pixels > np.sum(edge_mask) / 2
[docs]def invert_rf(rf):
"""Creates an inverted version of the receptive field
Parameters
----------
rf - matrix of spike counts at each stimulus position
Returns
-------
rf_inverted - new RF matrix
"""
return np.max(rf) - rf
[docs]def threshold_rf(rf, threshold):
"""Creates a spatial mask based on the receptive field peak, and returns the x, y coordinates of the center of
mass, as well as the area.
Parameters
----------
rf : numpy.ndarray
2D matrix of spike counts
threshold : float
Threshold as ratio of the RF's standard deviation
Returns
-------
threshold_rf : numpy.ndarray
Thresholded version of the original RF
center_x : float
x-coordinate of mask center of mass
center_y : float
y-coordinate of mask center of mass
area : float
area of mask
"""
rf_filt = ndi.gaussian_filter(rf, 1)
threshold_value = np.max(rf_filt) - np.std(rf_filt) * threshold
rf_thresh = np.zeros(rf.shape, dtype='bool')
rf_thresh[rf_filt > threshold_value] = True
labels, num_features = ndi.label(rf_thresh)
best_label = np.argmax(ndi.maximum(rf_filt, labels=labels, index=np.unique(labels)))
labels[labels != best_label] = 0
labels[labels > 0] = 1
center_y, center_x = ndi.measurements.center_of_mass(labels)
area = float(np.sum(labels))
return labels, np.around(center_x, 4), np.around(center_y, 4), area
[docs]def rf_on_screen(rf, center_y, center_x):
"""Checks whether the receptive field is on the screen, given the center location."""
return 0 < center_y < rf.shape[0] and 0 < center_x < rf.shape[1]
[docs]def convert_elevation_to_degrees(elevation_in_pixels, elevation_offset_degrees=-30):
"""Converts a pixel-based elevation into degrees relative to center of gaze
The receptive field computed by this class is oriented such that the
pixel values are in the correct relative location when using matplotlib.pyplot.imshow(),
which places (0,0) in the upper-left corner of the figure.
Therefore, we need to invert the elevation value prior to converting to degrees.
Parameters
----------
elevation_in_pixels : float
elevation_offset_degrees: float
Returns
-------
elevation_in_degrees : float
"""
elevation_in_degrees = convert_pixels_to_degrees(8 - elevation_in_pixels) + elevation_offset_degrees
return elevation_in_degrees
[docs]def convert_azimuth_to_degrees(azimuth_in_pixels, azimuth_offset_degrees=10):
"""Converts a pixel-based azimuth into degrees relative to center of gaze
Parameters
----------
azimuth_in_pixels : float
azimuth_offset_degrees: float
Returns
-------
azimuth_in_degrees : float
"""
azimuth_in_degrees = convert_pixels_to_degrees((azimuth_in_pixels)) + azimuth_offset_degrees
return azimuth_in_degrees
[docs]def convert_pixels_to_degrees(value_in_pixels, degrees_to_pixels_ratio=10):
"""Converts a pixel-based distance into degrees
Parameters
----------
value_in_pixels : float
degrees_to_pixels_ratio: float
Returns
-------
value in degrees : float
"""
return value_in_pixels * degrees_to_pixels_ratio
[docs]def convert_pixel_area_to_degrees(area_in_pixels):
"""Converts a pixel-based area measure into degrees
Each pixel is a square with side of length <degrees_to_pixels_ratio>
So the area in degrees is area_in_pixels * <degrees to_pixels_ratio>^2
Parameters
----------
area_in_pixels : float
Returns
-------
area_in_degrees : float
"""
return area_in_pixels * pow(convert_pixels_to_degrees(1), 2)