diff --git a/python/llm/dev/benchmark/all-in-one/run.py b/python/llm/dev/benchmark/all-in-one/run.py index 1a1d9490f9e..af254d1a05b 100644 --- a/python/llm/dev/benchmark/all-in-one/run.py +++ b/python/llm/dev/benchmark/all-in-one/run.py @@ -452,7 +452,7 @@ def run_transformer_int4_gpu(repo_id, if fp16: torch_dtype = torch.float16 else: - torch_dtype = 'auto' + torch_dtype = torch.float32 st = time.perf_counter() origin_repo_id = repo_id.replace("-4bit", "") if origin_repo_id in CHATGLM_IDS: