Source code for allensdk.brain_observatory.ecephys.ecephys_project_cache

from functools import partial
from pathlib import Path
from typing import Any, List, Optional,  Union, Callable
import ast

import pandas as pd
import SimpleITK as sitk
import numpy as np
import pynwb

from allensdk.api.cache import Cache
from allensdk.core.authentication import DbCredentials
from allensdk.brain_observatory.ecephys.ecephys_project_api import (
    EcephysProjectApi, EcephysProjectLimsApi, EcephysProjectWarehouseApi,
    EcephysProjectFixedApi
)
from allensdk.brain_observatory.ecephys.ecephys_project_api.http_engine import (
    write_bytes_from_coroutine, write_from_stream
)
from allensdk.brain_observatory.ecephys.ecephys_session_api import (
    EcephysNwbSessionApi
)
from allensdk.brain_observatory.ecephys.ecephys_session import EcephysSession
from allensdk.brain_observatory.ecephys import get_unit_filter_value
from allensdk.api.caching_utilities import one_file_call_caching


[docs]class EcephysProjectCache(Cache): SESSIONS_KEY = 'sessions' PROBES_KEY = 'probes' CHANNELS_KEY = 'channels' UNITS_KEY = 'units' SESSION_DIR_KEY = 'session_data' SESSION_NWB_KEY = 'session_nwb' PROBE_LFP_NWB_KEY = "probe_lfp_nwb" NATURAL_MOVIE_DIR_KEY = "movie_dir" NATURAL_MOVIE_KEY = "natural_movie" NATURAL_SCENE_DIR_KEY = "natural_scene_dir" NATURAL_SCENE_KEY = "natural_scene" SESSION_ANALYSIS_METRICS_KEY = "session_analysis_metrics" TYPEWISE_ANALYSIS_METRICS_KEY = "typewise_analysis_metrics" MANIFEST_VERSION = '0.2.1' SUPPRESS_FROM_UNITS = ("air_channel_index", "surface_channel_index", "has_nwb", "lfp_temporal_subsampling_factor", "epoch_name_quality_metrics", "epoch_name_waveform_metrics", "isi_experiment_id") SUPPRESS_FROM_CHANNELS = ( "air_channel_index", "surface_channel_index", "name", "date_of_acquisition", "published_at", "specimen_id", "session_type", "isi_experiment_id", "age_in_days", "sex", "genotype", "has_nwb", "lfp_temporal_subsampling_factor" ) SUPPRESS_FROM_PROBES = ( "air_channel_index", "surface_channel_index", "date_of_acquisition", "published_at", "specimen_id", "session_type", "isi_experiment_id", "age_in_days", "sex", "genotype", "has_nwb", "lfp_temporal_subsampling_factor" ) SUPPRESS_FROM_SESSION_TABLE = ( "has_nwb", "isi_experiment_id", "date_of_acquisition" ) def __init__( self, fetch_api: EcephysProjectApi = EcephysProjectWarehouseApi.default(), fetch_tries: int = 2, stream_writer: Callable = write_from_stream, manifest: Optional[Union[str, Path]] = None, version: Optional[str] = None, cache: bool = True): """ Entrypoint for accessing ecephys (neuropixels) data. Supports access to cross-session data (like stimulus templates) and high-level summaries of sessionwise data and provides tools for downloading detailed sessionwise data (such as spike times). Parameters ========== fetch_api : Used to pull data from remote sources, after which it is locally cached. Any object exposing the EcephysProjectApi interface is suitable. Standard options are: EcephysProjectWarehouseApi :: The default. Fetches publically available Allen Institute data EcephysProjectFixedApi :: Refuses to fetch any data - only the existing local cache is accessible. Useful if you want to settle on a fixed dataset for analysis EcephysProjectLimsApi :: Fetches bleeding-edge data from the Allen Institute's internal database. Only works if you are on our internal network. fetch_tries : int Maximum number of times to attempt a download before giving up and raising an exception. Note that this is total tries, not retries manifest : str or Path full path at which manifest json will be stored (default = "ecephys_project_manifest.json" in the local directory.) version : str version of manifest file. If this mismatches the version recorded in the file at manifest, an error will be raised. cache: bool Whether to write to the cache (default=True) """ manifest_ = manifest or "ecephys_project_manifest.json" version_ = version or self.MANIFEST_VERSION super(EcephysProjectCache, self).__init__(manifest=manifest_, version=version_, cache=cache) self.fetch_api = fetch_api self.fetch_tries = fetch_tries self.stream_writer = stream_writer def _get_sessions(self): path = self.get_cache_path(None, self.SESSIONS_KEY) response = one_file_call_caching(path, self.fetch_api.get_sessions, write_csv, read_csv, num_tries=self.fetch_tries) if "structure_acronyms" in response.columns: # unfortunately, structure_acronyms is a list of str response["ecephys_structure_acronyms"] = [ast.literal_eval(item) for item in response["structure_acronyms"]] response.drop(columns=["structure_acronyms"], inplace=True) return response def _get_probes(self): path: str = self.get_cache_path(None, self.PROBES_KEY) probes = one_file_call_caching(path, self.fetch_api.get_probes, write_csv, read_csv, num_tries=self.fetch_tries) # Divide the lfp sampling by the subsampling factor for clearer presentation (if provided) if all(c in list(probes) for c in ["lfp_sampling_rate", "lfp_temporal_subsampling_factor"]): probes["lfp_sampling_rate"] = ( probes["lfp_sampling_rate"] / probes["lfp_temporal_subsampling_factor"]) return probes def _get_channels(self): path = self.get_cache_path(None, self.CHANNELS_KEY) return one_file_call_caching(path, self.fetch_api.get_channels, write_csv, read_csv, num_tries=self.fetch_tries) def _get_units(self, filter_by_validity: bool = True, **unit_filter_kwargs) -> pd.DataFrame: path = self.get_cache_path(None, self.UNITS_KEY) units = one_file_call_caching(path, self.fetch_api.get_units, write_csv, read_csv, num_tries=self.fetch_tries) units = units.rename(columns={ 'PT_ratio': 'waveform_PT_ratio', 'amplitude': 'waveform_amplitude', 'duration': 'waveform_duration', 'halfwidth': 'waveform_halfwidth', 'recovery_slope': 'waveform_recovery_slope', 'repolarization_slope': 'waveform_repolarization_slope', 'spread': 'waveform_spread', 'velocity_above': 'waveform_velocity_above', 'velocity_below': 'waveform_velocity_below', 'l_ratio': 'L_ratio', }) units = units[ (units["amplitude_cutoff"] <= get_unit_filter_value("amplitude_cutoff_maximum", **unit_filter_kwargs)) & (units["presence_ratio"] >= get_unit_filter_value("presence_ratio_minimum", **unit_filter_kwargs)) & (units["isi_violations"] <= get_unit_filter_value("isi_violations_maximum", **unit_filter_kwargs)) ] if "quality" in units.columns and filter_by_validity: units = units[units["quality"] == "good"] units.drop(columns="quality", inplace=True) if "ecephys_structure_id" in units.columns and unit_filter_kwargs.get("filter_out_of_brain_units", True): units = units[~(units["ecephys_structure_id"].isna())] return units def _get_annotated_probes(self): sessions = self._get_sessions() probes = self._get_probes() return pd.merge(probes, sessions, left_on="ecephys_session_id", right_index=True, suffixes=['_probe', '_session']) def _get_annotated_channels(self): channels = self._get_channels() probes = self._get_annotated_probes() return pd.merge(channels, probes, left_on="ecephys_probe_id", right_index=True, suffixes=['_channel', '_probe']) def _get_annotated_units(self, filter_by_validity: bool = True, **unit_filter_kwargs) -> pd.DataFrame: units = self._get_units(filter_by_validity=filter_by_validity, **unit_filter_kwargs) channels = self._get_annotated_channels() annotated_units = pd.merge(units, channels, left_on='ecephys_channel_id', right_index=True, suffixes=['_unit', '_channel']) annotated_units = annotated_units.rename(columns={ 'name': 'probe_name', 'phase': 'probe_phase', 'sampling_rate': 'probe_sampling_rate', 'lfp_sampling_rate': 'probe_lfp_sampling_rate', 'local_index': 'peak_channel' }) return pd.merge(units, channels, left_on='ecephys_channel_id', right_index=True, suffixes=['_unit', '_channel'])
[docs] def get_session_table(self, suppress=None) -> pd.DataFrame: sessions = self._get_sessions() count_owned(sessions, self._get_annotated_units(), "ecephys_session_id", "unit_count", inplace=True) count_owned(sessions, self._get_annotated_channels(), "ecephys_session_id", "channel_count", inplace=True) count_owned(sessions, self._get_annotated_probes(), "ecephys_session_id", "probe_count", inplace=True) get_grouped_uniques(sessions, self._get_annotated_channels(), "ecephys_session_id", "ecephys_structure_acronym", "ecephys_structure_acronyms", inplace=True) if suppress is None: suppress = list(self.SUPPRESS_FROM_SESSION_TABLE) sessions.drop(columns=suppress, inplace=True, errors="ignore") sessions = sessions.rename(columns={'genotype': 'full_genotype'}) return sessions
[docs] def get_probes(self, suppress=None): probes = self._get_annotated_probes() count_owned(probes, self._get_annotated_units(), "ecephys_probe_id", "unit_count", inplace=True) count_owned(probes, self._get_annotated_channels(), "ecephys_probe_id", "channel_count", inplace=True) get_grouped_uniques(probes, self._get_annotated_channels(), "ecephys_probe_id", "ecephys_structure_acronym", "ecephys_structure_acronyms", inplace=True) if suppress is None: suppress = list(self.SUPPRESS_FROM_PROBES) probes.drop(columns=suppress, inplace=True, errors="ignore") return probes
[docs] def get_channels(self, suppress=None): """ Load (potentially downloading and caching) a table whose rows are individual channels. """ channels = self._get_annotated_channels() count_owned(channels, self._get_annotated_units(), "ecephys_channel_id", "unit_count", inplace=True) if suppress is None: suppress = list(self.SUPPRESS_FROM_CHANNELS) channels.drop(columns=suppress, inplace=True, errors="ignore") channels.rename(columns={"name": "probe_name"}, inplace=True, errors="ignore") return channels
[docs] def get_units(self, suppress: Optional[List[str]] = None, filter_by_validity: bool = True, **unit_filter_kwargs) -> pd.DataFrame: """Reports a table consisting of all sorted units across the entire extracellular electrophysiology project. Parameters ---------- suppress : Optional[List[str]], optional A list of dataframe column names to hide, by default None (None will hide dataframe columns specified in: SUPPRESS_FROM_UNITS) filter_by_validity : bool, optional Filter units so that only 'valid' units are returned, by default True **unit_filter_kwargs : Additional keyword arguments that can be used to filter units (for power users). Returns ------- pd.DataFrame A table consisting of sorted units across the entire extracellular electrophysiology project """ if suppress is None: suppress = list(self.SUPPRESS_FROM_UNITS) units = self._get_annotated_units(filter_by_validity=filter_by_validity, **unit_filter_kwargs) units.drop(columns=suppress, inplace=True, errors="ignore") return units
[docs] def get_session_data(self, session_id: int, filter_by_validity: bool = True, **unit_filter_kwargs): """ Obtain an EcephysSession object containing detailed data for a single session """ def read(_path): session_api = self._build_nwb_api_for_session(_path, session_id, filter_by_validity, **unit_filter_kwargs) return EcephysSession(api=session_api, test=True) return one_file_call_caching( self.get_cache_path(None, self.SESSION_NWB_KEY, session_id, session_id), partial(self.fetch_api.get_session_data, session_id), self.stream_writer, read, num_tries=self.fetch_tries )
def _build_nwb_api_for_session(self, path, session_id, filter_by_validity, **unit_filter_kwargs): get_analysis_metrics = partial( self.get_unit_analysis_metrics_for_session, session_id=session_id, annotate=False, filter_by_validity=True, **unit_filter_kwargs ) return EcephysNwbSessionApi( path=path, probe_lfp_paths=self._setup_probe_promises(session_id), additional_unit_metrics=get_analysis_metrics, external_channel_columns=partial(self._get_substitute_channel_columns, session_id), filter_by_validity=filter_by_validity, **unit_filter_kwargs ) def _setup_probe_promises(self, session_id): probes = self.get_probes() probe_ids = probes[probes["ecephys_session_id"] == session_id].index.values return { probe_id: partial( one_file_call_caching, self.get_cache_path(None, self.PROBE_LFP_NWB_KEY, session_id, probe_id), partial(self.fetch_api.get_probe_lfp_data, probe_id), self.stream_writer, read_nwb, num_tries=self.fetch_tries ) for probe_id in probe_ids } def _get_substitute_channel_columns(self, session_id): channels = self.get_channels() return channels.loc[channels["ecephys_session_id"] == session_id, [ "ecephys_structure_id", "ecephys_structure_acronym", "anterior_posterior_ccf_coordinate", "dorsal_ventral_ccf_coordinate", "left_right_ccf_coordinate" ]]
[docs] def get_natural_movie_template(self, number): return one_file_call_caching( self.get_cache_path(None, self.NATURAL_MOVIE_KEY, number), partial(self.fetch_api.get_natural_movie_template, number=number), self.stream_writer, read_movie, num_tries=self.fetch_tries )
[docs] def get_natural_scene_template(self, number): return one_file_call_caching( self.get_cache_path(None, self.NATURAL_SCENE_KEY, number), partial(self.fetch_api.get_natural_scene_template, number=number), self.stream_writer, read_scene, num_tries=self.fetch_tries )
[docs] def get_all_session_types(self, **session_kwargs): return self._get_all_values("session_type", self.get_session_table, **session_kwargs)
[docs] def get_all_full_genotypes(self, **session_kwargs): return self._get_all_values("full_genotype", self.get_session_table, **session_kwargs)
[docs] def get_structure_acronyms(self, **channel_kwargs) -> List[str]: return self._get_all_values("ecephys_structure_acronym", self.get_channels, **channel_kwargs)
[docs] def get_all_ages(self, **session_kwargs): return self._get_all_values("age_in_days", self.get_session_table, **session_kwargs)
[docs] def get_all_sexes(self, **session_kwargs): return self._get_all_values("sex", self.get_session_table, **session_kwargs)
def _get_all_values(self, key, method=None, **method_kwargs) -> List[Any]: if method is None: method = self.get_session_table data = method(**method_kwargs) return data[key].unique().tolist()
[docs] def get_unit_analysis_metrics_for_session(self, session_id, annotate: bool = True, filter_by_validity: bool = True, **unit_filter_kwargs): """ Cache and return a table of analysis metrics calculated on each unit from a specified session. See get_session_table for a list of sessions. Parameters ---------- session_id : int identifies the session from which to fetch analysis metrics. annotate : bool, optional if True, information from the annotated units table will be merged onto the outputs filter_by_validity : bool, optional Filter units used by analysis so that only 'valid' units are returned, by default True **unit_filter_kwargs : Additional keyword arguments that can be used to filter units (for power users). Returns ------- metrics : pd.DataFrame Each row corresponds to a single unit, describing a set of analysis metrics calculated on that unit. """ path = self.get_cache_path(None, self.SESSION_ANALYSIS_METRICS_KEY, session_id, session_id) fetch_metrics = partial(self.fetch_api.get_unit_analysis_metrics, ecephys_session_ids=[session_id]) metrics = one_file_call_caching(path, fetch_metrics, write_metrics_csv, read_metrics_csv, num_tries=self.fetch_tries) if annotate: units = self.get_units(filter_by_validity=filter_by_validity, **unit_filter_kwargs) units = units[units["ecephys_session_id"] == session_id] metrics = pd.merge(units, metrics, left_index=True, right_index=True, how="inner") metrics.index.rename("ecephys_unit_id", inplace=True) return metrics
[docs] def get_unit_analysis_metrics_by_session_type(self, session_type, annotate: bool = True, filter_by_validity: bool = True, **unit_filter_kwargs): """ Cache and return a table of analysis metrics calculated on each unit from a specified session type. See get_all_session_types for a list of session types. Parameters ---------- session_type : str identifies the session type for which to fetch analysis metrics. annotate : bool, optional if True, information from the annotated units table will be merged onto the outputs filter_by_validity : bool, optional Filter units used by analysis so that only 'valid' units are returned, by default True **unit_filter_kwargs : Additional keyword arguments that can be used to filter units (for power users). Returns ------- metrics : pd.DataFrame Each row corresponds to a single unit, describing a set of analysis metrics calculated on that unit. """ known_session_types = self.get_all_session_types() if session_type not in known_session_types: raise ValueError(f"unrecognized session type: {session_type}. Available types: {known_session_types}") path = self.get_cache_path(None, self.TYPEWISE_ANALYSIS_METRICS_KEY, session_type) fetch_metrics = partial(self.fetch_api.get_unit_analysis_metrics, session_types=[session_type]) metrics = one_file_call_caching( path, fetch_metrics, write_metrics_csv, read_metrics_csv, num_tries=self.fetch_tries ) if annotate: units = self.get_units(filter_by_validity=filter_by_validity, **unit_filter_kwargs) metrics = pd.merge(units, metrics, left_index=True, right_index=True, how="inner") metrics.index.rename("ecephys_unit_id", inplace=True) return metrics
[docs] def add_manifest_paths(self, manifest_builder): manifest_builder = super(EcephysProjectCache, self).add_manifest_paths(manifest_builder) manifest_builder.add_path( self.SESSIONS_KEY, 'sessions.csv', parent_key='BASEDIR', typename='file' ) manifest_builder.add_path( self.PROBES_KEY, 'probes.csv', parent_key='BASEDIR', typename='file' ) manifest_builder.add_path( self.CHANNELS_KEY, 'channels.csv', parent_key='BASEDIR', typename='file' ) manifest_builder.add_path( self.UNITS_KEY, 'units.csv', parent_key='BASEDIR', typename='file' ) manifest_builder.add_path( self.SESSION_DIR_KEY, 'session_%d', parent_key='BASEDIR', typename='dir' ) manifest_builder.add_path( self.SESSION_NWB_KEY, 'session_%d.nwb', parent_key=self.SESSION_DIR_KEY, typename='file' ) manifest_builder.add_path( self.SESSION_ANALYSIS_METRICS_KEY, 'session_%d_analysis_metrics.csv', parent_key=self.SESSION_DIR_KEY, typename='file' ) manifest_builder.add_path( self.PROBE_LFP_NWB_KEY, 'probe_%d_lfp.nwb', parent_key=self.SESSION_DIR_KEY, typename='file' ) manifest_builder.add_path( self.NATURAL_MOVIE_DIR_KEY, "natural_movie_templates", parent_key="BASEDIR", typename="dir" ) manifest_builder.add_path( self.TYPEWISE_ANALYSIS_METRICS_KEY, "%s_analysis_metrics.csv", parent_key='BASEDIR', typename="file" ) manifest_builder.add_path( self.NATURAL_MOVIE_KEY, "natural_movie_%d.h5", parent_key=self.NATURAL_MOVIE_DIR_KEY, typename="file" ) manifest_builder.add_path( self.NATURAL_SCENE_DIR_KEY, "natural_scene_templates", parent_key="BASEDIR", typename="dir" ) manifest_builder.add_path( self.NATURAL_SCENE_KEY, "natural_scene_%d.tiff", parent_key=self.NATURAL_SCENE_DIR_KEY, typename="file" ) return manifest_builder
@classmethod def _from_http_source_default(cls, fetch_api_cls, fetch_api_kwargs, **kwargs): fetch_api_kwargs = { "asynchronous": True } if fetch_api_kwargs is None else fetch_api_kwargs if kwargs.get("stream_writer") is None: if fetch_api_kwargs.get("asynchronous", True): kwargs["stream_writer"] = write_bytes_from_coroutine else: kwargs["stream_writer"] = write_from_stream return cls( fetch_api=fetch_api_cls.default(**fetch_api_kwargs), **kwargs )
[docs] @classmethod def from_lims(cls, lims_credentials: Optional[DbCredentials] = None, scheme: Optional[str] = None, host: Optional[str] = None, asynchronous: bool = True, manifest: Optional[str] = None, version: Optional[str] = None, cache: bool = True, fetch_tries: int = 2): """ Create an instance of EcephysProjectCache with an EcephysProjectLimsApi. Retrieves bleeding-edge data stored locally on Allen Institute servers. Only available for use on-site at the Allen Institute or through a vpn. Requires Allen Institute database credentials. Parameters ========== lims_credentials : DbCredentials Credentials to access LIMS database. If not provided will attempt to find credentials in environment variables. scheme : str URI scheme, such as "http". Defaults to EcephysProjectLimsApi.default value if unspecified. Will not be used unless `host` is also specified. host : str Web host. Defaults to EcephysProjectLimsApi.default value if unspecified. Will not be used unless `scheme` is also specified. asynchronous : bool Whether to fetch file asynchronously. Defaults to True. manifest : str or Path full path at which manifest json will be stored version : str version of manifest file. If this mismatches the version recorded in the file at manifest, an error will be raised. cache: bool Whether to write to the cache (default=True) fetch_tries : int Maximum number of times to attempt a download before giving up and raising an exception. Note that this is total tries, not retries """ if scheme and host: app_kwargs = {"scheme": scheme, "host": host} else: app_kwargs = None return cls._from_http_source_default( EcephysProjectLimsApi, {"lims_credentials": lims_credentials, "app_kwargs": app_kwargs, "asynchronous": asynchronous, }, # expects dictionary of kwargs manifest=manifest, version=version, cache=cache, fetch_tries=fetch_tries)
[docs] @classmethod def from_warehouse(cls, scheme: Optional[str] = None, host: Optional[str] = None, asynchronous: bool = True, manifest: Optional[Union[str, Path]] = None, version: Optional[str] = None, cache: bool = True, fetch_tries: int = 2): """ Create an instance of EcephysProjectCache with an EcephysProjectWarehouseApi. Retrieves released data stored in the warehouse. Parameters ========== scheme : str URI scheme, such as "http". Defaults to EcephysProjectWarehouseAPI.default value if unspecified. Will not be used unless `host` is also specified. host : str Web host. Defaults to EcephysProjectWarehouseApi.default value if unspecified. Will not be used unless `scheme` is also specified. asynchronous : bool Whether to fetch file asynchronously. Defaults to True. manifest : str or Path full path at which manifest json will be stored version : str version of manifest file. If this mismatches the version recorded in the file at manifest, an error will be raised. cache: bool Whether to write to the cache (default=True) fetch_tries : int Maximum number of times to attempt a download before giving up and raising an exception. Note that this is total tries, not retries """ if scheme and host: app_kwargs = {"scheme": scheme, "host": host, "asynchronous": asynchronous} else: app_kwargs = None return cls._from_http_source_default( EcephysProjectWarehouseApi, app_kwargs, manifest=manifest, version=version, cache=cache, fetch_tries=fetch_tries )
[docs] @classmethod def fixed(cls, manifest=None, version=None): """ Creates a EcephysProjectCache that refuses to fetch any data - only the existing local cache is accessible. Useful if you want to settle on a fixed dataset for analysis. Parameters ========== manifest : str or Path full path to existing manifest json version : str version of manifest file. If this mismatches the version recorded in the file at manifest, an error will be raised. """ return cls(fetch_api=EcephysProjectFixedApi(), manifest=manifest, version=version)
[docs]def count_owned(this, other, foreign_key, count_key, inplace=False): if not inplace: this = this.copy() counts = other.loc[:, foreign_key].value_counts() this[count_key] = 0 this.loc[counts.index.values, count_key] = counts.values if not inplace: return this
[docs]def get_grouped_uniques(this, other, foreign_key, field_key, unique_key, inplace=False): if not inplace: this = this.copy() uniques = other.groupby(foreign_key)\ .apply(lambda grp: pd.DataFrame(grp)[field_key].unique()) this[unique_key] = 0 this.loc[uniques.index.values, unique_key] = uniques.values if not inplace: return this
[docs]def read_csv(path) -> pd.DataFrame: return pd.read_csv(path, index_col="id")
[docs]def write_csv(path, df): df.to_csv(path)
[docs]def write_metrics_csv(path, df): df.to_csv(path)
[docs]def read_metrics_csv(path): return pd.read_csv(path, index_col='ecephys_unit_id')
[docs]def read_scene(path): return sitk.GetArrayFromImage(sitk.ReadImage(path))
[docs]def read_movie(path): return np.load(path, allow_pickle=False)
[docs]def read_nwb(path): reader = pynwb.NWBHDF5IO(str(path), 'r') nwbfile = reader.read() nwbfile.identifier # if the file is corrupt, make sure an exception gets raised during read return nwbfile