Skip to content

Commit

Permalink
[TensorRT EP] Enable more trt options (#237)
Browse files Browse the repository at this point in the history
  • Loading branch information
yf711 authored Mar 15, 2024
1 parent 6b896f0 commit 72189f9
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 0 deletions.
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,48 @@ TensorRT can be used in conjunction with an ONNX model to further optimize the p
* `trt_engine_cache_enable`: Enable engine caching.
* `trt_engine_cache_path`: Specify engine cache path.

To explore the usage of more parameters, follow the mapping table below and check [ONNX Runtime doc](https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#execution-provider-options) for detail.

> Please link to the latest ONNX Runtime binaries in CMake or build from [main branch of ONNX Runtime](https://github.com/microsoft/onnxruntime/tree/main) to enable latest options.
### Parameter mapping between ONNX Runtime and Triton ONNXRuntime Backend

| Key in Triton model configuration | Value in Triton model config | Corresponding TensorRT EP option in ONNX Runtime | Type |
| --------------------------------- | --------------------------------------------------- | :----------------------------------------------- | :----- |
| max_workspace_size_bytes | e.g: "4294967296" | trt_max_workspace_size | int |
| trt_max_partition_iterations | e.g: "1000" | trt_max_partition_iterations | int |
| trt_min_subgraph_size | e.g: "1" | trt_min_subgraph_size | int |
| precision_mode | "FP16" | trt_fp16_enable | bool |
| precision_mode | "INT8" | trt_int8_enable | bool |
| int8_calibration_table_name | | trt_int8_calibration_table_name | string |
| int8_use_native_calibration_table | e.g: "1" or "true", "0" or "false" | trt_int8_use_native_calibration_table | bool |
| trt_dla_enable | | trt_dla_enable | bool |
| trt_dla_core | e.g: "0" | trt_dla_core | int |
| trt_engine_cache_enable | e.g: "1" or "true", "0" or "false" | trt_engine_cache_enable | bool |
| trt_engine_cache_path | | trt_engine_cache_path | string |
| trt_engine_cache_prefix | | trt_engine_cache_prefix | string |
| trt_dump_subgraphs | e.g: "1" or "true", "0" or "false" | trt_dump_subgraphs | bool |
| trt_force_sequential_engine_build | e.g: "1" or "true", "0" or "false" | trt_force_sequential_engine_build | bool |
| trt_context_memory_sharing_enable | e.g: "1" or "true", "0" or "false" | trt_context_memory_sharing_enable | bool |
| trt_layer_norm_fp32_fallback | e.g: "1" or "true", "0" or "false" | trt_layer_norm_fp32_fallback | bool |
| trt_timing_cache_enable | e.g: "1" or "true", "0" or "false" | trt_timing_cache_enable | bool |
| trt_timing_cache_path | | trt_timing_cache_path | string |
| trt_force_timing_cache | e.g: "1" or "true", "0" or "false" | trt_force_timing_cache | bool |
| trt_detailed_build_log | e.g: "1" or "true", "0" or "false" | trt_detailed_build_log | bool |
| trt_build_heuristics_enable | e.g: "1" or "true", "0" or "false" | trt_build_heuristics_enable | bool |
| trt_sparsity_enable | e.g: "1" or "true", "0" or "false" | trt_sparsity_enable | bool |
| trt_builder_optimization_level | e.g: "3" | trt_builder_optimization_level | int |
| trt_auxiliary_streams | e.g: "-1" | trt_auxiliary_streams | int |
| trt_tactic_sources | e.g: "-CUDNN,+CUBLAS"; | trt_tactic_sources | string |
| trt_extra_plugin_lib_paths | | trt_extra_plugin_lib_paths | string |
| trt_profile_min_shapes | e.g: "input1:dim1xdimd2...,input2:dim1xdim2...,..." | trt_profile_min_shapes | string |
| trt_profile_max_shapes | e.g: "input1:dim1xdimd2...,input2:dim1xdim2...,..." | trt_profile_max_shapes | string |
| trt_profile_opt_shapes | e.g: "input1:dim1xdimd2...,input2:dim1xdim2...,..." | trt_profile_opt_shapes | string |
| trt_cuda_graph_enable | e.g: "1" or "true", "0" or "false" | trt_cuda_graph_enable | bool |
| trt_dump_ep_context_model | e.g: "1" or "true", "0" or "false" | trt_dump_ep_context_model | bool |
| trt_ep_context_file_path | | trt_ep_context_file_path | string |
| trt_ep_context_embed_mode | e.g: "1" | trt_ep_context_embed_mode | int |

The section of model config file specifying these parameters will look like:

```
Expand All @@ -104,6 +146,7 @@ optimization { execution_accelerators {
name : "tensorrt"
parameters { key: "precision_mode" value: "FP16" }
parameters { key: "max_workspace_size_bytes" value: "1073741824" }}
parameters { key: "trt_engine_cache_enable" value: "1" }}
]
}}
.
Expand Down
175 changes: 175 additions & 0 deletions src/onnxruntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,22 @@ ModelState::LoadModel(
value_string, &max_workspace_size_bytes));
key = "trt_max_workspace_size";
value = value_string;
} else if (param_key == "trt_max_partition_iterations") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
int trt_max_partition_iterations;
RETURN_IF_ERROR(ParseIntValue(
value_string, &trt_max_partition_iterations));
key = "trt_max_partition_iterations";
value = value_string;
} else if (param_key == "trt_min_subgraph_size") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
int trt_min_subgraph_size;
RETURN_IF_ERROR(
ParseIntValue(value_string, &trt_min_subgraph_size));
key = "trt_min_subgraph_size";
value = value_string;
} else if (param_key == "int8_calibration_table_name") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
Expand All @@ -485,6 +501,21 @@ ModelState::LoadModel(
value_string, &use_native_calibration_table));
key = "trt_int8_use_native_calibration_table";
value = value_string;
} else if (param_key == "trt_dla_enable") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_dla_enable;
RETURN_IF_ERROR(
ParseBoolValue(value_string, &trt_dla_enable));
key = "trt_dla_enable";
value = value_string;
} else if (param_key == "trt_dla_core") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
int trt_dla_core;
RETURN_IF_ERROR(ParseIntValue(value_string, &trt_dla_core));
key = "trt_dla_core";
value = value_string;
} else if (param_key == "trt_engine_cache_enable") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
Expand All @@ -497,6 +528,150 @@ ModelState::LoadModel(
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_engine_cache_path";
} else if (param_key == "trt_engine_cache_prefix") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_engine_cache_prefix";
} else if (param_key == "trt_dump_subgraphs") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool dump_subgraphs;
RETURN_IF_ERROR(
ParseBoolValue(value_string, &dump_subgraphs));
key = "trt_dump_subgraphs";
value = value_string;
} else if (param_key == "trt_force_sequential_engine_build") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_force_sequential_engine_build;
RETURN_IF_ERROR(ParseBoolValue(
value_string, &trt_force_sequential_engine_build));
key = "trt_force_sequential_engine_build";
value = value_string;
} else if (param_key == "trt_context_memory_sharing_enable") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_context_memory_sharing_enable;
RETURN_IF_ERROR(ParseBoolValue(
value_string, &trt_context_memory_sharing_enable));
key = "trt_context_memory_sharing_enable";
value = value_string;
} else if (param_key == "trt_layer_norm_fp32_fallback") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_layer_norm_fp32_fallback;
RETURN_IF_ERROR(ParseBoolValue(
value_string, &trt_layer_norm_fp32_fallback));
key = "trt_layer_norm_fp32_fallback";
value = value_string;
} else if (param_key == "trt_timing_cache_enable") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_timing_cache_enable;
RETURN_IF_ERROR(
ParseBoolValue(value_string, &trt_timing_cache_enable));
key = "trt_timing_cache_enable";
value = value_string;
} else if (param_key == "trt_timing_cache_path") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_timing_cache_path";
} else if (param_key == "trt_force_timing_cache") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_force_timing_cache;
RETURN_IF_ERROR(
ParseBoolValue(value_string, &trt_force_timing_cache));
key = "trt_force_timing_cache";
value = value_string;
} else if (param_key == "trt_detailed_build_log") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_detailed_build_log;
RETURN_IF_ERROR(
ParseBoolValue(value_string, &trt_detailed_build_log));
key = "trt_detailed_build_log";
value = value_string;
} else if (param_key == "trt_build_heuristics_enable") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_build_heuristics_enable;
RETURN_IF_ERROR(ParseBoolValue(
value_string, &trt_build_heuristics_enable));
key = "trt_build_heuristics_enable";
value = value_string;
} else if (param_key == "trt_sparsity_enable") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_sparsity_enable;
RETURN_IF_ERROR(
ParseBoolValue(value_string, &trt_sparsity_enable));
key = "trt_sparsity_enable";
value = value_string;
} else if (param_key == "trt_builder_optimization_level") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
int trt_builder_optimization_level;
RETURN_IF_ERROR(ParseIntValue(
value_string, &trt_builder_optimization_level));
key = "trt_builder_optimization_level";
value = value_string;
} else if (param_key == "trt_auxiliary_streams") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
int trt_auxiliary_streams;
RETURN_IF_ERROR(
ParseIntValue(value_string, &trt_auxiliary_streams));
key = "trt_auxiliary_streams";
value = value_string;
} else if (param_key == "trt_tactic_sources") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_tactic_sources";
} else if (param_key == "trt_extra_plugin_lib_paths") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_extra_plugin_lib_paths";
} else if (param_key == "trt_profile_min_shapes") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_profile_min_shapes";
} else if (param_key == "trt_profile_max_shapes") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_profile_max_shapes";
} else if (param_key == "trt_profile_opt_shapes") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_profile_opt_shapes";
} else if (param_key == "trt_cuda_graph_enable") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_cuda_graph_enable;
RETURN_IF_ERROR(
ParseBoolValue(value_string, &trt_cuda_graph_enable));
key = "trt_cuda_graph_enable";
value = value_string;
} else if (param_key == "trt_dump_ep_context_model") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_dump_ep_context_model;
RETURN_IF_ERROR(ParseBoolValue(
value_string, &trt_dump_ep_context_model));
key = "trt_dump_ep_context_model";
value = value_string;
} else if (param_key == "trt_ep_context_file_path") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_ep_context_file_path";
} else if (param_key == "trt_ep_context_embed_mode") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
int trt_ep_context_embed_mode;
RETURN_IF_ERROR(ParseIntValue(
value_string, &trt_ep_context_embed_mode));
key = "trt_ep_context_embed_mode";
value = value_string;
} else {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
Expand Down

0 comments on commit 72189f9

Please sign in to comment.