diff --git a/examples/inference_tp.py b/examples/inference_tp.py index c0a07a6a..a012769c 100644 --- a/examples/inference_tp.py +++ b/examples/inference_tp.py @@ -10,9 +10,22 @@ config.arch_compat_overrides() config.no_graphs = True model = ExLlamaV2(config) -model.load_tp(progress = True) + +# Load the model in tensor-parallel mode. With no gpu_split specified, the model will attempt to split across +# all visible devices according to the currently available VRAM on each. expect_cache_tokens is necessary for +# balancing the split, in case the GPUs are of uneven sizes, or if the number of GPUs doesn't divide the number +# of KV heads in the model +# +# The cache type for a TP model is always ExLlamaV2Cache_TP and should be allocated after the model. To use a +# quantized cache, add a `base = ExLlamaV2Cache_Q6` etc. argument to the cache constructor. It's advisable +# to also add `expect_cache_base = ExLlamaV2Cache_Q6` to load_tp() as well so the size can be correctly +# accounted for when splitting the model. + +model.load_tp(progress = True, expect_cache_tokens = 16384) cache = ExLlamaV2Cache_TP(model, max_seq_len = 16384) +# After loading the model, all other functions should work the same + print("Loading tokenizer...") tokenizer = ExLlamaV2Tokenizer(config)