Skip to content

Commit

Permalink
Labels QOL enhancements (#81)
Browse files Browse the repository at this point in the history
* hi

* hi test

* Add top level imports

* Lint

* coverage

* Better checks for videos, save_file

* tests

* Remove predictions from labeled frame

* Level up __getitem__

* Labels.save

* Remove empty instances

* Clean and remove predictions

* Set a timeout for the CI

* Add save_file to top level import
  • Loading branch information
talmo authored Apr 14, 2024
1 parent 599b207 commit 0035e86
Show file tree
Hide file tree
Showing 6 changed files with 317 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ jobs:
# Tests with pytest
tests:
timeout-minutes: 15
strategy:
fail-fast: false
matrix:
Expand Down
1 change: 1 addition & 0 deletions sleap_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@
save_jabs,
load_video,
load_file,
save_file,
)
8 changes: 8 additions & 0 deletions sleap_io/model/labeled_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,11 @@ def unused_predictions(self) -> list[Instance]:
]

return unused_predictions

def remove_predictions(self):
"""Remove all `PredictedInstance` objects from the frame."""
self.instances = [inst for inst in self.instances if type(inst) == Instance]

def remove_empty_instances(self):
"""Remove all instances with no visible points."""
self.instances = [inst for inst in self.instances if not inst.is_empty]
117 changes: 117 additions & 0 deletions sleap_io/model/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,27 @@ def __getitem__(self, key: int) -> list[LabeledFrame] | LabeledFrame:
"""Return one or more labeled frames based on indexing criteria."""
if type(key) == int:
return self.labeled_frames[key]
elif type(key) == slice:
return [self.labeled_frames[i] for i in range(*key.indices(len(self)))]
elif type(key) == list:
return [self.labeled_frames[i] for i in key]
elif isinstance(key, np.ndarray):
return [self.labeled_frames[i] for i in key.tolist()]
elif type(key) == tuple and len(key) == 2:
video, frame_idx = key
res = self.find(video, frame_idx)
if len(res) == 1:
return res[0]
elif len(res) == 0:
raise IndexError(
f"No labeled frames found for video {video} and "
f"frame index {frame_idx}."
)
elif type(key) == Video:
res = self.find(key)
if len(res) == 0:
raise IndexError(f"No labeled frames found for video {key}.")
return res
else:
raise IndexError(f"Invalid indexing argument for labels: {key}")

