diff --git a/pyproject.toml b/pyproject.toml index 0fb2c4ee..088d0f70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "attrs", "h5py>=3.8.0", "pynwb", - "ndx-pose", + "ndx-pose @ git+https://github.com/rly/ndx-pose@a847ad4be75e60ef9e413b8cbfc99c616fc9fd05", "pandas", "simplejson", "imageio", diff --git a/sleap_io/io/main.py b/sleap_io/io/main.py index 7fd702f7..cc83bd02 100644 --- a/sleap_io/io/main.py +++ b/sleap_io/io/main.py @@ -47,34 +47,50 @@ def save_slp( return slp.write_labels(filename, labels, embed=embed) -def load_nwb(filename: str) -> Labels: +def load_nwb(filename: str, as_training: Optional[bool]=None) -> Labels: """Load an NWB dataset as a SLEAP `Labels` object. Args: filename: Path to a NWB file (`.nwb`). + as_training: If `True`, load the dataset as a training dataset. Returns: The dataset as a `Labels` object. """ - return nwb.read_nwb(filename) + if as_training is None: + return + + if as_training: + return nwb.read_nwb_training(filename) + else: + return nwb.read_nwb(filename) -def save_nwb(labels: Labels, filename: str, append: bool = True): +def save_nwb(labels: Labels, filename: str, as_training: bool = None, append: bool = True, **kwargs): """Save a SLEAP dataset to NWB format. Args: labels: A SLEAP `Labels` object (see `load_slp`). filename: Path to NWB file to save to. Must end in `.nwb`. + as_training: If `True`, save the dataset as a training dataset. append: If `True` (the default), append to existing NWB file. File will be created if it does not exist. See also: nwb.write_nwb, nwb.append_nwb """ - if append and Path(filename).exists(): - nwb.append_nwb(labels, filename) - else: - nwb.write_nwb(labels, filename) + if as_training: + pose_training = nwb.labels_to_pose_training(labels, **kwargs) + if append and Path(filename).exists(): + nwb.append_nwb_training(pose_training, filename, **kwargs) + else: + nwb.write_nwb_training(pose_training, filename, **kwargs) + else: + if append and Path(filename).exists(): + nwb.append_nwb(labels, filename, **kwargs) + else: + nwb.write_nwb(labels, filename) + def load_labelstudio( filename: str, skeleton: Optional[Union[Skeleton, list[str]]] = None @@ -190,6 +206,8 @@ def load_file( return load_jabs(filename, **kwargs) elif format == "video": return load_video(filename, **kwargs) + else: + raise ValueError(f"Unknown format '{format}' for filename: '{filename}'.") def save_file( @@ -219,7 +237,7 @@ def save_file( if format == "slp": save_slp(labels, filename, **kwargs) - elif format == "nwb": + elif format in ("nwb", "nwb_training", "nwb_predictions"): save_nwb(labels, filename, **kwargs) elif format == "labelstudio": save_labelstudio(labels, filename, **kwargs) diff --git a/sleap_io/io/nwb.py b/sleap_io/io/nwb.py index 2d87e10b..6b60cc27 100644 --- a/sleap_io/io/nwb.py +++ b/sleap_io/io/nwb.py @@ -7,69 +7,214 @@ import uuid import re -import pandas as pd # type: ignore[import] +import pandas as pd import numpy as np try: from numpy.typing import ArrayLike except ImportError: ArrayLike = np.ndarray -from pynwb import NWBFile, NWBHDF5IO, ProcessingModule # type: ignore[import] -from ndx_pose import (PoseEstimationSeries, - PoseEstimation, - TrainingFrame, - TrainingFrames, - PoseTraining, - SourceVideo, -) # type: ignore[import] + +from pynwb import NWBFile, NWBHDF5IO, ProcessingModule +from pynwb.image import ImageSeries +from pynwb.testing.mock.utils import name_generator + +from ndx_pose import ( + PoseEstimationSeries, + PoseEstimation, + Skeleton as NWBSkeleton, + SkeletonInstance, + SkeletonInstances, + TrainingFrame, + TrainingFrames, + PoseTraining, + SourceVideos, +) from sleap_io import ( Labels, Video, LabeledFrame, Track, - Skeleton, + Skeleton as SLEAPSkeleton, Instance, PredictedInstance, + Edge, + Node, ) from sleap_io.io.utils import convert_predictions_to_dataframe -def convert_nwb(nwb_data_structure: Union[TrainingFrame, - TrainingFrames, - PoseTraining, - SourceVideo - ]) -> Union[LabeledFrame, List[LabeledFrame], Labels, Video]: - """Converts an NWB instance to its corresponding SLEAP instance.""" - - def convert_frame(frame: TrainingFrame) -> LabeledFrame: - """ - Converts an NWB TrainingFrame instance to a LabeledFrame instance. - """ - return LabeledFrame( - video=Video(filename=frame.source_video.data), - frame_idx=frame.frame_number.data, - instances=[ - PredictedInstance.from_numpy( - points=frame.points.data, - point_scores=frame.confidence.data, - instance_score=frame.confidence.data.mean(), - skeleton=Skeleton( - nodes=frame.skeleton.nodes.data, - edges=frame.skeleton.edges.data, - ), - ) - ], + +def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ignore[return] + """Creates a Labels object from an NWB PoseTraining object. + + Args: + pose_training: An NWB PoseTraining object. + + Returns: + A Labels object. + """ + labeled_frames = [] + for training_frame in pose_training.training_frames: + video = Video(filename=f"{training_frame.source_videos}") + frame_idx = training_frame # TODO + instances = [ + Instance.from_numpy( + points=instance.node_locations, + skeleton=nwb_skeleton_to_sleap(instance.skeleton), + ) + for instance in training_frame.skeleton_instances + ] + labeled_frames.append( + LabeledFrame(video=video, frame_idx=frame_idx, instances=instances) ) - if isinstance(nwb_data_structure, TrainingFrame): - return convert_frame(nwb_data_structure) - elif isinstance(nwb_data_structure, TrainingFrames): - return [convert_frame(frame) for frame in nwb_data_structure.training_frames] - elif isinstance(nwb_data_structure, PoseTraining): - return Labels([convert_frame(frame) for frame in nwb_data_structure.training_frames]) - elif isinstance(nwb_data_structure, SourceVideo): - return Video(filename=nwb_data_structure.data) - else: - raise ValueError(f"Cannot convert {type(nwb_data_structure)} to SLEAP instance.") + return Labels(labeled_frames=labeled_frames) + + +def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: ignore[return] + """Converts an NWB skeleton to a SLEAP skeleton. + + Args: + skeleton: An NWB skeleton. + + Returns: + A SLEAP skeleton. + """ + nodes = [Node(name=node) for node in skeleton.nodes] + edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges] + return SLEAPSkeleton( + nodes=nodes, + edges=edges, + name=skeleton.name, + ) + + +def labels_to_pose_training(labels: Labels, **kwargs) -> PoseTraining: # type: ignore[return] + """Creates an NWB PoseTraining object from a Labels object. + + Args: + labels: A Labels object. + filename: The filename of the source video. + + Returns: + A PoseTraining object. + """ + training_frame_list = [] + for i, labeled_frame in enumerate(labels.labeled_frames): + training_frame_name = name_generator("training_frame") + training_frame_annotator = f"{training_frame_name}{i}" + skeleton_instances_list = [] + for instance in labeled_frame.instances: + if isinstance(instance, PredictedInstance): + continue + skeleton_instance = instance_to_skeleton_instance(instance) + skeleton_instances_list.append(skeleton_instance) + + training_frame_skeleton_instances = SkeletonInstances( + skeleton_instances=skeleton_instances_list + ) + training_frame_video = labeled_frame.video + training_frame_video_index = labeled_frame.frame_idx + training_frame = TrainingFrame( + name=training_frame_name, + annotator=training_frame_annotator, + skeleton_instances=training_frame_skeleton_instances, + source_video=ImageSeries( + name=training_frame_name, + description=training_frame_annotator, + unit="NA", + format="external", + external_file=[training_frame_video.filename], + dimension=[ + training_frame_video.shape[1], + training_frame_video.shape[2], + ], + starting_frame=[0], + rate=30.0, + ), + source_video_frame_index=training_frame_video_index, + ) + training_frame_list.append(training_frame) + + training_frames = TrainingFrames(training_frames=training_frame_list) + pose_training = PoseTraining( + training_frames=training_frames, + source_videos=videos_to_source_videos(labels.videos), + ) + return pose_training + + +def slp_skeleton_to_nwb(skeleton: SLEAPSkeleton) -> NWBSkeleton: # type: ignore[return] + """Converts SLEAP skeleton to NWB skeleton. + + Args: + skeleton: A SLEAP skeleton. + + Returns: + An NWB skeleton. + """ + nwb_edges: list[list[int, int]] + + skeleton_edges = {i: node for i, node in enumerate(skeleton.nodes)} + nwb_edges = [] + for i, source in skeleton_edges.items(): + for destination in list(skeleton_edges.values())[i:]: + if Edge(source, destination) in skeleton.edges: + nwb_edges.append([i, list(skeleton_edges.values()).index(destination)]) + + return NWBSkeleton( + name=f"Nodes {skeleton.nodes[0].name}, ..., {skeleton.nodes[-1].name}", + nodes=skeleton.node_names, + edges=np.array(nwb_edges, dtype=np.uint8), + ) + + +def instance_to_skeleton_instance(instance: Instance) -> SkeletonInstance: # type: ignore[return] + """Converts a SLEAP Instance to an NWB SkeletonInstance. + + Args: + instance: A SLEAP Instance. + + Returns: + An NWB SkeletonInstance. + """ + skeleton = slp_skeleton_to_nwb(instance.skeleton) + points_list = list(instance.points.values()) + node_locs = [[point.x, point.y] for point in points_list] + np_node_locations = np.array(node_locs) + return SkeletonInstance( + name=name_generator("skeleton_instance"), + id=np.uint(10), + node_locations=np_node_locations, + node_visibility=[point.visible for point in instance.points.values()], + skeleton=skeleton, + ) + + +def videos_to_source_videos(videos: List[Video]) -> SourceVideos: # type: ignore[return] + """Converts a list of SLEAP Videos to NWB SourceVideos. + + Args: + videos: A list of SLEAP Videos. + + Returns: + An NWB SourceVideos object. + """ + source_videos = [] + for i, video in enumerate(videos): + image_series = ImageSeries( + name=f"video_{i}", + description="Video file", + unit="NA", + format="external", + external_file=[video.filename], + dimension=[video.backend.img_shape[0], video.backend.img_shape[1]], + starting_frame=[0], + rate=30.0, + ) + source_videos.append(image_series) + return SourceVideos(image_series=source_videos) + def get_timestamps(series: PoseEstimationSeries) -> np.ndarray: """Return a vector of timestamps for a `PoseEstimationSeries`.""" @@ -148,7 +293,7 @@ def read_nwb(path: str) -> Labels: ) # Create skeleton - skeleton = Skeleton( + skeleton = SLEAPSkeleton( nodes=node_names, edges=edge_inds, ) @@ -169,7 +314,7 @@ def read_nwb(path: str) -> Labels: if np.isnan(inst_pts).all(): continue insts.append( - PredictedInstance.from_numpy( + Instance.from_numpy( points=inst_pts, # (n_nodes, 2) point_scores=inst_confs, # (n_nodes,) instance_score=inst_confs.mean(), # () @@ -216,7 +361,7 @@ def write_nwb( or the sampling rate with key`video_sample_rate`. e.g. pose_estimation_metadata["video_timestamps"] = np.array(timestamps) - or pose_estimation_metadata["video_sample_rate] = 15 # In Hz + or pose_estimation_metadata["video_sample_rate"] = 15 # In Hz 2) The other use of this dictionary is to ovewrite sleap-io default arguments for the PoseEstimation container. @@ -246,6 +391,25 @@ def write_nwb( io.write(nwbfile) +def write_nwb_training(pose_training: PoseTraining, # type: ignore[return] + nwbfile_path: str, + nwb_file_kwargs: Optional[dict], + pose_estimation_metadata: Optional[dict] = None, + ): + """Writes data from a `PoseTraining` object to an NWB file. + + Args: + pose_training: A `PoseTraining` object. + nwbfile_path: The path where the nwb file is to be written. + """ + nwb_file_kwargs = nwb_file_kwargs or {} + + nwbfile = NWBFile(**nwb_file_kwargs) + nwbfile = append_nwb_data(pose_training, nwbfile, pose_estimation_metadata) + with NWBHDF5IO(str(nwbfile_path), "w") as io: + io.write(nwbfile) + + def append_nwb_data( labels: Labels, nwbfile: NWBFile, pose_estimation_metadata: Optional[dict] = None ) -> NWBFile: @@ -280,7 +444,12 @@ def append_nwb_data( sleap_version = provenance.get("sleap_version", None) default_metadata["source_software_version"] = sleap_version - labels_data_df = convert_predictions_to_dataframe(labels) + for lf in labels.labeled_frames: + if lf.has_predicted_instances: + labels_data_df = convert_predictions_to_dataframe(labels) + break + else: + labels_data_df = pd.DataFrame() # For every video create a processing module for video_index, video in enumerate(labels.videos): @@ -304,7 +473,7 @@ def append_nwb_data( .unique() ) - for track_index, track_name in enumerate(name_of_tracks_in_video): + for track_name in name_of_tracks_in_video: pose_estimation_container = build_pose_estimation_container_for_track( labels_data_df, labels, @@ -338,6 +507,19 @@ def append_nwb( io.write(nwb_file) +def append_nwb_training(pose_training: PoseTraining, nwbfile_path: str) -> NWBFile: # type: ignore[return] + """Append a PoseTraining object to an existing NWB data file. + + Args: + pose_training: A PoseTraining object. + nwbfile_path: The path to the NWB file. + + Returns: + An in-memory NWB file with the PoseTraining data appended. + """ + raise NotImplementedError + + def get_processing_module_for_video( processing_module_name: str, nwbfile: NWBFile ) -> ProcessingModule: @@ -437,7 +619,7 @@ def build_pose_estimation_container_for_track( def build_track_pose_estimation_list( - track_data_df: pd.DataFrame, timestamps: ArrayLike + track_data_df: pd.DataFrame, timestamps: ArrayLike # type: ignore[return] ) -> List[PoseEstimationSeries]: """Build a list of PoseEstimationSeries from tracks.