Source code for roiextractors.testing

"""Testing utilities for the roiextractors package."""

import warnings
from collections.abc import Iterable
from typing import Literal

import numpy as np
from numpy.testing import assert_array_almost_equal, assert_array_equal
from numpy.typing import DTypeLike

from roiextractors import NumpyImagingExtractor, NumpySegmentationExtractor

from .imagingextractor import ImagingExtractor
from .segmentationextractor import SegmentationExtractor

NoneType = type(None)
floattype = (float, np.floating)
inttype = (int, np.integer)


[docs] def generate_dummy_video( size: tuple[int, int, int] | tuple[int, int, int, int], dtype: DTypeLike = "uint16", seed: int = 0 ): """Generate a dummy video of a given size and dtype. Parameters ---------- size : tuple[int, int, int] or tuple[int, int, int, int] Size of the video to generate. For planar data: (num_frames, num_rows, num_columns) For volumetric data: (num_frames, num_rows, num_columns, num_planes) dtype : DTypeLike, optional Dtype of the video to generate, by default "uint16". seed : int, default 0 seed for the random number generator, by default 0. Returns ------- video : np.ndarray A dummy video of the given size and dtype. """ dtype = np.dtype(dtype) number_of_bytes = dtype.itemsize rng = np.random.default_rng(seed) low = 0 if "u" in dtype.name else 2 ** (number_of_bytes - 1) - 2**number_of_bytes high = 2**number_of_bytes - 1 if "u" in dtype.name else 2**number_of_bytes - 2 ** (number_of_bytes - 1) - 1 video = rng.random(size=size) if "float" in dtype.name else rng.integers(low=low, high=high, size=size, dtype=dtype) return video
[docs] def generate_dummy_imaging_extractor( *, num_rows: int = 10, num_columns: int = 10, sampling_frequency: float = 30.0, dtype: DTypeLike = "uint16", seed: int = 0, num_samples: int | None = 30, has_native_timestamps: bool = False, num_planes: int | None = None, ): """Generate a dummy imaging extractor for testing. The imaging extractor is built by feeding random data into the `NumpyImagingExtractor`. Parameters ---------- num_rows : int, optional number of rows in the video, by default 10. num_columns : int, optional number of columns in the video, by default 10. sampling_frequency : float, optional sampling frequency of the video, by default 30. dtype : DTypeLike, optional dtype of the video, by default "uint16". seed : int, default 0 seed for the random number generator, by default 0. num_samples : int, default 30 number of samples in the video, by default 30. has_native_timestamps : bool, default False if True, the extractor will return native timestamps (irregularly spaced). num_planes : int, optional number of depth planes for volumetric data. If None, creates 2D data. Returns ------- ImagingExtractor An imaging extractor with random data fed into `NumpyImagingExtractor`. """ # Generate video data - volumetric if num_planes is specified if num_planes is not None: size = (num_samples, num_rows, num_columns, num_planes) # For volumetric data, channel_names should match num_planes since NumpyImagingExtractor # treats the last dimension as channels channel_names_to_use = [f"plane_{i}" for i in range(num_planes)] else: size = (num_samples, num_rows, num_columns, 1) channel_names_to_use = ["channel_num_0"] video = generate_dummy_video(size=size, dtype=dtype, seed=seed) # Create base extractor imaging_extractor = NumpyImagingExtractor( timeseries=video, sampling_frequency=sampling_frequency, channel_names=channel_names_to_use ) # Add volumetric support if requested # TODO: Once channel names properly support planes, refactor NumpyImagingExtractor # to natively handle volumetric data instead of using types.MethodType overrides. # The challenge is that NumpyImagingExtractor fundamentally treats the last dimension # as channels, but volumetric data needs the last dimension to be planes. if num_planes is not None: import types imaging_extractor.is_volumetric = True imaging_extractor._num_planes = num_planes # Override methods to support volumetric data def get_num_planes(self): """Get the number of depth planes.""" return self._num_planes def get_series(self, start_sample=None, end_sample=None): """Get volumetric series data with all planes.""" if start_sample is None: start_sample = 0 if end_sample is None: end_sample = self.get_num_samples() # Return all dimensions (time, height, width, planes) return self._video[start_sample:end_sample, ...] def get_sample_shape(self): """Get the shape of a single volumetric sample.""" return (*self.get_image_shape(), self.get_num_planes()) def get_volume_shape(self): """Get the shape of the volume (num_rows, num_columns, num_planes).""" return (*self.get_image_shape(), self.get_num_planes()) # Bind methods to instance imaging_extractor.get_num_planes = types.MethodType(get_num_planes, imaging_extractor) imaging_extractor.get_series = types.MethodType(get_series, imaging_extractor) imaging_extractor.get_sample_shape = types.MethodType(get_sample_shape, imaging_extractor) imaging_extractor.get_volume_shape = types.MethodType(get_volume_shape, imaging_extractor) # Add native timestamps if requested # NOTE: We use types.MethodType here to override get_native_timestamps for testing purposes only. # NumpyImagingExtractor correctly returns None for get_native_timestamps() because numpy arrays # don't have native timestamps. This override creates synthetic timestamps to test code that # handles extractors with native timestamp support (like some microscopy file formats). # This is testing-specific functionality and should NOT be added to NumpyImagingExtractor itself. if has_native_timestamps: import types # Generate regular timestamps (evenly spaced) def get_native_timestamps(self, start_sample=None, end_sample=None): if start_sample is None: start_sample = 0 if end_sample is None: end_sample = self.get_num_samples() # Generate timestamps on the fly timestamps = np.arange(self.get_num_samples()) / self.get_sampling_frequency() return timestamps[start_sample:end_sample] imaging_extractor.get_native_timestamps = types.MethodType(get_native_timestamps, imaging_extractor) return imaging_extractor
class _DummySegmentationExtractor(NumpySegmentationExtractor): """A private subclass of NumpySegmentationExtractor that optionally provides native timestamps. NumpySegmentationExtractor returns None for get_native_timestamps() because numpy arrays do not have native timestamps. This subclass allows the dummy generator to optionally produce timestamps without modifying the underlying class. """ def __init__(self, *args, native_timestamps: np.ndarray | None = None, **kwargs): super().__init__(*args, **kwargs) self._native_timestamps = native_timestamps def get_native_timestamps( self, start_sample: int | None = None, end_sample: int | None = None ) -> np.ndarray | None: if self._native_timestamps is None: return None if start_sample is None: start_sample = 0 if end_sample is None: end_sample = self.get_num_samples() return self._native_timestamps[start_sample:end_sample]
[docs] def generate_dummy_segmentation_extractor( *, num_rois: int = 10, num_rows: int = 25, num_columns: int = 25, sampling_frequency: float = 30.0, has_summary_images: bool = True, has_raw_signal: bool = True, has_dff_signal: bool = True, has_deconvolved_signal: bool = True, has_neuropil_signal: bool = True, rejected_list: list | None = None, seed: int = 0, num_samples: int | None = 30, mask_type: Literal["image", "pixel"] = "image", native_timestamps: Literal["evenly_spaced", "unevenly_spaced"] | None = None, ) -> SegmentationExtractor: """Generate a dummy segmentation extractor for testing. The segmentation extractor is built by feeding random data into the `NumpySegmentationExtractor`. Parameters ---------- num_rois : int, optional number of regions of interest, by default 10. num_rows : int, optional number of rows in the hypothetical video from which the data was extracted, by default 25. num_columns : int, optional number of columns in the hypothetical video from which the data was extracted, by default 25. sampling_frequency : float, optional sampling frequency of the hypothetical video from which the data was extracted, by default 30.0. has_summary_images : bool, optional whether the dummy segmentation extractor has summary images or not (mean and correlation). has_raw_signal : bool, optional whether a raw fluorescence signal is desired in the object, by default True. has_dff_signal : bool, optional whether a relative (df/f) fluorescence signal is desired in the object, by default True. has_deconvolved_signal : bool, optional whether a deconvolved signal is desired in the object, by default True. has_neuropil_signal : bool, optional whether a neuropil signal is desired in the object, by default True. rejected_list: list, optional A list of rejected rois, None by default. seed : int, default 0 seed for the random number generator, by default 0. num_samples : int, optional Number of samples in the recording, by default 30. mask_type : str, default "image" Type of mask to generate. One of "image" or "pixel". "image" generates dense masks of shape (num_rows, num_columns, num_rois). "pixel" generates sparse masks as a list of (n_pixels, 3) arrays with columns [y, x, weight]. native_timestamps : "evenly_spaced" | "unevenly_spaced" | None, default None Controls whether the extractor returns native timestamps. None: no native timestamps (returns None). "evenly_spaced": evenly spaced timestamps based on sampling_frequency. "unevenly_spaced": timestamps with small random jitter around the regular spacing. Returns ------- SegmentationExtractor A segmentation extractor with random data fed into `NumpySegmentationExtractor` Notes ----- Note that this dummy example is meant to be a mock object with the right shape, structure and objects but does not contain meaningful content. That is, the image masks matrices are not plausible image mask for a roi, the raw signal is not a meaningful biological signal and is not related appropriately to the deconvolved signal , etc. """ valid_mask_types = ("image", "pixel") if mask_type not in valid_mask_types: raise ValueError(f"mask_type must be one of {valid_mask_types}, got '{mask_type}'") rng = np.random.default_rng(seed) # Create dummy image masks (always needed for NumpySegmentationExtractor construction) image_masks = rng.random((num_rows, num_columns, num_rois)) movie_dims = (num_rows, num_columns) # Create signals raw = rng.random((num_samples, num_rois)) if has_raw_signal else None dff = rng.random((num_samples, num_rois)) if has_dff_signal else None deconvolved = rng.random((num_samples, num_rois)) if has_deconvolved_signal else None neuropil = rng.random((num_samples, num_rois)) if has_neuropil_signal else None # Summary images mean_image = rng.random((num_rows, num_columns)) if has_summary_images else None correlation_image = rng.random((num_rows, num_columns)) if has_summary_images else None # Rois width = len( str(num_rois - 1) ) # e.g., width=2 for 10 ROIs (roi_00, roi_01, ..., roi_09), width=3 for 100 ROIs (roi_000, roi_001, ..., roi_099) roi_ids = [f"roi_{id:0{width}d}" for id in range(num_rois)] roi_locations_rows = rng.integers(low=0, high=num_rows, size=num_rois) roi_locations_columns = rng.integers(low=0, high=num_columns, size=num_rois) roi_locations = np.vstack((roi_locations_rows, roi_locations_columns)) accepted_list = roi_ids if rejected_list is not None: accepted_list = list(set(accepted_list).difference(rejected_list)) # Generate native timestamps if requested native_timestamps_array = None if native_timestamps is None: native_timestamps_array = None elif native_timestamps == "evenly_spaced": native_timestamps_array = np.arange(num_samples) / sampling_frequency elif native_timestamps == "unevenly_spaced": timestamps = np.arange(num_samples) / sampling_frequency jitter = rng.normal(loc=0.0, scale=0.1 / sampling_frequency, size=num_samples) native_timestamps_array = np.sort(timestamps + jitter) else: valid_types = (None, "evenly_spaced", "unevenly_spaced") raise ValueError(f"native_timestamps must be one of {valid_types}, got '{native_timestamps}'") dummy_segmentation_extractor = _DummySegmentationExtractor( sampling_frequency=sampling_frequency, image_masks=image_masks, raw=raw, dff=dff, deconvolved=deconvolved, neuropil=neuropil, mean_image=mean_image, correlation_image=correlation_image, roi_ids=roi_ids, roi_locations=roi_locations, accepted_list=accepted_list, rejected_list=rejected_list, movie_dims=movie_dims, channel_names=["channel_num_0"], native_timestamps=native_timestamps_array, ) # Replace mask data with pixel masks if requested if mask_type == "pixel": from .segmentationextractor import _ROIMasks num_pixels_per_roi = 5 pixel_masks = [] for _ in range(num_rois): y_coords = rng.integers(low=0, high=num_rows, size=num_pixels_per_roi).astype(float) x_coords = rng.integers(low=0, high=num_columns, size=num_pixels_per_roi).astype(float) weights = rng.random(num_pixels_per_roi) pixel_masks.append(np.column_stack([y_coords, x_coords, weights])) roi_id_map = {roi_id: index for index, roi_id in enumerate(dummy_segmentation_extractor.get_roi_ids())} dummy_segmentation_extractor._roi_masks = _ROIMasks( data=pixel_masks, mask_tpe="nwb-pixel_mask", field_of_view_shape=(num_rows, num_columns), roi_id_map=roi_id_map, ) return dummy_segmentation_extractor
def _assert_iterable_shape(iterable, shape): """Assert that the iterable has the given shape. If the iterable is a numpy array, the shape is checked directly.""" ar = iterable if isinstance(iterable, np.ndarray) else np.array(iterable) for ar_shape, given_shape in zip(ar.shape, shape): if isinstance(given_shape, int): assert ar_shape == given_shape, f"Expected {given_shape}, received {ar_shape}!" def _assert_iterable_shape_max(iterable, shape_max): """Assert that the iterable has a shape less than or equal to the given maximum shape.""" ar = iterable if isinstance(iterable, np.ndarray) else np.array(iterable) for ar_shape, given_shape in zip(ar.shape, shape_max): if isinstance(given_shape, int): assert ar_shape <= given_shape def _assert_iterable_element_dtypes(iterable, dtypes): """Assert that the iterable has elements of the given dtypes.""" if isinstance(iterable, Iterable) and not isinstance(iterable, str): for iter in iterable: _assert_iterable_element_dtypes(iter, dtypes) else: assert isinstance(iterable, dtypes), f"array is none of the types {dtypes}" def _assert_iterable_complete(iterable, dtypes=None, element_dtypes=None, shape=None, shape_max=None): """Assert that the iterable is complete, i.e. it is not None and has the given dtypes, element_dtypes, shape and shape_max.""" assert isinstance(iterable, dtypes), f"iterable {type(iterable)} is none of the types {dtypes}" if not isinstance(iterable, NoneType): if shape is not None: _assert_iterable_shape(iterable, shape=shape) if shape_max is not None: _assert_iterable_shape_max(iterable, shape_max=shape_max) if element_dtypes is not None: _assert_iterable_element_dtypes(iterable, element_dtypes)
[docs] def check_segmentations_equal( segmentation_extractor1: SegmentationExtractor, segmentation_extractor2: SegmentationExtractor ): """Check that two segmentation extractors have equal fields.""" check_segmentation_return_types(segmentation_extractor1) check_segmentation_return_types(segmentation_extractor2) # assert equality: assert segmentation_extractor1.get_num_rois() == segmentation_extractor2.get_num_rois() assert segmentation_extractor1.get_num_samples() == segmentation_extractor2.get_num_samples() assert np.isclose( segmentation_extractor1.get_sampling_frequency(), segmentation_extractor2.get_sampling_frequency() ) assert_array_equal(segmentation_extractor1.get_frame_shape(), segmentation_extractor2.get_frame_shape()) assert_array_equal( segmentation_extractor1.get_roi_image_masks(roi_ids=segmentation_extractor1.get_roi_ids()[:1]), segmentation_extractor2.get_roi_image_masks(roi_ids=segmentation_extractor2.get_roi_ids()[:1]), ) assert set( segmentation_extractor1.get_roi_pixel_masks(roi_ids=segmentation_extractor1.get_roi_ids()[:1])[0].flatten() ) == set( segmentation_extractor2.get_roi_pixel_masks(roi_ids=segmentation_extractor1.get_roi_ids()[:1])[0].flatten() ) check_segmentations_images(segmentation_extractor1, segmentation_extractor2) assert_array_equal(segmentation_extractor1.get_accepted_list(), segmentation_extractor2.get_accepted_list()) assert_array_equal(segmentation_extractor1.get_rejected_list(), segmentation_extractor2.get_rejected_list()) assert_array_equal(segmentation_extractor1.get_roi_ids(), segmentation_extractor2.get_roi_ids()) assert_array_equal(segmentation_extractor1.get_traces(), segmentation_extractor2.get_traces()) assert_array_equal( segmentation_extractor1.get_timestamps(), segmentation_extractor2.get_timestamps(), )
[docs] def check_segmentations_images( segmentation_extractor1: SegmentationExtractor, segmentation_extractor2: SegmentationExtractor, ): """Check that the segmentation images are equal for the given segmentation extractors.""" images_in_extractor1 = segmentation_extractor1.get_images_dict() images_in_extractor2 = segmentation_extractor2.get_images_dict() assert len(images_in_extractor1) == len(images_in_extractor2) image_names_are_equal = all(image_name in images_in_extractor1.keys() for image_name in images_in_extractor2.keys()) assert image_names_are_equal, "The names of segmentation images in the segmentation extractors are not the same." for image_name in images_in_extractor1.keys(): assert_array_equal( images_in_extractor1[image_name], images_in_extractor2[image_name], ), f"The segmentation images for {image_name} are not equal."
[docs] def check_segmentation_return_types(seg: SegmentationExtractor): """Check that the return types of the segmentation extractor are correct.""" assert isinstance(seg.get_num_rois(), int) assert isinstance(seg.get_num_samples(), int) assert isinstance(seg.get_sampling_frequency(), (NoneType, floattype)) _assert_iterable_complete(seg.get_image_size(), dtypes=Iterable, element_dtypes=inttype, shape=(2,)) _assert_iterable_complete( seg.get_roi_image_masks(roi_ids=seg.get_roi_ids()[:1]), dtypes=(np.ndarray,), element_dtypes=floattype, shape=(*seg.get_image_size(), 1), ) _assert_iterable_complete( seg.get_roi_ids(), dtypes=(list,), shape=(seg.get_num_rois(),), ) assert isinstance(seg.get_roi_pixel_masks(roi_ids=seg.get_roi_ids()[:2]), list) _assert_iterable_complete( seg.get_roi_pixel_masks(roi_ids=seg.get_roi_ids()[:1])[0], dtypes=(np.ndarray,), element_dtypes=floattype, shape_max=(np.prod(seg.get_image_size()), 3), ) for image_name in seg.get_images_dict(): _assert_iterable_complete( seg.get_image(image_name), dtypes=(np.ndarray, NoneType), element_dtypes=floattype, shape_max=(*seg.get_image_size(),), ) _assert_iterable_complete( seg.get_accepted_list(), dtypes=(list, NoneType), shape_max=(seg.get_num_rois(),), ) _assert_iterable_complete( seg.get_rejected_list(), dtypes=(list, NoneType), shape_max=(seg.get_num_rois(),), ) with warnings.catch_warnings(): warnings.filterwarnings("ignore", message="get_roi_locations", category=FutureWarning) _assert_iterable_complete( seg.get_roi_locations(), dtypes=(np.ndarray,), shape=(2, seg.get_num_rois()), element_dtypes=inttype, ) _assert_iterable_complete( seg.get_traces(), dtypes=(np.ndarray, NoneType), element_dtypes=floattype, shape=(np.prod(seg.get_num_rois()), None), ) assert isinstance(seg.get_traces_dict(), dict) assert isinstance(seg.get_images_dict(), dict) assert {"raw", "dff", "neuropil", "deconvolved", "denoised"} == set(seg.get_traces_dict().keys())
[docs] def check_imaging_equal(imaging_extractor1: ImagingExtractor, imaging_extractor2: ImagingExtractor): """Check that two imaging extractors have equal fields.""" # assert equality: assert imaging_extractor1.get_num_samples() == imaging_extractor2.get_num_samples() assert np.isclose(imaging_extractor1.get_sampling_frequency(), imaging_extractor2.get_sampling_frequency()) assert_array_equal(imaging_extractor1.get_sample_shape(), imaging_extractor2.get_sample_shape()) assert_array_equal( imaging_extractor1.get_series(start_sample=0, end_sample=1), imaging_extractor2.get_series(start_sample=0, end_sample=1), ) assert_array_almost_equal( imaging_extractor1.get_timestamps(), imaging_extractor2.get_timestamps(), )
[docs] def check_imaging_return_types(img_ex: ImagingExtractor): """Check that the return types of the imaging extractor are correct.""" assert isinstance(img_ex.get_num_samples(), inttype) assert isinstance(img_ex.get_sampling_frequency(), floattype) _assert_iterable_complete(iterable=img_ex.get_image_size(), dtypes=Iterable, element_dtypes=inttype, shape=(2,))
# This needs a method for getting frame shape not image size. It only works for n_channel==1 # two_first_frames = img_ex.get_frames(frame_idxs=[0, 1]) # _assert_iterable_complete( # iterable=two_first_frames, # dtypes=(np.ndarray,), # element_dtypes=inttype + floattype, # shape=(2, *img_ex.get_image_size()), # )