Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torchcodec] refactor test utils to be based around dataclasses #57

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions test/decoders/VideoDecoderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,35 +166,35 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) {
createDecoderFromPath(path, GetParam());
ourDecoder->addVideoStreamDecoder(-1);
auto output = ourDecoder->getNextDecodedOutput();
torch::Tensor tensor1FromOurDecoder = output.frame;
EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector<long>({270, 480, 3}));
torch::Tensor tensor0FromOurDecoder = output.frame;
EXPECT_EQ(tensor0FromOurDecoder.sizes(), std::vector<long>({270, 480, 3}));
EXPECT_EQ(output.ptsSeconds, 0.0);
EXPECT_EQ(output.pts, 0);
output = ourDecoder->getNextDecodedOutput();
torch::Tensor tensor2FromOurDecoder = output.frame;
EXPECT_EQ(tensor2FromOurDecoder.sizes(), std::vector<long>({270, 480, 3}));
torch::Tensor tensor1FromOurDecoder = output.frame;
EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector<long>({270, 480, 3}));
EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000);
EXPECT_EQ(output.pts, 1001);

torch::Tensor tensor0FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000000.pt");
torch::Tensor tensor1FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000001.pt");
torch::Tensor tensor2FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000002.pt");

EXPECT_EQ(tensor1FromFFMPEG.sizes(), std::vector<long>({270, 480, 3}));
EXPECT_TRUE(torch::equal(tensor0FromOurDecoder, tensor0FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor1FromOurDecoder, tensor1FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor2FromOurDecoder, tensor2FromFFMPEG));
EXPECT_TRUE(
torch::allclose(tensor1FromOurDecoder, tensor1FromFFMPEG, 0.1, 20));
EXPECT_EQ(tensor2FromFFMPEG.sizes(), std::vector<long>({270, 480, 3}));
torch::allclose(tensor0FromOurDecoder, tensor0FromFFMPEG, 0.1, 20));
EXPECT_EQ(tensor1FromFFMPEG.sizes(), std::vector<long>({270, 480, 3}));
EXPECT_TRUE(
torch::allclose(tensor2FromOurDecoder, tensor2FromFFMPEG, 0.1, 20));
torch::allclose(tensor1FromOurDecoder, tensor1FromFFMPEG, 0.1, 20));

if (FLAGS_dump_frames_for_debugging) {
dumpTensorToDisk(tensor0FromFFMPEG, "tensor0FromFFMPEG.pt");
dumpTensorToDisk(tensor1FromFFMPEG, "tensor1FromFFMPEG.pt");
dumpTensorToDisk(tensor2FromFFMPEG, "tensor2FromFFMPEG.pt");
dumpTensorToDisk(tensor0FromOurDecoder, "tensor0FromOurDecoder.pt");
dumpTensorToDisk(tensor1FromOurDecoder, "tensor1FromOurDecoder.pt");
dumpTensorToDisk(tensor2FromOurDecoder, "tensor2FromOurDecoder.pt");
}
}

Expand All @@ -211,13 +211,13 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNHWC) {
auto tensor = output.frames;
EXPECT_EQ(tensor.sizes(), std::vector<long>({2, 270, 480, 3}));

torch::Tensor tensor1FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000001.pt");
torch::Tensor tensor2FromFFMPEG =
torch::Tensor tensor0FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000000.pt");
torch::Tensor tensorTime6FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.time6.000000.pt");

EXPECT_TRUE(torch::equal(tensor[0], tensor1FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor[1], tensor2FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor[0], tensor0FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor[1], tensorTime6FromFFMPEG));
}

TEST_P(VideoDecoderTest, DecodesFramesInABatchInNCHW) {
Expand All @@ -235,14 +235,14 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNCHW) {
auto tensor = output.frames;
EXPECT_EQ(tensor.sizes(), std::vector<long>({2, 3, 270, 480}));

torch::Tensor tensor1FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000001.pt");
torch::Tensor tensor2FromFFMPEG =
torch::Tensor tensor0FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000000.pt");
torch::Tensor tensorTime6FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.time6.000000.pt");

tensor = tensor.permute({0, 2, 3, 1});
EXPECT_TRUE(torch::equal(tensor[0], tensor1FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor[1], tensor2FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor[0], tensor0FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor[1], tensorTime6FromFFMPEG));
}

