Source code for allensdk.internal.mouse_connectivity.interval_unionize.interval_unionizer
from __future__ import division
import logging
import functools
from collections import defaultdict
import copy as cp
import numpy as np
from six import iteritems
[docs]class IntervalUnionizer(object):
[docs] @classmethod
def record_cb(cls):
return defaultdict(lambda *a, **k: 0, {})
def __init__(self, exclude_structure_ids=None):
'''Builds unionize records from grid data. Unionize records are
summaries of experimental observations occuring in a particular
spatial domain. Domains are generally specified by the intersection of
1. a brain structure
2. an injection polygon (or its inverse)
3. the left or right side of the brain
Parameters
----------
exclude_structure_ids : list of int, optional
Don't generate records for these structures. Defaults to [0],
which excludes everything not in the brain.
'''
if exclude_structure_ids is None:
exclude_structure_ids = [0]
self.exclude_structure_ids = exclude_structure_ids
[docs] def setup_interval_map(self, annotation):
'''Build a map from structure ids to intervals in the sorted flattened
reference space.
Parameters
----------
annotation : np.ndarray
Segmentation label array.
'''
logging.info('getting flat annotation')
flat_annot = annotation.flat
logging.info('finding sort')
self.sort = np.argsort(flat_annot)
logging.info('sorting flat annotation')
flat_annot = flat_annot[self.sort]
logging.info('finding bounds')
diff = np.diff(flat_annot)
bounds = np.nonzero(diff)[0]
uniques = [ flat_annot[ii] for ii in bounds ] + [flat_annot[-1]]
logging.info('building map')
lower_bounds = [0] + (bounds + 1).tolist()
upper_bounds = (bounds + 1).tolist() + [len(flat_annot)]
self.interval_map = {sid: item for sid, item
in zip(uniques, zip(lower_bounds, upper_bounds))
if sid not in self.exclude_structure_ids}
[docs] @classmethod
def propagate_record(cls, child_record, ancestor_record, copy_all=False):
'''Updates one unionize corresponding to a rootward structure with
information from a unionize corresponding to a leafward structure
Parameters
----------
child_record : unionize
Data will be drawn from this record
ancestor_record : unionize
This record will be updated
'''
raise NotImplementedError('specify in subclass!')
[docs] @classmethod
def propagate_unionizes(cls, direct_unionizes, ancestor_id_map):
'''Structures are arranged in a tree, whose leafward-oriented edges
indicate physical containment. This method updates rootward unionize
records with information from leafward ones.
Parameters
----------
direct_unionizes : list of unionizes
Each entry is a unionize record produced from a collection of
directly labeled voxels in the segmentation volume.
ancestor_id_map : dict
Keys are structure ids. Values are ids of all structures rootward in
the tree, including the key node
Returns
-------
output_unionizes : list of unionizes
Contains completed unionize records at all depths in the structure
tree
'''
output_unionizes = defaultdict(cls.record_cb, cp.deepcopy(direct_unionizes))
for k, v in iteritems(direct_unionizes):
for aid in ancestor_id_map[k]:
if k == aid:
continue
logging.debug('propagating data from {0} to {1}'.format(k, aid))
output_unionizes[aid] = cls.propagate_record(v, output_unionizes[aid])
return output_unionizes
[docs] @classmethod
def propagate_to_bilateral(cls, lateral_unionizes):
bilateral = defaultdict(cls.record_cb, {})
for sid in list(lateral_unionizes.keys()):
unionize = lateral_unionizes[sid]
other_id = -1 * sid
if (sid in bilateral) or (other_id in bilateral):
continue
logging.debug('bilateralizing structure {0}'.format(sid))
other = lateral_unionizes[other_id]
bilateral[sid] = cls.propagate_record(unionize, bilateral[sid], True)
bilateral[sid] = cls.propagate_record(other, bilateral[sid], True)
return bilateral
[docs] def postprocess_unionizes(self, raw_unionizes, **kwargs):
'''Carry out additional calculations/formatting derivative of core
unionization.
Parameters
----------
raw_unionizes : list of unionizes
Each entry is a unionize record.
'''
raise NotImplementedError('specify in subclass!')
[docs] def sort_data_arrays(self, data_arrays):
'''Apply the precomputed sort to flattened data arrays
Parameters
----------
data_arrays : dict
Keys identify types of data volume. Values are flattened, unsorted
arrays.
Returns
-------
dict :
As input, but values are sorted
'''
logging.info('sorting data arrays')
return {k: v[self.sort] for k, v in iteritems(data_arrays)}
[docs] def direct_unionize(self, data_arrays, pre_sorted=False, **kwargs):
'''Obtain unionize records from directly annotated regions.
Parameters
----------
data_arrays : dict
Keys identify types of data volume. Values are flattened arrays.
sorted : bool, optional
If False, data arrays will be sorted.
'''
if not pre_sorted:
data_arrays = self.sort_data_arrays(data_arrays)
unionizes = {}
for sid, (low, high) in iteritems(self.interval_map):
logging.debug( 'unionizing structure {0} :: voxel_count={1}'.format(sid, high - low) )
unionizes[sid] = self.extract_data(data_arrays, low, high, **kwargs)
return unionizes