Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fully support trt-llm #109

Merged
merged 4 commits into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions .github/workflows/test_tensorrt_llm.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@ jobs:
- name: Checkout
uses: actions/checkout@v3

- name: Pull image
run: docker pull huggingface/optimum-nvidia:latest
- name: Build image
run: docker build
--file docker/tensorrt_llm.dockerfile
--build-arg USER_ID=$(id -u)
--build-arg GROUP_ID=$(id -g)
--tag opt-bench-tensorrt-llm:latest
.

- name: Run tests
run: docker run
Expand All @@ -34,5 +39,5 @@ jobs:
--workdir /workspace/optimum-benchmark
--gpus '"device=0,1"'
--entrypoint /bin/bash
huggingface/optimum-nvidia:latest
-c "pip install -e .[test] && pytest -k 'tensorrt_llm' -x && chown -R $USER_ID:$GROUP_ID ."
opt-bench-tensorrt-llm:latest
-c "pip install -e .[test] && pytest -k 'tensorrt_llm' -x"
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ Everything else is either optional or inferred from the model's name or path.

### Supported Backends/Devices

- [x] TensorRT-LLM backend for CUDA (NVIDIA GPUs)
- [x] Pytorch backend for CPU (Intel, AMD, ARM, etc)
- [x] Pytorch backend for CUDA (NVIDIA and AMD GPUs)
- [ ] Pytorch backend for Habana Gaudi Processor (HPU)
- [x] OnnxRuntime backend for CPUExecutionProvider
- [x] OnnxRuntime backend for CUDAExecutionProvider
- [ ] OnnxRuntime backend for ROCMExecutionProvider
- [x] OnnxRuntime backend for ROCMExecutionProvider
- [x] OnnxRuntime backend for TensorrtExecutionProvider
- [x] Intel Neural Compressor backend for CPU
- [x] OpenVINO backend for CPU
Expand Down
36 changes: 36 additions & 0 deletions docker/tensorrt_llm.dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2023 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

FROM huggingface/optimum-nvidia:latest

# Ignore interactive questions during `docker build`
ENV DEBIAN_FRONTEND noninteractive

# Run as non-root user
ARG USER_ID
ARG GROUP_ID

RUN addgroup --gid $GROUP_ID user
RUN adduser --disabled-password --gecos '' --uid $USER_ID --gid $GROUP_ID user

# Add local bin to PATH
ENV PATH="/home/user/.local/bin:${PATH}"

# Add user to sudoers
RUN adduser user sudo
RUN echo '%sudo ALL=(ALL) NOPASSWD:ALL' >>/etc/sudoers

# Change user
USER user
WORKDIR /home/user
6 changes: 3 additions & 3 deletions examples/trt_llama.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- backend: tensorrt # default backend
- launcher: process # default launcher
- benchmark: inference # default benchmark
- launcher: process
- benchmark: inference
- backend: tensorrt-llm
- experiment # inheriting experiment schema
- _self_ # for hydra 1.1 compatibility
- override hydra/job_logging: colorlog # colorful logging
Expand Down
18 changes: 0 additions & 18 deletions optimum_benchmark/backends/tensorrt/config.py

This file was deleted.

1 change: 0 additions & 1 deletion optimum_benchmark/backends/tensorrt/utils.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,53 @@
from transformers.utils import ModelOutput

from ..base import Backend
from .config import TRTConfig
from .utils import MODEL_TYPE_TO_TRTMODEL
from .config import TRTLLMConfig
from .utils import MODEL_TYPE_TO_TRTLLMMODEL

LOGGER = getLogger("tensorrt")
LOGGER = getLogger("tensorrt-llm")


class TRTBackend(Backend):
NAME: str = "tensorrt"
class TRTLLMBackend(Backend):
NAME = "tensorrt-llm"

def __init__(self, model: str, task: str, device: str, hub_kwargs: Dict[str, Any]) -> None:
super().__init__(model, task, device, hub_kwargs)
self.validate_device()
self.validate_model_type()

def validate_model_type(self) -> None:
if self.model_type not in MODEL_TYPE_TO_TRTMODEL:
raise NotImplementedError(f"TRTBackend does not support model_type {self.model_type}")
if self.model_type not in MODEL_TYPE_TO_TRTLLMMODEL:
raise NotImplementedError(f"TRTLLMBackend does not support model_type {self.model_type}")

def validate_device(self) -> None:
if self.device != "cuda":
raise NotImplementedError(f"TRTBackend only supports device cuda, got {self.device}")
raise NotImplementedError(f"TRTLLMBackend only supports device cuda, got {self.device}")

def configure(self, config: TRTConfig) -> None:
def configure(self, config: TRTLLMConfig) -> None:
super().configure(config)

self.trtmodel_class = get_class(MODEL_TYPE_TO_TRTMODEL[self.model_type])
self.trtmodel_class = get_class(MODEL_TYPE_TO_TRTLLMMODEL[self.model_type])
ortmodel_name = self.trtmodel_class.__name__
LOGGER.info(
f"\t+ Inferred TRTModel class {ortmodel_name} for task {self.task} and model_type {self.model_type}"
f"\t+ Inferred TRTLLMModel class {ortmodel_name} for task {self.task} and model_type {self.model_type}"
)

# TODO: save engine path for reuse, then maybe re build with max_prompt_size
self.load_trtmodel_from_pretrained()

@property
def trtmodel_kwargs(self) -> Dict[str, Any]:
return {}

