Skip to content

Commit

Permalink
Add video writing (#129)
Browse files Browse the repository at this point in the history
* Refactor io.video -> io.video_reading

* Add VideoWriter

* Fix valid frame index testing

* Add pixformat to writer
  • Loading branch information
talmo authored Oct 4, 2024
1 parent 2aa891d commit 0990f1a
Show file tree
Hide file tree
Showing 13 changed files with 219 additions and 8 deletions.
2 changes: 2 additions & 0 deletions docs/formats.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

::: sleap_io.load_video

::: sleap_io.save_video

::: sleap_io.load_slp

::: sleap_io.save_slp
Expand Down
3 changes: 3 additions & 0 deletions sleap_io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
load_jabs,
save_jabs,
load_video,
save_video,
load_file,
save_file,
)
from sleap_io.io.video_reading import VideoBackend
from sleap_io.io.video_writing import VideoWriter
2 changes: 2 additions & 0 deletions sleap_io/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
"""This sub-package contains I/O-related modules such as specific format backends."""

from . import video_reading as video
30 changes: 29 additions & 1 deletion sleap_io/io/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from __future__ import annotations
from sleap_io import Labels, Skeleton, Video
from sleap_io.io import slp, nwb, labelstudio, jabs
from sleap_io.io import slp, nwb, labelstudio, jabs, video_writing
from typing import Optional, Union
from pathlib import Path
import numpy as np


def load_slp(filename: str, open_videos: bool = True) -> Labels:
Expand Down Expand Up @@ -149,6 +150,33 @@ def load_video(filename: str, **kwargs) -> Video:
return Video.from_filename(filename, **kwargs)


def save_video(frames: np.ndarray | Video, filename: str | Path, **kwargs):
"""Write a list of frames to a video file.
Args:
frames: Sequence of frames to write to video. Each frame should be a 2D or 3D
numpy array with dimensions (height, width) or (height, width, channels).
filename: Path to output video file.
fps: Frames per second. Defaults to 30.
pixelformat: Pixel format for video. Defaults to "yuv420p".
codec: Codec to use for encoding. Defaults to "libx264".
crf: Constant rate factor to control lossiness of video. Values go from 2 to 32,
with numbers in the 18 to 30 range being most common. Lower values mean less
compressed/higher quality. Defaults to 25. No effect if codec is not
"libx264".
preset: H264 encoding preset. Defaults to "superfast". No effect if codec is not
"libx264".
output_params: Additional output parameters for FFMPEG. This should be a list of
strings corresponding to command line arguments for FFMPEG and libx264. Use
`ffmpeg -h encoder=libx264` to see all options for libx264 output_params.
See also: `sio.VideoWriter`
"""
with video_writing.VideoWriter(filename, **kwargs) as writer:
for frame in frames:
writer(frame)


