Skip to content

Commit

Permalink
[torchcodec] refactor test utils to be based around dataclasses (#57)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #57

Refactors our testing utilities to consider the actual reference file to be a first-class concept by making it an object, and then all operations we do are on that object. Some principles I was trying to keep:

1. There should be one clear definition of a reference media file, with all of its important parameters in one place with names that have obvious semantic meaning.
2. Operations that are conceptually connected to a reference media file should be methods on that object.
3. Tests should only have to use the object defined in 1.
4. Formalize the patterns we've already established in how we name reference files.
5. Make adding new reference files easy and obvious.

Right now, we can only support reference tensors by timestamp with a generic approach. I think we should try to make that more an explicit pattern based on the pts value, but I'm not sure exactly how to do that right now. That will probably require changing some of our current reference file names.

This diff also addresses the awkwardness in how we numbered our reference frames: the first 10 frames used to be 1-index based, and the rest 0-index based. This diff makes all frames 0-index based.

Also in the future, the reference file *generation* should probably use these definitions as well. That will ensure we keep everything consistent.

Reviewed By: ahmadsharif1

Differential Revision: D59161329

fbshipit-source-id: f60b733b1bdeb672832e221c30761dea1d4112a9
  • Loading branch information
scotts authored and facebook-github-bot committed Jul 3, 2024
1 parent faf73b0 commit 7af94c7
Show file tree
Hide file tree
Showing 18 changed files with 320 additions and 249 deletions.
44 changes: 22 additions & 22 deletions test/decoders/VideoDecoderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,35 +166,35 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) {
createDecoderFromPath(path, GetParam());
ourDecoder->addVideoStreamDecoder(-1);
auto output = ourDecoder->getNextDecodedOutput();
torch::Tensor tensor1FromOurDecoder = output.frame;
EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector<long>({270, 480, 3}));
torch::Tensor tensor0FromOurDecoder = output.frame;
EXPECT_EQ(tensor0FromOurDecoder.sizes(), std::vector<long>({270, 480, 3}));
EXPECT_EQ(output.ptsSeconds, 0.0);
EXPECT_EQ(output.pts, 0);
output = ourDecoder->getNextDecodedOutput();
torch::Tensor tensor2FromOurDecoder = output.frame;
EXPECT_EQ(tensor2FromOurDecoder.sizes(), std::vector<long>({270, 480, 3}));
torch::Tensor tensor1FromOurDecoder = output.frame;
EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector<long>({270, 480, 3}));
EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000);
EXPECT_EQ(output.pts, 1001);

torch::Tensor tensor0FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000000.pt");
torch::Tensor tensor1FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000001.pt");
torch::Tensor tensor2FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000002.pt");

EXPECT_EQ(tensor1FromFFMPEG.sizes(), std::vector<long>({270, 480, 3}));
EXPECT_TRUE(torch::equal(tensor0FromOurDecoder, tensor0FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor1FromOurDecoder, tensor1FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor2FromOurDecoder, tensor2FromFFMPEG));
EXPECT_TRUE(
torch::allclose(tensor1FromOurDecoder, tensor1FromFFMPEG, 0.1, 20));
EXPECT_EQ(tensor2FromFFMPEG.sizes(), std::vector<long>({270, 480, 3}));
torch::allclose(tensor0FromOurDecoder, tensor0FromFFMPEG, 0.1, 20));
EXPECT_EQ(tensor1FromFFMPEG.sizes(), std::vector<long>({270, 480, 3}));
EXPECT_TRUE(
torch::allclose(tensor2FromOurDecoder, tensor2FromFFMPEG, 0.1, 20));
torch::allclose(tensor1FromOurDecoder, tensor1FromFFMPEG, 0.1, 20));

