Source code for allensdk.brain_observatory.ecephys.ecephys_session_api.ecephys_nwb1_session_api

from typing import Dict
import pandas as pd
import numpy as np
import h5py
import collections
import warnings

# from allensdk.brain_observatory.nwb.nwb_api import NwbApi
from .ecephys_session_api import EcephysSessionApi
from allensdk.brain_observatory.running_speed import RunningSpeed


[docs]class IDCreator(object): def __init__(self, init_id=0): self._c_id = init_id self._map = {}
[docs] def get_id(self, key): if key in self._map: return self._map[key] else: id_val = self._c_id self._c_id += 1 self._map[key] = id_val return id_val
def __getitem__(self, key): return self.get_id(key) def __contains__(self, key): return isinstance(key, collections.Hashable)
[docs]class EcephysNwb1Api(EcephysSessionApi): """An EcephySession adaptor for reading NWB1.0 files. Was created by sight using an assortment of existing NWB1 files. It is possible that parts of the NWB1 standard (?!) is missing or not properly implemented. NWB1 vs NWB2 issues: * In NWB 1 there is no difference between global unit-ids and probe's local-index. A unit is unique to one channel * Units are missing information about firing_rate, isi_violation, and quality. - So that EcephysSession._build_units() actually return values I had to set quality=good for all units * NWB Stimulus_presentations missing stimulus_block, stimulus_index and Image column - To get EcephysSession.conditionwise_spikes() working had to make up a block number for every stimulus type * NWB1 missing a 'valid_data' tag for channels. Had to set to True otherwise EcephysSession won't see any channels * There were no 'channels' table/group in NWB1. Instead we had to iterate through all the units and pull out the distinct channel info. * In NWB2 each unit has a mean-waveform for every channel on the probe. In NWB1 A unit only has a single waveform * The NWB1 identifier is a string """ def __init__(self, path, *args, **kwargs): self._path = path self._h5_root = h5py.File(self._path, 'r') try: # check file is a valid NWB 1 file version_str = self._h5_root['nwb_version'][()] if not (version_str.startswith('NWB-1.') or version_str.startswith('1.')): raise Exception('{} is not a valid NWB 1 file path'.format(self._path)) except Exception: raise # EcephysSession requires session wide ids for units/channels/etc but NWB 1 doesn't have such a thing (ids # are relative to the probe). The following data-stuctures are used build and fetch session ids without having # to parse all the tables. self._unit_ids = IDCreator() self._channel_ids = IDCreator() self._probe_ids = IDCreator() @property def processing_grp(self): return self._h5_root['/processing'] @property def running_speed_grp(self): return self._h5_root['/acquisition/timeseries/RunningSpeed'] def _probe_groups(self): return [(pname, pgrp) for pname, pgrp in self.processing_grp.items() if isinstance(pgrp, h5py.Group) and pname.lower().startswith('probe')]
[docs] def get_running_speed(self): running_speed_grp = self.running_speed_grp return pd.DataFrame({ "start_time": running_speed_grp['timestamps'][:], "velocity": running_speed_grp['data'][:] # average velocities over a given interval })
__stim_col_map = { # Used for mapping column names from NWB 1.0 features ds to their appropiate NWB 2.0 name b'temporal_frequency': 'TF', b'spatial_frequency': 'SF', b'pos_x': 'Pos_x', b'pos_y': 'Pos_y', b'orientation': 'Ori', b'color': 'Color', b'phase': 'Phase', b'frame': 'Image' }
[docs] def get_stimulus_presentations(self) -> pd.DataFrame: # TODO: Missing 'stimulus_block', 'stimulus_index, Image, stimulus_presentations_df = None presentation_ids = 0 # make up a id for every stim-presentation stim_pres_grp = self._h5_root['/stimulus/presentation'] # Stimulus-presentations are heirarchily grouped by presentation name. Iterate through all of them and build # a single table. for block_i, (stim_name, stim_grp) in enumerate(stim_pres_grp.items()): timestamps = stim_grp['timestamps'][()] start_times = timestamps[:, 0] if timestamps.shape[1] == 2: stop_times = timestamps[:, 1] else: # Some of the datasets have an optotagging stimulus with no stop time. continue stop_times = np.nan n_stims = stim_grp['num_samples'][()] try: # parse the features/data datasets, map old column names (temporal freq->TF, phase-> phase, etc). stim_props = {self.__stim_col_map.get(ftr_name, ftr_name): stim_grp['data'][:, i] for i, ftr_name in enumerate(stim_grp['features'][()])} except Exception: stim_props = {} stim_df = pd.DataFrame({ 'stimulus_presentation_id': np.arange(presentation_ids, presentation_ids + n_stims), 'start_time': start_times, 'stop_time': stop_times, 'stimulus_name': stim_name, 'TF': stim_props.get('TF', np.nan), 'SF': stim_props.get('SF', np.nan), 'Ori': stim_props.get('Ori', np.nan), 'Pos_x': stim_props.get('Pos_x', np.nan), 'Pos_y': stim_props.get('Pos_y', np.nan), 'Color': stim_props.get('Color', np.nan), 'Phase': stim_props.get('Phase', np.nan), 'Image': stim_props.get('Image', np.nan), 'stimulus_block': block_i # Required by conditionwise_spike_counts(), add made-up number }) presentation_ids += n_stims if stimulus_presentations_df is None: stimulus_presentations_df = stim_df else: stimulus_presentations_df = stimulus_presentations_df.append(stim_df) stimulus_presentations_df['stimulus_index'] = 0 # I'm not sure what column is, but is droped by EcephysSession stimulus_presentations_df.set_index('stimulus_presentation_id', inplace=True) return stimulus_presentations_df
[docs] def get_probes(self) -> pd.DataFrame: probe_ids = [] locations = [] for prb_name, prb_grp in self._probe_groups(): probe_ids.append(self._probe_ids[prb_name]) locations.append(prb_name) probes_df = pd.DataFrame({ 'id': pd.Series(probe_ids, dtype=np.uint64), 'location': pd.Series(locations, dtype=object), 'description': "" # TODO: Find description }) probes_df.set_index('id', inplace=True) probes_df['sampling_rate'] = 30000.0 # TODO: calculate real sampling rate for each probe. return probes_df
[docs] def get_channels(self) -> pd.DataFrame: # TODO: Missing: manual_structure_id processing_grp = self.processing_grp max_channels = sum(len(prb_grp['unit_list']) for prb_grp in processing_grp.values()) channel_ids = np.zeros(max_channels, dtype=np.uint64) local_channel_indices = np.zeros(max_channels, dtype=np.int64) prb_ids = np.zeros(max_channels, dtype=np.uint64) prb_hrz_pos = np.zeros(max_channels, dtype=np.int64) prb_vert_pos = np.zeros(max_channels, dtype=np.int64) struct_acronyms = np.empty(max_channels, dtype=object) channel_indx = 0 existing_channels = set() # In NWB 1.0 files I used I couldn't find a channel group/dataset. Instead we have to iterate through all units # to get information about all available channels for prb_name, prb_grp in self._probe_groups(): prb_id = self._probe_ids[prb_name] unit_list = prb_grp['unit_list'][()] for indx, uid in enumerate(unit_list): unit_grp = prb_grp['UnitTimes'][str(uid)] local_channel_index = unit_grp['channel'][()] channel_id = self._channel_ids[(prb_name, local_channel_index)] if channel_id in existing_channels: # If a channel has already been processed (ie it's shared by another unit) skip it. I'm assuming # position/ccf info is the same for every probe/channel_id. continue else: channel_ids[channel_indx] = channel_id local_channel_indices[channel_indx] = local_channel_index prb_ids[channel_indx] = prb_id prb_hrz_pos[channel_indx] = unit_grp['xpos_probe'][()] prb_vert_pos[channel_indx] = unit_grp['ypos_probe'][()] try: struct_acronyms[channel_indx] = str(unit_grp['ccf_structure'][()], encoding='ascii') except TypeError: struct_acronyms[channel_indx] = unit_grp['ccf_structure'][()] existing_channels.add(channel_id) channel_indx += 1 n_channels = len(existing_channels) channels_df = pd.DataFrame({ 'id': channel_ids[:n_channels], 'local_index': local_channel_indices[:n_channels], 'probe_id': prb_ids[:n_channels], 'probe_horizontal_position': prb_hrz_pos[:n_channels], 'probe_vertical_position': prb_vert_pos[:n_channels], 'ecephys_structure_acronym': struct_acronyms[:n_channels], 'valid_data': True # TODO: Pull out valid table column from NWB }) channels_df.set_index('id', inplace=True) return channels_df
[docs] def get_mean_waveforms(self) -> Dict[int, np.ndarray]: waveforms = {} for prb_name, prb_grp in self._probe_groups(): # There is one waveform for any given spike, but still calling it "mean" wavefor for indx, uid in enumerate(prb_grp['unit_list']): unit_grp = prb_grp['UnitTimes'][str(uid)] unit_id = self._unit_ids[(prb_name, uid)] waveforms[unit_id] = np.array([unit_grp['waveform'][()],]) # EcephysSession is expecting an array of waveforms return waveforms
[docs] def get_spike_times(self) -> Dict[int, np.ndarray]: spike_times = {} for prb_name, prb_grp in self._probe_groups(): for indx, uid in enumerate(prb_grp['unit_list']): unit_grp = prb_grp['UnitTimes'][str(uid)] unit_id = self._unit_ids[(prb_name, uid)] spike_times[unit_id] = unit_grp['times'][()] return spike_times
[docs] def get_units(self) -> pd.DataFrame: # TODO: Missing properties: firing_rate, isi_violations unit_ids = np.zeros(0, dtype=np.uint64) local_indices = np.zeros(0, dtype=np.int64) peak_channel_ids = np.zeros(0, dtype=np.int64) snrs = np.zeros(0, dtype=np.float64) for prb_name, prb_grp in self._probe_groups(): # visit every /processing/probeN/UnitList/N/ group to build # TODO: Since just visting the tree is so expensive, maybe build the channels and probes at the same time. unit_list = prb_grp['unit_list'][()] prb_uids = np.zeros(len(unit_list), dtype=np.uint64) prb_channels = np.zeros(len(unit_list), dtype=np.int64) prb_snr = np.zeros(len(unit_list), dtype=np.float64) for indx, uid in enumerate(unit_list): unit_grp = prb_grp['UnitTimes'][str(uid)] prb_uids[indx] = self._unit_ids[(prb_name, uid)] prb_channels[indx] = self._channel_ids[(prb_name, unit_grp['channel'][()])] prb_snr[indx] = unit_grp['snr'][()] unit_ids = np.append(unit_ids, prb_uids) local_indices = np.append(local_indices, unit_list) peak_channel_ids = np.append(peak_channel_ids, prb_channels) snrs = np.append(snrs, prb_snr) units_df = pd.DataFrame({ 'unit_id': pd.Series(unit_ids, dtype=np.int64), 'local_index': local_indices, 'peak_channel_id': peak_channel_ids, 'snr': snrs, 'quality': "good" # TODO: NWB 1.0 is missing quality table, need to find an equivelent }) units_df.set_index('unit_id', inplace=True) return units_df
[docs] def get_ecephys_session_id(self) -> int: # Doesn't look like the session_id is stored return EcephysSessionApi.session_na
[docs] @classmethod def from_path(cls, path, **kwargs): # TODO: Validate that file is proper NWB1 return cls(path=path, **kwargs)