TEST_P(VideoDecoderTest, SeeksCloseToEof) {
Expand Down
233 changes: 142 additions & 91 deletions test/decoders/simple_video_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,18 @@

from torchcodec.decoders import _core, SimpleVideoDecoder

from ..test_utils import (
assert_equal,
EMPTY_REF_TENSOR,
get_reference_video_path,
get_reference_video_tensor,
load_tensor_from_file,
REF_DIMS,
)
from ..test_utils import assert_tensor_equal, NASA_VIDEO


class TestSimpleDecoder:
@pytest.mark.parametrize("source_kind", ("path", "tensor", "bytes"))
def test_create(self, source_kind):
if source_kind == "path":
source = str(get_reference_video_path())
source = str(NASA_VIDEO.path)
elif source_kind == "tensor":
source = get_reference_video_tensor()
source = NASA_VIDEO.to_tensor()
elif source_kind == "bytes":
path = str(get_reference_video_path())
path = str(NASA_VIDEO.path)
with open(path, "rb") as f:
source = f.read()
else:
Expand All @@ -42,99 +35,157 @@ def test_create_fails(self):
decoder = SimpleVideoDecoder(123) # noqa

def test_getitem_int(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))
decoder = SimpleVideoDecoder(str(NASA_VIDEO.path))

ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt")
ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt")
ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt")
ref_frame0 = NASA_VIDEO.get_tensor_by_index(0)
ref_frame1 = NASA_VIDEO.get_tensor_by_index(1)
ref_frame180 = NASA_VIDEO.get_tensor_by_name("time6.000000")
ref_frame_last = NASA_VIDEO.get_tensor_by_name("time12.979633")

assert_equal(ref_frame0, decoder[0])
assert_equal(ref_frame1, decoder[1])
assert_equal(ref_frame180, decoder[180])
assert_equal(ref_frame_last, decoder[-1])
assert_tensor_equal(ref_frame0, decoder[0])
assert_tensor_equal(ref_frame1, decoder[1])
assert_tensor_equal(ref_frame180, decoder[180])
assert_tensor_equal(ref_frame_last, decoder[-1])

def test_getitem_slice(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))

ref_frames0_9 = [
load_tensor_from_file(f"nasa_13013.mp4.frame{i + 1:06d}.pt")
for i in range(0, 9)
]

# Ensure that the degenerate case of a range of size 1 works; note that we get
# a tensor which CONTAINS a single frame, rather than a tensor that itself IS a
# single frame. Hence we have to access the 0th element of the return tensor.
slice_0 = decoder[0:1]
assert slice_0.shape == torch.Size([1, *REF_DIMS])
assert_equal(ref_frames0_9[0], slice_0[0])

slice_4 = decoder[4:5]
assert slice_4.shape == torch.Size([1, *REF_DIMS])
assert_equal(ref_frames0_9[4], slice_4[0])

slice_8 = decoder[8:9]
assert slice_8.shape == torch.Size([1, *REF_DIMS])
assert_equal(ref_frames0_9[8], slice_8[0])

ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
slice_180 = decoder[180:181]
assert slice_180.shape == torch.Size([1, *REF_DIMS])
assert_equal(ref_frame180, slice_180[0])
decoder = SimpleVideoDecoder(str(NASA_VIDEO.path))

# ensure that the degenerate case of a range of size 1 works

