import logging
import operator as op
from collections import defaultdict
from six.moves import reduce
import numpy as np
[docs]class Stitcher(object):
def __init__(self, image_dimensions, tiles, average_tiles, channels):
logging.info('image_dimensions: {0}'.format(image_dimensions))
self.image_dimensions = image_dimensions
self.average_tiles = defaultdict(lambda *a, **k: None, average_tiles)
self.tiles = tiles
self.channels = channels
[docs] def run(self, cb=np.array):
slice_image, stitched_indicator = initialize_images(self.image_dimensions, len(self.channels))
missing_tiles = {}
for tile in self.tiles:
if tile.is_missing:
missing_tiles[tile.index] = tile.get_missing_path()
tile.initialize_image()
else:
tile.apply_average_tile_to_self(self.average_tiles[tile.channel])
tile.trim_self()
self.stitch(slice_image, stitched_indicator, tile, cb)
return slice_image, missing_tiles
[docs] def stitch(self, slice_image, stitched_indicator, tile, cb=np.array):
region = tile.get_image_region()
current_region = slice_image[region]
indicator_region = stitched_indicator[region]
stup = (tile.size['row'], tile.size['column'])
blend = get_blend(indicator_region, stup, cb)
blend = make_blended_tile(blend, tile.image, current_region)
logging.info('obtained blend')
slice_image[region] = blend
stitched_indicator[region] = 1
logging.info('updated image region with tile data')
[docs]def initialize_image(dimensions, nchannels, dtype, order='C'):
return np.zeros((dimensions['row'], dimensions['column'], nchannels), dtype=dtype, order=order)
[docs]def initialize_images(dimensions, nchannels):
return initialize_image(dimensions, nchannels, np.uint16), initialize_image(dimensions, nchannels, np.int8)
[docs]def make_blended_tile(blend, tile, current_region):
return np.multiply((1 - blend), tile) + np.multiply(blend, current_region)
[docs]def get_indicator_bound_point(indicator, lg, axis):
'''Finds the index of first change in a binary mask
along a specified axis in a specified direction
'''
delta = np.diff(indicator, axis=axis)
points = np.where(lg(delta, 0))
del delta
points = np.unique(points[axis])
size = indicator.shape[axis]
points = points[lg(points, size / 2.0)]
if len(points) > 0:
return points[-1]
return None
[docs]def blend_component_from_point(point, mesh, lg):
'''Obtains a normalized component of the blend, which describes depth of
overlap along a specified axis in a specified direction
'''
# this has the effect that the shallowest part of the blend
# is always 0 - symmetric with the deepest after normalization.
blend = point - mesh + 1
blend[lg(blend, 0)] = 0
blend = np.fabs(blend)
mx = np.amax(blend)
mx = mx if mx > 0.0 else 1.0
return blend / mx
[docs]def get_blend_component(indicator, lg, axis, meshes):
'''
'''
point = get_indicator_bound_point(indicator, lg, axis)
if point is None:
return []
return [blend_component_from_point(point, meshes[axis], lg)]
[docs]def get_overall_blend(indicator, meshes):
'''
'''
blends = []
for lg in (op.lt, op.gt):
for axis in (0, 1):
blends.extend(get_blend_component(indicator, lg, axis, meshes))
if len(blends) == 0:
return np.zeros_like(indicator)
return reduce(np.maximum, blends)
[docs]def get_blend(indicator_region, stup, cb=np.array):
'''
'''
meshes = np.meshgrid(*map(np.arange, stup), indexing='ij')
blend = get_overall_blend(indicator_region, meshes)
return cb(np.multiply(blend, indicator_region))