import numpy as np
import pandas as pd
from allensdk import one
from scipy.stats import norm
SLIDING_WINDOW = 100
[docs]def get_go_responses(hit=None, miss=None, aborted=None):
assert len(hit) == len(miss) == len(aborted)
not_aborted = np.logical_not(np.array(aborted, dtype=bool))
hit = np.array(hit, dtype=bool)[not_aborted]
miss = np.array(miss, dtype=bool)[not_aborted]
# Go responses are nan when catch (aborted are masked out); 0 for miss, 1
# for hit. This allows pd.Series.rolling to ignore non-go trial data
go_responses = np.empty_like(hit, dtype="float")
go_responses.fill(float("nan"))
go_responses[hit] = 1
go_responses[miss] = 0
return go_responses
[docs]def get_hit_rate(
hit=None, miss=None, aborted=None, sliding_window=SLIDING_WINDOW
):
go_responses = get_go_responses(hit=hit, miss=miss, aborted=aborted)
hit_rate = (
pd.Series(go_responses)
.rolling(window=sliding_window, min_periods=0)
.mean()
.values
)
return hit_rate
[docs]def get_trial_count_corrected_hit_rate(
hit=None, miss=None, aborted=None, sliding_window=SLIDING_WINDOW
):
go_responses = get_go_responses(hit=hit, miss=miss, aborted=aborted)
go_responses_count = (
pd.Series(go_responses)
.rolling(window=sliding_window, min_periods=0)
.count()
)
hit_rate = (
pd.Series(go_responses)
.rolling(window=sliding_window, min_periods=0)
.mean()
.values
)
trial_count_corrected_hit_rate = np.vectorize(trial_number_limit)(
hit_rate, go_responses_count
)
return trial_count_corrected_hit_rate
[docs]def get_catch_responses(correct_reject=None, false_alarm=None, aborted=None):
assert len(correct_reject) == len(false_alarm) == len(aborted)
not_aborted = np.logical_not(np.array(aborted, dtype=bool))
correct_reject = np.array(correct_reject, dtype=bool)[not_aborted]
false_alarm = np.array(false_alarm, dtype=bool)[not_aborted]
# Catch responses are nan when go (aborted are masked out); 0 for
# correct-rejection, 1 for false-alarm This allows pd.Series.rolling to
# ignore non-catch trial data
catch_responses = np.empty_like(correct_reject, dtype="float")
catch_responses.fill(float("nan"))
catch_responses[false_alarm] = 1
catch_responses[correct_reject] = 0
return catch_responses
[docs]def get_false_alarm_rate(
correct_reject=None,
false_alarm=None,
aborted=None,
sliding_window=SLIDING_WINDOW,
):
catch_responses = get_catch_responses(
correct_reject=correct_reject, false_alarm=false_alarm, aborted=aborted
)
false_alarm_rate = (
pd.Series(catch_responses)
.rolling(window=sliding_window, min_periods=0)
.mean()
.values
)
return false_alarm_rate
[docs]def get_trial_count_corrected_false_alarm_rate(
correct_reject=None,
false_alarm=None,
aborted=None,
sliding_window=SLIDING_WINDOW,
):
catch_responses = get_catch_responses(
correct_reject=correct_reject, false_alarm=false_alarm, aborted=aborted
)
catch_responses_count = (
pd.Series(catch_responses)
.rolling(window=sliding_window, min_periods=0)
.count()
)
false_alarm_rate = (
pd.Series(catch_responses)
.rolling(window=sliding_window, min_periods=0)
.mean()
.values
)
trial_count_corrected_false_alarm_rate = np.vectorize(trial_number_limit)(
false_alarm_rate, catch_responses_count
)
return trial_count_corrected_false_alarm_rate
[docs]def get_rolling_dprime(
rolling_hit_rate, rolling_fa_rate, sliding_window=SLIDING_WINDOW
):
return np.array(
[
get_dprime(hr, far, sliding_window=SLIDING_WINDOW)
for hr, far in zip(rolling_hit_rate, rolling_fa_rate)
]
)
[docs]def get_dprime(hit_rate, fa_rate, sliding_window=SLIDING_WINDOW):
"""calculates the d-prime for a given hit rate and false alarm rate
https://en.wikipedia.org/wiki/Sensitivity_index
Parameters
----------
hit_rate : float
rate of hits in the True class
fa_rate : float
rate of false alarms in the False class
limits : tuple, optional
limits on extreme values, which distort. default: (0.01,0.99)
Returns
-------
d_prime
"""
limits = (1 / SLIDING_WINDOW, 1 - 1 / SLIDING_WINDOW)
assert limits[0] > 0.0, "limits[0] must be greater than 0.0"
assert limits[1] < 1.0, "limits[1] must be less than 1.0"
Z = norm.ppf
# Limit values in order to avoid d' infinity
hit_rate = np.clip(hit_rate, limits[0], limits[1])
fa_rate = np.clip(fa_rate, limits[0], limits[1])
d_prime = Z(pd.Series(hit_rate)) - Z(pd.Series(fa_rate))
return one(d_prime)
[docs]def trial_number_limit(p, N):
if N == 0:
return np.nan
if not pd.isnull(p):
p = np.max((p, 1.0 / (2 * N)))
p = np.min((p, 1 - 1.0 / (2 * N)))
return p