Source code for roiextractors.imagingextractor
"""Base class definitions for all ImagingExtractors.
Classes
-------
ImagingExtractor
Abstract class that contains all the meta-data and input data from the imaging data.
"""
import warnings
from abc import ABC, abstractmethod
from copy import deepcopy
from math import prod
import numpy as np
from .core_utils import _convert_bytes_to_str, _convert_seconds_to_str
from .extraction_tools import ArrayType, DtypeType, FloatType
[docs]
class ImagingExtractor(ABC):
"""Abstract class that contains all the meta-data and input data from the imaging data."""
def __init__(self, *args, **kwargs) -> None:
"""Initialize the ImagingExtractor object."""
self._args = args
self._kwargs = kwargs
self._times = None
self.name = self.__class__.__name__
self.is_volumetric = False
def _repr_text(self) -> str:
"""Generate text representation of the ImagingExtractor object."""
num_samples = self.get_num_samples()
sample_shape = self.get_sample_shape()
dtype = self.get_dtype()
sf_hz = self.get_sampling_frequency()
# Format sampling frequency
if not sf_hz.is_integer():
sampling_frequency_repr = f"{sf_hz:f} Hz"
else:
sampling_frequency_repr = f"{sf_hz:0.1f}Hz"
# Calculate duration
duration = num_samples / sf_hz
duration_repr = _convert_seconds_to_str(duration)
# Calculate memory size using product of all dimensions in image_size
memory_size = num_samples * prod(sample_shape) * dtype.itemsize
memory_repr = _convert_bytes_to_str(memory_size)
# Format shape string based on whether data is volumetric or not
sample_shape_repr = f"{sample_shape[0]} rows x {sample_shape[1]} columns "
if self.is_volumetric:
sample_shape_repr += f"x {sample_shape[2]} planes"
return (
f"{self.name}\n"
f" Number of samples: {num_samples:,} \n"
f" Sample shape: {sample_shape_repr} \n"
f" Sampling rate: {sampling_frequency_repr}\n"
f" Duration: {duration_repr}\n"
f" Imaging data memory: {memory_repr} ({dtype} dtype)"
)
def __repr__(self):
return self._repr_text()
[docs]
@abstractmethod
def get_image_shape(self) -> tuple[int, int]:
"""Get the shape of the video frame (num_rows, num_columns).
Returns
-------
image_shape: tuple
Shape of the video frame (num_rows, num_columns).
"""
pass
[docs]
def get_frame_shape(self) -> tuple[int, int]:
"""Get the shape of the video frame (num_rows, num_columns).
Returns
-------
frame_shape: tuple
Shape of the video frame (num_rows, num_columns).
Notes
-----
This method is equivalent to get_image_shape()
"""
return self.get_image_shape()
[docs]
def get_sample_shape(self) -> tuple[int, int] | tuple[int, int, int]:
"""
Get the shape of a single sample elements from the series.
If the series is volumetric, the shape is the shape of the volume and otherwise
returns the shape of a single frame/image.
Returns
-------
sample_shape: tuple
Shape of a single sample from the time series (num_rows, num_columns, num_planes)
if volumetric (num_rows, num_columns) otherwise
"""
if self.is_volumetric:
return self.get_volume_shape()
else:
return self.get_frame_shape()
[docs]
def get_num_planes(self) -> int:
"""Get the number of depth planes.
Returns
-------
num_planes: int
The number of depth planes.
Raises
------
NotImplementedError
If the extractor is not volumetric.
"""
if not self.is_volumetric:
raise NotImplementedError(
"This extractor is not volumetric. "
"The get_num_planes method is only available for volumetric extractors."
)
raise NotImplementedError("This method must be implemented by volumetric extractor subclasses.")
[docs]
def get_volume_shape(self) -> tuple[int, int, int]:
"""Get the shape of the volumetric video (num_rows, num_columns, num_planes).
Returns
-------
video_shape: tuple
Shape of the volumetric video (num_rows, num_columns, num_planes).
Raises
------
NotImplementedError
If the extractor is not volumetric or does not implement get_num_planes method.
"""
if not self.is_volumetric:
raise NotImplementedError(
"This extractor is not volumetric. "
"The get_volume_shape method is only available for volumetric extractors."
)
if not hasattr(self, "get_num_planes") or not callable(getattr(self, "get_num_planes")):
raise NotImplementedError(
"This extractor does not implement get_num_planes method. "
"The get_volume_shape method requires the get_num_planes method to be implemented."
)
frame_shape = self.get_frame_shape()
return (frame_shape[0], frame_shape[1], self.get_num_planes())
[docs]
@abstractmethod
def get_num_samples(self) -> int:
"""Get the number of samples in the video.
Returns
-------
num_samples: int
Number of samples in the video.
"""
pass
[docs]
@abstractmethod
def get_sampling_frequency(self) -> float:
"""Get the sampling frequency in Hz.
Returns
-------
sampling_frequency: float
Sampling frequency in Hz.
"""
pass
[docs]
def get_channel_names(self) -> list:
"""Get the channel names in the recoding.
Returns
-------
channel_names: list
List of strings of channel names
"""
return [f"channel_{i}" for i in range(self.get_num_channels())]
[docs]
def get_dtype(self) -> DtypeType:
"""Get the data type of the video.
Returns
-------
dtype: dtype
Data type of the video.
"""
return self.get_series(start_sample=0, end_sample=2).dtype
[docs]
@abstractmethod
def get_series(self, start_sample: int | None = None, end_sample: int | None = None) -> np.ndarray:
"""Get the series of samples.
Parameters
----------
start_sample: int, optional
Start sample index (inclusive).
end_sample: int, optional
End sample index (exclusive).
Returns
-------
series: numpy.ndarray
The series of samples.
Notes
-----
Importantly, we follow the convention that the dimensions of the array are returned in their matrix order,
More specifically:
(time, height, width)
Which is equivalent to:
(samples, rows, columns)
For volumetric data, the dimensions are:
(time, height, width, planes)
Which is equivalent to:
(samples, rows, columns, planes)
Note that this does not match the cartesian convention:
(t, x, y)
Where x is the columns width or and y is the rows or height.
"""
pass
[docs]
def get_samples(self, sample_indices: ArrayType) -> np.ndarray:
"""Get specific samples from indices (not necessarily continuous).
Parameters
----------
sample_indices: array-like
Indices of samples to return.
Returns
-------
samples: numpy.ndarray
The samples.
"""
assert (
max(sample_indices) < self.get_num_samples()
), "'sample_indices' range beyond number of available samples!"
if np.all(np.diff(sample_indices) == 1):
return self.get_series(start_sample=sample_indices[0], end_sample=sample_indices[-1] + 1)
relative_indices = np.array(sample_indices) - sample_indices[0]
series = self.get_series(start_sample=sample_indices[0], end_sample=sample_indices[-1] + 1)
return series[relative_indices]
[docs]
@abstractmethod
def get_native_timestamps(
self, start_sample: int | None = None, end_sample: int | None = None
) -> np.ndarray | None:
"""
Retrieve the original unaltered timestamps for the data in this interface.
This function should retrieve the data on-demand by re-initializing the IO.
Can be overridden to return None if the extractor does not have native timestamps.
Parameters
----------
start_sample : int, optional
The starting sample index. If None, starts from the beginning.
end_sample : int, optional
The ending sample index. If None, goes to the end.
Returns
-------
timestamps: numpy.ndarray or None
The timestamps for the data stream, or None if native timestamps are not available.
"""
return None
[docs]
def get_timestamps(self, start_sample: int | None = None, end_sample: int | None = None) -> np.ndarray:
"""
Retrieve the timestamps for the data in this extractor.
Parameters
----------
start_sample : int, optional
The starting sample index. If None, starts from the beginning.
end_sample : int, optional
The ending sample index. If None, goes to the end.
Returns
-------
timestamps: numpy.ndarray
The timestamps for the data stream.
"""
# Set defaults
if start_sample is None:
start_sample = 0
if end_sample is None:
end_sample = self.get_num_samples()
# Return cached timestamps if available
if self._times is not None:
return self._times[start_sample:end_sample]
# See if native timetstamps are available from the format
native_timestamps = self.get_native_timestamps()
if native_timestamps is not None:
self._times = native_timestamps # Cache the native timestamps
return native_timestamps[start_sample:end_sample]
# Fallback to calculated timestamps from sampling frequency
sample_indices = np.arange(start_sample, end_sample)
return sample_indices / self.get_sampling_frequency()
[docs]
def time_to_sample_indices(self, times: FloatType | ArrayType) -> FloatType | np.ndarray:
"""Convert user-inputted times (in seconds) to sample indices.
Parameters
----------
times: float or array-like
The times (in seconds) to be converted to sample indices.
Returns
-------
sample_indices: float or array-like
The corresponding sample indices.
"""
# Default implementation
if self._times is None:
return np.round(times * self.get_sampling_frequency()).astype("int64")
else:
return np.searchsorted(self._times, times).astype("int64")
[docs]
def set_times(self, times: ArrayType) -> None:
"""Set the recording times (in seconds) for each frame.
Parameters
----------
times: array-like
The times in seconds for each frame.
Raises
------
ValueError
If the length of 'times' does not match the number of samples.
"""
num_samples = self.get_num_samples()
num_timestamps = len(times)
if num_timestamps != num_samples:
raise ValueError(
f"Mismatch between the number of samples and timestamps: "
f"{num_samples} samples, but {num_timestamps} timestamps provided. "
"Ensure the length of 'times' matches the number of samples."
)
self._times = np.array(times).astype("float64", copy=False)
[docs]
def has_time_vector(self) -> bool:
"""Detect if the ImagingExtractor has a time vector set or not.
Returns
-------
has_times: bool
True if the ImagingExtractor has a time vector set, otherwise False.
"""
return self._times is not None
[docs]
def copy_times(self, extractor) -> None:
"""Copy times from another extractor.
Parameters
----------
extractor
The extractor from which the times will be copied.
"""
if extractor._times is not None:
self.set_times(deepcopy(extractor._times))
[docs]
def slice_samples(self, start_sample: int | None = None, end_sample: int | None = None):
"""Return a new ImagingExtractor ranging from the start_sample to the end_sample.
Parameters
----------
start_sample: int, optional
Start sample index (inclusive).
end_sample: int, optional
End sample index (exclusive).
Returns
-------
imaging: SampleSlicedImagingExtractor
The sliced ImagingExtractor object.
"""
return SampleSlicedImagingExtractor(parent_imaging=self, start_sample=start_sample, end_sample=end_sample)
[docs]
def slice_field_of_view(
self,
row_start: int | None = None,
row_end: int | None = None,
column_start: int | None = None,
column_end: int | None = None,
):
"""Return a new ImagingExtractor with a spatially sliced field of view.
Parameters
----------
row_start: int, optional
Starting row index (inclusive). Default is 0.
row_end: int, optional
Ending row index (exclusive). Default is the full height.
column_start: int, optional
Starting column index (inclusive). Default is 0.
column_end: int, optional
Ending column index (exclusive). Default is the full width.
Returns
-------
imaging: FieldOfViewSlicedImagingExtractor
The spatially sliced ImagingExtractor object.
Notes
-----
This method creates a lazy view of the imaging data with a cropped field of view.
The slicing is applied to the spatial dimensions (rows and columns) of each frame.
For volumetric data, all depth planes are preserved with the same spatial crop applied.
Examples
--------
>>> # Crop to center 100x100 region starting at (50, 50)
>>> cropped_extractor = extractor.slice_field_of_view(row_start=50, row_end=150,
... column_start=50, column_end=150)
>>>
>>> # Get top-left quadrant
>>> height, width = extractor.get_image_shape()
>>> quadrant = extractor.slice_field_of_view(row_end=height//2, column_end=width//2)
>>>
>>> # Compose with temporal slicing
>>> subset = extractor.slice_samples(0, 1000).slice_field_of_view(100, 200, 100, 200)
"""
return _FieldOfViewSlicedImagingExtractor(
parent_imaging=self,
row_start=row_start,
row_end=row_end,
column_start=column_start,
column_end=column_end,
)
[docs]
class SampleSlicedImagingExtractor(ImagingExtractor):
"""Class to get a lazy sample slice.
Do not use this class directly but use `.slice_samples(...)` on an ImagingExtractor object.
"""
extractor_name = "SampleSlicedImagingExtractor"
def __init__(
self, parent_imaging: ImagingExtractor, start_sample: int | None = None, end_sample: int | None = None
):
"""Initialize an ImagingExtractor whose samples subset the parent.
Subset is exclusive on the right bound, that is, the indexes of this ImagingExtractor range over
[0, ..., end_sample-start_sample-1], which is used to resolve the index mapping in `get_samples(sample_indices=[...])`.
Parameters
----------
parent_imaging : ImagingExtractor
The ImagingExtractor object to subset the samples of.
start_sample : int, optional
The left bound of the samples to subset.
The default is the start sample of the parent.
end_sample : int, optional
The right bound of the samples, exclusively, to subset.
The default is end sample of the parent.
"""
self._parent_imaging = parent_imaging
self._start_sample = start_sample
self._end_sample = end_sample
self._num_samples = self._end_sample - self._start_sample
parent_size = self._parent_imaging.get_num_samples()
if start_sample is None:
start_sample = 0
else:
assert 0 <= start_sample < parent_size
if end_sample is None:
end_sample = parent_size
else:
assert 0 < end_sample <= parent_size
assert end_sample > start_sample, "'start_sample' must be smaller than 'end_sample'!"
super().__init__()
if getattr(self._parent_imaging, "_times") is not None:
self._times = self._parent_imaging._times[start_sample:end_sample]
# Inherit volumetric properties from parent
self.is_volumetric = self._parent_imaging.is_volumetric
[docs]
def get_samples(self, sample_indices: ArrayType) -> np.ndarray:
assert max(sample_indices) < self._num_samples, "'sample_indices' range beyond number of available samples!"
mapped_sample_indices = np.array(sample_indices) + self._start_sample
return self._parent_imaging.get_samples(sample_indices=mapped_sample_indices)
[docs]
def get_series(self, start_sample: int | None = None, end_sample: int | None = None) -> np.ndarray:
assert start_sample is None or start_sample >= 0, (
f"'start_sample' must be greater than or equal to zero! Received '{start_sample}'.\n"
"Negative slicing semantics are not supported."
)
start_sample_shifted = (start_sample or 0) + self._start_sample
end_sample_shifted = end_sample
if end_sample is not None:
end_sample_shifted = end_sample + self._start_sample
return self._parent_imaging.get_series(start_sample=start_sample_shifted, end_sample=end_sample_shifted)
[docs]
def get_image_shape(self) -> tuple[int, int]:
"""Get the shape of the video frame (num_rows, num_columns).
Returns
-------
image_shape: tuple
Shape of the video frame (num_rows, num_columns).
"""
return self._parent_imaging.get_image_shape()
[docs]
def get_sampling_frequency(self) -> float:
return self._parent_imaging.get_sampling_frequency()
[docs]
def get_channel_names(self) -> list:
warnings.warn(
"get_channel_names is deprecated and will be removed in May 2026 or after.",
category=FutureWarning,
stacklevel=2,
)
return self._parent_imaging.get_channel_names()
[docs]
def get_num_planes(self) -> int:
"""Get the number of depth planes.
Returns
-------
num_planes: int
The number of depth planes.
Raises
------
NotImplementedError
If the parent extractor is not volumetric.
"""
if not self.is_volumetric:
raise NotImplementedError(
"This extractor is not volumetric. "
"The get_num_planes method is only available for volumetric extractors."
)
return self._parent_imaging.get_num_planes()
[docs]
def get_native_timestamps(
self, start_sample: int | None = None, end_sample: int | None = None
) -> np.ndarray | None:
# Get the full original timestamps from parent, but return only our slice range
if start_sample is None:
start_sample = 0
if end_sample is None:
end_sample = self.get_num_samples()
# Map relative indices to absolute indices in the parent
actual_start = self._start_sample + start_sample
actual_end = self._start_sample + end_sample
# Get timestamps from parent for our specific range
return self._parent_imaging.get_native_timestamps(start_sample=actual_start, end_sample=actual_end)
[docs]
def has_time_vector(self) -> bool:
# Override to check parent imaging for time vector
return self._parent_imaging.has_time_vector()
class _FieldOfViewSlicedImagingExtractor(ImagingExtractor):
"""Class to get a lazy field of view slice.
Do not use this class directly but use `.slice_field_of_view(...)` on an ImagingExtractor object.
"""
extractor_name = "FieldOfViewSlicedImagingExtractor"
def __init__(
self,
parent_imaging: ImagingExtractor,
row_start: int | None = None,
row_end: int | None = None,
column_start: int | None = None,
column_end: int | None = None,
):
"""Initialize an ImagingExtractor with a spatially sliced field of view.
Parameters
----------
parent_imaging : ImagingExtractor
The ImagingExtractor object to spatially slice.
row_start : int, optional
The starting row index (inclusive). Default is 0.
row_end : int, optional
The ending row index (exclusive). Default is the full height.
column_start : int, optional
The starting column index (inclusive). Default is 0.
column_end : int, optional
The ending column index (exclusive). Default is the full width.
"""
self._parent_imaging = parent_imaging
# Get parent image shape to determine defaults
parent_height, parent_width = parent_imaging.get_image_shape()
# Set defaults and validate
if row_start is None:
row_start = 0
if row_end is None:
row_end = parent_height
if column_start is None:
column_start = 0
if column_end is None:
column_end = parent_width
# Validation
assert (
0 <= row_start < parent_height
), f"'row_start' ({row_start}) must be >= 0 and < parent height ({parent_height})"
assert 0 < row_end <= parent_height, f"'row_end' ({row_end}) must be > 0 and <= parent height ({parent_height})"
assert row_end > row_start, f"'row_end' ({row_end}) must be greater than 'row_start' ({row_start})"
assert (
0 <= column_start < parent_width
), f"'column_start' ({column_start}) must be >= 0 and < parent width ({parent_width})"
assert (
0 < column_end <= parent_width
), f"'column_end' ({column_end}) must be > 0 and <= parent width ({parent_width})"
assert (
column_end > column_start
), f"'column_end' ({column_end}) must be greater than 'column_start' ({column_start})"
# Store slices
self._row_slice = slice(row_start, row_end)
self._column_slice = slice(column_start, column_end)
# Calculate new image shape
self._image_shape = (row_end - row_start, column_end - column_start)
super().__init__()
# Copy time vector from parent if it exists
if getattr(self._parent_imaging, "_times") is not None:
self._times = self._parent_imaging._times
# Inherit volumetric properties from parent
self.is_volumetric = self._parent_imaging.is_volumetric
def get_image_shape(self) -> tuple[int, int]:
"""Get the shape of the spatially sliced video frame (num_rows, num_columns).
Returns
-------
image_shape: tuple
Shape of the sliced video frame (num_rows, num_columns).
"""
return self._image_shape
def get_series(self, start_sample: int | None = None, end_sample: int | None = None) -> np.ndarray:
"""Get the spatially sliced series of samples.
Parameters
----------
start_sample: int, optional
Start sample index (inclusive).
end_sample: int, optional
End sample index (exclusive).
Returns
-------
series: numpy.ndarray
The spatially sliced series with shape (samples, rows, columns) for 2D data
or (samples, rows, columns, planes) for volumetric data.
"""
# Get full data from parent
data = self._parent_imaging.get_series(start_sample=start_sample, end_sample=end_sample)
# Apply spatial slicing
if self.is_volumetric:
# Volumetric data: (time, height, width, planes)
return data[:, self._row_slice, self._column_slice, :]
else:
# 2D data: (time, height, width)
return data[:, self._row_slice, self._column_slice]
def get_samples(self, sample_indices: ArrayType) -> np.ndarray:
"""Get specific spatially sliced samples from indices.
Parameters
----------
sample_indices: array-like
Indices of samples to return.
Returns
-------
samples: numpy.ndarray
The spatially sliced samples.
"""
# Get samples from parent
data = self._parent_imaging.get_samples(sample_indices=sample_indices)
# Apply spatial slicing
if self.is_volumetric:
return data[:, self._row_slice, self._column_slice, :]
else:
return data[:, self._row_slice, self._column_slice]
def get_num_samples(self) -> int:
"""Get the number of samples.
Returns
-------
num_samples: int
Number of samples in the video.
"""
return self._parent_imaging.get_num_samples()
def get_sampling_frequency(self) -> float:
"""Get the sampling frequency in Hz.
Returns
-------
sampling_frequency: float
Sampling frequency in Hz.
"""
return self._parent_imaging.get_sampling_frequency()
def get_channel_names(self) -> list:
"""Get the channel names.
Returns
-------
channel_names: list
List of strings of channel names.
"""
return self._parent_imaging.get_channel_names()
def get_num_planes(self) -> int:
"""Get the number of depth planes.
Returns
-------
num_planes: int
The number of depth planes.
Raises
------
NotImplementedError
If the parent extractor is not volumetric.
"""
if not self.is_volumetric:
raise NotImplementedError(
"This extractor is not volumetric. "
"The get_num_planes method is only available for volumetric extractors."
)
return self._parent_imaging.get_num_planes()
def get_native_timestamps(
self, start_sample: int | None = None, end_sample: int | None = None
) -> np.ndarray | None:
"""Get native timestamps from parent (FOV slicing doesn't affect temporal metadata).
Parameters
----------
start_sample : int, optional
The starting sample index.
end_sample : int, optional
The ending sample index.
Returns
-------
timestamps: numpy.ndarray or None
The timestamps for the data stream.
"""
return self._parent_imaging.get_native_timestamps(start_sample=start_sample, end_sample=end_sample)
def has_time_vector(self) -> bool:
"""Check if parent has time vector.
Returns
-------
has_times: bool
True if the parent ImagingExtractor has a time vector set.
"""
return self._parent_imaging.has_time_vector()