Source code for allensdk.brain_observatory.comparison_utils

import datetime
import math
from typing import Any, Optional, Set

import SimpleITK as sitk
import numpy as np
import pandas as pd
import xarray as xr
from pandas.testing import assert_frame_equal


[docs]def compare_fields(x1: Any, x2: Any, err_msg="", ignore_keys: Optional[Set[str]] = None): """Helper function to compare if two fields (attributes) are equal to one another. Parameters ---------- x1 : Any The first field x2 : Any The other field err_msg : str, optional The error message to display if two compared fields do not equal one another, by default "" (an empty string) ignore_keys For dictionary comparison, ignore these keys """ if ignore_keys is None: ignore_keys = set() if isinstance(x1, pd.Index): x1 = x1.values x2 = x2.values if isinstance(x1, pd.DataFrame): try: assert_frame_equal(x1, x2, check_like=True) except AssertionError as e: print(e) print(err_msg) raise elif isinstance(x1, pd.Series): pd.testing.assert_series_equal(x1, x2) elif isinstance(x1, np.ndarray): np.testing.assert_array_almost_equal(x1, x2, err_msg=err_msg) elif isinstance(x1, xr.DataArray): xr.testing.assert_allclose(x1, x2) elif isinstance(x1, (list, tuple)): assert len(x1) == len(x2) for i in range(len(x1)): compare_fields(x1=x1[i], x2=x2[i]) elif isinstance(x1, (sitk.Image,)): assert x1.GetSize() == x2.GetSize(), err_msg assert x1 == x2, err_msg elif isinstance(x1, (datetime.datetime, pd.Timestamp)): if isinstance(x1, pd.Timestamp): x1 = x1.to_pydatetime() if isinstance(x2, pd.Timestamp): x2 = x2.to_pydatetime() time_delta = (x1 - x2).total_seconds() # Timestamp differences should be less than 60 seconds assert abs(time_delta) < 60 elif isinstance(x1, (float,)): if math.isnan(x1) or math.isnan(x2): both_nan = (math.isnan(x1) and math.isnan(x2)) assert both_nan, err_msg else: assert x1 == x2, err_msg elif isinstance(x1, (dict,)): for key in set(x1.keys()).union(set(x2.keys())): if key in ignore_keys: continue key_err_msg = f"Mismatch when checking key {key}. {err_msg}" compare_fields(x1[key], x2[key], err_msg=key_err_msg) else: assert x1 == x2, err_msg