From ec6470ab814bcddb46ccb4f821fbb019001c6947 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 3 Jul 2024 08:23:16 -0700 Subject: [PATCH] [torchcodec] fix simple decoder iteration (#59) Summary: Pull Request resolved: https://github.com/pytorch-labs/torchcodec/pull/59 The previous iterable and iterator implementation had a bug, demonstrated by the test modified in this diff. The problem: 1. We were using the `SimpleVideoDecoder` as its own iterator object by directly implementing `__iter__()` and `__next__()`. 2. In `__next__()`, we were calling a core library function, `get_next_frame()`, that returned the next frame to be decoded, and advanced the internal state of the C++ decoder. 3. But we were not *initializing* the iterator. Because of the points above, for-based iteration only worked as expected on a freshly-created `SimpleVideoDecoder` object. The simplest fix is to just remove the implementations of `__iter__()` and `__next__()`. Because it implements `__len__()` and `__getitem__()`, a `SimpleVideoDecoder` is a Python sequence. Python sequences are automatically iterable through `__len__()` and `__getitem__()`. See: https://docs.python.org/3/glossary.html#term-iterable Differential Revision: D59309882 --- .../decoders/_simple_video_decoder.py | 11 -------- test/decoders/simple_video_decoder_test.py | 28 ++++++++++++++++++- 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/src/torchcodec/decoders/_simple_video_decoder.py b/src/torchcodec/decoders/_simple_video_decoder.py index 992ab623..21ece023 100644 --- a/src/torchcodec/decoders/_simple_video_decoder.py +++ b/src/torchcodec/decoders/_simple_video_decoder.py @@ -67,17 +67,6 @@ def __getitem__(self, key: Union[int, slice]) -> torch.Tensor: f"Unsupported key type: {type(key)}. Supported types are int and slice." ) - def __iter__(self) -> "SimpleVideoDecoder": - return self - - def __next__(self) -> torch.Tensor: - # TODO: We should distinguish between expected end-of-file and unexpected - # runtime error. - try: - return core.get_next_frame(self._decoder) - except RuntimeError: - raise StopIteration() - def _get_and_validate_stream_metadata(decoder: torch.Tensor) -> core.StreamMetadata: video_metadata = core.get_video_metadata(decoder) diff --git a/test/decoders/simple_video_decoder_test.py b/test/decoders/simple_video_decoder_test.py index bd0773ee..124732b4 100644 --- a/test/decoders/simple_video_decoder_test.py +++ b/test/decoders/simple_video_decoder_test.py @@ -196,24 +196,50 @@ def test_getitem_fails(self): with pytest.raises(TypeError, match="Unsupported key type"): frame = decoder["0"] # noqa - def test_next(self): + def test_iteration(self): decoder = SimpleVideoDecoder(str(NASA_VIDEO.path)) ref_frame0 = NASA_VIDEO.get_tensor_by_index(0) ref_frame1 = NASA_VIDEO.get_tensor_by_index(1) + ref_frame9 = NASA_VIDEO.get_tensor_by_index(9) + ref_frame35 = NASA_VIDEO.get_tensor_by_index(35) ref_frame180 = NASA_VIDEO.get_tensor_by_name("time6.000000") ref_frame_last = NASA_VIDEO.get_tensor_by_name("time12.979633") + # Access an arbitrary frame to make sure that the later iteration + # still works as expected. The underlying C++ decoder object is + # actually stateful, and accessing a frame will move its internal + # cursor. + assert_tensor_equal(ref_frame35, decoder[35]) + for i, frame in enumerate(decoder): if i == 0: assert_tensor_equal(ref_frame0, frame) elif i == 1: assert_tensor_equal(ref_frame1, frame) + elif i == 9: + assert_tensor_equal(ref_frame9, frame) + elif i == 35: + assert_tensor_equal(ref_frame35, frame) elif i == 180: assert_tensor_equal(ref_frame180, frame) elif i == 389: assert_tensor_equal(ref_frame_last, frame) + def test_iteration_slow(self): + decoder = SimpleVideoDecoder(str(NASA_VIDEO.path)) + ref_frame_last = NASA_VIDEO.get_tensor_by_index(389) + + # Force the decoder to seek around a lot while iterating; this will + # slow down decoding, but we should still only iterate the exact number + # of total frames. + iterations = 0 + for frame in decoder: + assert_tensor_equal(ref_frame_last, decoder[-1]) + iterations += 1 + + assert iterations == len(decoder) == 390 + if __name__ == "__main__": pytest.main()