diff --git a/source/tnn/network/tensorrt/layer_builder/upsample_layer_builder.cc b/source/tnn/network/tensorrt/layer_builder/upsample_layer_builder.cc index 4a8d9c840..721a32b5a 100644 --- a/source/tnn/network/tensorrt/layer_builder/upsample_layer_builder.cc +++ b/source/tnn/network/tensorrt/layer_builder/upsample_layer_builder.cc @@ -66,24 +66,19 @@ ILayer* UpsampleTRTPluginLayerBuilder::AddToNetwork(INetworkDefinition* network) layer->setName(layer_name_.c_str()); if (input_blobs_.size() == 1) { if (!paramlist->dims.empty()) { - if (!output_dims.empty()) { - nvinfer1::Dims4 dims(output_dims[0], output_dims[1], output_dims[2], output_dims[3]); - layer->setOutputDimensions(dims); - } else { - if (input_tensor->getDimensions().nbDims != 4) { - LOGE("Upsample with 1 input only support 4d input."); - return nullptr; - } - if (paramlist->dims.size() != 2 || - paramlist->dims[0] <= 0 || paramlist->dims[1] <= 0) { - LOGE("Upsample with 1 input should have positive 2-element param->dims."); - return nullptr; - } - auto trt_dim = input_tensor->getDimensions(); - // trt_dim may have one of the following values: - // [-1,3,32,32], [-1,2,-1,-1], [1,16,256,256] - if ((trt_dim.d[0] <= 0 || trt_dim.d[1] <= 0) && - (trt_dim.d[2] > 0 && trt_dim.d[3] > 0)) { + auto trt_dim = input_tensor->getDimensions(); + if (trt_dim.nbDims != 4) { + LOGE("Upsample with 1 input only support 4d input."); + return nullptr; + } + + // trt_dim may have one of the following values: + // [-1,3,32,32], [-1,2,-1,-1], [1,16,256,256] + if (trt_dim.d[0] <= 0 || trt_dim.d[1] <= 0) { + // Cases When At least One of N, C be dynamic + if (trt_dim.d[2] > 0 && trt_dim.d[3] > 0) { + // Cases when H,W are fixed, turn to scale mode + // e.g [-1,3,32,32] float scale[4]; scale[0] = 1; scale[1] = 1; @@ -91,9 +86,29 @@ ILayer* UpsampleTRTPluginLayerBuilder::AddToNetwork(INetworkDefinition* network) scale[3] = paramlist->dims[1] / float(trt_dim.d[3]); layer->setScales(scale, 4); } else { - // WARNING, trt_dims may have -1 values. - nvinfer1::Dims4 dims(trt_dim.d[0], trt_dim.d[1], paramlist->dims[0], paramlist->dims[1]); + // Cases When Both N,C and H+W are dynamic + // In this case, We cannot turn to Scale mode. + // Also layer->SetOutputDimensions() API does not accept -1 as dim + // Have to use TNN Upsample Plugin. + // e.g [-1,2,-1,-1] + LOGI("WARNING: Dynamic NCHW Upsample with fixed dims provided, NOT SUPPORTED by TensorRT, use TNN Upsample Plugin instead."); + return TensorRTPluginLayerBuilder::AddToNetwork(network); + } + } else { + // Cases When Both N and C are fixed + // e.g [1,16,256,256] + if (!output_dims.empty() && output_dims[2] > 0 && output_dims[3] > 0) { + nvinfer1::Dims4 dims(trt_dim.d[0], trt_dim.d[1], + output_dims[2], output_dims[3]); layer->setOutputDimensions(dims); + } else if (paramlist->dims.size() >= 2 && + paramlist->dims[0] > 0 && paramlist->dims[1] > 0) { + nvinfer1::Dims4 dims(trt_dim.d[0], trt_dim.d[1], + paramlist->dims[0], paramlist->dims[1]); + layer->setOutputDimensions(dims); + } else { + LOGE("Upsample with 1 input Fix N,C + Fixed dims does not have standard positive dim, Unsupported."); + return nullptr; } } } else {