Expand Down Expand Up @@ -248,3 +269,99 @@ def find(
results.append(LabeledFrame(video=video, frame_idx=frame_ind))

return results

def save(self, filename: str, format: Optional[str] = None, **kwargs):
"""Save labels to file in specified format.
Args:
filename: Path to save labels to.
format: The format to save the labels in. If `None`, the format will be
inferred from the file extension. Available formats are "slp", "nwb",
"labelstudio", and "jabs".
"""
from sleap_io import save_file

save_file(self, filename, format=format, **kwargs)

def clean(
self,
frames: bool = True,
empty_instances: bool = False,
skeletons: bool = True,
tracks: bool = True,
videos: bool = False,
):
"""Remove empty frames, unused skeletons, tracks and videos.
Args:
frames: If `True` (the default), remove empty frames.
empty_instances: If `True` (NOT default), remove instances that have no
visible points.
skeletons: If `True` (the default), remove unused skeletons.
tracks: If `True` (the default), remove unused tracks.
videos: If `True` (NOT default), remove videos that have no labeled frames.
"""
used_skeletons = []
used_tracks = []
used_videos = []
kept_frames = []
for lf in self.labeled_frames:

if empty_instances:
lf.remove_empty_instances()

if frames and len(lf) == 0:
continue

if videos and lf.video not in used_videos:
used_videos.append(lf.video)

if skeletons or tracks:
for inst in lf:
if skeletons and inst.skeleton not in used_skeletons:
used_skeletons.append(inst.skeleton)
if (
tracks
and inst.track is not None
and inst.track not in used_tracks
):
used_tracks.append(inst.track)

if frames:
kept_frames.append(lf)

if videos:
self.videos = [video for video in self.videos if video in used_videos]

if skeletons:
self.skeletons = [
skeleton for skeleton in self.skeletons if skeleton in used_skeletons
]

if tracks:
self.tracks = [track for track in self.tracks if track in used_tracks]

if frames:
self.labeled_frames = kept_frames

def remove_predictions(self, clean: bool = True):
"""Remove all predicted instances from the labels.
Args:
clean: If `True` (the default), also remove any empty frames and unused
tracks and skeletons. It does NOT remove videos that have no labeled
frames or instances with no visible points.
See also: `Labels.clean`
"""
for lf in self.labeled_frames:
lf.remove_predictions()

if clean:
self.clean(
frames=True,
empty_instances=False,
skeletons=True,
tracks=True,
videos=False,
)
49 changes: 49 additions & 0 deletions tests/model/test_labeled_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from numpy.testing import assert_equal
from sleap_io import Video, Skeleton, Instance, PredictedInstance
from sleap_io.model.labeled_frame import LabeledFrame
import numpy as np


def test_labeled_frame():
Expand All @@ -26,3 +27,51 @@ def test_labeled_frame():

# Test LabeledFrame.__getitem__ method
assert lf[0] == inst


def test_remove_predictions():
"""Test removing predictions from `LabeledFrame`."""
inst = Instance([[0, 1], [2, 3]], skeleton=Skeleton(["A", "B"]))
lf = LabeledFrame(
video=Video(filename="test"),
frame_idx=0,
instances=[
inst,
PredictedInstance([[4, 5], [6, 7]], skeleton=Skeleton(["A", "B"])),
],
)

assert len(lf) == 2
assert len(lf.predicted_instances) == 1

# Remove predictions
lf.remove_predictions()

assert len(lf) == 1
assert len(lf.predicted_instances) == 0
assert type(lf[0]) == Instance
assert_equal(lf.numpy(), [[[0, 1], [2, 3]]])


def test_remove_empty_instances():
"""Test removing empty instances from `LabeledFrame`."""
inst = Instance([[0, 1], [2, 3]], skeleton=Skeleton(["A", "B"]))
lf = LabeledFrame(
video=Video(filename="test"),
frame_idx=0,
instances=[
inst,
Instance(
[[np.nan, np.nan], [np.nan, np.nan]], skeleton=Skeleton(["A", "B"])
),
],
)

assert len(lf) == 2

# Remove empty instances
lf.remove_empty_instances()

assert len(lf) == 1
assert type(lf[0]) == Instance
assert_equal(lf.numpy(), [[[0, 1], [2, 3]]])
141 changes: 141 additions & 0 deletions tests/model/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@
Instance,
PredictedInstance,
LabeledFrame,
Track,
load_slp,
load_video,
)
from sleap_io.model.labels import Labels
import numpy as np


def test_labels():
Expand Down Expand Up @@ -117,3 +120,141 @@ def test_labels_skeleton():
labels.skeletons.append(Skeleton(["B"]))
with pytest.raises(ValueError):
labels.skeleton


def test_labels_getitem(slp_typical):
labels = load_slp(slp_typical)
labels.labeled_frames.append(LabeledFrame(video=labels.video, frame_idx=1))
assert len(labels) == 2
assert labels[0].frame_idx == 0
assert len(labels[:2]) == 2
assert len(labels[[0, 1]]) == 2
assert len(labels[np.array([0, 1])]) == 2
assert labels[(labels.video, 0)].frame_idx == 0

with pytest.raises(IndexError):
labels[(labels.video, 2000)]

assert len(labels[labels.video]) == 2

with pytest.raises(IndexError):
labels[Video(filename="test")]

with pytest.raises(IndexError):
labels[None]


