Source code for roiextractors.segmentationextractor
"""Base segmentation extractors.
Classes
-------
SegmentationExtractor
Abstract class that contains all the meta-data and output data from the ROI segmentation operation when applied to
the pre-processed data. It also contains methods to read from various data formats output from the
processing pipelines like SIMA, CaImAn, Suite2p, CNMF-E.
SampleSlicedSegmentationExtractor
Class to get a lazy sample slice.
"""
import warnings
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Literal
import numpy as np
from numpy.typing import ArrayLike
# TODO make public once API stabilizes.
@dataclass
class _RoiResponse:
"""Represents a fluorescence response (trace) with its metadata."""
response_type: str
data: ArrayLike # Shape: (num_samples, num_rois)
roi_ids: list[str | int]
@dataclass
class _PropertyInfo:
"""Bundles property data with its metadata.
Keeps the property values and their descriptive metadata together as a single unit,
so they cannot drift out of sync when copied or passed around.
"""
data: np.ndarray
description: str = ""
class _ROIMasks:
"""Internal container for all ROI spatial representations in native NWB-compatible format.
Stores all ROI masks (cells + background/neuropil) together with their ID mapping.
The representation format matches NWB standards for efficient reading/writing.
Note: This is a private class. Users should access ROI masks through SegmentationExtractor
methods like get_roi_image_masks() and get_roi_pixel_masks().
Attributes
----------
data : ArrayLike
Native format data:
- "nwb-image_mask": (height, width, n_rois) dense array, possibly lazy (DatasetView/h5py.Dataset)
- "nwb-pixel_mask": list of (n_pixels, 3) arrays with columns [y, x, weight]
- "nwb-voxel_mask": list of (n_voxels, 4) arrays with columns [y, x, z, weight]
mask_tpe : Literal["nwb-image_mask", "nwb-pixel_mask", "nwb-voxel_mask"]
Type of NWB-compatible representation.
field_of_view_shape : tuple[int, ...]
Shape of imaging FOV: (height, width) for 2D or (depth, height, width) for 3D.
roi_id_map : dict[str | int, int]
Maps ROI ID -> index in data structure.
- For dense: roi_id -> slice index along last axis
- For sparse lists: roi_id -> list index
Examples: {0: 0, 1: 1, "background0": 2, "background1": 3}
"""
def __init__(
self,
data: ArrayLike,
mask_tpe: Literal["nwb-image_mask", "nwb-pixel_mask", "nwb-voxel_mask"],
field_of_view_shape: tuple[int, ...],
roi_id_map: dict[str | int, int],
):
"""Initialize ROI representations container.
Parameters
----------
data : ArrayLike
ROI mask data in native format.
mask_tpe : Literal["nwb-image_mask", "nwb-pixel_mask", "nwb-voxel_mask"]
Format type following NWB conventions.
field_of_view_shape : tuple[int, ...]
Shape of the imaging field of view (height, width) or (depth, height, width).
roi_id_map : dict[str | int, int]
Mapping from ROI ID to index in data structure.
"""
self.data = data
self.mask_tpe = mask_tpe
self.field_of_view_shape = field_of_view_shape
self.roi_id_map = roi_id_map
@property
def is_volumetric(self) -> bool:
"""True if this is 3D volumetric data, False for 2D."""
return len(self.field_of_view_shape) == 3
@property
def num_rois(self) -> int:
"""Total number of ROIs in this container."""
return len(self.roi_id_map)
def get_roi_ids(self) -> list[str | int]:
"""Get all ROI IDs in this container.
Returns
-------
list[str | int]
List of all ROI IDs (cells + background).
"""
return list(self.roi_id_map.keys())
def get_roi_image_mask(self, roi_id: str | int) -> np.ndarray:
"""Get dense image mask for a single ROI.
Parameters
----------
roi_id : str | int
The ROI identifier.
Returns
-------
np.ndarray
Dense 2D or 3D array matching field_of_view_shape.
"""
index = self.roi_id_map[roi_id]
if self.mask_tpe == "nwb-image_mask":
# Extract slice from dense stack
return np.asarray(self.data[:, :, index])
elif self.mask_tpe == "nwb-pixel_mask":
# Convert sparse pixel list to dense
dense_mask = np.zeros(self.field_of_view_shape, dtype=np.float32)
pixel_data = self.data[index] # (n_pixels, 3): [y, x, weight]
if len(pixel_data) > 0:
y_coords = pixel_data[:, 0].astype(int)
x_coords = pixel_data[:, 1].astype(int)
weights = pixel_data[:, 2]
dense_mask[y_coords, x_coords] = weights
return dense_mask
elif self.mask_tpe == "nwb-voxel_mask":
# Convert sparse voxel list to dense 3D
dense_mask = np.zeros(self.field_of_view_shape, dtype=np.float32)
voxel_data = self.data[index] # (n_voxels, 4): [y, x, z, weight]
if len(voxel_data) > 0:
y_coords = voxel_data[:, 0].astype(int)
x_coords = voxel_data[:, 1].astype(int)
z_coords = voxel_data[:, 2].astype(int)
weights = voxel_data[:, 3]
dense_mask[y_coords, x_coords, z_coords] = weights
return dense_mask
def get_roi_pixel_mask(self, roi_id: str | int) -> np.ndarray:
"""Get sparse pixel mask for a single ROI.
Parameters
----------
roi_id : str | int
The ROI identifier.
Returns
-------
np.ndarray
Array with shape (n_pixels, 3) with columns [y, x, weight].
For 3D: (n_voxels, 4) with columns [y, x, z, weight].
"""
index = self.roi_id_map[roi_id]
if self.mask_tpe == "nwb-pixel_mask":
return np.asarray(self.data[index])
elif self.mask_tpe == "nwb-voxel_mask":
return np.asarray(self.data[index])
else:
# Convert dense to sparse
dense_mask = self.get_roi_image_mask(roi_id)
if self.is_volumetric:
# 3D case
y_coords, x_coords, z_coords = np.nonzero(dense_mask)
weights = dense_mask[y_coords, x_coords, z_coords]
return np.column_stack([y_coords, x_coords, z_coords, weights])
else:
# 2D case
y_coords, x_coords = np.nonzero(dense_mask)
weights = dense_mask[y_coords, x_coords]
return np.column_stack([y_coords, x_coords, weights])
[docs]
class SegmentationExtractor(ABC):
"""Abstract segmentation extractor class.
An abstract class that contains all the meta-data and output data from
the ROI segmentation operation when applied to the pre-processed data.
It also contains methods to read from various data formats output from the
output from the processing pipelines like SIMA, CaImAn, Suite2p, CNMF-E.
All the methods with @abstract decorator have to be defined by the
format specific classes that inherit from this.
"""
def __init__(self):
"""Create a new SegmentationExtractor for a specific data format (unique to each child SegmentationExtractor)."""
self._sampling_frequency = None
self._times = None
self._channel_names = ["OpticalChannel"]
self._num_planes = 1
self._roi_ids: list[str | int] | None = None
self._roi_responses: list[_RoiResponse] = []
self._summary_images = {}
self._roi_masks: _ROIMasks | None = None
self._properties: dict[str, _PropertyInfo] = {}
[docs]
def get_accepted_list(self) -> list:
"""Get a list of accepted ROI ids.
.. deprecated::
`get_accepted_list` is deprecated and will be removed in May 2026.
Use `get_property()` instead to access format-specific acceptance data.
Returns
-------
accepted_list: list
List of accepted ROI ids.
"""
warnings.warn(
"get_accepted_list is deprecated and will be removed in May 2026. "
"Use get_property() instead to access format-specific acceptance data.",
DeprecationWarning,
stacklevel=2,
)
# Default: all ROIs accepted
return list(self.get_roi_ids())
[docs]
def get_rejected_list(self) -> list:
"""Get a list of rejected ROI ids.
.. deprecated::
`get_rejected_list` is deprecated and will be removed in May 2026.
Use `get_property()` instead to access format-specific acceptance data.
Returns
-------
rejected_list: list
List of rejected ROI ids.
"""
warnings.warn(
"get_rejected_list is deprecated and will be removed in May 2026. "
"Use get_property() instead to access format-specific acceptance data.",
DeprecationWarning,
stacklevel=2,
)
# Default: no ROIs rejected
return []
[docs]
@abstractmethod
def get_native_timestamps(
self, start_sample: int | None = None, end_sample: int | None = None
) -> np.ndarray | None:
"""Get the original timestamps from the data source.
Parameters
----------
start_sample : int, optional
Start sample index (inclusive).
end_sample : int, optional
End sample index (exclusive).
Returns
-------
timestamps : np.ndarray or None
The original timestamps in seconds, or None if not available.
"""
return None
[docs]
@abstractmethod
def get_frame_shape(self) -> tuple[int, int]:
"""Get frame size of movie (height, width).
Returns
-------
frame_shape: array_like
2-D array: image height x image width
"""
pass
[docs]
def get_num_samples(self) -> int:
"""Get the number of samples in the recording (duration of recording).
Returns
-------
num_samples: int
Number of samples in the recording.
"""
for trace in self.get_traces_dict().values():
if trace is not None and len(trace.shape) > 0:
return trace.shape[0]
[docs]
def get_roi_locations(self, roi_ids=None) -> np.ndarray:
"""Get the locations of the Regions of Interest (ROIs).
.. deprecated::
`get_roi_locations` is deprecated and will be removed in or after September 2026.
Use `get_property("roi_centroids", roi_ids)` instead for centroid data
stored as a property.
Parameters
----------
roi_ids: array_like
A list or 1D array of ids of the ROIs. Length is the number of ROIs requested.
Returns
-------
roi_locs: numpy.ndarray
2-D array: 2 X no_ROIs. The pixel ids (x,y) where the centroid of the ROI is.
"""
warnings.warn(
"get_roi_locations is deprecated and will be removed in or after September 2026. "
"Use get_property('roi_centroids', roi_ids) instead.",
FutureWarning,
stacklevel=2,
)
if roi_ids is None:
roi_ids = self.get_roi_ids()
roi_location = np.zeros([2, len(roi_ids)], dtype="int")
for roi_index, roi_id in enumerate(roi_ids):
image_mask = self.get_roi_image_masks(roi_ids=[roi_id])
temp = np.where(image_mask == np.amax(image_mask))
roi_location[:, roi_index] = np.array([np.median(temp[0]), np.median(temp[1])]).T
return roi_location
[docs]
def get_roi_ids(self) -> list:
"""Get the list of ROI ids.
Returns
-------
roi_ids: list
List of roi ids.
"""
if self._roi_ids is not None:
return self._roi_ids
# For backward compatibility, only return cell ROIs (exclude background components)
if self._roi_masks is not None:
all_roi_ids = self._roi_masks.get_roi_ids()
cell_roi_ids = [rid for rid in all_roi_ids if not str(rid).startswith("background")]
return cell_roi_ids
return list(range(self.get_num_rois()))
[docs]
def get_roi_image_masks(self, roi_ids=None) -> np.ndarray:
"""Get the image masks extracted from segmentation algorithm.
Parameters
----------
roi_ids: array_like
A list or 1D array of ids of the ROIs. Length is the number of ROIs requested.
Returns
-------
image_masks: numpy.ndarray
3-D array(val 0 or 1): image_height X image_width X length(roi_ids)
"""
if roi_ids is None:
roi_ids = self.get_roi_ids()
if self._roi_masks is None:
# Fallback for extractors that haven't migrated yet
raise NotImplementedError("This extractor has not been updated to use the new ROI representation system.")
# Filter to only cell ROIs (exclude background)
cell_roi_ids = [rid for rid in roi_ids if not str(rid).startswith("background")]
if len(cell_roi_ids) == 0:
frame_shape = self.get_frame_shape()
return np.zeros((*frame_shape, 0))
# Get masks from representations
masks = []
for roi_id in cell_roi_ids:
mask = self._roi_masks.get_roi_image_mask(roi_id)
masks.append(mask)
return np.stack(masks, axis=2)
[docs]
def get_roi_pixel_masks(self, roi_ids=None) -> np.array:
"""Get the weights applied to each of the pixels of the mask.
Parameters
----------
roi_ids: array_like
A list or 1D array of ids of the ROIs. Length is the number of ROIs requested.
Returns
-------
pixel_masks: list
List of length number of rois, each element is a 2-D array with shape (number_of_non_zero_pixels, 3).
Columns 1 and 2 are the x and y coordinates of the pixel, while the third column represents the weight of
the pixel.
"""
if roi_ids is None:
roi_ids = self.get_roi_ids()
if self._roi_masks is None:
# Fallback for extractors that haven't migrated yet
raise NotImplementedError("This extractor has not been updated to use the new ROI representation system.")
# Filter to only cell ROIs (exclude background)
cell_roi_ids = [rid for rid in roi_ids if not str(rid).startswith("background")]
# Get pixel masks from representations
pixel_masks = []
for roi_id in cell_roi_ids:
pixel_mask = self._roi_masks.get_roi_pixel_mask(roi_id)
pixel_masks.append(pixel_mask)
return pixel_masks
[docs]
def get_background_ids(self) -> list:
"""Get the list of background components ids.
Returns
-------
background_components_ids: list
List of background components ids.
"""
if self._roi_masks is None:
return list(range(self.get_num_background_components()))
# Extract background IDs from roi_masks
all_roi_ids = self._roi_masks.get_roi_ids()
background_ids = [rid for rid in all_roi_ids if str(rid).startswith("background")]
return background_ids
[docs]
def get_background_image_masks(self, background_ids=None) -> np.ndarray:
"""Get the background image masks extracted from segmentation algorithm.
Parameters
----------
background_ids: array_like
A list or 1D array of ids of the background components. Length is the number of background components requested.
Returns
-------
background_image_masks: numpy.ndarray
3-D array(val 0 or 1): image_height X image_width X length(background_ids)
"""
if background_ids is None:
background_ids = self.get_background_ids()
if self._roi_masks is None:
# Fallback for extractors that haven't migrated yet
return np.zeros((*self.get_frame_shape(), 0))
if len(background_ids) == 0:
frame_shape = self.get_frame_shape()
return np.zeros((*frame_shape, 0))
# Get masks from representations
masks = []
for bg_id in background_ids:
mask = self._roi_masks.get_roi_image_mask(bg_id)
masks.append(mask)
return np.stack(masks, axis=2)
[docs]
def get_background_pixel_masks(self, background_ids=None) -> np.array:
"""Get the weights applied to each of the pixels of the mask.
Parameters
----------
background_ids: array_like
A list or 1D array of ids of the ROIs. Length is the number of ROIs requested.
Returns
-------
pixel_masks: list
List of length number of rois, each element is a 2-D array with shape (number_of_non_zero_pixels, 3).
Columns 1 and 2 are the x and y coordinates of the pixel, while the third column represents the weight of
the pixel.
"""
if background_ids is None:
background_ids = self.get_background_ids()
if self._roi_masks is None:
# Fallback for extractors that haven't migrated yet
return []
# Get pixel masks from representations
pixel_masks = []
for bg_id in background_ids:
pixel_mask = self._roi_masks.get_roi_pixel_mask(bg_id)
pixel_masks.append(pixel_mask)
return pixel_masks
[docs]
def slice_samples(self, start_sample: int | None = None, end_sample: int | None = None):
"""Return a new SegmentationExtractor ranging from the start_sample to the end_sample.
Parameters
----------
start_sample: int, optional
Start sample index (inclusive).
end_sample: int, optional
End sample index (exclusive).
Returns
-------
segmentation: SampleSlicedSegmentationExtractor
The sliced SegmentationExtractor object.
"""
return SampleSlicedSegmentationExtractor(
parent_segmentation=self, start_sample=start_sample, end_sample=end_sample
)
[docs]
def select_rois(self, roi_ids: list[str | int]):
"""Return a new SegmentationExtractor with only the specified ROIs.
Parameters
----------
roi_ids : list[str | int]
List of ROI IDs to include. Can include both cell and background ROI IDs.
The order of IDs is preserved in the returned extractor.
Returns
-------
segmentation : RoiSlicedSegmentationExtractor
The ROI-sliced SegmentationExtractor object.
Raises
------
ValueError
If roi_ids is empty or contains IDs not present in the extractor.
Notes
-----
This method creates a lazy view of the segmentation data with a subset of ROIs.
The slicing is applied to ROI-related data while temporal and spatial properties
are preserved.
Examples
--------
>>> # Select specific ROIs
>>> subset = extractor.select_rois([0, 1, 2])
>>>
>>> # Compose with temporal slicing
>>> subset = extractor.select_rois([0, 1, 2]).slice_samples(100, 200)
"""
if not roi_ids:
raise ValueError("roi_ids cannot be empty")
all_valid_ids = set(self.get_roi_ids()) | set(self.get_background_ids())
invalid_ids = [rid for rid in roi_ids if rid not in all_valid_ids]
if invalid_ids:
raise ValueError(
f"ROI ids {invalid_ids} not found in extractor. "
f"Available cell ROI ids: {self.get_roi_ids()}, "
f"Available background ids: {self.get_background_ids()}"
)
from .roislicedsegmentationextractor import _RoiSlicedSegmentationExtractor
return _RoiSlicedSegmentationExtractor(parent_segmentation=self, roi_ids=roi_ids)
[docs]
def get_traces(
self,
roi_ids: list[int | str] = None,
start_frame: int | None = None,
end_frame: int | None = None,
name: str = "raw",
) -> np.ndarray:
"""Get the traces of each ROI specified by roi_ids.
Parameters
----------
roi_ids: array_like
A list or 1D array of ids of the ROIs. Length is the number of ROIs requested.
start_frame: int
The starting frame of the trace.
end_frame: int
The ending frame of the trace.
name: str
The name of the trace to retrieve ex. 'raw', 'dff', 'neuropil', 'deconvolved'
Returns
-------
traces: array_like
2-D array (ROI x timepoints)
"""
traces_dict = self.get_traces_dict()
if traces_dict.get(name) is None:
return None
response = next((r for r in self._roi_responses if r.response_type == name), None)
if response is None:
raise ValueError(
f"Traces for {name} are registered in the trace dictionary but missing from the internal store."
)
data = np.asarray(response.data)
sliced = data[start_frame:end_frame, :]
input_roi_ids = roi_ids
if input_roi_ids is None:
return np.array(sliced)
# Match ROI ids by value, allowing for differing orders between sources
response_roi_ids = list(response.roi_ids)
indices: list[int] = []
missing_roi_ids: list = []
for roi_id in input_roi_ids:
try:
indices.append(response_roi_ids.index(roi_id))
except ValueError:
missing_roi_ids.append(roi_id)
if missing_roi_ids:
raise ValueError(
f"ROI ids {missing_roi_ids} not found for response '{name}'. Available ids: {response_roi_ids}"
)
return np.array(sliced[:, indices])
[docs]
def get_traces_dict(self) -> dict:
"""Get traces as a dictionary with key as the name of the ROiResponseSeries.
Returns
-------
_roi_response_dict: dict
dictionary with key, values representing different types of RoiResponseSeries:
Raw Fluorescence, DeltaFOverF, Denoised, Neuropil, Deconvolved, Background, etc.
"""
traces = {response.response_type: response.data for response in self._roi_responses}
for expected_type in ("raw", "dff", "neuropil", "deconvolved", "denoised", "baseline", "background"):
traces.setdefault(expected_type, None)
return traces
[docs]
def get_images_dict(self) -> dict:
"""Get images as a dictionary with key as the name of the ROIResponseSeries.
Returns
-------
_roi_image_dict: dict
dictionary with key, values representing different types of Images used in segmentation:
Mean, Correlation image, Maximum projection, etc.
"""
return dict(self._summary_images)
[docs]
def get_image(self, name: str = "correlation") -> np.ndarray:
"""Get specific images: mean or correlation.
Parameters
----------
name:str
name of the type of image to retrieve
Returns
-------
images: numpy.ndarray
"""
if name not in self.get_images_dict():
raise ValueError(f"could not find {name} image, enter one of {list(self.get_images_dict().keys())}")
return self.get_images_dict().get(name)
[docs]
def get_sampling_frequency(self) -> float:
"""Get the sampling frequency in Hz.
Returns
-------
sampling_frequency: float
Sampling frequency of the recording in Hz.
"""
if self._sampling_frequency is not None:
return float(self._sampling_frequency)
return self._sampling_frequency
[docs]
def get_num_rois(self) -> int:
"""Get total number of Regions of Interest (ROIs) in the acquired images.
Returns
-------
num_rois: int
The number of ROIs extracted.
"""
if self._roi_masks is not None:
# Count only cell ROIs (exclude background)
all_roi_ids = self._roi_masks.get_roi_ids()
cell_roi_ids = [rid for rid in all_roi_ids if not str(rid).startswith("background")]
return len(cell_roi_ids)
# Fallback to trace-based counting
for trace in self.get_traces_dict().values():
if trace is not None and len(trace.shape) > 0:
return trace.shape[1]
return 0
[docs]
def get_num_background_components(self) -> int:
"""Get total number of background components in the acquired images.
Returns
-------
num_background_components: int
The number of background components extracted.
"""
if self._roi_masks is not None:
# Count background ROIs from representations
all_roi_ids = self._roi_masks.get_roi_ids()
background_ids = [rid for rid in all_roi_ids if str(rid).startswith("background")]
return len(background_ids)
# Fallback to response-based counting
for response in self._roi_responses:
if response.response_type in {"neuropil", "background"}:
data = response.data
if data is None:
continue
if not hasattr(data, "shape"):
continue
if len(data.shape) == 1:
return int(data.shape[0])
return int(data.shape[1])
return 0
[docs]
def get_num_channels(self) -> int:
"""Get number of channels in the pipeline.
Returns
-------
num_of_channels: int
number of channels
Deprecated
----------
This method will be removed on or after September 2026.
"""
warnings.warn(
"get_num_channels is deprecated and will be removed on or after September 2026.",
FutureWarning,
stacklevel=2,
)
return len(self._channel_names)
[docs]
def get_num_planes(self) -> int:
"""Get the default number of planes of imaging for the segmentation extractor.
Notes
-----
Defaults to 1 for all but the MultiSegmentationExtractor.
Returns
-------
self._num_planes: int
number of planes
"""
return self._num_planes
[docs]
def set_times(self, times: ArrayLike):
"""Set the recording times in seconds for each frame.
Parameters
----------
times: array-like
The times in seconds for each frame
Notes
-----
Operates on _times attribute of the SegmentationExtractor object.
"""
assert len(times) == self.get_num_samples(), "'times' should have the same length of the number of samples!"
self._times = np.array(times, dtype=np.float64)
[docs]
def has_time_vector(self) -> bool:
"""Detect if the SegmentationExtractor has a time vector set or not.
Returns
-------
has_time_vector: bool
True if the SegmentationExtractor has a time vector set, otherwise False.
"""
return self._times is not None
[docs]
def get_timestamps(self, start_sample: int | None = None, end_sample: int | None = None) -> np.ndarray:
"""
Retrieve the timestamps for the data in this extractor.
Parameters
----------
start_sample : int, optional
The starting sample index. If None, starts from the beginning.
end_sample : int, optional
The ending sample index. If None, goes to the end.
Returns
-------
timestamps: numpy.ndarray
The timestamps for the data stream.
"""
# Set defaults
if start_sample is None:
start_sample = 0
if end_sample is None:
end_sample = self.get_num_samples()
# Return cached timestamps if available
if self._times is not None:
return self._times[start_sample:end_sample]
# See if native timetstamps are available from the format
native_timestamps = self.get_native_timestamps()
if native_timestamps is not None:
self._times = native_timestamps # Cache the native timestamps
return native_timestamps[start_sample:end_sample]
# Fallback to calculated timestamps from sampling frequency
sample_indices = np.arange(start_sample, end_sample)
return sample_indices / self.get_sampling_frequency()
[docs]
def time_to_sample_indices(self, times: float | ArrayLike) -> np.ndarray:
"""Convert user-inputted times (in seconds) to sample indices.
Parameters
----------
times: float or array-like
The times (in seconds) to be converted to sample indices.
Returns
-------
sample_indices: np.ndarray
The corresponding sample indices.
"""
# Ensure native timestamps are cached if available
native_timestamps = self.get_native_timestamps()
if native_timestamps is not None:
self._times = native_timestamps # Cache the native timestamps
if self._times is not None:
return (np.searchsorted(self._times, times, side="right") - 1).astype("int64")
else:
return np.round(times * self.get_sampling_frequency()).astype("int64")
[docs]
def set_property(self, key: str, values: ArrayLike, ids: ArrayLike, *, description: str = ""):
"""Set property values for ROIs.
Parameters
----------
key: str
The name of the property.
values: array-like
Array of property values. Must have same length as ids and num_rois.
ids: array-like
Array of ROI ids corresponding to the values. Must have same length as values and num_rois.
description: str, optional
Description of the property.
"""
values = np.asarray(values)
ids = list(ids)
num_rois = self.get_num_rois()
# Check that all arrays have the correct length
if len(values) != num_rois or len(ids) != num_rois:
raise ValueError(
f"Length of values ({len(values)}) and ids ({len(ids)}) must match number of ROIs ({num_rois})"
)
# Verify that the provided ids match the extractor's ROI ids
extractor_roi_ids = self.get_roi_ids()
if set(ids) != set(extractor_roi_ids):
raise ValueError("Provided ids must match the extractor's ROI ids")
# Create property array with values in the correct order
property_array = np.empty(values.shape, dtype=values.dtype)
for roi_index, roi_id in enumerate(extractor_roi_ids):
id_index = ids.index(roi_id)
property_array[roi_index] = values[id_index]
self._properties[key] = _PropertyInfo(data=property_array, description=description)
[docs]
def get_property(self, key: str, ids: ArrayLike) -> np.ndarray:
"""Get property values for ROIs.
Parameters
----------
key: str
The name of the property.
ids: array-like
Array of ROI ids to get property values for.
Returns
-------
values: array-like
Array of property values for the specified ROIs.
"""
ids = np.asarray(ids)
if key not in self._properties:
available_keys = list(self._properties.keys())
raise KeyError(f"Property '{key}' not found. Available properties: {available_keys}")
# Check that all requested ROI ids exist in extractor
all_roi_ids = self.get_roi_ids()
for roi_id in ids:
if roi_id not in all_roi_ids:
raise ValueError(f"ROI id {roi_id} not found in extractor. Available ROI ids: {all_roi_ids}")
# Map ids to indices and filter data
indices = [all_roi_ids.index(roi_id) for roi_id in ids]
return self._properties[key].data[indices]
[docs]
def get_property_keys(self) -> list[str]:
"""Get list of available property keys.
Returns
-------
keys: list
List of property names.
"""
return list(self._properties.keys())
[docs]
def get_property_info(self, key: str) -> _PropertyInfo:
"""Get property data and metadata bundled together.
Parameters
----------
key: str
The name of the property.
Returns
-------
property_info: _PropertyInfo
Dataclass containing the property data and description.
"""
if key not in self._properties:
available_keys = list(self._properties.keys())
raise KeyError(f"Property '{key}' not found. Available properties: {available_keys}")
return self._properties[key]
[docs]
class SampleSlicedSegmentationExtractor(SegmentationExtractor):
"""Class to get a lazy sample slice.
Do not use this class directly but use `.slice_samples(...)` on a SegmentationExtractor object.
"""
extractor_name = "SampleSlicedSegmentationExtractor"
def __init__(
self,
parent_segmentation: SegmentationExtractor,
start_sample: int | None = None,
end_sample: int | None = None,
):
"""Initialize a SegmentationExtractor whose samples subset the parent.
Subset is exclusive on the right bound, that is, the indexes of this SegmentationExtractor range over
[0, ..., end_sample-start_sample-1], which is used to resolve the index mapping in `get_traces(...)`.
Parameters
----------
parent_segmentation : SegmentationExtractor
The SegmentationExtractor object to subset the samples of.
start_sample : int, optional
The left bound of the samples to subset.
The default is the start sample of the parent.
end_sample : int, optional
The right bound of the samples, exclusively, to subset.
The default is end sample of the parent.
"""
self._parent_segmentation = parent_segmentation
parent_size = self._parent_segmentation.get_num_samples()
if start_sample is None:
start_sample = 0
else:
assert 0 <= start_sample < parent_size
if end_sample is None:
end_sample = parent_size
else:
assert 0 < end_sample <= parent_size
assert end_sample > start_sample, "'start_sample' must be smaller than 'end_sample'!"
self._start_sample = start_sample
self._end_sample = end_sample
self._num_samples = self._end_sample - self._start_sample
super().__init__()
# Share the parent's ROI representations (spatial data is same, only temporal is sliced)
if hasattr(self._parent_segmentation, "_roi_masks"):
self._roi_masks = self._parent_segmentation._roi_masks
self._roi_ids = list(self._parent_segmentation.get_roi_ids())
for roi_response in self._parent_segmentation._roi_responses:
sliced_data = roi_response.data[start_sample:end_sample, :]
self._roi_responses.append(
_RoiResponse(roi_response.response_type, sliced_data, list(roi_response.roi_ids))
)
self._summary_images = dict(self._parent_segmentation.get_images_dict())
# Preserve parent's channel names and other attributes (access attribute directly to avoid deprecation warning)
self._channel_names = self._parent_segmentation._channel_names
self._num_planes = self._parent_segmentation.get_num_planes()
# The _times attribute of the sliced extractor acts like a view to the parent's _times,
# which is memory efficient. However, it maintains copy semantics which are safer for the following reasons:
# Currently, there are only two ways of setting the _times:
#
# 1. set_times() method - always overwrites the entire _times array
# 2. get_timestamps() method - in some cases will cache get_native_timestamps() output
#
# Both methods overwrite the entire _times array of the instance, preventing aliasing
# problems where the _times reference of a slice extractor could be modified by the parent
# or vice versa. See issue 498 for more details about this design.
if getattr(self._parent_segmentation, "_times") is not None:
self._times = self._parent_segmentation._times[start_sample:end_sample]
# Properties use the same copy-on-write pattern as _times above.
# The shallow dict copy shares the underlying _PropertyInfo instances with the parent
# (memory efficient), but set_property() always creates a new _PropertyInfo and rebinds
# the dict key, so writes only affect this instance.
self._properties = dict(self._parent_segmentation._properties)
[docs]
def get_native_timestamps(
self, start_sample: int | None = None, end_sample: int | None = None
) -> np.ndarray | None:
# Adjust the sample indices to account for the slice offset
start_sample = start_sample or 0
end_sample = end_sample or self.get_num_samples()
# Map slice-relative indices to parent indices
parent_start = self._start_sample + start_sample
parent_end = self._start_sample + end_sample
return self._parent_segmentation.get_native_timestamps(start_sample=parent_start, end_sample=parent_end)
[docs]
def get_frame_shape(self) -> tuple[int, int]:
return tuple(self._parent_segmentation.get_frame_shape())
[docs]
def get_roi_image_masks(self, roi_ids=None) -> np.ndarray:
return self._parent_segmentation.get_roi_image_masks(roi_ids=roi_ids)
[docs]
def get_roi_pixel_masks(self, roi_ids: ArrayLike | None = None) -> list[np.ndarray]:
return self._parent_segmentation.get_roi_pixel_masks(roi_ids=roi_ids)
[docs]
def get_background_image_masks(self, background_ids=None) -> np.ndarray:
return self._parent_segmentation.get_background_image_masks(background_ids=background_ids)
[docs]
def get_background_pixel_masks(self, background_ids: ArrayLike | None = None) -> list[np.ndarray]:
return self._parent_segmentation.get_background_pixel_masks(background_ids=background_ids)
[docs]
def get_num_background_components(self) -> int:
return self._parent_segmentation.get_num_background_components()
[docs]
def get_image(self, name: str = "correlation") -> np.ndarray:
return self._parent_segmentation.get_image(name=name)
[docs]
def get_sampling_frequency(self) -> float:
return self._parent_segmentation.get_sampling_frequency()
[docs]
def has_time_vector(self) -> bool:
# Override to check parent segmentation for time vector
return self._parent_segmentation.has_time_vector()