Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed Jun 24, 2024
1 parent a2ee580 commit fd9bc40
Show file tree
Hide file tree
Showing 8 changed files with 293 additions and 37 deletions.
14 changes: 7 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ repos:
- id: check-added-large-files
args: ['--maxkb=1000']

# - repo: https://github.com/omnilib/ufmt
# rev: v2.6.0
# hooks:
# - id: ufmt
# additional_dependencies:
# - black == 24.4.2
# - usort == 1.0.5
- repo: https://github.com/omnilib/ufmt
rev: v2.6.0
hooks:
- id: ufmt
additional_dependencies:
- black == 24.4.2
- usort == 1.0.5

- repo: https://github.com/PyCQA/flake8
rev: 7.1.0
Expand Down
5 changes: 5 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ class VideoDecoder {
std::optional<int64_t> height;
};
struct ContainerMetadata {
// TODO: in C++ the StreamMetadata vec is part of the ContainerMetadata. In
// Python, the equivalent list isn't part of the containers' metadata: it is
// a separate attribute of the VideoMetaData dataclass, next to the
// container metadata. We can probably align the C++ structure to reflect
// the Python one?
std::vector<StreamMetadata> streams;
int numAudioStreams = 0;
int numVideoStreams = 0;
Expand Down
122 changes: 100 additions & 22 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,6 @@ namespace facebook::torchcodec {
// ==============================
// Define the operators
// ==============================

torch::Tensor plus_one(torch::Tensor t) {
return t + 1;
}

TORCH_LIBRARY(plusoneops, m) {
m.def("plus_one", plus_one);
}

// All instances of accepting the decoder as a tensor must be annotated with
// `Tensor(a!)`. The `(a!)` part normally indicates that the tensor is being
// mutated in place. We need it to make sure that torch.compile does not reorder
Expand All @@ -41,6 +32,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices, int stream_index) -> Tensor");
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
m.def("get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str");
}

// ==============================
Expand Down Expand Up @@ -152,6 +145,25 @@ std::string quoteValue(const std::string& value) {
return "\"" + value + "\"";
}

std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {

std::stringstream ss;
ss << "{\n";
auto it = metadataMap.begin();
while (it != metadataMap.end()) {
ss << "\"" << it->first << "\": " << it->second;
++it;
if (it != metadataMap.end()) {
ss << ",\n";
} else {
ss << "\n";
}
}
ss << "}";

return ss.str();
}

