Skip to content

Commit

Permalink
[torchcodec] initial version of simple video decoder (#42)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #42

Initial version of `SimpleVideoDecoder`. This diff supports:
1. Creating a simple decoder from file or tensor.
2. Accessing frames through random access: `decoder[i]`. Internally, this gets a frame by index.
3. Iterating over all available frames: `for frame in decoder`. Internally, this calls the core library function to get the next frame.

I did not implement slice semantics for random access yet. I think that in order to support that best we'll need to add new capabilities in the core API to support getting a batch of frames from a range. We *could* use `get_frames_by_indices`, but we would end up creating very large lists.

This is a partial implementation of the design: https://fburl.com/gdoc/hxotv2by

Reviewed By: NicolasHug, ahmadsharif1

Differential Revision: D58605628

fbshipit-source-id: b5bda7b0d63ec341608bcf58e82a0c3c8accf923
  • Loading branch information
scotts authored and facebook-github-bot committed Jun 19, 2024
1 parent 6f2bd31 commit d4b0111
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 7 deletions.
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ repos:
- id: check-added-large-files
args: ['--maxkb=1000']

- repo: https://github.com/omnilib/ufmt
rev: v2.6.0
hooks:
- id: ufmt
additional_dependencies:
- black == 24.4.2
- usort == 1.0.5
# - repo: https://github.com/omnilib/ufmt
# rev: v2.6.0
# hooks:
# - id: ufmt
# additional_dependencies:
# - black == 24.4.2
# - usort == 1.0.5

- repo: https://github.com/PyCQA/flake8
rev: 7.1.0
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._simple_video_decoder import SimpleVideoDecoder # noqa
62 changes: 62 additions & 0 deletions src/torchcodec/decoders/_simple_video_decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import json
from typing import Union

import torch
from torchcodec.decoders import _core as core


class SimpleVideoDecoder:

def __init__(self, source: Union[str, bytes, torch.Tensor]):
# TODO: support Path objects.
if isinstance(source, str):
self._decoder = core.create_from_file(source)
elif isinstance(source, bytes):
self._decoder = core.create_from_bytes(source)
elif isinstance(source, torch.Tensor):
self._decoder = core.create_from_tensor(source)
else:
raise TypeError(
f"Unknown source type: {type(source)}. "
"Supported types are str, bytes and Tensor."
)

core.add_video_stream(self._decoder)

# TODO: We should either implement specific core library function to
# retrieve these values, or replace this with a non-JSON metadata
# retrieval.
metadata_json = json.loads(core.get_json_metadata(self._decoder))
self._num_frames = metadata_json["numFrames"]
self._stream_index = metadata_json["bestVideoStreamIndex"]

def __len__(self) -> int:
return self._num_frames

def __getitem__(self, key: int) -> torch.Tensor:
if not isinstance(key, int):
raise TypeError(
f"Unsupported key type: {type(key)}. Supported type is int."
)

if key < 0:
key += self._num_frames
if key >= self._num_frames or key < 0:
raise IndexError(
f"Index {key} is out of bounds; length is {self._num_frames}"
)

return core.get_frame_at_index(
self._decoder, frame_index=key, stream_index=self._stream_index
)

def __iter__(self) -> "SimpleVideoDecoder":
return self

def __next__(self) -> torch.Tensor:
# TODO: We should distinguish between expected end-of-file and unexpected
# runtime error.
try:
return core.get_next_frame(self._decoder)
except RuntimeError:
raise StopIteration()
83 changes: 83 additions & 0 deletions test/decoders/simple_video_decoder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pytest

from torchcodec.decoders import SimpleVideoDecoder

from ..test_utils import (
assert_equal,
get_reference_video_path,
get_reference_video_tensor,
load_tensor_from_file,
)


class TestSimpleDecoder:

def test_create_from_file(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))
assert len(decoder) == 390
assert decoder._stream_index == 3

def test_create_from_tensor(self):
decoder = SimpleVideoDecoder(get_reference_video_tensor())
assert len(decoder) == 390
assert decoder._stream_index == 3

def test_create_from_bytes(self):
path = str(get_reference_video_path())
with open(path, "rb") as f:
video_bytes = f.read()

decoder = SimpleVideoDecoder(video_bytes)
assert len(decoder) == 390
assert decoder._stream_index == 3

def test_create_fails(self):
with pytest.raises(TypeError, match="Unknown source type"):
decoder = SimpleVideoDecoder(123) # noqa

def test_getitem_int(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))

ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt")
ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt")
ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt")

assert_equal(ref_frame0, decoder[0])
assert_equal(ref_frame1, decoder[1])
assert_equal(ref_frame180, decoder[180])
assert_equal(ref_frame_last, decoder[-1])

def test_getitem_fails(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))

with pytest.raises(IndexError, match="out of bounds"):
frame = decoder[1000] # noqa

with pytest.raises(IndexError, match="out of bounds"):
frame = decoder[-1000] # noqa

with pytest.raises(TypeError, match="Unsupported key type"):
frame = decoder["0"] # noqa

def test_next(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))

ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt")
ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt")
ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt")

for i, frame in enumerate(decoder):
if i == 0:
assert_equal(ref_frame0, frame)
elif i == 1:
assert_equal(ref_frame1, frame)
elif i == 180:
assert_equal(ref_frame180, frame)
elif i == 389:
assert_equal(ref_frame_last, frame)


if __name__ == "__main__":
pytest.main()

0 comments on commit d4b0111

Please sign in to comment.