Skip to content

Commit

Permalink
Add high level APIs (#80)
Browse files Browse the repository at this point in the history
* hi

* hi test

* Add top level imports

* Lint

* coverage

* Better checks for videos, save_file

* tests
  • Loading branch information
talmo authored Apr 14, 2024
1 parent 12b15fa commit 599b207
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 7 deletions.
2 changes: 2 additions & 0 deletions sleap_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@
save_labelstudio,
load_jabs,
save_jabs,
load_video,
load_file,
)
116 changes: 110 additions & 6 deletions sleap_io/io/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""This module contains high-level wrappers for utilizing different I/O backends."""

from __future__ import annotations
from sleap_io import Labels, Skeleton
from sleap_io import Labels, Skeleton, Video
from sleap_io.io import slp, nwb, labelstudio, jabs
from typing import Optional, Union
from pathlib import Path
Expand Down Expand Up @@ -77,7 +77,12 @@ def load_labelstudio(


def save_labelstudio(labels: Labels, filename: str):
"""Save a SLEAP dataset to Label Studio format."""
"""Save a SLEAP dataset to Label Studio format.
Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to save labels to ending with `.json`.
"""
labelstudio.write_labels(labels, filename)


Expand All @@ -95,11 +100,110 @@ def load_jabs(filename: str, skeleton: Optional[Skeleton] = None) -> Labels:


def save_jabs(labels: Labels, pose_version: int, root_folder: Optional[str] = None):
"""Save a SLEAP dataset to JABS pose file format. Filenames for JABS poses are based on video filenames.
"""Save a SLEAP dataset to JABS pose file format.
Args:
labels: SLEAP `Labels` object
pose_version: The JABS pose version to write data out
root_folder: Optional root folder where the files should be saved
labels: SLEAP `Labels` object.
pose_version: The JABS pose version to write data out.
root_folder: Optional root folder where the files should be saved.
Note:
Filenames for JABS poses are based on video filenames.
"""
jabs.write_labels(labels, pose_version, root_folder)


def load_video(filename: str, **kwargs) -> Video:
"""Load a video file.
Args:
filename: Path to a video file.
Returns:
A `Video` object.
"""
return Video.from_filename(filename, **kwargs)


def load_file(
filename: str | Path, format: Optional[str] = None, **kwargs
) -> Union[Labels, Video]:
"""Load a file and return the appropriate object.
Args:
filename: Path to a file.
format: Optional format to load as. If not provided, will be inferred from the
file extension. Available formats are: "slp", "nwb", "labelstudio", "jabs"
and "video".
Returns:
A `Labels` or `Video` object.
"""
if isinstance(filename, Path):
filename = str(filename)

if format is None:
if filename.endswith(".slp"):
format = "slp"
elif filename.endswith(".nwb"):
format = "nwb"
elif filename.endswith(".json"):
format = "json"
elif filename.endswith(".h5"):
format = "jabs"
else:
for vid_ext in Video.EXTS:
if filename.endswith(vid_ext):
format = "video"
break
if format is None:
raise ValueError(f"Could not infer format from filename: '{filename}'.")

if filename.endswith(".slp"):
return load_slp(filename, **kwargs)
elif filename.endswith(".nwb"):
return load_nwb(filename, **kwargs)
elif filename.endswith(".json"):
return load_labelstudio(filename, **kwargs)
elif filename.endswith(".h5"):
return load_jabs(filename, **kwargs)
elif format == "video":
return load_video(filename, **kwargs)


def save_file(
labels: Labels, filename: str | Path, format: Optional[str] = None, **kwargs
):
"""Save a file based on the extension.
Args:
labels: A SLEAP `Labels` object (see `load_slp`).
filename: Path to save labels to.
format: Optional format to save as. If not provided, will be inferred from the
file extension. Available formats are: "slp", "nwb", "labelstudio" and
"jabs".
"""
if isinstance(filename, Path):
filename = str(filename)

if format is None:
if filename.endswith(".slp"):
format = "slp"
elif filename.endswith(".nwb"):
format = "nwb"
elif filename.endswith(".json"):
format = "labelstudio"
elif "pose_version" in kwargs:
format = "jabs"

