#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 5 15:52:22 2019
@author: dan
"""
# TODO: Fix header
import warnings
import numpy as np
[docs]def chisq_from_stim_table(
stim_table, columns, mean_sweep_events, num_shuffles=1000, verbose=False
):
# stim_table is a pandas DataFrame with len = num_sweeps
# columns is a list of column names that define the categories (e.g.
# ['Ori','Contrast']) mean_sweep_events is a numpy array with shape
# (num_sweeps,num_cells)
sweep_categories = stim_table_to_categories(
stim_table, columns, verbose=verbose
)
p_vals = compute_chi_shuffle(
mean_sweep_events, sweep_categories, num_shuffles=num_shuffles
)
return p_vals
[docs]def compute_chi_shuffle(
mean_sweep_events, sweep_categories, num_shuffles=1000
):
# mean_sweep_events is a numpy array with shape (num_sweeps,num_cells)
# sweep_conditions is a numpy array with shape (num_sweeps)
# sweep_conditions gives the category label for each sweep
(num_sweeps, num_cells) = np.shape(mean_sweep_events)
if len(sweep_categories) != num_sweeps:
warnings.warn("sweep_categories and num_sweeps do not match")
return np.nan
sweep_categories_dummy = make_category_dummy(sweep_categories)
expected = compute_expected(mean_sweep_events, sweep_categories_dummy)
observed = compute_observed(mean_sweep_events, sweep_categories_dummy)
chi_actual = compute_chi(observed, expected)
chi_shuffle = np.zeros((num_cells, num_shuffles))
for ns in range(num_shuffles):
shuffle_sweeps = np.random.choice(num_sweeps, size=(num_sweeps,))
shuffle_sweep_events = mean_sweep_events[shuffle_sweeps]
shuffle_expected = compute_expected(
shuffle_sweep_events, sweep_categories_dummy
)
shuffle_observed = compute_observed(
shuffle_sweep_events, sweep_categories_dummy
)
chi_shuffle[:, ns] = compute_chi(shuffle_observed, shuffle_expected)
p_vals = np.mean(chi_actual.reshape(num_cells, 1) < chi_shuffle, axis=1)
return p_vals
[docs]def stim_table_to_categories(stim_table, columns, verbose=False):
# get the categories for all sweeps with each unique combination of
# parameters in 'columns' being one category
# sweeps with non-finite values in ANY column (e.g. np.NaN) are labeled
# as blank sweeps (category = -1)
# TODO: Replace with EcephysSession.get_stimulus_conditions
num_sweeps = len(stim_table)
num_params = len(columns)
unique_params = []
options_per_column = []
max_combination = 1
for column in columns:
column_params = np.unique(stim_table[column].values)
# column_params = column_params[np.isfinite(column_params)]
unique_params.append(column_params)
options_per_column.append(len(column_params))
max_combination *= len(column_params)
category = 0
sweep_categories = -1 * np.ones((num_sweeps,))
curr_combination = np.zeros((num_params,), dtype="int")
options_per_column = np.array(options_per_column).astype("int")
all_tried = False
while not all_tried:
matches_combination = np.ones((num_sweeps,), dtype=bool)
for i_col, column in enumerate(columns):
param = unique_params[i_col][curr_combination[i_col]]
matches_param = stim_table[column].values == param
matches_combination *= matches_param
if np.any(matches_combination):
sweep_categories[matches_combination] = category
if verbose:
print("Category " + str(category))
for i_col, column in enumerate(columns):
param = unique_params[i_col][curr_combination[i_col]]
print(column + ": " + str(param))
category += 1
# advance the combination
curr_combination = advance_combination(
curr_combination, options_per_column
)
all_tried = curr_combination[0] == options_per_column[0]
if verbose:
blank_sweeps = sweep_categories == -1
print("num blank: " + str(blank_sweeps.sum()))
return sweep_categories
[docs]def advance_combination(curr_combination, options_per_column):
num_cols = len(curr_combination)
might_carry = True
col = num_cols - 1
while might_carry:
curr_combination[col] += 1
if col == 0 or curr_combination[col] < options_per_column[col]:
might_carry = False
else:
curr_combination[col] = 0
col -= 1
return curr_combination
[docs]def make_category_dummy(sweep_categories):
# makes a dummy variable version of the sweep category list
num_sweeps = len(sweep_categories)
categories = np.unique(sweep_categories)
num_categories = len(categories)
sweep_category_mat = np.zeros((num_sweeps, num_categories), dtype=bool)
for i_cat, category in enumerate(categories):
category_idx = np.argwhere(sweep_categories == category)[:, 0]
sweep_category_mat[category_idx, i_cat] = True
return sweep_category_mat
[docs]def compute_observed(mean_sweep_events, sweep_conditions):
(num_sweeps, num_conditions) = np.shape(sweep_conditions)
num_cells = np.shape(mean_sweep_events)[1]
observed_mat = (mean_sweep_events.T).reshape(
num_cells, num_sweeps, 1
) * sweep_conditions.reshape(1, num_sweeps, num_conditions)
observed = np.sum(observed_mat, axis=1)
return observed
[docs]def compute_expected(mean_sweep_events, sweep_conditions):
num_conditions = np.shape(sweep_conditions)[1]
num_cells = np.shape(mean_sweep_events)[1]
sweeps_per_condition = np.sum(sweep_conditions, axis=0)
events_per_sweep = np.mean(mean_sweep_events, axis=0)
expected = sweeps_per_condition.reshape(
1, num_conditions
) * events_per_sweep.reshape(num_cells, 1)
return expected
[docs]def compute_chi(observed, expected):
chi = (observed - expected) ** 2 / expected
chi = np.where(expected > 0, chi, 0.0)
return np.sum(chi, axis=1)
# %%