-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Towards a clearer logic for determining output height and width #332
Changes from 7 commits
0256e18
4a9c00c
e529764
dcb50d5
a07fb2d
733e62b
52abe9b
b178250
e386b56
2cd2059
26e6203
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -154,18 +154,6 @@ AVBufferRef* getCudaContext(const torch::Device& device) { | |
#endif | ||
} | ||
|
||
torch::Tensor allocateDeviceTensor( | ||
at::IntArrayRef shape, | ||
torch::Device device, | ||
const torch::Dtype dtype = torch::kUInt8) { | ||
return torch::empty( | ||
shape, | ||
torch::TensorOptions() | ||
.dtype(dtype) | ||
.layout(torch::kStrided) | ||
.device(device)); | ||
} | ||
|
||
void throwErrorIfNonCudaDevice(const torch::Device& device) { | ||
TORCH_CHECK( | ||
device.type() != torch::kCPU, | ||
|
@@ -199,7 +187,7 @@ void initializeContextOnCuda( | |
void convertAVFrameToDecodedOutputOnCuda( | ||
const torch::Device& device, | ||
const VideoDecoder::VideoStreamDecoderOptions& options, | ||
AVCodecContext* codecContext, | ||
const VideoDecoder::StreamMetadata& metadata, | ||
VideoDecoder::RawDecodedOutput& rawOutput, | ||
VideoDecoder::DecodedOutput& output, | ||
std::optional<torch::Tensor> preAllocatedOutputTensor) { | ||
|
@@ -209,8 +197,9 @@ void convertAVFrameToDecodedOutputOnCuda( | |
src->format == AV_PIX_FMT_CUDA, | ||
"Expected format to be AV_PIX_FMT_CUDA, got " + | ||
std::string(av_get_pix_fmt_name((AVPixelFormat)src->format))); | ||
int width = options.width.value_or(codecContext->width); | ||
int height = options.height.value_or(codecContext->height); | ||
int height = 0, width = 0; | ||
std::tie(height, width) = | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO this is less maintainable/safe than returning a struct with named members. Someone could write:
|
||
getHeightAndWidthFromOptionsOrMetadata(options, metadata); | ||
NppiSize oSizeROI = {width, height}; | ||
Npp8u* input[2] = {src->data[0], src->data[1]}; | ||
torch::Tensor& dst = output.frame; | ||
|
@@ -227,7 +216,7 @@ void convertAVFrameToDecodedOutputOnCuda( | |
"x3, got ", | ||
shape); | ||
} else { | ||
dst = allocateDeviceTensor({height, width, 3}, options.device); | ||
dst = allocateEmptyHWCTensor(height, width, options.device); | ||
} | ||
|
||
// Use the user-requested GPU for running the NPP kernel. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -195,14 +195,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( | |
int64_t numFrames, | ||
const VideoStreamDecoderOptions& options, | ||
const StreamMetadata& metadata) | ||
: frames(torch::empty( | ||
{numFrames, | ||
options.height.value_or(*metadata.height), | ||
options.width.value_or(*metadata.width), | ||
3}, | ||
at::TensorOptions(options.device).dtype(torch::kUInt8))), | ||
ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), | ||
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {} | ||
: ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), | ||
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) { | ||
int height = 0, width = 0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same nit, and there's a few other places. |
||
std::tie(height, width) = | ||
getHeightAndWidthFromOptionsOrMetadata(options, metadata); | ||
frames = allocateEmptyHWCTensor(height, width, options.device, numFrames); | ||
} | ||
|
||
VideoDecoder::VideoDecoder() {} | ||
|
||
|
@@ -364,12 +363,10 @@ void VideoDecoder::initializeFilterGraphForStream( | |
inputs->pad_idx = 0; | ||
inputs->next = nullptr; | ||
char description[512]; | ||
int width = activeStream.codecContext->width; | ||
int height = activeStream.codecContext->height; | ||
if (options.height.has_value() && options.width.has_value()) { | ||
width = *options.width; | ||
height = *options.height; | ||
} | ||
int height = 0, width = 0; | ||
std::tie(height, width) = getHeightAndWidthFromOptionsOrMetadata( | ||
options, containerMetadata_.streams[streamIndex]); | ||
|
||
std::snprintf( | ||
description, | ||
sizeof(description), | ||
|
@@ -869,7 +866,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( | |
convertAVFrameToDecodedOutputOnCuda( | ||
streamInfo.options.device, | ||
streamInfo.options, | ||
streamInfo.codecContext.get(), | ||
containerMetadata_.streams[streamIndex], | ||
rawOutput, | ||
output, | ||
preAllocatedOutputTensor); | ||
|
@@ -899,8 +896,9 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( | |
torch::Tensor tensor; | ||
if (output.streamType == AVMEDIA_TYPE_VIDEO) { | ||
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { | ||
int width = streamInfo.options.width.value_or(frame->width); | ||
int height = streamInfo.options.height.value_or(frame->height); | ||
int height = 0, width = 0; | ||
std::tie(height, width) = | ||
getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, frame); | ||
if (preAllocatedOutputTensor.has_value()) { | ||
tensor = preAllocatedOutputTensor.value(); | ||
auto shape = tensor.sizes(); | ||
|
@@ -914,8 +912,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( | |
"x3, got ", | ||
shape); | ||
} else { | ||
tensor = torch::empty( | ||
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8})); | ||
tensor = allocateEmptyHWCTensor(height, width, torch::kCPU); | ||
} | ||
rawOutput.data = tensor.data_ptr<uint8_t>(); | ||
convertFrameToBufferUsingSwsScale(rawOutput); | ||
|
@@ -1315,8 +1312,9 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale( | |
enum AVPixelFormat frameFormat = | ||
static_cast<enum AVPixelFormat>(frame->format); | ||
StreamInfo& activeStream = streams_[streamIndex]; | ||
int outputWidth = activeStream.options.width.value_or(frame->width); | ||
int outputHeight = activeStream.options.height.value_or(frame->height); | ||
int outputHeight = 0, outputWidth = 0; | ||
std::tie(outputHeight, outputWidth) = | ||
getHeightAndWidthFromOptionsOrAVFrame(activeStream.options, frame); | ||
if (activeStream.swsContext.get() == nullptr) { | ||
SwsContext* swsContext = sws_getContext( | ||
frame->width, | ||
|
@@ -1382,7 +1380,11 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph( | |
ffmpegStatus = | ||
av_buffersink_get_frame(filterState.sinkContext, filteredFrame.get()); | ||
TORCH_CHECK_EQ(filteredFrame->format, AV_PIX_FMT_RGB24); | ||
std::vector<int64_t> shape = {filteredFrame->height, filteredFrame->width, 3}; | ||
int height = 0, width = 0; | ||
std::tie(height, width) = getHeightAndWidthFromOptionsOrAVFrame( | ||
streams_[streamIndex].options, filteredFrame.get()); | ||
std::vector<int64_t> shape = {height, width, 3}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this is the only place where the logic is slightly changed (but I think the behavior is the same): we go through There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks correct to me, but @ahmadsharif1 should also reason through it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe add a TORCH_CHECK to make sure the filteredFrame has the expected dimensions? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll add the validity checks as follow up! |
||
|
||
std::vector<int64_t> strides = {filteredFrame->linesize[0], 3, 1}; | ||
AVFrame* filteredFramePtr = filteredFrame.release(); | ||
auto deleter = [filteredFramePtr](void*) { | ||
|
@@ -1406,6 +1408,38 @@ VideoDecoder::~VideoDecoder() { | |
} | ||
} | ||
|
||
std::tuple<int, int> getHeightAndWidthFromOptionsOrMetadata( | ||
const VideoDecoder::VideoStreamDecoderOptions& options, | ||
const VideoDecoder::StreamMetadata& metadata) { | ||
return std::make_tuple( | ||
options.height.value_or(*metadata.height), | ||
options.width.value_or(*metadata.width)); | ||
} | ||
|
||
std::tuple<int, int> getHeightAndWidthFromOptionsOrAVFrame( | ||
const VideoDecoder::VideoStreamDecoderOptions& options, | ||
AVFrame* avFrame) { | ||
return std::make_tuple( | ||
options.height.value_or(avFrame->height), | ||
options.width.value_or(avFrame->width)); | ||
} | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Read the comment below, I get that you're purposefully marking in the code which strategy we're doing where. |
||
torch::Tensor allocateEmptyHWCTensor( | ||
int height, | ||
int width, | ||
torch::Device device, | ||
std::optional<int> numFrames) { | ||
auto tensorOptions = torch::TensorOptions() | ||
.dtype(torch::kUInt8) | ||
.layout(torch::kStrided) | ||
.device(device); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not critical or blocking, but this may be a good time to put in |
||
if (numFrames.has_value()) { | ||
return torch::empty({numFrames.value(), height, width, 3}, tensorOptions); | ||
} else { | ||
return torch::empty({height, width, 3}, tensorOptions); | ||
} | ||
} | ||
|
||
std::ostream& operator<<( | ||
std::ostream& os, | ||
const VideoDecoder::DecodeStats& stats) { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -243,6 +243,7 @@ class VideoDecoder { | |
const VideoStreamDecoderOptions& options, | ||
const StreamMetadata& metadata); | ||
}; | ||
|
||
// Returns frames at the given indices for a given stream as a single stacked | ||
// Tensor. | ||
BatchDecodedOutput getFramesAtIndices( | ||
|
@@ -413,6 +414,63 @@ class VideoDecoder { | |
bool scanned_all_streams_ = false; | ||
}; | ||
|
||
// -------------------------------------------------------------------------- | ||
// FRAME TENSOR ALLOCATION APIs | ||
// -------------------------------------------------------------------------- | ||
|
||
// Note [Frame Tensor allocation and height and width] | ||
// | ||
// We always allocate [N]HWC tensors. The low-level decoding functions all | ||
// assume HWC tensors, since this is what FFmpeg natively handles. It's up to | ||
// the high-level decoding entry-points to permute that back to CHW, by calling | ||
// MaybePermuteHWC2CHW(). | ||
// | ||
// Also, importantly, the way we figure out the the height and width of the | ||
// output frame varies and depends on the decoding entry-point: | ||
// - In all cases, if the user requested specific height and width from the | ||
// options, we honor that. Otherwise we fall into one of the categories below. | ||
// - In Batch decoding APIs (e.g. getFramesAtIndices), we get height and width | ||
// from the stream metadata, which itself got its value from the CodecContext, | ||
// when the stream was added. | ||
// - In single frames APIs: | ||
// - On CPU we get height and width from the AVFrame. | ||
// - On GPU, we get height and width from the metadata (same as batch APIs) | ||
// | ||
// These 2 strategies are encapsulated within | ||
// getHeightAndWidthFromOptionsOrMetadata() and | ||
// getHeightAndWidthFromOptionsOrAVFrame(). The reason they exist is to make it | ||
// very obvious which logic is used in which place, and they allow for `git | ||
// grep`ing. | ||
// | ||
// The source of truth for height and width really is the AVFrame: it's the | ||
// decoded ouptut from FFmpeg. The info from the metadata (i.e. from the | ||
// CodecContext) may not be as accurate. However, the AVFrame is only available | ||
// late in the call stack, when the frame is decoded, while the CodecContext is | ||
// available early when a stream is added. This is why we use the CodecContext | ||
// for pre-allocating batched output tensors (we could pre-allocate those only | ||
// once we decode the first frame to get the info frame the AVFrame, but that's | ||
// a more complex logic). | ||
// | ||
// Because the sources for height and width may disagree, we may end up with | ||
// conflicts: e.g. if we pre-allocate a batch output tensor based on the | ||
// metadata info, but the decoded AVFrame has a different height and width. | ||
// it is very important to check the height and width assumptions where the | ||
// tensors memory is used/filled in order to avoid segfaults. | ||
|
||
std::tuple<int, int> getHeightAndWidthFromOptionsOrMetadata( | ||
const VideoDecoder::VideoStreamDecoderOptions& options, | ||
const VideoDecoder::StreamMetadata& metadata); | ||
|
||
std::tuple<int, int> getHeightAndWidthFromOptionsOrAVFrame( | ||
const VideoDecoder::VideoStreamDecoderOptions& options, | ||
AVFrame* avFrame); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be const too const & would be preferred since this is a mandatory parameter and should never be null |
||
|
||
torch::Tensor allocateEmptyHWCTensor( | ||
int height, | ||
int width, | ||
torch::Device device, | ||
std::optional<int> numFrames = std::nullopt); | ||
|
||
// Prints the VideoDecoder::DecodeStats to the ostream. | ||
std::ostream& operator<<( | ||
std::ostream& os, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: please declare only one variable per line.