diff --git a/src/torchcodec/decoders/__init__.py b/src/torchcodec/decoders/__init__.py new file mode 100644 index 00000000..8a8adca4 --- /dev/null +++ b/src/torchcodec/decoders/__init__.py @@ -0,0 +1 @@ +from ._simple_video_decoder import SimpleVideoDecoder # noqa diff --git a/src/torchcodec/decoders/_simple_video_decoder.py b/src/torchcodec/decoders/_simple_video_decoder.py new file mode 100644 index 00000000..f37df591 --- /dev/null +++ b/src/torchcodec/decoders/_simple_video_decoder.py @@ -0,0 +1,61 @@ +import json +from typing import Union + +import torch +from torchcodec.decoders import _core as core + + +class SimpleVideoDecoder: + + def __init__(self, source: Union[str, bytes, torch.Tensor]): + # TODO: support Path objects. + if isinstance(source, str): + self._decoder = core.create_from_file(source) + elif isinstance(source, bytes): + self._decoder = core.create_from_bytes(source) + elif isinstance(source, torch.Tensor): + self._decoder = core.create_from_tensor(source) + else: + raise TypeError( + f"Unknown source type: {type(source)}. Supported types are str, bytes and Tensor." + ) + + core.add_video_stream(self._decoder) + + # TODO: We should either implement specific core library function to + # retrieve these values, or replace this with a non-JSON metadata + # retrieval. + metadata_json = json.loads(core.get_json_metadata(self._decoder)) + self._num_frames = metadata_json["numFrames"] + self._stream_index = metadata_json["bestVideoStreamIndex"] + + def __len__(self) -> int: + return self._num_frames + + def __getitem__(self, key: int) -> torch.Tensor: + if not isinstance(key, int): + raise TypeError( + f"Unsupported key type: {type(key)}. Supported type is int." + ) + + if key < 0: + key += self._num_frames + if key >= self._num_frames or key < 0: + raise IndexError( + f"Index {key} is out of bounds; length is {self._num_frames}" + ) + + return core.get_frame_at_index( + self._decoder, frame_index=key, stream_index=self._stream_index + ) + + def __iter__(self) -> "SimpleVideoDecoder": + return self + + def __next__(self) -> torch.Tensor: + # TODO: We should distinguish between expected end-of-file and unexpected + # runtime error. + try: + return core.get_next_frame(self._decoder) + except RuntimeError: + raise StopIteration() diff --git a/test/decoders/simple_video_decoder_test.py b/test/decoders/simple_video_decoder_test.py new file mode 100644 index 00000000..2ee5c3e7 --- /dev/null +++ b/test/decoders/simple_video_decoder_test.py @@ -0,0 +1,83 @@ +import pytest + +from torchcodec.decoders import SimpleVideoDecoder + +from ..test_utils import ( + assert_equal, + get_reference_video_path, + get_reference_video_tensor, + load_tensor_from_file, +) + + +class TestSimpleDecoder: + + def test_create_from_file(self): + decoder = SimpleVideoDecoder(str(get_reference_video_path())) + assert len(decoder) == 390 + assert decoder._stream_index == 3 + + def test_create_from_tensor(self): + decoder = SimpleVideoDecoder(get_reference_video_tensor()) + assert len(decoder) == 390 + assert decoder._stream_index == 3 + + def test_create_from_bytes(self): + path = str(get_reference_video_path()) + with open(path, "rb") as f: + video_bytes = f.read() + + decoder = SimpleVideoDecoder(video_bytes) + assert len(decoder) == 390 + assert decoder._stream_index == 3 + + def test_create_fails(self): + with pytest.raises(TypeError, match="Unknown source type"): + decoder = SimpleVideoDecoder(123) # noqa + + def test_getitem_int(self): + decoder = SimpleVideoDecoder(str(get_reference_video_path())) + + ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") + ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt") + ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") + ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt") + + assert_equal(ref_frame0, decoder[0]) + assert_equal(ref_frame1, decoder[1]) + assert_equal(ref_frame180, decoder[180]) + assert_equal(ref_frame_last, decoder[-1]) + + def test_getitem_fails(self): + decoder = SimpleVideoDecoder(str(get_reference_video_path())) + + with pytest.raises(IndexError, match="out of bounds"): + frame = decoder[1000] # noqa + + with pytest.raises(IndexError, match="out of bounds"): + frame = decoder[-1000] # noqa + + with pytest.raises(TypeError, match="Unsupported key type"): + frame = decoder["0"] # noqa + + def test_next(self): + decoder = SimpleVideoDecoder(str(get_reference_video_path())) + + ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") + ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt") + ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") + ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt") + + for i, frame in enumerate(decoder): + if i == 0: + assert_equal(ref_frame0, frame) + elif i == 1: + assert_equal(ref_frame1, frame) + elif i == 180: + assert_equal(ref_frame180, frame) + elif i == 389: + assert_equal(ref_frame_last, frame) + + +if __name__ == "__main__": + pytest.main()