diff --git a/sleap_io/model/labels.py b/sleap_io/model/labels.py index bf6a6795..08e77a4e 100644 --- a/sleap_io/model/labels.py +++ b/sleap_io/model/labels.py @@ -473,6 +473,13 @@ def replace_videos( video_map: Alternative input of dictionary where keys are the old videos and values are the new videos. """ + if ( + old_videos is None + and new_videos is not None + and len(new_videos) == len(self.videos) + ): + old_videos = self.videos + if video_map is None: video_map = {o: n for o, n in zip(old_videos, new_videos)} @@ -486,6 +493,9 @@ def replace_videos( if sf.video in video_map: sf.video = video_map[sf.video] + # Update the list of videos. + self.videos = [video_map.get(video, video) for video in self.videos] + def replace_filenames( self, new_filenames: list[str | Path] | None = None, diff --git a/tests/model/test_labels.py b/tests/model/test_labels.py index ff1d95d2..fa080bd1 100644 --- a/tests/model/test_labels.py +++ b/tests/model/test_labels.py @@ -356,9 +356,7 @@ def test_labels_remove_predictions(slp_real_data): def test_replace_videos(slp_real_data): labels = load_slp(slp_real_data) assert labels.video.filename == "tests/data/videos/centered_pair_low_quality.mp4" - labels.replace_videos( - old_videos=[labels.video], new_videos=[Video.from_filename("fake.mp4")] - ) + labels.replace_videos(new_videos=[Video.from_filename("fake.mp4")]) for lf in labels: assert lf.video.filename == "fake.mp4" @@ -366,6 +364,8 @@ def test_replace_videos(slp_real_data): for sf in labels.suggestions: assert sf.video.filename == "fake.mp4" + assert labels.video.filename == "fake.mp4" + def test_replace_filenames(): labels = Labels(videos=[Video.from_filename("a.mp4"), Video.from_filename("b.mp4")])