Skip to content

Commit

Permalink
added error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
keyaloding committed Jul 19, 2024
1 parent 055e3e8 commit dcc3da8
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 12 deletions.
120 changes: 120 additions & 0 deletions load_test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"import sleap_io as sio\n",
"import pynwb\n",
"import numpy as np\n",
"from numpy.testing import assert_equal\n",
"from sleap_io import save_file\n",
"from sleap_io.model.skeleton import Node, Edge, Symmetry, Skeleton\n",
"from sleap_io.model.instance import *\n",
"from pynwb.image import ImageSeries\n",
"from ndx_pose import (\n",
" PoseEstimation,\n",
" PoseEstimationSeries,\n",
" TrainingFrame,\n",
" TrainingFrames,\n",
" PoseTraining,\n",
" SourceVideos,\n",
")\n",
"from sleap_io.io.nwb import *"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Labels(labeled_frames=1, videos=1, skeletons=1, tracks=0, suggestions=0)\n"
]
}
],
"source": [
"labels_original = sio.load_slp(\"tests/data/slp/minimal_instance.pkg.slp\")\n",
"skel = Skeleton([Node(\"A\"), Node(\"B\")])\n",
"skel.name = \"name\"\n",
"skel.edges = [Edge(Node(\"A\"), Node(\"B\"))]\n",
"slp_skeleton_to_nwb(skel)\n",
"\n",
"inst = Instance({\"A\": [0, 1], \"B\": [2, 3]}, skeleton=Skeleton([\"A\", \"B\"]))\n",
"inst.skeleton.name = \"name\"\n",
"inst.skeleton.edges = [Edge(Node(\"A\"), Node(\"B\"))]\n",
"instance_to_skeleton_instance(inst)\n",
"\n",
"pose = labels_to_pose_training(labels_original)\n",
"labels = pose_training_to_labels(pose)\n",
"print(labels)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"ename": "KeyError",
"evalue": "'tests/data/slp/minimal_instance.pkg.slp'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[7], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m labels_original \u001b[38;5;241m=\u001b[39m sio\u001b[38;5;241m.\u001b[39mload_slp(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtests/data/slp/minimal_instance.pkg.slp\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m \u001b[43mlabels_original\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msave\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mminimal_instance.pkg.nwb\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mnwb_training\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4\u001b[0m labels_loaded \u001b[38;5;241m=\u001b[39m sio\u001b[38;5;241m.\u001b[39mload_nwb(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtests/data/slp/minimal_instance.pkg.nwb\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28mprint\u001b[39m(labels_loaded)\n",
"File \u001b[0;32m~/salk/io_fork/sleap_io/model/labels.py:372\u001b[0m, in \u001b[0;36mLabels.save\u001b[0;34m(self, filename, format, embed, **kwargs)\u001b[0m\n\u001b[1;32m 348\u001b[0m \u001b[38;5;250m\u001b[39m\u001b[38;5;124;03m\"\"\"Save labels to file in specified format.\u001b[39;00m\n\u001b[1;32m 349\u001b[0m \n\u001b[1;32m 350\u001b[0m \u001b[38;5;124;03mArgs:\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 368\u001b[0m \u001b[38;5;124;03m This argument is only valid for the SLP backend.\u001b[39;00m\n\u001b[1;32m 369\u001b[0m \u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 370\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01msleap_io\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m save_file\n\u001b[0;32m--> 372\u001b[0m \u001b[43msave_file\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mformat\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43membed\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43membed\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/salk/io_fork/sleap_io/io/main.py:235\u001b[0m, in \u001b[0;36msave_file\u001b[0;34m(labels, filename, format, **kwargs)\u001b[0m\n\u001b[1;32m 233\u001b[0m save_slp(labels, filename, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 234\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mformat\u001b[39m \u001b[38;5;129;01min\u001b[39;00m (\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnwb\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnwb_training\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mnwb_predictions\u001b[39m\u001b[38;5;124m\"\u001b[39m):\n\u001b[0;32m--> 235\u001b[0m \u001b[43msave_nwb\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 236\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m \u001b[38;5;28mformat\u001b[39m \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabelstudio\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n\u001b[1;32m 237\u001b[0m save_labelstudio(labels, filename, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
"File \u001b[0;32m~/salk/io_fork/sleap_io/io/main.py:86\u001b[0m, in \u001b[0;36msave_nwb\u001b[0;34m(labels, filename, as_training, append, **kwargs)\u001b[0m\n\u001b[1;32m 84\u001b[0m nwb\u001b[38;5;241m.\u001b[39mappend_nwb(labels, filename, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[1;32m 85\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 86\u001b[0m \u001b[43mnwb\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite_nwb\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mfilename\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/salk/io_fork/sleap_io/io/nwb.py:408\u001b[0m, in \u001b[0;36mwrite_nwb\u001b[0;34m(labels, nwbfile_path, nwb_file_kwargs, pose_estimation_metadata)\u001b[0m\n\u001b[1;32m 401\u001b[0m nwb_file_kwargs\u001b[38;5;241m.\u001b[39mupdate(\n\u001b[1;32m 402\u001b[0m session_description\u001b[38;5;241m=\u001b[39msession_description,\n\u001b[1;32m 403\u001b[0m session_start_time\u001b[38;5;241m=\u001b[39msession_start_time,\n\u001b[1;32m 404\u001b[0m identifier\u001b[38;5;241m=\u001b[39midentifier,\n\u001b[1;32m 405\u001b[0m )\n\u001b[1;32m 407\u001b[0m nwbfile \u001b[38;5;241m=\u001b[39m NWBFile(\u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mnwb_file_kwargs)\n\u001b[0;32m--> 408\u001b[0m nwbfile \u001b[38;5;241m=\u001b[39m \u001b[43mappend_nwb_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mlabels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnwbfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpose_estimation_metadata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 410\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m NWBHDF5IO(\u001b[38;5;28mstr\u001b[39m(nwbfile_path), \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;28;01mas\u001b[39;00m io:\n\u001b[1;32m 411\u001b[0m io\u001b[38;5;241m.\u001b[39mwrite(nwbfile)\n",
"File \u001b[0;32m~/salk/io_fork/sleap_io/io/nwb.py:491\u001b[0m, in \u001b[0;36mappend_nwb_data\u001b[0;34m(labels, nwbfile, pose_estimation_metadata)\u001b[0m\n\u001b[1;32m 487\u001b[0m default_metadata\u001b[38;5;241m.\u001b[39mupdate(pose_estimation_metadata)\n\u001b[1;32m 489\u001b[0m \u001b[38;5;66;03m# For every track in that video create a PoseEstimation container\u001b[39;00m\n\u001b[1;32m 490\u001b[0m name_of_tracks_in_video \u001b[38;5;241m=\u001b[39m (\n\u001b[0;32m--> 491\u001b[0m \u001b[43mlabels_data_df\u001b[49m\u001b[43m[\u001b[49m\u001b[43mvideo\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfilename\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 492\u001b[0m \u001b[38;5;241m.\u001b[39mcolumns\u001b[38;5;241m.\u001b[39mget_level_values(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtrack_name\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 493\u001b[0m \u001b[38;5;241m.\u001b[39munique()\n\u001b[1;32m 494\u001b[0m )\n\u001b[1;32m 496\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m track_name \u001b[38;5;129;01min\u001b[39;00m name_of_tracks_in_video:\n\u001b[1;32m 497\u001b[0m pose_estimation_container \u001b[38;5;241m=\u001b[39m build_pose_estimation_container_for_track(\n\u001b[1;32m 498\u001b[0m labels_data_df,\n\u001b[1;32m 499\u001b[0m labels,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 502\u001b[0m default_metadata,\n\u001b[1;32m 503\u001b[0m )\n",
"File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/pandas/core/frame.py:4102\u001b[0m, in \u001b[0;36mDataFrame.__getitem__\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 4100\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcolumns\u001b[38;5;241m.\u001b[39mnlevels \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 4101\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_getitem_multilevel(key)\n\u001b[0;32m-> 4102\u001b[0m indexer \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcolumns\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mget_loc\u001b[49m\u001b[43m(\u001b[49m\u001b[43mkey\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 4103\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m is_integer(indexer):\n\u001b[1;32m 4104\u001b[0m indexer \u001b[38;5;241m=\u001b[39m [indexer]\n",
"File \u001b[0;32m~/mambaforge3/envs/io_dev/lib/python3.12/site-packages/pandas/core/indexes/range.py:417\u001b[0m, in \u001b[0;36mRangeIndex.get_loc\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m 415\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01merr\u001b[39;00m\n\u001b[1;32m 416\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(key, Hashable):\n\u001b[0;32m--> 417\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key)\n\u001b[1;32m 418\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_check_indexing_error(key)\n\u001b[1;32m 419\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mKeyError\u001b[39;00m(key)\n",
"\u001b[0;31mKeyError\u001b[0m: 'tests/data/slp/minimal_instance.pkg.slp'"
]
}
],
"source": [
"labels_original = sio.load_slp(\"tests/data/slp/minimal_instance.pkg.slp\")\n",
"labels_original.save(\"minimal_instance.pkg.nwb\", format=\"nwb_training\")\n",
"\n",
"labels_loaded = sio.load_nwb(\"tests/data/slp/minimal_instance.pkg.nwb\")\n",
"print(labels_loaded)\n",
"assert len(labels_original.labeled_frames) == len(labels_loaded.labeled_frames)\n",
"assert len(labels_original.videos) == len(labels_loaded.videos)\n",
"assert len(labels_original.skeletons) == len(labels_loaded.skeletons)\n",
"assert len(labels_original.tracks) == len(labels_loaded.tracks)\n",
"assert len(labels_original.suggestions) == len(labels_loaded.suggestions)\n",
"assert labels_original.provenance == labels_loaded.provenance"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "io_dev",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
1 change: 1 addition & 0 deletions sleap_io/io/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def save_nwb(
filename: str,
as_training: bool = None,
append: bool = True,
img_paths: Optional[list[str]] = None,
**kwargs,
):
"""Save a SLEAP dataset to NWB format.
Expand Down
65 changes: 53 additions & 12 deletions sleap_io/io/nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
Node,
)
from sleap_io.io.utils import convert_predictions_to_dataframe
from sleap_io.io.main import load_slp
from sleap_io.io.main import load_slp, save_nwb


def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ignore[return]
Expand All @@ -58,6 +58,9 @@ def pose_training_to_labels(pose_training: PoseTraining) -> Labels: # type: ign
Returns:
A Labels object.
"""
if not isinstance(pose_training, PoseTraining):
raise ValueError("The input must be an NWB PoseTraining object.")

labeled_frames = []
for training_frame in pose_training.training_frames.training_frames.values():
video = Video(filename=f"{training_frame.source_video}")
Expand All @@ -84,6 +87,9 @@ def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: igno
Returns:
A SLEAP skeleton.
"""
if not isinstance(skeleton, NWBSkeleton):
raise ValueError("The input must be an NWB Skeleton object.")

nodes = [Node(name=node) for node in skeleton.nodes]
edges = [Edge(source=edge[0], destination=edge[1]) for edge in skeleton.edges]
return SLEAPSkeleton(
Expand All @@ -93,21 +99,27 @@ def nwb_skeleton_to_sleap(skeleton: NWBSkeleton) -> SLEAPSkeleton: # type: igno
)


def labels_to_pose_training(labels: Labels, **kwargs) -> PoseTraining: # type: ignore[return]
def labels_to_pose_training(
labels: Labels, img_paths: Optional[list[str]] = None, **kwargs
) -> PoseTraining: # type: ignore[return]
"""Creates an NWB PoseTraining object from a Labels object.
Args:
labels: A Labels object.
filename: The filename of the source video.
img_paths: An optional list of image paths for the labeled frames.
Returns:
A PoseTraining object.
"""
if not isinstance(labels, Labels):
raise ValueError("The input must be a SLEAP Labels object.")

training_frame_list = []
for i, labeled_frame in enumerate(labels.labeled_frames):
training_frame_name = name_generator("training_frame")
training_frame_annotator = f"{training_frame_name}_{i}"
skeleton_instances_list = []

for instance in labeled_frame.instances:
if isinstance(instance, PredictedInstance):
continue
Expand All @@ -119,10 +131,22 @@ def labels_to_pose_training(labels: Labels, **kwargs) -> PoseTraining: # type:
)
training_frame_video = labeled_frame.video
training_frame_video_index = labeled_frame.frame_idx
training_frame = TrainingFrame(
name=training_frame_name,
annotator=training_frame_annotator,
skeleton_instances=training_frame_skeleton_instances,

if img_paths:
source_video = ImageSeries(
name=training_frame_name,
description=training_frame_annotator,
unit="NA",
format="external",
external_file=img_paths,
dimension=[
training_frame_video.backend.img_shape[0],
training_frame_video.backend.img_shape[1],
],
starting_frame=[0],
rate=30.0, # change to `video.backend.fps` when available
)
else:
source_video=ImageSeries(
name=training_frame_name,
description=training_frame_annotator,
Expand All @@ -134,8 +158,13 @@ def labels_to_pose_training(labels: Labels, **kwargs) -> PoseTraining: # type:
training_frame_video.shape[2],
],
starting_frame=[0],
rate=30.0,
),
rate=30.0, # change to `video.backend.fps` when available
)
training_frame = TrainingFrame(
name=training_frame_name,
annotator=training_frame_annotator,
skeleton_instances=training_frame_skeleton_instances,
source_video=source_video,
source_video_frame_index=training_frame_video_index,
)
training_frame_list.append(training_frame)
Expand All @@ -158,6 +187,8 @@ def slp_skeleton_to_nwb(skeleton: SLEAPSkeleton) -> NWBSkeleton: # type: ignore
An NWB skeleton.
"""
nwb_edges: list[list[int, int]]
if not isinstance(skeleton, SLEAPSkeleton):
raise ValueError("The input must be a SLEAP Skeleton object.")

skeleton_edges = {i: node for i, node in enumerate(skeleton.nodes)}
nwb_edges = []
Expand Down Expand Up @@ -188,6 +219,9 @@ def instance_to_skeleton_instance(instance: Instance) -> SkeletonInstance: # ty
Returns:
An NWB SkeletonInstance.
"""
if not isinstance(instance, Instance):
raise ValueError("The input must be a SLEAP Instance object.")

skeleton = slp_skeleton_to_nwb(instance.skeleton)
points_list = list(instance.points.values())
node_locs = [[point.x, point.y] for point in points_list]
Expand All @@ -210,6 +244,9 @@ def videos_to_source_videos(videos: List[Video]) -> SourceVideos: # type: ignor
Returns:
An NWB SourceVideos object.
"""
if not isinstance(videos, list) or not all(isinstance(video, Video) for video in videos):
raise ValueError("The input must be a list of SLEAP Video objects.")

source_videos = []
for video in videos:
image_series = ImageSeries(
Expand All @@ -231,13 +268,13 @@ def sleap_pkg_to_nwb(filename: str, **kwargs) -> NWBFile:
Args:
filename: The path to the SLEAP package.
Returns:
An NWBFile object.
"""
if not filename.endswith(".pkg.slp"):
raise ValueError("The filename must end with '.pkg.slp'.")

labels = load_slp(filename)

save_path = Path(filename.replace(".slp", ".nwb"))
Expand All @@ -250,7 +287,11 @@ def sleap_pkg_to_nwb(filename: str, **kwargs) -> NWBFile:
else:
imwrite(img_path, labeled_frame.image)
img_paths.append(img_path)
return img_paths

# then use img_paths when saving the NWB TrainingFrames with references
# to the appropriate image files
save_nwb(labels, save_path, img_paths=img_paths, **kwargs)
raise NotImplementedError("This function is not yet implemented.")


def get_timestamps(series: PoseEstimationSeries) -> np.ndarray:
Expand Down

0 comments on commit dcc3da8

Please sign in to comment.