From fd9bc4026dc3701553f2f174a36babb613226061 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 24 Jun 2024 16:06:39 +0100 Subject: [PATCH] WIP --- .pre-commit-config.yaml | 14 +- src/torchcodec/decoders/_core/VideoDecoder.h | 5 + .../decoders/_core/VideoDecoderOps.cpp | 122 ++++++++++++++---- .../decoders/_core/VideoDecoderOps.h | 6 + src/torchcodec/decoders/_core/__init__.py | 7 + src/torchcodec/decoders/_core/_metadata.py | 108 ++++++++++++++++ .../decoders/_core/video_decoder_ops.py | 14 ++ .../decoders/_simple_video_decoder.py | 54 ++++++-- 8 files changed, 293 insertions(+), 37 deletions(-) create mode 100644 src/torchcodec/decoders/_core/_metadata.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 100eea43..d381a5a3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,13 +14,13 @@ repos: - id: check-added-large-files args: ['--maxkb=1000'] - # - repo: https://github.com/omnilib/ufmt - # rev: v2.6.0 - # hooks: - # - id: ufmt - # additional_dependencies: - # - black == 24.4.2 - # - usort == 1.0.5 + - repo: https://github.com/omnilib/ufmt + rev: v2.6.0 + hooks: + - id: ufmt + additional_dependencies: + - black == 24.4.2 + - usort == 1.0.5 - repo: https://github.com/PyCQA/flake8 rev: 7.1.0 diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 39cd93b7..8c0033af 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -101,6 +101,11 @@ class VideoDecoder { std::optional height; }; struct ContainerMetadata { + // TODO: in C++ the StreamMetadata vec is part of the ContainerMetadata. In + // Python, the equivalent list isn't part of the containers' metadata: it is + // a separate attribute of the VideoMetaData dataclass, next to the + // container metadata. We can probably align the C++ structure to reflect + // the Python one? std::vector streams; int numAudioStreams = 0; int numVideoStreams = 0; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 013d672f..1b1cd508 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -11,15 +11,6 @@ namespace facebook::torchcodec { // ============================== // Define the operators // ============================== - -torch::Tensor plus_one(torch::Tensor t) { - return t + 1; -} - -TORCH_LIBRARY(plusoneops, m) { - m.def("plus_one", plus_one); -} - // All instances of accepting the decoder as a tensor must be annotated with // `Tensor(a!)`. The `(a!)` part normally indicates that the tensor is being // mutated in place. We need it to make sure that torch.compile does not reorder @@ -41,6 +32,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices, int stream_index) -> Tensor"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); + m.def("get_container_json_metadata(Tensor(a!) decoder) -> str"); + m.def("get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str"); } // ============================== @@ -152,6 +145,25 @@ std::string quoteValue(const std::string& value) { return "\"" + value + "\""; } +std::string mapToJson(const std::map& metadataMap) { + + std::stringstream ss; + ss << "{\n"; + auto it = metadataMap.begin(); + while (it != metadataMap.end()) { + ss << "\"" << it->first << "\": " << it->second; + ++it; + if (it != metadataMap.end()) { + ss << ",\n"; + } else { + ss << "\n"; + } + } + ss << "}"; + + return ss.str(); +} + std::string get_json_metadata(at::Tensor& decoder) { auto videoDecoder = static_cast(decoder.mutable_data_ptr()); @@ -219,23 +231,87 @@ std::string get_json_metadata(at::Tensor& decoder) { std::to_string(*videoMetadata.bestAudioStreamIndex); } - std::stringstream ss; - ss << "{\n"; - auto it = metadataMap.begin(); - while (it != metadataMap.end()) { - ss << "\"" << it->first << "\": " << it->second; - ++it; - if (it != metadataMap.end()) { - ss << ",\n"; - } else { - ss << "\n"; - } + return mapToJson(metadataMap); +} + +std::string get_container_json_metadata(at::Tensor &decoder) { + auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + + auto containerMetadata = videoDecoder->getContainerMetadata(); + + std::map map; + + if (containerMetadata.durationSeconds.has_value()) { + map["durationSeconds"] = std::to_string(*containerMetadata.durationSeconds); } - ss << "}"; - return ss.str(); + if (containerMetadata.bitRate.has_value()) { + map["bitRate"] = std::to_string(*containerMetadata.bitRate); + } + + if (containerMetadata.bestVideoStreamIndex.has_value()) { + map["bestVideoStreamIndex"] = + std::to_string(*containerMetadata.bestVideoStreamIndex); + } + if (containerMetadata.bestAudioStreamIndex.has_value()) { + map["bestAudioStreamIndex"] = + std::to_string(*containerMetadata.bestAudioStreamIndex); + } + + // TODO: Q from Nicolas - is there a better way to retrieve and propagate the + // number of streams? + map["numStreams"] = std::to_string(containerMetadata.streams.size()); + + return mapToJson(map); +} + + +std::string get_stream_json_metadata(at::Tensor &decoder, + int64_t stream_index) { + auto videoDecoder = static_cast(decoder.mutable_data_ptr()); + auto streamMetadata = + videoDecoder->getContainerMetadata().streams[stream_index]; + + std::map map; + + if (streamMetadata.durationSeconds.has_value()) { + map["durationSeconds"] = std::to_string(*streamMetadata.durationSeconds); + } + if (streamMetadata.bitRate.has_value()) { + map["bitRate"] = std::to_string(*streamMetadata.bitRate); + } + if (streamMetadata.numFramesFromScan.has_value()) { + map["numFramesFromScan"] = + std::to_string(*streamMetadata.numFramesFromScan); + } + if (streamMetadata.numFrames.has_value()) { + map["numFrames"] = std::to_string(*streamMetadata.numFrames); + } + if (streamMetadata.minPtsSecondsFromScan.has_value()) { + map["minPtsSecondsFromScan"] = + std::to_string(*streamMetadata.minPtsSecondsFromScan); + } + if (streamMetadata.maxPtsSecondsFromScan.has_value()) { + map["maxPtsSecondsFromScan"] = + std::to_string(*streamMetadata.maxPtsSecondsFromScan); + } + if (streamMetadata.codecName.has_value()) { + map["codec"] = quoteValue(streamMetadata.codecName.value()); + } + if (streamMetadata.width.has_value()) { + map["width"] = std::to_string(*streamMetadata.width); + } + if (streamMetadata.height.has_value()) { + map["height"] = std::to_string(*streamMetadata.height); + } + if (streamMetadata.averageFps.has_value()) { + map["averageFps"] = std::to_string(*streamMetadata.averageFps); + } + return mapToJson(map); } + + TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("create_from_file", &create_from_file); m.impl("create_from_tensor", &create_from_tensor); @@ -246,6 +322,8 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("add_video_stream", &add_video_stream); m.impl("get_next_frame", &get_next_frame); m.impl("get_json_metadata", &get_json_metadata); + m.impl("get_container_json_metadata", &get_container_json_metadata); + m.impl("get_stream_json_metadata", &get_stream_json_metadata); m.impl("get_frame_at_pts", &get_frame_at_pts); m.impl("get_frame_at_index", &get_frame_at_index); m.impl("get_frames_at_indices", &get_frames_at_indices); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index c029473b..ace029a6 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -60,4 +60,10 @@ at::Tensor get_next_frame(at::Tensor& decoder); // Get the metadata from the video as a string. std::string get_json_metadata(at::Tensor& decoder); +// Get the container metadata as a string. +std::string get_container_json_metadata(at::Tensor& decoder); + +// Get the stream metadata as a string. +std::string get_stream_json_metadata(at::Tensor& decoder); + } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/__init__.py b/src/torchcodec/decoders/_core/__init__.py index 84add756..eb770d7c 100644 --- a/src/torchcodec/decoders/_core/__init__.py +++ b/src/torchcodec/decoders/_core/__init__.py @@ -3,3 +3,10 @@ # TODO: Don't use import * from .video_decoder_ops import * # noqa + +from ._metadata import ( + ContainerMetadata, + get_video_metadata, + StreamMetadata, + VideoMetadata, +) diff --git a/src/torchcodec/decoders/_core/_metadata.py b/src/torchcodec/decoders/_core/_metadata.py new file mode 100644 index 00000000..d7e90c42 --- /dev/null +++ b/src/torchcodec/decoders/_core/_metadata.py @@ -0,0 +1,108 @@ +import json + +from dataclasses import dataclass +from typing import List, Optional + +import torch + +from torchcodec.decoders._core.video_decoder_ops import ( + _get_container_json_metadata, + _get_stream_json_metadata, +) + + +@dataclass +class ContainerMetadata: + duration_seconds: Optional[float] + bit_rate: Optional[float] + best_video_stream_index: Optional[int] + best_audio_stream_index: Optional[int] + + +@dataclass +class StreamMetadata: + duration_seconds: Optional[float] + bit_rate: Optional[float] + # TODO Comment from Nicolas: + # Looking at this, it's not immediately obvious to me that "retrieved" means + # "less accurate than 'computed'". + # Are we open to different names? E.g. "num_frames_from_header" and "num_frames_accurate"? + num_frames_retrieved: Optional[int] + num_frames_computed: Optional[int] + min_pts_seconds: Optional[float] + max_pts_seconds: Optional[float] + codec: Optional[str] + width: Optional[int] + height: Optional[int] + average_fps: Optional[float] + + @property + def num_frames(self) -> Optional[int]: + if self.num_frames_computed is not None: + return self.num_frames_computed + else: + return self.num_frames_retrieved + + +@dataclass +class VideoMetadata: + container: ContainerMetadata + streams: List[StreamMetadata] + + @property + def duration_seconds(self) -> Optional[float]: + if ( + self.container.best_video_stream_index is not None + and self.streams[self.container.best_video_stream_index].duration_seconds + is not None + ): + return self.streams[self.container.best_video_stream_index].duration_seconds + else: + return self.container.duration_seconds + + @property + def bit_rate(self) -> Optional[float]: + if ( + self.container.best_video_stream_index is not None + and self.streams[self.container.best_video_stream_index].bit_rate + is not None + ): + return self.streams[self.container.best_video_stream_index].bit_rate + else: + return self.contain.bit_rate + + @property + def best_video_stream(self) -> StreamMetadata: + assert self.container.best_video_stream_index is not None + return self.container.streams[self.container.best_video_stream_index] + + +def get_video_metadata(decoder: torch.tensor) -> VideoMetadata: + + container_dict = json.loads(_get_container_json_metadata(decoder)) + container_metadata = ContainerMetadata( + duration_seconds=container_dict.get("durationSeconds"), + bit_rate=container_dict.get("bitRate"), + best_video_stream_index=container_dict.get("bestVideoStreamIndex"), + best_audio_stream_index=container_dict.get("bestAudioStreamIndex"), + ) + + streams_metadata = [] + for stream_index in range(container_dict["numStreams"]): + stream_dict = json.loads(_get_stream_json_metadata(decoder, stream_index)) + streams_metadata.append( + StreamMetadata( + duration_seconds=stream_dict.get("durationSeconds"), + bit_rate=stream_dict.get("bitRate"), + num_frames_retrieved=stream_dict.get("numFrames"), + num_frames_computed=stream_dict.get("numFramesFromScan"), + min_pts_seconds=stream_dict.get("minPtsSecondsFromScan"), + max_pts_seconds=stream_dict.get("maxPtsSecondsFromScan"), + codec=stream_dict.get("codec"), + width=stream_dict.get("width"), + height=stream_dict.get("height"), + average_fps=stream_dict.get("averageFps"), + ) + ) + + return VideoMetadata(container=container_metadata, streams=streams_metadata) diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 3aba20e5..f0556fbe 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -58,6 +58,10 @@ def load_torchcodec_extension(): get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default +_get_container_json_metadata = ( + torch.ops.torchcodec_ns.get_container_json_metadata.default +) +_get_stream_json_metadata = torch.ops.torchcodec_ns.get_stream_json_metadata.default # ============================= @@ -134,3 +138,13 @@ def get_frames_at_indices_abstract( @register_fake("torchcodec_ns::get_json_metadata") def get_json_metadata_abstract(decoder: torch.Tensor) -> str: return torch.empty_like("") + + +@register_fake("torchcodec_ns::get_container_json_metadata") +def get_container_json_metadata_abstract(decoder: torch.Tensor) -> str: + return torch.empty_like("") + + +@register_fake("torchcodec_ns::get_stream_json_metadata") +def get_stream_json_metadata_abstract(decoder: torch.Tensor, stream_idx: int) -> str: + return torch.empty_like("") diff --git a/src/torchcodec/decoders/_simple_video_decoder.py b/src/torchcodec/decoders/_simple_video_decoder.py index 9dac8612..895816ea 100644 --- a/src/torchcodec/decoders/_simple_video_decoder.py +++ b/src/torchcodec/decoders/_simple_video_decoder.py @@ -1,7 +1,8 @@ -import json -from typing import Union +from dataclasses import dataclass +from typing import Optional, Union import torch + from torchcodec.decoders import _core as core @@ -24,12 +25,10 @@ def __init__(self, source: Union[str, bytes, torch.Tensor]): core.add_video_stream(self._decoder) - # TODO: We should either implement specific core library function to - # retrieve these values, or replace this with a non-JSON metadata - # retrieval. - metadata_json = json.loads(core.get_json_metadata(self._decoder)) - self._num_frames = metadata_json["numFrames"] - self._stream_index = metadata_json["bestVideoStreamIndex"] + self.metadata = _get_and_validate_simple_video_metadata(self._decoder) + # Note: these fields exist and are not None, as validated in _get_and_validate_simple_video_metadata(). + self._num_frames = self.metadata.stream.num_frames_computed + self._stream_index = self.metadata.container.best_video_stream_index def __len__(self) -> int: return self._num_frames @@ -61,3 +60,42 @@ def __next__(self) -> torch.Tensor: return core.get_next_frame(self._decoder) except RuntimeError: raise StopIteration() + + +@dataclass +class SimpleVideoMetadata: + # TODO: ContainerMetadata and StreamMetadata should be publicly available. + # Right now they're only exposed in _core. + container: core.ContainerMetadata + stream: core.StreamMetadata + + # TODO: is the return really supposed to be Optional + @property + def duration_seconds(self) -> Optional[float]: + return self.stream.duration_seconds + + @property + def bit_rate(self) -> Optional[float]: + return self.stream.bit_rate + + +def _get_and_validate_simple_video_metadata( + decoder: torch.Tensor, +) -> SimpleVideoMetadata: + video_metadata = core.get_video_metadata(decoder) + container_metadata = video_metadata.container + + if container_metadata.best_video_stream_index is None: + raise ValueError( + "The best video stream is unknown. This should never happen. " + "Please report an issue following the steps on " + ) + + stream_metadata = video_metadata.streams[container_metadata.best_video_stream_index] + if stream_metadata.num_frames_computed is None: + raise ValueError( + "The number of frames is unknown. This should never happen. " + "Please report an issue following the steps on " + ) + + return SimpleVideoMetadata(container=container_metadata, stream=stream_metadata)