diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index e90ec3f6..4f80cbc0 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -11,13 +11,7 @@ from torchcodec.decoders import _core, VideoDecoder -from ..utils import ( - assert_frames_equal, - assert_tensor_close, - cpu_and_cuda, - H265_VIDEO, - NASA_VIDEO, -) +from ..utils import assert_frames_equal, cpu_and_cuda, H265_VIDEO, NASA_VIDEO class TestVideoDecoder: @@ -538,13 +532,17 @@ def test_get_frames_in_range(self, stream_index, device): ] ) assert_frames_equal(ref_frames0_9, frames0_9.data) - assert_tensor_close( + torch.testing.assert_close( NASA_VIDEO.get_pts_seconds_by_range(0, 10, stream_index=stream_index), frames0_9.pts_seconds, + atol=1e-6, + rtol=1e-6, ) - assert_tensor_close( + torch.testing.assert_close( NASA_VIDEO.get_duration_seconds_by_range(0, 10, stream_index=stream_index), frames0_9.duration_seconds, + atol=1e-6, + rtol=1e-6, ) # test steps @@ -561,15 +559,19 @@ def test_get_frames_in_range(self, stream_index, device): ] ) assert_frames_equal(ref_frames0_8_2, frames0_8_2.data) - assert_tensor_close( + torch.testing.assert_close( NASA_VIDEO.get_pts_seconds_by_range(0, 10, 2, stream_index=stream_index), frames0_8_2.pts_seconds, + atol=1e-6, + rtol=1e-6, ) - assert_tensor_close( + torch.testing.assert_close( NASA_VIDEO.get_duration_seconds_by_range( 0, 10, 2, stream_index=stream_index ), frames0_8_2.duration_seconds, + atol=1e-6, + rtol=1e-6, ) # test numpy.int64 for indices @@ -584,8 +586,10 @@ def test_get_frames_in_range(self, stream_index, device): empty_frames.data, NASA_VIDEO.get_empty_chw_tensor(stream_index=stream_index).to(device), ) - assert_tensor_close(empty_frames.pts_seconds, NASA_VIDEO.empty_pts_seconds) - assert_tensor_close( + torch.testing.assert_close( + empty_frames.pts_seconds, NASA_VIDEO.empty_pts_seconds + ) + torch.testing.assert_close( empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds ) diff --git a/test/utils.py b/test/utils.py index 67bb815d..8ab9602f 100644 --- a/test/utils.py +++ b/test/utils.py @@ -41,15 +41,6 @@ def assert_frames_equal(*args, **kwargs): torch.testing.assert_close(*args, **kwargs, atol=absolute_tolerance, rtol=0) -# For use with floating point metadata, or in other instances where we are not confident -# that reference and test tensors can be exactly equal. This is true for pts and duration -# in seconds, as the reference values are from ffprobe's JSON output. In that case, it is -# limiting the floating point precision when printing the value as a string. The value from -# JSON and the value we retrieve during decoding are not exactly the same. -def assert_tensor_close(*args, **kwargs): - torch.testing.assert_close(*args, **kwargs, atol=1e-6, rtol=1e-6) - - def in_fbcode() -> bool: return os.environ.get("IN_FBCODE_TORCHCODEC") == "1"