Skip to content

Commit

Permalink
Clean and remove predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed Apr 14, 2024
1 parent cff0490 commit aca30c0
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 0 deletions.
83 changes: 83 additions & 0 deletions sleap_io/model/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,3 +282,86 @@ def save(self, filename: str, format: Optional[str] = None, **kwargs):
from sleap_io import save_file

save_file(self, filename, format=format, **kwargs)

def clean(
self,
frames: bool = True,
empty_instances: bool = False,
skeletons: bool = True,
tracks: bool = True,
videos: bool = False,
):
"""Remove empty frames, unused skeletons, tracks and videos.
Args:
frames: If `True` (the default), remove empty frames.
empty_instances: If `True` (NOT default), remove instances that have no
visible points.
skeletons: If `True` (the default), remove unused skeletons.
tracks: If `True` (the default), remove unused tracks.
videos: If `True` (NOT default), remove videos that have no labeled frames.
"""
used_skeletons = []
used_tracks = []
used_videos = []
kept_frames = []
for lf in self.labeled_frames:

if empty_instances:
lf.remove_empty_instances()

if frames and len(lf) == 0:
continue

if videos and lf.video not in used_videos:
used_videos.append(lf.video)

if skeletons or tracks:
for inst in lf:
if skeletons and inst.skeleton not in used_skeletons:
used_skeletons.append(inst.skeleton)
if (
tracks
and inst.track is not None
and inst.track not in used_tracks
):
used_tracks.append(inst.track)

if frames:
kept_frames.append(lf)

if videos:
self.videos = [video for video in self.videos if video in used_videos]

if skeletons:
self.skeletons = [
skeleton for skeleton in self.skeletons if skeleton in used_skeletons
]

if tracks:
self.tracks = [track for track in self.tracks if track in used_tracks]

if frames:
self.labeled_frames = kept_frames

def remove_predictions(self, clean: bool = True):
"""Remove all predicted instances from the labels.
Args:
clean: If `True` (the default), also remove any empty frames and unused
tracks and skeletons. It does NOT remove videos that have no labeled
frames or instances with no visible points.
See also: `Labels.clean`
"""
for lf in self.labeled_frames:
lf.remove_predictions()

if clean:
self.clean(
frames=True,
empty_instances=False,
skeletons=True,
tracks=True,
videos=False,
)
111 changes: 111 additions & 0 deletions tests/model/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Instance,
PredictedInstance,
LabeledFrame,
Track,
load_slp,
load_video,
)
Expand Down Expand Up @@ -147,3 +148,113 @@ def test_labels_save(tmp_path, slp_typical):
labels = load_slp(slp_typical)
labels.save(tmp_path / "test.slp")
assert (tmp_path / "test.slp").exists()


def test_labels_clean_unchanged(slp_real_data):
labels = load_slp(slp_real_data)
assert len(labels) == 10
assert labels[0].frame_idx == 0
assert len(labels[0]) == 2
assert labels[1].frame_idx == 990
assert len(labels[1]) == 2
assert len(labels.skeletons) == 1
assert len(labels.videos) == 1
assert len(labels.tracks) == 0
labels.clean(
frames=True, empty_instances=True, skeletons=True, tracks=True, videos=True
)
assert len(labels) == 10
assert labels[0].frame_idx == 0
assert len(labels[0]) == 2
assert labels[1].frame_idx == 990
assert len(labels[1]) == 2
assert len(labels.skeletons) == 1
assert len(labels.videos) == 1
assert len(labels.tracks) == 0


def test_labels_clean_frames(slp_real_data):
labels = load_slp(slp_real_data)
assert labels[0].frame_idx == 0
assert len(labels[0]) == 2
labels[0].instances = []
labels.clean(
frames=True, empty_instances=False, skeletons=False, tracks=False, videos=False
)
assert len(labels) == 9
assert labels[0].frame_idx == 990
assert len(labels[0]) == 2


def test_labels_clean_empty_instances(slp_real_data):
labels = load_slp(slp_real_data)
assert labels[0].frame_idx == 0
assert len(labels[0]) == 2
labels[0].instances = [
Instance.from_numpy(
np.full((len(labels.skeleton), 2), np.nan), skeleton=labels.skeleton
)
]
labels.clean(
frames=False, empty_instances=True, skeletons=False, tracks=False, videos=False
)
assert len(labels) == 10
assert labels[0].frame_idx == 0
assert len(labels[0]) == 0

labels.clean(
frames=True, empty_instances=True, skeletons=False, tracks=False, videos=False
)
assert len(labels) == 9


def test_labels_clean_skeletons(slp_real_data):
labels = load_slp(slp_real_data)
labels.skeletons.append(Skeleton(["A", "B"]))
assert len(labels.skeletons) == 2
labels.clean(
frames=False, empty_instances=False, skeletons=True, tracks=False, videos=False
)
assert len(labels) == 10
assert len(labels.skeletons) == 1


def test_labels_clean_tracks(slp_real_data):
labels = load_slp(slp_real_data)
labels.tracks.append(Track(name="test1"))
labels.tracks.append(Track(name="test2"))
assert len(labels.tracks) == 2
labels[0].instances[0].track = labels.tracks[1]
labels.clean(
frames=False, empty_instances=False, skeletons=False, tracks=True, videos=False
)
assert len(labels) == 10
assert len(labels.tracks) == 1
assert labels[0].instances[0].track == labels.tracks[0]
assert labels.tracks[0].name == "test2"


def test_labels_clean_videos(slp_real_data):
labels = load_slp(slp_real_data)
labels.videos.append(Video(filename="test2"))
assert len(labels.videos) == 2
labels.clean(
frames=False, empty_instances=False, skeletons=False, tracks=False, videos=True
)
assert len(labels) == 10
assert len(labels.videos) == 1
assert labels.video.filename == "tests/data/videos/centered_pair_low_quality.mp4"


def test_labels_remove_predictions(slp_real_data):
labels = load_slp(slp_real_data)
assert len(labels) == 10
assert sum([len(lf.predicted_instances) for lf in labels]) == 12
labels.remove_predictions(clean=False)
assert len(labels) == 10
assert sum([len(lf.predicted_instances) for lf in labels]) == 0

labels = load_slp(slp_real_data)
labels.remove_predictions(clean=True)
assert len(labels) == 5
assert sum([len(lf.predicted_instances) for lf in labels]) == 0

0 comments on commit aca30c0

Please sign in to comment.