-
Notifications
You must be signed in to change notification settings - Fork 20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Towards a clearer logic for determining output height and width #332
Changes from all commits
0256e18
4a9c00c
e529764
dcb50d5
a07fb2d
733e62b
52abe9b
b178250
e386b56
2cd2059
26e6203
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -195,14 +195,13 @@ VideoDecoder::BatchDecodedOutput::BatchDecodedOutput( | |
int64_t numFrames, | ||
const VideoStreamDecoderOptions& options, | ||
const StreamMetadata& metadata) | ||
: frames(torch::empty( | ||
{numFrames, | ||
options.height.value_or(*metadata.height), | ||
options.width.value_or(*metadata.width), | ||
3}, | ||
at::TensorOptions(options.device).dtype(torch::kUInt8))), | ||
ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), | ||
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) {} | ||
: ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), | ||
durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) { | ||
auto frameDims = getHeightAndWidthFromOptionsOrMetadata(options, metadata); | ||
int height = frameDims.height; | ||
int width = frameDims.width; | ||
frames = allocateEmptyHWCTensor(height, width, options.device, numFrames); | ||
} | ||
|
||
VideoDecoder::VideoDecoder() {} | ||
|
||
|
@@ -364,12 +363,11 @@ void VideoDecoder::initializeFilterGraphForStream( | |
inputs->pad_idx = 0; | ||
inputs->next = nullptr; | ||
char description[512]; | ||
int width = activeStream.codecContext->width; | ||
int height = activeStream.codecContext->height; | ||
if (options.height.has_value() && options.width.has_value()) { | ||
width = *options.width; | ||
height = *options.height; | ||
} | ||
auto frameDims = getHeightAndWidthFromOptionsOrMetadata( | ||
options, containerMetadata_.streams[streamIndex]); | ||
int height = frameDims.height; | ||
int width = frameDims.width; | ||
|
||
std::snprintf( | ||
description, | ||
sizeof(description), | ||
|
@@ -869,7 +867,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( | |
convertAVFrameToDecodedOutputOnCuda( | ||
streamInfo.options.device, | ||
streamInfo.options, | ||
streamInfo.codecContext.get(), | ||
containerMetadata_.streams[streamIndex], | ||
rawOutput, | ||
output, | ||
preAllocatedOutputTensor); | ||
|
@@ -899,8 +897,10 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( | |
torch::Tensor tensor; | ||
if (output.streamType == AVMEDIA_TYPE_VIDEO) { | ||
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { | ||
int width = streamInfo.options.width.value_or(frame->width); | ||
int height = streamInfo.options.height.value_or(frame->height); | ||
auto frameDims = | ||
getHeightAndWidthFromOptionsOrAVFrame(streamInfo.options, *frame); | ||
int height = frameDims.height; | ||
int width = frameDims.width; | ||
if (preAllocatedOutputTensor.has_value()) { | ||
tensor = preAllocatedOutputTensor.value(); | ||
auto shape = tensor.sizes(); | ||
|
@@ -914,8 +914,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( | |
"x3, got ", | ||
shape); | ||
} else { | ||
tensor = torch::empty( | ||
{height, width, 3}, torch::TensorOptions().dtype({torch::kUInt8})); | ||
tensor = allocateEmptyHWCTensor(height, width, torch::kCPU); | ||
} | ||
rawOutput.data = tensor.data_ptr<uint8_t>(); | ||
convertFrameToBufferUsingSwsScale(rawOutput); | ||
|
@@ -1315,8 +1314,10 @@ void VideoDecoder::convertFrameToBufferUsingSwsScale( | |
enum AVPixelFormat frameFormat = | ||
static_cast<enum AVPixelFormat>(frame->format); | ||
StreamInfo& activeStream = streams_[streamIndex]; | ||
int outputWidth = activeStream.options.width.value_or(frame->width); | ||
int outputHeight = activeStream.options.height.value_or(frame->height); | ||
auto frameDims = | ||
getHeightAndWidthFromOptionsOrAVFrame(activeStream.options, *frame); | ||
int outputHeight = frameDims.height; | ||
int outputWidth = frameDims.width; | ||
if (activeStream.swsContext.get() == nullptr) { | ||
SwsContext* swsContext = sws_getContext( | ||
frame->width, | ||
|
@@ -1382,7 +1383,11 @@ torch::Tensor VideoDecoder::convertFrameToTensorUsingFilterGraph( | |
ffmpegStatus = | ||
av_buffersink_get_frame(filterState.sinkContext, filteredFrame.get()); | ||
TORCH_CHECK_EQ(filteredFrame->format, AV_PIX_FMT_RGB24); | ||
std::vector<int64_t> shape = {filteredFrame->height, filteredFrame->width, 3}; | ||
auto frameDims = getHeightAndWidthFromOptionsOrAVFrame( | ||
streams_[streamIndex].options, *filteredFrame.get()); | ||
int height = frameDims.height; | ||
int width = frameDims.width; | ||
std::vector<int64_t> shape = {height, width, 3}; | ||
std::vector<int64_t> strides = {filteredFrame->linesize[0], 3, 1}; | ||
AVFrame* filteredFramePtr = filteredFrame.release(); | ||
auto deleter = [filteredFramePtr](void*) { | ||
|
@@ -1406,6 +1411,43 @@ VideoDecoder::~VideoDecoder() { | |
} | ||
} | ||
|
||
FrameDims getHeightAndWidthFromOptionsOrMetadata( | ||
const VideoDecoder::VideoStreamDecoderOptions& options, | ||
const VideoDecoder::StreamMetadata& metadata) { | ||
return FrameDims( | ||
options.height.value_or(*metadata.height), | ||
options.width.value_or(*metadata.width)); | ||
} | ||
|
||
FrameDims getHeightAndWidthFromOptionsOrAVFrame( | ||
const VideoDecoder::VideoStreamDecoderOptions& options, | ||
const AVFrame& avFrame) { | ||
return FrameDims( | ||
options.height.value_or(avFrame.height), | ||
options.width.value_or(avFrame.width)); | ||
} | ||
|
||
torch::Tensor allocateEmptyHWCTensor( | ||
int height, | ||
int width, | ||
torch::Device device, | ||
std::optional<int> numFrames) { | ||
auto tensorOptions = torch::TensorOptions() | ||
.dtype(torch::kUInt8) | ||
.layout(torch::kStrided) | ||
.device(device); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not critical or blocking, but this may be a good time to put in |
||
TORCH_CHECK(height > 0, "height must be > 0, got: ", height); | ||
TORCH_CHECK(width > 0, "width must be > 0, got: ", width); | ||
if (numFrames.has_value()) { | ||
auto numFramesValue = numFrames.value(); | ||
TORCH_CHECK( | ||
numFramesValue >= 0, "numFrames must be >= 0, got: ", numFramesValue); | ||
return torch::empty({numFramesValue, height, width, 3}, tensorOptions); | ||
} else { | ||
return torch::empty({height, width, 3}, tensorOptions); | ||
} | ||
} | ||
|
||
std::ostream& operator<<( | ||
std::ostream& os, | ||
const VideoDecoder::DecodeStats& stats) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd actually prefer that we just call both of these functionsgetHeightAndWidth
and allow C++ function overloading to determine which one to call. In C++, because the types of parameters are formally a part of the function signature, it's less common to encode the type of one of the parameters in the name. This, however, is a question of style, and I know @ahmadsharif1 may feel differently.Read the comment below, I get that you're purposefully marking in the code which strategy we're doing where.