def load_file(
filename: str | Path, format: Optional[str] = None, **kwargs
) -> Union[Labels, Video]:
Expand Down
2 changes: 1 addition & 1 deletion sleap_io/io/slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
LabeledFrame,
Labels,
)
from sleap_io.io.video import VideoBackend, ImageVideo, MediaVideo, HDF5Video
from sleap_io.io.video_reading import VideoBackend, ImageVideo, MediaVideo, HDF5Video
from sleap_io.io.utils import read_hdf5_attrs, read_hdf5_dataset, is_file_accessible
from enum import IntEnum
from pathlib import Path
Expand Down
30 changes: 29 additions & 1 deletion sleap_io/io/video.py → sleap_io/io/video_reading.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Backends for reading and writing videos."""
"""Backends for reading videos."""

from __future__ import annotations
from pathlib import Path
Expand Down Expand Up @@ -193,6 +193,17 @@ def __len__(self) -> int:
"""Return number of frames in the video."""
return self.shape[0]

def has_frame(self, frame_idx: int) -> bool:
"""Check if a frame index is contained in the video.
Args:
frame_idx: Index of frame to check.
Returns:
`True` if the index is contained in the video, otherwise `False`.
"""
return frame_idx < len(self)

def get_frame(self, frame_idx: int) -> np.ndarray:
"""Read a single frame from the video.
Expand All @@ -212,6 +223,9 @@ def get_frame(self, frame_idx: int) -> np.ndarray:
See also: `get_frames`
"""
if not self.has_frame(frame_idx):
raise IndexError(f"Frame index {frame_idx} out of range.")

img = self._read_frame(frame_idx)

if self.grayscale is None:
Expand Down Expand Up @@ -620,6 +634,20 @@ def decode_embedded(self, img_string: np.ndarray) -> np.ndarray:
img = np.expand_dims(img, axis=-1)
return img

def has_frame(self, frame_idx: int) -> bool:
"""Check if a frame index is contained in the video.
Args:
frame_idx: Index of frame to check.
Returns:
`True` if the index is contained in the video, otherwise `False`.
"""
if self.frame_map:
return frame_idx in self.frame_map
else:
return frame_idx < len(self)

def _read_frame(self, frame_idx: int) -> np.ndarray:
"""Read a single frame from the video.
Expand Down
119 changes: 119 additions & 0 deletions sleap_io/io/video_writing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Utilities for writing videos."""

from __future__ import annotations
from typing import Type, Optional
from types import TracebackType
import numpy as np
import imageio
import imageio.v2 as iio_v2
import attrs
from pathlib import Path


@attrs.define
class VideoWriter:
"""Simple video writer using imageio and FFMPEG.
Attributes:
filename: Path to output video file.
fps: Frames per second. Defaults to 30.
pixelformat: Pixel format for video. Defaults to "yuv420p".
codec: Codec to use for encoding. Defaults to "libx264".
crf: Constant rate factor to control lossiness of video. Values go from 2 to 32,
with numbers in the 18 to 30 range being most common. Lower values mean less
compressed/higher quality. Defaults to 25. No effect if codec is not
"libx264".
preset: H264 encoding preset. Defaults to "superfast". No effect if codec is not
"libx264".
output_params: Additional output parameters for FFMPEG. This should be a list of
strings corresponding to command line arguments for FFMPEG and libx264. Use
`ffmpeg -h encoder=libx264` to see all options for libx264 output_params.
Notes:
This class can be used as a context manager to ensure the video is properly
closed after writing. For example:
```python
with VideoWriter("output.mp4") as writer:
for frame in frames:
writer(frame)
```
"""

filename: Path = attrs.field(converter=Path)
fps: float = 30
pixelformat: str = "yuv420p"
codec: str = "libx264"
crf: int = 25
preset: str = "superfast"
output_params: list[str] = attrs.field(factory=list)
_writer: "imageio.plugins.ffmpeg.FfmpegFormat.Writer" | None = None

def build_output_params(self) -> list[str]:
"""Build the output parameters for FFMPEG."""
output_params = []
if self.codec == "libx264":
output_params.extend(
[
"-crf",
str(self.crf),
"-preset",
self.preset,
]
)
return output_params + self.output_params

def open(self):
"""Open the video writer."""
self.close()

self.filename.parent.mkdir(parents=True, exist_ok=True)
self._writer = iio_v2.get_writer(
self.filename.as_posix(),
format="FFMPEG",
fps=self.fps,
codec=self.codec,
pixelformat=self.pixelformat,
output_params=self.build_output_params(),
)

def close(self):
"""Close the video writer."""
if self._writer is not None:
self._writer.close()
self._writer = None

def write_frame(self, frame: np.ndarray):
"""Write a frame to the video.
Args:
frame: Frame to write to video. Should be a 2D or 3D numpy array with
dimensions (height, width) or (height, width, channels).
"""
if self._writer is None:
self.open()

self._writer.append_data(frame)

def __enter__(self):
"""Context manager entry."""
return self

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> Optional[bool]:
"""Context manager exit."""
self.close()
return False

def __call__(self, frame: np.ndarray):
"""Write a frame to the video.
Args:
frame: Frame to write to video. Should be a 2D or 3D numpy array with
dimensions (height, width) or (height, width, channels).
"""
self.write_frame(frame)
2 changes: 1 addition & 1 deletion sleap_io/model/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import attrs
from typing import Tuple, Optional, Optional
import numpy as np
from sleap_io.io.video import VideoBackend, MediaVideo, HDF5Video, ImageVideo
from sleap_io.io.video_reading import VideoBackend, MediaVideo, HDF5Video, ImageVideo
from sleap_io.io.utils import is_file_accessible
from pathlib import Path

Expand Down
11 changes: 11 additions & 0 deletions tests/io/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
load_jabs,
save_jabs,
load_video,
save_video,
load_file,
save_file,
)
Expand Down Expand Up @@ -104,3 +105,13 @@ def test_load_save_file_invalid():

with pytest.raises(ValueError):
save_file(Labels(), "invalid_file.ext")


def test_save_video(centered_pair_low_quality_video, tmp_path):
imgs = centered_pair_low_quality_video[:4]
save_video(imgs, tmp_path / "output.mp4")
vid = load_video(tmp_path / "output.mp4")
assert vid.shape == (4, 384, 384, 1)
save_video(vid, tmp_path / "output2.mp4")
vid2 = load_video(tmp_path / "output2.mp4")
assert vid2.shape == (4, 384, 384, 1)
2 changes: 1 addition & 1 deletion tests/io/test_slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import pytest
from pathlib import Path
import shutil
from sleap_io.io.video import ImageVideo, HDF5Video, MediaVideo
from sleap_io.io.video_reading import ImageVideo, HDF5Video, MediaVideo
import sys


Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Tests for methods in the sleap_io.io.video file."""
"""Tests for methods in the sleap_io.io.video_reading file."""

from sleap_io.io.video import VideoBackend, MediaVideo, HDF5Video, ImageVideo
from sleap_io.io.video_reading import VideoBackend, MediaVideo, HDF5Video, ImageVideo
import numpy as np
from numpy.testing import assert_equal
import h5py
Expand Down Expand Up @@ -56,6 +56,9 @@ def test_get_frame(centered_pair_low_quality_path):
assert_equal(backend[-3:], backend.get_frames(range(1097, 1100)))
assert_equal(backend[-3:-1], backend.get_frames(range(1097, 1099)))

with pytest.raises(IndexError):
backend.get_frame(1100)


@pytest.mark.parametrize("keep_open", [False, True])
def test_mediavideo(centered_pair_low_quality_path, keep_open):
Expand Down
15 changes: 15 additions & 0 deletions tests/io/test_video_writing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Tests for the sleap_io.io.video_writing module."""

import sleap_io as sio
from sleap_io.io.video_writing import VideoWriter


def test_video_writer(centered_pair_low_quality_video, tmp_path):
imgs = centered_pair_low_quality_video[:4]
with VideoWriter(tmp_path / "output.mp4") as writer:
for img in imgs:
writer.write_frame(img)

assert (tmp_path / "output.mp4").exists()
vid = sio.load_video(tmp_path / "output.mp4")
assert vid.shape == (4, 384, 384, 1)
2 changes: 1 addition & 1 deletion tests/model/test_video.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Tests for methods in the sleap_io.model.video file."""

from sleap_io import Video
from sleap_io.io.video import MediaVideo, ImageVideo
from sleap_io.io.video_reading import MediaVideo, ImageVideo
import numpy as np
import pytest
from pathlib import Path
Expand Down

0 comments on commit 0990f1a

Please sign in to comment.