Skip to content

Commit

Permalink
Merge pull request #55 from huggingface/trt-engine
Browse files Browse the repository at this point in the history
TRT dynamic shapes for text gen models
  • Loading branch information
IlyasMoutawwakil authored Sep 15, 2023
2 parents 9c672c2 + a76aa50 commit d0194d1
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 47 deletions.
10 changes: 5 additions & 5 deletions optimum_benchmark/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,13 @@ def prepare_input(self, input: Dict[str, Any]) -> Dict[str, Any]:

return input

# compiling in openvino requires input shapes
def prepare_for_inference(self, input_shapes: Dict[str, int]) -> Dict[str, Any]:
# compiling in openvino requires input shapes, trt ep requires max tokens, etc.
def prepare_for_inference(self, **kwargs) -> None:
pass

# symbolic tracing in transformers requires input names
def prepare_for_profiling(self, input_names: List[str]) -> Dict[str, Any]:
pass
# # symbolic tracing in transformers requires input names
# def prepare_for_profiling(self, input_names: List[str]) -> Dict[str, Any]:
# pass

def forward(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> "ModelOutput":
return self.pretrained_model(**input, **kwargs)
Expand Down
26 changes: 22 additions & 4 deletions optimum_benchmark/backends/onnxruntime/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def configure(self, config: ORTConfig) -> None:
# Some statefullness to handle the different combinations of options
self.export = self.config.export
self.use_merged = self.config.use_merged
self.provider_options = self.config.provider_options.copy()

if self.is_diffusion_pipeline():
self.load_ortmodel()
Expand Down Expand Up @@ -131,11 +132,11 @@ def configure(self, config: ORTConfig) -> None:
if self.config.auto_quantization or self.config.quantization:
self.quantize_onnx_files()

self.load_ortmodel()
self.tmpdir.cleanup()
if not (self.config.provider == "TensorrtExecutionProvider" and self.is_text_generation_model()):
self.load_ortmodel()
self.tmpdir.cleanup()

def load_automodel_from_config(self) -> None:
# TODO: create no_weights tests
from accelerate import init_empty_weights

LOGGER.info("\t+ Loading AutoModel from config")
Expand Down Expand Up @@ -164,8 +165,8 @@ def load_ortmodel(self) -> None:
export=self.export,
provider=self.config.provider,
session_options=self.session_options,
provider_options=self.provider_options,
use_io_binding=self.config.use_io_binding,
provider_options=self.config.provider_options,
**self.ortmodel_kwargs,
**self.hub_kwargs,
)
Expand Down Expand Up @@ -311,6 +312,23 @@ def quantize_onnx_files(self) -> None:
)
self.model = quantized_model_path

def prepare_for_inference(self, **kwargs) -> None:
if self.config.provider == "TensorrtExecutionProvider" and self.is_text_generation_model():
max_new_tokens = kwargs["max_new_tokens"]
batch_size = kwargs["input_shapes"]["batch_size"]
sequence_length = kwargs["input_shapes"]["sequence_length"]

LOGGER.info("\t+ Creating dynamic shapes for Tensorrt engine, loading will take a while")
self.provider_options = {
**self.provider_options,
"trt_profile_min_shapes": f"input_ids:{batch_size}x{sequence_length},attention_mask:{batch_size}x{sequence_length}",
"trt_profile_max_shapes": f"input_ids:{batch_size}x{sequence_length + max_new_tokens},attention_mask:{batch_size}x{sequence_length + max_new_tokens}",
"trt_profile_opt_shapes": f"input_ids:{batch_size}x{sequence_length + max_new_tokens},attention_mask:{batch_size}x{sequence_length + max_new_tokens}",
}

self.load_ortmodel()
self.tmpdir.cleanup()

def prepare_for_profiling(self, input_names: List[str]) -> None:
LOGGER.info("Preparing model for profiling")
LOGGER.info("\t+ Wrapping model inside profiler")
Expand Down
39 changes: 22 additions & 17 deletions optimum_benchmark/backends/onnxruntime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,31 +13,29 @@
def infer_device_id(device: str) -> int:
"""Infer the device id from the given device string."""
if "cuda" in device:
# here we resolve conflicts between CUDA_VISIBLE_DEVICES and pytorch's indexing
# e.g. CUDA_VISIBLE_DEVICES=1 and device=cuda:0 should return 1
CUDA_VISIBLE_DEVICES = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if CUDA_VISIBLE_DEVICES is None:
if ":" in device:
return int(device.split(":")[1])
else:
return 0
if ":" in device:
# either CUDA_VISIBLE_DEVICES is set or device is set to cuda:0
return int(device.split(":")[1])
else:
if ":" in device:
return int(CUDA_VISIBLE_DEVICES.split(",")[int(device.split(":")[1])])
else:
return int(CUDA_VISIBLE_DEVICES.split(",")[0])
# device is set to cuda
return 0
elif device == "cpu":
return -1
else:
raise ValueError(f"Unknown device: {device}")


DEVICE_PROVIDER_MAP = {
"cpu": "CPUExecutionProvider",
"cuda": "CUDAExecutionProvider",
}

OmegaConf.register_new_resolver("onnxruntime_version", onnxruntime_version)
OmegaConf.register_new_resolver("is_gpu", lambda device: "cuda" in device)
OmegaConf.register_new_resolver("infer_device_id", lambda device: infer_device_id(device))
OmegaConf.register_new_resolver("infer_provider", lambda device: DEVICE_PROVIDER_MAP[device])
OmegaConf.register_new_resolver("is_profiling", lambda benchmark_name: benchmark_name == "profiling")
OmegaConf.register_new_resolver(
"infer_provider", lambda device: "CPUExecutionProvider" if device == "cpu" else "CUDAExecutionProvider"
"io_bind", lambda provider: provider in ["CPUExecutionProvider", "CUDAExecutionProvider"]
)


