Source code for allensdk.brain_observatory.sync_dataset

"""
dataset.py

Dataset object for loading and unpacking an HDF5 dataset generated by
    sync.py

@author: derricw

Allen Institute for Brain Science

Dependencies
------------
numpy  http://www.numpy.org/
h5py   http://www.h5py.org/

"""
import collections
from typing import Union, Sequence, Optional

import h5py as h5
import numpy as np

import warnings
import logging
logger = logging.getLogger(__name__)

dset_version = 1.04


[docs]def unpack_uint32(uint32_array, endian='L'): """ Unpacks an array of 32-bit unsigned integers into bits. Default is least significant bit first. *Not currently used by sync dataset because get_bit is better and does basically the same thing. I'm just leaving it in because it could potentially account for endianness and possibly have other uses in the future. """ if not uint32_array.dtype == np.uint32: raise TypeError("Must be uint32 ndarray.") buff = np.getbuffer(uint32_array) uint8_array = np.frombuffer(buff, dtype=np.uint8) uint8_array = np.fliplr(uint8_array.reshape(-1, 4)) bits = np.unpackbits(uint8_array).reshape(-1, 32) if endian.upper() == 'B': bits = np.fliplr(bits) return bits
[docs]def get_bit(uint_array, bit): """ Returns a bool array for a specific bit in a uint ndarray. Parameters ---------- uint_array : (numpy.ndarray) The array to extract bits from. bit : (int) The bit to extract. """ return np.bitwise_and(uint_array, 2 ** bit).astype(bool).astype(np.uint8)
[docs]class Dataset(object): """ A sync dataset. Contains methods for loading and parsing the binary data. Parameters ---------- path : str Path to HDF5 file. Examples -------- >>> dset = Dataset('my_h5_file.h5') >>> logger.info(dset.meta_data) >>> dset.stats() >>> dset.close() >>> with Dataset('my_h5_file.h5') as d: ... logger.info(dset.meta_data) ... dset.stats() The sync file documentation from MPE can be found at sharepoint > Instrumentation > Shared Documents > Sync_line_labels_discussion_2020-01-27-.xlsx # NOQA E501 Direct link: https://alleninstitute.sharepoint.com/:x:/s/Instrumentation/ES2bi1xJ3E9NupX-zQeXTlYBS2mVVySycfbCQhsD_jPMUw?e=Z9jCwH """ FRAME_KEYS = ('frames', 'stim_vsync', 'vsync_stim') PHOTODIODE_KEYS = ('photodiode', 'stim_photodiode') OPTOGENETIC_STIMULATION_KEYS = ("LED_sync", "opto_trial") EYE_TRACKING_KEYS = ("eye_frame_received", # Expected eye tracking # line label after 3/27/2020 # clocks eye tracking frame pulses (port 0, line 9) "cam2_exposure", # previous line label for eye tracking # (prior to ~ Oct. 2018) "eyetracking", "eye_cam_exposing", "eye_tracking") # An undocumented, but possible eye tracking line label # NOQA E114 BEHAVIOR_TRACKING_KEYS = ("beh_frame_received", # Expected behavior line label after 3/27/2020 # NOQA E127 # clocks behavior tracking frame # NOQA E127 # pulses (port 0, line 8) "cam1_exposure", "behavior_monitoring") DEPRECATED_KEYS = set() def __init__(self, path): self.dfile = self.load(path) self._check_line_labels() def _check_line_labels(self): if hasattr(self, "line_labels"): deprecated_keys = set(self.line_labels) & self.DEPRECATED_KEYS if deprecated_keys: warnings.warn((f"The loaded sync file contains the " f"following deprecated line label keys: " f"{deprecated_keys}. Consider updating the " f"sync file line labels."), stacklevel=2) else: warnings.warn(("The loaded sync file has no line labels and may " "not be valid."), stacklevel=2) def _process_times(self): """ Preprocesses the time array to account for rollovers. This is only relevant for event-based sampling. """ times = self.get_all_events()[:, 0:1].astype(np.int64) intervals = np.ediff1d(times, to_begin=0) rollovers = np.where(intervals < 0)[0] for i in rollovers: times[i:] += 4294967296 return times
[docs] def load(self, path): """ Loads an hdf5 sync dataset. Parameters ---------- path : str Path to hdf5 file. """ self.dfile = h5.File( path, 'r') # MG edit 3/15 removed 'r' because some sync files were unable to load # NOQA E501 self.meta_data = eval(self.dfile['meta'][()]) self.line_labels = self.meta_data['line_labels'] self.times = self._process_times() return self.dfile
@property def sample_freq(self): try: return float(self.meta_data['ni_daq']['sample_freq']) except KeyError: return float(self.meta_data['ni_daq']['counter_output_freq'])
[docs] def get_bit(self, bit): """ Returns the values for a specific bit. Parameters ---------- bit : int Bit to return. """ return get_bit(self.get_all_bits(), bit)
[docs] def get_line(self, line): """ Returns the values for a specific line. Parameters ---------- line : str Line to return. """ bit = self._line_to_bit(line) return self.get_bit(bit)
[docs] def get_bit_changes(self, bit): """ Returns the first derivative of a specific bit. Data points are 1 on rising edges and 255 on falling edges. Parameters ---------- bit : int Bit for which to return changes. """ bit_array = self.get_bit(bit) return np.ediff1d(bit_array, to_begin=0)
[docs] def get_line_changes(self, line): """ Returns the first derivative of a specific line. Data points are 1 on rising edges and 255 on falling edges. Parameters ---------- line : (str) Line name for which to return changes. """ bit = self._line_to_bit(line) return self.get_bit_changes(bit)
[docs] def get_all_bits(self): """ Returns the data for all bits. """ return self.dfile['data'][()][:, -1]
[docs] def get_all_times(self, units='samples'): """ Returns all counter values. Parameters ---------- units : str Return times in 'samples' or 'seconds' """ if self.meta_data['ni_daq']['counter_bits'] == 32: times = self.get_all_events()[:, 0] else: times = self.times units = units.lower() if units == 'samples': return times elif units in ['seconds', 'sec', 'secs']: freq = self.sample_freq return times / freq else: raise ValueError("Only 'samples' or 'seconds' are valid units.")
[docs] def get_all_events(self): """ Returns all counter values and their cooresponding IO state. """ return self.dfile['data'][()]
[docs] def get_events_by_bit(self, bit, units='samples'): """ Returns all counter values for transitions (both rising and falling) for a specific bit. Parameters ---------- bit : int Bit for which to return events. """ changes = self.get_bit_changes(bit) return self.get_all_times(units)[np.where(changes != 0)]
[docs] def get_events_by_line(self, line, units='samples'): """ Returns all counter values for transitions (both rising and falling) for a specific line. Parameters ---------- line : str Line for which to return events. """ line = self._line_to_bit(line) return self.get_events_by_bit(line, units)
def _line_to_bit(self, line): """ Returns the bit for a specified line. Either line name and number is accepted. Parameters ---------- line : str Line name for which to return corresponding bit. """ if type(line) is int: return line elif type(line) is str: return self.line_labels.index(line) else: raise TypeError("Incorrect line type. Try a str or int.") def _bit_to_line(self, bit): """ Returns the line name for a specified bit. Parameters ---------- bit : int Bit for which to return the corresponding line name. """ return self.line_labels[bit]
[docs] def get_rising_edges(self, line, units='samples'): """ Returns the counter values for the rizing edges for a specific bit or line. Parameters ---------- line : str Line for which to return edges. """ bit = self._line_to_bit(line) changes = self.get_bit_changes(bit) return self.get_all_times(units)[np.where(changes == 1)]
[docs] def get_edges( self, kind: str, keys: Union[str, Sequence[str]], units: str = "seconds", permissive: bool = False ) -> Optional[np.ndarray]: """ Utility function for extracting edge times from a line Parameters ---------- kind : One of "rising", "falling", or "all". Should this method return timestamps for rising, falling or both edges on the appropriate line keys : These will be checked in sequence. Timestamps will be returned for the first which is present in the line labels units : one of "seconds", "samples", or "indices". The returned "time"stamps will be given in these units. raise_missing : If True and no matching line is found, a KeyError will be raised Returns ------- An array of edge times. If raise_missing is False and none of the keys were found, returns None. Raises ------ KeyError : none of the provided keys were found among this dataset's line labels """ if kind == 'falling': fn = self.get_falling_edges elif kind == 'rising': fn = self.get_rising_edges elif kind == 'all': return np.sort(np.concatenate([ self.get_edges('rising', keys, units), self.get_edges('falling', keys, units) ])) if isinstance(keys, str): keys = [keys] for key in keys: try: return fn(key, units) except ValueError: continue if not permissive: raise KeyError( f"none of {keys} were found in this dataset's line labels")
[docs] def get_falling_edges(self, line, units='samples'): """ Returns the counter values for the falling edges for a specific bit or line. Parameters ---------- line : str Line for which to return edges. """ bit = self._line_to_bit(line) changes = self.get_bit_changes(bit) return self.get_all_times(units)[np.where(changes == 255)]
[docs] def get_nearest(self, source, target, source_edge="rising", target_edge="rising", direction="previous", units='indices', ): """ For all values of the source line, finds the nearest edge from the target line. By default, returns the indices of the target edges. Args: source (str, int): desired source line target (str, int): desired target line source_edge [Optional(str)]: "rising" or "falling" source edges target_edge [Optional(str): "rising" or "falling" target edges direction (str): "previous" or "next". Whether to prefer the previous edge or the following edge. units (str): "indices" """ source_edges = getattr(self, "get_{}_edges".format(source_edge.lower()))(source.lower(), units="samples") # NOQA E501 target_edges = getattr(self, "get_{}_edges".format(target_edge.lower()))(target.lower(), units="samples") # NOQA E501 indices = np.searchsorted(target_edges, source_edges, side="right") if direction.lower() == "previous": indices[np.where(indices != 0)] -= 1 elif direction.lower() == "next": indices[np.where(indices == len(target_edges))] = -1 if units in ["indices", 'index']: return indices elif units == "samples": return target_edges[indices] elif units in ['sec', 'seconds', 'second']: return target_edges[indices] / self.sample_freq else: raise KeyError( "Invalid units. Try 'seconds', 'samples' or 'indices'")
[docs] def get_analog_channel(self, channel, start_time=0.0, stop_time=None, downsample=1): """ Returns the data from the specified analog channel between the timepoints. Args: channel (int, str): desired channel index or label start_time (Optional[float]): start time in seconds stop_time (Optional[float]): stop time in seconds downsample (Optional[int]): downsample factor Returns: ndarray: slice of data for specified channel Raises: KeyError: no analog data present """ if isinstance(channel, str): channel_index = self.analog_meta_data['analog_labels'].index( channel) channel = self.analog_meta_data['analog_channels'].index( channel_index) if "analog_data" in self.dfile.keys(): dset = self.dfile['analog_data'] analog_meta = self.get_analog_meta() sample_rate = analog_meta['analog_sample_rate'] start = int(start_time * sample_rate) if stop_time: stop = int(stop_time * sample_rate) return dset[start:stop:downsample, channel] else: return dset[start::downsample, channel] else: raise KeyError("No analog data was saved.")
[docs] def get_analog_meta(self): """ Returns the metadata for the analog data. """ if "analog_meta" in self.dfile.keys(): return eval(self.dfile['analog_meta'].value) else: raise KeyError("No analog data was saved.")
@property def analog_meta_data(self): return self.get_analog_meta()
[docs] def line_stats(self, line, print_results=True): """ Quick-and-dirty analysis of a bit. ##TODO: Split this up into smaller functions. """ # convert to bit bit = self._line_to_bit(line) # get the bit's data bit_data = self.get_bit(bit) total_data_points = len(bit_data) # get the events events = self.get_events_by_bit(bit) total_events = len(events) # get the rising edges rising = self.get_rising_edges(bit) total_rising = len(rising) # get falling edges falling = self.get_falling_edges(bit) total_falling = len(falling) if total_events <= 0: if print_results: logger.info("*" * 70) logger.info("No events on line: %s" % line) logger.info("*" * 70) return None elif total_events <= 10: if print_results: logger.info("*" * 70) logger.info("Sparse events on line: %s" % line) logger.info("Rising: %s" % total_rising) logger.info("Falling: %s" % total_falling) logger.info("*" * 70) return { 'line': line, 'bit': bit, 'total_rising': total_rising, 'total_falling': total_falling, 'avg_freq': None, 'duty_cycle': None, } else: # period period = self.period(line) avg_period = period['avg'] max_period = period['max'] min_period = period['min'] period_sd = period['sd'] # freq avg_freq = self.frequency(line) # duty cycle duty_cycle = self.duty_cycle(line) if print_results: logger.info("*" * 70) logger.info("Quick stats for line: %s" % line) logger.info("Bit: %i" % bit) logger.info("Data points: %i" % total_data_points) logger.info("Total transitions: %i" % total_events) logger.info("Rising edges: %i" % total_rising) logger.info("Falling edges: %i" % total_falling) logger.info("Average period: %s" % avg_period) logger.info("Minimum period: %s" % min_period) logger.info("Max period: %s" % max_period) logger.info("Period SD: %s" % period_sd) logger.info("Average freq: %s" % avg_freq) logger.info("Duty cycle: %s" % duty_cycle) logger.info("*" * 70) return { 'line': line, 'bit': bit, 'total_data_points': total_data_points, 'total_events': total_events, 'total_rising': total_rising, 'total_falling': total_falling, 'avg_period': avg_period, 'min_period': min_period, 'max_period': max_period, 'period_sd': period_sd, 'avg_freq': avg_freq, 'duty_cycle': duty_cycle, }
[docs] def period(self, line, edge="rising"): """ Returns a dictionary with avg, min, max, and st of period for a line. """ bit = self._line_to_bit(line) if edge.lower() == "rising": edges = self.get_rising_edges(bit) elif edge.lower() == "falling": edges = self.get_falling_edges(bit) if len(edges) > 2: timebase_freq = self.meta_data['ni_daq']['counter_output_freq'] avg_period = np.mean(np.ediff1d(edges[1:])) / timebase_freq max_period = np.max(np.ediff1d(edges[1:])) / timebase_freq min_period = np.min(np.ediff1d(edges[1:])) / timebase_freq period_sd = np.std(avg_period) else: raise IndexError("Not enough edges for period: %i" % len(edges)) return { 'avg': avg_period, 'max': max_period, 'min': min_period, 'sd': period_sd, }
[docs] def frequency(self, line, edge="rising"): """ Returns the average frequency of a line. """ period = self.period(line, edge) return 1.0 / period['avg']
[docs] def duty_cycle(self, line): """ Doesn't work right now. Freezes python for some reason. Returns the duty cycle of a line. """ return "fix me" bit = self._line_to_bit(line) rising = self.get_rising_edges(bit) falling = self.get_falling_edges(bit) total_rising = len(rising) total_falling = len(falling) if total_rising > total_falling: rising = rising[:total_falling] elif total_rising < total_falling: falling = falling[:total_rising] else: pass if rising[0] < falling[0]: # line starts low high = falling - rising else: # line starts high high = np.concatenate(falling, self.get_all_events()[-1, 0]) - \ np.concatenate(0, rising) total_high_time = np.sum(high) all_events = self.get_events_by_bit(bit) total_time = all_events[-1] - all_events[0] return 1.0 * total_high_time / total_time
[docs] def stats(self): """ Quick-and-dirty analysis of all bits. Prints a few things about each bit where events are found. """ bits = [] for i in range(32): bits.append(self.line_stats(i, print_results=False)) active_bits = [x for x in bits if x is not None] logger.info("Active bits: ", len(active_bits)) for bit in active_bits: logger.info("*" * 70) logger.info("Bit: %i" % bit['bit']) logger.info("Label: %s" % self.line_labels[bit['bit']]) logger.info("Rising edges: %i" % bit['total_rising']) logger.info("Falling edges: %i" % bit["total_falling"]) logger.info("Average freq: %s" % bit['avg_freq']) logger.info("Duty cycle: %s" % bit['duty_cycle']) logger.info("*" * 70) return active_bits
[docs] def plot_all(self, start_time, stop_time, auto_show=True, ): """ Plot all active bits. Yikes. Come up with a better way to show this. """ import matplotlib.pyplot as plt for bit in range(32): if len(self.get_events_by_bit(bit)) > 0: self.plot_bit(bit, start_time, stop_time, auto_show=False, ) if auto_show: plt.show()
[docs] def plot_bits(self, bits, start_time=0.0, end_time=None, auto_show=True, ): """ Plots a list of bits. """ import matplotlib.pyplot as plt subplots = len(bits) f, axes = plt.subplots(subplots, sharex=True, sharey=True) if not isinstance(axes, collections.Iterable): axes = [axes] for bit, ax in zip(bits, axes): self.plot_bit(bit, start_time, end_time, auto_show=False, axes=ax) # f.set_size_inches(18, 10, forward=True) f.subplots_adjust(hspace=0) if auto_show: plt.show() return f, axes
[docs] def plot_bit(self, bit, start_time=0.0, end_time=None, auto_show=True, axes=None, name="", ): """ Plots a specific bit at a specific time period. """ import matplotlib.pyplot as plt times = self.get_all_times(units='sec') if not end_time: end_time = 2 ** 32 window = (times < end_time) & (times > start_time) if axes: ax = axes else: ax = plt if not name: name = self._bit_to_line(bit) if not name: name = str(bit) bit = self.get_bit(bit) ax.step(times[window], bit[window], where='post') if hasattr(ax, "set_ylim"): ax.set_ylim(-0.1, 1.1) else: axes_obj = plt.gca() axes_obj.set_ylim(-0.1, 1.1) # ax.set_ylabel('Logic State') # ax.yaxis.set_ticks_position('none') plt.setp(ax.get_yticklabels(), visible=False) ax.set_xlabel('time (seconds)') ax.legend([name]) if auto_show: plt.show() return plt.gcf()
[docs] def plot_line(self, line, start_time=0.0, end_time=None, auto_show=True, ): """ Plots a specific line at a specific time period. """ import matplotlib.pyplot as plt bit = self._line_to_bit(line) self.plot_bit(bit, start_time, end_time, auto_show=False) # plt.legend([line]) if auto_show: plt.show()
[docs] def plot_lines(self, lines, start_time=0.0, end_time=None, auto_show=True, ): """ Plots specific lines at a specific time period. """ import matplotlib.pyplot as plt bits = [] for line in lines: bits.append(self._line_to_bit(line)) f, axes = self.plot_bits(bits, start_time, end_time, auto_show=False, ) plt.subplots_adjust(left=0.025, right=0.975, bottom=0.05, top=0.95) if auto_show: plt.show() return f, axes
[docs] def close(self): """ Closes the dataset. """ self.dfile.close()
def __enter__(self): """ So we can use context manager (with...as) like any other open file. Examples -------- >>> with Dataset('my_data.h5') as d: ... d.stats() """ return self def __exit__(self, type, value, traceback): """ Exit statement for context manager. """ self.close()
if __name__ == '__main__': pass