From cc5740fd31d0522bd65d6a0fe3924fe85fea9ce2 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Sat, 13 Jan 2024 10:21:42 +0100 Subject: [PATCH] Fix trt llm workflow (#109) --- .github/workflows/test_tensorrt_llm.yaml | 13 +++-- README.md | 3 +- docker/tensorrt_llm.dockerfile | 36 ++++++++++++ examples/trt_llama.yaml | 6 +- optimum_benchmark/backends/tensorrt/config.py | 18 ------ optimum_benchmark/backends/tensorrt/utils.py | 1 - .../{tensorrt => tensorrt_llm}/__init__.py | 0 .../{tensorrt => tensorrt_llm}/backend.py | 56 ++++++++++--------- .../backends/tensorrt_llm/config.py | 46 +++++++++++++++ .../backends/tensorrt_llm/utils.py | 1 + optimum_benchmark/experiment.py | 4 +- tests/configs/tensorrt_llm_inference.yaml | 2 +- 12 files changed, 131 insertions(+), 55 deletions(-) create mode 100644 docker/tensorrt_llm.dockerfile delete mode 100644 optimum_benchmark/backends/tensorrt/config.py delete mode 100644 optimum_benchmark/backends/tensorrt/utils.py rename optimum_benchmark/backends/{tensorrt => tensorrt_llm}/__init__.py (100%) rename optimum_benchmark/backends/{tensorrt => tensorrt_llm}/backend.py (57%) create mode 100644 optimum_benchmark/backends/tensorrt_llm/config.py create mode 100644 optimum_benchmark/backends/tensorrt_llm/utils.py diff --git a/.github/workflows/test_tensorrt_llm.yaml b/.github/workflows/test_tensorrt_llm.yaml index 0454a41c..dafffe96 100644 --- a/.github/workflows/test_tensorrt_llm.yaml +++ b/.github/workflows/test_tensorrt_llm.yaml @@ -18,8 +18,13 @@ jobs: - name: Checkout uses: actions/checkout@v3 - - name: Pull image - run: docker pull huggingface/optimum-nvidia:latest + - name: Build image + run: docker build + --file docker/tensorrt_llm.dockerfile + --build-arg USER_ID=$(id -u) + --build-arg GROUP_ID=$(id -g) + --tag opt-bench-tensorrt-llm:latest + . - name: Run tests run: docker run @@ -34,5 +39,5 @@ jobs: --workdir /workspace/optimum-benchmark --gpus '"device=0,1"' --entrypoint /bin/bash - huggingface/optimum-nvidia:latest - -c "pip install -e .[test] && pytest -k 'tensorrt_llm' -x && chown -R $USER_ID:$GROUP_ID ." + opt-bench-tensorrt-llm:latest + -c "pip install -e .[test] && pytest -k 'tensorrt_llm' -x" diff --git a/README.md b/README.md index 81b5926d..eb3ed937 100644 --- a/README.md +++ b/README.md @@ -26,12 +26,13 @@ Everything else is either optional or inferred from the model's name or path. ### Supported Backends/Devices +- [x] TensorRT-LLM backend for CUDA (NVIDIA GPUs) - [x] Pytorch backend for CPU (Intel, AMD, ARM, etc) - [x] Pytorch backend for CUDA (NVIDIA and AMD GPUs) - [ ] Pytorch backend for Habana Gaudi Processor (HPU) - [x] OnnxRuntime backend for CPUExecutionProvider - [x] OnnxRuntime backend for CUDAExecutionProvider -- [ ] OnnxRuntime backend for ROCMExecutionProvider +- [x] OnnxRuntime backend for ROCMExecutionProvider - [x] OnnxRuntime backend for TensorrtExecutionProvider - [x] Intel Neural Compressor backend for CPU - [x] OpenVINO backend for CPU diff --git a/docker/tensorrt_llm.dockerfile b/docker/tensorrt_llm.dockerfile new file mode 100644 index 00000000..15e0299e --- /dev/null +++ b/docker/tensorrt_llm.dockerfile @@ -0,0 +1,36 @@ +# Copyright 2023 The HuggingFace Team All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM huggingface/optimum-nvidia:latest + +# Ignore interactive questions during `docker build` +ENV DEBIAN_FRONTEND noninteractive + +# Run as non-root user +ARG USER_ID +ARG GROUP_ID + +RUN addgroup --gid $GROUP_ID user +RUN adduser --disabled-password --gecos '' --uid $USER_ID --gid $GROUP_ID user + +# Add local bin to PATH +ENV PATH="/home/user/.local/bin:${PATH}" + +# Add user to sudoers +RUN adduser user sudo +RUN echo '%sudo ALL=(ALL) NOPASSWD:ALL' >>/etc/sudoers + +# Change user +USER user +WORKDIR /home/user diff --git a/examples/trt_llama.yaml b/examples/trt_llama.yaml index 982ad4bb..e0f168d0 100644 --- a/examples/trt_llama.yaml +++ b/examples/trt_llama.yaml @@ -1,7 +1,7 @@ defaults: - - backend: tensorrt # default backend - - launcher: process # default launcher - - benchmark: inference # default benchmark + - launcher: process + - benchmark: inference + - backend: tensorrt-llm - experiment # inheriting experiment schema - _self_ # for hydra 1.1 compatibility - override hydra/job_logging: colorlog # colorful logging diff --git a/optimum_benchmark/backends/tensorrt/config.py b/optimum_benchmark/backends/tensorrt/config.py deleted file mode 100644 index 544f1e0b..00000000 --- a/optimum_benchmark/backends/tensorrt/config.py +++ /dev/null @@ -1,18 +0,0 @@ -from dataclasses import dataclass -from logging import getLogger - -from omegaconf import OmegaConf - -from ...import_utils import tesnorrt_version -from ..config import BackendConfig - -LOGGER = getLogger("tensorrt") - -OmegaConf.register_new_resolver("tensorrt_version", tesnorrt_version) - - -@dataclass -class TRTConfig(BackendConfig): - name: str = "tensorrt" - version: str = "${tensorrt_version:}" - _target_: str = "optimum_benchmark.backends.tensorrt.backend.TRTBackend" diff --git a/optimum_benchmark/backends/tensorrt/utils.py b/optimum_benchmark/backends/tensorrt/utils.py deleted file mode 100644 index e2827309..00000000 --- a/optimum_benchmark/backends/tensorrt/utils.py +++ /dev/null @@ -1 +0,0 @@ -MODEL_TYPE_TO_TRTMODEL = {"llama": "optimum.nvidia.models.llama.LlamaForCausalLM"} diff --git a/optimum_benchmark/backends/tensorrt/__init__.py b/optimum_benchmark/backends/tensorrt_llm/__init__.py similarity index 100% rename from optimum_benchmark/backends/tensorrt/__init__.py rename to optimum_benchmark/backends/tensorrt_llm/__init__.py diff --git a/optimum_benchmark/backends/tensorrt/backend.py b/optimum_benchmark/backends/tensorrt_llm/backend.py similarity index 57% rename from optimum_benchmark/backends/tensorrt/backend.py rename to optimum_benchmark/backends/tensorrt_llm/backend.py index 1832d6f9..dbf003f8 100644 --- a/optimum_benchmark/backends/tensorrt/backend.py +++ b/optimum_benchmark/backends/tensorrt_llm/backend.py @@ -5,14 +5,14 @@ from transformers.utils import ModelOutput from ..base import Backend -from .config import TRTConfig -from .utils import MODEL_TYPE_TO_TRTMODEL +from .config import TRTLLMConfig +from .utils import MODEL_TYPE_TO_TRTLLMMODEL -LOGGER = getLogger("tensorrt") +LOGGER = getLogger("tensorrt-llm") -class TRTBackend(Backend): - NAME: str = "tensorrt" +class TRTLLMBackend(Backend): + NAME = "tensorrt-llm" def __init__(self, model: str, task: str, device: str, hub_kwargs: Dict[str, Any]) -> None: super().__init__(model, task, device, hub_kwargs) @@ -20,33 +20,38 @@ def __init__(self, model: str, task: str, device: str, hub_kwargs: Dict[str, Any self.validate_model_type() def validate_model_type(self) -> None: - if self.model_type not in MODEL_TYPE_TO_TRTMODEL: - raise NotImplementedError(f"TRTBackend does not support model_type {self.model_type}") + if self.model_type not in MODEL_TYPE_TO_TRTLLMMODEL: + raise NotImplementedError(f"TRTLLMBackend does not support model_type {self.model_type}") def validate_device(self) -> None: if self.device != "cuda": - raise NotImplementedError(f"TRTBackend only supports device cuda, got {self.device}") + raise NotImplementedError(f"TRTLLMBackend only supports device cuda, got {self.device}") - def configure(self, config: TRTConfig) -> None: + def configure(self, config: TRTLLMConfig) -> None: super().configure(config) - self.trtmodel_class = get_class(MODEL_TYPE_TO_TRTMODEL[self.model_type]) + self.trtmodel_class = get_class(MODEL_TYPE_TO_TRTLLMMODEL[self.model_type]) ortmodel_name = self.trtmodel_class.__name__ LOGGER.info( - f"\t+ Inferred TRTModel class {ortmodel_name} for task {self.task} and model_type {self.model_type}" + f"\t+ Inferred TRTLLMModel class {ortmodel_name} for task {self.task} and model_type {self.model_type}" ) - # TODO: save engine path for reuse, then maybe re build with max_prompt_size self.load_trtmodel_from_pretrained() - @property - def trtmodel_kwargs(self) -> Dict[str, Any]: - return {} - def load_trtmodel_from_pretrained(self) -> None: self.pretrained_model = self.trtmodel_class.from_pretrained( self.model, - **self.trtmodel_kwargs, + tp=self.config.tp, + pp=self.config.pp, + dtype=self.config.dtype, + use_fp8=self.config.use_fp8, + world_size=self.config.world_size, + gpus_per_node=self.config.gpus_per_node, + use_cuda_graph=self.config.use_cuda_graph, + optimization_level=self.config.optimization_level, + max_prompt_length=self.config.max_prompt_length, + max_batch_size=self.config.max_batch_size, + max_new_tokens=self.config.max_new_tokens, **self.hub_kwargs, ) @@ -59,19 +64,20 @@ def forward(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> ModelOutput: def generate(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> ModelOutput: return self.pretrained_model.generate( - # spelling args to avoid conflict - input_ids=input.get("inputs", None), # diff api + input_ids=input.get("inputs", None), # diff names attention_mask=input.get("attention_mask", None), + # important for benchmarking max_new_tokens=kwargs.get("max_new_tokens", -1), - min_length=kwargs.get("min_new_tokens", -1), # diff api + min_length=kwargs.get("min_new_tokens", -1), # why different ? num_beams=kwargs.get("num_beams", 1), - temperature=kwargs.get("temperature", 1.0), - top_k=kwargs.get("top_k", 50), - top_p=kwargs.get("top_p", 1.0), - repetition_penalty=kwargs.get("repetition_penalty", 1.0), + # not really important but just in case + repetition_penalty=kwargs.get("repetition_penalty", 0.0), length_penalty=kwargs.get("length_penalty", 1.0), - seed=kwargs.get("seed", 42), pad_token_id=kwargs.get("pad_token_id", 0), bos_token_id=kwargs.get("bos_token_id", 1), eos_token_id=kwargs.get("eos_token_id", 2), + temperature=kwargs.get("temperature", 1.0), + top_k=kwargs.get("top_k", 50), + top_p=kwargs.get("top_p", 1.0), + seed=kwargs.get("seed", 42), ) diff --git a/optimum_benchmark/backends/tensorrt_llm/config.py b/optimum_benchmark/backends/tensorrt_llm/config.py new file mode 100644 index 00000000..54ee48bd --- /dev/null +++ b/optimum_benchmark/backends/tensorrt_llm/config.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass +from logging import getLogger + +from omegaconf import OmegaConf + +from ...import_utils import tesnorrt_version +from ..config import BackendConfig + +LOGGER = getLogger("tensorrt-llm") + +OmegaConf.register_new_resolver("tensorrt_llm_version", tesnorrt_version) + +SUPPORTED_DTYPES = ["float16", "bfloat16", "float32"] + + +@dataclass +class TRTLLMConfig(BackendConfig): + name: str = "tensorrt_llm" + version: str = "${tensorrt_llm_version:}" + _target_: str = "optimum_benchmark.backends.tensorrt_llm.backend.TRTLLMBackend" + + # build config + tp: int = 1 + pp: int = 1 + use_fp8: bool = False + dtype: str = "float16" + optimization_level: int = 2 + use_cuda_graph: bool = False + gpus_per_node: int = "${available_gpus:}" + world_size: int = "${backend.gpus_per_node}" + + max_batch_size: int = "${benchmark.input_shapes.batch_size}" + max_prompt_length: int = "${benchmark.input_shapes.sequence_length}" + max_new_tokens: int = "${benchmark.new_tokens}" + + def __post_init__(self) -> None: + super().__post_init__() + + if self.dtype not in SUPPORTED_DTYPES: + raise ValueError(f"dtype must be one of float16, bfloat16, float32, got {self.dtype}") + + if self.gpus_per_node != self.world_size: + raise ValueError(f"gpus_per_node ({self.gpus_per_node}) != world_size ({self.world_size})") + + if self.world_size != self.pp * self.tp: + raise ValueError(f"world_size ({self.gpus_per_node}) != pp ({self.pp}) * tp ({self.tp})") diff --git a/optimum_benchmark/backends/tensorrt_llm/utils.py b/optimum_benchmark/backends/tensorrt_llm/utils.py new file mode 100644 index 00000000..4574da53 --- /dev/null +++ b/optimum_benchmark/backends/tensorrt_llm/utils.py @@ -0,0 +1 @@ +MODEL_TYPE_TO_TRTLLMMODEL = {"llama": "optimum.nvidia.models.llama.LlamaForCausalLM"} diff --git a/optimum_benchmark/experiment.py b/optimum_benchmark/experiment.py index 1a8b9995..5a687fbe 100644 --- a/optimum_benchmark/experiment.py +++ b/optimum_benchmark/experiment.py @@ -12,7 +12,7 @@ from .backends.onnxruntime.config import ORTConfig from .backends.openvino.config import OVConfig from .backends.pytorch.config import PyTorchConfig -from .backends.tensorrt.config import TRTConfig +from .backends.tensorrt_llm.config import TRTLLMConfig from .backends.text_generation_inference.config import TGIConfig from .benchmarks.inference.config import InferenceConfig from .benchmarks.training.config import TrainingConfig @@ -130,9 +130,9 @@ def __post_init__(self) -> None: cs.store(name="experiment", node=ExperimentConfig) # cs.store(group="backend", name="openvino", node=OVConfig) -cs.store(group="backend", name="tensorrt", node=TRTConfig) cs.store(group="backend", name="pytorch", node=PyTorchConfig) cs.store(group="backend", name="onnxruntime", node=ORTConfig) +cs.store(group="backend", name="tensorrt-llm", node=TRTLLMConfig) cs.store(group="backend", name="neural-compressor", node=INCConfig) cs.store(group="backend", name="text-generation-inference", node=TGIConfig) # diff --git a/tests/configs/tensorrt_llm_inference.yaml b/tests/configs/tensorrt_llm_inference.yaml index bae6529a..5f14705a 100644 --- a/tests/configs/tensorrt_llm_inference.yaml +++ b/tests/configs/tensorrt_llm_inference.yaml @@ -1,5 +1,5 @@ defaults: - - backend: tensorrt # default backend + - backend: tensorrt-llm # default backend # order of inheritance, last one overrides previous ones - _base_ # inherits from base config - _inference_ # inherits from inference config