This document shows how to build and run a Mamba model in TensorRT-LLM on a single GPU.
The TensorRT-LLM Mamba implementation can be found in tensorrt_llm/models/mamba/model.py
. The TensorRT-LLM Mamba example code is located in examples/mamba
. There is one main file:
convert_checkpoint.py
to convert a checkpoint from the HuggingFace (HF) Transformers format to the TensorRT-LLM format.
In addition, there are two shared files in the parent folder examples
for inference and evaluation:
../run.py
to run the inference on an input text;../summarize.py
to summarize the articles in the cnn_dailymail dataset.
Model Name | FP16 | BF16 | TP |
---|---|---|---|
Mamba1 | Y | Y | N |
Mamba2 | Y | Y | Y |
- Mamba2: TensorRT-LLM can only support the pure Mamba model for now, will support the hybrid models later.
The next two sections describe how to convert the weights from the HuggingFace (HF) Transformers format to the TensorRT-LLM format.
Please install required packages first and setup git-lfs
:
pip install -r requirements.txt
git lfs install
There are different HF checkpoints available. For Mamba1, TensorRT-LLM can support those Transformers compatible models. Here're some examples to fetch the checkpoint.
# mamba-2.8b
git clone https://huggingface.co/state-spaces/mamba-2.8b-hf ./mamba_model/mamba-2.8b
# mamba-130m
git clone https://huggingface.co/state-spaces/mamba-130m-hf ./mamba_model/mamba-130m
# mamba2-2.7b
git clone https://huggingface.co/state-spaces/mamba2-2.7b ./mamba_model/mamba2-2.7b
# mamba2-130m
git clone https://huggingface.co/state-spaces/mamba2-130m ./mamba_model/mamba2-130m
# mamba-codestral-7B-v0.1
git clone https://huggingface.co/mistralai/mamba-codestral-7B-v0.1 ./mamba_model/mamba-codestral-7B-v0.1
Since mamba models use tokenizer from gpt-neox-20b model, use the following command to fetch the checkpoint of gpt-neox-20b.
# gpt-neox-20b
git clone https://huggingface.co/EleutherAI/gpt-neox-20b ./mamba_model/gpt-neox-20b
The convert_checkpoint.py
script converts HF weights to TensorRT-LLM checkpoints.
For the Mamba2 models, if they can support tensor parallelism, you can run them with 1, 2, 4 or 8 GPUs. Here we use mamba-codestral-7B-v0.1 as an example.
# mamba-2.8b
python convert_checkpoint.py --model_dir ./mamba_model/mamba-2.8b/ \
--dtype bfloat16 \
--output_dir ./mamba_model/mamba-2.8b/trt_ckpt/bf16/1-gpu/
# mamba-130m
python convert_checkpoint.py --model_dir ./mamba_model/mamba-130m/ \
--dtype float16 \
--output_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/
# mamba2-2.7b
python convert_checkpoint.py --model_dir ./mamba_model/mamba2-2.7b/ \
--dtype float16 \
--output_dir ./mamba_model/mamba2-2.7b/trt_ckpt/fp16/1-gpu/
# mamba2-130m
python convert_checkpoint.py --model_dir ./mamba_model/mamba2-130m/ \
--dtype float16 \
--output_dir ./mamba_model/mamba2-130m/trt_ckpt/fp16/1-gpu/
# mamba-codestral-7B-v0.1
python convert_checkpoint.py --model_dir ./mamba_model/mamba-codestral-7B-v0.1/ \
--dtype float16 \
--output_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_ckpt/fp16/1-gpu/
# mamba-codestral-7B-v0.1 with 2-way tensor parallelism.
python convert_checkpoint.py --model_dir ./mamba_model/mamba-codestral-7B-v0.1/ \
--dtype float16 \
--world_size 2 \
--output_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_ckpt/fp16/2-gpu/
The trtllm-build
command builds TensorRT-LLM engines from TensorRT-LLM checkpoints.
# mamba-2.8b
trtllm-build --checkpoint_dir ./mamba_model/mamba-2.8b/trt_ckpt/bf16/1-gpu/ \
--paged_kv_cache disable \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba-2.8b/trt_engines/bf16/1-gpu/
# mamba-130m
trtllm-build --checkpoint_dir ./mamba_model/mamba-130m/trt_ckpt/fp16/1-gpu/ \
--paged_kv_cache disable \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba-130m/trt_engines/fp16/1-gpu/
# mamba2-2.7b
trtllm-build --checkpoint_dir ./mamba_model/mamba2-2.7b/trt_ckpt/fp16/1-gpu/ \
--paged_kv_cache disable \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba2-2.7b/trt_engines/fp16/1-gpu/
# mamba2-130m
trtllm-build --checkpoint_dir ./mamba_model/mamba2-130m/trt_ckpt/fp16/1-gpu/ \
--paged_kv_cache disable \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba2-130m/trt_engines/fp16/1-gpu/
# mamba-codestral-7B-v0.1
trtllm-build --checkpoint_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_ckpt/fp16/1-gpu/ \
--paged_kv_cache disable \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_engines/fp16/1-gpu/
# mamba-codestral-7B-v0.1 with 2-way tensor parallelism.
trtllm-build --checkpoint_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_ckpt/fp16/2-gpu/ \
--paged_kv_cache disable \
--gemm_plugin auto \
--max_batch_size 8 \
--max_input_len 924 \
--max_seq_len 1024 \
--output_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_engines/fp16/2-gpu/
Note that when building Mamba models, you need to disable the paged_kv_cache
as it is used for
transformer-based models. Mamba models use paged_state
instead and it is enabled by default.
If paged_state
is disabled, engine will be built with the contiguous stage cache.
The following section describes how to run a TensorRT-LLM Mamba model to summarize the articles from the
cnn_dailymail dataset. For each summary, the script can compute the
ROUGE scores and use the ROUGE-1
score to validate the implementation.
# mamba-2.8b
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba-2.8b/ \
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
--data_type bf16 \
--engine_dir ./mamba_model/mamba-2.8b/trt_engines/bf16/1-gpu/
# mamba-130m
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba-130m/ \
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
--data_type fp16 \
--engine_dir ./mamba_model/mamba-130m/trt_engines/fp16/1-gpu/
# mamba2-2.7b
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba2-2.7b/ \
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
--data_type fp16 \
--engine_dir ./mamba_model/mamba2-2.7b/trt_engines/fp16/1-gpu/
# mamba2-130m
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba2-130m/ \
--tokenizer_dir ./mamba_model/gpt-neox-20b/ \
--data_type fp16 \
--engine_dir ./mamba_model/mamba2-130m/trt_engines/fp16/1-gpu/
# mamba-codestral-7B-v0.1
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba-codestral-7B-v0.1/ \
--tokenizer_dir ./mamba_model/mamba-codestral-7B-v0.1/ \
--data_type fp16 \
--engine_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_engines/fp16/1-gpu/
# mamba-codestral-7B-v0.1 with 2-way tensor parallelism.
mpirun -n 2 --allow-run-as-root \
python ../summarize.py --test_trt_llm \
--hf_model_dir ./mamba_model/mamba-codestral-7B-v0.1/ \
--tokenizer_dir ./mamba_model/mamba-codestral-7B-v0.1/ \
--data_type fp16 \
--engine_dir ./mamba_model/mamba-codestral-7B-v0.1/trt_engines/fp16/2-gpu/