Source code for roiextractors.volumetricimagingextractor

"""Base class definition for volumetric imaging extractors."""

import warnings
from typing import Iterable

import numpy as np

from .extraction_tools import DtypeType
from .imagingextractor import ImagingExtractor


[docs] class VolumetricImagingExtractor(ImagingExtractor): """Class to combine multiple ImagingExtractor objects by depth plane.""" extractor_name = "VolumetricImaging" installatiuon_mesage = "" def __init__(self, imaging_extractors: list[ImagingExtractor]): """Initialize a VolumetricImagingExtractor object from a list of ImagingExtractors. Parameters ---------- imaging_extractors: list of ImagingExtractor list of imaging extractor objects """ super().__init__() assert isinstance(imaging_extractors, list), "Enter a list of ImagingExtractor objects as argument" assert all(isinstance(imaging_extractor, ImagingExtractor) for imaging_extractor in imaging_extractors) self._check_consistency_between_imaging_extractors(imaging_extractors) self._imaging_extractors = imaging_extractors self._num_planes = len(imaging_extractors) self.is_volumetric = True @staticmethod def _check_consistency_between_imaging_extractors(imaging_extractors: list[ImagingExtractor]): """Check that essential properties are consistent between extractors so that they can be combined appropriately. Parameters ---------- imaging_extractors: list of ImagingExtractor list of imaging extractor objects Raises ------ AssertionError If any of the properties are not consistent between extractors. Notes ----- This method checks the following properties: - sampling frequency - image size - number of channels - channel names - data type - num_frames """ properties_to_check = dict( get_sampling_frequency="The sampling frequency", get_image_shape="The shape of a frame", get_channel_names="The name of the channels", get_dtype="The data type", get_num_samples="The number of samples", ) for method, property_message in properties_to_check.items(): values = [getattr(extractor, method)() for extractor in imaging_extractors] unique_values = set(tuple(v) if isinstance(v, Iterable) else v for v in values) assert ( len(unique_values) == 1 ), f"{property_message} is not consistent over the files (found {unique_values})."
[docs] def get_series(self, start_sample: int | None = None, end_sample: int | None = None) -> np.ndarray: if start_sample is None: start_sample = 0 elif start_sample < 0: start_sample = self.get_num_samples() + start_sample elif start_sample >= self.get_num_samples(): raise ValueError( f"start_sample {start_sample} is greater than or equal to the number of samples {self.get_num_samples()}" ) if end_sample is None: end_sample = self.get_num_samples() elif end_sample < 0: end_sample = self.get_num_samples() + end_sample elif end_sample > self.get_num_samples(): raise ValueError(f"end_sample {end_sample} is greater than the number of samples {self.get_num_samples()}") if end_sample <= start_sample: raise ValueError(f"end_sample {end_sample} is less than or equal to start_sample {start_sample}") series = np.zeros((end_sample - start_sample, *self.get_sample_shape()), self.get_dtype()) for i, imaging_extractor in enumerate(self._imaging_extractors): series[..., i] = imaging_extractor.get_series(start_sample, end_sample) return series
[docs] def get_image_shape(self) -> tuple[int, int]: """Get the shape of the video frame (num_rows, num_columns). Returns ------- image_shape: tuple Shape of the video frame (num_rows, num_columns). """ return self._imaging_extractors[0].get_image_shape()
[docs] def get_num_planes(self) -> int: """Get the number of depth planes. Returns ------- _num_planes: int The number of depth planes. """ return self._num_planes
[docs] def get_num_samples(self) -> int: return self._imaging_extractors[0].get_num_samples()
[docs] def get_sampling_frequency(self) -> float: return self._imaging_extractors[0].get_sampling_frequency()
[docs] def get_channel_names(self) -> list: warnings.warn( "get_channel_names is deprecated and will be removed in May 2026 or after.", category=FutureWarning, stacklevel=2, ) return self._imaging_extractors[0].get_channel_names()
[docs] def get_dtype(self) -> DtypeType: return self._imaging_extractors[0].get_dtype()
[docs] def get_volume_shape(self) -> tuple[int, int, int]: """Get the shape of the volumetric video (num_rows, num_columns, num_planes). Returns ------- video_shape: tuple Shape of the volumetric video (num_rows, num_columns, num_planes). """ image_shape = self.get_image_shape() return (image_shape[0], image_shape[1], self.get_num_planes())
[docs] def depth_slice(self, start_plane: int | None = None, end_plane: int | None = None): """Return a new VolumetricImagingExtractor ranging from the start_plane to the end_plane.""" start_plane = start_plane if start_plane is not None else 0 end_plane = end_plane if end_plane is not None else self._num_planes assert ( 0 <= start_plane < self._num_planes ), f"'start_plane' ({start_plane}) must be greater than 0 and smaller than the number of planes ({self._num_planes})." assert ( start_plane < end_plane <= self._num_planes ), f"'end_plane' ({end_plane}) must be greater than 'start_plane' ({start_plane}) and smaller than or equal to the number of planes ({self._num_planes})." return DepthSliceVolumetricImagingExtractor(parent_extractor=self, start_plane=start_plane, end_plane=end_plane)
[docs] def slice_samples(self, start_sample: int | None = None, end_sample: int | None = None): """Return a new VolumetricImagingExtractor with a subset of samples.""" raise NotImplementedError( "slice_samples is not implemented for VolumetricImagingExtractor due to conflicts with get_series()." )
[docs] def get_native_timestamps( self, start_sample: int | None = None, end_sample: int | None = None ) -> np.ndarray | None: # Delegate to the first imaging extractor return self._imaging_extractors[0].get_native_timestamps(start_sample, end_sample)
[docs] class DepthSliceVolumetricImagingExtractor(VolumetricImagingExtractor): """Class to get a lazy depth slice. This class can only be used for volumetric imaging data. Do not use this class directly but use `.depth_slice(...)` on a VolumetricImagingExtractor object. """ extractor_name = "DepthSliceVolumetricImagingExtractor" def __init__( self, parent_extractor: VolumetricImagingExtractor, start_plane: int | None = None, end_plane: int | None = None, ): """Initialize a VolumetricImagingExtractor whose plane(s) subset the parent. Subset is exclusive on the right bound, that is, the plane indices of this VolumetricImagingExtractor range over [0, ..., end_plane-start_plane-1]. Parameters ---------- parent_extractor : VolumetricImagingExtractor The VolumetricImagingExtractor object to subset the planes of. start_plane : int, optional The left bound of the depth to subset. The default is the first plane of the parent. end_plane : int, optional The right bound of the depth, exclusively, to subset. The default is the last plane of the parent. """ super().__init__(imaging_extractors=parent_extractor._imaging_extractors[start_plane:end_plane])