Source code for roiextractors.extractors.suite2p.suite2psegmentationextractor
"""A segmentation extractor for Suite2p.
Classes
-------
Suite2pSegmentationExtractor
A segmentation extractor for Suite2p.
"""
import warnings
from pathlib import Path
from warnings import warn
import numpy as np
from ...extraction_tools import PathType
from ...segmentationextractor import (
SegmentationExtractor,
_ROIMasks,
_RoiResponse,
)
[docs]
class Suite2pSegmentationExtractor(SegmentationExtractor):
"""A segmentation extractor for Suite2p."""
extractor_name = "Suite2pSegmentationExtractor"
[docs]
@classmethod
def get_available_channels(cls, folder_path: PathType) -> list[str]:
"""Get the available channel names from the folder paths produced by Suite2p.
Parameters
----------
folder_path : PathType
Path to Suite2p output path.
Returns
-------
channel_names: list
List of channel names.
"""
plane_names = cls.get_available_planes(folder_path=folder_path)
channel_names = ["chan1"]
second_channel_paths = list((Path(folder_path) / plane_names[0]).glob("F_chan2.npy"))
if not second_channel_paths:
return channel_names
channel_names.append("chan2")
return channel_names
[docs]
@classmethod
def get_available_planes(cls, folder_path: PathType) -> list[str]:
"""Get the available plane names from the folder produced by Suite2p.
Parameters
----------
folder_path : PathType
Path to Suite2p output path.
Returns
-------
plane_names: list
List of plane names.
"""
from natsort import natsorted
folder_path = Path(folder_path)
prefix = "plane"
plane_paths = natsorted(folder_path.glob(pattern=prefix + "*"))
assert len(plane_paths), f"No planes found in '{folder_path}'."
plane_names = [plane_path.stem for plane_path in plane_paths]
return plane_names
def __init__(
self,
folder_path: PathType,
channel_name: str | None = None,
plane_name: str | None = None,
):
"""Create SegmentationExtractor object out of suite 2p data type.
Parameters
----------
folder_path: str or Path
The path to the 'suite2p' folder.
channel_name: str, optional
The name of the channel to load, to determine what channels are available use Suite2pSegmentationExtractor.get_available_channels(folder_path).
plane_name: str, optional
The name of the plane to load, to determine what planes are available use Suite2pSegmentationExtractor.get_available_planes(folder_path).
"""
channel_names = self.get_available_channels(folder_path=folder_path)
if channel_name is None:
if len(channel_names) > 1:
# For backward compatibility maybe it is better to warn first
warn(
"More than one channel is detected! Please specify which channel you wish to load with the `channel_name` argument. "
"To see what channels are available, call `Suite2pSegmentationExtractor.get_available_channels(folder_path=...)`.",
UserWarning,
)
channel_name = channel_names[0]
self.channel_name = channel_name
if self.channel_name not in channel_names:
raise ValueError(
f"The selected channel '{channel_name}' is not a valid channel name. To see what channels are available, "
f"call `Suite2pSegmentationExtractor.get_available_channels(folder_path=...)`."
)
plane_names = self.get_available_planes(folder_path=folder_path)
if plane_name is None:
if len(plane_names) > 1:
# For backward compatibility maybe it is better to warn first
warn(
"More than one plane is detected! Please specify which plane you wish to load with the `plane_name` argument. "
"To see what planes are available, call `Suite2pSegmentationExtractor.get_available_planes(folder_path=...)`.",
UserWarning,
)
plane_name = plane_names[0]
if plane_name not in plane_names:
raise ValueError(
f"The selected plane '{plane_name}' is not a valid plane name. To see what planes are available, "
f"call `Suite2pSegmentationExtractor.get_available_planes(folder_path=...)`."
)
self.plane_name = plane_name
super().__init__()
self.folder_path = Path(folder_path)
options = self._load_npy(file_name="ops.npy", require=True)
self.options = options.item()
self._sampling_frequency = self.options["fs"]
self._num_frames = self.options["nframes"]
self._image_shape = (self.options["Ly"], self.options["Lx"])
self.stat = self._load_npy(file_name="stat.npy", require=True)
fluorescence_traces_file_name = "F.npy" if channel_name == "chan1" else "F_chan2.npy"
neuropil_traces_file_name = "Fneu.npy" if channel_name == "chan1" else "Fneu_chan2.npy"
raw_traces = self._load_npy(file_name=fluorescence_traces_file_name, mmap_mode="r", transpose=True)
neuropil_traces = self._load_npy(file_name=neuropil_traces_file_name, mmap_mode="r", transpose=True)
deconvolved_traces = (
self._load_npy(file_name="spks.npy", mmap_mode="r", transpose=True) if channel_name == "chan1" else None
)
cell_ids = None
if raw_traces is not None:
cell_ids = list(range(raw_traces.shape[1]))
self._roi_responses.append(_RoiResponse("raw", raw_traces, cell_ids))
if neuropil_traces is not None:
if cell_ids is None:
cell_ids = list(range(neuropil_traces.shape[1]))
self._roi_responses.append(_RoiResponse("neuropil", neuropil_traces, list(cell_ids)))
if deconvolved_traces is not None:
if cell_ids is None:
cell_ids = list(range(deconvolved_traces.shape[1]))
self._roi_responses.append(_RoiResponse("deconvolved", deconvolved_traces, list(cell_ids)))
if cell_ids is None:
cell_ids = list(range(self.stat.size))
self._roi_ids = list(cell_ids)
# rois segmented from the iamging acquired with second channel (red/anatomical) that match the first channel segmentation
redcell = self._load_npy(file_name="redcell.npy", mmap_mode="r")
if channel_name == "chan2" and redcell is not None:
self.iscell = redcell
else:
self.iscell = self._load_npy("iscell.npy", mmap_mode="r")
# Set iscell as a property for ROI classification (first column is the binary classification)
if self.iscell is not None:
self.set_property("iscell", self.iscell[:, 0], self._roi_ids)
# The name of the OpticalChannel object is "OpticalChannel" if there is only one channel, otherwise it is
# "Chan1" or "Chan2".
self._channel_names = ["OpticalChannel" if len(channel_names) == 1 else channel_name.capitalize()]
correlation_image = self._correlation_image_read()
if correlation_image is not None:
self._summary_images["correlation"] = correlation_image
image_mean_name = "meanImg" if channel_name == "chan1" else f"meanImg_chan2"
mean_image = self.options[image_mean_name] if image_mean_name in self.options else None
if mean_image is not None:
self._summary_images["mean"] = mean_image
# Create ROI representations from Suite2p native pixel masks
# Suite2p stores per-ROI sparse pixel lists - already in nwb-pixel_mask format!
pixel_masks = []
for i in range(len(cell_ids)):
pixel_mask = np.column_stack(
[
self.stat[i]["ypix"],
self.stat[i]["xpix"],
self.stat[i]["lam"],
]
)
pixel_masks.append(pixel_mask)
# Create roi_id_map
roi_id_map = {roi_id: index for index, roi_id in enumerate(cell_ids)}
self._roi_masks = _ROIMasks(
data=pixel_masks,
mask_tpe="nwb-pixel_mask",
field_of_view_shape=self._image_shape,
roi_id_map=roi_id_map,
)
def _load_npy(self, file_name: str, mmap_mode=None, transpose: bool = False, require: bool = False):
"""Load a .npy file with specified filename. Returns None if file is missing.
Parameters
----------
file_name: str
The name of the .npy file to load.
mmap_mode: str
The mode to use for memory mapping. See numpy.load for details.
transpose: bool, optional
Whether to transpose the loaded array.
require: bool, optional
Whether to raise an error if the file is missing.
Returns
-------
The loaded .npy file.
"""
file_path = self.folder_path / self.plane_name / file_name
if not file_path.exists():
if require:
raise FileNotFoundError(f"File {file_path} not found.")
return
data = np.load(file_path, mmap_mode=mmap_mode, allow_pickle=mmap_mode is None)
if transpose:
return data.T
return data
[docs]
def get_num_samples(self) -> int:
"""Get the number of samples in the recording (duration of recording).
Returns
-------
num_samples: int
Number of samples in the recording.
"""
return self._num_frames
[docs]
def get_accepted_list(self) -> list:
"""Get a list of accepted ROI ids.
Returns
-------
accepted_list: list
List of accepted ROI ids.
"""
warnings.warn(
"get_accepted_list is deprecated and will be removed in May 2026. "
"Use get_property('iscell', ids) instead to access Suite2p's native classification.",
DeprecationWarning,
stacklevel=2,
)
if self.iscell is None:
return list(self.get_roi_ids())
return [roi_id for roi_id, is_cell in zip(self.get_roi_ids(), self.iscell[:, 0]) if is_cell]
[docs]
def get_rejected_list(self) -> list:
"""Get a list of rejected ROI ids.
Returns
-------
rejected_list: list
List of rejected ROI ids.
"""
warnings.warn(
"get_rejected_list is deprecated and will be removed in May 2026. "
"Use get_property('iscell', ids) instead to access Suite2p's native classification.",
DeprecationWarning,
stacklevel=2,
)
if self.iscell is None:
return []
return [roi_id for roi_id, is_cell in zip(self.get_roi_ids(), self.iscell[:, 0]) if not is_cell]
def _correlation_image_read(self) -> np.ndarray | None:
"""Read correlation image from ops (settings) dict.
Returns
-------
img : numpy.ndarray | None
The correlation image.
"""
if "Vcorr" not in self.options:
return None
correlation_image = self.options["Vcorr"]
if (self.options["yrange"][-1], self.options["xrange"][-1]) == self._image_shape:
return correlation_image
img = np.zeros(self._image_shape, correlation_image.dtype)
img[
(self.options["Ly"] - self.options["yrange"][-1]) : (self.options["Ly"] - self.options["yrange"][0]),
self.options["xrange"][0] : self.options["xrange"][-1],
] = correlation_image
return img
@property
def roi_locations(self) -> np.ndarray:
"""Returns the center locations (x, y) of each ROI."""
return np.array([j["med"] for j in self.stat]).T.astype(int)
[docs]
def get_roi_pixel_masks(self, roi_ids=None) -> list[np.ndarray]:
pixel_mask = []
for i in range(self.get_num_rois()):
pixel_mask.append(
np.vstack(
[
self.stat[i]["ypix"],
self.stat[i]["xpix"],
self.stat[i]["lam"],
]
).T
)
if roi_ids is None:
roi_idx_ = range(self.get_num_rois())
else:
roi_idx = [np.where(np.array(i) == self.get_roi_ids())[0] for i in roi_ids]
ele = [i for i, j in enumerate(roi_idx) if j.size == 0]
roi_idx_ = [j[0] for i, j in enumerate(roi_idx) if i not in ele]
return [pixel_mask[i] for i in roi_idx_]
[docs]
def get_native_timestamps(
self, start_sample: int | None = None, end_sample: int | None = None
) -> np.ndarray | None:
# Suite2p segmentation data does not have native timestamps
return None