Skip to content

Commit

Permalink
Fix serialization and logic for checking for embedded images
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed May 5, 2024
1 parent 247f92e commit d35f07e
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 30 deletions.
22 changes: 18 additions & 4 deletions sleap_io/io/slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,15 @@
read_hdf5_attrs,
read_hdf5_dataset,
)
import imageio.v3 as iio
from enum import IntEnum
from pathlib import Path
import imageio.v3 as iio
import sys

try:
import cv2
except ImportError:
pass


class InstanceType(IntEnum):
Expand Down Expand Up @@ -214,9 +220,17 @@ def embed_video(
if image_format == "hdf5":
img_data = frame
else:
img_data = iio.imwrite(
"<bytes>", frame, extension="." + image_format
).astype("int8")
if "cv2" in sys.modules:
img_data = np.squeeze(
cv2.imencode("." + image_format, frame)[1]
).astype("int8")
else:
img_data = np.frombuffer(
iio.imwrite(
"<bytes>", frame.squeeze(axis=-1), extension="." + image_format
),
dtype="int8",
)

imgs_data.append(img_data)

Expand Down
5 changes: 2 additions & 3 deletions sleap_io/io/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,13 +585,12 @@ def has_embedded_images(self) -> bool:
"""Return True if the dataset contains embedded images."""
return self.image_format is not None and self.image_format != "hdf5"

def decode_embedded(self, img_string: np.ndarray, format: str) -> np.ndarray:
def decode_embedded(self, img_string: np.ndarray) -> np.ndarray:
"""Decode an embedded image string into a numpy array.
Args:
img_string: Binary string of the image as a `int8` numpy vector with the
bytes as values corresponding to the format-encoded image.
format: Image format (e.g., "png" or "jpg").
Returns:
The decoded image as a numpy array of shape `(height, width, channels)`. If
Expand All @@ -604,7 +603,7 @@ def decode_embedded(self, img_string: np.ndarray, format: str) -> np.ndarray:
if "cv2" in sys.modules:
img = cv2.imdecode(img_string, cv2.IMREAD_UNCHANGED)
else:
img = iio.imread(BytesIO(img_string), extension=f".{format}")
img = iio.imread(BytesIO(img_string), extension=f".{self.image_format}")

if img.ndim == 2:
img = np.expand_dims(img, axis=-1)
Expand Down
40 changes: 17 additions & 23 deletions tests/io/test_slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,32 +101,26 @@ def test_read_videos_pkg(slp_minimal_pkg):

def test_write_videos(slp_minimal_pkg, centered_pair, tmp_path):

def load_jsons(h5_path, dataset):
return [json.loads(x) for x in read_hdf5_dataset(h5_path, dataset)]

def compare_jsons(jsons_ref, jsons_test):
for jsons_ref, jsons_test in zip(jsons_ref, jsons_test):
for k in jsons_ref["backend"]:
assert jsons_ref["backend"][k] == jsons_test["backend"][k]

videos = read_videos(slp_minimal_pkg)
write_videos(tmp_path / "test_minimal_pkg.slp", videos)
json_fixture = load_jsons(slp_minimal_pkg, "videos_json")
json_test = load_jsons(tmp_path / "test_minimal_pkg.slp", "videos_json")
compare_jsons(json_fixture, json_test)

videos = read_videos(centered_pair)
write_videos(tmp_path / "test_centered_pair.slp", videos)
json_fixture = load_jsons(centered_pair, "videos_json")
json_test = load_jsons(tmp_path / "test_centered_pair.slp", "videos_json")
compare_jsons(json_fixture, json_test)
def compare_videos(videos_ref, videos_test):
assert len(videos_ref) == len(videos_test)
for video_ref, video_test in zip(videos_ref, videos_test):
assert video_ref.shape == video_test.shape
assert (video_ref[0] == video_test[0]).all()

videos_ref = read_videos(slp_minimal_pkg)
write_videos(tmp_path / "test_minimal_pkg.slp", videos_ref)
videos_test = read_videos(tmp_path / "test_minimal_pkg.slp")
compare_videos(videos_ref, videos_test)

videos_ref = read_videos(centered_pair)
write_videos(tmp_path / "test_centered_pair.slp", videos_ref)
videos_test = read_videos(tmp_path / "test_centered_pair.slp")
compare_videos(videos_ref, videos_test)

videos = read_videos(centered_pair) * 2
write_videos(tmp_path / "test_centered_pair_2vids.slp", videos)
json_test = read_hdf5_dataset(
tmp_path / "test_centered_pair_2vids.slp", "videos_json"
)
assert len(json_test) == 2
videos_test = read_videos(tmp_path / "test_centered_pair_2vids.slp")
compare_videos(videos, videos_test)


def test_write_tracks(centered_pair, tmp_path):
Expand Down

0 comments on commit d35f07e

Please sign in to comment.