diff --git a/src/torchcodec/decoders/_simple_video_decoder.py b/src/torchcodec/decoders/_simple_video_decoder.py index 992ab623..21ece023 100644 --- a/src/torchcodec/decoders/_simple_video_decoder.py +++ b/src/torchcodec/decoders/_simple_video_decoder.py @@ -67,17 +67,6 @@ def __getitem__(self, key: Union[int, slice]) -> torch.Tensor: f"Unsupported key type: {type(key)}. Supported types are int and slice." ) - def __iter__(self) -> "SimpleVideoDecoder": - return self - - def __next__(self) -> torch.Tensor: - # TODO: We should distinguish between expected end-of-file and unexpected - # runtime error. - try: - return core.get_next_frame(self._decoder) - except RuntimeError: - raise StopIteration() - def _get_and_validate_stream_metadata(decoder: torch.Tensor) -> core.StreamMetadata: video_metadata = core.get_video_metadata(decoder) diff --git a/test/decoders/simple_video_decoder_test.py b/test/decoders/simple_video_decoder_test.py index bd0773ee..124732b4 100644 --- a/test/decoders/simple_video_decoder_test.py +++ b/test/decoders/simple_video_decoder_test.py @@ -196,24 +196,50 @@ def test_getitem_fails(self): with pytest.raises(TypeError, match="Unsupported key type"): frame = decoder["0"] # noqa - def test_next(self): + def test_iteration(self): decoder = SimpleVideoDecoder(str(NASA_VIDEO.path)) ref_frame0 = NASA_VIDEO.get_tensor_by_index(0) ref_frame1 = NASA_VIDEO.get_tensor_by_index(1) + ref_frame9 = NASA_VIDEO.get_tensor_by_index(9) + ref_frame35 = NASA_VIDEO.get_tensor_by_index(35) ref_frame180 = NASA_VIDEO.get_tensor_by_name("time6.000000") ref_frame_last = NASA_VIDEO.get_tensor_by_name("time12.979633") + # Access an arbitrary frame to make sure that the later iteration + # still works as expected. The underlying C++ decoder object is + # actually stateful, and accessing a frame will move its internal + # cursor. + assert_tensor_equal(ref_frame35, decoder[35]) + for i, frame in enumerate(decoder): if i == 0: assert_tensor_equal(ref_frame0, frame) elif i == 1: assert_tensor_equal(ref_frame1, frame) + elif i == 9: + assert_tensor_equal(ref_frame9, frame) + elif i == 35: + assert_tensor_equal(ref_frame35, frame) elif i == 180: assert_tensor_equal(ref_frame180, frame) elif i == 389: assert_tensor_equal(ref_frame_last, frame) + def test_iteration_slow(self): + decoder = SimpleVideoDecoder(str(NASA_VIDEO.path)) + ref_frame_last = NASA_VIDEO.get_tensor_by_index(389) + + # Force the decoder to seek around a lot while iterating; this will + # slow down decoding, but we should still only iterate the exact number + # of total frames. + iterations = 0 + for frame in decoder: + assert_tensor_equal(ref_frame_last, decoder[-1]) + iterations += 1 + + assert iterations == len(decoder) == 390 + if __name__ == "__main__": pytest.main()