Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Enable more trt options #237

Merged
merged 32 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,48 @@ TensorRT can be used in conjunction with an ONNX model to further optimize the p
* `trt_engine_cache_enable`: Enable engine caching.
* `trt_engine_cache_path`: Specify engine cache path.

To explore the usage of more parameters, follow the mapping table below and check [ONNX Runtime doc](https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#execution-provider-options) for detail.

> Please link to the latest ONNX Runtime binaries in CMake or build from [main branch of ONNX Runtime](https://github.com/microsoft/onnxruntime/tree/main) to enable latest options.
### Parameter mapping between ONNX Runtime and Triton ONNXRuntime Backend

| Key in Triton model configuration | Value in Triton model config | Corresponding TensorRT EP option in ONNX Runtime | Type |
| --------------------------------- | --------------------------------------------------- | :----------------------------------------------- | :----- |
| max_workspace_size_bytes | e.g: "4294967296" | trt_max_workspace_size | int |
| trt_max_partition_iterations | e.g: "1000" | trt_max_partition_iterations | int |
| trt_min_subgraph_size | e.g: "1" | trt_min_subgraph_size | int |
| precision_mode | "FP16" | trt_fp16_enable | bool |
| precision_mode | "INT8" | trt_int8_enable | bool |
| int8_calibration_table_name | | trt_int8_calibration_table_name | string |
| int8_use_native_calibration_table | e.g: "1" or "true", "0" or "false" | trt_int8_use_native_calibration_table | bool |
| trt_dla_enable | | trt_dla_enable | bool |
| trt_dla_core | e.g: "0" | trt_dla_core | int |
| trt_engine_cache_enable | e.g: "1" or "true", "0" or "false" | trt_engine_cache_enable | bool |
| trt_engine_cache_path | | trt_engine_cache_path | string |
| trt_engine_cache_prefix | | trt_engine_cache_prefix | string |
| trt_dump_subgraphs | e.g: "1" or "true", "0" or "false" | trt_dump_subgraphs | bool |
| trt_force_sequential_engine_build | e.g: "1" or "true", "0" or "false" | trt_force_sequential_engine_build | bool |
| trt_context_memory_sharing_enable | e.g: "1" or "true", "0" or "false" | trt_context_memory_sharing_enable | bool |
| trt_layer_norm_fp32_fallback | e.g: "1" or "true", "0" or "false" | trt_layer_norm_fp32_fallback | bool |
| trt_timing_cache_enable | e.g: "1" or "true", "0" or "false" | trt_timing_cache_enable | bool |
| trt_timing_cache_path | | trt_timing_cache_path | string |
| trt_force_timing_cache | e.g: "1" or "true", "0" or "false" | trt_force_timing_cache | bool |
| trt_detailed_build_log | e.g: "1" or "true", "0" or "false" | trt_detailed_build_log | bool |
| trt_build_heuristics_enable | e.g: "1" or "true", "0" or "false" | trt_build_heuristics_enable | bool |
| trt_sparsity_enable | e.g: "1" or "true", "0" or "false" | trt_sparsity_enable | bool |
| trt_builder_optimization_level | e.g: "3" | trt_builder_optimization_level | int |
| trt_auxiliary_streams | e.g: "-1" | trt_auxiliary_streams | int |
| trt_tactic_sources | e.g: "-CUDNN,+CUBLAS"; | trt_tactic_sources | string |
| trt_extra_plugin_lib_paths | | trt_extra_plugin_lib_paths | string |
| trt_profile_min_shapes | e.g: "input1:dim1xdimd2...,input2:dim1xdim2...,..." | trt_profile_min_shapes | string |
| trt_profile_max_shapes | e.g: "input1:dim1xdimd2...,input2:dim1xdim2...,..." | trt_profile_max_shapes | string |
| trt_profile_opt_shapes | e.g: "input1:dim1xdimd2...,input2:dim1xdim2...,..." | trt_profile_opt_shapes | string |
| trt_cuda_graph_enable | e.g: "1" or "true", "0" or "false" | trt_cuda_graph_enable | bool |
| trt_dump_ep_context_model | e.g: "1" or "true", "0" or "false" | trt_dump_ep_context_model | bool |
| trt_ep_context_file_path | | trt_ep_context_file_path | string |
| trt_ep_context_embed_mode | e.g: "1" | trt_ep_context_embed_mode | int |

The section of model config file specifying these parameters will look like:

```
Expand All @@ -104,6 +146,7 @@ optimization { execution_accelerators {
name : "tensorrt"
parameters { key: "precision_mode" value: "FP16" }
parameters { key: "max_workspace_size_bytes" value: "1073741824" }}
parameters { key: "trt_engine_cache_enable" value: "1" }}
]
}}
.
Expand Down
175 changes: 175 additions & 0 deletions src/onnxruntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,22 @@ ModelState::LoadModel(
value_string, &max_workspace_size_bytes));
key = "trt_max_workspace_size";
value = value_string;
} else if (param_key == "trt_max_partition_iterations") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
int trt_max_partition_iterations;
RETURN_IF_ERROR(ParseIntValue(
value_string, &trt_max_partition_iterations));
key = "trt_max_partition_iterations";
value = value_string;
} else if (param_key == "trt_min_subgraph_size") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
int trt_min_subgraph_size;
RETURN_IF_ERROR(
ParseIntValue(value_string, &trt_min_subgraph_size));
key = "trt_min_subgraph_size";
value = value_string;
} else if (param_key == "int8_calibration_table_name") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
Expand All @@ -485,6 +501,21 @@ ModelState::LoadModel(
value_string, &use_native_calibration_table));
key = "trt_int8_use_native_calibration_table";
value = value_string;
} else if (param_key == "trt_dla_enable") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_dla_enable;
RETURN_IF_ERROR(
ParseBoolValue(value_string, &trt_dla_enable));
key = "trt_dla_enable";
value = value_string;
} else if (param_key == "trt_dla_core") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
int trt_dla_core;
RETURN_IF_ERROR(ParseIntValue(value_string, &trt_dla_core));
key = "trt_dla_core";
value = value_string;
} else if (param_key == "trt_engine_cache_enable") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
Expand All @@ -497,6 +528,150 @@ ModelState::LoadModel(
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_engine_cache_path";
} else if (param_key == "trt_engine_cache_prefix") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_engine_cache_prefix";
} else if (param_key == "trt_dump_subgraphs") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool dump_subgraphs;
RETURN_IF_ERROR(
ParseBoolValue(value_string, &dump_subgraphs));
key = "trt_dump_subgraphs";
value = value_string;
} else if (param_key == "trt_force_sequential_engine_build") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_force_sequential_engine_build;
RETURN_IF_ERROR(ParseBoolValue(
value_string, &trt_force_sequential_engine_build));
key = "trt_force_sequential_engine_build";
value = value_string;
} else if (param_key == "trt_context_memory_sharing_enable") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_context_memory_sharing_enable;
RETURN_IF_ERROR(ParseBoolValue(
value_string, &trt_context_memory_sharing_enable));
key = "trt_context_memory_sharing_enable";
value = value_string;
} else if (param_key == "trt_layer_norm_fp32_fallback") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_layer_norm_fp32_fallback;
RETURN_IF_ERROR(ParseBoolValue(
value_string, &trt_layer_norm_fp32_fallback));
key = "trt_layer_norm_fp32_fallback";
value = value_string;
} else if (param_key == "trt_timing_cache_enable") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_timing_cache_enable;
RETURN_IF_ERROR(
ParseBoolValue(value_string, &trt_timing_cache_enable));
key = "trt_timing_cache_enable";
value = value_string;
} else if (param_key == "trt_timing_cache_path") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_timing_cache_path";
} else if (param_key == "trt_force_timing_cache") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_force_timing_cache;
RETURN_IF_ERROR(
ParseBoolValue(value_string, &trt_force_timing_cache));
key = "trt_force_timing_cache";
value = value_string;
} else if (param_key == "trt_detailed_build_log") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_detailed_build_log;
RETURN_IF_ERROR(
ParseBoolValue(value_string, &trt_detailed_build_log));
key = "trt_detailed_build_log";
value = value_string;
} else if (param_key == "trt_build_heuristics_enable") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_build_heuristics_enable;
RETURN_IF_ERROR(ParseBoolValue(
value_string, &trt_build_heuristics_enable));
key = "trt_build_heuristics_enable";
value = value_string;
} else if (param_key == "trt_sparsity_enable") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_sparsity_enable;
RETURN_IF_ERROR(
ParseBoolValue(value_string, &trt_sparsity_enable));
key = "trt_sparsity_enable";
value = value_string;
} else if (param_key == "trt_builder_optimization_level") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
int trt_builder_optimization_level;
RETURN_IF_ERROR(ParseIntValue(
value_string, &trt_builder_optimization_level));
key = "trt_builder_optimization_level";
value = value_string;
} else if (param_key == "trt_auxiliary_streams") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
int trt_auxiliary_streams;
RETURN_IF_ERROR(
ParseIntValue(value_string, &trt_auxiliary_streams));
key = "trt_auxiliary_streams";
value = value_string;
} else if (param_key == "trt_tactic_sources") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_tactic_sources";
} else if (param_key == "trt_extra_plugin_lib_paths") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_extra_plugin_lib_paths";
} else if (param_key == "trt_profile_min_shapes") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_profile_min_shapes";
} else if (param_key == "trt_profile_max_shapes") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_profile_max_shapes";
} else if (param_key == "trt_profile_opt_shapes") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_profile_opt_shapes";
} else if (param_key == "trt_cuda_graph_enable") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_cuda_graph_enable;
RETURN_IF_ERROR(
ParseBoolValue(value_string, &trt_cuda_graph_enable));
key = "trt_cuda_graph_enable";
value = value_string;
} else if (param_key == "trt_dump_ep_context_model") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
bool trt_dump_ep_context_model;
RETURN_IF_ERROR(ParseBoolValue(
value_string, &trt_dump_ep_context_model));
key = "trt_dump_ep_context_model";
value = value_string;
} else if (param_key == "trt_ep_context_file_path") {
RETURN_IF_ERROR(
params.MemberAsString(param_key.c_str(), &value));
key = "trt_ep_context_file_path";
} else if (param_key == "trt_ep_context_embed_mode") {
RETURN_IF_ERROR(params.MemberAsString(
param_key.c_str(), &value_string));
int trt_ep_context_embed_mode;
RETURN_IF_ERROR(ParseIntValue(
value_string, &trt_ep_context_embed_mode));
key = "trt_ep_context_embed_mode";
value = value_string;
} else {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
Expand Down
Loading