Skip to content

Commit

Permalink
nwb to sleap conversion function
Browse files Browse the repository at this point in the history
  • Loading branch information
keyaloding committed Jun 28, 2024
1 parent 5a30110 commit 70a34ed
Showing 1 changed file with 43 additions and 1 deletion.
44 changes: 43 additions & 1 deletion sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@
except ImportError:
ArrayLike = np.ndarray
from pynwb import NWBFile, NWBHDF5IO, ProcessingModule # type: ignore[import]
from ndx_pose import PoseEstimationSeries, PoseEstimation # type: ignore[import]
from ndx_pose import (PoseEstimationSeries,
PoseEstimation,
TrainingFrame,
TrainingFrames,
PoseTraining,
SourceVideo,
) # type: ignore[import]

from sleap_io import (
Labels,
Expand All @@ -28,6 +34,42 @@
)
from sleap_io.io.utils import convert_predictions_to_dataframe

def convert_nwb(nwb_data_structure: Union[TrainingFrame,
TrainingFrames,
PoseTraining,
SourceVideo
]) -> Union[LabeledFrame, List[LabeledFrame], Labels, Video]:
"""Converts an NWB instance to its corresponding SLEAP instance."""

def convert_frame(frame: TrainingFrame) -> LabeledFrame:
"""
Converts an NWB TrainingFrame instance to a LabeledFrame instance.
"""
return LabeledFrame(
video=Video(filename=frame.source_video.data),
frame_idx=frame.frame_number.data,
instances=[
PredictedInstance.from_numpy(
points=frame.points.data,
point_scores=frame.confidence.data,
instance_score=frame.confidence.data.mean(),
skeleton=Skeleton(
nodes=frame.skeleton.nodes.data,
edges=frame.skeleton.edges.data,
),
)
],
)
if isinstance(nwb_data_structure, TrainingFrame):
return convert_frame(nwb_data_structure)
elif isinstance(nwb_data_structure, TrainingFrames):
return [convert_frame(frame) for frame in nwb_data_structure.training_frames]
elif isinstance(nwb_data_structure, PoseTraining):
return Labels([convert_frame(frame) for frame in nwb_data_structure.training_frames])
elif isinstance(nwb_data_structure, SourceVideo):
return Video(filename=nwb_data_structure.data)
else:
raise ValueError(f"Cannot convert {type(nwb_data_structure)} to SLEAP instance.")

def get_timestamps(series: PoseEstimationSeries) -> np.ndarray:
"""Return a vector of timestamps for a `PoseEstimationSeries`."""
Expand Down

0 comments on commit 70a34ed

Please sign in to comment.