Expand Down Expand Up @@ -97,6 +95,11 @@ def infer_device_id(device: str) -> int:
"preprocess_class": "optimum_benchmark.preprocessors.glue.GluePreprocessor",
}

TRT_PROVIDER_OPTIONS = {
"trt_engine_cache_enable": True,
"trt_engine_cache_path": "tmp/trt_cache",
}


@dataclass
class ORTConfig(BackendConfig):
Expand All @@ -114,12 +117,10 @@ class ORTConfig(BackendConfig):

# provider options
provider: str = "${infer_provider:${device}}"
device_id: Optional[int] = "${oc.deprecated:backend.provider_options.device_id}"
provider_options: Dict[str, Any] = field(default_factory=lambda: {"device_id": "${infer_device_id:${device}}"})

# inference options
use_io_binding: bool = "${is_gpu:${device}}"
enable_profiling: bool = "${oc.deprecated:backend.session_options.enable_profiling}"
use_io_binding: bool = "${io_bind:${device}}"
session_options: Dict[str, Any] = field(
default_factory=lambda: {"enable_profiling": "${is_profiling:${benchmark.name}}"}
)
Expand Down Expand Up @@ -161,6 +162,10 @@ def __post_init__(self):
if not self.no_weights and not self.export and self.torch_dtype is not None:
raise NotImplementedError("Can't convert an exported model's weights to a different dtype.")

if self.provider == "TensorrtExecutionProvider":
self.provider_options = OmegaConf.to_object(OmegaConf.merge(TRT_PROVIDER_OPTIONS, self.provider_options))
os.makedirs(self.provider_options["trt_engine_cache_path"], exist_ok=True)

if self.optimization:
self.optimization_config = OmegaConf.to_object(
OmegaConf.merge(OPTIMIZATION_CONFIG, self.optimization_config)
Expand Down
3 changes: 2 additions & 1 deletion optimum_benchmark/backends/openvino/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def quantize_automodel(self) -> None:
)
self.model = quantized_model_path

def prepare_for_inference(self, input_shapes: Dict[str, int]) -> None:
def prepare_for_inference(self, **kwargs) -> None:
input_shapes = kwargs["input_shapes"]
if self.config.reshape:
static_shapes = {
key: value
Expand Down
9 changes: 6 additions & 3 deletions optimum_benchmark/benchmarks/inference/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def run(self, backend: "Backend") -> None:
input_shapes=self.config.input_shapes,
)

# openvino requires compiling with static shapes and trt ep requires max tokens
backend.prepare_for_inference(
input_shapes=self.config.input_shapes,
max_new_tokens=self.config.generate_kwargs.get("max_new_tokens", 0),
)

# run forward pass tracking
self.run_forward_tracking(backend)

Expand All @@ -60,9 +66,6 @@ def run_forward_tracking(self, backend: "Backend") -> None:
LOGGER.info("\t+ Preparing input for the forward pass")
forward_input = backend.prepare_input(forward_input)

# for backends that require compilation with static shapes
backend.prepare_for_inference(input_shapes=self.config.input_shapes)

LOGGER.info("\t+ Warming up the forward pass")
for _ in range(self.config.warmup_runs):
_ = backend.forward(forward_input, self.config.forward_kwargs)
Expand Down
23 changes: 6 additions & 17 deletions optimum_benchmark/benchmarks/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ class InferenceConfig(BenchmarkConfig):
# benchmark options
duration: int = 10
warmup_runs: int = 10
benchmark_duration: Optional[int] = None # deprecated

# additional/optional metrics
memory: bool = False
Expand All @@ -57,9 +56,7 @@ class InferenceConfig(BenchmarkConfig):
},
)

# TODO: deprecate this and use `benchamrk.generate_kwargs`
new_tokens: Optional[int] = None

can_diffuse: bool = "${can_diffuse:${task}}"
can_generate: bool = "${can_generate:${task}}"

Expand All @@ -79,17 +76,9 @@ def __post_init__(self):
if self.generate_kwargs["max_new_tokens"] != self.generate_kwargs["min_new_tokens"]:
raise ValueError("`max_new_tokens` and `min_new_tokens` must be equal for fixed length output.")

if self.new_tokens is not None:
LOGGER.warning(
"The `new_tokens` option is deprecated, please use `generate_kwargs` instead. "
"`generate_kwargs.max_new_tokens` and `generate_kwargs.min_new_tokens` will be set to the value of `new_tokens`."
)
self.generate_kwargs["max_new_tokens"] = self.new_tokens
self.generate_kwargs["min_new_tokens"] = self.new_tokens

if self.benchmark_duration is not None:
LOGGER.warning(
"The `benchmark_duration` option is deprecated, please use `duration` instead. "
"`duration` will be set to the value of `benchmark_duration`."
)
self.duration = self.benchmark_duration
if self.new_tokens is not None:
LOGGER.info(
f"`new_tokens` was set to {self.new_tokens}. `max_new_tokens` and `min_new_tokens` will be set to {self.new_tokens}."
)
self.generate_kwargs["max_new_tokens"] = self.new_tokens
self.generate_kwargs["min_new_tokens"] = self.new_tokens

0 comments on commit d0194d1

Please sign in to comment.