Source code for allensdk.brain_observatory.circle_plots

# Allen Institute Software License - This software license is the 2-clause BSD
# license plus a third clause that prohibits redistribution for commercial
# purposes without further permission.
#
# Copyright 2017. Allen Institute. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Redistributions for commercial purposes are not permitted without the
# Allen Institute's written permission.
# For purposes of this license, commercial purposes is the incorporation of the
# Allen Institute's software into anything for which you will charge fees or
# other compensation. Contact terms@alleninstitute.org for commercial licensing
# opportunities.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#

import math

try:
    xrange
except:
    from past.builtins import xrange

import numpy as np
import pandas as pd
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.collections import PatchCollection, LineCollection
import matplotlib.transforms as mxfms
import matplotlib.colors as mcolors
import skimage.transform
from six import iteritems


DEFAULT_COLOR_MAP = LinearSegmentedColormap.from_list('default', [[.7,0,.7,0.0],[.7,0,0,1]])
DEFAULT_MEAN_RESP_COLOR_MAP = LinearSegmentedColormap.from_list('default', [[0.0,0.0,0.5,0.0],[0.0,0.0,0.5,1]])
DEFAULT_AXIS_COLOR = (0.8, 0.8, 0.8)
DEFAULT_LABEL_COLOR = (0.8, 0.8, 0.8)
LSN_ON_COLOR_MAP = LinearSegmentedColormap.from_list('default', [[.7,0,.7,0.0],[.7,0,0,1]])
LSN_OFF_COLOR_MAP = LinearSegmentedColormap.from_list('default', [[0.0,0.7,.7,0.0],[0,0,0.7,1]])
HEX_POSITIONS = []


