Skip to content

Commit

Permalink
Video QOL enhancements (#82)
Browse files Browse the repository at this point in the history
* Video open/close control

* Replace filename

* Remove outdated test

* Update readme examples

* coverage
  • Loading branch information
talmo authored Apr 14, 2024
1 parent 0035e86 commit c6b4e4a
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 25 deletions.
25 changes: 21 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
[![codecov](https://codecov.io/gh/talmolab/sleap-io/branch/main/graph/badge.svg?token=Sj8kIFl3pi)](https://codecov.io/gh/talmolab/sleap-io)
[![Release](https://img.shields.io/github/v/release/talmolab/sleap-io?label=Latest)](https://github.com/talmolab/sleap-io/releases/)
[![PyPI](https://img.shields.io/pypi/v/sleap-io?label=PyPI)](https://pypi.org/project/sleap-io)
<!-- TODO: ![PyPI - Python Version](https://img.shields.io/pypi/pyversions/sleap-io) -->
![PyPI - Python Version](https://img.shields.io/pypi/pyversions/sleap-io)

Standalone utilities for working with animal pose tracking data.

Expand Down Expand Up @@ -35,10 +35,12 @@ See [`CONTRIBUTING.md`](CONTRIBUTING.md) for more information on development.
import sleap_io as sio

# Load from SLEAP file.
labels = sio.load_slp("predictions.slp")
labels = sio.load_file("predictions.slp")

# Save to NWB file.
sio.save_nwb(labels, "predictions.nwb")
sio.save_file(labels, "predictions.nwb")
# Or:
# labels.save("predictions.nwb")
```

### Convert labels to raw arrays
Expand All @@ -59,6 +61,18 @@ n_frames, n_tracks, n_nodes, xy_score = trx.shape
assert xy_score == 3
```

### Read video data

```py
import sleap_io as sio

video = sio.load_video("test.mp4")
n_frames, height, width, channels = video.shape

frame = video[0]
height, width, channels = frame.shape
```

### Create labels from raw data

```py
Expand All @@ -72,7 +86,7 @@ skeleton = sio.Skeleton(
)

# Create video.
video = sio.Video.from_filename("test.mp4")
video = sio.load_video("test.mp4")

# Create instance.
instance = sio.Instance.from_numpy(
Expand All @@ -89,6 +103,9 @@ lf = sio.LabeledFrame(video=video, frame_idx=0, instances=[instance])

# Create labels.
labels = sio.Labels(videos=[video], skeletons=[skeleton], labeled_frames=[lf])

# Save.
labels.save("labels.slp")
```

## Support
Expand Down
95 changes: 82 additions & 13 deletions sleap_io/model/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,6 @@ class Video:

EXTS = MediaVideo.EXTS + HDF5Video.EXTS

def __attrs_post_init__(self):
"""Set the video backend if not already set."""
if self.backend is None:
if Path(self.filename).exists():
# TODO: Automatic path resolution?
self.backend = VideoBackend.from_filename(self.filename)

@classmethod
def from_filename(
cls,
Expand Down Expand Up @@ -135,10 +128,86 @@ def __getitem__(self, inds: int | list[int] | slice) -> np.ndarray:
See also: VideoBackend.get_frame, VideoBackend.get_frames
"""
if self.backend is None:
raise ValueError(
"Video backend is not set. "
"This may be because the video reader could not be determined "
"automatically from the filename."
)
if not self.is_open:
self.open()
return self.backend[inds]

def exists(self) -> bool:
"""Check if the video file exists."""
return Path(self.filename).exists()

@property
def is_open(self) -> bool:
"""Check if the video backend is open."""
return self.exists() and self.backend is not None

def open(
self,
dataset: Optional[str] = None,
grayscale: Optional[str] = None,
keep_open: bool = True,
):
"""Open the video backend for reading.
Args:
dataset: Name of dataset in HDF5 file.
grayscale: Whether to force grayscale. If None, autodetect on first frame
load.
keep_open: Whether to keep the video reader open between calls to read
frames. If False, will close the reader after each call. If True (the
default), it will keep the reader open and cache it for subsequent calls
which may enhance the performance of reading multiple frames.
Notes:
This is useful for opening the video backend to read frames and then closing
it after reading all the necessary frames.
If the backend was already open, it will be closed before opening a new one.
Values for the HDF5 dataset and grayscale will be remembered if not
specified.
"""
if not self.exists():
raise FileNotFoundError(f"Video file not found: {self.filename}")

# Try to remember values from previous backend if available and not specified.
if self.backend is not None:
if dataset is None:
dataset = getattr(self.backend, "dataset", None)
if grayscale is None:
grayscale = getattr(self.backend, "grayscale", None)

# Close previous backend if open.
self.close()

# Create new backend.
self.backend = VideoBackend.from_filename(
self.filename,
dataset=dataset,
grayscale=grayscale,
keep_open=keep_open,
)

def close(self):
"""Close the video backend."""
if self.backend is not None:
del self.backend
self.backend = None

def replace_filename(self, new_filename: str | Path, open: bool = True):
"""Update the filename of the video, optionally opening the backend.
Args:
new_filename: New filename to set for the video.
open: If `True` (the default), open the backend with the new filename. If
the new filename does not exist, no error is raised.
"""
if isinstance(new_filename, Path):
new_filename = str(new_filename)

self.filename = new_filename

if open:
if self.exists():
self.open()
else:
self.close()
77 changes: 69 additions & 8 deletions tests/model/test_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from sleap_io import Video
from sleap_io.io.video import MediaVideo
import numpy as np
import pytest
from pathlib import Path


def test_video_class():
Expand All @@ -18,15 +20,7 @@ def test_video_from_filename(centered_pair_low_quality_path):
test_video = Video.from_filename(centered_pair_low_quality_path)
assert test_video.filename == centered_pair_low_quality_path
assert test_video.shape == (1100, 384, 384, 1)


def test_video_auto_backend(centered_pair_low_quality_path):
"""Test initialization of `Video` object with automatic backend selection."""
test_video = Video(filename=centered_pair_low_quality_path)
assert test_video.backend is not None
assert type(test_video.backend) == MediaVideo
assert test_video.filename == centered_pair_low_quality_path
assert test_video.shape == (1100, 384, 384, 1)


def test_video_getitem(centered_pair_low_quality_video):
Expand All @@ -40,3 +34,70 @@ def test_video_repr(centered_pair_low_quality_video):
'Video(filename="tests/data/videos/centered_pair_low_quality.mp4", '
"shape=(1100, 384, 384, 1), backend=MediaVideo)"
)


def test_video_exists(centered_pair_low_quality_video):
video = Video("test.mp4")
assert video.exists() is False

assert centered_pair_low_quality_video.exists() is True


def test_video_open_close(centered_pair_low_quality_path):
video = Video(centered_pair_low_quality_path)
assert video.is_open is False
assert video.backend is None

img = video[0]
assert img.shape == (384, 384, 1)
assert video.is_open is True

video = Video("test.mp4")
assert video.is_open is False
with pytest.raises(FileNotFoundError):
video[0]

video = Video.from_filename(centered_pair_low_quality_path)
assert video.is_open is True
assert type(video.backend) == MediaVideo

video.close()
assert video.is_open is False
assert video.backend is None
assert video.shape is None

video.open()
assert video.is_open is True
assert type(video.backend) == MediaVideo
assert video[0].shape == (384, 384, 1)

video = Video.from_filename(centered_pair_low_quality_path, grayscale=False)
assert video.shape == (1100, 384, 384, 3)
video.open()
assert video.shape == (1100, 384, 384, 3)
video.open(grayscale=True)
assert video.shape == (1100, 384, 384, 1)


def test_video_replace_filename(centered_pair_low_quality_path):
video = Video.from_filename("test.mp4")
assert video.exists() is False

video.replace_filename(centered_pair_low_quality_path)
assert video.exists() is True
assert video.is_open is True
assert type(video.backend) == MediaVideo

video.replace_filename(Path(centered_pair_low_quality_path))
assert video.exists() is True
assert video.is_open is True
assert type(video.backend) == MediaVideo

video.replace_filename("test.mp4")
assert video.exists() is False
assert video.is_open is False

video.replace_filename(centered_pair_low_quality_path, open=False)
assert video.exists() is True
assert video.is_open is False
assert video.backend is None

0 comments on commit c6b4e4a

Please sign in to comment.