Source code for allensdk.brain_observatory.behavior.stimulus_processing.stimulus_templates
from typing import Dict, List
import warnings
import numpy as np
import pandas as pd
from allensdk.brain_observatory.behavior.stimulus_processing.util import \
convert_filepath_caseinsensitive
[docs]class StimulusImage(np.ndarray):
"""Container class for image stimuli"""
def __new__(cls, input_array: np.ndarray, name: str):
"""
Parameters
----------
name:
Name of the image
values
The unwarped image values
"""
obj = np.asarray(input_array).view(cls)
obj._name = name
return obj
@property
def name(self):
return self._name
[docs]class StimulusTemplate:
"""Container class for a collection of image stimuli"""
def __init__(self, image_set_name: str, image_attributes: List[dict],
images: List[np.ndarray]):
"""
Parameters
----------
image_set_name:
the name of the image set
image_attributes
List of image attributes as returned by the stimulus pkl
images
List of images as returned by the stimulus pkl
"""
self._image_set_name = image_set_name
image_set_name = convert_filepath_caseinsensitive(
image_set_name)
self._image_set_filepath = image_set_name
self._images: Dict[str, StimulusImage] = {}
for attr, image in zip(image_attributes, images):
image_name = attr['image_name']
self.__add_image(name=image_name, values=image)
@property
def image_set_name(self) -> str:
return self._image_set_name
@property
def image_names(self) -> List[str]:
return list(self.keys())
@property
def images(self) -> List[StimulusImage]:
return list(self.values())
[docs] def keys(self):
return self._images.keys()
[docs] def values(self):
return self._images.values()
[docs] def items(self):
return self._images.items()
[docs] def to_dataframe(self) -> pd.DataFrame:
index = pd.Index(self.image_names, name='image_name')
df = pd.DataFrame({'image': self.images}, index=index)
df.name = self._image_set_name
return df
def __add_image(self, name: str, values: np.ndarray):
"""
Parameters
----------
name:
Name of the image
values
The unwarped image values
"""
image = StimulusImage(input_array=values, name=name)
self._images[name] = image
def __getitem__(self, item) -> StimulusImage:
"""
Given an image name, returns the corresponding StimulusImage
"""
return self._images[item]
def __len__(self):
return len(self._images)
def __iter__(self):
yield from self._images
def __repr__(self):
return f'{self._images}'
def __eq__(self, other: object):
if isinstance(other, StimulusTemplate):
if self.image_set_name != other.image_set_name:
return False
if sorted(self.image_names) != sorted(other.image_names):
return False
for (img_name, self_img) in self.items():
other_img = other._images[img_name]
if not np.array_equal(self_img, other_img):
return False
return True
else:
raise NotImplementedError(
"Cannot compare a StimulusTemplate with an object of type: "
f"{type(other)}!")