Skip to content

Commit

Permalink
Fix video indexing when embedding from labels that already have embed…
Browse files Browse the repository at this point in the history
…ded data (#126)

* Don't reuse old video object

* Cast frame index to int on load

* Add embedded image reading test
  • Loading branch information
talmo authored Oct 3, 2024
1 parent c0e922e commit fa6d4fc
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 15 deletions.
31 changes: 16 additions & 15 deletions sleap_io/io/slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,13 @@ def embed_video(

# Store metadata.
ds.attrs["format"] = image_format
video_shape = video.shape
(
ds.attrs["frames"],
ds.attrs["height"],
ds.attrs["width"],
ds.attrs["channels"],
) = video.shape
) = video_shape

# Store frame indices.
f.create_dataset(f"{group}/frame_numbers", data=frame_inds)
Expand All @@ -320,20 +321,20 @@ def embed_video(
if video.source_video is not None:
# If this is already an embedded dataset, retain the previous source video.
source_video = video.source_video
embedded_video = video
video.replace_filename(labels_path, open=False)
else:
source_video = video
embedded_video = Video(
filename=labels_path,
backend=VideoBackend.from_filename(
labels_path,
dataset=f"{group}/video",
grayscale=video.grayscale,
keep_open=False,
),
source_video=source_video,
)

# Create a new video object with the embedded data.
embedded_video = Video(
filename=labels_path,
backend=VideoBackend.from_filename(
labels_path,
dataset=f"{group}/video",
grayscale=video.grayscale,
keep_open=False,
),
source_video=source_video,
)

grp = f.require_group(f"{group}/source_video")
grp.attrs["json"] = json.dumps(
Expand Down Expand Up @@ -369,7 +370,7 @@ def embed_frames(
to_embed_by_video[video].append(frame_idx)

for video in to_embed_by_video:
to_embed_by_video[video] = np.unique(to_embed_by_video[video])
to_embed_by_video[video] = np.unique(to_embed_by_video[video]).tolist()

replaced_videos = {}
for video, frame_inds in to_embed_by_video.items():
Expand Down Expand Up @@ -1068,7 +1069,7 @@ def read_labels(labels_path: str, open_videos: bool = True) -> Labels:
labeled_frames.append(
LabeledFrame(
video=videos[video_id],
frame_idx=frame_idx,
frame_idx=int(frame_idx),
instances=instances[instance_id_start:instance_id_end],
)
)
Expand Down
5 changes: 5 additions & 0 deletions tests/model/test_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,3 +627,8 @@ def test_make_training_splits_save(slp_real_data, tmp_path, embed):
assert train_.video.filename == labels.video.filename
assert val_.video.filename == labels.video.filename
assert test_.video.filename == labels.video.filename

if embed:
for labels_ in [train_, val_, test_]:
for lf in labels_:
assert lf.image.shape == (384, 384, 1)

0 comments on commit fa6d4fc

Please sign in to comment.