Skip to content

Commit

Permalink
added trt-llm supported features
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jan 13, 2024
1 parent b8913d0 commit 44af91a
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 19 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ Everything else is either optional or inferred from the model's name or path.

### Supported Backends/Devices

- [x] TensorRT-LLM backend for CUDA (NVIDIA GPUs)
- [x] Pytorch backend for CPU (Intel, AMD, ARM, etc)
- [x] Pytorch backend for CUDA (NVIDIA and AMD GPUs)
- [ ] Pytorch backend for Habana Gaudi Processor (HPU)
- [x] OnnxRuntime backend for CPUExecutionProvider
- [x] OnnxRuntime backend for CUDAExecutionProvider
- [ ] OnnxRuntime backend for ROCMExecutionProvider
- [x] OnnxRuntime backend for ROCMExecutionProvider
- [x] OnnxRuntime backend for TensorrtExecutionProvider
- [x] Intel Neural Compressor backend for CPU
- [x] OpenVINO backend for CPU
Expand Down
6 changes: 3 additions & 3 deletions examples/trt_llama.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- backend: tensorrt # default backend
- launcher: process # default launcher
- benchmark: inference # default benchmark
- launcher: process
- benchmark: inference
- backend: tensorrt-llm
- experiment # inheriting experiment schema
- _self_ # for hydra 1.1 compatibility
- override hydra/job_logging: colorlog # colorful logging
Expand Down
36 changes: 21 additions & 15 deletions optimum_benchmark/backends/tensorrt_llm/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


class TRTLLMBackend(Backend):
NAME: str = "tensorrt-llm"
NAME = "tensorrt-llm"

def __init__(self, model: str, task: str, device: str, hub_kwargs: Dict[str, Any]) -> None:
super().__init__(model, task, device, hub_kwargs)
Expand All @@ -36,17 +36,22 @@ def configure(self, config: TRTLLMConfig) -> None:
f"\t+ Inferred TRTLLMModel class {ortmodel_name} for task {self.task} and model_type {self.model_type}"
)

# TODO: save engine path for reuse, then maybe re build with max_prompt_size
self.load_trtmodel_from_pretrained()

@property
def trtmodel_kwargs(self) -> Dict[str, Any]:
return {}

def load_trtmodel_from_pretrained(self) -> None:
self.pretrained_model = self.trtmodel_class.from_pretrained(
self.model,
**self.trtmodel_kwargs,
tp=self.config.tp,
pp=self.config.pp,
dtype=self.config.dtype,
use_fp8=self.config.use_fp8,
world_size=self.config.world_size,
gpus_per_node=self.config.gpus_per_node,
use_cuda_graph=self.config.use_cuda_graph,
optimization_level=self.config.optimization_level,
max_prompt_length=self.config.max_prompt_length,
max_batch_size=self.config.max_batch_size,
max_new_tokens=self.config.max_new_tokens,
**self.hub_kwargs,
)

Expand All @@ -59,19 +64,20 @@ def forward(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> ModelOutput:

def generate(self, input: Dict[str, Any], kwargs: Dict[str, Any]) -> ModelOutput:
return self.pretrained_model.generate(
# spelling args to avoid conflict
input_ids=input.get("inputs", None), # diff api
input_ids=input.get("inputs", None), # diff names
attention_mask=input.get("attention_mask", None),
# important for benchmarking
max_new_tokens=kwargs.get("max_new_tokens", -1),
min_length=kwargs.get("min_new_tokens", -1), # diff api
min_length=kwargs.get("min_new_tokens", -1), # why different ?
num_beams=kwargs.get("num_beams", 1),
temperature=kwargs.get("temperature", 1.0),
top_k=kwargs.get("top_k", 50),
top_p=kwargs.get("top_p", 1.0),
repetition_penalty=kwargs.get("repetition_penalty", 1.0),
# not really important but just in case
repetition_penalty=kwargs.get("repetition_penalty", 0.0),
length_penalty=kwargs.get("length_penalty", 1.0),
seed=kwargs.get("seed", 42),
pad_token_id=kwargs.get("pad_token_id", 0),
bos_token_id=kwargs.get("bos_token_id", 1),
eos_token_id=kwargs.get("eos_token_id", 2),
temperature=kwargs.get("temperature", 1.0),
top_k=kwargs.get("top_k", 50),
top_p=kwargs.get("top_p", 1.0),
seed=kwargs.get("seed", 42),
)
28 changes: 28 additions & 0 deletions optimum_benchmark/backends/tensorrt_llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,37 @@

OmegaConf.register_new_resolver("tensorrt_llm_version", tesnorrt_version)

SUPPORTED_DTYPES = ["float16", "bfloat16", "float32"]


@dataclass
class TRTLLMConfig(BackendConfig):
name: str = "tensorrt_llm"
version: str = "${tensorrt_llm_version:}"
_target_: str = "optimum_benchmark.backends.tensorrt_llm.backend.TRTLLMBackend"

# build config
tp: int = 1
pp: int = 1
use_fp8: bool = False
dtype: str = "float16"
optimization_level: int = 2
use_cuda_graph: bool = False
gpus_per_node: int = "${available_gpus:}"
world_size: int = "${backend.gpus_per_node}"

max_batch_size: int = "${benchmark.input_shapes.batch_size}"
max_prompt_length: int = "${benchmark.input_shapes.sequence_length}"
max_new_tokens: int = "${benchmark.new_tokens}"

def __post_init__(self) -> None:
super().__post_init__()

if self.dtype not in SUPPORTED_DTYPES:
raise ValueError(f"dtype must be one of float16, bfloat16, float32, got {self.dtype}")

if self.gpus_per_node != self.world_size:
raise ValueError(f"gpus_per_node ({self.gpus_per_node}) != world_size ({self.world_size})")

if self.world_size != self.pp * self.tp:
raise ValueError(f"world_size ({self.gpus_per_node}) != pp ({self.pp}) * tp ({self.tp})")

0 comments on commit 44af91a

Please sign in to comment.