Skip to content

Commit

Permalink
[torchcodec] simple video decoder constructor now accepts path objects (
Browse files Browse the repository at this point in the history
#60)

Summary:
Pull Request resolved: #60

Addresses TODO in code. `SimpleVideoDecoder` now accepts a `pathlib.Path` object directly. We still have to call a string constructor on it, as the underlying core function needs to take a string, but this is a usability improvement.

Reviewed By: ahmadsharif1

Differential Revision: D59325727

fbshipit-source-id: 2d2e1b8414046fb48d5d2d6631b0abda56a8b86b
  • Loading branch information
scotts authored and facebook-github-bot committed Jul 3, 2024
1 parent a388b06 commit 8f1c911
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
6 changes: 4 additions & 2 deletions src/torchcodec/decoders/_simple_video_decoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path
from typing import Union

import torch
Expand All @@ -9,17 +10,18 @@ class SimpleVideoDecoder:
"""TODO: Add docstring."""

def __init__(self, source: Union[str, bytes, torch.Tensor]):
# TODO: support Path objects.
if isinstance(source, str):
self._decoder = core.create_from_file(source)
elif isinstance(source, Path):
self._decoder = core.create_from_file(str(source))
elif isinstance(source, bytes):
self._decoder = core.create_from_bytes(source)
elif isinstance(source, torch.Tensor):
self._decoder = core.create_from_tensor(source)
else:
raise TypeError(
f"Unknown source type: {type(source)}. "
"Supported types are str, bytes and Tensor."
"Supported types are str, Path, bytes and Tensor."
)

core.add_video_stream(self._decoder)
Expand Down
14 changes: 8 additions & 6 deletions test/decoders/simple_video_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@


class TestSimpleDecoder:
@pytest.mark.parametrize("source_kind", ("path", "tensor", "bytes"))
@pytest.mark.parametrize("source_kind", ("str", "path", "tensor", "bytes"))
def test_create(self, source_kind):
if source_kind == "path":
if source_kind == "str":
source = str(NASA_VIDEO.path)
elif source_kind == "path":
source = NASA_VIDEO.path
elif source_kind == "tensor":
source = NASA_VIDEO.to_tensor()
elif source_kind == "bytes":
Expand All @@ -35,7 +37,7 @@ def test_create_fails(self):
decoder = SimpleVideoDecoder(123) # noqa

def test_getitem_int(self):
decoder = SimpleVideoDecoder(str(NASA_VIDEO.path))
decoder = SimpleVideoDecoder(NASA_VIDEO.path)

ref_frame0 = NASA_VIDEO.get_tensor_by_index(0)
ref_frame1 = NASA_VIDEO.get_tensor_by_index(1)
Expand All @@ -48,7 +50,7 @@ def test_getitem_int(self):
assert_tensor_equal(ref_frame_last, decoder[-1])

def test_getitem_slice(self):
decoder = SimpleVideoDecoder(str(NASA_VIDEO.path))
decoder = SimpleVideoDecoder(NASA_VIDEO.path)

# ensure that the degenerate case of a range of size 1 works

Expand Down Expand Up @@ -185,7 +187,7 @@ def test_getitem_slice(self):
assert_tensor_equal(sliced, ref)

def test_getitem_fails(self):
decoder = SimpleVideoDecoder(str(NASA_VIDEO.path))
decoder = SimpleVideoDecoder(NASA_VIDEO.path)

with pytest.raises(IndexError, match="out of bounds"):
frame = decoder[1000] # noqa
Expand All @@ -197,7 +199,7 @@ def test_getitem_fails(self):
frame = decoder["0"] # noqa

def test_iteration(self):
decoder = SimpleVideoDecoder(str(NASA_VIDEO.path))
decoder = SimpleVideoDecoder(NASA_VIDEO.path)

ref_frame0 = NASA_VIDEO.get_tensor_by_index(0)
ref_frame1 = NASA_VIDEO.get_tensor_by_index(1)
Expand Down

0 comments on commit 8f1c911

Please sign in to comment.