Skip to content

Commit

Permalink
[torchcodec] fix simple decoder iteration (#59)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #59

The previous iterable and iterator implementation had a bug, demonstrated by the test modified in this diff. The problem:

1. We were using the `SimpleVideoDecoder` as its own iterator object by directly implementing `__iter__()` and `__next__()`.
2. In `__next__()`, we were calling a core library function, `get_next_frame()`, that returned the next frame to be decoded, and advanced the internal state of the C++ decoder.
3. But we were not *initializing* the iterator.

Because of the points above, for-based iteration only worked as expected on a freshly-created `SimpleVideoDecoder` object.

The simplest fix is to just remove the implementations of `__iter__()` and `__next__()`. Because it implements `__len__()` and `__getitem__()`, a `SimpleVideoDecoder` is a Python sequence. Python sequences are automatically iterable through `__len__()` and `__getitem__()`. See: https://docs.python.org/3/glossary.html#term-iterable

Differential Revision: D59309882
  • Loading branch information
scotts authored and facebook-github-bot committed Jul 3, 2024
1 parent 7af94c7 commit ec6470a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
11 changes: 0 additions & 11 deletions src/torchcodec/decoders/_simple_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,17 +67,6 @@ def __getitem__(self, key: Union[int, slice]) -> torch.Tensor:
f"Unsupported key type: {type(key)}. Supported types are int and slice."
)

def __iter__(self) -> "SimpleVideoDecoder":
return self

def __next__(self) -> torch.Tensor:
# TODO: We should distinguish between expected end-of-file and unexpected
# runtime error.
try:
return core.get_next_frame(self._decoder)
except RuntimeError:
raise StopIteration()


def _get_and_validate_stream_metadata(decoder: torch.Tensor) -> core.StreamMetadata:
video_metadata = core.get_video_metadata(decoder)
Expand Down
28 changes: 27 additions & 1 deletion test/decoders/simple_video_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,24 +196,50 @@ def test_getitem_fails(self):
with pytest.raises(TypeError, match="Unsupported key type"):
frame = decoder["0"] # noqa

def test_next(self):
def test_iteration(self):
decoder = SimpleVideoDecoder(str(NASA_VIDEO.path))

ref_frame0 = NASA_VIDEO.get_tensor_by_index(0)
ref_frame1 = NASA_VIDEO.get_tensor_by_index(1)
ref_frame9 = NASA_VIDEO.get_tensor_by_index(9)
ref_frame35 = NASA_VIDEO.get_tensor_by_index(35)
ref_frame180 = NASA_VIDEO.get_tensor_by_name("time6.000000")
ref_frame_last = NASA_VIDEO.get_tensor_by_name("time12.979633")

# Access an arbitrary frame to make sure that the later iteration
# still works as expected. The underlying C++ decoder object is
# actually stateful, and accessing a frame will move its internal
# cursor.
assert_tensor_equal(ref_frame35, decoder[35])

for i, frame in enumerate(decoder):
if i == 0:
assert_tensor_equal(ref_frame0, frame)
elif i == 1:
assert_tensor_equal(ref_frame1, frame)
elif i == 9:
assert_tensor_equal(ref_frame9, frame)
elif i == 35:
assert_tensor_equal(ref_frame35, frame)
elif i == 180:
assert_tensor_equal(ref_frame180, frame)
elif i == 389:
assert_tensor_equal(ref_frame_last, frame)

def test_iteration_slow(self):
decoder = SimpleVideoDecoder(str(NASA_VIDEO.path))
ref_frame_last = NASA_VIDEO.get_tensor_by_index(389)

# Force the decoder to seek around a lot while iterating; this will
# slow down decoding, but we should still only iterate the exact number
# of total frames.
iterations = 0
for frame in decoder:
assert_tensor_equal(ref_frame_last, decoder[-1])
iterations += 1

assert iterations == len(decoder) == 390


if __name__ == "__main__":
pytest.main()

0 comments on commit ec6470a

Please sign in to comment.