"""Various tools for extraction of ROIs from imaging data.
Classes
-------
VideoStructure
A data class for specifying the structure of a video.
"""
import importlib.util
import sys
from dataclasses import dataclass
from functools import wraps
from pathlib import Path
from platform import python_version
from types import ModuleType
import h5py
import lazy_ops
import numpy as np
import zarr
from numpy.typing import ArrayLike, DTypeLike
from packaging import version
from tqdm import tqdm
ArrayType = ArrayLike
PathType = str | Path
NumpyArray = np.ndarray
DtypeType = DTypeLike
IntType = int | np.integer
FloatType = float
[docs]
def calculate_regular_series_rate(series: ArrayType, tolerance_decimals: int = 6) -> float | None:
"""Calculate the rate of a regular series from consecutive differences.
If all differences between consecutive points are the same (within rounding tolerance),
returns the rate as `1.0 / interval`. Otherwise returns None.
Parameters
----------
series : array-like
Array of timestamps or time points.
tolerance_decimals : int, default: 6
Number of decimal places for rounding when checking uniformity.
Returns
-------
float | None
The calculated rate if the series is regular, None otherwise.
"""
diff_ts = np.diff(series)
rounded_diff_ts = diff_ts.round(decimals=tolerance_decimals)
uniq_diff_ts = np.unique(rounded_diff_ts)
rate = 1.0 / diff_ts[0] if len(uniq_diff_ts) == 1 else None
return rate
[docs]
def raise_multi_channel_or_depth_not_implemented(extractor_name: str):
"""Raise a NotImplementedError for an extractor that does not support multiple channels or depth (z-axis)."""
raise NotImplementedError(
f"The {extractor_name}Extractor does not currently support multiple color channels or 3-dimensional depth."
"If you with to request either of these features, please do so by raising an issue at "
"https://github.com/catalystneuro/roiextractors/issues"
)
[docs]
@dataclass
class VideoStructure:
"""A data class for specifying the structure of a video.
The role of the data class is to ensure consistency in naming and provide some initial
consistency checks to ensure the validity of the structure.
Attributes
----------
num_rows : int
The number of rows of each frame as a matrix.
num_columns : int
The number of columns of each frame as a matrix.
num_channels : int
The number of channels (1 for grayscale, 3 for color).
rows_axis : int
The axis or dimension corresponding to the rows.
columns_axis : int
The axis or dimension corresponding to the columns.
channels_axis : int
The axis or dimension corresponding to the channels.
frame_axis : int
The axis or dimension corresponding to the frames in the video.
As an example if you wanted to build the structure for a video with gray (n_channels=1) frames of 10 x 5
where the video is to have the following shape (num_frames, num_rows, num_columns, num_channels) you
could define the class this way:
>>> from roiextractors.extraction_tools import VideoStructure
>>> num_rows = 10
>>> num_columns = 5
>>> num_channels = 1
>>> frame_axis = 0
>>> rows_axis = 1
>>> columns_axis = 2
>>> channels_axis = 3
>>> video_structure = VideoStructure(
num_rows=num_rows,
num_columns=num_columns,
num_channels=num_channels,
rows_axis=rows_axis,
columns_axis=columns_axis,
channels_axis=channels_axis,
frame_axis=frame_axis,
)
"""
num_rows: int
num_columns: int
num_channels: int
rows_axis: int
columns_axis: int
channels_axis: int
frame_axis: int
[docs]
def __post_init__(self) -> None:
"""Validate the structure of the video and initialize the shape of the frame."""
self._validate_video_structure()
self._initialize_frame_shape()
self.number_of_pixels_per_frame = np.prod(self.frame_shape)
def _initialize_frame_shape(self) -> None:
"""Initialize the shape of the frame."""
self.frame_shape = [None, None, None, None]
self.frame_shape[self.rows_axis] = self.num_rows
self.frame_shape[self.columns_axis] = self.num_columns
self.frame_shape[self.channels_axis] = self.num_channels
self.frame_shape.pop(self.frame_axis)
self.frame_shape = tuple(self.frame_shape)
def _validate_video_structure(self) -> None:
"""Validate the structure of the video."""
exception_message = (
"Invalid structure: "
f"{self.__repr__()}, "
"each property axis should be unique value between 0 and 3 (inclusive)"
)
axis_values = {self.rows_axis, self.columns_axis, self.channels_axis, self.frame_axis}
axis_values_are_not_unique = len(axis_values) != 4
if axis_values_are_not_unique:
raise ValueError(exception_message)
values_out_of_range = any([axis < 0 or axis > 4 for axis in axis_values])
if values_out_of_range:
raise ValueError(exception_message)
[docs]
def build_video_shape(self, n_frames: int) -> tuple[int, int, int, int]:
"""Build the shape of the video from class attributes.
Parameters
----------
n_frames : int
The number of frames in the video.
Returns
-------
Tuple[int, int, int, int]
The shape of the video.
Notes
-----
The class attributes frame_axis, rows_axis, columns_axis and channels_axis are used to determine the order of the
dimensions in the returned tuple.
"""
video_shape = [None] * 4
video_shape[self.frame_axis] = n_frames
video_shape[self.rows_axis] = self.num_rows
video_shape[self.columns_axis] = self.num_columns
video_shape[self.channels_axis] = self.num_channels
return tuple(video_shape)
[docs]
def read_numpy_memmap_video(
file_path: PathType, video_structure: VideoStructure, dtype: DtypeType, offset: int = 0
) -> np.array:
"""Auxiliary function to read videos from binary files.
Parameters
----------
file_path : PathType
the file_path where the data resides.
video_structure : VideoStructure
A VideoStructure instance describing the structure of the video to read. This includes parameters
such as the number of rows, columns and channels plus which axis (i.e. dimension) of the
image corresponds to each of them.
As an example you create one of these structures in the following way:
from roiextractors.extraction_tools import VideoStructure
num_rows = 10
num_columns = 5
num_channels = 3
frame_axis = 0
rows_axis = 1
columns_axis = 2
channels_axis = 3
video_structure = VideoStructure(
num_rows=num_rows,
num_columns=num_columns,
num_channels=num_channels,
rows_axis=rows_axis,
columns_axis=columns_axis,
channels_axis=channels_axis,
frame_axis=frame_axis,
)
dtype : DtypeType
The type of the data to be loaded (int, float, etc.)
offset : int, optional
The offset in bytes. Usually corresponds to the number of bytes occupied by the header. 0 by default.
Returns
-------
video_memap: np.array
A numpy memmap pointing to the video.
"""
file_size_bytes = Path(file_path).stat().st_size
pixels_per_frame = video_structure.number_of_pixels_per_frame
type_size = np.dtype(dtype).itemsize
frame_size_bytes = pixels_per_frame * type_size
bytes_available = file_size_bytes - offset
number_of_frames = bytes_available // frame_size_bytes
memmap_shape = video_structure.build_video_shape(n_frames=number_of_frames)
video_memap = np.memmap(file_path, offset=offset, dtype=dtype, mode="r", shape=memmap_shape)
return video_memap
def _pixel_mask_extractor(image_mask_, _roi_ids) -> list:
"""Convert image mask to pixel mask.
Pixel masks are an alternative data format for storage of image masks which relies on the sparsity of the images.
The location and weight of each non-zero pixel is stored for each mask.
Parameters
----------
image_mask_: numpy.ndarray
Dense representation of the ROIs with shape (number_of_rows, number_of_columns, number_of_rois).
_roi_ids: list
List of roi ids with length number_of_rois.
Returns
-------
pixel_masks: list
List of length number of rois, each element is a 2-D array with shape (number_of_non_zero_pixels, 3).
Columns 1 and 2 are the x and y coordinates of the pixel, while the third column represents the weight of
the pixel.
"""
pixel_mask_list = []
for i, roiid in enumerate(_roi_ids):
image_mask = np.array(image_mask_[:, :, i])
_locs = np.where(image_mask > 0)
_pix_values = image_mask[image_mask > 0]
pixel_mask_list.append(np.vstack((_locs[0], _locs[1], _pix_values)).T)
return pixel_mask_list
def _image_mask_extractor(pixel_mask, _roi_ids, image_shape) -> np.ndarray:
"""Convert a pixel mask to image mask.
Parameters
----------
pixel_mask: list
list of pixel masks (no pixels X 3)
_roi_ids: list
list of roi ids with length number_of_rois
image_shape: array_like
shape of the image (number_of_rows, number_of_columns)
Returns
-------
image_mask: np.ndarray
Dense representation of the ROIs with shape (number_of_rows, number_of_columns, number_of_rois).
"""
image_mask = np.zeros(list(image_shape) + [len(_roi_ids)])
for no, rois in enumerate(_roi_ids):
for y, x, wt in pixel_mask[rois]:
image_mask[int(y), int(x), no] = wt
return image_mask
[docs]
def check_get_frames_args(func):
"""Check the arguments of the get_frames function.
This decorator allows the get_frames function to be queried with either
an integer, slice or an array and handles a common return. [I think that np.take can be used instead of this]
Parameters
----------
func: function
The get_frames function.
Returns
-------
corrected_args: function
The get_frames function with corrected arguments.
Raises
------
AssertionError
If 'frame_idxs' exceed the number of frames.
Deprecated
----------
This function will be removed on or after June 2026.
The get_frames method it decorates has been removed.
"""
import warnings
warnings.warn(
"check_get_frames_args() is deprecated and will be removed on or after June 2026. "
"The get_frames method has been removed.",
FutureWarning,
stacklevel=2,
)
@wraps(func)
def corrected_args(imaging, frame_idxs, channel=0):
channel = int(channel)
if isinstance(frame_idxs, (int, np.integer)):
frame_idxs = [frame_idxs]
if not isinstance(frame_idxs, slice):
frame_idxs = np.array(frame_idxs)
assert np.all(frame_idxs < imaging.get_num_samples()), "'frame_idxs' exceed number of frames"
get_frames_correct_arg = func(imaging, frame_idxs, channel)
if len(frame_idxs) == 1:
return get_frames_correct_arg[0]
else:
return get_frames_correct_arg
return corrected_args
def _cast_start_end_frame(start_frame, end_frame):
"""Cast start and end frame to int or None.
Parameters
----------
start_frame: int, float, None
The start frame.
end_frame: int, float, None
The end frame.
Returns
-------
start_frame: int, None
The start frame.
end_frame: int, None
The end frame.
Raises
------
ValueError
If start_frame is not an int, float or None.
ValueError
If end_frame is not an int, float or None.
"""
if isinstance(start_frame, float):
start_frame = int(start_frame)
elif isinstance(start_frame, (int, np.integer, type(None))):
start_frame = start_frame
else:
raise ValueError("start_frame must be an int, float (not infinity), or None")
if isinstance(end_frame, float) and np.isfinite(end_frame):
end_frame = int(end_frame)
elif isinstance(end_frame, (int, np.integer, type(None))):
end_frame = end_frame
# else end_frame is infinity (accepted for get_unit_spike_train)
if start_frame is not None:
start_frame = int(start_frame)
if end_frame is not None and np.isfinite(end_frame):
end_frame = int(end_frame)
return start_frame, end_frame
[docs]
def check_get_videos_args(func):
"""Check the arguments of the get_videos function.
This decorator allows the get_videos function to be queried with either
an integer or slice and handles a common return.
Parameters
----------
func: function
The get_videos function.
Returns
-------
corrected_args: function
The get_videos function with corrected arguments.
Raises
------
AssertionError
If 'start_frame' exceeds the number of frames.
AssertionError
If 'end_frame' exceeds the number of frames.
AssertionError
If 'start_frame' is greater than 'end_frame'.
Deprecated
----------
This function will be removed on or after January 2026.
The get_video method it decorates has been removed.
"""
import warnings
warnings.warn(
"check_get_videos_args() is deprecated and will be removed on or after January 2026. "
"The get_video method has been removed.",
FutureWarning,
stacklevel=2,
)
@wraps(func)
def corrected_args(imaging, start_frame=None, end_frame=None, channel=0):
if start_frame is not None:
if start_frame > imaging.get_num_samples():
raise Exception(f"'start_frame' exceeds number of frames {imaging.get_num_samples()}!")
elif start_frame < 0:
start_frame = imaging.get_num_samples() + start_frame
else:
start_frame = 0
if end_frame is not None:
if end_frame > imaging.get_num_samples():
raise Exception(f"'end_frame' exceeds number of frames {imaging.get_num_samples()}!")
elif end_frame < 0:
end_frame = imaging.get_num_samples() + end_frame
else:
end_frame = imaging.get_num_samples()
assert end_frame - start_frame > 0, "'start_frame' must be less than 'end_frame'!"
start_frame, end_frame = _cast_start_end_frame(start_frame, end_frame)
channel = int(channel)
get_videos_correct_arg = func(imaging, start_frame=start_frame, end_frame=end_frame, channel=channel)
return get_videos_correct_arg
return corrected_args
# TODO will be moved eventually, but for now it's very handy :)
[docs]
def show_video(imaging, ax=None):
"""Show video as animation.
Parameters
----------
imaging: ImagingExtractor
The imaging extractor object to be saved in the .h5 file
ax: matplotlib axis
Axis to plot the video. If None, a new axis is created.
Returns
-------
anim: matplotlib.animation.FuncAnimation
Animation of the video.
"""
import matplotlib.animation as animation
import matplotlib.pyplot as plt
def animate_func(i, imaging, im, ax):
ax.set_title(f"{i}")
im.set_array(imaging.get_samples(sample_indices=[i])[0])
return [im]
if ax is None:
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(111)
im0 = imaging.get_samples(sample_indices=[0])[0]
im = ax.imshow(im0, interpolation="none", aspect="auto", vmin=0, vmax=1)
interval = 1 / imaging.get_sampling_frequency() * 1000
anim = animation.FuncAnimation(
fig,
animate_func,
frames=imaging.get_num_samples(),
fargs=(imaging, im, ax),
interval=interval,
blit=False,
)
return anim
[docs]
def check_keys(dict_: dict) -> dict:
"""Check keys of dictionary for mat-objects.
Checks if entries in dictionary are mat-objects. If yes
todict is called to change them to nested dictionaries.
Parameters
----------
dict_: dict
Dictionary to check.
Returns
-------
dict: dict
Dictionary with mat-objects converted to nested dictionaries.
Raises
------
AssertionError
If scipy is not installed.
"""
from scipy.io.matlab import mat_struct
for key in dict_:
if isinstance(dict_[key], mat_struct):
dict_[key] = todict(dict_[key])
return dict_
[docs]
def todict(matobj):
"""Recursively construct nested dictionaries from matobjects.
Parameters
----------
matobj: mat_struct
Matlab object to convert to nested dictionary.
Returns
-------
dict: dict
Dictionary with mat-objects converted to nested dictionaries.
"""
from scipy.io.matlab import mat_struct
dict_ = {}
for strg in matobj._fieldnames:
elem = matobj.__dict__[strg]
if isinstance(elem, mat_struct):
dict_[strg] = todict(elem)
else:
dict_[strg] = elem
return dict_
[docs]
def get_package(
package_name: str,
installation_instructions: str | None = None,
excluded_platforms_and_python_versions: dict[str, list[str]] | None = None,
) -> ModuleType:
"""Check if package is installed and return module if so.
Otherwise, raise informative error describing how to perform the installation.
Inspired by https://docs.python.org/3/library/importlib.html#checking-if-a-module-can-be-imported.
Parameters
----------
package_name : str
Name of the package to be imported.
installation_instructions : str, optional
String describing the source, options, and alias of package name (if needed) for installation.
For example,
>>> installation_source = "conda install -c conda-forge my-package-name"
Defaults to f"pip install {package_name}".
excluded_platforms_and_python_versions : dict mapping string platform names to a list of string versions, optional
In case some combinations of platforms or Python versions are not allowed for the given package, specify
this dictionary to raise a more specific error to that issue.
For example, `excluded_platforms_and_python_versions = dict(darwin=["3.7"])` will raise an informative error
when running on MacOS with Python version 3.7.
Allows all platforms and Python versions used by default.
Raises
------
ModuleNotFoundError
If the package is not installed.
"""
installation_instructions = installation_instructions or f"pip install {package_name}"
excluded_platforms_and_python_versions = excluded_platforms_and_python_versions or dict()
if package_name in sys.modules:
return sys.modules[package_name]
if importlib.util.find_spec(package_name) is not None:
return importlib.import_module(name=package_name)
for excluded_version in excluded_platforms_and_python_versions.get(sys.platform, list()):
if version.parse(python_version()).minor == version.parse(excluded_version).minor:
raise ModuleNotFoundError(
f"\nThe package '{package_name}' is not available on the {sys.platform} platform for "
f"Python version {excluded_version}!"
)
raise ModuleNotFoundError(
f"\nThe required package'{package_name}' is not installed!\n"
f"To install this package, please run\n\n\t{installation_instructions}\n"
)