Skip to content

Commit

Permalink
Update existing calls
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Dec 2, 2024
1 parent ca7cf75 commit d0798ad
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
18 changes: 10 additions & 8 deletions test/decoders/test_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from torchcodec.decoders import _core, VideoDecoder

from ..utils import (
assert_tensor_close,
assert_frames_equal,
assert_tensor_close,
cpu_and_cuda,
H265_VIDEO,
NASA_VIDEO,
Expand Down Expand Up @@ -584,8 +584,8 @@ def test_get_frames_in_range(self, stream_index, device):
empty_frames.data,
NASA_VIDEO.get_empty_chw_tensor(stream_index=stream_index).to(device),
)
assert_frames_equal(empty_frames.pts_seconds, NASA_VIDEO.empty_pts_seconds)
assert_frames_equal(
assert_tensor_close(empty_frames.pts_seconds, NASA_VIDEO.empty_pts_seconds)
assert_tensor_close(
empty_frames.duration_seconds, NASA_VIDEO.empty_duration_seconds
)

Expand Down Expand Up @@ -731,12 +731,14 @@ def test_get_frames_by_pts_in_range(self, stream_index, device):
empty_frame.data,
NASA_VIDEO.get_empty_chw_tensor(stream_index=stream_index).to(device),
)
assert_frames_equal(
empty_frame.pts_seconds,
NASA_VIDEO.empty_pts_seconds,
torch.testing.assert_close(
empty_frame.pts_seconds, NASA_VIDEO.empty_pts_seconds, atol=0, rtol=0
)
assert_frames_equal(
empty_frame.duration_seconds, NASA_VIDEO.empty_duration_seconds
torch.testing.assert_close(
empty_frame.duration_seconds,
NASA_VIDEO.empty_duration_seconds,
atol=0,
rtol=0,
)

# Start and stop seconds land within the first frame.
Expand Down
8 changes: 4 additions & 4 deletions test/decoders/test_video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,15 @@ def test_pts_apis_against_index_ref(self, device):
*[get_frame_at_pts(decoder, seconds=pts) for pts in all_pts_seconds_ref]
)
pts_seconds = torch.tensor(pts_seconds)
assert_frames_equal(pts_seconds, all_pts_seconds_ref)
torch.testing.assert_close(pts_seconds, all_pts_seconds_ref, atol=0, rtol=0)

_, pts_seconds, _ = get_frames_by_pts_in_range(
decoder,
stream_index=stream_index,
start_seconds=0,
stop_seconds=all_pts_seconds_ref[-1] + 1e-4,
)
assert_frames_equal(pts_seconds, all_pts_seconds_ref)
torch.testing.assert_close(pts_seconds, all_pts_seconds_ref, atol=0, rtol=0)

_, pts_seconds, _ = zip(
*[
Expand All @@ -280,12 +280,12 @@ def test_pts_apis_against_index_ref(self, device):
]
)
pts_seconds = torch.tensor(pts_seconds)
assert_frames_equal(pts_seconds, all_pts_seconds_ref)
torch.testing.assert_close(pts_seconds, all_pts_seconds_ref, atol=0, rtol=0)

_, pts_seconds, _ = get_frames_by_pts(
decoder, stream_index=stream_index, timestamps=all_pts_seconds_ref.tolist()
)
assert_frames_equal(pts_seconds, all_pts_seconds_ref)
torch.testing.assert_close(pts_seconds, all_pts_seconds_ref, atol=0, rtol=0)

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_get_frames_in_range(self, device):
Expand Down
8 changes: 6 additions & 2 deletions test/samplers/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,8 +437,12 @@ def test_random_sampler_randomness(sampler):

for clip_1, clip_2 in zip(clips_1, clips_2):
assert_frames_equal(clip_1.data, clip_2.data)
assert_frames_equal(clip_1.pts_seconds, clip_2.pts_seconds)
assert_frames_equal(clip_1.duration_seconds, clip_2.duration_seconds)
torch.testing.assert_close(
clip_1.pts_seconds, clip_2.pts_seconds, rtol=0, atol=0
)
torch.testing.assert_close(
clip_1.duration_seconds, clip_2.duration_seconds, rtol=0, atol=0
)

# Call with a different seed, expect different results
torch.manual_seed(1)
Expand Down

0 comments on commit d0798ad

Please sign in to comment.