std::string get_json_metadata(at::Tensor& decoder) {
auto videoDecoder = static_cast<VideoDecoder*>(decoder.mutable_data_ptr());

Expand Down Expand Up @@ -219,23 +231,87 @@ std::string get_json_metadata(at::Tensor& decoder) {
std::to_string(*videoMetadata.bestAudioStreamIndex);
}

std::stringstream ss;
ss << "{\n";
auto it = metadataMap.begin();
while (it != metadataMap.end()) {
ss << "\"" << it->first << "\": " << it->second;
++it;
if (it != metadataMap.end()) {
ss << ",\n";
} else {
ss << "\n";
}
return mapToJson(metadataMap);
}

std::string get_container_json_metadata(at::Tensor &decoder) {
auto videoDecoder = static_cast<VideoDecoder *>(decoder.mutable_data_ptr());

auto containerMetadata = videoDecoder->getContainerMetadata();

std::map<std::string, std::string> map;

if (containerMetadata.durationSeconds.has_value()) {
map["durationSeconds"] = std::to_string(*containerMetadata.durationSeconds);
}
ss << "}";

return ss.str();
if (containerMetadata.bitRate.has_value()) {
map["bitRate"] = std::to_string(*containerMetadata.bitRate);
}

if (containerMetadata.bestVideoStreamIndex.has_value()) {
map["bestVideoStreamIndex"] =
std::to_string(*containerMetadata.bestVideoStreamIndex);
}
if (containerMetadata.bestAudioStreamIndex.has_value()) {
map["bestAudioStreamIndex"] =
std::to_string(*containerMetadata.bestAudioStreamIndex);
}

// TODO: Q from Nicolas - is there a better way to retrieve and propagate the
// number of streams?
map["numStreams"] = std::to_string(containerMetadata.streams.size());

return mapToJson(map);
}


std::string get_stream_json_metadata(at::Tensor &decoder,
int64_t stream_index) {
auto videoDecoder = static_cast<VideoDecoder *>(decoder.mutable_data_ptr());
auto streamMetadata =
videoDecoder->getContainerMetadata().streams[stream_index];

std::map<std::string, std::string> map;

if (streamMetadata.durationSeconds.has_value()) {
map["durationSeconds"] = std::to_string(*streamMetadata.durationSeconds);
}
if (streamMetadata.bitRate.has_value()) {
map["bitRate"] = std::to_string(*streamMetadata.bitRate);
}
if (streamMetadata.numFramesFromScan.has_value()) {
map["numFramesFromScan"] =
std::to_string(*streamMetadata.numFramesFromScan);
}
if (streamMetadata.numFrames.has_value()) {
map["numFrames"] = std::to_string(*streamMetadata.numFrames);
}
if (streamMetadata.minPtsSecondsFromScan.has_value()) {
map["minPtsSecondsFromScan"] =
std::to_string(*streamMetadata.minPtsSecondsFromScan);
}
if (streamMetadata.maxPtsSecondsFromScan.has_value()) {
map["maxPtsSecondsFromScan"] =
std::to_string(*streamMetadata.maxPtsSecondsFromScan);
}
if (streamMetadata.codecName.has_value()) {
map["codec"] = quoteValue(streamMetadata.codecName.value());
}
if (streamMetadata.width.has_value()) {
map["width"] = std::to_string(*streamMetadata.width);
}
if (streamMetadata.height.has_value()) {
map["height"] = std::to_string(*streamMetadata.height);
}
if (streamMetadata.averageFps.has_value()) {
map["averageFps"] = std::to_string(*streamMetadata.averageFps);
}
return mapToJson(map);
}



TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
m.impl("create_from_file", &create_from_file);
m.impl("create_from_tensor", &create_from_tensor);
Expand All @@ -246,6 +322,8 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
m.impl("add_video_stream", &add_video_stream);
m.impl("get_next_frame", &get_next_frame);
m.impl("get_json_metadata", &get_json_metadata);
m.impl("get_container_json_metadata", &get_container_json_metadata);
m.impl("get_stream_json_metadata", &get_stream_json_metadata);
m.impl("get_frame_at_pts", &get_frame_at_pts);
m.impl("get_frame_at_index", &get_frame_at_index);
m.impl("get_frames_at_indices", &get_frames_at_indices);
Expand Down
6 changes: 6 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,10 @@ at::Tensor get_next_frame(at::Tensor& decoder);
// Get the metadata from the video as a string.
std::string get_json_metadata(at::Tensor& decoder);

// Get the container metadata as a string.
std::string get_container_json_metadata(at::Tensor& decoder);

// Get the stream metadata as a string.
std::string get_stream_json_metadata(at::Tensor& decoder);

} // namespace facebook::torchcodec
7 changes: 7 additions & 0 deletions src/torchcodec/decoders/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,10 @@
# TODO: Don't use import *

from .video_decoder_ops import * # noqa

from ._metadata import (
ContainerMetadata,
get_video_metadata,
StreamMetadata,
VideoMetadata,
)
108 changes: 108 additions & 0 deletions src/torchcodec/decoders/_core/_metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import json

from dataclasses import dataclass
from typing import List, Optional

import torch

from torchcodec.decoders._core.video_decoder_ops import (
_get_container_json_metadata,
_get_stream_json_metadata,
)


@dataclass
class ContainerMetadata:
duration_seconds: Optional[float]
bit_rate: Optional[float]
best_video_stream_index: Optional[int]
best_audio_stream_index: Optional[int]


