from collections import deque
from typing import Optional, Callable, Any
import numpy as np
import h5py
from allensdk.brain_observatory.sync_dataset import Dataset
import pandas as pd
import logging
try:
import cv2
except ImportError:
cv2 = None
TRANSITION_FRAME_INTERVAL = 60
REG_PHOTODIODE_INTERVAL = 1.0 # seconds
REG_PHOTODIODE_STD = 0.05 # seconds
PHOTODIODE_ANOMALY_THRESHOLD = 0.5 # seconds
LONG_STIM_THRESHOLD = 0.2 # seconds
MAX_MONITOR_DELAY = 0.07 # seconds
[docs]def get_keys(sync_dset: Dataset) -> dict:
"""
Gets the correct keys for the sync file by searching the sync file
line labels. Removes key from the dictionary if it is not in the
sync dataset line labels.
Args:
sync_dset: The sync dataset to search for keys within
Returns:
key_dict: dictionary of key value pairs for finding data in the
sync file
"""
# key_dict contains key value pairs where key is expected label category
# and value is the possible data for each category existing in sync dataset
# line labels
key_dict = {
"photodiode": ["stim_photodiode", "photodiode"],
"2p": ["2p_vsync"],
"stimulus": ["stim_vsync", "vsync_stim"],
"eye_camera": ["cam2_exposure", "eye_tracking",
"eye_frame_received"],
"behavior_camera": ["cam1_exposure", "behavior_monitoring",
"beh_frame_received"],
"acquiring": ["2p_acquiring", "acq_trigger"],
"lick_sensor": ["lick_1", "lick_sensor"]
}
label_set = set(sync_dset.line_labels)
remove_keys = []
for key, value in key_dict.items():
# for each key in the above `key_dict`, this loop
# checks to see if there is a corresponing value in
# the set of line labels present in the sync file (`label_set`)
# If not, the key is added to the `remove_keys` list
value_set = set(value)
diff = value_set.intersection(label_set)
if len(diff) == 1:
key_dict[key] = diff.pop()
else:
remove_keys.append(key)
# the contents of the `remove_keys` list is printed to the console
# as a user warning
if len(remove_keys) > 0:
logging.warning("Could not find valid lines for the following data "
"sources")
for key in remove_keys:
logging.warning(f"{key} (valid line label(s) = {key_dict[key]}")
key_dict.pop(key)
return key_dict
[docs]def calculate_monitor_delay(sync_dset, stim_times, photodiode_key,
transition_frame_interval=TRANSITION_FRAME_INTERVAL, # noqa: E501
max_monitor_delay=MAX_MONITOR_DELAY):
"""Calculate monitor delay."""
transitions = stim_times[::transition_frame_interval]
photodiode_events = get_real_photodiode_events(sync_dset, photodiode_key)
transition_events = photodiode_events[0:len(transitions)]
delays = transition_events - transitions
delay = np.mean(delays)
logging.info(f"Calculated monitor delay: {delay}. \n "
f"Max monitor delay: {np.max(delays)}. \n "
f"Min monitor delay: {np.min(delays)}.\n "
f"Std monitor delay: {np.std(delays)}.")
if delay < 0 or delay > max_monitor_delay:
raise ValueError(f"Delay ({delay}s) falls outside expected value "
f"range (0-{MAX_MONITOR_DELAY}s).")
return delay
def _find_last_n(arr: np.ndarray, n: int,
cond: Callable[[Any], bool]) -> Optional[int]:
"""
Find the final index where the prior `n` values in an array meet
the condition `cond` (inclusive).
Parameters
==========
arr: numpy.1darray
n: int
cond: Callable that returns True if condition is met, False
otherwise. Should be able to be applied to the array elements
without any additional arguments.
"""
reversed_ix = _find_n(arr[::-1], n, cond)
if reversed_ix is not None:
reversed_ix = len(arr) - reversed_ix - 1
return reversed_ix
def _find_n(arr: np.ndarray, n: int,
cond: Callable[[Any], bool]) -> Optional[int]:
"""
Find the index where the next `n` values in an array meet the
condition `cond` (inclusive).
Parameters
==========
arr: numpy.1darray
n: int
cond: Callable that returns True if condition is met, False
otherwise. Should be able to be applied to the array elements
without any additional arguments.
"""
if len(arr) < n:
return None
queue = deque(np.apply_along_axis(cond, 0, arr[:n]), maxlen=n)
i = 0
while queue.count(True) < n:
try:
i += 1
queue.append(cond(arr[i+n-1]))
except IndexError:
return None
return i
[docs]def get_photodiode_events(sync_dset, photodiode_key):
"""Returns the photodiode events with the start/stop indicators and
the window init flash stripped off. These transitions occur roughly
~1.0s apart, since the sync square changes state every N frames
(where N = 60, and frame rate is 60 Hz). Because there are no
markers for when the first transition of this type started, we
estimate based on the event intervals. For the first valid event,
find the first two events that both meet the following criteria:
The next event occurs ~1.0s later
First the last valid event, find the first two events that both meet
the following criteria:
The last valid event occured ~1.0s before
"""
all_events = sync_dset.get_events_by_line(photodiode_key, units="seconds")
all_events_diff = np.ediff1d(all_events, to_begin=0, to_end=0)
all_events_diff_prev = all_events_diff[:-1]
all_events_diff_next = all_events_diff[1:]
min_interval = REG_PHOTODIODE_INTERVAL - REG_PHOTODIODE_STD
max_interval = REG_PHOTODIODE_INTERVAL + REG_PHOTODIODE_STD
if not len(all_events):
raise ValueError("No photodiode events found. Please check "
"the input data for errors. ")
first_valid_index = _find_n(
all_events_diff_next, 2,
lambda x: (x >= min_interval) & (x <= max_interval))
last_valid_index = _find_last_n(
all_events_diff_prev, 2,
lambda x: (x >= min_interval) & (x <= max_interval))
if first_valid_index is None:
raise ValueError("Can't find valid start event")
if last_valid_index is None:
raise ValueError("Can't find valid end event")
pd_events = all_events[first_valid_index:last_valid_index+1]
return pd_events
[docs]def get_real_photodiode_events(sync_dset, photodiode_key,
anomaly_threshold=PHOTODIODE_ANOMALY_THRESHOLD):
"""Gets the photodiode events with the anomalies removed."""
events = get_photodiode_events(sync_dset, photodiode_key)
anomalies = np.where(np.diff(events) < anomaly_threshold)
return np.delete(events, anomalies)
[docs]def get_alignment_array(ref, other, int_method=np.floor):
"""Generate an alignment array """
return int_method(np.interp(other, ref, np.arange(len(ref)), left=np.nan,
right=np.nan))
[docs]def get_video_length(filename):
if cv2 is not None:
try:
capture = cv2.VideoCapture(filename)
return int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
except AttributeError:
logging.warning("Could not get length for %s, opencv out of date",
filename)
else:
logging.warning("Could not get length for %s", filename)
[docs]def get_ophys_data_length(filename):
with h5py.File(filename, "r") as f:
return f["data"].shape[1]
[docs]def get_stim_data_length(filename: str) -> int:
"""Get stimulus data length from .pkl file.
Parameters
----------
filename : str
Path of stimulus data .pkl file.
Returns
-------
int
Stimulus data length.
"""
stim_data = pd.read_pickle(filename)
# A subset of stimulus .pkl files do not have the "vsynccount" field.
# MPE *won't* be backfilling the "vsynccount" field for these .pkl files.
# So the least worst option is to recalculate the vsync_count.
try:
vsync_count = stim_data["vsynccount"]
except KeyError:
vsync_count = len(stim_data["items"]["behavior"]["intervalsms"]) + 1
return vsync_count
[docs]def corrected_video_timestamps(video_name, timestamps, data_length):
delta = 0
if data_length is not None:
delta = len(timestamps) - data_length
if delta != 0:
logging.info("%s data of length %s has timestamps of length "
"%s", video_name, data_length, len(timestamps))
else:
logging.info("No data length provided for %s", video_name)
return timestamps, delta
[docs]class OphysTimeAligner(object):
def __init__(self, sync_file, scanner=None, dff_file=None,
stimulus_pkl=None, eye_video=None, behavior_video=None,
long_stim_threshold=LONG_STIM_THRESHOLD):
self.scanner = scanner if scanner is not None else "SCIVIVO"
self._dataset = Dataset(sync_file)
self._keys = get_keys(self._dataset)
self.long_stim_threshold = long_stim_threshold
self._monitor_delay = None
self._clipped_stim_ts_delta = None
self._clipped_stim_timestamp_values = None
if dff_file is not None:
self.ophys_data_length = get_ophys_data_length(dff_file)
else:
self.ophys_data_length = None
if stimulus_pkl is not None:
self.stim_data_length = get_stim_data_length(stimulus_pkl)
else:
self.stim_data_length = None
if eye_video is not None:
self.eye_data_length = get_video_length(eye_video)
else:
self.eye_data_length = None
if behavior_video is not None:
self.behavior_data_length = get_video_length(behavior_video)
else:
self.behavior_data_length = None
@property
def dataset(self):
return self._dataset
@property
def ophys_timestamps(self):
"""Get the timestamps for the ophys data."""
ophys_key = self._keys["2p"]
if self.scanner == "SCIVIVO":
# Scientifica data looks different than Nikon.
# http://confluence.corp.alleninstitute.org/display/IT/Ophys+Time+Sync
times = self.dataset.get_rising_edges(ophys_key, units="seconds")
elif self.scanner == "NIKONA1RMP":
# Nikon has a signal that indicates when it started writing to disk
acquiring_key = self._keys["acquiring"]
acquisition_start = self._dataset.get_rising_edges(
acquiring_key, units="seconds")[0]
ophys_times = self._dataset.get_falling_edges(
ophys_key, units="seconds")
times = ophys_times[ophys_times >= acquisition_start]
else:
raise ValueError("Invalid scanner: {}".format(self.scanner))
return times
@property
def corrected_ophys_timestamps(self):
times = self.ophys_timestamps
delta = 0
if self.ophys_data_length is not None:
if len(times) < self.ophys_data_length:
raise ValueError(
"Got too few timestamps ({}) for ophys data length "
"({})".format(len(times), self.ophys_data_length))
elif len(times) > self.ophys_data_length:
logging.info("Ophys data of length %s has timestamps of "
"length %s, truncating timestamps",
self.ophys_data_length, len(times))
delta = len(times) - self.ophys_data_length
times = times[:-delta]
else:
logging.info("No data length provided for ophys stream")
return times, delta
@property
def stim_timestamps(self):
stim_key = self._keys["stimulus"]
return self.dataset.get_falling_edges(stim_key, units="seconds")
def _get_clipped_stim_timestamps(self):
timestamps = self.stim_timestamps
delta = 0
if self.stim_data_length is not None and \
self.stim_data_length < len(timestamps):
stim_key = self._keys["stimulus"]
rising = self.dataset.get_rising_edges(stim_key, units="seconds")
# Some versions of camstim caused a spike when the DAQ is first
# initialized. Remove it.
if rising[1] - rising[0] > self.long_stim_threshold:
logging.info("Initial DAQ spike detected from stimulus, "
"removing it")
timestamps = timestamps[1:]
delta = len(timestamps) - self.stim_data_length
if delta != 0:
logging.info("Stim data of length %s has timestamps of "
"length %s",
self.stim_data_length, len(timestamps))
elif self.stim_data_length is None:
logging.info("No data length provided for stim stream")
return timestamps, delta
@property
def clipped_stim_timestamps(self):
"""
Return the stimulus timestamps with the erroneous initial spike
removed (if relevant)
Returns
-------
timestamps: np.ndarray
An array of stimulus timestamps in seconds with th emonitor delay
added
delta: int
Difference between the length of timestamps
and the number of frames reported in the stimulus
pickle file, i.e.
len(timestamps) - len(pkl_file['items']['behavior']['intervalsms']
"""
if self._clipped_stim_ts_delta is None:
(self._clipped_stim_timestamp_values,
self._clipped_stim_ts_delta) = self._get_clipped_stim_timestamps()
return (self._clipped_stim_timestamp_values,
self._clipped_stim_ts_delta)
def _get_monitor_delay(self):
timestamps, delta = self.clipped_stim_timestamps
photodiode_key = self._keys["photodiode"]
delay = calculate_monitor_delay(self.dataset,
timestamps,
photodiode_key)
return delay
@property
def monitor_delay(self):
"""
The monitor delay (in seconds) associated with the session
"""
if self._monitor_delay is None:
self._monitor_delay = self._get_monitor_delay()
return self._monitor_delay
@property
def corrected_stim_timestamps(self):
"""
The stimulus timestamps corrected for monitor delay
Returns
-------
timestamps: np.ndarray
An array of stimulus timestamps in seconds with th emonitor delay
added
delta: int
Difference between the length of timestamps and
the number of frames reported in the stimulus
pickle file, i.e.
len(timestamps) - len(pkl_file['items']['behavior']['intervalsms']
delay: float
The monitor delay in seconds
"""
timestamps, delta = self.clipped_stim_timestamps
delay = self.monitor_delay
return timestamps + delay, delta, delay
@property
def behavior_video_timestamps(self):
key = self._keys["behavior_camera"]
return self.dataset.get_falling_edges(key, units="seconds")
@property
def corrected_behavior_video_timestamps(self):
return corrected_video_timestamps("Behavior video",
self.behavior_video_timestamps,
self.behavior_data_length)
@property
def eye_video_timestamps(self):
key = self._keys["eye_camera"]
return self.dataset.get_falling_edges(key, units="seconds")
@property
def corrected_eye_video_timestamps(self):
return corrected_video_timestamps("Eye video",
self.eye_video_timestamps,
self.eye_data_length)