From 4af9f38edcd40f5d99f2f93144230aec428f57aa Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Wed, 22 May 2024 08:58:26 -0700 Subject: [PATCH] Fix suggestions deserialization (#95) * Fix suggestions deserialization * Bump version --- sleap_io/io/slp.py | 8 +++++--- sleap_io/version.py | 2 +- tests/io/test_slp.py | 20 ++++++++++++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/sleap_io/io/slp.py b/sleap_io/io/slp.py index 1f172fe4..0218645c 100644 --- a/sleap_io/io/slp.py +++ b/sleap_io/io/slp.py @@ -197,9 +197,11 @@ def read_suggestions(labels_path: str, videos: list[Video]) -> list[SuggestionFr Returns: A list of `SuggestionFrame` objects. """ - suggestions = [ - json.loads(x) for x in read_hdf5_dataset(labels_path, "suggestions_json") - ] + try: + suggestions = read_hdf5_dataset(labels_path, "suggestions_json") + except KeyError: + return [] + suggestions = [json.loads(x) for x in suggestions] suggestions_objects = [] for suggestion in suggestions: suggestions_objects.append( diff --git a/sleap_io/version.py b/sleap_io/version.py index 49f63fa1..764ea5db 100644 --- a/sleap_io/version.py +++ b/sleap_io/version.py @@ -2,4 +2,4 @@ # Define package version. # This is read dynamically by setuptools in pyproject.toml to determine the release version. -__version__ = "0.1.1" +__version__ = "0.1.2" diff --git a/tests/io/test_slp.py b/tests/io/test_slp.py index e2ab9470..c2ea39ad 100644 --- a/tests/io/test_slp.py +++ b/tests/io/test_slp.py @@ -12,6 +12,7 @@ PredictedPoint, PredictedInstance, Labels, + SuggestionFrame, ) from sleap_io.io.slp import ( read_videos, @@ -29,6 +30,8 @@ write_lfs, read_labels, write_labels, + read_suggestions, + write_suggestions, ) from sleap_io.io.utils import read_hdf5_dataset import numpy as np @@ -237,3 +240,20 @@ def test_slp_imgvideo(tmpdir, slp_imgvideo): assert type(videos[0].backend) == ImageVideo assert len(videos[0].filename) == 2 assert videos[0].shape is None + + +def test_suggestions(tmpdir): + labels = Labels() + labels.videos.append(Video.from_filename("fake.mp4")) + labels.suggestions.append(SuggestionFrame(video=labels.video, frame_idx=0)) + + write_suggestions(tmpdir / "test.slp", labels.suggestions, labels.videos) + loaded_suggestions = read_suggestions(tmpdir / "test.slp", labels.videos) + assert len(loaded_suggestions) == 1 + assert loaded_suggestions[0].video.filename == "fake.mp4" + assert loaded_suggestions[0].frame_idx == 0 + + # Handle missing suggestions dataset + write_videos(tmpdir / "test2.slp", labels.videos) + loaded_suggestions = read_suggestions(tmpdir / "test2.slp", labels.videos) + assert len(loaded_suggestions) == 0