import datetime
from pynwb import NWBFile, NWBHDF5IO
import pandas as pd
import allensdk.brain_observatory.nwb as nwb
import numpy as np
import SimpleITK as sitk
import pytz
import uuid
from pandas.util.testing import assert_frame_equal
import os
import math
import numpy as np
import xarray as xr
import pandas as pd
from allensdk.core.lazy_property import LazyProperty
from allensdk.brain_observatory.nwb.nwb_api import NwbApi
from allensdk.brain_observatory.behavior.trials_processing import TRIAL_COLUMN_DESCRIPTION_DICT
from allensdk.brain_observatory.behavior.schemas import OphysBehaviorMetaDataSchema, OphysBehaviorTaskParametersSchema
from allensdk.brain_observatory.nwb.metadata import load_LabMetaData_extension
from allensdk.brain_observatory.behavior.behavior_ophys_api import BehaviorOphysApiBase
load_LabMetaData_extension(OphysBehaviorMetaDataSchema, 'AIBS_ophys_behavior')
load_LabMetaData_extension(OphysBehaviorTaskParametersSchema, 'AIBS_ophys_behavior')
[docs]class BehaviorOphysNwbApi(NwbApi, BehaviorOphysApiBase):
def __init__(self, *args, **kwargs):
self.filter_invalid_rois = kwargs.pop("filter_invalid_rois", False)
super(BehaviorOphysNwbApi, self).__init__(*args, **kwargs)
[docs] def save(self, session_object):
nwbfile = NWBFile(
session_description=str(session_object.metadata['session_type']),
identifier=str(session_object.ophys_experiment_id),
session_start_time=session_object.metadata['experiment_datetime'],
file_create_date=pytz.utc.localize(datetime.datetime.now())
)
# Add stimulus_timestamps to NWB in-memory object:
nwb.add_stimulus_timestamps(nwbfile, session_object.stimulus_timestamps)
# Add running data to NWB in-memory object:
unit_dict = {'v_sig': 'V', 'v_in': 'V', 'speed': 'cm/s', 'timestamps': 's', 'dx': 'cm'}
nwb.add_running_data_df_to_nwbfile(nwbfile, session_object.running_data_df, unit_dict)
# Add stimulus template data to NWB in-memory object:
for name, image_data in session_object.stimulus_templates.items():
nwb.add_stimulus_template(nwbfile, image_data, name)
# Add index for this template to NWB in-memory object:
nwb_template = nwbfile.stimulus_template[name]
stimulus_index = session_object.stimulus_presentations[session_object.stimulus_presentations['image_set'] == nwb_template.name]
nwb.add_stimulus_index(nwbfile, stimulus_index, nwb_template)
# Add stimulus presentations data to NWB in-memory object:
nwb.add_stimulus_presentations(nwbfile, session_object.stimulus_presentations)
# Add trials data to NWB in-memory object:
nwb.add_trials(nwbfile, session_object.trials, TRIAL_COLUMN_DESCRIPTION_DICT)
# Add licks data to NWB in-memory object:
if len(session_object.licks) > 0:
nwb.add_licks(nwbfile, session_object.licks)
# Add rewards data to NWB in-memory object:
if len(session_object.rewards) > 0:
nwb.add_rewards(nwbfile, session_object.rewards)
# Add max_projection image data to NWB in-memory object:
nwb.add_max_projection(nwbfile, session_object.max_projection)
# Add average_image image data to NWB in-memory object:
nwb.add_average_image(nwbfile, session_object.average_projection)
# Add segmentation_mask_image image data to NWB in-memory object:
nwb.add_segmentation_mask_image(nwbfile, session_object.segmentation_mask_image)
# Add metadata to NWB in-memory object:
nwb.add_metadata(nwbfile, session_object.metadata)
# Add task parameters to NWB in-memory object:
nwb.add_task_parameters(nwbfile, session_object.task_parameters)
# Add roi metrics to NWB in-memory object:
nwb.add_cell_specimen_table(nwbfile, session_object.cell_specimen_table)
# Add dff to NWB in-memory object:
nwb.add_dff_traces(nwbfile, session_object.dff_traces, session_object.ophys_timestamps)
# Add corrected_fluorescence to NWB in-memory object:
nwb.add_corrected_fluorescence_traces(nwbfile, session_object.corrected_fluorescence_traces)
# Add motion correction to NWB in-memory object:
nwb.add_motion_correction(nwbfile, session_object.motion_correction)
# Write the file:
with NWBHDF5IO(self.path, 'w') as nwb_file_writer:
nwb_file_writer.write(nwbfile)
return nwbfile
[docs] def get_running_data_df(self, **kwargs):
running_speed = self.get_running_speed()
running_data_df = pd.DataFrame({'speed': running_speed.values},
index=pd.Index(running_speed.timestamps, name='timestamps'))
for key in ['v_in', 'v_sig']:
if key in self.nwbfile.acquisition:
running_data_df[key] = self.nwbfile.get_acquisition(key).data
for key in ['dx']:
if ('running' in self.nwbfile.modules) and (key in self.nwbfile.modules['running'].fields['data_interfaces']):
running_data_df[key] = self.nwbfile.modules['running'].get_data_interface(key).data
return running_data_df[['speed', 'dx', 'v_sig', 'v_in']]
[docs] def get_stimulus_templates(self, **kwargs):
return {key: val.data[:] for key, val in self.nwbfile.stimulus_template.items()}
[docs] def get_ophys_timestamps(self) -> np.ndarray:
return self.nwbfile.modules['two_photon_imaging'].get_data_interface('dff').roi_response_series['traces'].timestamps[:]
[docs] def get_stimulus_timestamps(self) -> np.ndarray:
return self.nwbfile.modules['stimulus'].get_data_interface('timestamps').timestamps[:]
[docs] def get_trials(self) -> pd.DataFrame:
trials = self.nwbfile.trials.to_dataframe()
if 'lick_events' in trials.columns:
trials.drop('lick_events', inplace=True, axis=1)
trials.index = trials.index.rename('trials_id')
return trials
[docs] def get_licks(self) -> np.ndarray:
if 'licking' in self.nwbfile.modules:
return pd.DataFrame({'time': self.nwbfile.modules['licking'].get_data_interface('licks')['timestamps'].timestamps[:]})
else:
return pd.DataFrame({'time': []})
[docs] def get_rewards(self) -> np.ndarray:
if 'rewards' in self.nwbfile.modules:
time = self.nwbfile.modules['rewards'].get_data_interface('autorewarded').timestamps[:]
autorewarded = self.nwbfile.modules['rewards'].get_data_interface('autorewarded').data[:]
volume = self.nwbfile.modules['rewards'].get_data_interface('volume').data[:]
return pd.DataFrame({'volume': volume, 'timestamps': time, 'autorewarded': autorewarded}).set_index('timestamps')
else:
return pd.DataFrame({'volume': [], 'timestamps': [], 'autorewarded': []}).set_index('timestamps')
[docs] def get_max_projection(self, image_api=None) -> sitk.Image:
return self.get_image('max_projection', 'two_photon_imaging', image_api=image_api)
[docs] def get_average_projection(self, image_api=None) -> sitk.Image:
return self.get_image('average_image', 'two_photon_imaging', image_api=image_api)
[docs] def get_segmentation_mask_image(self, image_api=None) -> sitk.Image:
return self.get_image('segmentation_mask_image', 'two_photon_imaging', image_api=image_api)
[docs] def get_task_parameters(self) -> dict:
metadata_nwb_obj = self.nwbfile.lab_meta_data['task_parameters']
data = OphysBehaviorTaskParametersSchema().dump(metadata_nwb_obj)
return data
[docs] def get_cell_specimen_table(self) -> pd.DataFrame:
df = self.nwbfile.modules['two_photon_imaging'].data_interfaces['image_segmentation'].plane_segmentations['cell_specimen_table'].to_dataframe()
df.index.rename('cell_roi_id', inplace=True)
df['cell_specimen_id'] = [None if csid == -1 else csid for csid in df['cell_specimen_id'].values]
df['image_mask'] = [mask.astype(bool) for mask in df['image_mask'].values]
df.reset_index(inplace=True)
df.set_index('cell_specimen_id', inplace=True)
if self.filter_invalid_rois:
df = df[df["valid_roi"]]
return df
[docs] def get_dff_traces(self) -> pd.DataFrame:
dff_nwb = self.nwbfile.modules['two_photon_imaging'].data_interfaces['dff'].roi_response_series['traces']
dff_traces = dff_nwb.data[:]
number_of_cells, number_of_dff_frames = dff_traces.shape
num_of_timestamps = len(self.get_ophys_timestamps())
assert num_of_timestamps == number_of_dff_frames
df = pd.DataFrame({'dff': [x for x in dff_traces]}, index=pd.Index(data=dff_nwb.rois.table.id[:], name='cell_roi_id'))
cell_specimen_table = self.get_cell_specimen_table()
df = cell_specimen_table[['cell_roi_id']].join(df, on='cell_roi_id')
return df
[docs] def get_corrected_fluorescence_traces(self) -> pd.DataFrame:
corrected_fluorescence_nwb = self.nwbfile.modules['two_photon_imaging'].data_interfaces['corrected_fluorescence'].roi_response_series['traces']
df = pd.DataFrame({'corrected_fluorescence': [x for x in corrected_fluorescence_nwb.data[:]]},
index=pd.Index(data=corrected_fluorescence_nwb.rois.table.id[:], name='cell_roi_id'))
cell_specimen_table = self.get_cell_specimen_table()
df = cell_specimen_table[['cell_roi_id']].join(df, on='cell_roi_id')
return df
[docs] def get_motion_correction(self) -> pd.DataFrame:
motion_correction_data = {}
motion_correction_data['x'] = self.nwbfile.modules['motion_correction'].get_data_interface('x').data[:]
motion_correction_data['y'] = self.nwbfile.modules['motion_correction'].get_data_interface('y').data[:]
return pd.DataFrame(motion_correction_data)
[docs]def equals(A, B, reraise=False):
field_set = set()
for key, val in A.__dict__.items():
if isinstance(val, LazyProperty):
field_set.add(key)
for key, val in B.__dict__.items():
if isinstance(val, LazyProperty):
field_set.add(key)
try:
for field in sorted(field_set):
x1, x2 = getattr(A, field), getattr(B, field)
err_msg = f"{field} on {A} did not equal {field} on {B} (\n{x1} vs\n{x2}\n)"
compare_fields(x1, x2, err_msg)
except NotImplementedError as e:
A_implements_get_field = hasattr(A.api, getattr(type(A), field).getter_name)
B_implements_get_field = hasattr(B.api, getattr(type(B), field).getter_name)
assert A_implements_get_field == B_implements_get_field == False
except (AssertionError, AttributeError) as e:
if reraise:
raise
return False
return True
[docs]def compare_fields(x1, x2, err_msg=""):
if isinstance(x1, pd.DataFrame):
try:
assert_frame_equal(x1, x2, check_like=True)
except:
print(err_msg)
raise
elif isinstance(x1, np.ndarray):
np.testing.assert_array_almost_equal(x1, x2, err_msg=err_msg)
elif isinstance(x1, xr.DataArray):
xr.testing.assert_allclose(x1, x2)
elif isinstance(x1, (list,)):
assert x1 == x2, err_msg
elif isinstance(x1, (sitk.Image,)):
assert x1.GetSize() == x2.GetSize(), err_msg
assert x1 == x2, err_msg
elif isinstance(x1, (dict,)):
for key in set(x1.keys()).union(set(x2.keys())):
key_err_msg = f"mismatch when checking key {key}. {err_msg}"
if isinstance(x1[key], (np.ndarray,)):
np.testing.assert_array_almost_equal(x1[key], x2[key], err_msg=key_err_msg)
elif isinstance(x1[key], (float,)):
if math.isnan(x1[key]) or math.isnan(x2[key]):
assert math.isnan(x1[key]) and math.isnan(x2[key]), key_err_msg
else:
assert x1[key] == x2[key], key_err_msg
else:
assert x1[key] == x2[key], key_err_msg
else:
assert x1 == x2, err_msg