import itertools
import numpy as np
import logging
[docs]class MaskSet( object ):
def __init__(self, masks):
self.masks = masks
self.bbs = make_bbs(self.masks)
self.mask_dist = bb_dist(self.bbs)
self.cached_sizes = {}
self.cached_unions = {}
self.cached_union_sizes = {}
self.cached_intersections = {}
self.cached_intersection_sizes = {}
@property
def count(self):
return len(self.bbs)
[docs] def distance(self, mask_idxs):
return max(self.mask_dist[i,j] for (i,j) in itertools.combinations(mask_idxs, 2))
[docs] def close(self, mask_idxs, max_dist):
return not any(self.mask_dist[i,j] > max_dist for (i,j) in itertools.combinations(mask_idxs, 2))
[docs] def close_sets(self, set_size, max_dist):
mask_sets = itertools.combinations(range(len(self.bbs)), set_size)
return (ms for ms in mask_sets if self.close(ms, max_dist))
def _idx_key(self, idxs):
return tuple(sorted(set(idxs)))
[docs] def mask(self, mask_idx):
return self.masks[mask_idx]
[docs] def union(self, mask_idxs):
mask_idxs = self._idx_key(mask_idxs)
if mask_idxs in self.cached_unions:
return self.cached_unions[mask_idxs]
if len(mask_idxs) == 0:
return None
i0 = mask_idxs[0]
union = self.masks[i0].copy()
if len(mask_idxs) == 1:
return union
for idx in mask_idxs[1:]:
union |= self.masks[idx]
self.cached_unions[mask_idxs] = union
return union
[docs] def overlap_fraction(self, idx0, idx1):
union_size = self.union_size([idx0,idx1])
overlap_size = self.intersection_size([idx0,idx1])
return float(overlap_size) / float(union_size)
[docs] def detect_duplicates(self, overlap_threshold):
duplicate_masks = set()
for idx0,idx1 in self.close_sets(set_size=2, max_dist=0):
overlap_frac = self.overlap_fraction(idx0, idx1)
if overlap_frac > overlap_threshold:
duplicate_masks.add(tuple(sorted([idx0,idx1])))
return duplicate_masks
[docs] def mask_is_union_of_set(self, mask_idx, set_idxs, threshold):
# does this mask overlap with each element of the set individually?
# i.e. overlap of mask and set element covers most of the set element
for set_mask_idx in set_idxs:
overlap_size = self.intersection_size([set_mask_idx, mask_idx])
set_mask_size = self.size(set_mask_idx)
if overlap_size < threshold * set_mask_size:
return False
# does this mask cover more than the union of the individual set elements?
set_union = self.union(set_idxs)
mask = self.mask(mask_idx)
overlap = set_union & mask
overlap_size = overlap.sum()
return overlap_size > self.size(mask_idx) * threshold
[docs] def detect_unions(self, set_size=2, max_dist=10, threshold=0.7):
union_masks = {}
mask_combos = list(self.close_sets(set_size, max_dist))
for i, set_idxs in enumerate(mask_combos):
for mask_idx in range(self.count):
if mask_idx in set_idxs:
continue
elif not self.close([mask_idx] + list(set_idxs), max_dist):
continue
if self.mask_is_union_of_set(mask_idx, set_idxs, threshold):
if mask_idx in union_masks:
logging.warning("already detected this mask as a union")
union_masks[mask_idx] = set_idxs
return union_masks
[docs] def union_size(self, mask_idxs):
mask_idxs = self._idx_key(mask_idxs)
if mask_idxs in self.cached_union_sizes:
return self.cached_union_sizes[mask_idxs]
s = self.union(mask_idxs).sum()
self.cached_union_sizes[mask_idxs] = s
return s
[docs] def intersection(self, mask_idxs):
mask_idxs = self._idx_key(mask_idxs)
if mask_idxs in self.cached_intersections:
return self.cached_intersections[mask_idxs]
if len(mask_idxs) == 0:
return None
# don't cache the empty ones
if not self.close(mask_idxs, 0):
return np.zeros(self.masks[0].shape)
i0 = mask_idxs[0]
intersection = self.masks[i0].copy()
if len(mask_idxs) == 1:
return intersection
for idx in mask_idxs[1:]:
intersection &= self.masks[idx]
self.cached_intersections[mask_idxs] = intersection
return intersection
[docs] def intersection_size(self, mask_idxs):
mask_idxs = self._idx_key(mask_idxs)
if mask_idxs in self.cached_intersection_sizes:
return self.cached_intersection_sizes[mask_idxs]
s = self.intersection(mask_idxs).sum()
self.cached_intersection_sizes[mask_idxs] = s
return s
[docs] def size(self, mask_idx):
return self.union_size([mask_idx])
[docs]def make_bbs(masks):
bbs = []
for i in range(len(masks)):
m = np.where(masks[i])
bbs.append([[m[0].min(), m[0].max()],[m[1].min(), m[1].max()]])
return bbs
[docs]def bb_dist(bbs):
num_bbs = len(bbs)
dist = np.zeros((num_bbs, num_bbs))
for i,j in itertools.combinations(range(num_bbs), 2):
bbi = bbs[i]
bbj = bbs[j]
if bbi[0][0] < bbj[0][1]:
distx = bbj[0][0] - bbi[0][1]
else:
distx = bbi[0][0] - bbj[0][1]
if bbi[1][0] < bbj[1][1]:
disty = bbj[1][0] - bbi[1][1]
else:
disty = bbi[1][0] - bbj[1][1]
dist[i,j] = max(distx,disty)
dist[j,i] = dist[i,j]
return dist