Source code for roiextractors.extractors.numpyextractors.numpyextractors
"""Imaging and Segmentation Extractors for .npy files.
Classes
-------
NumpyImagingExtractor
An ImagingExtractor specified by timeseries .npy file, sampling frequency, and channel names.
NumpySegmentationExtractor
A Segmentation extractor specified by image masks and traces .npy files.
"""
import warnings
from pathlib import Path
import numpy as np
from ...extraction_tools import ArrayType, FloatType, PathType
from ...imagingextractor import ImagingExtractor
from ...segmentationextractor import (
SegmentationExtractor,
_ROIMasks,
_RoiResponse,
)
[docs]
class NumpyImagingExtractor(ImagingExtractor):
"""An ImagingExtractor specified by timeseries .npy file, sampling frequency, and channel names."""
extractor_name = "NumpyImagingExtractor"
installation_mesg = "" # error message when not installed
def __init__(
self,
timeseries: PathType,
sampling_frequency: FloatType,
channel_names: ArrayType = None,
):
"""Create a NumpyImagingExtractor from a .npy file.
Parameters
----------
timeseries: PathType
Path to .npy file.
sampling_frequency: FloatType
Sampling frequency of the video in Hz.
channel_names: ArrayType
List of channel names.
"""
ImagingExtractor.__init__(self)
if isinstance(timeseries, (str, Path)):
timeseries = Path(timeseries)
if timeseries.is_file():
assert timeseries.suffix == ".npy", "'timeseries' file is not a numpy file (.npy)"
self.is_dumpable = True
self._video = np.load(timeseries, mmap_mode="r")
self._kwargs = {
"timeseries": str(Path(timeseries).absolute()),
"sampling_frequency": sampling_frequency,
}
else:
raise ValueError("'timeseries' is does not exist")
elif isinstance(timeseries, np.ndarray):
self.is_dumpable = False
self._video = timeseries
self._kwargs = {
"timeseries": timeseries,
"sampling_frequency": sampling_frequency,
}
else:
raise TypeError("'timeseries' can be a str or a numpy array")
self._sampling_frequency = float(sampling_frequency)
self._sampling_frequency = sampling_frequency
self._channel_names = channel_names
(
self._num_samples,
self._num_rows,
self._num_columns,
self._num_channels,
) = self.get_volume_shape(self._video)
if len(self._video.shape) == 3:
# check if this converts to np.ndarray
self._video = self._video[np.newaxis, :]
if self._channel_names is not None:
assert len(self._channel_names) == self._num_channels, (
"'channel_names' length is different than number " "of channels"
)
else:
self._channel_names = [f"channel_{ch}" for ch in range(self._num_channels)]
[docs]
@staticmethod
def get_volume_shape(video) -> tuple[int, int, int, int]:
"""Get the shape of a video (num_frames, num_rows, num_columns, num_channels).
Parameters
----------
video: numpy.ndarray
The video to get the shape of.
Returns
-------
video_shape: tuple
The shape of the video (num_frames, num_rows, num_columns, num_channels).
"""
if len(video.shape) == 3:
# 1 channel
num_channels = 1
num_frames, num_rows, num_columns = video.shape
else:
num_frames, num_rows, num_columns, num_channels = video.shape
return num_frames, num_rows, num_columns, num_channels
[docs]
def get_series(self, start_sample=None, end_sample=None) -> np.ndarray:
return self._video[start_sample:end_sample, ..., 0]
[docs]
def get_image_shape(self) -> tuple[int, int]:
"""Get the shape of the video frame (num_rows, num_columns).
Returns
-------
image_shape: tuple
Shape of the video frame (num_rows, num_columns).
"""
return (self._num_rows, self._num_columns)
[docs]
def get_channel_names(self):
warnings.warn(
"get_channel_names is deprecated and will be removed in May 2026 or after.",
category=FutureWarning,
stacklevel=2,
)
return self._channel_names
[docs]
def get_native_timestamps(
self, start_sample: int | None = None, end_sample: int | None = None
) -> np.ndarray | None:
# Numpy arrays do not have native timestamps
return None
[docs]
class NumpySegmentationExtractor(SegmentationExtractor):
"""A Segmentation extractor specified by image masks and traces .npy files.
NumpySegmentationExtractor objects are built to contain all data coming from
a file format for which there is currently no support. To construct this,
all data must be entered manually as arguments.
"""
extractor_name = "NumpySegmentationExtractor"
installation_mesg = "" # error message when not installed
def __init__(
self,
image_masks,
raw=None,
dff=None,
deconvolved=None,
neuropil=None,
mean_image=None,
correlation_image=None,
roi_ids=None,
roi_locations=None,
sampling_frequency=None,
rejected_list=None,
channel_names=None,
movie_dims=None,
accepted_list=None,
):
"""Create a NumpySegmentationExtractor from a .npy file.
Parameters
----------
image_masks: np.ndarray
Binary image for each of the regions of interest
raw: np.ndarray
Fluorescence response of each of the ROI in time
dff: np.ndarray
DfOverF response of each of the ROI in time
deconvolved: np.ndarray
deconvolved response of each of the ROI in time
neuropil: np.ndarray
neuropil response of each of the ROI in time
mean_image: np.ndarray
Mean image
correlation_image: np.ndarray
correlation image
roi_ids: int list
Unique ids of the ROIs if any
roi_locations: np.ndarray
x and y location representative of ROI mask
sampling_frequency: float
Frame rate of the movie
rejected_list: list
list of ROI ids that are rejected manually or via automated rejection
channel_names: list
list of strings representing channel names
movie_dims: tuple
height x width of the movie
"""
SegmentationExtractor.__init__(self)
if roi_ids is None:
cell_ids: list[int] | None = None
else:
cell_ids = list(roi_ids)
if isinstance(image_masks, (str, Path)):
image_masks = Path(image_masks)
if not image_masks.is_file():
raise ValueError("'timeeseries' is does not exist")
assert image_masks.suffix == ".npy", "'image_masks' file is not a numpy file (.npy)"
self.is_dumpable = True
image_masks_data = np.load(image_masks, mmap_mode="r")
if cell_ids is None:
cell_ids = list(range(image_masks_data.shape[2]))
self._roi_ids = cell_ids
# Create ROI representations
roi_id_map = {roi_id: index for index, roi_id in enumerate(cell_ids)}
self._roi_masks = _ROIMasks(
data=image_masks_data,
mask_tpe="nwb-image_mask",
field_of_view_shape=image_masks_data.shape[:2],
roi_id_map=roi_id_map,
)
raw_data = None
if raw is not None:
raw = Path(raw)
assert raw.suffix == ".npy", "'raw' file is not a numpy file (.npy)"
raw_data = np.load(raw, mmap_mode="r")
self._roi_responses.append(_RoiResponse("raw", raw_data, cell_ids))
dff_data = None
if dff is not None:
dff = Path(dff)
assert dff.suffix == ".npy", "'dff' file is not a numpy file (.npy)"
dff_data = np.load(dff, mmap_mode="r")
self._roi_responses.append(_RoiResponse("dff", dff_data, cell_ids))
deconvolved_data = None
if deconvolved is not None:
deconvolved = Path(deconvolved)
assert deconvolved.suffix == ".npy", "'deconvolved' file is not a numpy file (.npy)"
deconvolved_data = np.load(deconvolved, mmap_mode="r")
self._roi_responses.append(_RoiResponse("deconvolved", deconvolved_data, cell_ids))
neuropil_data = None
if neuropil is not None:
neuropil = Path(neuropil)
assert neuropil.suffix == ".npy", "'neuropil' file is not a numpy file (.npy)"
neuropil_data = np.load(neuropil, mmap_mode="r")
self._roi_responses.append(_RoiResponse("neuropil", neuropil_data, cell_ids))
self._kwargs = {"image_masks": str(image_masks.absolute())}
if raw is not None:
self._kwargs.update({"raw": str(raw.absolute())})
if dff is not None:
self._kwargs.update({"dff": str(dff.absolute())})
if neuropil is not None:
self._kwargs.update({"neuropil": str(neuropil.absolute())})
if deconvolved is not None:
self._kwargs.update({"deconvolved": str(deconvolved.absolute())})
elif isinstance(image_masks, np.ndarray):
NoneType = type(None)
assert isinstance(raw, (np.ndarray, NoneType))
assert isinstance(dff, (np.ndarray, NoneType))
assert isinstance(neuropil, (np.ndarray, NoneType))
assert isinstance(deconvolved, (np.ndarray, NoneType))
self.is_dumpable = False
image_masks_data = image_masks
if cell_ids is None:
cell_ids = list(range(image_masks_data.shape[2]))
self._roi_ids = cell_ids
# Create ROI representations
roi_id_map = {roi_id: index for index, roi_id in enumerate(cell_ids)}
self._roi_masks = _ROIMasks(
data=image_masks_data,
mask_tpe="nwb-image_mask",
field_of_view_shape=image_masks_data.shape[:2],
roi_id_map=roi_id_map,
)
if raw is not None:
assert image_masks_data.shape[-1] == raw.shape[-1], (
"Inconsistency between image masks and raw traces. "
"Image masks must be (px, py, num_rois), "
"traces must be (num_frames, num_rois)"
)
self._roi_responses.append(_RoiResponse("raw", raw, cell_ids))
if dff is not None:
assert image_masks_data.shape[-1] == dff.shape[-1], (
"Inconsistency between image masks and dff traces. "
"Image masks must be (px, py, num_rois), "
"traces must be (num_frames, num_rois)"
)
self._roi_responses.append(_RoiResponse("dff", dff, cell_ids))
if neuropil is not None:
assert image_masks_data.shape[-1] == neuropil.shape[-1], (
"Inconsistency between image masks and neuropil traces. "
"Image masks must be (px, py, num_rois), "
"traces must be (num_frames, num_rois)"
)
self._roi_responses.append(_RoiResponse("neuropil", neuropil, cell_ids))
if deconvolved is not None:
assert image_masks_data.shape[-1] == deconvolved.shape[-1], (
"Inconsistency between image masks and deconvolved traces. "
"Image masks must be (px, py, num_rois), "
"traces must be (num_frames, num_rois)"
)
self._roi_responses.append(_RoiResponse("deconvolved", deconvolved, cell_ids))
self._kwargs = {
"image_masks": image_masks,
"signal": raw,
"dff": dff,
"neuropil": neuropil,
"deconvolved": deconvolved,
}
else:
raise TypeError("'image_masks' can be a str or a numpy array")
# Get image masks data from representations
image_masks_data = self._roi_masks.data
self._movie_dims = movie_dims if movie_dims is not None else image_masks_data.shape
if mean_image is not None:
self._summary_images["mean"] = mean_image
if correlation_image is not None:
self._summary_images["correlation"] = correlation_image
if self._roi_ids is None:
self._roi_ids = list(np.arange(image_masks_data.shape[2]))
self._roi_locs = roi_locations
self._sampling_frequency = sampling_frequency
self._channel_names = channel_names
# Set accepted_list and rejected_list as properties if provided
if accepted_list is not None:
is_accepted = np.array([roi_id in accepted_list for roi_id in self._roi_ids], dtype=bool)
self.set_property("is_accepted", is_accepted, self._roi_ids)
if rejected_list is not None:
is_rejected = np.array([roi_id in rejected_list for roi_id in self._roi_ids], dtype=bool)
self.set_property("is_rejected", is_rejected, self._roi_ids)
@property
def image_dims(self):
"""Return the dimensions of the image.
Returns
-------
image_dims: list
The dimensions of the image (num_rois, num_rows, num_columns).
"""
return list(self._roi_masks.field_of_view_shape)
[docs]
def get_accepted_list(self) -> list:
"""Get a list of accepted ROI ids.
Returns
-------
accepted_list: list
List of accepted ROI ids.
"""
warnings.warn(
"get_accepted_list is deprecated and will be removed in May 2026. "
"Use get_property('is_accepted', ids) instead.",
DeprecationWarning,
stacklevel=2,
)
if "is_accepted" not in self.get_property_keys():
return list(self.get_roi_ids())
is_accepted = self.get_property("is_accepted", self.get_roi_ids())
return [roi_id for roi_id, accepted in zip(self.get_roi_ids(), is_accepted) if accepted]
[docs]
def get_rejected_list(self) -> list:
"""Get a list of rejected ROI ids.
Returns
-------
rejected_list: list
List of rejected ROI ids.
"""
warnings.warn(
"get_rejected_list is deprecated and will be removed in May 2026. "
"Use get_property('is_rejected', ids) instead.",
DeprecationWarning,
stacklevel=2,
)
if "is_rejected" not in self.get_property_keys():
return []
is_rejected = self.get_property("is_rejected", self.get_roi_ids())
return [roi_id for roi_id, rejected in zip(self.get_roi_ids(), is_rejected) if rejected]
@property
def roi_locations(self):
"""Returns the center locations (x, y) of each ROI."""
if self._roi_locs is None:
num_ROIs = self.get_num_rois()
raw_images = self._roi_masks.data # (H, W, N) array
roi_location = np.ndarray([2, num_ROIs], dtype="int")
for i in range(num_ROIs):
temp = np.where(raw_images[:, :, i] == np.amax(raw_images[:, :, i]))
roi_location[:, i] = np.array([np.median(temp[0]), np.median(temp[1])]).T
return roi_location
else:
return self._roi_locs
# defining the abstract class informed methods:
[docs]
def get_roi_ids(self):
if self._roi_ids is None:
return list(range(self.get_num_rois()))
else:
return self._roi_ids
[docs]
def get_frame_shape(self):
"""Get the shape of the video frame (num_rows, num_columns).
Returns
-------
frame_shape: tuple
Shape of the video frame (num_rows, num_columns).
"""
return self._movie_dims
[docs]
def get_num_samples(self):
"""Get the number of samples in the recording (duration of recording).
Returns
-------
num_samples: int
Number of samples in the recording.
"""
for trace in self.get_traces_dict().values():
if trace is not None and len(trace.shape) > 0:
return trace.shape[0]
[docs]
def get_native_timestamps(
self, start_sample: int | None = None, end_sample: int | None = None
) -> np.ndarray | None:
# Numpy arrays do not have native timestamps
return None