Skip to content

Commit

Permalink
[TRT][OP] Fix Upsample TRT LayerBuilder BUG for dim upsample under dy…
Browse files Browse the repository at this point in the history
…namic batch cases
  • Loading branch information
doxutx committed Mar 28, 2024
1 parent f453576 commit 7cea315
Showing 1 changed file with 35 additions and 20 deletions.
55 changes: 35 additions & 20 deletions source/tnn/network/tensorrt/layer_builder/upsample_layer_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,34 +66,49 @@ ILayer* UpsampleTRTPluginLayerBuilder::AddToNetwork(INetworkDefinition* network)
layer->setName(layer_name_.c_str());
if (input_blobs_.size() == 1) {
if (!paramlist->dims.empty()) {
if (!output_dims.empty()) {
nvinfer1::Dims4 dims(output_dims[0], output_dims[1], output_dims[2], output_dims[3]);
layer->setOutputDimensions(dims);
} else {
if (input_tensor->getDimensions().nbDims != 4) {
LOGE("Upsample with 1 input only support 4d input.");
return nullptr;
}
if (paramlist->dims.size() != 2 ||
paramlist->dims[0] <= 0 || paramlist->dims[1] <= 0) {
LOGE("Upsample with 1 input should have positive 2-element param->dims.");
return nullptr;
}
auto trt_dim = input_tensor->getDimensions();
// trt_dim may have one of the following values:
// [-1,3,32,32], [-1,2,-1,-1], [1,16,256,256]
if ((trt_dim.d[0] <= 0 || trt_dim.d[1] <= 0) &&
(trt_dim.d[2] > 0 && trt_dim.d[3] > 0)) {
auto trt_dim = input_tensor->getDimensions();
if (trt_dim.nbDims != 4) {
LOGE("Upsample with 1 input only support 4d input.");
return nullptr;
}

// trt_dim may have one of the following values:
// [-1,3,32,32], [-1,2,-1,-1], [1,16,256,256]
if (trt_dim.d[0] <= 0 || trt_dim.d[1] <= 0) {
// Cases When At least One of N, C be dynamic
if (trt_dim.d[2] > 0 && trt_dim.d[3] > 0) {
// Cases when H,W are fixed, turn to scale mode
// e.g [-1,3,32,32]
float scale[4];
scale[0] = 1;
scale[1] = 1;
scale[2] = paramlist->dims[0] / float(trt_dim.d[2]);
scale[3] = paramlist->dims[1] / float(trt_dim.d[3]);
layer->setScales(scale, 4);
} else {
// WARNING, trt_dims may have -1 values.
nvinfer1::Dims4 dims(trt_dim.d[0], trt_dim.d[1], paramlist->dims[0], paramlist->dims[1]);
// Cases When Both N,C and H+W are dynamic
// In this case, We cannot turn to Scale mode.
// Also layer->SetOutputDimensions() API does not accept -1 as dim
// Have to use TNN Upsample Plugin.
// e.g [-1,2,-1,-1]
LOGI("WARNING: Dynamic NCHW Upsample with fixed dims provided, NOT SUPPORTED by TensorRT, use TNN Upsample Plugin instead.");
return TensorRTPluginLayerBuilder::AddToNetwork(network);
}
} else {
// Cases When Both N and C are fixed
// e.g [1,16,256,256]
if (!output_dims.empty() && output_dims[2] > 0 && output_dims[3] > 0) {
nvinfer1::Dims4 dims(trt_dim.d[0], trt_dim.d[1],
output_dims[2], output_dims[3]);
layer->setOutputDimensions(dims);
} else if (paramlist->dims.size() >= 2 &&
paramlist->dims[0] > 0 && paramlist->dims[1] > 0) {
nvinfer1::Dims4 dims(trt_dim.d[0], trt_dim.d[1],
paramlist->dims[0], paramlist->dims[1]);
layer->setOutputDimensions(dims);
} else {
LOGE("Upsample with 1 input Fix N,C + Fixed dims does not have standard positive dim, Unsupported.");
return nullptr;
}
}
} else {
Expand Down

0 comments on commit 7cea315

Please sign in to comment.