Skip to content

Commit

Permalink
Delete assert_tensor_close util
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Dec 2, 2024
1 parent d0798ad commit 2cb26b0
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 22 deletions.
30 changes: 17 additions & 13 deletions test/decoders/test_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,7 @@

from torchcodec.decoders import _core, VideoDecoder

from ..utils import (
assert_frames_equal,
assert_tensor_close,
cpu_and_cuda,
H265_VIDEO,
NASA_VIDEO,
)
from ..utils import assert_frames_equal, cpu_and_cuda, H265_VIDEO, NASA_VIDEO


class TestVideoDecoder:
Expand Down Expand Up @@ -538,13 +532,17 @@ def test_get_frames_in_range(self, stream_index, device):
]
)
assert_frames_equal(ref_frames0_9, frames0_9.data)
assert_tensor_close(
torch.testing.assert_close(
NASA_VIDEO.get_pts_seconds_by_range(0, 10, stream_index=stream_index),
frames0_9.pts_seconds,
atol=1e-6,
rtol=1e-6,
)
assert_tensor_close(
torch.testing.assert_close(
NASA_VIDEO.get_duration_seconds_by_range(0, 10, stream_index=stream_index),
frames0_9.duration_seconds,
atol=1e-6,
rtol=1e-6,
)

# test steps
Expand All @@ -561,15 +559,19 @@ def test_get_frames_in_range(self, stream_index, device):
]
)
assert_frames_equal(ref_frames0_8_2, frames0_8_2.data)
assert_tensor_close(
torch.testing.assert_close(
NASA_VIDEO.get_pts_seconds_by_range(0, 10, 2, stream_index=stream_index),
frames0_8_2.pts_seconds,
atol=1e-6,
rtol=1e-6,
)
assert_tensor_close(
torch.testing.assert_close(
NASA_VIDEO.get_duration_seconds_by_range(
0, 10, 2, stream_index=stream_index
),
frames0_8_2.duration_seconds,
atol=1e-6,
rtol=1e-6,
)

# test numpy.int64 for indices
Expand All @@ -584,8 +586,10 @@ def test_get_frames_in_range(self, stream_index, device):
empty_frames.data,
NASA_VIDEO.get_empty_chw_tensor(stream_index=stream_index).to(device),
)
assert_tensor_close(empty_frames.pts_seconds, NASA_VIDEO.empty_pts_seconds)
assert_tensor_close(
torch.testing.assert_close(
empty_frames.pts_seconds, NASA_VIDEO.empty_pts_seconds
)
torch.testing.assert_close(
empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds
)

Expand Down
9 changes: 0 additions & 9 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,6 @@ def assert_frames_equal(*args, **kwargs):
torch.testing.assert_close(*args, **kwargs, atol=absolute_tolerance, rtol=0)


# For use with floating point metadata, or in other instances where we are not confident
# that reference and test tensors can be exactly equal. This is true for pts and duration
# in seconds, as the reference values are from ffprobe's JSON output. In that case, it is
# limiting the floating point precision when printing the value as a string. The value from
# JSON and the value we retrieve during decoding are not exactly the same.
def assert_tensor_close(*args, **kwargs):
torch.testing.assert_close(*args, **kwargs, atol=1e-6, rtol=1e-6)


def in_fbcode() -> bool:
return os.environ.get("IN_FBCODE_TORCHCODEC") == "1"

Expand Down

0 comments on commit 2cb26b0

Please sign in to comment.