import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as plt
import allensdk.internal.core.lims_utilities as lu
from allensdk.internal.core.lims_pipeline_module import PipelineModule, run_module
import argparse, os, logging, shutil
import h5py
import numpy as np
import shutil
import allensdk.brain_observatory.demixer as demixer
from allensdk.config.manifest import Manifest
import allensdk.core.json_utilities as ju
import logging
EXCLUDE_LABELS = ["union", "duplicate", "motion_border" ]
[docs]def debug(experiment_id, local=False):
OUTPUT_DIRECTORY = "/data/informatics/CAM/demix"
SDK_PATH = "/data/informatics/CAM/analysis/allensdk"
SCRIPT = "/data/informatics/CAM/analysis/allensdk/allensdk/internal/pipeline_modules/run_demixing.py"
sd = lu.query("select storage_directory from ophys_experiments where id = %d" % experiment_id)[0]['storage_directory']
rois = lu.query("select * from cell_rois where ophys_experiment_id = %d" % experiment_id)
exc_labels = lu.query("""
select cr.id, rel.name as exclusion_label from cell_rois cr
join cell_rois_roi_exclusion_labels crrel on crrel.cell_roi_id = cr.id
join roi_exclusion_labels rel on crrel.roi_exclusion_label_id = rel.id
where cr.ophys_experiment_id = %d
""" % experiment_id)
nrois = { roi['id']: dict(width=roi['width'],
height=roi['height'],
x=roi['x'],
y=roi['y'],
id=roi['id'],
valid=roi['valid_roi'],
mask=roi['mask_matrix'],
exclusion_labels=[])
for roi in rois }
for exc_label in exc_labels:
nrois[exc_label['id']]['exclusion_labels'].append(exc_label['exclusion_label'])
movie_path_response = lu.query('''
select wkf.filename, wkf.storage_directory from well_known_files wkf
join well_known_file_types wkft on wkft.id = wkf.well_known_file_type_id
where wkf.attachable_id = {} and wkf.attachable_type = 'OphysExperiment'
and wkft.name = 'MotionCorrectedImageStack'
'''.format(experiment_id))
movie_h5_path = os.path.join(movie_path_response[0]['storage_directory'], movie_path_response[0]['filename'])
exp_dir = os.path.join(OUTPUT_DIRECTORY, str(experiment_id))
input_data = {
"movie_h5": movie_h5_path,
"traces_h5": os.path.join(sd, "processed", "roi_traces.h5"),
"roi_masks": nrois.values(),
"output_file": os.path.join(exp_dir, "demixed_traces.h5")
}
run_module(SCRIPT,
input_data,
exp_dir,
sdk_path=SDK_PATH,
pbs=dict(vmem=160,
job_name="demix_%d"% experiment_id,
walltime="36:00:00"),
local=local,
optional_args=['--log-level','DEBUG'])
[docs]def assert_exists(file_name):
if not os.path.exists(file_name):
raise IOError("file does not exist: %s" % file_name)
[docs]def get_path(obj, key, check_exists):
try:
path = obj[key]
except KeyError:
raise KeyError("required input field '%s' does not exist" % key)
if check_exists:
assert_exists(path)
return path
[docs]def main():
mod = PipelineModule()
mod.parser.add_argument("--exclude-labels", nargs="*", default=EXCLUDE_LABELS)
data = mod.input_data()
logging.debug("reading input")
traces, masks, valid, trace_ids, movie_h5, output_h5 = parse_input(data, mod.args.exclude_labels)
logging.debug("excluded masks: %s", str(zip(np.where(~valid)[0], trace_ids[~valid])))
output_dir = os.path.dirname(output_h5)
plot_dir = os.path.join(output_dir, "demix_plots")
if os.path.exists(plot_dir):
shutil.rmtree(plot_dir)
Manifest.safe_mkdir(plot_dir)
logging.debug("reading movie")
with h5py.File(movie_h5, 'r') as f:
movie = f['data'].value
# only demix non-union, non-duplicate ROIs
valid_idxs = np.where(valid)
demix_traces = traces[valid_idxs]
demix_masks = masks[valid_idxs]
logging.debug("demixing")
demixed_traces, drop_frames = demixer.demix_time_dep_masks(demix_traces, movie, demix_masks)
nt_inds = demixer.plot_negative_transients(demix_traces,
demixed_traces,
valid[valid_idxs],
demix_masks,
trace_ids[valid_idxs],
plot_dir)
logging.debug("rois with negative transients: %s", str(trace_ids[valid_idxs][nt_inds]))
nb_inds = demixer.plot_negative_baselines(demix_traces,
demixed_traces,
demix_masks,
trace_ids[valid_idxs],
plot_dir)
# negative baseline rois (and those that overlap with them) become nans
logging.debug("rois with negative baselines (or overlap with them): %s", str(trace_ids[valid_idxs][nb_inds]))
demixed_traces[nb_inds, :] = np.nan
logging.info("Saving output")
out_traces = np.zeros(traces.shape, dtype=demix_traces.dtype)
out_traces[:] = np.nan
out_traces[valid_idxs] = demixed_traces
with h5py.File(output_h5, 'w') as f:
f.create_dataset("data", data=out_traces, compression="gzip")
roi_names = np.array([str(rn) for rn in trace_ids]).astype(np.string_)
f.create_dataset("roi_names", data=roi_names)
mod.write_output_data(dict(
negative_transient_roi_ids=trace_ids[valid_idxs][nt_inds],
negative_baseline_roi_ids=trace_ids[valid_idxs][nb_inds]
))
if __name__ == "__main__": main()