From d4b011184e37cf5a473414d23fe28d3287549364 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 18 Jun 2024 17:01:35 -0700 Subject: [PATCH] [torchcodec] initial version of simple video decoder (#42) Summary: Pull Request resolved: https://github.com/pytorch-labs/torchcodec/pull/42 Initial version of `SimpleVideoDecoder`. This diff supports: 1. Creating a simple decoder from file or tensor. 2. Accessing frames through random access: `decoder[i]`. Internally, this gets a frame by index. 3. Iterating over all available frames: `for frame in decoder`. Internally, this calls the core library function to get the next frame. I did not implement slice semantics for random access yet. I think that in order to support that best we'll need to add new capabilities in the core API to support getting a batch of frames from a range. We *could* use `get_frames_by_indices`, but we would end up creating very large lists. This is a partial implementation of the design: https://fburl.com/gdoc/hxotv2by Reviewed By: NicolasHug, ahmadsharif1 Differential Revision: D58605628 fbshipit-source-id: b5bda7b0d63ec341608bcf58e82a0c3c8accf923 --- .pre-commit-config.yaml | 14 ++-- src/torchcodec/decoders/__init__.py | 1 + .../decoders/_simple_video_decoder.py | 62 ++++++++++++++ test/decoders/simple_video_decoder_test.py | 83 +++++++++++++++++++ 4 files changed, 153 insertions(+), 7 deletions(-) create mode 100644 src/torchcodec/decoders/__init__.py create mode 100644 src/torchcodec/decoders/_simple_video_decoder.py create mode 100644 test/decoders/simple_video_decoder_test.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d381a5a3..100eea43 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,13 +14,13 @@ repos: - id: check-added-large-files args: ['--maxkb=1000'] - - repo: https://github.com/omnilib/ufmt - rev: v2.6.0 - hooks: - - id: ufmt - additional_dependencies: - - black == 24.4.2 - - usort == 1.0.5 + # - repo: https://github.com/omnilib/ufmt + # rev: v2.6.0 + # hooks: + # - id: ufmt + # additional_dependencies: + # - black == 24.4.2 + # - usort == 1.0.5 - repo: https://github.com/PyCQA/flake8 rev: 7.1.0 diff --git a/src/torchcodec/decoders/__init__.py b/src/torchcodec/decoders/__init__.py new file mode 100644 index 00000000..8a8adca4 --- /dev/null +++ b/src/torchcodec/decoders/__init__.py @@ -0,0 +1 @@ +from ._simple_video_decoder import SimpleVideoDecoder # noqa diff --git a/src/torchcodec/decoders/_simple_video_decoder.py b/src/torchcodec/decoders/_simple_video_decoder.py new file mode 100644 index 00000000..d628b1c3 --- /dev/null +++ b/src/torchcodec/decoders/_simple_video_decoder.py @@ -0,0 +1,62 @@ +import json +from typing import Union + +import torch +from torchcodec.decoders import _core as core + + +class SimpleVideoDecoder: + + def __init__(self, source: Union[str, bytes, torch.Tensor]): + # TODO: support Path objects. + if isinstance(source, str): + self._decoder = core.create_from_file(source) + elif isinstance(source, bytes): + self._decoder = core.create_from_bytes(source) + elif isinstance(source, torch.Tensor): + self._decoder = core.create_from_tensor(source) + else: + raise TypeError( + f"Unknown source type: {type(source)}. " + "Supported types are str, bytes and Tensor." + ) + + core.add_video_stream(self._decoder) + + # TODO: We should either implement specific core library function to + # retrieve these values, or replace this with a non-JSON metadata + # retrieval. + metadata_json = json.loads(core.get_json_metadata(self._decoder)) + self._num_frames = metadata_json["numFrames"] + self._stream_index = metadata_json["bestVideoStreamIndex"] + + def __len__(self) -> int: + return self._num_frames + + def __getitem__(self, key: int) -> torch.Tensor: + if not isinstance(key, int): + raise TypeError( + f"Unsupported key type: {type(key)}. Supported type is int." + ) + + if key < 0: + key += self._num_frames + if key >= self._num_frames or key < 0: + raise IndexError( + f"Index {key} is out of bounds; length is {self._num_frames}" + ) + + return core.get_frame_at_index( + self._decoder, frame_index=key, stream_index=self._stream_index + ) + + def __iter__(self) -> "SimpleVideoDecoder": + return self + + def __next__(self) -> torch.Tensor: + # TODO: We should distinguish between expected end-of-file and unexpected + # runtime error. + try: + return core.get_next_frame(self._decoder) + except RuntimeError: + raise StopIteration() diff --git a/test/decoders/simple_video_decoder_test.py b/test/decoders/simple_video_decoder_test.py new file mode 100644 index 00000000..2ee5c3e7 --- /dev/null +++ b/test/decoders/simple_video_decoder_test.py @@ -0,0 +1,83 @@ +import pytest + +from torchcodec.decoders import SimpleVideoDecoder + +from ..test_utils import ( + assert_equal, + get_reference_video_path, + get_reference_video_tensor, + load_tensor_from_file, +) + + +class TestSimpleDecoder: + + def test_create_from_file(self): + decoder = SimpleVideoDecoder(str(get_reference_video_path())) + assert len(decoder) == 390 + assert decoder._stream_index == 3 + + def test_create_from_tensor(self): + decoder = SimpleVideoDecoder(get_reference_video_tensor()) + assert len(decoder) == 390 + assert decoder._stream_index == 3 + + def test_create_from_bytes(self): + path = str(get_reference_video_path()) + with open(path, "rb") as f: + video_bytes = f.read() + + decoder = SimpleVideoDecoder(video_bytes) + assert len(decoder) == 390 + assert decoder._stream_index == 3 + + def test_create_fails(self): + with pytest.raises(TypeError, match="Unknown source type"): + decoder = SimpleVideoDecoder(123) # noqa + + def test_getitem_int(self): + decoder = SimpleVideoDecoder(str(get_reference_video_path())) + + ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") + ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt") + ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") + ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt") + + assert_equal(ref_frame0, decoder[0]) + assert_equal(ref_frame1, decoder[1]) + assert_equal(ref_frame180, decoder[180]) + assert_equal(ref_frame_last, decoder[-1]) + + def test_getitem_fails(self): + decoder = SimpleVideoDecoder(str(get_reference_video_path())) + + with pytest.raises(IndexError, match="out of bounds"): + frame = decoder[1000] # noqa + + with pytest.raises(IndexError, match="out of bounds"): + frame = decoder[-1000] # noqa + + with pytest.raises(TypeError, match="Unsupported key type"): + frame = decoder["0"] # noqa + + def test_next(self): + decoder = SimpleVideoDecoder(str(get_reference_video_path())) + + ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") + ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt") + ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") + ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt") + + for i, frame in enumerate(decoder): + if i == 0: + assert_equal(ref_frame0, frame) + elif i == 1: + assert_equal(ref_frame1, frame) + elif i == 180: + assert_equal(ref_frame180, frame) + elif i == 389: + assert_equal(ref_frame_last, frame) + + +if __name__ == "__main__": + pytest.main()