Source code for roiextractors.extractors.nwbextractors.nwbextractors

"""Imaging and segmentation extractors for NWB files.

Classes
-------
NwbImagingExtractor
    Extracts imaging data from NWB files.
NwbSegmentationExtractor
    Extracts segmentation data from NWB files.
"""

import warnings
from pathlib import Path
from typing import Iterable

import numpy as np
from lazy_ops import DatasetView
from pynwb import NWBHDF5IO
from pynwb.ophys import OnePhotonSeries, TwoPhotonSeries

from ...extraction_tools import (
    PathType,
)
from ...imagingextractor import ImagingExtractor
from ...segmentationextractor import (
    SegmentationExtractor,
    _ROIMasks,
    _RoiResponse,
)


[docs] class NwbImagingExtractor(ImagingExtractor): """An imaging extractor for NWB files. Class used to extract data from the NWB data format. Also implements a static method to write any format specific object to NWB. Supports both planar (2D) and volumetric (3D) imaging data from OnePhotonSeries and TwoPhotonSeries objects. """ extractor_name = "NwbImaging" def __init__(self, file_path: PathType, optical_series_name: str | None = "TwoPhotonSeries"): """Create ImagingExtractor object from NWB file. Parameters ---------- file_path: str The location of the folder containing dataset.nwb file optical_series_name: string, optional The name of the optical series to extract data from. """ ImagingExtractor.__init__(self) self._path = file_path self.io = NWBHDF5IO(str(self._path), "r") self.nwbfile = self.io.read() if optical_series_name is not None: self._optical_series_name = optical_series_name else: a_names = list(self.nwbfile.acquisition) if len(a_names) > 1: raise ValueError("More than one acquisition found. You must specify two_photon_series.") if len(a_names) == 0: raise ValueError("No acquisitions found in the .nwb file.") self._optical_series_name = a_names[0] self.photon_series = self.nwbfile.acquisition[self._optical_series_name] valid_photon_series_types = [OnePhotonSeries, TwoPhotonSeries] assert any( [isinstance(self.photon_series, photon_series_type) for photon_series_type in valid_photon_series_types] ), "The optical series must be of type pynwb.ophys.OnePhotonSeries or pynwb.ophys.TwoPhotonSeries." # TODO if external file --> return another proper extractor (e.g. TiffImagingExtractor) assert self.photon_series.external_file is None, "Only 'raw' format is currently supported" # Load the two video structures that TwoPhotonSeries supports. if len(self.photon_series.data.shape) == 3: # Planar 2D imaging: (time, width, height) self._num_channels = 1 self._num_samples, self._columns, self._num_rows = self.photon_series.data.shape self.is_volumetric = False self._num_planes = 1 elif len(self.photon_series.data.shape) == 4: # Volumetric 3D imaging: (time, width, height, depth) self._num_channels = 1 self._num_samples, self._columns, self._num_rows, self._num_planes = self.photon_series.data.shape self.is_volumetric = True # Set channel names (This should disambiguate which optical channel) self._channel_names = [i.name for i in self.photon_series.imaging_plane.optical_channel] # Set sampling frequency if hasattr(self.photon_series, "timestamps") and self.photon_series.timestamps: self._sampling_frequency = 1.0 / np.median(np.diff(self.photon_series.timestamps)) self._imaging_start_time = self.photon_series.timestamps[0] self.set_times(np.array(self.photon_series.timestamps)) else: self._sampling_frequency = self.photon_series.rate self._imaging_start_time = self.photon_series.fields.get("starting_time", 0.0) # Fill epochs dictionary self._epochs = {} if self.nwbfile.epochs is not None: df_epochs = self.nwbfile.epochs.to_dataframe() # TODO implement add_epoch() method in base class self._epochs = { row["tags"][0]: { "start_frame": self.time_to_frame(row["start_time"]), "end_frame": self.time_to_frame(row["stop_time"]), } for _, row in df_epochs.iterrows() } self._kwargs = { "file_path": str(Path(file_path).absolute()), "optical_series_name": optical_series_name, }
[docs] def __del__(self): """Close the NWB file.""" self.io.close()
[docs] def make_nwb_metadata( self, nwbfile, opts ): # TODO: refactor to use two photon series name directly rather than via opts """Create metadata dictionary for NWB file. Parameters ---------- nwbfile: pynwb.NWBFile The NWBFile object associated with the metadata. opts: object The options object with name of TwoPhotonSeries as an attribute. Notes ----- Metadata dictionary is stored in the nwb_metadata attribute. .. deprecated:: 0.7.3 This method is deprecated and will be removed on or after May 2026. """ from warnings import warn warn( "The 'make_nwb_metadata' method is deprecated and will be removed on or after May 2026.", FutureWarning, stacklevel=2, ) # Metadata dictionary - useful for constructing a nwb file self.nwb_metadata = dict( NWBFile=dict( session_description=nwbfile.session_description, identifier=nwbfile.identifier, session_start_time=nwbfile.session_start_time, institution=nwbfile.institution, lab=nwbfile.lab, ), Ophys=dict( Device=[dict(name=dev) for dev in nwbfile.devices], TwoPhotonSeries=[dict(name=opts.name)], ), )
[docs] def get_series(self, start_sample=None, end_sample=None) -> np.ndarray: start_sample = start_sample if start_sample is not None else 0 end_sample = end_sample if end_sample is not None else self.get_num_samples() series = self.photon_series.data if self.is_volumetric: # NWB: (time, width, height, depth) -> roiextractors: (time, height, width, depth) series = series[start_sample:end_sample].transpose([0, 2, 1, 3]) else: # NWB: (time, width, height) -> roiextractors: (time, height, width) series = series[start_sample:end_sample].transpose([0, 2, 1]) return series
[docs] def get_image_shape(self) -> tuple[int, int]: """Get the shape of the video frame (num_rows, num_columns). Returns ------- image_shape: tuple Shape of the video frame (num_rows, num_columns). """ return (self._num_rows, self._columns) # TODO: change name of _columns to _num_cols for consistency
[docs] def get_num_samples(self): return self._num_samples
[docs] def get_sampling_frequency(self): return self._sampling_frequency
[docs] def get_channel_names(self): warnings.warn( "get_channel_names is deprecated and will be removed in May 2026 or after.", category=FutureWarning, stacklevel=2, ) return self._channel_names
[docs] def get_native_timestamps( self, start_sample: int | None = None, end_sample: int | None = None ) -> np.ndarray | None: """Retrieve the original unaltered timestamps for the data. This method uses the NWB photon series get_timestamps() method, which properly handles explicit timestamps, starting_time attribute, and rate-based calculations. Parameters ---------- start_sample : int, optional The starting sample index. If None, starts from the beginning. end_sample : int, optional The ending sample index. If None, goes to the end. Returns ------- timestamps : numpy.ndarray The timestamps for the data stream. """ # Set defaults if start_sample is None: start_sample = 0 if end_sample is None: end_sample = self.get_num_samples() # Use NWB's native get_timestamps() method which handles timestamps, starting_time, and rate timestamps = self.photon_series.get_timestamps() return timestamps[start_sample:end_sample]
[docs] def get_num_planes(self) -> int: """Get the number of depth planes. Returns ------- num_planes : int The number of depth planes (1 for planar data). """ return self._num_planes
[docs] class NwbSegmentationExtractor(SegmentationExtractor): """An segmentation extractor for NWB files.""" extractor_name = "NwbSegmentationExtractor" installation_mesg = "" # error message when not installed def __init__(self, file_path: PathType): """Create NwbSegmentationExtractor object from nwb file. Parameters ---------- file_path: PathType .nwb file location """ super().__init__() file_path = Path(file_path) if not file_path.is_file(): raise Exception("file does not exist") self.file_path = file_path self._roi_locs = None self._io = NWBHDF5IO(str(file_path), mode="r") self.nwbfile = self._io.read() assert "ophys" in self.nwbfile.processing, "Ophys processing module is not in nwbfile." ophys = self.nwbfile.processing.get("ophys") # Extract roi_responses: fluorescence = None df_over_f = None collected_responses: list[tuple[str, DatasetView]] = [] if "Fluorescence" in ophys.data_interfaces: fluorescence = ophys.data_interfaces["Fluorescence"] if "DfOverF" in ophys.data_interfaces: df_over_f = ophys.data_interfaces["DfOverF"] if fluorescence is None and df_over_f is None: raise Exception("Could not find Fluorescence/DfOverF module in nwbfile.") for trace_name in ("raw", "dff", "neuropil", "deconvolved", "denoised", "baseline", "background"): trace_name_segext = "RoiResponseSeries" if trace_name in ["raw", "dff"] else trace_name.capitalize() container = df_over_f if trace_name == "dff" else fluorescence if container is not None and trace_name_segext in container.roi_response_series: dataset_view = DatasetView(container.roi_response_series[trace_name_segext].data) collected_responses.append((trace_name, dataset_view)) if self._sampling_frequency is None: self._sampling_frequency = container.roi_response_series[trace_name_segext].rate if not collected_responses: raise Exception( "could not find any of 'RoiResponseSeries'/'Dff'/'Neuropil'/ 'Background'/'Deconvolved'" "named RoiResponseSeries in nwbfile" ) # Extract image_mask/background: if "ImageSegmentation" in ophys.data_interfaces: image_seg = ophys.data_interfaces["ImageSegmentation"] assert len(image_seg.plane_segmentations), "Could not find any PlaneSegmentation in nwbfile." if "PlaneSegmentation" in image_seg.plane_segmentations: # this requirement in nwbfile is enforced ps = image_seg.plane_segmentations["PlaneSegmentation"] assert "image_mask" in ps.colnames, "Could not find any image_masks in nwbfile." image_masks_data = DatasetView(ps["image_mask"].data).lazy_transpose([2, 1, 0]) self._roi_locs = ps["ROICentroids"] if "ROICentroids" in ps.colnames else None accepted_data = ps["Accepted"].data[:] if "Accepted" in ps.colnames else None rejected_data = ps["Rejected"].data[:] if "Rejected" in ps.colnames else None if hasattr(ps, "id"): self._roi_ids = ps.id.data[:].tolist() else: self._roi_ids = list(range(image_masks_data.shape[-1])) # Create ROI representations roi_id_map = {roi_id: index for index, roi_id in enumerate(self._roi_ids)} self._roi_masks = _ROIMasks( data=image_masks_data, mask_tpe="nwb-image_mask", field_of_view_shape=self.get_frame_shape(), roi_id_map=roi_id_map, ) # Set accepted/rejected as properties if columns exist in NWB file if accepted_data is not None: self.set_property("accepted", accepted_data.astype(bool), self._roi_ids) if rejected_data is not None: self.set_property("rejected", rejected_data.astype(bool), self._roi_ids) # Extracting stored images as GrayscaleImages: self._segmentation_images = None if "SegmentationImages" in ophys.data_interfaces: images_container = ophys.data_interfaces["SegmentationImages"] self._segmentation_images = images_container.images # Imaging plane: if "ImagingPlane" in self.nwbfile.imaging_planes: imaging_plane = self.nwbfile.imaging_planes["ImagingPlane"] self._channel_names = [i.name for i in imaging_plane.optical_channel] if self._roi_ids is None and self._roi_masks is not None: self._roi_ids = list(range(self._roi_masks.num_rois)) if self._roi_ids is None: raise ValueError("Unable to determine ROI ids from NWB file.") for trace_name, dataset in collected_responses: data = dataset roi_ids = list(self._roi_ids) self._roi_responses.append(_RoiResponse(trace_name, data, roi_ids))
[docs] def __del__(self): """Close the NWB file.""" self._io.close()
[docs] def get_accepted_list(self) -> list: """Get a list of accepted ROI ids. Returns ------- accepted_list: list List of accepted ROI ids. """ warnings.warn( "get_accepted_list is deprecated and will be removed in May 2026. " "Use get_property('accepted', ids) instead to access NWB's acceptance data.", DeprecationWarning, stacklevel=2, ) if "accepted" not in self.get_property_keys(): return list(self.get_roi_ids()) accepted = self.get_property("accepted", self.get_roi_ids()) return [roi_id for roi_id, is_accepted in zip(self.get_roi_ids(), accepted) if is_accepted]
[docs] def get_rejected_list(self) -> list: """Get a list of rejected ROI ids. Returns ------- rejected_list: list List of rejected ROI ids. """ warnings.warn( "get_rejected_list is deprecated and will be removed in May 2026. " "Use get_property('rejected', ids) instead to access NWB's rejection data.", DeprecationWarning, stacklevel=2, ) if "rejected" not in self.get_property_keys(): return [] rejected = self.get_property("rejected", self.get_roi_ids()) return [roi_id for roi_id, is_rejected in zip(self.get_roi_ids(), rejected) if is_rejected]
[docs] def get_images_dict(self): """Return traces as a dictionary with key as the name of the ROIResponseSeries. Returns ------- images_dict: dict dictionary with key, values representing different types of Images used in segmentation: Mean, Correlation image """ images_dict = super().get_images_dict() if self._segmentation_images is not None: images_dict.update( (image_name, image_data[:].T) for image_name, image_data in self._segmentation_images.items() ) return images_dict
[docs] def get_roi_locations(self, roi_ids: Iterable[int] | None = None) -> np.ndarray: """Return the locations of the Regions of Interest (ROIs). Parameters ---------- roi_ids: array_like A list or 1D array of ids of the ROIs. Length is the number of ROIs requested. Returns ------- roi_locs: numpy.ndarray 2-D array: 2 X no_ROIs. The pixel ids (x,y) where the centroid of the ROI is. """ if self._roi_locs is None: return all_ids = self.get_roi_ids() roi_idxs = slice(None) if roi_ids is None else [all_ids.index(i) for i in roi_ids] # ROIExtractors uses height x width x (depth), but NWB uses width x height x depth tranpose_image_convention = (1, 0) if len(self.get_image_shape()) == 2 else (1, 0, 2) return np.array(self._roi_locs.data)[roi_idxs, tranpose_image_convention].T # h5py fancy indexing is slow
[docs] def get_frame_shape(self): """Get the shape of the video frame (num_rows, num_columns). Returns ------- frame_shape: tuple Shape of the video frame (num_rows, num_columns). """ return self._roi_masks.field_of_view_shape
[docs] def get_image_shape(self): """Get the shape of the video frame (num_rows, num_columns). Returns ------- image_shape: tuple Shape of the video frame (num_rows, num_columns). """ return self._roi_masks.field_of_view_shape
[docs] def get_native_timestamps( self, start_sample: int | None = None, end_sample: int | None = None ) -> np.ndarray | None: # NWB files may have timestamps but need to check the specific implementation # For now, return None to use calculated timestamps based on sampling frequency # TODO: check if the RoiResponseSeries has timestamps return None