diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index 20a4e380..b4ebb273 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -17,7 +17,7 @@ namespace facebook::torchcodec { void convertAVFrameToDecodedOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamDecoderOptions& options, - AVCodecContext* codecContext, + const VideoDecoder::StreamMetadata& metadata, VideoDecoder::RawDecodedOutput& rawOutput, VideoDecoder::DecodedOutput& output, std::optional<torch::Tensor> preAllocatedOutputTensor) { diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index b15684a2..7c3964cc 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -154,18 +154,6 @@ AVBufferRef* getCudaContext(const torch::Device& device) { #endif } -torch::Tensor allocateDeviceTensor( - at::IntArrayRef shape, - torch::Device device, - const torch::Dtype dtype = torch::kUInt8) { - return torch::empty( - shape, - torch::TensorOptions() - .dtype(dtype) - .layout(torch::kStrided) - .device(device)); -} - void throwErrorIfNonCudaDevice(const torch::Device& device) { TORCH_CHECK( device.type() != torch::kCPU, @@ -199,7 +187,7 @@ void initializeContextOnCuda( void convertAVFrameToDecodedOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamDecoderOptions& options, - AVCodecContext* codecContext, + const VideoDecoder::StreamMetadata& metadata, VideoDecoder::RawDecodedOutput& rawOutput, VideoDecoder::DecodedOutput& output, std::optional<torch::Tensor> preAllocatedOutputTensor) { @@ -209,8 +197,9 @@ void convertAVFrameToDecodedOutputOnCuda( src->format == AV_PIX_FMT_CUDA, "Expected format to be AV_PIX_FMT_CUDA, got " + std::string(av_get_pix_fmt_name((AVPixelFormat)src->format))); - int width = options.width.value_or(codecContext->width); - int height = options.height.value_or(codecContext->height); + auto frameDims = getHeightAndWidthFromOptionsOrMetadata(options, metadata); + int height = frameDims.height; + int width = frameDims.width; NppiSize oSizeROI = {width, height}; Npp8u* input[2] = {src->data[0], src->data[1]}; torch::Tensor& dst = output.frame; @@ -227,7 +216,7 @@ void convertAVFrameToDecodedOutputOnCuda( "x3, got ", shape); } else { - dst = allocateDeviceTensor({height, width, 3}, options.device); + dst = allocateEmptyHWCTensor(height, width, options.device); } // Use the user-requested GPU for running the NPP kernel. diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index 772bdfe6..0fed5e00 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -35,7 +35,7 @@ void initializeContextOnCuda( void convertAVFrameToDecodedOutputOnCuda( const torch::Device& device, const VideoDecoder::VideoStreamDecoderOptions& options, - AVCodecContext* codecContext, + const VideoDecoder::StreamMetadata& metadata, VideoDecoder::RawDecodedOutput& rawOutput, VideoDecoder::DecodedOutput& output, std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index b20a5e54..169dc216 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -195,14 +195,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( int64_t numFrames, const VideoStreamDecoderOptions& options, const StreamMetadata& metadata) - : frames(torch::empty( - {numFrames, - options.height.value_or(*metadata.height), - options.width.value_or(*metadata.width), - 3}, - at::TensorOptions(options.device).dtype(torch::kUInt8))), - ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), - durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {} + : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), + durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) { + auto frameDims = getHeightAndWidthFromOptionsOrMetadata(options, metadata); + int height = frameDims.height; + int width = frameDims.width; + frames = allocateEmptyHWCTensor(height, width, options.device, numFrames); +} VideoDecoder::VideoDecoder() {} @@ -364,12 +363,11 @@ void VideoDecoder::initializeFilterGraphForStream( inputs->pad_idx = 0; inputs->next = nullptr; char description[512]; - int width = activeStream.codecContext->width; - int height = activeStream.codecContext->height; - if (options.height.has_value() && options.width.has_value()) { - width = *options.width; - height = *options.height; - } + auto frameDims = getHeightAndWidthFromOptionsOrMetadata( + options, containerMetadata_.streams[streamIndex]); + int height = frameDims.height; + int width = frameDims.width; + std::snprintf( description, sizeof(description), @@ -869,7 +867,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( convertAVFrameToDecodedOutputOnCuda( streamInfo.options.device, streamInfo.options, - streamInfo.codecContext.get(), + containerMetadata_.streams[streamIndex], rawOutput, output, preAllocatedOutputTensor); @@ -899,8 +897,10 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( torch::Tensor tensor; if (output.streamType == AVMEDIA_TYPE_VIDEO) { if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { - int width = streamInfo.options.width.value_or(frame->width); - int height = streamInfo.options.height.value_or(frame->height); + auto frameDims = + getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *frame); + int height = frameDims.height; + int width = frameDims.width; if (preAllocatedOutputTensor.has_value()) { tensor = preAllocatedOutputTensor.value(); auto shape = tensor.sizes(); @@ -914,8 +914,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( "x3, got ", shape); } else { - tensor = torch::empty( - {height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8})); + tensor = allocateEmptyHWCTensor(height, width, torch::kCPU); } rawOutput.data = tensor.data_ptr<uint8_t>(); convertFrameToBufferUsingSwsScale(rawOutput); @@ -1315,8 +1314,10 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale( enum AVPixelFormat frameFormat = static_cast<enum AVPixelFormat>(frame->format); StreamInfo& activeStream = streams_[streamIndex]; - int outputWidth = activeStream.options.width.value_or(frame->width); - int outputHeight = activeStream.options.height.value_or(frame->height); + auto frameDims = + getHeightAndWidthFromOptionsOrAVFrame(activeStream.options, *frame); + int outputHeight = frameDims.height; + int outputWidth = frameDims.width; if (activeStream.swsContext.get() == nullptr) { SwsContext* swsContext = sws_getContext( frame->width, @@ -1382,7 +1383,11 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph( ffmpegStatus = av_buffersink_get_frame(filterState.sinkContext, filteredFrame.get()); TORCH_CHECK_EQ(filteredFrame->format, AV_PIX_FMT_RGB24); - std::vector<int64_t> shape = {filteredFrame->height, filteredFrame->width, 3}; + auto frameDims = getHeightAndWidthFromOptionsOrAVFrame( + streams_[streamIndex].options, *filteredFrame.get()); + int height = frameDims.height; + int width = frameDims.width; + std::vector<int64_t> shape = {height, width, 3}; std::vector<int64_t> strides = {filteredFrame->linesize[0], 3, 1}; AVFrame* filteredFramePtr = filteredFrame.release(); auto deleter = [filteredFramePtr](void*) { @@ -1406,6 +1411,43 @@ VideoDecoder::~VideoDecoder() { } } +FrameDims getHeightAndWidthFromOptionsOrMetadata( + const VideoDecoder::VideoStreamDecoderOptions& options, + const VideoDecoder::StreamMetadata& metadata) { + return FrameDims( + options.height.value_or(*metadata.height), + options.width.value_or(*metadata.width)); +} + +FrameDims getHeightAndWidthFromOptionsOrAVFrame( + const VideoDecoder::VideoStreamDecoderOptions& options, + const AVFrame& avFrame) { + return FrameDims( + options.height.value_or(avFrame.height), + options.width.value_or(avFrame.width)); +} + +torch::Tensor allocateEmptyHWCTensor( + int height, + int width, + torch::Device device, + std::optional<int> numFrames) { + auto tensorOptions = torch::TensorOptions() + .dtype(torch::kUInt8) + .layout(torch::kStrided) + .device(device); + TORCH_CHECK(height > 0, "height must be > 0, got: ", height); + TORCH_CHECK(width > 0, "width must be > 0, got: ", width); + if (numFrames.has_value()) { + auto numFramesValue = numFrames.value(); + TORCH_CHECK( + numFramesValue >= 0, "numFrames must be >= 0, got: ", numFramesValue); + return torch::empty({numFramesValue, height, width, 3}, tensorOptions); + } else { + return torch::empty({height, width, 3}, tensorOptions); + } +} + std::ostream& operator<<( std::ostream& os, const VideoDecoder::DecodeStats& stats) { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index ce4a0cc1..f944c01a 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -243,6 +243,7 @@ class VideoDecoder { const VideoStreamDecoderOptions& options, const StreamMetadata& metadata); }; + // Returns frames at the given indices for a given stream as a single stacked // Tensor. BatchDecodedOutput getFramesAtIndices( @@ -413,6 +414,69 @@ class VideoDecoder { bool scanned_all_streams_ = false; }; +// -------------------------------------------------------------------------- +// FRAME TENSOR ALLOCATION APIs +// -------------------------------------------------------------------------- + +// Note [Frame Tensor allocation and height and width] +// +// We always allocate [N]HWC tensors. The low-level decoding functions all +// assume HWC tensors, since this is what FFmpeg natively handles. It's up to +// the high-level decoding entry-points to permute that back to CHW, by calling +// MaybePermuteHWC2CHW(). +// +// Also, importantly, the way we figure out the the height and width of the +// output frame varies and depends on the decoding entry-point: +// - In all cases, if the user requested specific height and width from the +// options, we honor that. Otherwise we fall into one of the categories below. +// - In Batch decoding APIs (e.g. getFramesAtIndices), we get height and width +// from the stream metadata, which itself got its value from the CodecContext, +// when the stream was added. +// - In single frames APIs: +// - On CPU we get height and width from the AVFrame. +// - On GPU, we get height and width from the metadata (same as batch APIs) +// +// These 2 strategies are encapsulated within +// getHeightAndWidthFromOptionsOrMetadata() and +// getHeightAndWidthFromOptionsOrAVFrame(). The reason they exist is to make it +// very obvious which logic is used in which place, and they allow for `git +// grep`ing. +// +// The source of truth for height and width really is the AVFrame: it's the +// decoded ouptut from FFmpeg. The info from the metadata (i.e. from the +// CodecContext) may not be as accurate. However, the AVFrame is only available +// late in the call stack, when the frame is decoded, while the CodecContext is +// available early when a stream is added. This is why we use the CodecContext +// for pre-allocating batched output tensors (we could pre-allocate those only +// once we decode the first frame to get the info frame the AVFrame, but that's +// a more complex logic). +// +// Because the sources for height and width may disagree, we may end up with +// conflicts: e.g. if we pre-allocate a batch output tensor based on the +// metadata info, but the decoded AVFrame has a different height and width. +// it is very important to check the height and width assumptions where the +// tensors memory is used/filled in order to avoid segfaults. + +struct FrameDims { + int height; + int width; + FrameDims(int h, int w) : height(h), width(w) {} +}; + +FrameDims getHeightAndWidthFromOptionsOrMetadata( + const VideoDecoder::VideoStreamDecoderOptions& options, + const VideoDecoder::StreamMetadata& metadata); + +FrameDims getHeightAndWidthFromOptionsOrAVFrame( + const VideoDecoder::VideoStreamDecoderOptions& options, + const AVFrame& avFrame); + +torch::Tensor allocateEmptyHWCTensor( + int height, + int width, + torch::Device device, + std::optional<int> numFrames = std::nullopt); + // Prints the VideoDecoder::DecodeStats to the ostream. std::ostream& operator<<( std::ostream& os,