Source code for allensdk.internal.brain_observatory.mask_set

import itertools
import numpy as np
import logging

[docs]class MaskSet( object ): def __init__(self, masks): self.masks = masks self.bbs = make_bbs(self.masks) self.mask_dist = bb_dist(self.bbs) self.cached_sizes = {} self.cached_unions = {} self.cached_union_sizes = {} self.cached_intersections = {} self.cached_intersection_sizes = {} @property def count(self): return len(self.bbs)
[docs] def distance(self, mask_idxs): return max(self.mask_dist[i,j] for (i,j) in itertools.combinations(mask_idxs, 2))
[docs] def close(self, mask_idxs, max_dist): return not any(self.mask_dist[i,j] > max_dist for (i,j) in itertools.combinations(mask_idxs, 2))
[docs] def close_sets(self, set_size, max_dist): mask_sets = itertools.combinations(range(len(self.bbs)), set_size) return (ms for ms in mask_sets if self.close(ms, max_dist))
def _idx_key(self, idxs): return tuple(sorted(set(idxs)))
[docs] def mask(self, mask_idx): return self.masks[mask_idx]
[docs] def union(self, mask_idxs): mask_idxs = self._idx_key(mask_idxs) if mask_idxs in self.cached_unions: return self.cached_unions[mask_idxs] if len(mask_idxs) == 0: return None i0 = mask_idxs[0] union = self.masks[i0].copy() if len(mask_idxs) == 1: return union for idx in mask_idxs[1:]: union |= self.masks[idx] self.cached_unions[mask_idxs] = union return union
[docs] def overlap_fraction(self, idx0, idx1): union_size = self.union_size([idx0,idx1]) overlap_size = self.intersection_size([idx0,idx1]) return float(overlap_size) / float(union_size)
[docs] def detect_duplicates(self, overlap_threshold): duplicate_masks = set() for idx0,idx1 in self.close_sets(set_size=2, max_dist=0): overlap_frac = self.overlap_fraction(idx0, idx1) if overlap_frac > overlap_threshold: duplicate_masks.add(tuple(sorted([idx0,idx1]))) return duplicate_masks
[docs] def mask_is_union_of_set(self, mask_idx, set_idxs, threshold): # does this mask overlap with each element of the set individually? # i.e. overlap of mask and set element covers most of the set element for set_mask_idx in set_idxs: overlap_size = self.intersection_size([set_mask_idx, mask_idx]) set_mask_size = self.size(set_mask_idx) if overlap_size < threshold * set_mask_size: return False # does this mask cover more than the union of the individual set elements? set_union = self.union(set_idxs) mask = self.mask(mask_idx) overlap = set_union & mask overlap_size = overlap.sum() return overlap_size > self.size(mask_idx) * threshold
[docs] def detect_unions(self, set_size=2, max_dist=10, threshold=0.7): union_masks = {} mask_combos = list(self.close_sets(set_size, max_dist)) for i, set_idxs in enumerate(mask_combos): for mask_idx in range(self.count): if mask_idx in set_idxs: continue elif not self.close([mask_idx] + list(set_idxs), max_dist): continue if self.mask_is_union_of_set(mask_idx, set_idxs, threshold): if mask_idx in union_masks: logging.warning("already detected this mask as a union") union_masks[mask_idx] = set_idxs return union_masks
[docs] def union_size(self, mask_idxs): mask_idxs = self._idx_key(mask_idxs) if mask_idxs in self.cached_union_sizes: return self.cached_union_sizes[mask_idxs] s = self.union(mask_idxs).sum() self.cached_union_sizes[mask_idxs] = s return s
[docs] def intersection(self, mask_idxs): mask_idxs = self._idx_key(mask_idxs) if mask_idxs in self.cached_intersections: return self.cached_intersections[mask_idxs] if len(mask_idxs) == 0: return None # don't cache the empty ones if not self.close(mask_idxs, 0): return np.zeros(self.masks[0].shape) i0 = mask_idxs[0] intersection = self.masks[i0].copy() if len(mask_idxs) == 1: return intersection for idx in mask_idxs[1:]: intersection &= self.masks[idx] self.cached_intersections[mask_idxs] = intersection return intersection
[docs] def intersection_size(self, mask_idxs): mask_idxs = self._idx_key(mask_idxs) if mask_idxs in self.cached_intersection_sizes: return self.cached_intersection_sizes[mask_idxs] s = self.intersection(mask_idxs).sum() self.cached_intersection_sizes[mask_idxs] = s return s
[docs] def size(self, mask_idx): return self.union_size([mask_idx])
[docs]def make_bbs(masks): bbs = [] for i in range(len(masks)): m = np.where(masks[i]) bbs.append([[m[0].min(), m[0].max()],[m[1].min(), m[1].max()]]) return bbs
[docs]def bb_dist(bbs): num_bbs = len(bbs) dist = np.zeros((num_bbs, num_bbs)) for i,j in itertools.combinations(range(num_bbs), 2): bbi = bbs[i] bbj = bbs[j] if bbi[0][0] < bbj[0][1]: distx = bbj[0][0] - bbi[0][1] else: distx = bbi[0][0] - bbj[0][1] if bbi[1][0] < bbj[1][1]: disty = bbj[1][0] - bbi[1][1] else: disty = bbi[1][0] - bbj[1][1] dist[i,j] = max(distx,disty) dist[j,i] = dist[i,j] return dist