if (FLAGS_dump_frames_for_debugging) {
dumpTensorToDisk(tensor0FromFFMPEG, "tensor0FromFFMPEG.pt");
dumpTensorToDisk(tensor1FromFFMPEG, "tensor1FromFFMPEG.pt");
dumpTensorToDisk(tensor2FromFFMPEG, "tensor2FromFFMPEG.pt");
dumpTensorToDisk(tensor0FromOurDecoder, "tensor0FromOurDecoder.pt");
dumpTensorToDisk(tensor1FromOurDecoder, "tensor1FromOurDecoder.pt");
dumpTensorToDisk(tensor2FromOurDecoder, "tensor2FromOurDecoder.pt");
}
}

Expand All @@ -211,13 +211,13 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNHWC) {
auto tensor = output.frames;
EXPECT_EQ(tensor.sizes(), std::vector<long>({2, 270, 480, 3}));

torch::Tensor tensor1FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000001.pt");
torch::Tensor tensor2FromFFMPEG =
torch::Tensor tensor0FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000000.pt");
torch::Tensor tensorTime6FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.time6.000000.pt");

EXPECT_TRUE(torch::equal(tensor[0], tensor1FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor[1], tensor2FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor[0], tensor0FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor[1], tensorTime6FromFFMPEG));
}

TEST_P(VideoDecoderTest, DecodesFramesInABatchInNCHW) {
Expand All @@ -235,14 +235,14 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNCHW) {
auto tensor = output.frames;
EXPECT_EQ(tensor.sizes(), std::vector<long>({2, 3, 270, 480}));

torch::Tensor tensor1FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000001.pt");
torch::Tensor tensor2FromFFMPEG =
torch::Tensor tensor0FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.frame000000.pt");
torch::Tensor tensorTime6FromFFMPEG =
readTensorFromDisk("nasa_13013.mp4.time6.000000.pt");

tensor = tensor.permute({0, 2, 3, 1});
EXPECT_TRUE(torch::equal(tensor[0], tensor1FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor[1], tensor2FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor[0], tensor0FromFFMPEG));
EXPECT_TRUE(torch::equal(tensor[1], tensorTime6FromFFMPEG));
}

TEST_P(VideoDecoderTest, SeeksCloseToEof) {
Expand Down
233 changes: 142 additions & 91 deletions test/decoders/simple_video_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,18 @@

from torchcodec.decoders import _core, SimpleVideoDecoder

from ..test_utils import (
assert_equal,
EMPTY_REF_TENSOR,
get_reference_video_path,
get_reference_video_tensor,
load_tensor_from_file,
REF_DIMS,
)
from ..test_utils import assert_tensor_equal, NASA_VIDEO


class TestSimpleDecoder:
@pytest.mark.parametrize("source_kind", ("path", "tensor", "bytes"))
def test_create(self, source_kind):
if source_kind == "path":
source = str(get_reference_video_path())
source = str(NASA_VIDEO.path)
elif source_kind == "tensor":
source = get_reference_video_tensor()
source = NASA_VIDEO.to_tensor()
elif source_kind == "bytes":
path = str(get_reference_video_path())
path = str(NASA_VIDEO.path)
with open(path, "rb") as f:
source = f.read()
else:
Expand All @@ -42,99 +35,157 @@ def test_create_fails(self):
decoder = SimpleVideoDecoder(123) # noqa

def test_getitem_int(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))
decoder = SimpleVideoDecoder(str(NASA_VIDEO.path))

ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt")
ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt")
ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt")
ref_frame0 = NASA_VIDEO.get_tensor_by_index(0)
ref_frame1 = NASA_VIDEO.get_tensor_by_index(1)
ref_frame180 = NASA_VIDEO.get_tensor_by_name("time6.000000")
ref_frame_last = NASA_VIDEO.get_tensor_by_name("time12.979633")

assert_equal(ref_frame0, decoder[0])
assert_equal(ref_frame1, decoder[1])
assert_equal(ref_frame180, decoder[180])
assert_equal(ref_frame_last, decoder[-1])
assert_tensor_equal(ref_frame0, decoder[0])
assert_tensor_equal(ref_frame1, decoder[1])
assert_tensor_equal(ref_frame180, decoder[180])
assert_tensor_equal(ref_frame_last, decoder[-1])

def test_getitem_slice(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))

