diff --git a/sleap_io/__init__.py b/sleap_io/__init__.py index f945d711..9e63e48b 100644 --- a/sleap_io/__init__.py +++ b/sleap_io/__init__.py @@ -24,4 +24,6 @@ save_labelstudio, load_jabs, save_jabs, + load_video, + load_file, ) diff --git a/sleap_io/io/main.py b/sleap_io/io/main.py index 433457df..1bf25e3c 100644 --- a/sleap_io/io/main.py +++ b/sleap_io/io/main.py @@ -1,7 +1,7 @@ """This module contains high-level wrappers for utilizing different I/O backends.""" from __future__ import annotations -from sleap_io import Labels, Skeleton +from sleap_io import Labels, Skeleton, Video from sleap_io.io import slp, nwb, labelstudio, jabs from typing import Optional, Union from pathlib import Path @@ -77,7 +77,12 @@ def load_labelstudio( def save_labelstudio(labels: Labels, filename: str): - """Save a SLEAP dataset to Label Studio format.""" + """Save a SLEAP dataset to Label Studio format. + + Args: + labels: A SLEAP `Labels` object (see `load_slp`). + filename: Path to save labels to ending with `.json`. + """ labelstudio.write_labels(labels, filename) @@ -95,11 +100,110 @@ def load_jabs(filename: str, skeleton: Optional[Skeleton] = None) -> Labels: def save_jabs(labels: Labels, pose_version: int, root_folder: Optional[str] = None): - """Save a SLEAP dataset to JABS pose file format. Filenames for JABS poses are based on video filenames. + """Save a SLEAP dataset to JABS pose file format. Args: - labels: SLEAP `Labels` object - pose_version: The JABS pose version to write data out - root_folder: Optional root folder where the files should be saved + labels: SLEAP `Labels` object. + pose_version: The JABS pose version to write data out. + root_folder: Optional root folder where the files should be saved. + + Note: + Filenames for JABS poses are based on video filenames. """ jabs.write_labels(labels, pose_version, root_folder) + + +def load_video(filename: str, **kwargs) -> Video: + """Load a video file. + + Args: + filename: Path to a video file. + + Returns: + A `Video` object. + """ + return Video.from_filename(filename, **kwargs) + + +def load_file( + filename: str | Path, format: Optional[str] = None, **kwargs +) -> Union[Labels, Video]: + """Load a file and return the appropriate object. + + Args: + filename: Path to a file. + format: Optional format to load as. If not provided, will be inferred from the + file extension. Available formats are: "slp", "nwb", "labelstudio", "jabs" + and "video". + + Returns: + A `Labels` or `Video` object. + """ + if isinstance(filename, Path): + filename = str(filename) + + if format is None: + if filename.endswith(".slp"): + format = "slp" + elif filename.endswith(".nwb"): + format = "nwb" + elif filename.endswith(".json"): + format = "json" + elif filename.endswith(".h5"): + format = "jabs" + else: + for vid_ext in Video.EXTS: + if filename.endswith(vid_ext): + format = "video" + break + if format is None: + raise ValueError(f"Could not infer format from filename: '{filename}'.") + + if filename.endswith(".slp"): + return load_slp(filename, **kwargs) + elif filename.endswith(".nwb"): + return load_nwb(filename, **kwargs) + elif filename.endswith(".json"): + return load_labelstudio(filename, **kwargs) + elif filename.endswith(".h5"): + return load_jabs(filename, **kwargs) + elif format == "video": + return load_video(filename, **kwargs) + + +def save_file( + labels: Labels, filename: str | Path, format: Optional[str] = None, **kwargs +): + """Save a file based on the extension. + + Args: + labels: A SLEAP `Labels` object (see `load_slp`). + filename: Path to save labels to. + format: Optional format to save as. If not provided, will be inferred from the + file extension. Available formats are: "slp", "nwb", "labelstudio" and + "jabs". + """ + if isinstance(filename, Path): + filename = str(filename) + + if format is None: + if filename.endswith(".slp"): + format = "slp" + elif filename.endswith(".nwb"): + format = "nwb" + elif filename.endswith(".json"): + format = "labelstudio" + elif "pose_version" in kwargs: + format = "jabs" + + if format == "slp": + save_slp(labels, filename, **kwargs) + elif format == "nwb": + save_nwb(labels, filename, **kwargs) + elif format == "labelstudio": + save_labelstudio(labels, filename, **kwargs) + elif format == "jabs": + pose_version = kwargs.pop("pose_version", 5) + save_jabs(labels, pose_version, filename, **kwargs) + else: + raise ValueError(f"Unknown format '{format}' for filename: '{filename}'.") diff --git a/sleap_io/model/video.py b/sleap_io/model/video.py index f78d9219..3ee7e17f 100644 --- a/sleap_io/model/video.py +++ b/sleap_io/model/video.py @@ -8,7 +8,7 @@ from attrs import define from typing import Tuple, Optional, Optional import numpy as np -from sleap_io.io.video import VideoBackend +from sleap_io.io.video import VideoBackend, MediaVideo, HDF5Video from pathlib import Path @@ -33,6 +33,8 @@ class Video: filename: str backend: Optional[VideoBackend] = None + EXTS = MediaVideo.EXTS + HDF5Video.EXTS + def __attrs_post_init__(self): """Set the video backend if not already set.""" if self.backend is None: diff --git a/tests/io/test_main.py b/tests/io/test_main.py index f70fbacd..a9255b63 100644 --- a/tests/io/test_main.py +++ b/tests/io/test_main.py @@ -1,5 +1,6 @@ """Tests for functions in the sleap_io.io.main file.""" +import pytest from sleap_io import Labels from sleap_io.io.main import ( load_slp, @@ -9,12 +10,16 @@ save_labelstudio, load_jabs, save_jabs, + load_video, + load_file, + save_file, ) def test_load_slp(slp_typical): """Test `load_slp` loads a .slp to a `Labels` object.""" assert type(load_slp(slp_typical)) == Labels + assert type(load_file(slp_typical)) == Labels def test_nwb(tmp_path, slp_typical): @@ -22,6 +27,7 @@ def test_nwb(tmp_path, slp_typical): save_nwb(labels, tmp_path / "test_nwb.nwb") loaded_labels = load_nwb(tmp_path / "test_nwb.nwb") assert type(loaded_labels) == Labels + assert type(load_file(tmp_path / "test_nwb.nwb")) == Labels assert len(loaded_labels) == len(labels) labels2 = load_slp(slp_typical) @@ -38,6 +44,7 @@ def test_labelstudio(tmp_path, slp_typical): save_labelstudio(labels, tmp_path / "test_labelstudio.json") loaded_labels = load_labelstudio(tmp_path / "test_labelstudio.json") assert type(loaded_labels) == Labels + assert type(load_file(tmp_path / "test_labelstudio.json")) == Labels assert len(loaded_labels) == len(labels) @@ -48,6 +55,7 @@ def test_jabs(tmp_path, jabs_real_data_v2, jabs_real_data_v5): labels_single_written = load_jabs(str(tmp_path / jabs_real_data_v2)) # Confidence field is not preserved, so just check number of labels assert len(labels_single) == len(labels_single_written) + assert type(load_file(jabs_real_data_v2)) == Labels labels_multi = load_jabs(jabs_real_data_v5) assert isinstance(labels_multi, Labels) @@ -58,3 +66,39 @@ def test_jabs(tmp_path, jabs_real_data_v2, jabs_real_data_v5): # v5 contains all v4 and v3 data, so only need to check v5 # Confidence field and ordering of identities is not preserved, so just check number of labels assert len(labels_v5_written) == len(labels_multi) + + +def test_load_video(centered_pair_low_quality_path): + assert load_video(centered_pair_low_quality_path).shape == (1100, 384, 384, 1) + assert load_file(centered_pair_low_quality_path).shape == (1100, 384, 384, 1) + + +@pytest.mark.parametrize("format", ["slp", "nwb", "labelstudio", "jabs"]) +def test_load_save_file(format, tmp_path, slp_typical, jabs_real_data_v5): + if format == "slp": + labels = load_slp(slp_typical) + save_file(labels, tmp_path / "test.slp") + assert type(load_file(tmp_path / "test.slp")) == Labels + elif format == "nwb": + labels = load_slp(slp_typical) + save_file(labels, tmp_path / "test.nwb") + assert type(load_file(tmp_path / "test.nwb")) == Labels + elif format == "labelstudio": + labels = load_slp(slp_typical) + save_file(labels, tmp_path / "test.json") + assert type(load_file(tmp_path / "test.json")) == Labels + elif format == "jabs": + labels = load_jabs(jabs_real_data_v5) + save_file(labels, tmp_path, pose_version=5) + assert type(load_file(tmp_path / jabs_real_data_v5)) == Labels + + save_file(labels, tmp_path, format="jabs") + assert type(load_file(tmp_path / jabs_real_data_v5)) == Labels + + +def test_load_save_file_invalid(): + with pytest.raises(ValueError): + load_file("invalid_file.ext") + + with pytest.raises(ValueError): + save_file(Labels(), "invalid_file.ext")