Skip to content

Commit

Permalink
Towards a clearer logic for determining output height and width (#332)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Nov 5, 2024
1 parent ab79e67 commit 373d1c5
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/torchcodec/decoders/_core/CPUOnlyDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace facebook::torchcodec {
void convertAVFrameToDecodedOutputOnCuda(
const torch::Device& device,
const VideoDecoder::VideoStreamDecoderOptions& options,
AVCodecContext* codecContext,
const VideoDecoder::StreamMetadata& metadata,
VideoDecoder::RawDecodedOutput& rawOutput,
VideoDecoder::DecodedOutput& output,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
Expand Down
21 changes: 5 additions & 16 deletions src/torchcodec/decoders/_core/CudaDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,18 +154,6 @@ AVBufferRef* getCudaContext(const torch::Device& device) {
#endif
}

torch::Tensor allocateDeviceTensor(
at::IntArrayRef shape,
torch::Device device,
const torch::Dtype dtype = torch::kUInt8) {
return torch::empty(
shape,
torch::TensorOptions()
.dtype(dtype)
.layout(torch::kStrided)
.device(device));
}

void throwErrorIfNonCudaDevice(const torch::Device& device) {
TORCH_CHECK(
device.type() != torch::kCPU,
Expand Down Expand Up @@ -199,7 +187,7 @@ void initializeContextOnCuda(
void convertAVFrameToDecodedOutputOnCuda(
const torch::Device& device,
const VideoDecoder::VideoStreamDecoderOptions& options,
AVCodecContext* codecContext,
const VideoDecoder::StreamMetadata& metadata,
VideoDecoder::RawDecodedOutput& rawOutput,
VideoDecoder::DecodedOutput& output,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
Expand All @@ -209,8 +197,9 @@ void convertAVFrameToDecodedOutputOnCuda(
src->format == AV_PIX_FMT_CUDA,
"Expected format to be AV_PIX_FMT_CUDA, got " +
std::string(av_get_pix_fmt_name((AVPixelFormat)src->format)));
int width = options.width.value_or(codecContext->width);
int height = options.height.value_or(codecContext->height);
auto frameDims = getHeightAndWidthFromOptionsOrMetadata(options, metadata);
int height = frameDims.height;
int width = frameDims.width;
NppiSize oSizeROI = {width, height};
Npp8u* input[2] = {src->data[0], src->data[1]};
torch::Tensor& dst = output.frame;
Expand All @@ -227,7 +216,7 @@ void convertAVFrameToDecodedOutputOnCuda(
"x3, got ",
shape);
} else {
dst = allocateDeviceTensor({height, width, 3}, options.device);
dst = allocateEmptyHWCTensor(height, width, options.device);
}

// Use the user-requested GPU for running the NPP kernel.
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/decoders/_core/DeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ void initializeContextOnCuda(
void convertAVFrameToDecodedOutputOnCuda(
const torch::Device& device,
const VideoDecoder::VideoStreamDecoderOptions& options,
AVCodecContext* codecContext,
const VideoDecoder::StreamMetadata& metadata,
VideoDecoder::RawDecodedOutput& rawOutput,
VideoDecoder::DecodedOutput& output,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
Expand Down
86 changes: 64 additions & 22 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput(
int64_t numFrames,
const VideoStreamDecoderOptions& options,
const StreamMetadata& metadata)
: frames(torch::empty(
{numFrames,
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width),
3},
at::TensorOptions(options.device).dtype(torch::kUInt8))),
ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {}
: ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})),
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {
auto frameDims = getHeightAndWidthFromOptionsOrMetadata(options, metadata);
int height = frameDims.height;
int width = frameDims.width;
frames = allocateEmptyHWCTensor(height, width, options.device, numFrames);
}

VideoDecoder::VideoDecoder() {}

Expand Down Expand Up @@ -364,12 +363,11 @@ void VideoDecoder::initializeFilterGraphForStream(
inputs->pad_idx = 0;
inputs->next = nullptr;
char description[512];
int width = activeStream.codecContext->width;
int height = activeStream.codecContext->height;
if (options.height.has_value() && options.width.has_value()) {
width = *options.width;
height = *options.height;
}
auto frameDims = getHeightAndWidthFromOptionsOrMetadata(
options, containerMetadata_.streams[streamIndex]);
int height = frameDims.height;
int width = frameDims.width;

std::snprintf(
description,
sizeof(description),
Expand Down Expand Up @@ -869,7 +867,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
convertAVFrameToDecodedOutputOnCuda(
streamInfo.options.device,
streamInfo.options,
streamInfo.codecContext.get(),
containerMetadata_.streams[streamIndex],
rawOutput,
output,
preAllocatedOutputTensor);
Expand Down Expand Up @@ -899,8 +897,10 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
torch::Tensor tensor;
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
int width = streamInfo.options.width.value_or(frame->width);
int height = streamInfo.options.height.value_or(frame->height);
auto frameDims =
getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *frame);
int height = frameDims.height;
int width = frameDims.width;
if (preAllocatedOutputTensor.has_value()) {
tensor = preAllocatedOutputTensor.value();
auto shape = tensor.sizes();
Expand All @@ -914,8 +914,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
"x3, got ",
shape);
} else {
tensor = torch::empty(
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8}));
tensor = allocateEmptyHWCTensor(height, width, torch::kCPU);
}
rawOutput.data = tensor.data_ptr<uint8_t>();
convertFrameToBufferUsingSwsScale(rawOutput);
Expand Down Expand Up @@ -1315,8 +1314,10 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale(
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(frame->format);
StreamInfo& activeStream = streams_[streamIndex];
int outputWidth = activeStream.options.width.value_or(frame->width);
int outputHeight = activeStream.options.height.value_or(frame->height);
auto frameDims =
getHeightAndWidthFromOptionsOrAVFrame(activeStream.options, *frame);
int outputHeight = frameDims.height;
int outputWidth = frameDims.width;
if (activeStream.swsContext.get() == nullptr) {
SwsContext* swsContext = sws_getContext(
frame->width,
Expand Down Expand Up @@ -1382,7 +1383,11 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph(
ffmpegStatus =
av_buffersink_get_frame(filterState.sinkContext, filteredFrame.get());
TORCH_CHECK_EQ(filteredFrame->format, AV_PIX_FMT_RGB24);
std::vector<int64_t> shape = {filteredFrame->height, filteredFrame->width, 3};
auto frameDims = getHeightAndWidthFromOptionsOrAVFrame(
streams_[streamIndex].options, *filteredFrame.get());
int height = frameDims.height;
int width = frameDims.width;
std::vector<int64_t> shape = {height, width, 3};
std::vector<int64_t> strides = {filteredFrame->linesize[0], 3, 1};
AVFrame* filteredFramePtr = filteredFrame.release();
auto deleter = [filteredFramePtr](void*) {
Expand All @@ -1406,6 +1411,43 @@ VideoDecoder::~VideoDecoder() {
}
}

FrameDims getHeightAndWidthFromOptionsOrMetadata(
const VideoDecoder::VideoStreamDecoderOptions& options,
const VideoDecoder::StreamMetadata& metadata) {
return FrameDims(
options.height.value_or(*metadata.height),
options.width.value_or(*metadata.width));
}

FrameDims getHeightAndWidthFromOptionsOrAVFrame(
const VideoDecoder::VideoStreamDecoderOptions& options,
const AVFrame& avFrame) {
return FrameDims(
options.height.value_or(avFrame.height),
options.width.value_or(avFrame.width));
}

torch::Tensor allocateEmptyHWCTensor(
int height,
int width,
torch::Device device,
std::optional<int> numFrames) {
auto tensorOptions = torch::TensorOptions()
.dtype(torch::kUInt8)
.layout(torch::kStrided)
.device(device);
TORCH_CHECK(height > 0, "height must be > 0, got: ", height);
TORCH_CHECK(width > 0, "width must be > 0, got: ", width);
if (numFrames.has_value()) {
auto numFramesValue = numFrames.value();
TORCH_CHECK(
numFramesValue >= 0, "numFrames must be >= 0, got: ", numFramesValue);
return torch::empty({numFramesValue, height, width, 3}, tensorOptions);
} else {
return torch::empty({height, width, 3}, tensorOptions);
}
}

std::ostream& operator<<(
std::ostream& os,
const VideoDecoder::DecodeStats& stats) {
Expand Down
64 changes: 64 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ class VideoDecoder {
const VideoStreamDecoderOptions& options,
const StreamMetadata& metadata);
};

// Returns frames at the given indices for a given stream as a single stacked
// Tensor.
BatchDecodedOutput getFramesAtIndices(
Expand Down Expand Up @@ -413,6 +414,69 @@ class VideoDecoder {
bool scanned_all_streams_ = false;
};

// --------------------------------------------------------------------------
// FRAME TENSOR ALLOCATION APIs
// --------------------------------------------------------------------------

// Note [Frame Tensor allocation and height and width]
//
// We always allocate [N]HWC tensors. The low-level decoding functions all
// assume HWC tensors, since this is what FFmpeg natively handles. It's up to
// the high-level decoding entry-points to permute that back to CHW, by calling
// MaybePermuteHWC2CHW().
//
// Also, importantly, the way we figure out the the height and width of the
// output frame varies and depends on the decoding entry-point:
// - In all cases, if the user requested specific height and width from the
// options, we honor that. Otherwise we fall into one of the categories below.
// - In Batch decoding APIs (e.g. getFramesAtIndices), we get height and width
// from the stream metadata, which itself got its value from the CodecContext,
// when the stream was added.
// - In single frames APIs:
// - On CPU we get height and width from the AVFrame.
// - On GPU, we get height and width from the metadata (same as batch APIs)
//
// These 2 strategies are encapsulated within
// getHeightAndWidthFromOptionsOrMetadata() and
// getHeightAndWidthFromOptionsOrAVFrame(). The reason they exist is to make it
// very obvious which logic is used in which place, and they allow for `git
// grep`ing.
//
// The source of truth for height and width really is the AVFrame: it's the
// decoded ouptut from FFmpeg. The info from the metadata (i.e. from the
// CodecContext) may not be as accurate. However, the AVFrame is only available
// late in the call stack, when the frame is decoded, while the CodecContext is
// available early when a stream is added. This is why we use the CodecContext
// for pre-allocating batched output tensors (we could pre-allocate those only
// once we decode the first frame to get the info frame the AVFrame, but that's
// a more complex logic).
//
// Because the sources for height and width may disagree, we may end up with
// conflicts: e.g. if we pre-allocate a batch output tensor based on the
// metadata info, but the decoded AVFrame has a different height and width.
// it is very important to check the height and width assumptions where the
// tensors memory is used/filled in order to avoid segfaults.

struct FrameDims {
int height;
int width;
FrameDims(int h, int w) : height(h), width(w) {}
};

FrameDims getHeightAndWidthFromOptionsOrMetadata(
const VideoDecoder::VideoStreamDecoderOptions& options,
const VideoDecoder::StreamMetadata& metadata);

FrameDims getHeightAndWidthFromOptionsOrAVFrame(
const VideoDecoder::VideoStreamDecoderOptions& options,
const AVFrame& avFrame);

torch::Tensor allocateEmptyHWCTensor(
int height,
int width,
torch::Device device,
std::optional<int> numFrames = std::nullopt);

// Prints the VideoDecoder::DecodeStats to the ostream.
std::ostream& operator<<(
std::ostream& os,
Expand Down

0 comments on commit 373d1c5

Please sign in to comment.