-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[torchcodec] initial version of simple video decoder (#42)
Summary: Pull Request resolved: #42 Initial version of `SimpleVideoDecoder`. This diff supports: 1. Creating a simple decoder from file or tensor. 2. Accessing frames through random access: `decoder[i]`. Internally, this gets a frame by index. 3. Iterating over all available frames: `for frame in decoder`. Internally, this calls the core library function to get the next frame. I did not implement slice semantics for random access yet. I think that in order to support that best we'll need to add new capabilities in the core API to support getting a batch of frames from a range. We *could* use `get_frames_by_indices`, but we would end up creating very large lists. This is a partial implementation of the design: https://fburl.com/gdoc/hxotv2by Reviewed By: NicolasHug, ahmadsharif1 Differential Revision: D58605628 fbshipit-source-id: b5bda7b0d63ec341608bcf58e82a0c3c8accf923
- Loading branch information
1 parent
6f2bd31
commit d4b0111
Showing
4 changed files
with
153 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from ._simple_video_decoder import SimpleVideoDecoder # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import json | ||
from typing import Union | ||
|
||
import torch | ||
from torchcodec.decoders import _core as core | ||
|
||
|
||
class SimpleVideoDecoder: | ||
|
||
def __init__(self, source: Union[str, bytes, torch.Tensor]): | ||
# TODO: support Path objects. | ||
if isinstance(source, str): | ||
self._decoder = core.create_from_file(source) | ||
elif isinstance(source, bytes): | ||
self._decoder = core.create_from_bytes(source) | ||
elif isinstance(source, torch.Tensor): | ||
self._decoder = core.create_from_tensor(source) | ||
else: | ||
raise TypeError( | ||
f"Unknown source type: {type(source)}. " | ||
"Supported types are str, bytes and Tensor." | ||
) | ||
|
||
core.add_video_stream(self._decoder) | ||
|
||
# TODO: We should either implement specific core library function to | ||
# retrieve these values, or replace this with a non-JSON metadata | ||
# retrieval. | ||
metadata_json = json.loads(core.get_json_metadata(self._decoder)) | ||
self._num_frames = metadata_json["numFrames"] | ||
self._stream_index = metadata_json["bestVideoStreamIndex"] | ||
|
||
def __len__(self) -> int: | ||
return self._num_frames | ||
|
||
def __getitem__(self, key: int) -> torch.Tensor: | ||
if not isinstance(key, int): | ||
raise TypeError( | ||
f"Unsupported key type: {type(key)}. Supported type is int." | ||
) | ||
|
||
if key < 0: | ||
key += self._num_frames | ||
if key >= self._num_frames or key < 0: | ||
raise IndexError( | ||
f"Index {key} is out of bounds; length is {self._num_frames}" | ||
) | ||
|
||
return core.get_frame_at_index( | ||
self._decoder, frame_index=key, stream_index=self._stream_index | ||
) | ||
|
||
def __iter__(self) -> "SimpleVideoDecoder": | ||
return self | ||
|
||
def __next__(self) -> torch.Tensor: | ||
# TODO: We should distinguish between expected end-of-file and unexpected | ||
# runtime error. | ||
try: | ||
return core.get_next_frame(self._decoder) | ||
except RuntimeError: | ||
raise StopIteration() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import pytest | ||
|
||
from torchcodec.decoders import SimpleVideoDecoder | ||
|
||
from ..test_utils import ( | ||
assert_equal, | ||
get_reference_video_path, | ||
get_reference_video_tensor, | ||
load_tensor_from_file, | ||
) | ||
|
||
|
||
class TestSimpleDecoder: | ||
|
||
def test_create_from_file(self): | ||
decoder = SimpleVideoDecoder(str(get_reference_video_path())) | ||
assert len(decoder) == 390 | ||
assert decoder._stream_index == 3 | ||
|
||
def test_create_from_tensor(self): | ||
decoder = SimpleVideoDecoder(get_reference_video_tensor()) | ||
assert len(decoder) == 390 | ||
assert decoder._stream_index == 3 | ||
|
||
def test_create_from_bytes(self): | ||
path = str(get_reference_video_path()) | ||
with open(path, "rb") as f: | ||
video_bytes = f.read() | ||
|
||
decoder = SimpleVideoDecoder(video_bytes) | ||
assert len(decoder) == 390 | ||
assert decoder._stream_index == 3 | ||
|
||
def test_create_fails(self): | ||
with pytest.raises(TypeError, match="Unknown source type"): | ||
decoder = SimpleVideoDecoder(123) # noqa | ||
|
||
def test_getitem_int(self): | ||
decoder = SimpleVideoDecoder(str(get_reference_video_path())) | ||
|
||
ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") | ||
ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt") | ||
ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") | ||
ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt") | ||
|
||
assert_equal(ref_frame0, decoder[0]) | ||
assert_equal(ref_frame1, decoder[1]) | ||
assert_equal(ref_frame180, decoder[180]) | ||
assert_equal(ref_frame_last, decoder[-1]) | ||
|
||
def test_getitem_fails(self): | ||
decoder = SimpleVideoDecoder(str(get_reference_video_path())) | ||
|
||
with pytest.raises(IndexError, match="out of bounds"): | ||
frame = decoder[1000] # noqa | ||
|
||
with pytest.raises(IndexError, match="out of bounds"): | ||
frame = decoder[-1000] # noqa | ||
|
||
with pytest.raises(TypeError, match="Unsupported key type"): | ||
frame = decoder["0"] # noqa | ||
|
||
def test_next(self): | ||
decoder = SimpleVideoDecoder(str(get_reference_video_path())) | ||
|
||
ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") | ||
ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt") | ||
ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") | ||
ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt") | ||
|
||
for i, frame in enumerate(decoder): | ||
if i == 0: | ||
assert_equal(ref_frame0, frame) | ||
elif i == 1: | ||
assert_equal(ref_frame1, frame) | ||
elif i == 180: | ||
assert_equal(ref_frame180, frame) | ||
elif i == 389: | ||
assert_equal(ref_frame_last, frame) | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main() |