diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 100eea43..6417cfd1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,13 +14,13 @@ repos: - id: check-added-large-files args: ['--maxkb=1000'] - # - repo: https://github.com/omnilib/ufmt - # rev: v2.6.0 - # hooks: - # - id: ufmt - # additional_dependencies: - # - black == 24.4.2 - # - usort == 1.0.5 + # - repo: https://github.com/omnilib/ufmt + # rev: v2.6.0 + # hooks: + # - id: ufmt + # additional_dependencies: + # - black == 24.4.2 + # - usort == 1.0.5 - repo: https://github.com/PyCQA/flake8 rev: 7.1.0 diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 0a842cd9..e66ac03b 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -11,7 +11,6 @@ namespace facebook::torchcodec { // ============================== // Define the operators // ============================== - // All instances of accepting the decoder as a tensor must be annotated with // `Tensor(a!)`. The `(a!)` part normally indicates that the tensor is being // mutated in place. We need it to make sure that torch.compile does not reorder @@ -35,6 +34,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> Tensor"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); + m.def("get_container_json_metadata(Tensor(a!) decoder) -> str"); + m.def( + "get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str"); m.def("_get_json_ffmpeg_library_versions() -> str"); } @@ -159,6 +161,29 @@ std::string quoteValue(const std::string& value) { return "\"" + value + "\""; } +// TODO: we should use a more robust way to serialize the metadata. There are a +// few alternatives, but ultimately we are limited to what custom ops allow us +// to return. Current ideas are to use a proper JSON library, or to pack all the +// info into tensors. *If* we're OK to drop the export support for metadata, we +// could also easily bind the C++ structs to Python with pybind11. +std::string mapToJson(const std::map& metadataMap) { + std::stringstream ss; + ss << "{\n"; + auto it = metadataMap.begin(); + while (it != metadataMap.end()) { + ss << "\"" << it->first << "\": " << it->second; + ++it; + if (it != metadataMap.end()) { + ss << ",\n"; + } else { + ss << "\n"; + } + } + ss << "}"; + + return ss.str(); +} + std::string get_json_metadata(at::Tensor& decoder) { auto videoDecoder = static_cast(decoder.mutable_data_ptr()); @@ -222,21 +247,85 @@ std::string get_json_metadata(at::Tensor& decoder) { std::to_string(*videoMetadata.bestAudioStreamIndex); } - std::stringstream ss; - ss << "{\n"; - auto it = metadataMap.begin(); - while (it != metadataMap.end()) { - ss << "\"" << it->first << "\": " << it->second; - ++it; - if (it != metadataMap.end()) { - ss << ",\n"; - } else { - ss << "\n"; - } + return mapToJson(metadataMap); +} + +std::string get_container_json_metadata(at::Tensor& decoder) { + auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + + auto containerMetadata = videoDecoder->getContainerMetadata(); + + std::map map; + + if (containerMetadata.durationSeconds.has_value()) { + map["durationSeconds"] = std::to_string(*containerMetadata.durationSeconds); } - ss << "}"; - return ss.str(); + if (containerMetadata.bitRate.has_value()) { + map["bitRate"] = std::to_string(*containerMetadata.bitRate); + } + + if (containerMetadata.bestVideoStreamIndex.has_value()) { + map["bestVideoStreamIndex"] = + std::to_string(*containerMetadata.bestVideoStreamIndex); + } + if (containerMetadata.bestAudioStreamIndex.has_value()) { + map["bestAudioStreamIndex"] = + std::to_string(*containerMetadata.bestAudioStreamIndex); + } + + map["numStreams"] = std::to_string(containerMetadata.streams.size()); + + return mapToJson(map); +} + +std::string get_stream_json_metadata( + at::Tensor& decoder, + int64_t stream_index) { + auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + auto streams = videoDecoder->getContainerMetadata().streams; + if (stream_index < 0 || stream_index >= streams.size()) { + throw std::out_of_range( + "stream_index out of bounds: " + std::to_string(stream_index)); + } + auto streamMetadata = streams[stream_index]; + + std::map map; + + if (streamMetadata.durationSeconds.has_value()) { + map["durationSeconds"] = std::to_string(*streamMetadata.durationSeconds); + } + if (streamMetadata.bitRate.has_value()) { + map["bitRate"] = std::to_string(*streamMetadata.bitRate); + } + if (streamMetadata.numFramesFromScan.has_value()) { + map["numFramesFromScan"] = + std::to_string(*streamMetadata.numFramesFromScan); + } + if (streamMetadata.numFrames.has_value()) { + map["numFrames"] = std::to_string(*streamMetadata.numFrames); + } + if (streamMetadata.minPtsSecondsFromScan.has_value()) { + map["minPtsSecondsFromScan"] = + std::to_string(*streamMetadata.minPtsSecondsFromScan); + } + if (streamMetadata.maxPtsSecondsFromScan.has_value()) { + map["maxPtsSecondsFromScan"] = + std::to_string(*streamMetadata.maxPtsSecondsFromScan); + } + if (streamMetadata.codecName.has_value()) { + map["codec"] = quoteValue(streamMetadata.codecName.value()); + } + if (streamMetadata.width.has_value()) { + map["width"] = std::to_string(*streamMetadata.width); + } + if (streamMetadata.height.has_value()) { + map["height"] = std::to_string(*streamMetadata.height); + } + if (streamMetadata.averageFps.has_value()) { + map["averageFps"] = std::to_string(*streamMetadata.averageFps); + } + return mapToJson(map); } std::string _get_json_ffmpeg_library_versions() { @@ -277,6 +366,8 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("add_video_stream", &add_video_stream); m.impl("get_next_frame", &get_next_frame); m.impl("get_json_metadata", &get_json_metadata); + m.impl("get_container_json_metadata", &get_container_json_metadata); + m.impl("get_stream_json_metadata", &get_stream_json_metadata); m.impl("get_frame_at_pts", &get_frame_at_pts); m.impl("get_frame_at_index", &get_frame_at_index); m.impl("get_frames_at_indices", &get_frames_at_indices); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 9c663919..9b87ff91 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -69,6 +69,12 @@ at::Tensor get_next_frame(at::Tensor& decoder); // Get the metadata from the video as a string. std::string get_json_metadata(at::Tensor& decoder); +// Get the container metadata as a string. +std::string get_container_json_metadata(at::Tensor& decoder); + +// Get the stream metadata as a string. +std::string get_stream_json_metadata(at::Tensor& decoder); + // Returns version information about the various FFMPEG libraries that are // loaded in the program's address space. std::string _get_json_ffmpeg_library_versions(); diff --git a/src/torchcodec/decoders/_core/__init__.py b/src/torchcodec/decoders/_core/__init__.py index 84add756..56560208 100644 --- a/src/torchcodec/decoders/_core/__init__.py +++ b/src/torchcodec/decoders/_core/__init__.py @@ -3,3 +3,5 @@ # TODO: Don't use import * from .video_decoder_ops import * # noqa + +from ._metadata import get_video_metadata, StreamMetadata, VideoMetadata diff --git a/src/torchcodec/decoders/_core/_metadata.py b/src/torchcodec/decoders/_core/_metadata.py new file mode 100644 index 00000000..8f0442ab --- /dev/null +++ b/src/torchcodec/decoders/_core/_metadata.py @@ -0,0 +1,92 @@ +import json + +from dataclasses import dataclass +from typing import List, Optional + +import torch + +from torchcodec.decoders._core.video_decoder_ops import ( + _get_container_json_metadata, + _get_stream_json_metadata, +) + + +@dataclass +class StreamMetadata: + duration_seconds: Optional[float] + bit_rate: Optional[float] + # TODO: Before release, we should come up with names that better convey the + # " 'fast and potentially inaccurate' vs 'slower but accurate' " tradeoff. + num_frames_retrieved: Optional[int] + num_frames_computed: Optional[int] + min_pts_seconds: Optional[float] + max_pts_seconds: Optional[float] + codec: Optional[str] + width: Optional[int] + height: Optional[int] + average_fps: Optional[float] + stream_index: int + + @property + def num_frames(self) -> Optional[int]: + if self.num_frames_computed is not None: + return self.num_frames_computed + else: + return self.num_frames_retrieved + + +# This may be renamed into e.g. ContainerMetadata in the future to be more generic. +@dataclass +class VideoMetadata: + duration_seconds_container: Optional[float] + bit_rate_container: Optional[float] + best_video_stream_index: Optional[int] + best_audio_stream_index: Optional[int] + + streams: List[StreamMetadata] + + @property + def duration_seconds(self) -> Optional[float]: + raise NotImplementedError("TODO: decide on logic and implement this!") + + @property + def bit_rate(self) -> Optional[float]: + raise NotImplementedError("TODO: decide on logic and implement this!") + + @property + def best_video_stream(self) -> StreamMetadata: + if self.best_video_stream_index is None: + raise ValueError("The best video stream is unknown.") + return self.streams[self.best_video_stream_index] + + +def get_video_metadata(decoder: torch.tensor) -> VideoMetadata: + + container_dict = json.loads(_get_container_json_metadata(decoder)) + streams_metadata = [] + for stream_index in range(container_dict["numStreams"]): + stream_dict = json.loads(_get_stream_json_metadata(decoder, stream_index)) + streams_metadata.append( + StreamMetadata( + duration_seconds=stream_dict.get("durationSeconds"), + bit_rate=stream_dict.get("bitRate"), + # TODO: We should align the C++ names and the json keys with the Python names + num_frames_retrieved=stream_dict.get("numFrames"), + num_frames_computed=stream_dict.get("numFramesFromScan"), + min_pts_seconds=stream_dict.get("minPtsSecondsFromScan"), + max_pts_seconds=stream_dict.get("maxPtsSecondsFromScan"), + codec=stream_dict.get("codec"), + width=stream_dict.get("width"), + height=stream_dict.get("height"), + average_fps=stream_dict.get("averageFps"), + stream_index=stream_index, + ) + ) + + return VideoMetadata( + duration_seconds_container=container_dict.get("durationSeconds"), + bit_rate_container=container_dict.get("bitRate"), + best_video_stream_index=container_dict.get("bestVideoStreamIndex"), + best_audio_stream_index=container_dict.get("bestAudioStreamIndex"), + streams=streams_metadata, + ) diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 9823fc14..31377a54 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -60,6 +60,10 @@ def load_torchcodec_extension(): get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default +_get_container_json_metadata = ( + torch.ops.torchcodec_ns.get_container_json_metadata.default +) +_get_stream_json_metadata = torch.ops.torchcodec_ns.get_stream_json_metadata.default _get_json_ffmpeg_library_versions = ( torch.ops.torchcodec_ns._get_json_ffmpeg_library_versions.default ) @@ -154,6 +158,16 @@ def get_json_metadata_abstract(decoder: torch.Tensor) -> str: return torch.empty_like("") +@register_fake("torchcodec_ns::get_container_json_metadata") +def get_container_json_metadata_abstract(decoder: torch.Tensor) -> str: + return torch.empty_like("") + + +@register_fake("torchcodec_ns::get_stream_json_metadata") +def get_stream_json_metadata_abstract(decoder: torch.Tensor, stream_idx: int) -> str: + return torch.empty_like("") + + @register_fake("torchcodec_ns::_get_json_ffmpeg_library_versions") def _get_json_ffmpeg_library_versions_abstract() -> str: return torch.empty_like("") diff --git a/src/torchcodec/decoders/_simple_video_decoder.py b/src/torchcodec/decoders/_simple_video_decoder.py index 4e12c82e..992ab623 100644 --- a/src/torchcodec/decoders/_simple_video_decoder.py +++ b/src/torchcodec/decoders/_simple_video_decoder.py @@ -1,7 +1,7 @@ -import json from typing import Union import torch + from torchcodec.decoders import _core as core @@ -24,12 +24,9 @@ def __init__(self, source: Union[str, bytes, torch.Tensor]): core.add_video_stream(self._decoder) - # TODO: We should either implement specific core library function to - # retrieve these values, or replace this with a non-JSON metadata - # retrieval. - metadata_json = json.loads(core.get_json_metadata(self._decoder)) - self._num_frames = metadata_json["numFrames"] - self._stream_index = metadata_json["bestVideoStreamIndex"] + self.stream_metadata = _get_and_validate_stream_metadata(self._decoder) + self._num_frames = self.stream_metadata.num_frames_computed + self._stream_index = self.stream_metadata.stream_index def __len__(self) -> int: return self._num_frames @@ -80,3 +77,24 @@ def __next__(self) -> torch.Tensor: return core.get_next_frame(self._decoder) except RuntimeError: raise StopIteration() + + +def _get_and_validate_stream_metadata(decoder: torch.Tensor) -> core.StreamMetadata: + video_metadata = core.get_video_metadata(decoder) + + if video_metadata.best_video_stream_index is None: + raise ValueError( + "The best video stream is unknown. This should never happen. " + "Please report an issue following the steps in " + ) + + best_stream_metadata = video_metadata.streams[ + video_metadata.best_video_stream_index + ] + if best_stream_metadata.num_frames_computed is None: + raise ValueError( + "The number of frames is unknown. This should never happen. " + "Please report an issue following the steps in " + ) + + return best_stream_metadata diff --git a/test/decoders/simple_video_decoder_test.py b/test/decoders/simple_video_decoder_test.py index 034eeaec..9175c83b 100644 --- a/test/decoders/simple_video_decoder_test.py +++ b/test/decoders/simple_video_decoder_test.py @@ -1,7 +1,7 @@ import pytest import torch -from torchcodec.decoders import SimpleVideoDecoder +from torchcodec.decoders import _core, SimpleVideoDecoder from ..test_utils import ( assert_equal, @@ -14,25 +14,28 @@ class TestSimpleDecoder: - - def test_create_from_file(self): - decoder = SimpleVideoDecoder(str(get_reference_video_path())) - assert len(decoder) == 390 - assert decoder._stream_index == 3 - - def test_create_from_tensor(self): - decoder = SimpleVideoDecoder(get_reference_video_tensor()) - assert len(decoder) == 390 - assert decoder._stream_index == 3 - - def test_create_from_bytes(self): - path = str(get_reference_video_path()) - with open(path, "rb") as f: - video_bytes = f.read() - - decoder = SimpleVideoDecoder(video_bytes) - assert len(decoder) == 390 - assert decoder._stream_index == 3 + @pytest.mark.parametrize("source_kind", ("path", "tensor", "bytes")) + def test_create(self, source_kind): + if source_kind == "path": + source = str(get_reference_video_path()) + elif source_kind == "tensor": + source = get_reference_video_tensor() + elif source_kind == "bytes": + path = str(get_reference_video_path()) + with open(path, "rb") as f: + source = f.read() + else: + raise ValueError("Oops, double check the parametrization of this test!") + + decoder = SimpleVideoDecoder(source) + assert isinstance(decoder.stream_metadata, _core.StreamMetadata) + assert ( + len(decoder) + == decoder._num_frames + == decoder.stream_metadata.num_frames_computed + == 390 + ) + assert decoder._stream_index == decoder.stream_metadata.stream_index == 3 def test_create_fails(self): with pytest.raises(TypeError, match="Unknown source type"): diff --git a/test/decoders/test_metadata.py b/test/decoders/test_metadata.py new file mode 100644 index 00000000..321dd640 --- /dev/null +++ b/test/decoders/test_metadata.py @@ -0,0 +1,61 @@ +import pytest + +from torchcodec.decoders._core import ( + create_from_file, + get_video_metadata, + StreamMetadata, +) + +from ..test_utils import get_reference_video_path + + +def test_get_video_metadata(): + decoder = create_from_file(str(get_reference_video_path())) + metadata = get_video_metadata(decoder) + assert len(metadata.streams) == 6 + assert metadata.best_video_stream_index == 3 + assert metadata.best_audio_stream_index == 3 + + with pytest.raises(NotImplementedError, match="TODO: decide on logic"): + metadata.duration_seconds + with pytest.raises(NotImplementedError, match="TODO: decide on logic"): + metadata.bit_rate + + # TODO: put these checks back once D58974580 is landed. The expected values + # are different depending on the FFmpeg version. + # assert metadata.duration_seconds_container == pytest.approx(16.57, abs=0.001) + # assert metadata.bit_rate_container == 324915 + + best_stream_metadata = metadata.streams[metadata.best_video_stream_index] + assert best_stream_metadata is metadata.best_video_stream + assert best_stream_metadata.duration_seconds == pytest.approx(13.013, abs=0.001) + assert best_stream_metadata.bit_rate == 128783 + assert best_stream_metadata.average_fps == pytest.approx(29.97, abs=0.001) + assert best_stream_metadata.codec == "h264" + assert best_stream_metadata.num_frames_computed == 390 + assert best_stream_metadata.num_frames_retrieved == 390 + + +@pytest.mark.parametrize( + "num_frames_retrieved, num_frames_computed, expected_num_frames", + [(None, 10, 10), (10, None, 10), (None, None, None)], +) +def test_num_frames_fallback( + num_frames_retrieved, num_frames_computed, expected_num_frames +): + """Check that num_frames_computed always has priority when accessing `.num_frames`""" + metadata = StreamMetadata( + duration_seconds=4, + bit_rate=123, + num_frames_retrieved=num_frames_retrieved, + num_frames_computed=num_frames_computed, + min_pts_seconds=0, + max_pts_seconds=4, + codec="whatever", + width=123, + height=321, + average_fps=30, + stream_index=0, + ) + + assert metadata.num_frames == expected_num_frames diff --git a/test/decoders/video_decoder_ops_test.py b/test/decoders/video_decoder_ops_test.py index 8fe7c1dd..0a2c402e 100644 --- a/test/decoders/video_decoder_ops_test.py +++ b/test/decoders/video_decoder_ops_test.py @@ -247,6 +247,13 @@ def test_create_decoder(self, create_from): reference_frame_time6 = load_tensor_from_file("nasa_13013.mp4.time6.000000.pt") assert_equal(frame_time6, reference_frame_time6) + # TODO: Keeping the metadata tests below for now, but we should remove them + # once we remove get_json_metadata(). + # Note that the distinction made between test_video_get_json_metadata and + # test_video_get_json_metadata_with_stream is misleading: all of the stream + # metadata are available even without adding a video stream, because we + # always call scanFileAndUpdateMetadataAndIndex() when creating a decoder + # from the core API. def test_video_get_json_metadata(self): decoder = create_from_file(str(get_reference_video_path())) metadata = get_json_metadata(decoder)