Source code for roiextractors.extractors.simaextractor.simasegmentationextractor

"""A segmentation extractor for Sima.

Classes
-------
SimaSegmentationExtractor
    A segmentation extractor for Sima.
"""

import importlib
import os
import pickle
import re
from shutil import copyfile

import numpy as np

from ...extraction_tools import PathType
from ...segmentationextractor import (
    SegmentationExtractor,
    _ROIMasks,
    _RoiResponse,
)


[docs] class SimaSegmentationExtractor(SegmentationExtractor): """A segmentation extractor for Sima. This class inherits from the SegmentationExtractor class, having all its functionality specifically applied to the dataset output from the 'SIMA' ROI segmentation method. """ extractor_name = "SimaSegmentation" mode = "file" # error message when not installed installation_mesg = "To use the SimaSegmentationExtractor install sima and dill: \n\n pip install sima/dill\n\n" def __init__(self, file_path: PathType, sima_segmentation_label: str = "auto_ROIs"): """Create a SegmentationExtractor instance from a sima file. Parameters ---------- file_path: str or Path The location of the folder containing dataset.sima file and the raw image file(s) (tiff, h5, .zip) sima_segmentation_label: str name of the ROIs in the dataset from which to extract all ROI info """ sima_spec = importlib.util.find_spec("sima") dill_spec = importlib.util.find_spec("dill") if sima_spec is not None and dill_spec is not None: HAVE_SIMA = True else: HAVE_SIMA = False assert HAVE_SIMA, self.installation_mesg SegmentationExtractor.__init__(self) self.file_path = file_path self._convert_sima(file_path) self._dataset_file = self._file_extractor_read() self._channel_names = [str(i) for i in self._dataset_file.channel_names] self._num_of_channels = len(self._channel_names) self.sima_segmentation_label = sima_segmentation_label # Read traces first to get number of ROIs traces = self._trace_extractor_read() cell_ids = list(range(traces.shape[1])) self._roi_ids = cell_ids self._roi_responses.append(_RoiResponse("raw", traces, cell_ids)) # Create ROI representations from dense image masks image_masks_data = self._image_mask_extractor_read() # (H, W, N) array roi_id_map = {roi_id: index for index, roi_id in enumerate(cell_ids)} self._roi_masks = _ROIMasks( data=image_masks_data, mask_tpe="nwb-image_mask", field_of_view_shape=self.get_frame_shape(), roi_id_map=roi_id_map, ) mean_image = self._summary_image_read() if mean_image is not None: self._summary_images["mean"] = mean_image @staticmethod def _convert_sima(old_pkl_loc): """Convert the sima file to python 3 pickle. This function is used to convert python 2 pickles to python 3 pickles. Forward compatibility of '*.sima' files containing .pkl dataset, rois, sequences, signals, time_averages. Replaces the pickle file with a python 3 version with the same name. Saves the old Py2 pickle as 'oldpicklename_p2.pkl' Parameters ---------- old_pkl_loc: str Path of the pickle file to be converted """ import dill # Make a name for the new pickle old_pkl_loc = old_pkl_loc + "/" for dirpath, dirnames, filenames in os.walk(old_pkl_loc): _exit = [True for file in filenames if "_p2.pkl" in file] if True in _exit: print("pickle already in Py3 format") continue for file in filenames: if ".pkl" in file: old_pkl = os.path.join(dirpath, file) print(old_pkl) # Make a name for the new pickle new_pkl_name = os.path.splitext(os.path.basename(old_pkl))[0] + "_p2.pkl" base_directory = os.path.split(old_pkl)[0] new_pkl = base_directory + "/" + new_pkl_name # Convert Python 2 "ObjectType" to Python 3 object dill._dill._reverse_typemap["ObjectType"] = object # Open the pickle using latin1 encoding with open(old_pkl, "rb") as f: loaded = pickle.load(f, encoding="latin1") copyfile(old_pkl, new_pkl) os.remove(f.name) # Re-save as Python 3 pickle with open(old_pkl, "wb") as outfile: pickle.dump(loaded, outfile) def _file_extractor_read(self): """Read the sima file and return the sima.ImagingDataset object.""" import sima _img_dataset = sima.ImagingDataset.load(self.file_path) _img_dataset._savedir = self.file_path return _img_dataset def _image_mask_extractor_read(self): """Read the image mask from the sima.ImagingDataset object (self._dataset_file).""" _sima_rois = self._dataset_file.ROIs if len(_sima_rois) > 1: if self.sima_segmentation_label in list(_sima_rois.keys()): _sima_rois_data = _sima_rois[self.sima_segmentation_label] else: raise Exception("Enter a valid name of ROIs from: {}".format(",".join(list(_sima_rois.keys())))) elif len(_sima_rois) == 1: _sima_rois_data = list(_sima_rois.values())[0] self.sima_segmentation_label = list(_sima_rois.keys())[0] else: raise Exception("no ROIs found in the sima file") image_masks_ = [np.squeeze(np.array(roi_dat)).T for roi_dat in _sima_rois_data] return np.array(image_masks_).T def _trace_extractor_read(self): """Read the traces from the sima.ImagingDataset object (self._dataset_file).""" for channel_now in self._channel_names: for labels in self._dataset_file.signals(channel=channel_now): if labels: _active_channel = channel_now break print( "extracting signal from channel {} from {} no of channels".format( _active_channel, self._num_of_channels ) ) # label for the extraction method in SIMA: for labels in self._dataset_file.signals(channel=_active_channel): _count = 0 if not re.findall(r"[\d]{4}-[\d]{2}-[\d]{2}-", labels): _count = _count + 1 _label = labels break if _count > 1: print("multiple labels found for extract method using {}".format(_label)) elif _count == 0: print("no label found for extract method using {}".format(labels)) _label = labels extracted_signals = np.array(self._dataset_file.signals(channel=_active_channel)[_label]["raw"][0]) return extracted_signals def _summary_image_read(self): """Read the summary image from the sima.ImagingDataset object (self._dataset_file).""" summary_image = np.squeeze(self._dataset_file.time_averages[0]).T return np.array(summary_image).T
[docs] def get_frame_shape(self): """Get the frame shape (height, width) of the movie. Returns ------- tuple The frame shape as (height, width). """ return self._roi_masks.field_of_view_shape
[docs] def get_native_timestamps( self, start_sample: int | None = None, end_sample: int | None = None ) -> np.ndarray | None: """Retrieve the original unaltered timestamps for the data in this interface. Returns ------- timestamps: numpy.ndarray or None The timestamps for the data stream, or None if native timestamps are not available. """ # SIMA segmentation data does not have native timestamps return None