From 0035e86279c69686315f766ffacc168a33338c59 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Sat, 13 Apr 2024 23:22:06 -0700 Subject: [PATCH] Labels QOL enhancements (#81) * hi * hi test * Add top level imports * Lint * coverage * Better checks for videos, save_file * tests * Remove predictions from labeled frame * Level up __getitem__ * Labels.save * Remove empty instances * Clean and remove predictions * Set a timeout for the CI * Add save_file to top level import --- .github/workflows/ci.yml | 1 + sleap_io/__init__.py | 1 + sleap_io/model/labeled_frame.py | 8 ++ sleap_io/model/labels.py | 117 +++++++++++++++++++++++++ tests/model/test_labeled_frame.py | 49 +++++++++++ tests/model/test_labels.py | 141 ++++++++++++++++++++++++++++++ 6 files changed, 317 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7bcc18f8..11fe5c6a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -54,6 +54,7 @@ jobs: # Tests with pytest tests: + timeout-minutes: 15 strategy: fail-fast: false matrix: diff --git a/sleap_io/__init__.py b/sleap_io/__init__.py index 9e63e48b..a872bf00 100644 --- a/sleap_io/__init__.py +++ b/sleap_io/__init__.py @@ -26,4 +26,5 @@ save_jabs, load_video, load_file, + save_file, ) diff --git a/sleap_io/model/labeled_frame.py b/sleap_io/model/labeled_frame.py index c771d49f..be7b957b 100644 --- a/sleap_io/model/labeled_frame.py +++ b/sleap_io/model/labeled_frame.py @@ -104,3 +104,11 @@ def unused_predictions(self) -> list[Instance]: ] return unused_predictions + + def remove_predictions(self): + """Remove all `PredictedInstance` objects from the frame.""" + self.instances = [inst for inst in self.instances if type(inst) == Instance] + + def remove_empty_instances(self): + """Remove all instances with no visible points.""" + self.instances = [inst for inst in self.instances if not inst.is_empty] diff --git a/sleap_io/model/labels.py b/sleap_io/model/labels.py index f3581ceb..0419a5db 100644 --- a/sleap_io/model/labels.py +++ b/sleap_io/model/labels.py @@ -63,6 +63,27 @@ def __getitem__(self, key: int) -> list[LabeledFrame] | LabeledFrame: """Return one or more labeled frames based on indexing criteria.""" if type(key) == int: return self.labeled_frames[key] + elif type(key) == slice: + return [self.labeled_frames[i] for i in range(*key.indices(len(self)))] + elif type(key) == list: + return [self.labeled_frames[i] for i in key] + elif isinstance(key, np.ndarray): + return [self.labeled_frames[i] for i in key.tolist()] + elif type(key) == tuple and len(key) == 2: + video, frame_idx = key + res = self.find(video, frame_idx) + if len(res) == 1: + return res[0] + elif len(res) == 0: + raise IndexError( + f"No labeled frames found for video {video} and " + f"frame index {frame_idx}." + ) + elif type(key) == Video: + res = self.find(key) + if len(res) == 0: + raise IndexError(f"No labeled frames found for video {key}.") + return res else: raise IndexError(f"Invalid indexing argument for labels: {key}") @@ -248,3 +269,99 @@ def find( results.append(LabeledFrame(video=video, frame_idx=frame_ind)) return results + + def save(self, filename: str, format: Optional[str] = None, **kwargs): + """Save labels to file in specified format. + + Args: + filename: Path to save labels to. + format: The format to save the labels in. If `None`, the format will be + inferred from the file extension. Available formats are "slp", "nwb", + "labelstudio", and "jabs". + """ + from sleap_io import save_file + + save_file(self, filename, format=format, **kwargs) + + def clean( + self, + frames: bool = True, + empty_instances: bool = False, + skeletons: bool = True, + tracks: bool = True, + videos: bool = False, + ): + """Remove empty frames, unused skeletons, tracks and videos. + + Args: + frames: If `True` (the default), remove empty frames. + empty_instances: If `True` (NOT default), remove instances that have no + visible points. + skeletons: If `True` (the default), remove unused skeletons. + tracks: If `True` (the default), remove unused tracks. + videos: If `True` (NOT default), remove videos that have no labeled frames. + """ + used_skeletons = [] + used_tracks = [] + used_videos = [] + kept_frames = [] + for lf in self.labeled_frames: + + if empty_instances: + lf.remove_empty_instances() + + if frames and len(lf) == 0: + continue + + if videos and lf.video not in used_videos: + used_videos.append(lf.video) + + if skeletons or tracks: + for inst in lf: + if skeletons and inst.skeleton not in used_skeletons: + used_skeletons.append(inst.skeleton) + if ( + tracks + and inst.track is not None + and inst.track not in used_tracks + ): + used_tracks.append(inst.track) + + if frames: + kept_frames.append(lf) + + if videos: + self.videos = [video for video in self.videos if video in used_videos] + + if skeletons: + self.skeletons = [ + skeleton for skeleton in self.skeletons if skeleton in used_skeletons + ] + + if tracks: + self.tracks = [track for track in self.tracks if track in used_tracks] + + if frames: + self.labeled_frames = kept_frames + + def remove_predictions(self, clean: bool = True): + """Remove all predicted instances from the labels. + + Args: + clean: If `True` (the default), also remove any empty frames and unused + tracks and skeletons. It does NOT remove videos that have no labeled + frames or instances with no visible points. + + See also: `Labels.clean` + """ + for lf in self.labeled_frames: + lf.remove_predictions() + + if clean: + self.clean( + frames=True, + empty_instances=False, + skeletons=True, + tracks=True, + videos=False, + ) diff --git a/tests/model/test_labeled_frame.py b/tests/model/test_labeled_frame.py index f8093ced..b542124e 100644 --- a/tests/model/test_labeled_frame.py +++ b/tests/model/test_labeled_frame.py @@ -3,6 +3,7 @@ from numpy.testing import assert_equal from sleap_io import Video, Skeleton, Instance, PredictedInstance from sleap_io.model.labeled_frame import LabeledFrame +import numpy as np def test_labeled_frame(): @@ -26,3 +27,51 @@ def test_labeled_frame(): # Test LabeledFrame.__getitem__ method assert lf[0] == inst + + +def test_remove_predictions(): + """Test removing predictions from `LabeledFrame`.""" + inst = Instance([[0, 1], [2, 3]], skeleton=Skeleton(["A", "B"])) + lf = LabeledFrame( + video=Video(filename="test"), + frame_idx=0, + instances=[ + inst, + PredictedInstance([[4, 5], [6, 7]], skeleton=Skeleton(["A", "B"])), + ], + ) + + assert len(lf) == 2 + assert len(lf.predicted_instances) == 1 + + # Remove predictions + lf.remove_predictions() + + assert len(lf) == 1 + assert len(lf.predicted_instances) == 0 + assert type(lf[0]) == Instance + assert_equal(lf.numpy(), [[[0, 1], [2, 3]]]) + + +def test_remove_empty_instances(): + """Test removing empty instances from `LabeledFrame`.""" + inst = Instance([[0, 1], [2, 3]], skeleton=Skeleton(["A", "B"])) + lf = LabeledFrame( + video=Video(filename="test"), + frame_idx=0, + instances=[ + inst, + Instance( + [[np.nan, np.nan], [np.nan, np.nan]], skeleton=Skeleton(["A", "B"]) + ), + ], + ) + + assert len(lf) == 2 + + # Remove empty instances + lf.remove_empty_instances() + + assert len(lf) == 1 + assert type(lf[0]) == Instance + assert_equal(lf.numpy(), [[[0, 1], [2, 3]]]) diff --git a/tests/model/test_labels.py b/tests/model/test_labels.py index e52a8103..0060a220 100644 --- a/tests/model/test_labels.py +++ b/tests/model/test_labels.py @@ -8,9 +8,12 @@ Instance, PredictedInstance, LabeledFrame, + Track, load_slp, + load_video, ) from sleap_io.model.labels import Labels +import numpy as np def test_labels(): @@ -117,3 +120,141 @@ def test_labels_skeleton(): labels.skeletons.append(Skeleton(["B"])) with pytest.raises(ValueError): labels.skeleton + + +def test_labels_getitem(slp_typical): + labels = load_slp(slp_typical) + labels.labeled_frames.append(LabeledFrame(video=labels.video, frame_idx=1)) + assert len(labels) == 2 + assert labels[0].frame_idx == 0 + assert len(labels[:2]) == 2 + assert len(labels[[0, 1]]) == 2 + assert len(labels[np.array([0, 1])]) == 2 + assert labels[(labels.video, 0)].frame_idx == 0 + + with pytest.raises(IndexError): + labels[(labels.video, 2000)] + + assert len(labels[labels.video]) == 2 + + with pytest.raises(IndexError): + labels[Video(filename="test")] + + with pytest.raises(IndexError): + labels[None] + + +def test_labels_save(tmp_path, slp_typical): + labels = load_slp(slp_typical) + labels.save(tmp_path / "test.slp") + assert (tmp_path / "test.slp").exists() + + +def test_labels_clean_unchanged(slp_real_data): + labels = load_slp(slp_real_data) + assert len(labels) == 10 + assert labels[0].frame_idx == 0 + assert len(labels[0]) == 2 + assert labels[1].frame_idx == 990 + assert len(labels[1]) == 2 + assert len(labels.skeletons) == 1 + assert len(labels.videos) == 1 + assert len(labels.tracks) == 0 + labels.clean( + frames=True, empty_instances=True, skeletons=True, tracks=True, videos=True + ) + assert len(labels) == 10 + assert labels[0].frame_idx == 0 + assert len(labels[0]) == 2 + assert labels[1].frame_idx == 990 + assert len(labels[1]) == 2 + assert len(labels.skeletons) == 1 + assert len(labels.videos) == 1 + assert len(labels.tracks) == 0 + + +def test_labels_clean_frames(slp_real_data): + labels = load_slp(slp_real_data) + assert labels[0].frame_idx == 0 + assert len(labels[0]) == 2 + labels[0].instances = [] + labels.clean( + frames=True, empty_instances=False, skeletons=False, tracks=False, videos=False + ) + assert len(labels) == 9 + assert labels[0].frame_idx == 990 + assert len(labels[0]) == 2 + + +def test_labels_clean_empty_instances(slp_real_data): + labels = load_slp(slp_real_data) + assert labels[0].frame_idx == 0 + assert len(labels[0]) == 2 + labels[0].instances = [ + Instance.from_numpy( + np.full((len(labels.skeleton), 2), np.nan), skeleton=labels.skeleton + ) + ] + labels.clean( + frames=False, empty_instances=True, skeletons=False, tracks=False, videos=False + ) + assert len(labels) == 10 + assert labels[0].frame_idx == 0 + assert len(labels[0]) == 0 + + labels.clean( + frames=True, empty_instances=True, skeletons=False, tracks=False, videos=False + ) + assert len(labels) == 9 + + +def test_labels_clean_skeletons(slp_real_data): + labels = load_slp(slp_real_data) + labels.skeletons.append(Skeleton(["A", "B"])) + assert len(labels.skeletons) == 2 + labels.clean( + frames=False, empty_instances=False, skeletons=True, tracks=False, videos=False + ) + assert len(labels) == 10 + assert len(labels.skeletons) == 1 + + +def test_labels_clean_tracks(slp_real_data): + labels = load_slp(slp_real_data) + labels.tracks.append(Track(name="test1")) + labels.tracks.append(Track(name="test2")) + assert len(labels.tracks) == 2 + labels[0].instances[0].track = labels.tracks[1] + labels.clean( + frames=False, empty_instances=False, skeletons=False, tracks=True, videos=False + ) + assert len(labels) == 10 + assert len(labels.tracks) == 1 + assert labels[0].instances[0].track == labels.tracks[0] + assert labels.tracks[0].name == "test2" + + +def test_labels_clean_videos(slp_real_data): + labels = load_slp(slp_real_data) + labels.videos.append(Video(filename="test2")) + assert len(labels.videos) == 2 + labels.clean( + frames=False, empty_instances=False, skeletons=False, tracks=False, videos=True + ) + assert len(labels) == 10 + assert len(labels.videos) == 1 + assert labels.video.filename == "tests/data/videos/centered_pair_low_quality.mp4" + + +def test_labels_remove_predictions(slp_real_data): + labels = load_slp(slp_real_data) + assert len(labels) == 10 + assert sum([len(lf.predicted_instances) for lf in labels]) == 12 + labels.remove_predictions(clean=False) + assert len(labels) == 10 + assert sum([len(lf.predicted_instances) for lf in labels]) == 0 + + labels = load_slp(slp_real_data) + labels.remove_predictions(clean=True) + assert len(labels) == 5 + assert sum([len(lf.predicted_instances) for lf in labels]) == 0