@dataclass
class StreamMetadata:
duration_seconds: Optional[float]
bit_rate: Optional[float]
# TODO Comment from Nicolas:
# Looking at this, it's not immediately obvious to me that "retrieved" means
# "less accurate than 'computed'".
# Are we open to different names? E.g. "num_frames_from_header" and "num_frames_accurate"?
num_frames_retrieved: Optional[int]
num_frames_computed: Optional[int]
min_pts_seconds: Optional[float]
max_pts_seconds: Optional[float]
codec: Optional[str]
width: Optional[int]
height: Optional[int]
average_fps: Optional[float]

@property
def num_frames(self) -> Optional[int]:
if self.num_frames_computed is not None:
return self.num_frames_computed
else:
return self.num_frames_retrieved


@dataclass
class VideoMetadata:
container: ContainerMetadata
streams: List[StreamMetadata]

@property
def duration_seconds(self) -> Optional[float]:
if (
self.container.best_video_stream_index is not None
and self.streams[self.container.best_video_stream_index].duration_seconds
is not None
):
return self.streams[self.container.best_video_stream_index].duration_seconds
else:
return self.container.duration_seconds

@property
def bit_rate(self) -> Optional[float]:
if (
self.container.best_video_stream_index is not None
and self.streams[self.container.best_video_stream_index].bit_rate
is not None
):
return self.streams[self.container.best_video_stream_index].bit_rate
else:
return self.contain.bit_rate

@property
def best_video_stream(self) -> StreamMetadata:
assert self.container.best_video_stream_index is not None
return self.container.streams[self.container.best_video_stream_index]


def get_video_metadata(decoder: torch.tensor) -> VideoMetadata:

container_dict = json.loads(_get_container_json_metadata(decoder))
container_metadata = ContainerMetadata(
duration_seconds=container_dict.get("durationSeconds"),
bit_rate=container_dict.get("bitRate"),
best_video_stream_index=container_dict.get("bestVideoStreamIndex"),
best_audio_stream_index=container_dict.get("bestAudioStreamIndex"),
)

streams_metadata = []
for stream_index in range(container_dict["numStreams"]):
stream_dict = json.loads(_get_stream_json_metadata(decoder, stream_index))
streams_metadata.append(
StreamMetadata(
duration_seconds=stream_dict.get("durationSeconds"),
bit_rate=stream_dict.get("bitRate"),
num_frames_retrieved=stream_dict.get("numFrames"),
num_frames_computed=stream_dict.get("numFramesFromScan"),
min_pts_seconds=stream_dict.get("minPtsSecondsFromScan"),
max_pts_seconds=stream_dict.get("maxPtsSecondsFromScan"),
codec=stream_dict.get("codec"),
width=stream_dict.get("width"),
height=stream_dict.get("height"),
average_fps=stream_dict.get("averageFps"),
)
)

return VideoMetadata(container=container_metadata, streams=streams_metadata)
14 changes: 14 additions & 0 deletions src/torchcodec/decoders/_core/video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def load_torchcodec_extension():
get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default
get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default
get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default
_get_container_json_metadata = (
torch.ops.torchcodec_ns.get_container_json_metadata.default
)
_get_stream_json_metadata = torch.ops.torchcodec_ns.get_stream_json_metadata.default


# =============================
Expand Down Expand Up @@ -134,3 +138,13 @@ def get_frames_at_indices_abstract(
@register_fake("torchcodec_ns::get_json_metadata")
def get_json_metadata_abstract(decoder: torch.Tensor) -> str:
return torch.empty_like("")


@register_fake("torchcodec_ns::get_container_json_metadata")
def get_container_json_metadata_abstract(decoder: torch.Tensor) -> str:
return torch.empty_like("")


@register_fake("torchcodec_ns::get_stream_json_metadata")
def get_stream_json_metadata_abstract(decoder: torch.Tensor, stream_idx: int) -> str:
return torch.empty_like("")
Loading

0 comments on commit fd9bc40

Please sign in to comment.