import scipy.signal as signal
from scipy.stats import zscore
import numpy as np
import pandas as pd
import warnings
from typing import Iterable, Union, Any, Optional
[docs]def calc_deriv(x, time):
dx = np.diff(x, prepend=np.nan)
dt = np.diff(time, prepend=np.nan)
return dx / dt
def _angular_change(summed_voltage: np.ndarray,
vmax: Union[np.ndarray, float]) -> np.ndarray:
"""
Compute the change in degrees in radians at each point from the
summed voltage encoder data.
Parameters
----------
summed_voltage: 1d np.ndarray
The "unwrapped" voltage signal from the encoder, cumulatively
summed. See `_unwrap_voltage_signal`.
vmax: 1d np.ndarray or float
Either a constant float, or a 1d array (typically constant)
of values. These values represent the theoretical max voltage
value of the encoder. If an array, needs to be the same length
as the summed_voltage array.
Returns
-------
np.ndarray
1d array of change in degrees in radians from each point
"""
delta_theta = np.diff(summed_voltage, prepend=np.nan) / vmax * 2 * np.pi
return delta_theta
def _shift(
arr: Iterable,
periods: int = 1,
fill_value: float = np.nan) -> np.ndarray:
"""
Shift index of an iterable (array-like) by desired number of
periods with an optional fill value (default = NaN).
Parameters
----------
arr: Iterable (array-like)
Iterable containing numeric data. If int, will be converted to
float in returned object.
periods: int (default=1)
The number of elements to shift.
fill_value: float (default=np.nan)
The value to fill at the beginning of the shifted array
Returns
-------
np.ndarray (1d)
Copy of input object as a 1d array, shifted.
"""
if periods <= 0:
raise ValueError("Can only shift for periods > 0.")
if fill_value is None:
fill_value = np.nan
if isinstance(fill_value, float):
# Circumvent issue if int-like array with np.nan as fill
shifted = np.roll(arr, periods).astype(float)
else:
shifted = np.roll(arr, periods)
shifted[:periods] = fill_value
return shifted
[docs]def deg_to_dist(angular_speed: np.ndarray) -> np.ndarray:
"""
Takes the angular speed (radians/s) at each step in radians, and
computes the linear speed in cm/s.
Parameters
----------
angular_speed: np.ndarray (1d)
1d array of angular speed in radians/s
Returns
-------
np.ndarray (1d)
Linear speed in cm/s at each time point.
"""
wheel_diameter = 6.5 * 2.54 # 6.5" wheel diameter, 2.54 = cm/in
running_radius = 0.5 * (
# assume the animal runs at 2/3 the distance from the wheel center
2.0 * wheel_diameter / 3.0)
running_speed_cm_per_sec = angular_speed * running_radius
return running_speed_cm_per_sec
def _identify_wraps(vsig: Iterable, *,
min_threshold: float = 1.5,
max_threshold: float = 3.5):
"""
Identify "wraps" in the voltage signal. In practice, this is when
the encoder voltage signal crosses 5V and wraps to 0V, or
vice-versa.
Argument defaults and implementation suggestion via @dougo
Parameters
----------
vsig: Iterable (array-like)
1d array-like iterable of voltage signal
min_threshold: float (default=1.5)
The min_threshold value that must be crossed to be considered
a possible wrapping point.
max_threshold: float (default=3.5)
The max threshold value that must be crossed to be considered
a possible wrapping point.
Returns
-------
Tuple
Tuple of ([indices of positive wraps], [indices of negative wraps])
"""
# Compare against previous value
shifted_vsig = _shift(vsig)
if not isinstance(vsig, np.ndarray):
vsig = np.array(vsig)
# Suppress warnings for when comparing to nan values
with np.errstate(invalid='ignore'):
pos_wraps = np.asarray(
np.logical_and(vsig < min_threshold, shifted_vsig > max_threshold)
).nonzero()[0]
neg_wraps = np.asarray(
np.logical_and(vsig > max_threshold, shifted_vsig < min_threshold)
).nonzero()[0]
return pos_wraps, neg_wraps
def _local_boundaries(time, index, span: float = 0.25) -> tuple:
"""
Given a 1d array of monotonically increasing timestamps, and a
point in that array (`index`), compute the indices that form the
inclusive boundary around `index` for timespan `span`.
Values in `time` must monotonically increase. Flat lines (same value
multiple times) are OK. The neighborhood may terminate around the
index if the `span` is too small for the sampling rate. A warning
will be raised in this case.
Returns
-------
Tuple
Tuple of corresponding to the start, end indices that bound
a time span of length `span` (maximally)
E.g.
```
time = np.array([0, 1, 1.5, 2, 2.2, 2.5, 3, 3.5])
_local_boundary(time, 3, 1.0)
>>> (1, 6)
```
"""
if np.diff(time[~np.isnan(time)]).min() < 0:
raise ValueError("Data do not monotonically increase. This probably "
"means there is an error in your time series.")
t_val = time[index]
max_val = t_val + abs(span)
min_val = t_val - abs(span)
eligible_indices = np.nonzero((time <= max_val) & (time >= min_val))[0]
max_ix = eligible_indices.max()
min_ix = eligible_indices.min()
if (min_ix == index) or (max_ix == index):
warnings.warn("Unable to find two data points around index "
f"for span={span} that do not include the index. "
"This could mean that your time span is too small for "
"the time data sampling rate, the data are not "
"monotonically increasing, or that you are trying "
"to find a neighborhood at the beginning/end of the "
"data stream.")
return min_ix, max_ix
def _clip_speed_wraps(speed, time, wrap_indices, t_span: float = 0.25):
"""
Correct for artifacts at the voltage 'wraps'. Sometimes there are
transient spikes in speed at the 'wrap' points. This doesn't make
sense since speed on a running wheel should be a smoothly varying
function. Take the neighborhood of values in +/- `t_span` seconds
around wrap points, and clip the value at the wrap point
such that it does not exceed the min/max values in the neighborhood.
"""
corrected_speed = speed.copy()
for wrap in wrap_indices:
start_ix, end_ix = _local_boundaries(time, wrap, t_span)
local_slice = np.concatenate( # Remove the wrap point
(speed[start_ix:wrap], speed[wrap+1:end_ix+1]))
corrected_speed[wrap] = np.clip(
speed[wrap], np.nanmin(local_slice), np.nanmax(local_slice))
return corrected_speed
def _unwrap_voltage_signal(
vsig: Iterable,
pos_wrap_ix: Iterable,
neg_wrap_ix: Iterable,
*,
vmax: Optional[float] = None,
max_threshold: float = 5.1,
max_diff: float = 1.0) -> np.ndarray:
"""
Calculate the change in voltage at each timestamp.
'Unwraps' the
voltage data coming from the encoder at the value `vmax`. If `vmax`
is a float, use that value to 'wrap'. If it is None, then compute
the maximum value from the observed voltage signal (`vsig`, as long
as the maximum value is under the value of `max_threshold` (to
account for possible outlier data/encoder errors).
The reason is because the rotary encoder should theoretically wrap
at 5V, but in practice does not always reach 5V before wrapping
back to 0V. If it is assumed that the encoder wraps at 5V, but
actually does not reach that voltage, then the computed running
speed can be transiently higher at the timestamps of the signal
'wraps'.
Parameters
----------
vsig: Iterable (array-like)
The raw voltage data from the rotary encoder
vmax: Optional[float] (default=None)
The value at which, upon passing this threshold, the voltage
"wraps" back to 0V on the encoder.
max_threshold: float (default=5.1)
The maximum threshold for the `vmax` value. Used only if
`vmax` is `None`. To account for the possibility of outlier
data/encoder errors, the computed `vmax` should not exceed
this value.
max_diff: float (default=1.0)
The maximum voltage difference allowed between two adjacent
points, after accounting for the voltage "wrap". Values
exceeding this threshold will be set to np.nan.
Returns
-------
np.ndarray
1d np.ndarray of the "unwrapped" signal from `vsig`.
"""
if not isinstance(vsig, np.ndarray):
vsig = np.array(vsig)
if vmax is None:
vmax = vsig[vsig < max_threshold].max()
unwrapped_diff = np.zeros(vsig.shape)
vsig_last = _shift(vsig)
if len(pos_wrap_ix):
# positive wraps: subtract from the previous value and add vmax
unwrapped_diff[pos_wrap_ix] = (
(vsig[pos_wrap_ix] + vmax) - vsig_last[pos_wrap_ix])
# negative: subtract vmax and the previous value
if len(neg_wrap_ix):
unwrapped_diff[neg_wrap_ix] = (
vsig[neg_wrap_ix] - (vsig_last[neg_wrap_ix] + vmax))
# Other indices, just compute straight diff from previous value
wrap_ix = np.concatenate((pos_wrap_ix, neg_wrap_ix))
other_ix = np.array(list(set(range(len(vsig_last))).difference(wrap_ix)))
unwrapped_diff[other_ix] = vsig[other_ix] - vsig_last[other_ix]
# Correct for wrap artifacts based on allowed `max_diff` value
# (fill with nan)
# Suppress warnings when comparing with nan values to reduce noise
with np.errstate(invalid='ignore'):
unwrapped_diff = np.where(
np.abs(unwrapped_diff) <= max_diff, unwrapped_diff, np.nan)
# Get nan indices to propogate to the cumulative sum (otherwise
# treated as 0)
unwrapped_nans = np.array(np.isnan(unwrapped_diff)).nonzero()
summed_diff = np.nancumsum(unwrapped_diff) + vsig[0] # Add the baseline
summed_diff[unwrapped_nans] = np.nan
return summed_diff
def _zscore_threshold_1d(data: np.ndarray,
threshold: float = 5.0) -> np.ndarray:
"""
Replace values in 1d array `data` that exceed `threshold` number
of SDs from the mean with NaN.
Parameters
---------
data: np.ndarray
1d np array of values
threshold: float (default=5.0)
Z-score threshold to replace with NaN.
Returns
-------
np.ndarray (1d)
A copy of `data` with values exceeding `threshold` SDs from
the mean replaced with NaN.
"""
corrected_data = data.copy().astype("float")
scores = zscore(data, nan_policy="omit")
# Suppress warnings when comparing to nan values to reduce noise
with np.errstate(invalid='ignore'):
corrected_data[np.abs(scores) > threshold] = np.nan
return corrected_data
[docs]def get_running_df(data, time: np.ndarray, lowpass: bool = True, zscore_threshold=10.0):
"""
Given the data from the behavior 'pkl' file object and a 1d
array of timestamps, compute the running speed. Returns a
dataframe with the raw voltage data as well as the computed speed
at each timestamp. By default, the running speed is filtered with
a 10 Hz Butterworth lowpass filter to remove artifacts caused by
the rotary encoder.
Parameters
----------
data
Deserialized 'behavior pkl' file data
time: np.ndarray (1d)
Timestamps for running data measurements
lowpass: bool (default=True)
Whether to apply a 10Hz low-pass filter to the running speed
data.
zscore_threshold: float
The threshold to use for removing outlier running speeds which might be noise and not true signal
Returns
-------
pd.DataFrame
Dataframe with an index of timestamps and the following
columns:
"speed": computed running speed
"dx": angular change, computed during data collection
"v_sig": voltage signal from the encoder
"v_in": the theoretical maximum voltage that the encoder
will reach prior to "wrapping". This should
theoretically be 5V (after crossing 5V goes to 0V, or
vice versa). In practice the encoder does not always
reach this value before wrapping, which can cause
transient spikes in speed at the voltage "wraps".
The raw data are provided so that the user may compute their
own speed from source, if desired.
Notes
-----
Though the angular change is available in the raw data
(key="dx"), this method recomputes the angular change from the
voltage signal (key="vsig") due to very specific, low-level
artifacts in the data caused by the encoder. See method
docstrings for more detailed information. The raw data is
included in the final output in case the end user wants to apply
their own corrections and compute running speed from the raw
source.
"""
v_sig = data["items"]["behavior"]["encoders"][0]["vsig"]
v_in = data["items"]["behavior"]["encoders"][0]["vin"]
if len(v_in) > len(time) + 1:
error_string = ("length of v_in ({}) cannot be longer than length of "
"time ({}) + 1, they are off by {}").format(
len(v_in),
len(time),
abs(len(v_in) - len(time))
)
raise ValueError(error_string)
if len(v_in) == len(time) + 1:
warnings.warn(
"Time array is 1 value shorter than encoder array. Last encoder "
"value removed\n", UserWarning, stacklevel=1)
v_in = v_in[:-1]
v_sig = v_sig[:-1]
# dx = 'd_theta' = angular change
# There are some issues with angular change in the raw data so we
# recompute this value
dx_raw = data["items"]["behavior"]["encoders"][0]["dx"]
# Identify "wraps" in the voltage signal that need to be unwrapped
# This is where the encoder switches from 0V to 5V or vice versa
pos_wraps, neg_wraps = _identify_wraps(
v_sig, min_threshold=1.5, max_threshold=3.5)
# Unwrap the voltage signal and apply correction for transient spikes
unwrapped_vsig = _unwrap_voltage_signal(
v_sig, pos_wraps, neg_wraps, max_threshold=5.1, max_diff=1.0)
angular_change_point = _angular_change(unwrapped_vsig, v_in)
angular_change = np.nancumsum(angular_change_point)
# Add the nans back in (get turned to 0 in nancumsum)
angular_change[np.isnan(angular_change_point)] = np.nan
angular_speed = calc_deriv(angular_change, time) # speed in radians/s
linear_speed = deg_to_dist(angular_speed)
# Artifact correction to speed data
wrap_corrected_linear_speed = _clip_speed_wraps(
linear_speed, time, np.concatenate([pos_wraps, neg_wraps]),
t_span=0.25)
outlier_corrected_linear_speed = _zscore_threshold_1d(
wrap_corrected_linear_speed, threshold=zscore_threshold)
# Final filtering (optional) for smoothing out the speed data
if lowpass:
b, a = signal.butter(3, Wn=4, fs=60, btype="lowpass")
outlier_corrected_linear_speed = signal.filtfilt(
b, a, np.nan_to_num(outlier_corrected_linear_speed))
return pd.DataFrame({
'speed': outlier_corrected_linear_speed[:len(time)],
'dx': dx_raw[:len(time)],
'v_sig': v_sig[:len(time)],
'v_in': v_in[:len(time)],
}, index=pd.Index(time, name='timestamps'))