ref0 = NASA_VIDEO.get_stacked_tensor_by_range(0, 1)
slice0 = decoder[0:1]
assert slice0.shape == torch.Size(
[
1,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref0, slice0)

ref4 = NASA_VIDEO.get_stacked_tensor_by_range(4, 5)
slice4 = decoder[4:5]
assert slice4.shape == torch.Size(
[
1,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref4, slice4)

ref8 = NASA_VIDEO.get_stacked_tensor_by_range(8, 9)
slice8 = decoder[8:9]
assert slice8.shape == torch.Size(
[
1,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref8, slice8)

ref180 = NASA_VIDEO.get_tensor_by_name("time6.000000")
slice180 = decoder[180:181]
assert slice180.shape == torch.Size(
[
1,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref180, slice180[0])

# contiguous ranges
slice_frames0_9 = decoder[0:9]
assert slice_frames0_9.shape == torch.Size([9, *REF_DIMS])
for ref_frame, slice_frame in zip(ref_frames0_9, slice_frames0_9):
assert_equal(ref_frame, slice_frame)

slice_frames4_8 = decoder[4:8]
assert slice_frames4_8.shape == torch.Size([4, *REF_DIMS])
for ref_frame, slice_frame in zip(ref_frames0_9[4:8], slice_frames4_8):
assert_equal(ref_frame, slice_frame)
ref0_9 = NASA_VIDEO.get_stacked_tensor_by_range(0, 9)
slice0_9 = decoder[0:9]
assert slice0_9.shape == torch.Size(
[
9,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref0_9, slice0_9)

ref4_8 = NASA_VIDEO.get_stacked_tensor_by_range(4, 8)
slice4_8 = decoder[4:8]
assert slice4_8.shape == torch.Size(
[
4,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref4_8, slice4_8)

# ranges with a stride
ref_frames15_35 = [
load_tensor_from_file(f"nasa_13013.mp4.frame{i:06d}.pt")
for i in range(15, 36, 5)
]
slice_frames15_35 = decoder[15:36:5]
assert slice_frames15_35.shape == torch.Size([5, *REF_DIMS])
for ref_frame, slice_frame in zip(ref_frames15_35, slice_frames15_35):
assert_equal(ref_frame, slice_frame)

slice_frames0_9_2 = decoder[0:9:2]
assert slice_frames0_9_2.shape == torch.Size([5, *REF_DIMS])
for ref_frame, slice_frame in zip(ref_frames0_9[0:0:2], slice_frames0_9_2):
assert_equal(ref_frame, slice_frame)
ref15_35 = NASA_VIDEO.get_stacked_tensor_by_range(15, 36, 5)
slice15_35 = decoder[15:36:5]
assert slice15_35.shape == torch.Size(
[
5,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref15_35, slice15_35)

ref0_9_2 = NASA_VIDEO.get_stacked_tensor_by_range(0, 9, 2)
slice0_9_2 = decoder[0:9:2]
assert slice0_9_2.shape == torch.Size(
[
5,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref0_9_2, slice0_9_2)

# negative numbers in the slice
ref_frames386_389 = [
load_tensor_from_file(f"nasa_13013.mp4.frame{i:06d}.pt")
for i in range(386, 390)
]

slice_frames386_389 = decoder[-4:]
assert slice_frames386_389.shape == torch.Size([4, *REF_DIMS])
for ref_frame, slice_frame in zip(ref_frames386_389[-4:], slice_frames386_389):
assert_equal(ref_frame, slice_frame)
ref386_389 = NASA_VIDEO.get_stacked_tensor_by_range(386, 390)
slice386_389 = decoder[-4:]
assert slice386_389.shape == torch.Size(
[
4,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref386_389, slice386_389)

# an empty range is valid!
empty_frame = decoder[5:5]
assert_equal(empty_frame, EMPTY_REF_TENSOR)
assert_tensor_equal(empty_frame, NASA_VIDEO.empty_hwc_tensor)

# slices that are out-of-range are also valid - they return an empty tensor
also_empty = decoder[10000:]
assert_equal(also_empty, EMPTY_REF_TENSOR)
assert_tensor_equal(also_empty, NASA_VIDEO.empty_hwc_tensor)

# should be just a copy
all_frames = decoder[:]
assert all_frames.shape == torch.Size([len(decoder), *REF_DIMS])
assert all_frames.shape == torch.Size(
[
len(decoder),
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
for sliced, ref in zip(all_frames, decoder):
assert_equal(sliced, ref)
assert_tensor_equal(sliced, ref)

def test_getitem_fails(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))
decoder = SimpleVideoDecoder(str(NASA_VIDEO.path))

with pytest.raises(IndexError, match="out of bounds"):
frame = decoder[1000] # noqa
Expand All @@ -146,22 +197,22 @@ def test_getitem_fails(self):
frame = decoder["0"] # noqa

def test_next(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))
decoder = SimpleVideoDecoder(str(NASA_VIDEO.path))

ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt")
ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt")
ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt")
ref_frame0 = NASA_VIDEO.get_tensor_by_index(0)
ref_frame1 = NASA_VIDEO.get_tensor_by_index(1)
ref_frame180 = NASA_VIDEO.get_tensor_by_name("time6.000000")
ref_frame_last = NASA_VIDEO.get_tensor_by_name("time12.979633")

for i, frame in enumerate(decoder):
if i == 0:
assert_equal(ref_frame0, frame)
assert_tensor_equal(ref_frame0, frame)
elif i == 1:
assert_equal(ref_frame1, frame)
assert_tensor_equal(ref_frame1, frame)
elif i == 180:
assert_equal(ref_frame180, frame)
assert_tensor_equal(ref_frame180, frame)
elif i == 389:
assert_equal(ref_frame_last, frame)
assert_tensor_equal(ref_frame_last, frame)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions test/decoders/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
StreamMetadata,
)

from ..test_utils import get_reference_video_path
from ..test_utils import NASA_VIDEO


def test_get_video_metadata():
decoder = create_from_file(str(get_reference_video_path()))
decoder = create_from_file(str(NASA_VIDEO.path))
metadata = get_video_metadata(decoder)
assert len(metadata.streams) == 6
assert metadata.best_video_stream_index == 3
Expand Down
Loading
Loading