diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 086d2fc8..3555debe 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -166,35 +166,35 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) { createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); auto output = ourDecoder->getNextDecodedOutput(); - torch::Tensor tensor1FromOurDecoder = output.frame; - EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector({270, 480, 3})); + torch::Tensor tensor0FromOurDecoder = output.frame; + EXPECT_EQ(tensor0FromOurDecoder.sizes(), std::vector({270, 480, 3})); EXPECT_EQ(output.ptsSeconds, 0.0); EXPECT_EQ(output.pts, 0); output = ourDecoder->getNextDecodedOutput(); - torch::Tensor tensor2FromOurDecoder = output.frame; - EXPECT_EQ(tensor2FromOurDecoder.sizes(), std::vector({270, 480, 3})); + torch::Tensor tensor1FromOurDecoder = output.frame; + EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector({270, 480, 3})); EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000); EXPECT_EQ(output.pts, 1001); + torch::Tensor tensor0FromFFMPEG = + readTensorFromDisk("nasa_13013.mp4.frame000000.pt"); torch::Tensor tensor1FromFFMPEG = readTensorFromDisk("nasa_13013.mp4.frame000001.pt"); - torch::Tensor tensor2FromFFMPEG = - readTensorFromDisk("nasa_13013.mp4.frame000002.pt"); EXPECT_EQ(tensor1FromFFMPEG.sizes(), std::vector({270, 480, 3})); + EXPECT_TRUE(torch::equal(tensor0FromOurDecoder, tensor0FromFFMPEG)); EXPECT_TRUE(torch::equal(tensor1FromOurDecoder, tensor1FromFFMPEG)); - EXPECT_TRUE(torch::equal(tensor2FromOurDecoder, tensor2FromFFMPEG)); EXPECT_TRUE( - torch::allclose(tensor1FromOurDecoder, tensor1FromFFMPEG, 0.1, 20)); - EXPECT_EQ(tensor2FromFFMPEG.sizes(), std::vector({270, 480, 3})); + torch::allclose(tensor0FromOurDecoder, tensor0FromFFMPEG, 0.1, 20)); + EXPECT_EQ(tensor1FromFFMPEG.sizes(), std::vector({270, 480, 3})); EXPECT_TRUE( - torch::allclose(tensor2FromOurDecoder, tensor2FromFFMPEG, 0.1, 20)); + torch::allclose(tensor1FromOurDecoder, tensor1FromFFMPEG, 0.1, 20)); if (FLAGS_dump_frames_for_debugging) { + dumpTensorToDisk(tensor0FromFFMPEG, "tensor0FromFFMPEG.pt"); dumpTensorToDisk(tensor1FromFFMPEG, "tensor1FromFFMPEG.pt"); - dumpTensorToDisk(tensor2FromFFMPEG, "tensor2FromFFMPEG.pt"); + dumpTensorToDisk(tensor0FromOurDecoder, "tensor0FromOurDecoder.pt"); dumpTensorToDisk(tensor1FromOurDecoder, "tensor1FromOurDecoder.pt"); - dumpTensorToDisk(tensor2FromOurDecoder, "tensor2FromOurDecoder.pt"); } } @@ -211,13 +211,13 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNHWC) { auto tensor = output.frames; EXPECT_EQ(tensor.sizes(), std::vector({2, 270, 480, 3})); - torch::Tensor tensor1FromFFMPEG = - readTensorFromDisk("nasa_13013.mp4.frame000001.pt"); - torch::Tensor tensor2FromFFMPEG = + torch::Tensor tensor0FromFFMPEG = + readTensorFromDisk("nasa_13013.mp4.frame000000.pt"); + torch::Tensor tensorTime6FromFFMPEG = readTensorFromDisk("nasa_13013.mp4.time6.000000.pt"); - EXPECT_TRUE(torch::equal(tensor[0], tensor1FromFFMPEG)); - EXPECT_TRUE(torch::equal(tensor[1], tensor2FromFFMPEG)); + EXPECT_TRUE(torch::equal(tensor[0], tensor0FromFFMPEG)); + EXPECT_TRUE(torch::equal(tensor[1], tensorTime6FromFFMPEG)); } TEST_P(VideoDecoderTest, DecodesFramesInABatchInNCHW) { @@ -235,14 +235,14 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNCHW) { auto tensor = output.frames; EXPECT_EQ(tensor.sizes(), std::vector({2, 3, 270, 480})); - torch::Tensor tensor1FromFFMPEG = - readTensorFromDisk("nasa_13013.mp4.frame000001.pt"); - torch::Tensor tensor2FromFFMPEG = + torch::Tensor tensor0FromFFMPEG = + readTensorFromDisk("nasa_13013.mp4.frame000000.pt"); + torch::Tensor tensorTime6FromFFMPEG = readTensorFromDisk("nasa_13013.mp4.time6.000000.pt"); tensor = tensor.permute({0, 2, 3, 1}); - EXPECT_TRUE(torch::equal(tensor[0], tensor1FromFFMPEG)); - EXPECT_TRUE(torch::equal(tensor[1], tensor2FromFFMPEG)); + EXPECT_TRUE(torch::equal(tensor[0], tensor0FromFFMPEG)); + EXPECT_TRUE(torch::equal(tensor[1], tensorTime6FromFFMPEG)); } TEST_P(VideoDecoderTest, SeeksCloseToEof) { diff --git a/test/decoders/simple_video_decoder_test.py b/test/decoders/simple_video_decoder_test.py index 208e9dbb..bd0773ee 100644 --- a/test/decoders/simple_video_decoder_test.py +++ b/test/decoders/simple_video_decoder_test.py @@ -3,25 +3,18 @@ from torchcodec.decoders import _core, SimpleVideoDecoder -from ..test_utils import ( - assert_equal, - EMPTY_REF_TENSOR, - get_reference_video_path, - get_reference_video_tensor, - load_tensor_from_file, - REF_DIMS, -) +from ..test_utils import assert_tensor_equal, NASA_VIDEO class TestSimpleDecoder: @pytest.mark.parametrize("source_kind", ("path", "tensor", "bytes")) def test_create(self, source_kind): if source_kind == "path": - source = str(get_reference_video_path()) + source = str(NASA_VIDEO.path) elif source_kind == "tensor": - source = get_reference_video_tensor() + source = NASA_VIDEO.to_tensor() elif source_kind == "bytes": - path = str(get_reference_video_path()) + path = str(NASA_VIDEO.path) with open(path, "rb") as f: source = f.read() else: @@ -42,99 +35,157 @@ def test_create_fails(self): decoder = SimpleVideoDecoder(123) # noqa def test_getitem_int(self): - decoder = SimpleVideoDecoder(str(get_reference_video_path())) + decoder = SimpleVideoDecoder(str(NASA_VIDEO.path)) - ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") - ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt") - ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") - ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt") + ref_frame0 = NASA_VIDEO.get_tensor_by_index(0) + ref_frame1 = NASA_VIDEO.get_tensor_by_index(1) + ref_frame180 = NASA_VIDEO.get_tensor_by_name("time6.000000") + ref_frame_last = NASA_VIDEO.get_tensor_by_name("time12.979633") - assert_equal(ref_frame0, decoder[0]) - assert_equal(ref_frame1, decoder[1]) - assert_equal(ref_frame180, decoder[180]) - assert_equal(ref_frame_last, decoder[-1]) + assert_tensor_equal(ref_frame0, decoder[0]) + assert_tensor_equal(ref_frame1, decoder[1]) + assert_tensor_equal(ref_frame180, decoder[180]) + assert_tensor_equal(ref_frame_last, decoder[-1]) def test_getitem_slice(self): - decoder = SimpleVideoDecoder(str(get_reference_video_path())) - - ref_frames0_9 = [ - load_tensor_from_file(f"nasa_13013.mp4.frame{i + 1:06d}.pt") - for i in range(0, 9) - ] - - # Ensure that the degenerate case of a range of size 1 works; note that we get - # a tensor which CONTAINS a single frame, rather than a tensor that itself IS a - # single frame. Hence we have to access the 0th element of the return tensor. - slice_0 = decoder[0:1] - assert slice_0.shape == torch.Size([1, *REF_DIMS]) - assert_equal(ref_frames0_9[0], slice_0[0]) - - slice_4 = decoder[4:5] - assert slice_4.shape == torch.Size([1, *REF_DIMS]) - assert_equal(ref_frames0_9[4], slice_4[0]) - - slice_8 = decoder[8:9] - assert slice_8.shape == torch.Size([1, *REF_DIMS]) - assert_equal(ref_frames0_9[8], slice_8[0]) - - ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") - slice_180 = decoder[180:181] - assert slice_180.shape == torch.Size([1, *REF_DIMS]) - assert_equal(ref_frame180, slice_180[0]) + decoder = SimpleVideoDecoder(str(NASA_VIDEO.path)) + + # ensure that the degenerate case of a range of size 1 works + + ref0 = NASA_VIDEO.get_stacked_tensor_by_range(0, 1) + slice0 = decoder[0:1] + assert slice0.shape == torch.Size( + [ + 1, + NASA_VIDEO.height, + NASA_VIDEO.width, + NASA_VIDEO.num_color_channels, + ] + ) + assert_tensor_equal(ref0, slice0) + + ref4 = NASA_VIDEO.get_stacked_tensor_by_range(4, 5) + slice4 = decoder[4:5] + assert slice4.shape == torch.Size( + [ + 1, + NASA_VIDEO.height, + NASA_VIDEO.width, + NASA_VIDEO.num_color_channels, + ] + ) + assert_tensor_equal(ref4, slice4) + + ref8 = NASA_VIDEO.get_stacked_tensor_by_range(8, 9) + slice8 = decoder[8:9] + assert slice8.shape == torch.Size( + [ + 1, + NASA_VIDEO.height, + NASA_VIDEO.width, + NASA_VIDEO.num_color_channels, + ] + ) + assert_tensor_equal(ref8, slice8) + + ref180 = NASA_VIDEO.get_tensor_by_name("time6.000000") + slice180 = decoder[180:181] + assert slice180.shape == torch.Size( + [ + 1, + NASA_VIDEO.height, + NASA_VIDEO.width, + NASA_VIDEO.num_color_channels, + ] + ) + assert_tensor_equal(ref180, slice180[0]) # contiguous ranges - slice_frames0_9 = decoder[0:9] - assert slice_frames0_9.shape == torch.Size([9, *REF_DIMS]) - for ref_frame, slice_frame in zip(ref_frames0_9, slice_frames0_9): - assert_equal(ref_frame, slice_frame) - - slice_frames4_8 = decoder[4:8] - assert slice_frames4_8.shape == torch.Size([4, *REF_DIMS]) - for ref_frame, slice_frame in zip(ref_frames0_9[4:8], slice_frames4_8): - assert_equal(ref_frame, slice_frame) + ref0_9 = NASA_VIDEO.get_stacked_tensor_by_range(0, 9) + slice0_9 = decoder[0:9] + assert slice0_9.shape == torch.Size( + [ + 9, + NASA_VIDEO.height, + NASA_VIDEO.width, + NASA_VIDEO.num_color_channels, + ] + ) + assert_tensor_equal(ref0_9, slice0_9) + + ref4_8 = NASA_VIDEO.get_stacked_tensor_by_range(4, 8) + slice4_8 = decoder[4:8] + assert slice4_8.shape == torch.Size( + [ + 4, + NASA_VIDEO.height, + NASA_VIDEO.width, + NASA_VIDEO.num_color_channels, + ] + ) + assert_tensor_equal(ref4_8, slice4_8) # ranges with a stride - ref_frames15_35 = [ - load_tensor_from_file(f"nasa_13013.mp4.frame{i:06d}.pt") - for i in range(15, 36, 5) - ] - slice_frames15_35 = decoder[15:36:5] - assert slice_frames15_35.shape == torch.Size([5, *REF_DIMS]) - for ref_frame, slice_frame in zip(ref_frames15_35, slice_frames15_35): - assert_equal(ref_frame, slice_frame) - - slice_frames0_9_2 = decoder[0:9:2] - assert slice_frames0_9_2.shape == torch.Size([5, *REF_DIMS]) - for ref_frame, slice_frame in zip(ref_frames0_9[0:0:2], slice_frames0_9_2): - assert_equal(ref_frame, slice_frame) + ref15_35 = NASA_VIDEO.get_stacked_tensor_by_range(15, 36, 5) + slice15_35 = decoder[15:36:5] + assert slice15_35.shape == torch.Size( + [ + 5, + NASA_VIDEO.height, + NASA_VIDEO.width, + NASA_VIDEO.num_color_channels, + ] + ) + assert_tensor_equal(ref15_35, slice15_35) + + ref0_9_2 = NASA_VIDEO.get_stacked_tensor_by_range(0, 9, 2) + slice0_9_2 = decoder[0:9:2] + assert slice0_9_2.shape == torch.Size( + [ + 5, + NASA_VIDEO.height, + NASA_VIDEO.width, + NASA_VIDEO.num_color_channels, + ] + ) + assert_tensor_equal(ref0_9_2, slice0_9_2) # negative numbers in the slice - ref_frames386_389 = [ - load_tensor_from_file(f"nasa_13013.mp4.frame{i:06d}.pt") - for i in range(386, 390) - ] - - slice_frames386_389 = decoder[-4:] - assert slice_frames386_389.shape == torch.Size([4, *REF_DIMS]) - for ref_frame, slice_frame in zip(ref_frames386_389[-4:], slice_frames386_389): - assert_equal(ref_frame, slice_frame) + ref386_389 = NASA_VIDEO.get_stacked_tensor_by_range(386, 390) + slice386_389 = decoder[-4:] + assert slice386_389.shape == torch.Size( + [ + 4, + NASA_VIDEO.height, + NASA_VIDEO.width, + NASA_VIDEO.num_color_channels, + ] + ) + assert_tensor_equal(ref386_389, slice386_389) # an empty range is valid! empty_frame = decoder[5:5] - assert_equal(empty_frame, EMPTY_REF_TENSOR) + assert_tensor_equal(empty_frame, NASA_VIDEO.empty_hwc_tensor) # slices that are out-of-range are also valid - they return an empty tensor also_empty = decoder[10000:] - assert_equal(also_empty, EMPTY_REF_TENSOR) + assert_tensor_equal(also_empty, NASA_VIDEO.empty_hwc_tensor) # should be just a copy all_frames = decoder[:] - assert all_frames.shape == torch.Size([len(decoder), *REF_DIMS]) + assert all_frames.shape == torch.Size( + [ + len(decoder), + NASA_VIDEO.height, + NASA_VIDEO.width, + NASA_VIDEO.num_color_channels, + ] + ) for sliced, ref in zip(all_frames, decoder): - assert_equal(sliced, ref) + assert_tensor_equal(sliced, ref) def test_getitem_fails(self): - decoder = SimpleVideoDecoder(str(get_reference_video_path())) + decoder = SimpleVideoDecoder(str(NASA_VIDEO.path)) with pytest.raises(IndexError, match="out of bounds"): frame = decoder[1000] # noqa @@ -146,22 +197,22 @@ def test_getitem_fails(self): frame = decoder["0"] # noqa def test_next(self): - decoder = SimpleVideoDecoder(str(get_reference_video_path())) + decoder = SimpleVideoDecoder(str(NASA_VIDEO.path)) - ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") - ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt") - ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") - ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt") + ref_frame0 = NASA_VIDEO.get_tensor_by_index(0) + ref_frame1 = NASA_VIDEO.get_tensor_by_index(1) + ref_frame180 = NASA_VIDEO.get_tensor_by_name("time6.000000") + ref_frame_last = NASA_VIDEO.get_tensor_by_name("time12.979633") for i, frame in enumerate(decoder): if i == 0: - assert_equal(ref_frame0, frame) + assert_tensor_equal(ref_frame0, frame) elif i == 1: - assert_equal(ref_frame1, frame) + assert_tensor_equal(ref_frame1, frame) elif i == 180: - assert_equal(ref_frame180, frame) + assert_tensor_equal(ref_frame180, frame) elif i == 389: - assert_equal(ref_frame_last, frame) + assert_tensor_equal(ref_frame_last, frame) if __name__ == "__main__": diff --git a/test/decoders/test_metadata.py b/test/decoders/test_metadata.py index 321dd640..64bbd5c9 100644 --- a/test/decoders/test_metadata.py +++ b/test/decoders/test_metadata.py @@ -6,11 +6,11 @@ StreamMetadata, ) -from ..test_utils import get_reference_video_path +from ..test_utils import NASA_VIDEO def test_get_video_metadata(): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_VIDEO.path)) metadata = get_video_metadata(decoder) assert len(metadata.streams) == 6 assert metadata.best_video_stream_index == 3 diff --git a/test/decoders/video_decoder_ops_test.py b/test/decoders/video_decoder_ops_test.py index 0a2c402e..267246c8 100644 --- a/test/decoders/video_decoder_ops_test.py +++ b/test/decoders/video_decoder_ops_test.py @@ -24,20 +24,14 @@ seek_to_pts, ) -from ..test_utils import ( - assert_equal, - EMPTY_REF_TENSOR, - get_reference_audio_path, - get_reference_video_path, - load_tensor_from_file, -) +from ..test_utils import assert_tensor_equal, NASA_AUDIO, NASA_VIDEO torch._dynamo.config.capture_dynamic_output_shape_ops = True class ReferenceDecoder: def __init__(self): - self.decoder: torch.Tensor = create_from_file(str(get_reference_video_path())) + self.decoder: torch.Tensor = create_from_file(str(NASA_VIDEO.path)) add_video_stream(self.decoder) def get_next_frame(self) -> torch.Tensor: @@ -52,130 +46,117 @@ def seek(self, pts: float): # TODO: Some of these tests could probably be unified and parametrized? class TestOps: def test_seek_and_next(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder) + frame0 = get_next_frame(decoder) + reference_frame0 = NASA_VIDEO.get_tensor_by_index(0) + assert_tensor_equal(frame0, reference_frame0) + reference_frame1 = NASA_VIDEO.get_tensor_by_index(1) frame1 = get_next_frame(decoder) - reference_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") - assert_equal(frame1, reference_frame1) - reference_frame2 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt") - img2 = get_next_frame(decoder) - assert_equal(img2, reference_frame2) + assert_tensor_equal(frame1, reference_frame1) seek_to_pts(decoder, 6.0) frame_time6 = get_next_frame(decoder) - reference_frame_time6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") - assert_equal(frame_time6, reference_frame_time6) + reference_frame_time6 = NASA_VIDEO.get_tensor_by_name("time6.000000") + assert_tensor_equal(frame_time6, reference_frame_time6) def test_get_frame_at_pts(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder) # This frame has pts=6.006 and duration=0.033367, so it should be visible # at timestamps in the range [6.006, 6.039367) (not including the last timestamp). frame6 = get_frame_at_pts(decoder, 6.006) - reference_frame6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") - assert_equal(frame6, reference_frame6) + reference_frame6 = NASA_VIDEO.get_tensor_by_name("time6.000000") + assert_tensor_equal(frame6, reference_frame6) frame6 = get_frame_at_pts(decoder, 6.02) - assert_equal(frame6, reference_frame6) + assert_tensor_equal(frame6, reference_frame6) frame6 = get_frame_at_pts(decoder, 6.039366) - assert_equal(frame6, reference_frame6) + assert_tensor_equal(frame6, reference_frame6) # Note that this timestamp is exactly on a frame boundary, so it should # return the next frame since the right boundary of the interval is # open. next_frame = get_frame_at_pts(decoder, 6.039367) with pytest.raises(AssertionError): - assert_equal(next_frame, reference_frame6) + assert_tensor_equal(next_frame, reference_frame6) def test_get_frame_at_index(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder) - frame1 = get_frame_at_index(decoder, stream_index=3, frame_index=0) - reference_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") - assert_equal(frame1, reference_frame1) + frame0 = get_frame_at_index(decoder, stream_index=3, frame_index=0) + reference_frame0 = NASA_VIDEO.get_tensor_by_index(0) + assert_tensor_equal(frame0, reference_frame0) # The frame that is displayed at 6 seconds is frame 180 from a 0-based index. frame6 = get_frame_at_index(decoder, stream_index=3, frame_index=180) - reference_frame6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") - assert_equal(frame6, reference_frame6) + reference_frame6 = NASA_VIDEO.get_tensor_by_name("time6.000000") + assert_tensor_equal(frame6, reference_frame6) def test_get_frames_at_indices(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder) - frames1and6 = get_frames_at_indices( + frames0and180 = get_frames_at_indices( decoder, stream_index=3, frame_indices=[0, 180] ) - reference_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") - reference_frame6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") - assert_equal(frames1and6[0], reference_frame1) - assert_equal(frames1and6[1], reference_frame6) + reference_frame0 = NASA_VIDEO.get_tensor_by_index(0) + reference_frame180 = NASA_VIDEO.get_tensor_by_name("time6.000000") + assert_tensor_equal(frames0and180[0], reference_frame0) + assert_tensor_equal(frames0and180[1], reference_frame180) def test_get_frames_in_range(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder) - ref_frames0_9 = [ - load_tensor_from_file(f"nasa_13013.mp4.frame{i + 1:06d}.pt") - for i in range(0, 9) - ] - ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") - ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt") - # ensure that the degenerate case of a range of size 1 works + ref_frame0 = NASA_VIDEO.get_stacked_tensor_by_range(0, 1) bulk_frame0 = get_frames_in_range(decoder, stream_index=3, start=0, stop=1) - assert_equal(bulk_frame0[0], ref_frames0_9[0]) + assert_tensor_equal(ref_frame0, bulk_frame0) + ref_frame1 = NASA_VIDEO.get_stacked_tensor_by_range(1, 2) bulk_frame1 = get_frames_in_range(decoder, stream_index=3, start=1, stop=2) - assert_equal(bulk_frame1[0], ref_frames0_9[1]) - - bulk_frame180 = get_frames_in_range( - decoder, stream_index=3, start=180, stop=181 - ) - assert_equal(bulk_frame180[0], ref_frame180) + assert_tensor_equal(ref_frame1, bulk_frame1) - bulk_frame_last = get_frames_in_range( + ref_frame389 = NASA_VIDEO.get_stacked_tensor_by_range(389, 390) + bulk_frame389 = get_frames_in_range( decoder, stream_index=3, start=389, stop=390 ) - assert_equal(bulk_frame_last[0], ref_frame_last) + assert_tensor_equal(ref_frame389, bulk_frame389) # contiguous ranges + ref_frames0_9 = NASA_VIDEO.get_stacked_tensor_by_range(0, 9) bulk_frames0_9 = get_frames_in_range(decoder, stream_index=3, start=0, stop=9) - for i in range(0, 9): - assert_equal(ref_frames0_9[i], bulk_frames0_9[i]) + assert_tensor_equal(ref_frames0_9, bulk_frames0_9) + ref_frames4_8 = NASA_VIDEO.get_stacked_tensor_by_range(4, 8) bulk_frames4_8 = get_frames_in_range(decoder, stream_index=3, start=4, stop=8) - for i, bulk_frame in enumerate(bulk_frames4_8): - assert_equal(ref_frames0_9[i + 4], bulk_frame) + assert_tensor_equal(ref_frames4_8, bulk_frames4_8) # ranges with a stride - ref_frames15_35 = [ - load_tensor_from_file(f"nasa_13013.mp4.frame{i:06d}.pt") - for i in range(15, 36, 5) - ] + ref_frames15_35 = NASA_VIDEO.get_stacked_tensor_by_range(15, 36, 5) bulk_frames15_35 = get_frames_in_range( decoder, stream_index=3, start=15, stop=36, step=5 ) - for i, bulk_frame in enumerate(bulk_frames15_35): - assert_equal(ref_frames15_35[i], bulk_frame) + assert_tensor_equal(ref_frames15_35, bulk_frames15_35) + ref_frames0_9_2 = NASA_VIDEO.get_stacked_tensor_by_range(0, 9, 2) bulk_frames0_9_2 = get_frames_in_range( decoder, stream_index=3, start=0, stop=9, step=2 ) - for i, bulk_frame in enumerate(bulk_frames0_9_2): - assert_equal(ref_frames0_9[i * 2], bulk_frame) + assert_tensor_equal(ref_frames0_9_2, bulk_frames0_9_2) # an empty range is valid! empty_frame = get_frames_in_range(decoder, stream_index=3, start=5, stop=5) - assert_equal(empty_frame, EMPTY_REF_TENSOR) + assert_tensor_equal(empty_frame, NASA_VIDEO.empty_hwc_tensor) def test_throws_exception_at_eof(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder) seek_to_pts(decoder, 12.979633) last_frame = get_next_frame(decoder) - reference_last_frame = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt") - assert_equal(last_frame, reference_last_frame) + reference_last_frame = NASA_VIDEO.get_tensor_by_name("time12.979633") + assert_tensor_equal(last_frame, reference_last_frame) with pytest.raises(RuntimeError, match="End of file"): get_next_frame(decoder) def test_throws_exception_if_seek_too_far(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder) # pts=12.979633 is the last frame in the video. seek_to_pts(decoder, 12.979633 + 1.0e-4) @@ -189,19 +170,19 @@ def test_compile_seek_and_next(self): @torch.compile(fullgraph=True, backend="eager") def get_frame1_and_frame_time6(decoder): add_video_stream(decoder) - frame1 = get_next_frame(decoder) + frame0 = get_next_frame(decoder) seek_to_pts(decoder, 6.0) frame_time6 = get_next_frame(decoder) - return frame1, frame_time6 + return frame0, frame_time6 # NB: create needs to happen outside the torch.compile region, # for now. Otherwise torch.compile constant-props it. - decoder = create_from_file(str(get_reference_video_path())) - frame1, frame_time6 = get_frame1_and_frame_time6(decoder) - reference_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") - reference_frame_time6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") - assert_equal(frame1, reference_frame1) - assert_equal(frame_time6, reference_frame_time6) + decoder = create_from_file(str(NASA_VIDEO.path)) + frame0, frame_time6 = get_frame1_and_frame_time6(decoder) + reference_frame0 = NASA_VIDEO.get_tensor_by_index(0) + reference_frame_time6 = NASA_VIDEO.get_tensor_by_name("time6.000000") + assert_tensor_equal(frame0, reference_frame0) + assert_tensor_equal(frame_time6, reference_frame_time6) def test_class_based_compile_seek_and_next(self): # TODO(T180277797): Ditto as above. @@ -209,21 +190,21 @@ def test_class_based_compile_seek_and_next(self): def class_based_get_frame1_and_frame_time6( decoder: ReferenceDecoder, ) -> Tuple[torch.Tensor, torch.Tensor]: - frame1 = decoder.get_next_frame() + frame0 = decoder.get_next_frame() decoder.seek(6.0) frame_time6 = decoder.get_next_frame() - return frame1, frame_time6 + return frame0, frame_time6 decoder = ReferenceDecoder() - frame1, frame_time6 = class_based_get_frame1_and_frame_time6(decoder) - reference_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") - reference_frame_time6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") - assert_equal(frame1, reference_frame1) - assert_equal(frame_time6, reference_frame_time6) + frame0, frame_time6 = class_based_get_frame1_and_frame_time6(decoder) + reference_frame0 = NASA_VIDEO.get_tensor_by_index(0) + reference_frame_time6 = NASA_VIDEO.get_tensor_by_name("time6.000000") + assert_tensor_equal(frame0, reference_frame0) + assert_tensor_equal(frame_time6, reference_frame_time6) @pytest.mark.parametrize("create_from", ("file", "tensor", "bytes")) def test_create_decoder(self, create_from): - path = str(get_reference_video_path()) + path = str(NASA_VIDEO.path) if create_from == "file": decoder = create_from_file(path) elif create_from == "tensor": @@ -236,16 +217,16 @@ def test_create_decoder(self, create_from): decoder = create_from_bytes(video_bytes) add_video_stream(decoder) + frame0 = get_next_frame(decoder) + reference_frame0 = NASA_VIDEO.get_tensor_by_index(0) + assert_tensor_equal(frame0, reference_frame0) + reference_frame1 = NASA_VIDEO.get_tensor_by_index(1) frame1 = get_next_frame(decoder) - reference_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt") - assert_equal(frame1, reference_frame1) - reference_frame2 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt") - img2 = get_next_frame(decoder) - assert_equal(img2, reference_frame2) + assert_tensor_equal(frame1, reference_frame1) seek_to_pts(decoder, 6.0) frame_time6 = get_next_frame(decoder) - reference_frame_time6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") - assert_equal(frame_time6, reference_frame_time6) + reference_frame_time6 = NASA_VIDEO.get_tensor_by_name("time6.000000") + assert_tensor_equal(frame_time6, reference_frame_time6) # TODO: Keeping the metadata tests below for now, but we should remove them # once we remove get_json_metadata(). @@ -255,7 +236,7 @@ def test_create_decoder(self, create_from): # always call scanFileAndUpdateMetadataAndIndex() when creating a decoder # from the core API. def test_video_get_json_metadata(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_VIDEO.path)) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) @@ -271,7 +252,7 @@ def test_video_get_json_metadata(self): assert metadata_dict["bitRate"] == 324915.0 def test_video_get_json_metadata_with_stream(self): - decoder = create_from_file(str(get_reference_video_path())) + decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) @@ -281,7 +262,7 @@ def test_video_get_json_metadata_with_stream(self): assert metadata_dict["maxPtsSecondsFromScan"] == 13.013 def test_audio_get_json_metadata(self): - decoder = create_from_file(str(get_reference_audio_path())) + decoder = create_from_file(str(NASA_AUDIO.path)) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) assert metadata_dict["durationSeconds"] == pytest.approx(13.25, abs=0.01) diff --git a/test/generate_reference_resources.sh b/test/generate_reference_resources.sh index 8f1d6161..f38a0b79 100755 --- a/test/generate_reference_resources.sh +++ b/test/generate_reference_resources.sh @@ -11,22 +11,21 @@ TORCHCODEC_PATH=$HOME/fbsource/fbcode/pytorch/torchcodec RESOURCES_DIR=$TORCHCODEC_PATH/test/resources VIDEO_PATH=$RESOURCES_DIR/nasa_13013.mp4 -# Important note: I used ffmpeg version 6.1.1 to generate these images. We -# must have the version that matches the one that we link against in the test. -# TODO: The first 10 frames are numbered starting from 1, so their name is one more -# than their index. This is confusing. We should unify the naming so files are -# named by their index. This will inovlve also updating the tests that load -# these files. -ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,0)+eq(n\,1)+eq(n\,2)+eq(n\,3)+eq(n\,4)+eq(n\,5)+eq(n\,6)+eq(n\,7)+eq(n\,8)+eq(n\,9)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame%06d.bmp" -ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,15)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000015.bmp" -ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,20)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000020.bmp" -ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,25)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000025.bmp" -ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,30)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000030.bmp" -ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,35)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000035.bmp" -ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,386)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000386.bmp" -ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,387)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000387.bmp" -ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,388)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000388.bmp" -ffmpeg -y -i "$VIDEO_PATH" -vf select='eq(n\,389)' -vsync vfr -q:v 2 "$VIDEO_PATH.frame000389.bmp" +# Last generated with ffmpeg version 4.3 +# +# Note: The naming scheme used here must match the naming scheme used to load +# tensors in test_utils.py. +FRAMES=(0 1 2 3 4 5 6 7 8 9) +FRAMES+=(15 20 25 30 35) +FRAMES+=(386 387 388 389) +for frame in "${FRAMES[@]}"; do + # Note that we are using 0-based index naming. Asking ffmpeg to number output + # frames would result in 1-based index naming. We enforce 0-based index naming + # so that the name of reference frames matches the index when accessing that + # frame in the Python decoder. + frame_name=$(printf "%06d" "$frame") + ffmpeg -y -i "$VIDEO_PATH" -vf select="eq(n\,$frame)" -vsync vfr -q:v 2 "$VIDEO_PATH.frame$frame_name.bmp" +done ffmpeg -y -ss 6.0 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time6.000000.bmp" ffmpeg -y -ss 6.1 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time6.100000.bmp" ffmpeg -y -ss 10.0 -i "$VIDEO_PATH" -frames:v 1 "$VIDEO_PATH.time10.000000.bmp" diff --git a/test/resources/nasa_13013.mp4.frame000000.pt b/test/resources/nasa_13013.mp4.frame000000.pt new file mode 100644 index 00000000..497dd47c Binary files /dev/null and b/test/resources/nasa_13013.mp4.frame000000.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000001.pt b/test/resources/nasa_13013.mp4.frame000001.pt index c49a895e..cff13c10 100644 Binary files a/test/resources/nasa_13013.mp4.frame000001.pt and b/test/resources/nasa_13013.mp4.frame000001.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000002.pt b/test/resources/nasa_13013.mp4.frame000002.pt index 813eaa6b..58df03a5 100644 Binary files a/test/resources/nasa_13013.mp4.frame000002.pt and b/test/resources/nasa_13013.mp4.frame000002.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000003.pt b/test/resources/nasa_13013.mp4.frame000003.pt index 2d9a04cd..cbedb9a5 100644 Binary files a/test/resources/nasa_13013.mp4.frame000003.pt and b/test/resources/nasa_13013.mp4.frame000003.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000004.pt b/test/resources/nasa_13013.mp4.frame000004.pt index 1efb87e8..51cb2196 100644 Binary files a/test/resources/nasa_13013.mp4.frame000004.pt and b/test/resources/nasa_13013.mp4.frame000004.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000005.pt b/test/resources/nasa_13013.mp4.frame000005.pt index 427c06cc..9135c42c 100644 Binary files a/test/resources/nasa_13013.mp4.frame000005.pt and b/test/resources/nasa_13013.mp4.frame000005.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000006.pt b/test/resources/nasa_13013.mp4.frame000006.pt index a6ceb4ce..295d87c3 100644 Binary files a/test/resources/nasa_13013.mp4.frame000006.pt and b/test/resources/nasa_13013.mp4.frame000006.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000007.pt b/test/resources/nasa_13013.mp4.frame000007.pt index 7fd4d482..4e7e868d 100644 Binary files a/test/resources/nasa_13013.mp4.frame000007.pt and b/test/resources/nasa_13013.mp4.frame000007.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000008.pt b/test/resources/nasa_13013.mp4.frame000008.pt index cbbdce81..7b360177 100644 Binary files a/test/resources/nasa_13013.mp4.frame000008.pt and b/test/resources/nasa_13013.mp4.frame000008.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000009.pt b/test/resources/nasa_13013.mp4.frame000009.pt index 7f075cf0..ff7ebd33 100644 Binary files a/test/resources/nasa_13013.mp4.frame000009.pt and b/test/resources/nasa_13013.mp4.frame000009.pt differ diff --git a/test/resources/nasa_13013.mp4.frame000010.pt b/test/resources/nasa_13013.mp4.frame000010.pt deleted file mode 100644 index ce926a7f..00000000 Binary files a/test/resources/nasa_13013.mp4.frame000010.pt and /dev/null differ diff --git a/test/samplers/video_clip_sampler_test.py b/test/samplers/video_clip_sampler_test.py index 67c41ac6..64c8e04e 100644 --- a/test/samplers/video_clip_sampler_test.py +++ b/test/samplers/video_clip_sampler_test.py @@ -11,7 +11,7 @@ ) from ..test_utils import ( # noqa: F401; see use in test_sampler - assert_equal, + assert_tensor_equal, reference_video_tensor, ) @@ -42,7 +42,7 @@ def test_sampler( video_args = VideoArgs(desired_width=desired_width, desired_height=desired_height) sampler = VideoClipSampler(video_args, sampler_args) clips = sampler(reference_video_tensor) - assert_equal(len(clips), sampler_args.clips_per_video) + assert_tensor_equal(len(clips), sampler_args.clips_per_video) clip = clips[0] if isinstance(sampler_args, TimeBasedSamplerArgs): # TODO FIXME: Looks like we have an API inconsistency. diff --git a/test/test_utils.py b/test/test_utils.py index 6073d843..999f20ec 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -2,25 +2,23 @@ import os import pathlib +from dataclasses import dataclass + import numpy as np import pytest import torch -# The dimensions and type have to match the frames in our reference video. -REF_DIMS = (270, 480, 3) -EMPTY_REF_TENSOR = torch.empty([0, *REF_DIMS], dtype=torch.uint8) + +def assert_tensor_equal(*args, **kwargs): + torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0) def in_fbcode() -> bool: return os.environ.get("IN_FBCODE_TORCHCODEC") == "1" -def assert_equal(*args, **kwargs): - torch.testing.assert_close(*args, **kwargs, atol=0, rtol=0) - - -def get_video_path(filename: str) -> pathlib.Path: +def _get_file_path(filename: str) -> pathlib.Path: if in_fbcode(): resource = ( importlib.resources.files(__spec__.parent) @@ -33,25 +31,67 @@ def get_video_path(filename: str) -> pathlib.Path: return pathlib.Path(__file__).parent / "resources" / filename -def get_reference_video_path() -> pathlib.Path: - return get_video_path("nasa_13013.mp4") +def _load_tensor_from_file(filename: str) -> torch.Tensor: + file_path = _get_file_path(filename) + return torch.load(file_path, weights_only=True) -def get_reference_audio_path() -> pathlib.Path: - return get_video_path("nasa_13013.mp4.audio.mp3") +@pytest.fixture() +def reference_video_tensor() -> torch.Tensor: + return NASA_VIDEO.to_tensor() -def load_tensor_from_file(filename: str) -> torch.Tensor: - file_path = get_video_path(filename) - return torch.load(file_path, weights_only=True) +@dataclass +class TestContainerFile: + filename: str + @property + def path(self) -> pathlib.Path: + return _get_file_path(self.filename) -def get_reference_video_tensor() -> torch.Tensor: - arr = np.fromfile(get_reference_video_path(), dtype=np.uint8) - video_tensor = torch.from_numpy(arr) - return video_tensor + def to_tensor(self) -> torch.Tensor: + arr = np.fromfile(self.path, dtype=np.uint8) + return torch.from_numpy(arr) + def get_tensor_by_index(self, idx: int) -> torch.Tensor: + return _load_tensor_from_file(f"{self.filename}.frame{idx:06d}.pt") -@pytest.fixture() -def reference_video_tensor() -> torch.Tensor: - return get_reference_video_tensor() + def get_stacked_tensor_by_range( + self, start: int, stop: int, step: int = 1 + ) -> torch.Tensor: + tensors = [self.get_tensor_by_index(i) for i in range(start, stop, step)] + return torch.stack(tensors) + + def get_tensor_by_name(self, name: str) -> torch.Tensor: + return _load_tensor_from_file(f"{self.filename}.{name}.pt") + + +@dataclass +class TestVideo(TestContainerFile): + """ + Represents a video file used in our testing. + + Note that right now, we implicitly only support a single stream. Our current tests always + use the "best" stream as defined by FFMPEG. In general, height, width and num_color_channels + can vary per-stream. When we start testing multiple streams in the same video, we will have + to generalize this class. + """ + + height: int + width: int + num_color_channels: int + + @property + def empty_hwc_tensor(self) -> torch.Tensor: + return torch.empty( + [0, self.height, self.width, self.num_color_channels], dtype=torch.uint8 + ) + + +NASA_VIDEO = TestVideo( + filename="nasa_13013.mp4", height=270, width=480, num_color_channels=3 +) + +# When we start actually decoding audio-only files, we'll probably need to define +# a TestAudio class with audio specific values. Until then, we only need a filename. +NASA_AUDIO = TestContainerFile(filename="nasa_13013.mp4.audio.mp3")