Skip to content

Commit

Permalink
Update high level
Browse files Browse the repository at this point in the history
  • Loading branch information
talmo committed May 1, 2024
1 parent 1d2a9ba commit cb0c07e
Showing 1 changed file with 66 additions and 9 deletions.
75 changes: 66 additions & 9 deletions sleap_io/io/video.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Backends for reading and writing videos."""

from __future__ import annotations
from pathlib import Path

import simplejson as json
import sys
Expand Down Expand Up @@ -42,15 +43,15 @@ class VideoBackend:
constructor to create a backend instance.
Attributes:
filename: Path to video file.
filename: Path to video file(s).
grayscale: Whether to force grayscale. If None, autodetect on first frame load.
keep_open: Whether to keep the video reader open between calls to read frames.
If False, will close the reader after each call. If True (the default), it
will keep the reader open and cache it for subsequent calls which may
enhance the performance of reading multiple frames.
"""

filename: str
filename: str | list[str]
grayscale: Optional[bool] = None
keep_open: bool = True
_cached_shape: Optional[Tuple[int, int, int, int]] = None
Expand All @@ -59,7 +60,7 @@ class VideoBackend:
@classmethod
def from_filename(
cls,
filename: str,
filename: str | list[str],
dataset: Optional[str] = None,
grayscale: Optional[bool] = None,
keep_open: bool = True,
Expand All @@ -68,7 +69,7 @@ def from_filename(
"""Create a VideoBackend from a filename.
Args:
filename: Path to video file.
filename: Path to video file(s).
dataset: Name of dataset in HDF5 file.
grayscale: Whether to force grayscale. If None, autodetect on first frame
load.
Expand All @@ -80,10 +81,21 @@ def from_filename(
Returns:
VideoBackend subclass instance.
"""
if type(filename) != str:
if isinstance(filename, Path):
filename = str(filename)

if filename.endswith(MediaVideo.EXTS):
if type(filename) == str and Path(filename).is_dir():
filename = ImageVideo.find_images(filename)

if type(filename) == list:
return ImageVideo(
filename, grayscale=grayscale, **_get_valid_kwargs(ImageVideo, kwargs)
)
elif filename.endswith(ImageVideo.EXTS):
return ImageVideo(
[filename], grayscale=grayscale, **_get_valid_kwargs(ImageVideo, kwargs)
)
elif filename.endswith(MediaVideo.EXTS):
return MediaVideo(
filename,
grayscale=grayscale,
Expand All @@ -106,8 +118,8 @@ def _read_frame(self, frame_idx: int) -> np.ndarray:
raise NotImplementedError

def _read_frames(self, frame_inds: list) -> np.ndarray:
"""Read a list of frames from the video. Must be implemented in subclasses."""
return np.stack([self._read_frame(i) for i in frame_inds], axis=0)
"""Read a list of frames from the video."""
return np.stack([self.get_frame(i) for i in frame_inds], axis=0)

def read_test_frame(self) -> np.ndarray:
"""Read a single frame from the video to test for grayscale.
Expand Down Expand Up @@ -146,7 +158,7 @@ def num_frames(self) -> int:

@property
def img_shape(self) -> Tuple[int, int, int]:
"""Shape of a single frame in the video. Must be implemented in subclasses."""
"""Shape of a single frame in the video."""
return self.get_frame(0).shape

@property
Expand Down Expand Up @@ -668,3 +680,48 @@ def _read_frames(self, frame_inds: list) -> np.ndarray:
f.close()

return imgs


@attrs.define
class ImageVideo(VideoBackend):
"""Video backend for reading videos stored as image files.
This backend supports reading videos stored as a list of images.
Attributes:
filename: Path to video files.
grayscale: Whether to force grayscale. If None, autodetect on first frame load.
"""

EXTS = ("png", "jpg", "jpeg", "tif", "tiff", "bmp")

@staticmethod
def find_images(folder: str) -> list[str]:
"""Find images in a folder and return a list of filenames."""
folder = Path(folder)
return sorted(
[str(f) for f in folder.glob("*") if f.suffix[1:] in ImageVideo.EXTS]
)

@property
def num_frames(self) -> int:
"""Number of frames in the video."""
return len(self.filename)

def _read_frame(self, frame_idx: int) -> np.ndarray:
"""Read a single frame from the video.
Args:
frame_idx: Index of frame to read.
Returns:
The frame as a numpy array of shape `(height, width, channels)`.
Notes:
This does not apply grayscale conversion. It is recommended to use the
`get_frame` method of the `VideoBackend` class instead.
"""
img = iio.imread(self.filename[frame_idx])
if img.ndim == 2:
img = np.expand_dims(img, axis=-1)
return img

0 comments on commit cb0c07e

Please sign in to comment.