Source code for roiextractors.extractors.nwbextractors.nwbextractors
"""Imaging and segmentation extractors for NWB files.
Classes
-------
NwbImagingExtractor
Extracts imaging data from NWB files.
NwbSegmentationExtractor
Extracts segmentation data from NWB files.
"""
import warnings
from pathlib import Path
from typing import Iterable
import numpy as np
from lazy_ops import DatasetView
from pynwb import NWBHDF5IO
from pynwb.ophys import OnePhotonSeries, TwoPhotonSeries
from ...extraction_tools import (
PathType,
)
from ...imagingextractor import ImagingExtractor
from ...segmentationextractor import (
SegmentationExtractor,
_ROIMasks,
_RoiResponse,
)
[docs]
class NwbImagingExtractor(ImagingExtractor):
"""An imaging extractor for NWB files.
Class used to extract data from the NWB data format. Also implements a
static method to write any format specific object to NWB.
Supports both planar (2D) and volumetric (3D) imaging data from
OnePhotonSeries and TwoPhotonSeries objects.
"""
extractor_name = "NwbImaging"
def __init__(self, file_path: PathType, optical_series_name: str | None = "TwoPhotonSeries"):
"""Create ImagingExtractor object from NWB file.
Parameters
----------
file_path: str
The location of the folder containing dataset.nwb file
optical_series_name: string, optional
The name of the optical series to extract data from.
"""
ImagingExtractor.__init__(self)
self._path = file_path
self.io = NWBHDF5IO(str(self._path), "r")
self.nwbfile = self.io.read()
if optical_series_name is not None:
self._optical_series_name = optical_series_name
else:
a_names = list(self.nwbfile.acquisition)
if len(a_names) > 1:
raise ValueError("More than one acquisition found. You must specify two_photon_series.")
if len(a_names) == 0:
raise ValueError("No acquisitions found in the .nwb file.")
self._optical_series_name = a_names[0]
self.photon_series = self.nwbfile.acquisition[self._optical_series_name]
valid_photon_series_types = [OnePhotonSeries, TwoPhotonSeries]
assert any(
[isinstance(self.photon_series, photon_series_type) for photon_series_type in valid_photon_series_types]
), "The optical series must be of type pynwb.ophys.OnePhotonSeries or pynwb.ophys.TwoPhotonSeries."
# TODO if external file --> return another proper extractor (e.g. TiffImagingExtractor)
assert self.photon_series.external_file is None, "Only 'raw' format is currently supported"
# Load the two video structures that TwoPhotonSeries supports.
if len(self.photon_series.data.shape) == 3:
# Planar 2D imaging: (time, width, height)
self._num_channels = 1
self._num_samples, self._columns, self._num_rows = self.photon_series.data.shape
self.is_volumetric = False
self._num_planes = 1
elif len(self.photon_series.data.shape) == 4:
# Volumetric 3D imaging: (time, width, height, depth)
self._num_channels = 1
self._num_samples, self._columns, self._num_rows, self._num_planes = self.photon_series.data.shape
self.is_volumetric = True
# Set channel names (This should disambiguate which optical channel)
self._channel_names = [i.name for i in self.photon_series.imaging_plane.optical_channel]
# Set sampling frequency
if hasattr(self.photon_series, "timestamps") and self.photon_series.timestamps:
self._sampling_frequency = 1.0 / np.median(np.diff(self.photon_series.timestamps))
self._imaging_start_time = self.photon_series.timestamps[0]
self.set_times(np.array(self.photon_series.timestamps))
else:
self._sampling_frequency = self.photon_series.rate
self._imaging_start_time = self.photon_series.fields.get("starting_time", 0.0)
# Fill epochs dictionary
self._epochs = {}
if self.nwbfile.epochs is not None:
df_epochs = self.nwbfile.epochs.to_dataframe()
# TODO implement add_epoch() method in base class
self._epochs = {
row["tags"][0]: {
"start_frame": self.time_to_frame(row["start_time"]),
"end_frame": self.time_to_frame(row["stop_time"]),
}
for _, row in df_epochs.iterrows()
}
self._kwargs = {
"file_path": str(Path(file_path).absolute()),
"optical_series_name": optical_series_name,
}
[docs]
def make_nwb_metadata(
self, nwbfile, opts
): # TODO: refactor to use two photon series name directly rather than via opts
"""Create metadata dictionary for NWB file.
Parameters
----------
nwbfile: pynwb.NWBFile
The NWBFile object associated with the metadata.
opts: object
The options object with name of TwoPhotonSeries as an attribute.
Notes
-----
Metadata dictionary is stored in the nwb_metadata attribute.
.. deprecated:: 0.7.3
This method is deprecated and will be removed on or after May 2026.
"""
from warnings import warn
warn(
"The 'make_nwb_metadata' method is deprecated and will be removed on or after May 2026.",
FutureWarning,
stacklevel=2,
)
# Metadata dictionary - useful for constructing a nwb file
self.nwb_metadata = dict(
NWBFile=dict(
session_description=nwbfile.session_description,
identifier=nwbfile.identifier,
session_start_time=nwbfile.session_start_time,
institution=nwbfile.institution,
lab=nwbfile.lab,
),
Ophys=dict(
Device=[dict(name=dev) for dev in nwbfile.devices],
TwoPhotonSeries=[dict(name=opts.name)],
),
)
[docs]
def get_series(self, start_sample=None, end_sample=None) -> np.ndarray:
start_sample = start_sample if start_sample is not None else 0
end_sample = end_sample if end_sample is not None else self.get_num_samples()
series = self.photon_series.data
if self.is_volumetric:
# NWB: (time, width, height, depth) -> roiextractors: (time, height, width, depth)
series = series[start_sample:end_sample].transpose([0, 2, 1, 3])
else:
# NWB: (time, width, height) -> roiextractors: (time, height, width)
series = series[start_sample:end_sample].transpose([0, 2, 1])
return series
[docs]
def get_image_shape(self) -> tuple[int, int]:
"""Get the shape of the video frame (num_rows, num_columns).
Returns
-------
image_shape: tuple
Shape of the video frame (num_rows, num_columns).
"""
return (self._num_rows, self._columns) # TODO: change name of _columns to _num_cols for consistency
[docs]
def get_channel_names(self):
warnings.warn(
"get_channel_names is deprecated and will be removed in May 2026 or after.",
category=FutureWarning,
stacklevel=2,
)
return self._channel_names
[docs]
def get_native_timestamps(
self, start_sample: int | None = None, end_sample: int | None = None
) -> np.ndarray | None:
"""Retrieve the original unaltered timestamps for the data.
This method uses the NWB photon series get_timestamps() method, which properly handles
explicit timestamps, starting_time attribute, and rate-based calculations.
Parameters
----------
start_sample : int, optional
The starting sample index. If None, starts from the beginning.
end_sample : int, optional
The ending sample index. If None, goes to the end.
Returns
-------
timestamps : numpy.ndarray
The timestamps for the data stream.
"""
# Set defaults
if start_sample is None:
start_sample = 0
if end_sample is None:
end_sample = self.get_num_samples()
# Use NWB's native get_timestamps() method which handles timestamps, starting_time, and rate
timestamps = self.photon_series.get_timestamps()
return timestamps[start_sample:end_sample]
[docs]
def get_num_planes(self) -> int:
"""Get the number of depth planes.
Returns
-------
num_planes : int
The number of depth planes (1 for planar data).
"""
return self._num_planes
[docs]
class NwbSegmentationExtractor(SegmentationExtractor):
"""An segmentation extractor for NWB files."""
extractor_name = "NwbSegmentationExtractor"
installation_mesg = "" # error message when not installed
def __init__(self, file_path: PathType):
"""Create NwbSegmentationExtractor object from nwb file.
Parameters
----------
file_path: PathType
.nwb file location
"""
super().__init__()
file_path = Path(file_path)
if not file_path.is_file():
raise Exception("file does not exist")
self.file_path = file_path
self._roi_locs = None
self._io = NWBHDF5IO(str(file_path), mode="r")
self.nwbfile = self._io.read()
assert "ophys" in self.nwbfile.processing, "Ophys processing module is not in nwbfile."
ophys = self.nwbfile.processing.get("ophys")
# Extract roi_responses:
fluorescence = None
df_over_f = None
collected_responses: list[tuple[str, DatasetView]] = []
if "Fluorescence" in ophys.data_interfaces:
fluorescence = ophys.data_interfaces["Fluorescence"]
if "DfOverF" in ophys.data_interfaces:
df_over_f = ophys.data_interfaces["DfOverF"]
if fluorescence is None and df_over_f is None:
raise Exception("Could not find Fluorescence/DfOverF module in nwbfile.")
for trace_name in ("raw", "dff", "neuropil", "deconvolved", "denoised", "baseline", "background"):
trace_name_segext = "RoiResponseSeries" if trace_name in ["raw", "dff"] else trace_name.capitalize()
container = df_over_f if trace_name == "dff" else fluorescence
if container is not None and trace_name_segext in container.roi_response_series:
dataset_view = DatasetView(container.roi_response_series[trace_name_segext].data)
collected_responses.append((trace_name, dataset_view))
if self._sampling_frequency is None:
self._sampling_frequency = container.roi_response_series[trace_name_segext].rate
if not collected_responses:
raise Exception(
"could not find any of 'RoiResponseSeries'/'Dff'/'Neuropil'/ 'Background'/'Deconvolved'"
"named RoiResponseSeries in nwbfile"
)
# Extract image_mask/background:
if "ImageSegmentation" in ophys.data_interfaces:
image_seg = ophys.data_interfaces["ImageSegmentation"]
assert len(image_seg.plane_segmentations), "Could not find any PlaneSegmentation in nwbfile."
if "PlaneSegmentation" in image_seg.plane_segmentations: # this requirement in nwbfile is enforced
ps = image_seg.plane_segmentations["PlaneSegmentation"]
assert "image_mask" in ps.colnames, "Could not find any image_masks in nwbfile."
image_masks_data = DatasetView(ps["image_mask"].data).lazy_transpose([2, 1, 0])
self._roi_locs = ps["ROICentroids"] if "ROICentroids" in ps.colnames else None
accepted_data = ps["Accepted"].data[:] if "Accepted" in ps.colnames else None
rejected_data = ps["Rejected"].data[:] if "Rejected" in ps.colnames else None
if hasattr(ps, "id"):
self._roi_ids = ps.id.data[:].tolist()
else:
self._roi_ids = list(range(image_masks_data.shape[-1]))
# Create ROI representations
roi_id_map = {roi_id: index for index, roi_id in enumerate(self._roi_ids)}
self._roi_masks = _ROIMasks(
data=image_masks_data,
mask_tpe="nwb-image_mask",
field_of_view_shape=self.get_frame_shape(),
roi_id_map=roi_id_map,
)
# Set accepted/rejected as properties if columns exist in NWB file
if accepted_data is not None:
self.set_property("accepted", accepted_data.astype(bool), self._roi_ids)
if rejected_data is not None:
self.set_property("rejected", rejected_data.astype(bool), self._roi_ids)
# Extracting stored images as GrayscaleImages:
self._segmentation_images = None
if "SegmentationImages" in ophys.data_interfaces:
images_container = ophys.data_interfaces["SegmentationImages"]
self._segmentation_images = images_container.images
# Imaging plane:
if "ImagingPlane" in self.nwbfile.imaging_planes:
imaging_plane = self.nwbfile.imaging_planes["ImagingPlane"]
self._channel_names = [i.name for i in imaging_plane.optical_channel]
if self._roi_ids is None and self._roi_masks is not None:
self._roi_ids = list(range(self._roi_masks.num_rois))
if self._roi_ids is None:
raise ValueError("Unable to determine ROI ids from NWB file.")
for trace_name, dataset in collected_responses:
data = dataset
roi_ids = list(self._roi_ids)
self._roi_responses.append(_RoiResponse(trace_name, data, roi_ids))
[docs]
def get_accepted_list(self) -> list:
"""Get a list of accepted ROI ids.
Returns
-------
accepted_list: list
List of accepted ROI ids.
"""
warnings.warn(
"get_accepted_list is deprecated and will be removed in May 2026. "
"Use get_property('accepted', ids) instead to access NWB's acceptance data.",
DeprecationWarning,
stacklevel=2,
)
if "accepted" not in self.get_property_keys():
return list(self.get_roi_ids())
accepted = self.get_property("accepted", self.get_roi_ids())
return [roi_id for roi_id, is_accepted in zip(self.get_roi_ids(), accepted) if is_accepted]
[docs]
def get_rejected_list(self) -> list:
"""Get a list of rejected ROI ids.
Returns
-------
rejected_list: list
List of rejected ROI ids.
"""
warnings.warn(
"get_rejected_list is deprecated and will be removed in May 2026. "
"Use get_property('rejected', ids) instead to access NWB's rejection data.",
DeprecationWarning,
stacklevel=2,
)
if "rejected" not in self.get_property_keys():
return []
rejected = self.get_property("rejected", self.get_roi_ids())
return [roi_id for roi_id, is_rejected in zip(self.get_roi_ids(), rejected) if is_rejected]
[docs]
def get_images_dict(self):
"""Return traces as a dictionary with key as the name of the ROIResponseSeries.
Returns
-------
images_dict: dict
dictionary with key, values representing different types of Images used in segmentation:
Mean, Correlation image
"""
images_dict = super().get_images_dict()
if self._segmentation_images is not None:
images_dict.update(
(image_name, image_data[:].T) for image_name, image_data in self._segmentation_images.items()
)
return images_dict
[docs]
def get_roi_locations(self, roi_ids: Iterable[int] | None = None) -> np.ndarray:
"""Return the locations of the Regions of Interest (ROIs).
Parameters
----------
roi_ids: array_like
A list or 1D array of ids of the ROIs. Length is the number of ROIs
requested.
Returns
-------
roi_locs: numpy.ndarray
2-D array: 2 X no_ROIs. The pixel ids (x,y) where the centroid of the ROI is.
"""
if self._roi_locs is None:
return
all_ids = self.get_roi_ids()
roi_idxs = slice(None) if roi_ids is None else [all_ids.index(i) for i in roi_ids]
# ROIExtractors uses height x width x (depth), but NWB uses width x height x depth
tranpose_image_convention = (1, 0) if len(self.get_image_shape()) == 2 else (1, 0, 2)
return np.array(self._roi_locs.data)[roi_idxs, tranpose_image_convention].T # h5py fancy indexing is slow
[docs]
def get_frame_shape(self):
"""Get the shape of the video frame (num_rows, num_columns).
Returns
-------
frame_shape: tuple
Shape of the video frame (num_rows, num_columns).
"""
return self._roi_masks.field_of_view_shape
[docs]
def get_image_shape(self):
"""Get the shape of the video frame (num_rows, num_columns).
Returns
-------
image_shape: tuple
Shape of the video frame (num_rows, num_columns).
"""
return self._roi_masks.field_of_view_shape
[docs]
def get_native_timestamps(
self, start_sample: int | None = None, end_sample: int | None = None
) -> np.ndarray | None:
# NWB files may have timestamps but need to check the specific implementation
# For now, return None to use calculated timestamps based on sampling frequency
# TODO: check if the RoiResponseSeries has timestamps
return None