[docs]def polar_to_xy(angles, radius): """ Convert an array of angles (in radians) and a radius in polar coordinates to an array of x,y coordinates. """ x = radius*np.cos(angles) y = radius*np.sin(angles) return np.array([x,y]).T
[docs]def polar_linspace(radius, start_angle, stop_angle, num, endpoint=False, degrees=True): """ Evenly distributed list of x,y coordinates from an input range of angles and a radius in polar coordinates. """ angles = np.linspace(start_angle, stop_angle, num=num, endpoint=endpoint) if degrees is True: angles *= np.pi / 180.0 return polar_to_xy(angles, radius)
[docs]def spiral_trials(radii, x=0.0, y=0.0): radii = np.array(radii) circles = [] if radii.size > 0: spiral = hex_pack(radii[0], len(radii)) for i,radius in enumerate(radii): circles.append(mpatches.Circle((spiral[i][0], spiral[i][1]), radii[i])) pos_xfm = mxfms.Affine2D().translate(x,y) collection = PatchCollection(circles) collection.set_transform(pos_xfm) return collection
[docs]def spiral_trials_polar(r, theta, radii, offset=None): if offset is None: offset = [0,0] collection = spiral_trials(radii, r + offset[0], offset[1]) rot_xfm = mxfms.Affine2D().rotate(theta) collection.set_transform(collection.get_transform() + rot_xfm) return collection
[docs]def angle_lines(angles, inner_radius, outer_radius): inner_pos = polar_to_xy(angles, inner_radius) outer_pos = polar_to_xy(angles, outer_radius) segments = np.array(list(zip(inner_pos, outer_pos))) return LineCollection(segments)
[docs]def radial_arcs(rs, start_theta, end_theta): arcs = [] for r in rs: arcs.append(mpatches.Arc((0,0), 2*r, 2*r, theta1=start_theta*180.0/np.pi, theta2=end_theta*180.0/np.pi)) return PatchCollection(arcs)
[docs]def rings_in_hex_pack(ct): return np.ceil((-3.0 + np.sqrt(9.0 - 12.0*(1.0 - ct))) / 6.0 + 1.0)
[docs]def radial_circles(rs): circles = [ mpatches.Circle((0,0), r) for r in rs ] return PatchCollection(circles)
[docs]def reset_hex_pack(): global HEX_POSITIONS HEX_POSITIONS = []
[docs]def hex_pack(radius, n): global HEX_POSITIONS if len(HEX_POSITIONS) < n: HEX_POSITIONS = build_hex_pack(n) return HEX_POSITIONS[:n]*radius*2.0
[docs]def build_hex_pack(n): pos = [] sq32 = math.sqrt(3.0) / 2.0 N = 1 vs = [ [-0.5, -sq32], [-1.0, 0.0], [-0.5, sq32], [0.5, sq32], [1, 0], [0.5, -sq32] ] pos.append([0,0]) while len(pos) < n: layer_pos = [ ] for i,v in enumerate(vs): x = - N * v[1] * sq32 y = N * v[0] * sq32 if N % 2 == 1: x -= 0.5 * v[0] y -= 0.5 * v[1] layer_pos.append([]) layer_pos[i].append([x,y]) mag = 1 sign = 1 for j in xrange(N-1): x += v[0] * mag * sign y += v[1] * mag * sign mag += 1 sign = -sign layer_pos[i].append([x,y]) for j in range(N): for i in range(len(vs)): if j < len(layer_pos[i]): pos.append(layer_pos[i][j]) N+=1 return np.array(pos)
[docs]def polar_line_circles(radii, theta, start_r=0): circles = [ mpatches.Circle( (0,0), radii[0] ) ] line_xfm = mxfms.Affine2D().translate(start_r,0).rotate(theta) x = 0 for ri in range(1, len(radii)): x += radii[ri-1] + radii[ri] circles.append(mpatches.Circle( (x,0), radii[ri] )) collection = PatchCollection(circles) collection.set_transform(line_xfm) return collection
[docs]def wedge_ring(N, inner_radius, outer_radius, start=0, stop=360): degs = np.linspace(start, stop, N+1, endpoint=True) wedges = [] if stop > start: for i in range(len(degs)-1): wedges.append( mpatches.Wedge( (0,0), outer_radius, degs[i], degs[i+1], width=outer_radius-inner_radius ) ) else: for i in range(1,len(degs)): wedges.append( mpatches.Wedge( (0,0), outer_radius, degs[i], degs[i-1], width=outer_radius-inner_radius ) ) return PatchCollection(wedges)
[docs]def add_angle_labels(ax, angles, labels, radius, color=None, fontdict=None, offset=0.05): angle_pos = polar_to_xy(angles, radius) for i in range(len(angle_pos)): xy = angle_pos[i,:] u = xy + xy / np.linalg.norm(xy) * offset ax.text(u[0], u[1], labels[i], color=color, horizontalalignment='center', verticalalignment='center', fontdict=fontdict)
[docs]def add_arrow(ax, radius, start_angle, end_angle, color=None, width=18.0): if color is None: color = DEFAULT_LABEL_COLOR fig = ax.get_figure() size = fig.get_size_inches() dpi = fig.get_dpi() mutation_scale = size[0] * dpi / 500.0 * width d_angle = end_angle - start_angle start_pos = (radius * np.cos(start_angle), radius * np.sin(start_angle)) end_pos = (radius * np.cos(end_angle), radius * np.sin(end_angle)) connstyle = mpatches.ConnectionStyle.Angle3(angleA=0, angleB=(d_angle*180.0/np.pi)) arrowstyle = mpatches.ArrowStyle.Simple(tail_width=0.33, head_length=0.66, head_width=1.0) ax.add_patch(mpatches.FancyArrowPatch(posA=start_pos, posB=end_pos, arrowstyle=arrowstyle, connectionstyle=connstyle, facecolor=color, linewidth=0, mutation_scale=mutation_scale))
[docs]def make_pincushion_plot(data, trials, on, nrows, ncols, clim=None, color_map=None, radius=None): if radius is None: max_sweeps = 0 for sweeps in trials.itervalues(): max_sweeps = max(max_sweeps, len(sweeps[0])) rings = rings_in_hex_pack(max_sweeps) radius = 0.5 / (2.0 * rings - 1.0) if clim is None: clim = [ data.min(), data.max() ] if color_map is None: color_map = LSN_ON_COLOR_MAP if on else LSN_OFF_COLOR_MAP ax = plt.gca() for (col,row,on_state), sweeps in iteritems(trials): if on_state != on: continue valid_sweeps = sweeps[0][sweeps[0] < data.size] responses = np.sort(data[valid_sweeps])[::-1] responses = responses[responses >= clim[0]] if responses.size > 0: coll = spiral_trials(np.ones(responses.shape)*radius, col+0.5, row+0.5) coll.set_transform(coll.get_transform() + ax.transData) coll.set_array(responses) coll.set_cmap(color_map) coll.set_clim(clim) coll.set_linewidths(0) ax.add_collection(coll) ax.set_ylim((0,nrows)) ax.set_xlim((0,ncols))
[docs]class PolarPlotter( object ): DIR_CW = -1 DIR_CCW = 1 def __init__(self, direction=DIR_CW, angle_start=0, circle_scale=1.1, inner_radius=None, plot_center=(0.0,0.0), plot_scale=0.9): self.plot_scale = plot_scale self.plot_center = plot_center self.angle_transform = np.vectorize(lambda x: ((x + angle_start)*direction)*np.pi/180.0) self.inner_radius = inner_radius self.circle_scale = circle_scale
[docs] def finalize(self): ax = plt.gca() fig = plt.gcf() figsize = fig.get_size_inches() aspect = figsize[0] / figsize[1] w = 2.0 / self.plot_scale h = w / aspect bounds = ( self.plot_center[0] - w*.5, self.plot_center[0] + w*.5, self.plot_center[1] - h*.5, self.plot_center[1] + h*.5 ) ax.set_xlim(bounds[0], bounds[1]) ax.set_ylim(bounds[2], bounds[3]) plt.subplots_adjust(left=0,right=1,bottom=0,top=1)
@classmethod def _clim(self, clim, data): if clim is None: clim = [ data.min(), data.max() ] if clim[0] == clim[1]: clim[0] = 0 if clim[0] == clim[1]: clim[1] = 1 return clim
[docs]class TrackPlotter( PolarPlotter ): def __init__(self, direction=PolarPlotter.DIR_CW, angle_start=270.0, inner_radius=.45, ring_length=None, *args, **kwargs): super(TrackPlotter, self).__init__(direction=direction, angle_start=angle_start, inner_radius=inner_radius, *args, **kwargs) self.ring_length = ring_length
[docs] def show_arrow(self, color=None): start, end = self.angle_transform([0.0, 40.0]) add_arrow(plt.gca(), self.inner_radius * .85, start, end, color)
[docs] def plot(self, data, clim=None, cmap=DEFAULT_COLOR_MAP, mean_cmap=DEFAULT_MEAN_RESP_COLOR_MAP, norm=None): ax = plt.gca() clim = self._clim(clim, data) if self.ring_length: data = skimage.transform.resize(data.astype(np.float64), (data.shape[0], self.ring_length), mode='constant', anti_aliasing=False) data_mean = data.mean(axis=0) data = np.vstack((data, data_mean)) radii = np.linspace(self.inner_radius, 1.0, data.shape[0]+2) start,stop = self.angle_transform([0,360])*180.0/np.pi if norm is None: norm = mcolors.PowerNorm(0.5, vmin=clim[0], vmax=clim[1], clip=True) for i, row_data in enumerate(data): inner_radius = radii[i] if i < data.shape[0] - 1: outer_radius = radii[i+1] ring_cmap = cmap else: outer_radius = radii[i+2] ring_cmap = mean_cmap wedges = wedge_ring(len(row_data), inner_radius, outer_radius, start=start, stop=stop) wedges.set_array(row_data) #wedges.set_clim(clim) wedges.set_cmap(ring_cmap) wedges.set_norm(norm) wedges.set_edgecolors((0,0,0,0)) ax.add_collection(wedges) self.finalize()
[docs]class CoronaPlotter( PolarPlotter ): def __init__(self, angle_start=270, plot_scale=1.2, inner_radius=.3, *args, **kwargs): super(CoronaPlotter, self).__init__(inner_radius=inner_radius, angle_start=angle_start, plot_scale=plot_scale, *args, **kwargs) self.categories = None self.cat_idx_map = None
[docs] def infer_dims(self, category_data): self.set_dims(np.sort(np.unique(category_data)))
[docs] def set_dims(self, categories): self.categories = categories self.cat_idx_map = dict(zip(categories, range(len(categories))))
[docs] def show_arrow(self, color=None): start, end = self.angle_transform([0.0, 40.0]) add_arrow(plt.gca(), self.inner_radius * .85, start, end, color)
[docs] def show_circle(self, color=None): if color is None: color = DEFAULT_LABEL_COLOR ax = plt.gca() collection = radial_circles([0.96 * self.inner_radius]) collection.set_facecolor((0,0,0,0)) collection.set_edgecolor(color) collection.set_zorder(1) ax.add_collection(collection)
[docs] def plot(self, category_data, data=None, clim=None, cmap=DEFAULT_COLOR_MAP): ax = plt.gca() if self.categories is None: self.infer_dims(category_data) if data is None: data = np.ones(len(category_data)) clim = self._clim(clim, data) num_cats = len(self.categories) hth = 180.0 / num_cats degs = np.linspace(hth, 360.0-hth, num_cats) degs = self.angle_transform(degs) circle_radius = self.inner_radius * abs(np.sin((degs[1] - degs[0]) * .5)) radii = np.ones(len(data)) * circle_radius * self.circle_scale df = pd.DataFrame({ 'category': category_data }) gb = df.groupby(['category']) for category, trials in iteritems(gb.groups): idx = self.cat_idx_map[category] order = np.argsort(data[trials])[::-1] trial_order = np.array(trials)[order] circles = polar_line_circles(radii[trial_order], degs[idx], self.inner_radius) circles.set_transform(circles.get_transform() + ax.transData) circles.set_array(data[trial_order]) circles.set_cmap(cmap) circles.set_clim(clim) circles.set_edgecolors((0,0,0,0)) circles.set_zorder(2) ax.add_collection(circles) self.finalize()
[docs]class FanPlotter( PolarPlotter ): def __init__(self, group_scale=0.9, *args, **kwargs): super(FanPlotter, self).__init__(*args, **kwargs) self.group_scale = group_scale self.angles = None self.xangles = None self.angle_map = None self.rs = None self.radii = None self.r_radius_map = None self.groups = None self.group_offsets = None self.group_offset_map = None self.group_radius = None
[docs] def infer_dims(self, r_data, angle_data, group_data): rs = np.sort(np.unique(r_data)) angles = np.sort(np.unique(angle_data)) groups = np.sort(np.unique(group_data)) if group_data is not None else None self.set_dims(rs, angles, groups)
[docs] def set_dims(self, rs, angles, groups): self.angles = angles self.xangles = self.angle_transform(angles) self.angle_map = dict(zip(self.angles, self.xangles)) self.rs = rs num_rs = len(rs) # map r value to radius if self.inner_radius is None: self.inner_radius = 1.0 / ( 2 * num_rs ) hdr = ( 1.0 - self.inner_radius ) / num_rs / 2.0 self.radii = np.linspace(self.inner_radius + hdr, 1.0 - hdr, num_rs) self.r_radius_map = dict(zip(rs, self.radii)) self.group_radius = hdr * self.group_scale self.groups = groups if groups is not None else [ np.nan ] num_groups = len(self.groups) # map group to group offset if num_groups == 1: self.group_offsets = [ [ 0, 0 ] ] else: offset_radius = self.group_radius * self.circle_scale self.group_offsets = polar_linspace(offset_radius/np.sqrt(2), -45, -45-360, num_groups) self.group_radius = offset_radius * 0.5 self.group_offset_map = dict(zip(self.groups, self.group_offsets))
[docs] def show_axes(self, angles=None, radii=None, closed=False, color=None): ax = plt.gca() if self.angles is None: raise Exception("dimensions not set!") if color is None: color = DEFAULT_AXIS_COLOR if angles is None: angles = self.xangles if radii is None: radii = self.radii lines = angle_lines(angles, radii[0], radii[-1]) lines.set_zorder(1) lines.set_edgecolors(color) ax.add_collection(lines) if closed: collection = radial_circles(radii) else: collection = radial_arcs(radii, min(angles), max(angles)) collection.set_facecolors((0,0,0,0.0)) collection.set_edgecolors(color) collection.set_zorder(1) ax.add_collection(collection)
[docs] def show_angle_labels(self, angles=None, labels=None, color=None, offset=.05, fontdict=None): if angles is None: angles = self.xangles if labels is None: labels = self.angles.astype(int) if color is None: color = DEFAULT_LABEL_COLOR add_angle_labels(plt.gca(), angles, labels, 1.0, offset=offset, color=color, fontdict=fontdict)
[docs] def show_group_labels(self, groups=None, color=None, fontdict=None): ax = plt.gca() if groups is None: groups = self.groups if color is None: color = DEFAULT_LABEL_COLOR r = self.inner_radius*.5 angle = 90.0 x = r * np.cos(angle) y = r * np.sin(angle) for group in groups: off = self.group_offset_map[group] xfm = mxfms.Affine2D().translate(r+off[0]*2.0,off[1]*2.0).rotate(self.angle_transform(angle)) p = xfm.transform_point([0,0]) ax.text(p[0], p[1], group, color=color, horizontalalignment='center', verticalalignment='center', fontdict=fontdict) start_theta = self.angle_transform(angle+20) end_theta = self.angle_transform(angle-20) ax.add_patch(mpatches.Arc((0,0), 2*r, 2*r, theta1=start_theta*180.0/np.pi, theta2=end_theta*180.0/np.pi, color=color)) ax.add_collection(LineCollection([[[0, .7*r], [0, 1.3*r]]], color=color))
[docs] def show_r_labels(self, radii=None, labels=None, color=None, offset=.1, fontdict=None): ax = plt.gca() if radii is None: radii = self.radii if labels is None: labels = self.rs if color is None: color = DEFAULT_LABEL_COLOR if labels is None: labels = self.rs line_th = self.xangles[0] line_x = radii * np.cos(line_th) line_y = radii * np.sin(line_th) for i,(x,y) in enumerate(zip(line_x,line_y)): ax.text(x, y-offset, labels[i], color=color, horizontalalignment='center', verticalalignment='center', fontdict=fontdict)
[docs] def plot(self, r_data, angle_data, group_data=None, data=None, cmap=DEFAULT_COLOR_MAP, clim=None, rmap=None, rlim=None, axis_color=None, label_color=None): ax = plt.gca() if data is None: data = np.ones(len(r_data)) clim = self._clim(clim, data) if rmap is None: rnorm = np.vectorize(lambda x: 1.0) else: if rlim is None: rlim = clim norm = mcolors.Normalize(clim[0], clim[1]) rnorm = np.vectorize(lambda x: rmap(norm(x))) if self.angles is None: self.infer_dims(r_data, angle_data, group_data) num_groups = len(self.groups) num_rs = len(self.rs) num_angles = len(self.angles) df = pd.DataFrame({ 'group': group_data, 'angle': angle_data, 'r': r_data }) # compute circle radius trials_per_group = float(len(df)) / num_groups / num_rs / num_angles rings = rings_in_hex_pack(trials_per_group) circle_radius = self.group_radius / (2*rings - 1) * self.circle_scale gb = df.groupby(['group', 'angle', 'r']) for (group, angle, r), trials in iteritems(gb.groups): responses = np.sort(data[trials])[::-1] circles = spiral_trials_polar(self.r_radius_map[r], self.angle_map[angle], rnorm(responses) * circle_radius, offset=self.group_offset_map[group]) circles.set_transform(circles.get_transform() + ax.transData) circles.set_array(responses) circles.set_cmap(cmap) circles.set_clim(clim) circles.set_zorder(2) circles.set_linewidths(0) ax.add_collection(circles) self.finalize()
[docs] @staticmethod def for_static_gratings(): return FanPlotter(angle_start=180, plot_scale=0.9, circle_scale=2.0, group_scale=0.4, plot_center=[0,.45], inner_radius=.2)
[docs] @staticmethod def for_drifting_gratings(): return FanPlotter()