ref_frames0_9 = [
load_tensor_from_file(f"nasa_13013.mp4.frame{i + 1:06d}.pt")
for i in range(0, 9)
]

# Ensure that the degenerate case of a range of size 1 works; note that we get
# a tensor which CONTAINS a single frame, rather than a tensor that itself IS a
# single frame. Hence we have to access the 0th element of the return tensor.
slice_0 = decoder[0:1]
assert slice_0.shape == torch.Size([1, *REF_DIMS])
assert_equal(ref_frames0_9[0], slice_0[0])

slice_4 = decoder[4:5]
assert slice_4.shape == torch.Size([1, *REF_DIMS])
assert_equal(ref_frames0_9[4], slice_4[0])

slice_8 = decoder[8:9]
assert slice_8.shape == torch.Size([1, *REF_DIMS])
assert_equal(ref_frames0_9[8], slice_8[0])

ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
slice_180 = decoder[180:181]
assert slice_180.shape == torch.Size([1, *REF_DIMS])
assert_equal(ref_frame180, slice_180[0])
decoder = SimpleVideoDecoder(str(NASA_VIDEO.path))

# ensure that the degenerate case of a range of size 1 works

ref0 = NASA_VIDEO.get_stacked_tensor_by_range(0, 1)
slice0 = decoder[0:1]
assert slice0.shape == torch.Size(
[
1,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref0, slice0)

ref4 = NASA_VIDEO.get_stacked_tensor_by_range(4, 5)
slice4 = decoder[4:5]
assert slice4.shape == torch.Size(
[
1,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref4, slice4)

ref8 = NASA_VIDEO.get_stacked_tensor_by_range(8, 9)
slice8 = decoder[8:9]
assert slice8.shape == torch.Size(
[
1,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref8, slice8)

ref180 = NASA_VIDEO.get_tensor_by_name("time6.000000")
slice180 = decoder[180:181]
assert slice180.shape == torch.Size(
[
1,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref180, slice180[0])

# contiguous ranges
slice_frames0_9 = decoder[0:9]
assert slice_frames0_9.shape == torch.Size([9, *REF_DIMS])
for ref_frame, slice_frame in zip(ref_frames0_9, slice_frames0_9):
assert_equal(ref_frame, slice_frame)

slice_frames4_8 = decoder[4:8]
assert slice_frames4_8.shape == torch.Size([4, *REF_DIMS])
for ref_frame, slice_frame in zip(ref_frames0_9[4:8], slice_frames4_8):
assert_equal(ref_frame, slice_frame)
ref0_9 = NASA_VIDEO.get_stacked_tensor_by_range(0, 9)
slice0_9 = decoder[0:9]
assert slice0_9.shape == torch.Size(
[
9,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref0_9, slice0_9)

ref4_8 = NASA_VIDEO.get_stacked_tensor_by_range(4, 8)
slice4_8 = decoder[4:8]
assert slice4_8.shape == torch.Size(
[
4,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref4_8, slice4_8)

# ranges with a stride
ref_frames15_35 = [
load_tensor_from_file(f"nasa_13013.mp4.frame{i:06d}.pt")
for i in range(15, 36, 5)
]
slice_frames15_35 = decoder[15:36:5]
assert slice_frames15_35.shape == torch.Size([5, *REF_DIMS])
for ref_frame, slice_frame in zip(ref_frames15_35, slice_frames15_35):
assert_equal(ref_frame, slice_frame)

slice_frames0_9_2 = decoder[0:9:2]
assert slice_frames0_9_2.shape == torch.Size([5, *REF_DIMS])
for ref_frame, slice_frame in zip(ref_frames0_9[0:0:2], slice_frames0_9_2):
assert_equal(ref_frame, slice_frame)
ref15_35 = NASA_VIDEO.get_stacked_tensor_by_range(15, 36, 5)
slice15_35 = decoder[15:36:5]
assert slice15_35.shape == torch.Size(
[
5,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref15_35, slice15_35)

ref0_9_2 = NASA_VIDEO.get_stacked_tensor_by_range(0, 9, 2)
slice0_9_2 = decoder[0:9:2]
assert slice0_9_2.shape == torch.Size(
[
5,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref0_9_2, slice0_9_2)

# negative numbers in the slice
ref_frames386_389 = [
load_tensor_from_file(f"nasa_13013.mp4.frame{i:06d}.pt")
for i in range(386, 390)
]

slice_frames386_389 = decoder[-4:]
assert slice_frames386_389.shape == torch.Size([4, *REF_DIMS])
for ref_frame, slice_frame in zip(ref_frames386_389[-4:], slice_frames386_389):
assert_equal(ref_frame, slice_frame)
ref386_389 = NASA_VIDEO.get_stacked_tensor_by_range(386, 390)
slice386_389 = decoder[-4:]
assert slice386_389.shape == torch.Size(
[
4,
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
assert_tensor_equal(ref386_389, slice386_389)

# an empty range is valid!
empty_frame = decoder[5:5]
assert_equal(empty_frame, EMPTY_REF_TENSOR)
assert_tensor_equal(empty_frame, NASA_VIDEO.empty_hwc_tensor)

# slices that are out-of-range are also valid - they return an empty tensor
also_empty = decoder[10000:]
assert_equal(also_empty, EMPTY_REF_TENSOR)
assert_tensor_equal(also_empty, NASA_VIDEO.empty_hwc_tensor)

# should be just a copy
all_frames = decoder[:]
assert all_frames.shape == torch.Size([len(decoder), *REF_DIMS])
assert all_frames.shape == torch.Size(
[
len(decoder),
NASA_VIDEO.height,
NASA_VIDEO.width,
NASA_VIDEO.num_color_channels,
]
)
for sliced, ref in zip(all_frames, decoder):
assert_equal(sliced, ref)
assert_tensor_equal(sliced, ref)

def test_getitem_fails(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))
decoder = SimpleVideoDecoder(str(NASA_VIDEO.path))

with pytest.raises(IndexError, match="out of bounds"):
frame = decoder[1000] # noqa
Expand All @@ -146,22 +197,22 @@ def test_getitem_fails(self):
frame = decoder["0"] # noqa

def test_next(self):
decoder = SimpleVideoDecoder(str(get_reference_video_path()))
decoder = SimpleVideoDecoder(str(NASA_VIDEO.path))

ref_frame0 = load_tensor_from_file("nasa_13013.mp4.frame000001.pt")
ref_frame1 = load_tensor_from_file("nasa_13013.mp4.frame000002.pt")
ref_frame180 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt")
ref_frame_last = load_tensor_from_file("nasa_13013.mp4.time12.979633.pt")
ref_frame0 = NASA_VIDEO.get_tensor_by_index(0)
ref_frame1 = NASA_VIDEO.get_tensor_by_index(1)
ref_frame180 = NASA_VIDEO.get_tensor_by_name("time6.000000")
ref_frame_last = NASA_VIDEO.get_tensor_by_name("time12.979633")

for i, frame in enumerate(decoder):
if i == 0:
assert_equal(ref_frame0, frame)
assert_tensor_equal(ref_frame0, frame)
elif i == 1:
assert_equal(ref_frame1, frame)
assert_tensor_equal(ref_frame1, frame)
elif i == 180:
assert_equal(ref_frame180, frame)
assert_tensor_equal(ref_frame180, frame)
elif i == 389:
assert_equal(ref_frame_last, frame)
assert_tensor_equal(ref_frame_last, frame)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions test/decoders/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@
StreamMetadata,
)

from ..test_utils import get_reference_video_path
from ..test_utils import NASA_VIDEO


def test_get_video_metadata():
decoder = create_from_file(str(get_reference_video_path()))
decoder = create_from_file(str(NASA_VIDEO.path))
metadata = get_video_metadata(decoder)
assert len(metadata.streams) == 6
assert metadata.best_video_stream_index == 3
Expand Down
Loading

0 comments on commit 7af94c7

Please sign in to comment.