diff --git a/.gitignore b/.gitignore index f8070c49..4c9ab20a 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ +lcov.info # Translations *.mo diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 220d37d9..e58b85ff 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -93,6 +93,15 @@ We check for coverage by parsing the outputs from `pytest` and uploading to [Cod All changes should aim to increase or maintain test coverage. +### Live coverage + +*The following steps are based on [this guide](https://jasonstitt.com/perfect-python-live-test-coverage).* + +1. If you already have an environment installed, `pip install -e ."[dev]"` to make sure you have the latest dev tools (namely `pytest-watch`). +2. Install the [Coverage Gutters extension](https://marketplace.visualstudio.com/items?itemName=ryanluker.vscode-coverage-gutters) in VS Code. +3. Open a terminal, `conda activate sleap-io` and then run `ptw` to automatically run tests. This will generate a new `lcov.info` file when it's done. +4. Enable the coverage gutters by using **Ctrl/Cmd**+**Shift**+**P**, then **Coverage Gutters: Display Coverage**. + ### Code style To standardize formatting conventions, we use [`black`](https://black.readthedocs.io/en/stable/). diff --git a/pyproject.toml b/pyproject.toml index 6ee831cf..0fb2c4ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,8 +7,7 @@ name = "sleap-io" authors = [ {name = "Liezl Maree", email = "lmaree@salk.edu"}, {name = "David Samy", email = "davidasamy@gmail.com"}, - {name = "Talmo Pereira", email = "talmo@salk.edu"} -] + {name = "Talmo Pereira", email = "talmo@salk.edu"}] description="Standalone utilities for working with pose data from SLEAP and other tools." requires-python = ">=3.7" keywords = ["sleap", "pose tracking", "pose estimation", "behavior"] @@ -19,8 +18,7 @@ classifiers = [ "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12" -] + "Programming Language :: Python :: 3.12"] dependencies = [ "numpy", "attrs", @@ -31,8 +29,7 @@ dependencies = [ "simplejson", "imageio", "imageio-ffmpeg", - "av" -] + "av"] dynamic = ["version", "readme"] [tool.setuptools.dynamic] @@ -43,6 +40,7 @@ readme = {file = ["README.md"], content-type="text/markdown"} dev = [ "pytest", "pytest-cov", + "pytest-watch", "black", "pydocstyle", "toml", @@ -52,16 +50,24 @@ dev = [ "mkdocs-jupyter", "mkdocstrings[python]>=0.18", "mkdocs-gen-files", - "mkdocs-literate-nav" -] + "mkdocs-literate-nav"] [project.urls] -Homepage = "https://sleap.ai" +Homepage = "https://io.sleap.ai" Repository = "https://github.com/talmolab/sleap-io" +[tool.setuptools.packages.find] +exclude = ["site"] + [tool.black] line-length = 88 [pydocstyle] convention = "google" match-dir = "sleap_io" + +[tool.coverage.run] +source = ["livecov"] + +[tool.pytest.ini_options] +addopts = "--cov sleap_io --cov-report=lcov:lcov.info --cov-report=term" diff --git a/sleap_io/io/jabs.py b/sleap_io/io/jabs.py index 6aca6f7d..fd687f72 100644 --- a/sleap_io/io/jabs.py +++ b/sleap_io/io/jabs.py @@ -86,6 +86,7 @@ def read_labels( frames: List[LabeledFrame] = [] # Video name is the pose file minus the suffix video_name = re.sub(r"(_pose_est_v[2-6])?\.h5", ".avi", labels_path) + video = Video.from_filename(video_name) if not skeleton: skeleton = JABS_DEFAULT_SKELETON tracks = {} @@ -166,7 +167,7 @@ def read_labels( ) if new_instance: instances.append(new_instance) - frame_label = LabeledFrame(Video(video_name), frame_idx, instances) + frame_label = LabeledFrame(video, frame_idx, instances) frames.append(frame_label) return Labels(frames) diff --git a/sleap_io/io/main.py b/sleap_io/io/main.py index 23720b86..32432ac4 100644 --- a/sleap_io/io/main.py +++ b/sleap_io/io/main.py @@ -19,14 +19,20 @@ def load_slp(filename: str) -> Labels: return slp.read_labels(filename) -def save_slp(labels: Labels, filename: str): +def save_slp( + labels: Labels, filename: str, embed: str | list[tuple[Video, int]] | None = None +): """Save a SLEAP dataset to a `.slp` file. Args: labels: A SLEAP `Labels` object (see `load_slp`). filename: Path to save labels to ending with `.slp`. + embed: One of `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"` or + list of tuples of `(video, frame_idx)` specifying the frames to embed. If + `"source"` is specified, no images will be embedded and the source video + will be restored if available. """ - return slp.write_labels(filename, labels) + return slp.write_labels(filename, labels, embed=embed) def load_nwb(filename: str) -> Labels: diff --git a/sleap_io/io/slp.py b/sleap_io/io/slp.py index 0218645c..4a1f5f85 100644 --- a/sleap_io/io/slp.py +++ b/sleap_io/io/slp.py @@ -20,17 +20,20 @@ LabeledFrame, Labels, ) -from sleap_io.io.video import ImageVideo, MediaVideo, HDF5Video +from sleap_io.io.video import VideoBackend, ImageVideo, MediaVideo, HDF5Video from sleap_io.io.utils import ( read_hdf5_attrs, read_hdf5_dataset, - write_hdf5_dataset, - write_hdf5_group, - write_hdf5_attrs, ) -from sleap_io.io.video import VideoBackend from enum import IntEnum from pathlib import Path +import imageio.v3 as iio +import sys + +try: + import cv2 +except ImportError: + pass class InstanceType(IntEnum): @@ -40,6 +43,76 @@ class InstanceType(IntEnum): PREDICTED = 1 +def make_video( + labels_path: str, video_json: dict, video_ind: int | None = None +) -> Video: + """Create a `Video` object from a JSON dictionary. + + Args: + labels_path: A string path to the SLEAP labels file. + video_json: A dictionary containing the video metadata. + video_ind: The index of the video in the labels file. This is used to try to + recover the source video for embedded videos. This is skipped if `None`. + """ + backend_metadata = video_json["backend"] + video_path = backend_metadata["filename"] + + # Marker for embedded videos. + source_video = None + is_embedded = False + if video_path == ".": + video_path = labels_path + is_embedded = True + + # Basic path resolution. + video_path = Path(video_path) + if not video_path.exists(): + # Check for the same filename in the same directory as the labels file. + video_path_ = Path(labels_path).parent / video_path.name + if video_path_.exists(): + video_path = video_path_ + else: + # TODO (TP): Expand capabilities of path resolution to support more + # complex path finding strategies. + pass + + # Convert video path to string. + video_path = video_path.as_posix() + + if is_embedded: + # Try to recover the source video. + with h5py.File(labels_path, "r") as f: + if f"video{video_ind}" in f: + source_video_json = json.loads( + f[f"video{video_ind}/source_video"].attrs["json"] + ) + source_video = make_video( + labels_path, source_video_json, video_ind=None + ) + + if "filenames" in backend_metadata: + # This is an ImageVideo. + # TODO: Path resolution. + video_path = backend_metadata["filenames"] + + try: + backend = VideoBackend.from_filename( + video_path, + dataset=backend_metadata.get("dataset", None), + grayscale=backend_metadata.get("grayscale", None), + input_format=backend_metadata.get("input_format", None), + ) + except ValueError: + backend = None + + return Video( + filename=video_path, + backend=backend, + backend_metadata=backend_metadata, + source_video=source_video, + ) + + def read_videos(labels_path: str) -> list[Video]: """Read `Video` dataset in a SLEAP labels file. @@ -49,105 +122,296 @@ def read_videos(labels_path: str) -> list[Video]: Returns: A list of `Video` objects. """ - # TODO (DS) - Find shape of video - videos = [json.loads(x) for x in read_hdf5_dataset(labels_path, "videos_json")] - video_objects = [] - for video in videos: - backend = video["backend"] - video_path = backend["filename"] - - # Marker for embedded videos. - if video_path == ".": - video_path = labels_path - - # Basic path resolution. - video_path = Path(video_path) - if not video_path.exists(): - # Check for the same filename in the same directory as the labels file. - video_path_ = Path(labels_path).parent / video_path.name - if video_path_.exists(): - video_path = video_path_ + videos = [] + for video_ind, video_data in enumerate( + read_hdf5_dataset(labels_path, "videos_json") + ): + video_json = json.loads(video_data) + video = make_video(labels_path, video_json, video_ind=video_ind) + videos.append(video) + return videos + + +def video_to_dict(video: Video) -> dict: + """Convert a `Video` object to a JSON-compatible dictionary. + + Args: + video: A `Video` object to convert. + + Returns: + A dictionary containing the video metadata. + """ + if video.backend is None: + return {"filename": video.filename, "backend": video.backend_metadata} + + if type(video.backend) == MediaVideo: + return { + "filename": video.filename, + "backend": { + "type": "MediaVideo", + "shape": video.shape, + "filename": video.filename, + "grayscale": video.grayscale, + "bgr": True, + "dataset": "", + "input_format": "", + }, + } + + elif type(video.backend) == HDF5Video: + return { + "filename": video.filename, + "backend": { + "type": "HDF5Video", + "shape": video.shape, + "filename": ( + "." if video.backend.has_embedded_images else video.filename + ), + "dataset": video.backend.dataset, + "input_format": video.backend.input_format, + "convert_range": False, + "has_embedded_images": video.backend.has_embedded_images, + }, + } + + elif type(video.backend) == ImageVideo: + return { + "filename": video.filename, + "backend": { + "type": "ImageVideo", + "shape": video.shape, + "filename": video.backend.filename[0], + "filenames": video.backend.filename, + "dataset": video.backend_metadata.get("dataset", None), + "grayscale": video.grayscale, + "input_format": video.backend_metadata.get("input_format", None), + }, + } + + +def embed_video( + labels_path: str, + video: Video, + group: str, + frame_inds: list[int], + image_format: str = "png", + fixed_length: bool = True, +) -> Video: + """Embed frames of a video in a SLEAP labels file. + + Args: + labels_path: A string path to the SLEAP labels file. + video: A `Video` object to embed in the labels file. + group: The name of the group to store the embedded video in. Image data will be + stored in a dataset named `{group}/video`. Frame indices will be stored + in a data set named `{group}/frame_numbers`. + frame_inds: A list of frame indices to embed. + image_format: The image format to use for embedding. Valid formats are "png" + (the default), "jpg" or "hdf5". + fixed_length: If `True` (the default), the embedded images will be padded to the + length of the largest image. If `False`, the images will be stored as + variable length, which is smaller but may not be supported by all readers. + + Returns: + An embedded `Video` object. + + If the video is already embedded, the original video will be returned. If not, + a new `Video` object will be created with the embedded data. + """ + # Load the image data and optionally encode it. + imgs_data = [] + for frame_idx in frame_inds: + frame = video[frame_idx] + + if image_format == "hdf5": + img_data = frame + else: + if "cv2" in sys.modules: + img_data = np.squeeze( + cv2.imencode("." + image_format, frame)[1] + ).astype("int8") else: - # TODO (TP): Expand capabilities of path resolution to support more - # complex path finding strategies. - pass - - video_path = video_path.as_posix() - - if "filenames" in backend: - # This is an ImageVideo. - # TODO: Path resolution. - video_path = backend["filenames"] - - try: - backend = VideoBackend.from_filename( - video_path, - dataset=backend.get("dataset", None), - grayscale=backend.get("grayscale", None), - input_format=backend.get("input_format", None), + img_data = np.frombuffer( + iio.imwrite( + "", frame.squeeze(axis=-1), extension="." + image_format + ), + dtype="int8", + ) + + imgs_data.append(img_data) + + # Write the image data to the labels file. + with h5py.File(labels_path, "a") as f: + if image_format == "hdf5": + f.create_dataset( + f"{group}/video", data=imgs_data, compression="gzip", chunks=True ) - except ValueError: - backend = None - video_objects.append(Video(filename=video_path, backend=backend)) - return video_objects + else: + if fixed_length: + ds = f.create_dataset( + f"{group}/video", + shape=(len(imgs_data), max(len(img) for img in imgs_data)), + dtype="int8", + compression="gzip", + ) + for i, img in enumerate(imgs_data): + ds[i, : len(img)] = img + else: + ds = f.create_dataset( + f"{group}/video", + shape=(len(imgs_data),), + dtype=h5py.special_dtype(vlen=np.dtype("int8")), + ) + for i, img in enumerate(imgs_data): + ds[i] = img + + # Store metadata. + ds.attrs["format"] = image_format + ( + ds.attrs["frames"], + ds.attrs["height"], + ds.attrs["width"], + ds.attrs["channels"], + ) = video.shape + + # Store frame indices. + f.create_dataset(f"{group}/frame_numbers", data=frame_inds) + + # Store source video. + if video.source_video is not None: + # If this is already an embedded dataset, retain the previous source video. + source_video = video.source_video + embedded_video = video + video.replace_filename(labels_path, open=False) + else: + source_video = video + embedded_video = Video( + filename=labels_path, + backend=VideoBackend.from_filename( + labels_path, + dataset=f"{group}/video", + grayscale=video.grayscale, + keep_open=False, + ), + source_video=source_video, + ) + + grp = f.require_group(f"{group}/source_video") + grp.attrs["json"] = json.dumps( + video_to_dict(source_video), separators=(",", ":") + ) + + return embedded_video + + +def embed_frames( + labels_path: str, + labels: Labels, + embed: list[tuple[Video, int]], + image_format: str = "png", +): + """Embed frames in a SLEAP labels file. + + Args: + labels_path: A string path to the SLEAP labels file. + labels: A `Labels` object to embed in the labels file. + embed: A list of tuples of `(video, frame_idx)` specifying the frames to embed. + image_format: The image format to use for embedding. Valid formats are "png" + (the default), "jpg" or "hdf5". + + Notes: + This function will embed the frames in the labels file and update the `Videos` + and `Labels` objects in place. + """ + to_embed_by_video = {} + for video, frame_idx in embed: + if video not in to_embed_by_video: + to_embed_by_video[video] = [] + to_embed_by_video[video].append(frame_idx) + + replaced_videos = {} + for video, frame_inds in to_embed_by_video.items(): + video_ind = labels.videos.index(video) + embedded_video = embed_video( + labels_path, + video, + group=f"video{video_ind}", + frame_inds=frame_inds, + image_format=image_format, + ) + + labels.videos[video_ind] = embedded_video + replaced_videos[video] = embedded_video + + if len(replaced_videos) > 0: + labels.replace_videos(video_map=replaced_videos) -def write_videos(labels_path: str, videos: list[Video]): +def embed_videos( + labels_path: str, labels: Labels, embed: str | list[tuple[Video, int]] +): + """Embed videos in a SLEAP labels file. + + Args: + labels_path: A string path to the SLEAP labels file to save. + labels: A `Labels` object to save. + embed: One of `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"` or + list of tuples of `(video, frame_idx)` specifying the frames to embed. If + `"source"` is specified, no images will be embedded and the source video + will be restored if available. + """ + if embed == "user": + embed = [(lf.video, lf.frame_idx) for lf in labels.user_labeled_frames] + elif embed == "suggestions": + embed = [(sf.video, sf.frame_idx) for sf in labels.suggestions] + elif embed == "user+suggestions": + embed = [(lf.video, lf.frame_idx) for lf in labels.user_labeled_frames] + embed += [(sf.video, sf.frame_idx) for sf in labels.suggestions] + elif embed == "source": + embed = [] + elif isinstance(embed, list): + embed = embed + else: + raise ValueError(f"Invalid value for embed: {embed}") + + embed_frames(labels_path, labels, embed) + + +def write_videos(labels_path: str, videos: list[Video], restore_source: bool = False): """Write video metadata to a SLEAP labels file. Args: labels_path: A string path to the SLEAP labels file. videos: A list of `Video` objects to store the metadata for. + restore_source: If `True`, restore source videos if available and will not + re-embed the embedded images. If `False` (the default), will re-embed images + that were previously embedded. """ video_jsons = [] - for video in videos: - if type(video.backend) == MediaVideo: - video_json = { - "backend": { - "filename": video.filename, - "grayscale": video.backend.grayscale, - "bgr": True, - "dataset": "", - "input_format": "", - } - } + for video_ind, video in enumerate(videos): - elif type(video.backend) == HDF5Video: - video_json = { - "backend": { - "filename": ( - "." if video.backend.has_embedded_images else video.filename - ), - "dataset": video.backend.dataset, - "input_format": video.backend.input_format, - "convert_range": False, - } - } - # TODO: Handle saving embedded images or restoring source video. - # Ref: https://github.com/talmolab/sleap/blob/fb61b6ce7a9ac9613d99303111f3daafaffc299b/sleap/io/format/hdf5.py#L246-L273 - - elif type(video.backend) == ImageVideo: - shape = video.shape - if shape is None: - height, width, channels = 0, 0, 1 + if type(video.backend) == HDF5Video and video.backend.has_embedded_images: + if restore_source: + video = video.source_video else: - height, width, channels = shape[1:] - - video_json = { - "backend": { - "filename": video.filename[0], - "filenames": video.filename, - "height_": height, - "width_": width, - "channels_": channels, - "grayscale": video.backend.grayscale, - } - } + # If the video has embedded images, embed them images again if we haven't + # already. + already_embedded = False + if Path(labels_path).exists(): + with h5py.File(labels_path, "r") as f: + already_embedded = f"video{video_ind}/video" in f + + if not already_embedded: + video = embed_video( + labels_path, + video, + group=f"video{video_ind}", + frame_inds=video.backend.source_inds, + image_format=video.backend.image_format, + ) + + video_json = video_to_dict(video) - else: - raise NotImplementedError( - f"Cannot serialize video backend for video: {video}" - ) video_jsons.append(np.string_(json.dumps(video_json, separators=(",", ":")))) with h5py.File(labels_path, "a") as f: @@ -753,16 +1017,25 @@ def read_labels(labels_path: str) -> Labels: return labels -def write_labels(labels_path: str, labels: Labels): +def write_labels( + labels_path: str, labels: Labels, embed: str | list[tuple[Video, int]] | None = None +): """Write a SLEAP labels file. Args: labels_path: A string path to the SLEAP labels file to save. labels: A `Labels` object to save. + embed: One of `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"`, + `None` or list of tuples of `(video, frame_idx)` specifying the frames to + embed. If `"source"` is specified, no images will be embedded and the source + video will be restored if available. If `None` is specified (the default), + existing embedded images will be re-embedded. """ if Path(labels_path).exists(): Path(labels_path).unlink() - write_videos(labels_path, labels.videos) + if embed is not None: + embed_videos(labels_path, labels, embed) + write_videos(labels_path, labels.videos, restore_source=(embed == "source")) write_tracks(labels_path, labels.tracks) write_suggestions(labels_path, labels.suggestions, labels.videos) write_metadata(labels_path, labels) diff --git a/sleap_io/io/video.py b/sleap_io/io/video.py index b121c223..74c3d316 100644 --- a/sleap_io/io/video.py +++ b/sleap_io/io/video.py @@ -160,7 +160,8 @@ def num_frames(self) -> int: @property def img_shape(self) -> Tuple[int, int, int]: """Shape of a single frame in the video.""" - return self.get_frame(0).shape + height, width, channels = self.get_frame(0).shape + return int(height), int(width), int(channels) @property def shape(self) -> Tuple[int, int, int, int]: @@ -478,6 +479,8 @@ class HDF5Video(VideoBackend): when reading embedded image datasets. source_inds: Indices of the frames in the source video file. This is metadata and only used when reading embedded image datasets. + image_format: Format of the images in the embedded dataset. This is metadata and + only used when reading embedded image datasets. """ dataset: Optional[str] = None @@ -488,6 +491,7 @@ class HDF5Video(VideoBackend): frame_map: dict[int, int] = attrs.field(init=False, default=attrs.Factory(dict)) source_filename: Optional[str] = None source_inds: Optional[np.ndarray] = None + image_format: str = "hdf5" EXTS = ("h5", "hdf5", "slp") @@ -530,6 +534,9 @@ def find_embedded(name, obj): # This may be an embedded video dataset. Check for frame map. ds = f[self.dataset] + if "format" in ds.attrs: + self.image_format = ds.attrs["format"] + if "frame_numbers" in ds.parent: frame_numbers = ds.parent["frame_numbers"][:] self.frame_map = {frame: idx for idx, frame in enumerate(frame_numbers)} @@ -563,7 +570,7 @@ def img_shape(self) -> Tuple[int, int, int]: img_shape = ds.shape[1:] if self.input_format == "channels_first": img_shape = img_shape[::-1] - return img_shape + return int(img_shape[0]), int(img_shape[1]), int(img_shape[2]) def read_test_frame(self) -> np.ndarray: """Read a single frame from the video to test for grayscale.""" @@ -576,17 +583,19 @@ def read_test_frame(self) -> np.ndarray: @property def has_embedded_images(self) -> bool: """Return True if the dataset contains embedded images.""" - with h5py.File(self.filename, "r") as f: - ds = f[self.dataset] - return "format" in ds.attrs + return self.image_format is not None and self.image_format != "hdf5" - def decode_embedded(self, img_string: np.ndarray, format: str) -> np.ndarray: + @property + def embedded_frame_inds(self) -> list[int]: + """Return the frame indices of the embedded images.""" + return list(self.frame_map.keys()) + + def decode_embedded(self, img_string: np.ndarray) -> np.ndarray: """Decode an embedded image string into a numpy array. Args: img_string: Binary string of the image as a `int8` numpy vector with the bytes as values corresponding to the format-encoded image. - format: Image format (e.g., "png" or "jpg"). Returns: The decoded image as a numpy array of shape `(height, width, channels)`. If @@ -599,7 +608,7 @@ def decode_embedded(self, img_string: np.ndarray, format: str) -> np.ndarray: if "cv2" in sys.modules: img = cv2.imdecode(img_string, cv2.IMREAD_UNCHANGED) else: - img = iio.imread(BytesIO(img_string), extension=f".{format}") + img = iio.imread(BytesIO(img_string), extension=f".{self.image_format}") if img.ndim == 2: img = np.expand_dims(img, axis=-1) @@ -632,8 +641,8 @@ def _read_frame(self, frame_idx: int) -> np.ndarray: img = ds[frame_idx] - if "format" in ds.attrs: - img = self.decode_embedded(img, ds.attrs["format"]) + if self.has_embedded_images: + img = self.decode_embedded(img) if self.input_format == "channels_first": img = np.transpose(img, (2, 1, 0)) @@ -670,7 +679,7 @@ def _read_frames(self, frame_inds: list) -> np.ndarray: if "format" in ds.attrs: imgs = np.stack( - [self.decode_embedded(img, ds.attrs["format"]) for img in imgs], + [self.decode_embedded(img) for img in imgs], axis=0, ) diff --git a/sleap_io/model/labeled_frame.py b/sleap_io/model/labeled_frame.py index e3b131e0..959f3360 100644 --- a/sleap_io/model/labeled_frame.py +++ b/sleap_io/model/labeled_frame.py @@ -11,14 +11,19 @@ import numpy as np -@define(auto_attribs=True) +@define(eq=False) class LabeledFrame: """Labeled data for a single frame of a video. Attributes: - video: The :class:`Video` associated with this `LabeledFrame`. + video: The `Video` associated with this `LabeledFrame`. frame_idx: The index of the `LabeledFrame` in the `Video`. instances: List of `Instance` objects associated with this `LabeledFrame`. + + Notes: + Instances of this class are hashed by identity, not by value. This means that + two `LabeledFrame` instances with the same attributes will NOT be considered + equal in a set or dict. """ video: Video @@ -42,11 +47,27 @@ def user_instances(self) -> list[Instance]: """Frame instances that are user-labeled (`Instance` objects).""" return [inst for inst in self.instances if type(inst) == Instance] + @property + def has_user_instances(self) -> bool: + """Return True if the frame has any user-labeled instances.""" + for inst in self.instances: + if type(inst) == Instance: + return True + return False + @property def predicted_instances(self) -> list[Instance]: """Frame instances that are predicted by a model (`PredictedInstance` objects).""" return [inst for inst in self.instances if type(inst) == PredictedInstance] + @property + def has_predicted_instances(self) -> bool: + """Return True if the frame has any predicted instances.""" + for inst in self.instances: + if type(inst) == PredictedInstance: + return True + return False + def numpy(self) -> np.ndarray: """Return all instances in the frame as a numpy array. diff --git a/sleap_io/model/labels.py b/sleap_io/model/labels.py index a9b979b2..841c272a 100644 --- a/sleap_io/model/labels.py +++ b/sleap_io/model/labels.py @@ -279,7 +279,13 @@ def find( return results - def save(self, filename: str, format: Optional[str] = None, **kwargs): + def save( + self, + filename: str, + format: Optional[str] = None, + embed: str | list[tuple[Video, int]] | None = None, + **kwargs, + ): """Save labels to file in specified format. Args: @@ -287,10 +293,15 @@ def save(self, filename: str, format: Optional[str] = None, **kwargs): format: The format to save the labels in. If `None`, the format will be inferred from the file extension. Available formats are "slp", "nwb", "labelstudio", and "jabs". + embed: One of `"user"`, `"suggestions"`, `"user+suggestions"`, `"source"` or + list of tuples of `(video, frame_idx)` specifying the frames to embed. + If `"source"` is specified, no images will be embedded and the source + video will be restored if available. This argument is only valid for the + SLP backend. """ from sleap_io import save_file - save_file(self, filename, format=format, **kwargs) + save_file(self, filename, format=format, embed=embed, **kwargs) def clean( self, @@ -374,3 +385,35 @@ def remove_predictions(self, clean: bool = True): tracks=True, videos=False, ) + + @property + def user_labeled_frames(self) -> list[LabeledFrame]: + """Return all labeled frames with user (non-predicted) instances.""" + return [lf for lf in self.labeled_frames if lf.has_user_instances] + + def replace_videos( + self, + old_videos: list[Video] | None = None, + new_videos: list[Video] | None = None, + video_map: dict[Video, Video] | None = None, + ): + """Replace videos and update all references. + + Args: + old_videos: List of videos to be replaced. + new_videos: List of videos to replace with. + video_map: Alternative input of dictionary where keys are the old videos and + values are the new videos. + """ + if video_map is None: + video_map = {o: n for o, n in zip(old_videos, new_videos)} + + # Update the labeled frames with the new videos. + for lf in self.labeled_frames: + if lf.video in video_map: + lf.video = video_map[lf.video] + + # Update suggestions with the new videos. + for sf in self.suggestions: + if sf.video in video_map: + sf.video = video_map[sf.video] diff --git a/sleap_io/model/video.py b/sleap_io/model/video.py index fe9c70cf..2c547804 100644 --- a/sleap_io/model/video.py +++ b/sleap_io/model/video.py @@ -5,14 +5,14 @@ """ from __future__ import annotations -from attrs import define +import attrs from typing import Tuple, Optional, Optional import numpy as np from sleap_io.io.video import VideoBackend, MediaVideo, HDF5Video from pathlib import Path -@define +@attrs.define(eq=False) class Video: """`Video` class used by sleap to represent videos and data associated with them. @@ -26,15 +26,32 @@ class Video: filename: The filename(s) of the video. backend: An object that implements the basic methods for reading and manipulating frames of a specific video type. + backend_metadata: A dictionary of metadata specific to the backend. This is + useful for storing metadata that requires an open backend (e.g., shape + information) without having access to the video file itself. + source_video: The source video object if this is a proxy video. This is present + when the video contains an embedded subset of frames from another video. + + Notes: + Instances of this class are hashed by identity, not by value. This means that + two `Video` instances with the same attributes will NOT be considered equal in a + set or dict. See also: VideoBackend """ filename: str | list[str] backend: Optional[VideoBackend] = None + backend_metadata: dict[str, any] = attrs.field(factory=dict) + source_video: Optional[Video] = None EXTS = MediaVideo.EXTS + HDF5Video.EXTS + def __attrs_post_init__(self): + """Post init syntactic sugar.""" + if self.backend is None and self.exists(): + self.open() + @classmethod def from_filename( cls, @@ -42,6 +59,7 @@ def from_filename( dataset: Optional[str] = None, grayscale: Optional[bool] = None, keep_open: bool = True, + source_video: Optional[Video] = None, **kwargs, ) -> VideoBackend: """Create a Video from a filename. @@ -55,6 +73,9 @@ def from_filename( frames. If False, will close the reader after each call. If True (the default), it will keep the reader open and cache it for subsequent calls which may enhance the performance of reading multiple frames. + source_video: The source video object if this is a proxy video. This is + present when the video contains an embedded subset of frames from + another video. Returns: Video instance with the appropriate backend instantiated. @@ -68,6 +89,7 @@ def from_filename( keep_open=keep_open, **kwargs, ), + source_video=source_video, ) @property @@ -88,6 +110,23 @@ def _get_shape(self) -> Tuple[int, int, int, int] | None: try: return self.backend.shape except: + if "shape" in self.backend_metadata: + return self.backend_metadata["shape"] + return None + + @property + def grayscale(self) -> bool | None: + """Return whether the video is grayscale. + + If the video backend is not set or it cannot determine whether the video is + grayscale, this will return None. + """ + shape = self.shape + if shape is not None: + return shape[-1] == 1 + else: + if "grayscale" in self.backend_metadata: + return self.backend_metadata["grayscale"] return None def __len__(self) -> int: @@ -189,6 +228,12 @@ def open( if grayscale is None: grayscale = getattr(self.backend, "grayscale", None) + else: + if dataset is None and "dataset" in self.backend_metadata: + dataset = self.backend_metadata["dataset"] + if grayscale is None and "grayscale" in self.backend_metadata: + grayscale = self.backend_metadata["grayscale"] + # Close previous backend if open. self.close() diff --git a/tests/io/test_main.py b/tests/io/test_main.py index a9255b63..882c3295 100644 --- a/tests/io/test_main.py +++ b/tests/io/test_main.py @@ -55,6 +55,7 @@ def test_jabs(tmp_path, jabs_real_data_v2, jabs_real_data_v5): labels_single_written = load_jabs(str(tmp_path / jabs_real_data_v2)) # Confidence field is not preserved, so just check number of labels assert len(labels_single) == len(labels_single_written) + assert len(labels_single.videos) == len(labels_single_written.videos) assert type(load_file(jabs_real_data_v2)) == Labels labels_multi = load_jabs(jabs_real_data_v5) @@ -66,6 +67,7 @@ def test_jabs(tmp_path, jabs_real_data_v2, jabs_real_data_v5): # v5 contains all v4 and v3 data, so only need to check v5 # Confidence field and ordering of identities is not preserved, so just check number of labels assert len(labels_v5_written) == len(labels_multi) + assert len(labels_v5_written.videos) == len(labels_multi.videos) def test_load_video(centered_pair_low_quality_path): diff --git a/tests/io/test_slp.py b/tests/io/test_slp.py index c2ea39ad..7d229bc4 100644 --- a/tests/io/test_slp.py +++ b/tests/io/test_slp.py @@ -35,8 +35,11 @@ ) from sleap_io.io.utils import read_hdf5_dataset import numpy as np +import simplejson as json +import pytest +from pathlib import Path -from sleap_io.io.video import ImageVideo +from sleap_io.io.video import ImageVideo, HDF5Video, MediaVideo def test_read_labels(slp_typical, slp_simple_skel, slp_minimal): @@ -101,24 +104,27 @@ def test_read_videos_pkg(slp_minimal_pkg): def test_write_videos(slp_minimal_pkg, centered_pair, tmp_path): - videos = read_videos(slp_minimal_pkg) - write_videos(tmp_path / "test_minimal_pkg.slp", videos) - json_fixture = read_hdf5_dataset(slp_minimal_pkg, "videos_json") - json_test = read_hdf5_dataset(tmp_path / "test_minimal_pkg.slp", "videos_json") - assert json_fixture == json_test - videos = read_videos(centered_pair) - write_videos(tmp_path / "test_centered_pair.slp", videos) - json_fixture = read_hdf5_dataset(centered_pair, "videos_json") - json_test = read_hdf5_dataset(tmp_path / "test_centered_pair.slp", "videos_json") - assert json_fixture == json_test + def compare_videos(videos_ref, videos_test): + assert len(videos_ref) == len(videos_test) + for video_ref, video_test in zip(videos_ref, videos_test): + assert video_ref.shape == video_test.shape + assert (video_ref[0] == video_test[0]).all() + + videos_ref = read_videos(slp_minimal_pkg) + write_videos(tmp_path / "test_minimal_pkg.slp", videos_ref) + videos_test = read_videos(tmp_path / "test_minimal_pkg.slp") + compare_videos(videos_ref, videos_test) + + videos_ref = read_videos(centered_pair) + write_videos(tmp_path / "test_centered_pair.slp", videos_ref) + videos_test = read_videos(tmp_path / "test_centered_pair.slp") + compare_videos(videos_ref, videos_test) videos = read_videos(centered_pair) * 2 write_videos(tmp_path / "test_centered_pair_2vids.slp", videos) - json_test = read_hdf5_dataset( - tmp_path / "test_centered_pair_2vids.slp", "videos_json" - ) - assert len(json_test) == 2 + videos_test = read_videos(tmp_path / "test_centered_pair_2vids.slp") + compare_videos(videos, videos_test) def test_write_tracks(centered_pair, tmp_path): @@ -257,3 +263,79 @@ def test_suggestions(tmpdir): write_videos(tmpdir / "test2.slp", labels.videos) loaded_suggestions = read_suggestions(tmpdir / "test2.slp", labels.videos) assert len(loaded_suggestions) == 0 + + +def test_pkg_roundtrip(tmpdir, slp_minimal_pkg): + labels = read_labels(slp_minimal_pkg) + assert type(labels.video.backend) == HDF5Video + assert labels.video.shape == (1, 384, 384, 1) + assert labels.video.backend.embedded_frame_inds == [0] + assert labels.video.filename == slp_minimal_pkg + + write_labels(str(tmpdir / "roundtrip.pkg.slp"), labels) + labels = read_labels(str(tmpdir / "roundtrip.pkg.slp")) + assert type(labels.video.backend) == HDF5Video + assert labels.video.shape == (1, 384, 384, 1) + assert labels.video.backend.embedded_frame_inds == [0] + assert ( + Path(labels.video.filename).as_posix() + == Path(tmpdir / "roundtrip.pkg.slp").as_posix() + ) + + +@pytest.mark.parametrize("to_embed", ["user", "suggestions", "user+suggestions"]) +def test_embed(tmpdir, slp_real_data, to_embed): + base_labels = read_labels(slp_real_data) + assert type(base_labels.video.backend) == MediaVideo + assert ( + Path(base_labels.video.filename).as_posix() + == "tests/data/videos/centered_pair_low_quality.mp4" + ) + assert base_labels.video.shape == (1100, 384, 384, 1) + assert len(base_labels) == 10 + assert len(base_labels.suggestions) == 10 + assert len(base_labels.user_labeled_frames) == 5 + + labels_path = Path(tmpdir / "labels.pkg.slp").as_posix() + write_labels(labels_path, base_labels, embed=to_embed) + labels = read_labels(labels_path) + assert len(labels) == 10 + assert type(labels.video.backend) == HDF5Video + assert Path(labels.video.filename).as_posix() == labels_path + assert ( + Path(labels.video.source_video.filename).as_posix() + == "tests/data/videos/centered_pair_low_quality.mp4" + ) + if to_embed == "user": + assert labels.video.backend.embedded_frame_inds == [0, 990, 440, 220, 770] + elif to_embed == "suggestions": + assert len(labels.video.backend.embedded_frame_inds) == 10 + elif to_embed == "suggestions+user": + assert len(labels.video.backend.embedded_frame_inds) == 10 + + +def test_embed_two_rounds(tmpdir, slp_real_data): + base_labels = read_labels(slp_real_data) + labels_path = str(tmpdir / "labels.pkg.slp") + write_labels(labels_path, base_labels, embed="user") + labels = read_labels(labels_path) + + assert labels.video.backend.embedded_frame_inds == [0, 990, 440, 220, 770] + + labels2_path = str(tmpdir / "labels2.pkg.slp") + write_labels(labels2_path, labels) + labels2 = read_labels(labels2_path) + assert ( + Path(labels2.video.source_video.filename).as_posix() + == "tests/data/videos/centered_pair_low_quality.mp4" + ) + assert labels2.video.backend.embedded_frame_inds == [0, 990, 440, 220, 770] + + labels3_path = str(tmpdir / "labels3.slp") + write_labels(labels3_path, labels, embed="source") + labels3 = read_labels(labels3_path) + assert ( + Path(labels3.video.filename).as_posix() + == "tests/data/videos/centered_pair_low_quality.mp4" + ) + assert type(labels3.video.backend) == MediaVideo diff --git a/tests/model/test_labeled_frame.py b/tests/model/test_labeled_frame.py index b542124e..42f40428 100644 --- a/tests/model/test_labeled_frame.py +++ b/tests/model/test_labeled_frame.py @@ -1,7 +1,7 @@ """Tests for methods in sleap_io.model.labeled_frame file.""" from numpy.testing import assert_equal -from sleap_io import Video, Skeleton, Instance, PredictedInstance +from sleap_io import Video, Skeleton, Instance, PredictedInstance, Track from sleap_io.model.labeled_frame import LabeledFrame import numpy as np @@ -28,6 +28,9 @@ def test_labeled_frame(): # Test LabeledFrame.__getitem__ method assert lf[0] == inst + assert lf.has_predicted_instances + assert lf.has_user_instances + def test_remove_predictions(): """Test removing predictions from `LabeledFrame`.""" @@ -43,6 +46,8 @@ def test_remove_predictions(): assert len(lf) == 2 assert len(lf.predicted_instances) == 1 + assert lf.has_predicted_instances + assert lf.has_user_instances # Remove predictions lf.remove_predictions() @@ -51,6 +56,8 @@ def test_remove_predictions(): assert len(lf.predicted_instances) == 0 assert type(lf[0]) == Instance assert_equal(lf.numpy(), [[[0, 1], [2, 3]]]) + assert not lf.has_predicted_instances + assert lf.has_user_instances def test_remove_empty_instances(): @@ -75,3 +82,44 @@ def test_remove_empty_instances(): assert len(lf) == 1 assert type(lf[0]) == Instance assert_equal(lf.numpy(), [[[0, 1], [2, 3]]]) + + +def test_labeled_frame_image(centered_pair_low_quality_path): + video = Video.from_filename(centered_pair_low_quality_path) + lf = LabeledFrame(video=video, frame_idx=0) + assert_equal(lf.image, video[0]) + + +def test_labeled_frame_unused_predictions(): + video = Video("test.mp4") + skel = Skeleton(["A", "B"]) + track = Track("trk") + + lf1 = LabeledFrame(video=video, frame_idx=0) + lf1.instances.append( + Instance.from_numpy([[0, 0], [0, 0]], skeleton=skel, track=track) + ) + lf1.instances.append( + PredictedInstance.from_numpy( + [[0, 0], [0, 0]], [1, 1], 1, skeleton=skel, track=track + ) + ) + lf1.instances.append( + PredictedInstance.from_numpy([[1, 1], [1, 1]], [1, 1], 1, skeleton=skel) + ) + + assert len(lf1.unused_predictions) == 1 + assert (lf1.unused_predictions[0].numpy() == 1).all() + + lf2 = LabeledFrame(video=video, frame_idx=1) + lf2.instances.append( + PredictedInstance.from_numpy([[0, 0], [0, 0]], [1, 1], 1, skeleton=skel) + ) + lf2.instances.append(Instance.from_numpy([[0, 0], [0, 0]], skeleton=skel)) + lf2.instances[-1].from_predicted = lf2.instances[-2] + lf2.instances.append( + PredictedInstance.from_numpy([[1, 1], [1, 1]], [1, 1], 1, skeleton=skel) + ) + + assert len(lf2.unused_predictions) == 1 + assert (lf2.unused_predictions[0].numpy() == 1).all() diff --git a/tests/model/test_labels.py b/tests/model/test_labels.py index 0060a220..52ced7c6 100644 --- a/tests/model/test_labels.py +++ b/tests/model/test_labels.py @@ -258,3 +258,17 @@ def test_labels_remove_predictions(slp_real_data): labels.remove_predictions(clean=True) assert len(labels) == 5 assert sum([len(lf.predicted_instances) for lf in labels]) == 0 + + +def test_replace_videos(slp_real_data): + labels = load_slp(slp_real_data) + assert labels.video.filename == "tests/data/videos/centered_pair_low_quality.mp4" + labels.replace_videos( + old_videos=[labels.video], new_videos=[Video.from_filename("fake.mp4")] + ) + + for lf in labels: + assert lf.video.filename == "fake.mp4" + + for sf in labels.suggestions: + assert sf.video.filename == "fake.mp4" diff --git a/tests/model/test_video.py b/tests/model/test_video.py index 0ccd13c2..03a0bdba 100644 --- a/tests/model/test_video.py +++ b/tests/model/test_video.py @@ -20,6 +20,7 @@ def test_video_from_filename(centered_pair_low_quality_path): test_video = Video.from_filename(centered_pair_low_quality_path) assert test_video.filename == centered_pair_low_quality_path assert test_video.shape == (1100, 384, 384, 1) + assert len(test_video) == 1100 assert type(test_video.backend) == MediaVideo @@ -57,8 +58,8 @@ def test_video_exists(centered_pair_low_quality_video, centered_pair_frame_paths def test_video_open_close(centered_pair_low_quality_path): video = Video(centered_pair_low_quality_path) - assert video.is_open is False - assert video.backend is None + assert video.is_open + assert type(video.backend) == MediaVideo img = video[0] assert img.shape == (384, 384, 1)