def load_trtmodel_from_pretrained(self) -> None:
self.pretrained_model = self.trtmodel_class.from_pretrained(
self.model,
**self.trtmodel_kwargs,
tp=self.config.tp,
pp=self.config.pp,
dtype=self.config.dtype,
use_fp8=self.config.use_fp8,
world_size=self.config.world_size,
gpus_per_node=self.config.gpus_per_node,
use_cuda_graph=self.config.use_cuda_graph,
optimization_level=self.config.optimization_level,
max_prompt_length=self.config.max_prompt_length,
max_batch_size=self.config.max_batch_size,
max_new_tokens=self.config.max_new_tokens,
**self.hub_kwargs,
)

Expand All @@ -59,19 +64,20 @@ def forward(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> ModelOutput:

def generate(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> ModelOutput:
return self.pretrained_model.generate(
# spelling args to avoid conflict
input_ids=input.get("inputs", None), # diff api
input_ids=input.get("inputs", None), # diff names
attention_mask=input.get("attention_mask", None),
# important for benchmarking
max_new_tokens=kwargs.get("max_new_tokens", -1),
min_length=kwargs.get("min_new_tokens", -1), # diff api
min_length=kwargs.get("min_new_tokens", -1), # why different ?
num_beams=kwargs.get("num_beams", 1),
temperature=kwargs.get("temperature", 1.0),
top_k=kwargs.get("top_k", 50),
top_p=kwargs.get("top_p", 1.0),
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
# not really important but just in case
repetition_penalty=kwargs.get("repetition_penalty", 0.0),
length_penalty=kwargs.get("length_penalty", 1.0),
seed=kwargs.get("seed", 42),
pad_token_id=kwargs.get("pad_token_id", 0),
bos_token_id=kwargs.get("bos_token_id", 1),
eos_token_id=kwargs.get("eos_token_id", 2),
temperature=kwargs.get("temperature", 1.0),
top_k=kwargs.get("top_k", 50),
top_p=kwargs.get("top_p", 1.0),
seed=kwargs.get("seed", 42),
)
46 changes: 46 additions & 0 deletions optimum_benchmark/backends/tensorrt_llm/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from dataclasses import dataclass
from logging import getLogger

from omegaconf import OmegaConf

from ...import_utils import tesnorrt_version
from ..config import BackendConfig

LOGGER = getLogger("tensorrt-llm")

OmegaConf.register_new_resolver("tensorrt_llm_version", tesnorrt_version)

SUPPORTED_DTYPES = ["float16", "bfloat16", "float32"]


@dataclass
class TRTLLMConfig(BackendConfig):
name: str = "tensorrt_llm"
version: str = "${tensorrt_llm_version:}"
_target_: str = "optimum_benchmark.backends.tensorrt_llm.backend.TRTLLMBackend"

# build config
tp: int = 1
pp: int = 1
use_fp8: bool = False
dtype: str = "float16"
optimization_level: int = 2
use_cuda_graph: bool = False
gpus_per_node: int = "${available_gpus:}"
world_size: int = "${backend.gpus_per_node}"

max_batch_size: int = "${benchmark.input_shapes.batch_size}"
max_prompt_length: int = "${benchmark.input_shapes.sequence_length}"
max_new_tokens: int = "${benchmark.new_tokens}"

def __post_init__(self) -> None:
super().__post_init__()

if self.dtype not in SUPPORTED_DTYPES:
raise ValueError(f"dtype must be one of float16, bfloat16, float32, got {self.dtype}")

if self.gpus_per_node != self.world_size:
raise ValueError(f"gpus_per_node ({self.gpus_per_node}) != world_size ({self.world_size})")

if self.world_size != self.pp * self.tp:
raise ValueError(f"world_size ({self.gpus_per_node}) != pp ({self.pp}) * tp ({self.tp})")
1 change: 1 addition & 0 deletions optimum_benchmark/backends/tensorrt_llm/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
MODEL_TYPE_TO_TRTLLMMODEL = {"llama": "optimum.nvidia.models.llama.LlamaForCausalLM"}
4 changes: 2 additions & 2 deletions optimum_benchmark/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .backends.onnxruntime.config import ORTConfig
from .backends.openvino.config import OVConfig
from .backends.pytorch.config import PyTorchConfig
from .backends.tensorrt.config import TRTConfig
from .backends.tensorrt_llm.config import TRTLLMConfig
from .backends.text_generation_inference.config import TGIConfig
from .benchmarks.inference.config import InferenceConfig
from .benchmarks.training.config import TrainingConfig
Expand Down Expand Up @@ -130,9 +130,9 @@ def __post_init__(self) -> None:
cs.store(name="experiment", node=ExperimentConfig)
#
cs.store(group="backend", name="openvino", node=OVConfig)
cs.store(group="backend", name="tensorrt", node=TRTConfig)
cs.store(group="backend", name="pytorch", node=PyTorchConfig)
cs.store(group="backend", name="onnxruntime", node=ORTConfig)
cs.store(group="backend", name="tensorrt-llm", node=TRTLLMConfig)
cs.store(group="backend", name="neural-compressor", node=INCConfig)
cs.store(group="backend", name="text-generation-inference", node=TGIConfig)
#
Expand Down
2 changes: 1 addition & 1 deletion tests/configs/tensorrt_llm_inference.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defaults:
- backend: tensorrt # default backend
- backend: tensorrt-llm # default backend
# order of inheritance, last one overrides previous ones
- _base_ # inherits from base config
- _inference_ # inherits from inference config
Expand Down
Loading