if format == "slp":
save_slp(labels, filename, **kwargs)
elif format == "nwb":
save_nwb(labels, filename, **kwargs)
elif format == "labelstudio":
save_labelstudio(labels, filename, **kwargs)
elif format == "jabs":
pose_version = kwargs.pop("pose_version", 5)
save_jabs(labels, pose_version, filename, **kwargs)
else:
raise ValueError(f"Unknown format '{format}' for filename: '{filename}'.")
4 changes: 3 additions & 1 deletion sleap_io/model/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from attrs import define
from typing import Tuple, Optional, Optional
import numpy as np
from sleap_io.io.video import VideoBackend
from sleap_io.io.video import VideoBackend, MediaVideo, HDF5Video
from pathlib import Path


Expand All @@ -33,6 +33,8 @@ class Video:
filename: str
backend: Optional[VideoBackend] = None

EXTS = MediaVideo.EXTS + HDF5Video.EXTS

def __attrs_post_init__(self):
"""Set the video backend if not already set."""
if self.backend is None:
Expand Down
44 changes: 44 additions & 0 deletions tests/io/test_main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for functions in the sleap_io.io.main file."""

import pytest
from sleap_io import Labels
from sleap_io.io.main import (
load_slp,
Expand All @@ -9,19 +10,24 @@
save_labelstudio,
load_jabs,
save_jabs,
load_video,
load_file,
save_file,
)


def test_load_slp(slp_typical):
"""Test `load_slp` loads a .slp to a `Labels` object."""
assert type(load_slp(slp_typical)) == Labels
assert type(load_file(slp_typical)) == Labels


def test_nwb(tmp_path, slp_typical):
labels = load_slp(slp_typical)
save_nwb(labels, tmp_path / "test_nwb.nwb")
loaded_labels = load_nwb(tmp_path / "test_nwb.nwb")
assert type(loaded_labels) == Labels
assert type(load_file(tmp_path / "test_nwb.nwb")) == Labels
assert len(loaded_labels) == len(labels)

labels2 = load_slp(slp_typical)
Expand All @@ -38,6 +44,7 @@ def test_labelstudio(tmp_path, slp_typical):
save_labelstudio(labels, tmp_path / "test_labelstudio.json")
loaded_labels = load_labelstudio(tmp_path / "test_labelstudio.json")
assert type(loaded_labels) == Labels
assert type(load_file(tmp_path / "test_labelstudio.json")) == Labels
assert len(loaded_labels) == len(labels)


Expand All @@ -48,6 +55,7 @@ def test_jabs(tmp_path, jabs_real_data_v2, jabs_real_data_v5):
labels_single_written = load_jabs(str(tmp_path / jabs_real_data_v2))
# Confidence field is not preserved, so just check number of labels
assert len(labels_single) == len(labels_single_written)
assert type(load_file(jabs_real_data_v2)) == Labels

labels_multi = load_jabs(jabs_real_data_v5)
assert isinstance(labels_multi, Labels)
Expand All @@ -58,3 +66,39 @@ def test_jabs(tmp_path, jabs_real_data_v2, jabs_real_data_v5):
# v5 contains all v4 and v3 data, so only need to check v5
# Confidence field and ordering of identities is not preserved, so just check number of labels
assert len(labels_v5_written) == len(labels_multi)


def test_load_video(centered_pair_low_quality_path):
assert load_video(centered_pair_low_quality_path).shape == (1100, 384, 384, 1)
assert load_file(centered_pair_low_quality_path).shape == (1100, 384, 384, 1)


@pytest.mark.parametrize("format", ["slp", "nwb", "labelstudio", "jabs"])
def test_load_save_file(format, tmp_path, slp_typical, jabs_real_data_v5):
if format == "slp":
labels = load_slp(slp_typical)
save_file(labels, tmp_path / "test.slp")
assert type(load_file(tmp_path / "test.slp")) == Labels
elif format == "nwb":
labels = load_slp(slp_typical)
save_file(labels, tmp_path / "test.nwb")
assert type(load_file(tmp_path / "test.nwb")) == Labels
elif format == "labelstudio":
labels = load_slp(slp_typical)
save_file(labels, tmp_path / "test.json")
assert type(load_file(tmp_path / "test.json")) == Labels
elif format == "jabs":
labels = load_jabs(jabs_real_data_v5)
save_file(labels, tmp_path, pose_version=5)
assert type(load_file(tmp_path / jabs_real_data_v5)) == Labels

save_file(labels, tmp_path, format="jabs")
assert type(load_file(tmp_path / jabs_real_data_v5)) == Labels


def test_load_save_file_invalid():
with pytest.raises(ValueError):
load_file("invalid_file.ext")

with pytest.raises(ValueError):
save_file(Labels(), "invalid_file.ext")

0 comments on commit 599b207

Please sign in to comment.