def test_labels_save(tmp_path, slp_typical):
labels = load_slp(slp_typical)
labels.save(tmp_path / "test.slp")
assert (tmp_path / "test.slp").exists()


def test_labels_clean_unchanged(slp_real_data):
labels = load_slp(slp_real_data)
assert len(labels) == 10
assert labels[0].frame_idx == 0
assert len(labels[0]) == 2
assert labels[1].frame_idx == 990
assert len(labels[1]) == 2
assert len(labels.skeletons) == 1
assert len(labels.videos) == 1
assert len(labels.tracks) == 0
labels.clean(
frames=True, empty_instances=True, skeletons=True, tracks=True, videos=True
)
assert len(labels) == 10
assert labels[0].frame_idx == 0
assert len(labels[0]) == 2
assert labels[1].frame_idx == 990
assert len(labels[1]) == 2
assert len(labels.skeletons) == 1
assert len(labels.videos) == 1
assert len(labels.tracks) == 0


def test_labels_clean_frames(slp_real_data):
labels = load_slp(slp_real_data)
assert labels[0].frame_idx == 0
assert len(labels[0]) == 2
labels[0].instances = []
labels.clean(
frames=True, empty_instances=False, skeletons=False, tracks=False, videos=False
)
assert len(labels) == 9
assert labels[0].frame_idx == 990
assert len(labels[0]) == 2


def test_labels_clean_empty_instances(slp_real_data):
labels = load_slp(slp_real_data)
assert labels[0].frame_idx == 0
assert len(labels[0]) == 2
labels[0].instances = [
Instance.from_numpy(
np.full((len(labels.skeleton), 2), np.nan), skeleton=labels.skeleton
)
]
labels.clean(
frames=False, empty_instances=True, skeletons=False, tracks=False, videos=False
)
assert len(labels) == 10
assert labels[0].frame_idx == 0
assert len(labels[0]) == 0

labels.clean(
frames=True, empty_instances=True, skeletons=False, tracks=False, videos=False
)
assert len(labels) == 9


def test_labels_clean_skeletons(slp_real_data):
labels = load_slp(slp_real_data)
labels.skeletons.append(Skeleton(["A", "B"]))
assert len(labels.skeletons) == 2
labels.clean(
frames=False, empty_instances=False, skeletons=True, tracks=False, videos=False
)
assert len(labels) == 10
assert len(labels.skeletons) == 1


def test_labels_clean_tracks(slp_real_data):
labels = load_slp(slp_real_data)
labels.tracks.append(Track(name="test1"))
labels.tracks.append(Track(name="test2"))
assert len(labels.tracks) == 2
labels[0].instances[0].track = labels.tracks[1]
labels.clean(
frames=False, empty_instances=False, skeletons=False, tracks=True, videos=False
)
assert len(labels) == 10
assert len(labels.tracks) == 1
assert labels[0].instances[0].track == labels.tracks[0]
assert labels.tracks[0].name == "test2"


def test_labels_clean_videos(slp_real_data):
labels = load_slp(slp_real_data)
labels.videos.append(Video(filename="test2"))
assert len(labels.videos) == 2
labels.clean(
frames=False, empty_instances=False, skeletons=False, tracks=False, videos=True
)
assert len(labels) == 10
assert len(labels.videos) == 1
assert labels.video.filename == "tests/data/videos/centered_pair_low_quality.mp4"


def test_labels_remove_predictions(slp_real_data):
labels = load_slp(slp_real_data)
assert len(labels) == 10
assert sum([len(lf.predicted_instances) for lf in labels]) == 12
labels.remove_predictions(clean=False)
assert len(labels) == 10
assert sum([len(lf.predicted_instances) for lf in labels]) == 0

labels = load_slp(slp_real_data)
labels.remove_predictions(clean=True)
assert len(labels) == 5
assert sum([len(lf.predicted_instances) for lf in labels]) == 0

0 comments on commit 0035e86

Please sign in to comment.