forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval.py
293 lines (264 loc) · 10.8 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import aiohttp
import datasets
import torch
from transformers import AutoProcessor
from utils import add_common_args
import tensorrt_llm
import tensorrt_llm.profiler as profiler
from tensorrt_llm import logger
from tensorrt_llm.runtime import MultimodalModelRunner
SUPPORTED_MODEL_TYPES = {
'blip2': 'Blip2ForConditionalGeneration',
'fuyu': 'FuyuForCausalLM',
'kosmos-2': 'Kosmos2ForConditionalGeneration',
'llava': 'LlavaForConditionalGeneration',
'llava_next': 'LlavaNextForConditionalGeneration',
'llava_onevision': 'LlavaOnevisionForConditionalGeneration',
'phi-3-vision': 'AutoModelForCausalLM',
'qwen2_vl': 'Qwen2VLForConditionalGeneration', # not tested for TRT-LLM yet
'mllama': 'MllamaForConditionalGeneration',
'vila': None,
'cogvlm': None, # not tested for TRT-LLM yet
'neva': None, # not tested for TRT-LLM yet
'internvl': None,
}
EVAL_TASKS = ['lmms-lab/ai2d', 'lmms-lab/VQAv2', 'lmms-lab/MME']
def parse_arguments(args=None):
parser = argparse.ArgumentParser()
parser = add_common_args(parser)
parser.add_argument('--test_trtllm',
action='store_true',
default=None,
help="Evaluate the TensorRT-LLM.")
parser.add_argument('--test_hf',
action='store_true',
default=None,
help="Evaluate the Huggingface.")
parser.add_argument('--max_ite', type=int, default=20)
parser.add_argument('--eval_task',
type=str,
choices=EVAL_TASKS,
default='lmms-lab/VQAv2')
parser.add_argument('--model_type',
type=str,
default=None,
choices=SUPPORTED_MODEL_TYPES.keys())
parser.add_argument(
'--accuracy_threshold',
type=float,
default=None,
help=
'used to check the accuracy of test_trtllm. Should be between 0 and 100.'
)
parser.add_argument(
'--dataset_dir',
type=str,
default=None,
help="The local directory of the dataset for evaluation; "
"will download the dataset from huggingface hub if not specified.")
parser.add_argument(
'--dataset_cache_dir',
type=str,
default=None,
help="The local cache directory for dataset; "
"will use `~/.cache/huggingface/datasets` if not specified.")
return parser.parse_args(args=args)
def load_dataset(args) -> datasets.Dataset:
split_name = 'validation' if 'VQAv2' in args.eval_task else 'test'
dataset = datasets.load_dataset(
args.dataset_dir or args.eval_task,
cache_dir=args.dataset_cache_dir,
split=split_name,
storage_options={
'client_kwargs': {
'timeout': aiohttp.ClientTimeout(total=3600)
}
},
)
return dataset
def load_hf_model(args):
if SUPPORTED_MODEL_TYPES[args.model_type] == None:
raise ValueError(f"Unsupported HF model_type: {args.model_type}")
profiler.start('load HF model')
model_class = getattr(__import__('transformers'),
SUPPORTED_MODEL_TYPES[args.model_type])
hf_model = model_class.from_pretrained(args.hf_model_dir,
torch_dtype=torch.float16,
device_map="cuda:0",
trust_remote_code=True)
profiler.stop('load HF model')
logger.info(
f'Load HF model takes: {profiler.elapsed_time_in_sec("load HF model")} sec'
)
return hf_model
def load_trtllm_model(args):
profiler.start('load TensorRT-LLM model')
trtllm_model = MultimodalModelRunner(args)
profiler.stop('load TensorRT-LLM model')
logger.info(
f'Load TensorRT-LLM model takes: {profiler.elapsed_time_in_sec("load TensorRT-LLM model")} sec'
)
return trtllm_model
def prepare_prompts(task, data, model_type, processor) -> str:
prompts = None
question = data['question']
if question[-1] != '?':
question += '?'
if task == 'lmms-lab/ai2d':
for j, option in enumerate(data['options']):
question += f" ({j}) {option}"
if model_type in ['blip2', 'neva']:
prompts = f"Question: {question} Answer: "
elif model_type == 'fuyu':
prompts = f"Answer the following {task} question based on the image: {question}"
elif model_type == 'kosmos-2':
prompts = f"<grounding> Question: {question} Answer: "
elif model_type == 'cogvlm':
prompts = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {question} ASSISTANT:"
elif model_type in ['llava', 'llava_next', 'llava_onevision']:
conversation = [
{
"role":
"user",
"content": [
{
"type": "image"
},
{
"type": "text",
"text": question
},
],
},
]
prompts = processor.apply_chat_template(conversation,
add_generation_prompt=True)
elif model_type in ['vila', 'internvl']:
prompts = f"<image>\n{question}"
elif model_type == 'phi-3-vision':
messages = [
{
"role": "user",
"content": f"<|image_1|>\n{question}"
},
]
prompts = processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True)
elif model_type == 'qwen2_vl':
conversation = [{
"role":
"user",
"content": [{
"type": "image",
}, {
"type": "text",
"text": question
}]
}]
prompts = processor.apply_chat_template(conversation,
add_generation_prompt=True)
elif model_type == 'mllama':
# TODO: use apply_chat_template when llama 3.2 model is updated
prompts = f"<|image|><|begin_of_text|> {question}; answer: "
else:
raise ValueError(f"Unsupported model_type: {model_type}")
return prompts
def eval(output, task, data) -> bool:
output = output.strip().lower()
if task == 'lmms-lab/VQAv2':
return any(answer['answer'] in output for answer in data['answers'])
elif task == 'lmms-lab/ai2d':
return data['answer'] == output[0]
else:
return data['answer'].lower() in output
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = parse_arguments()
if args.model_type not in SUPPORTED_MODEL_TYPES:
raise ValueError(f"Unsupported model_type: {args.model_type}")
logger.set_level(args.log_level)
runtime_rank = tensorrt_llm.mpi_rank()
dataset = load_dataset(args)
hf_model = load_hf_model(args) if args.test_hf else None
trtllm_model = load_trtllm_model(args) if args.test_trtllm else None
if SUPPORTED_MODEL_TYPES[args.model_type] == None:
hf_processor = None
else:
hf_processor = AutoProcessor.from_pretrained(args.hf_model_dir,
trust_remote_code=True)
hf_correct = trtllm_correct = 0
if args.test_trtllm or args.test_hf:
for i in range(args.max_ite):
logger.debug(f"Ite: {i:3d}")
data = dataset[i]
prompts = prepare_prompts(args.eval_task, data, args.model_type,
hf_processor)
image = data['image']
if args.test_hf:
assert hf_model is not None, f"Unsupported HF model_type: {args.model_type}"
profiler.start('hf')
inputs = hf_processor(
images=image,
text=prompts,
return_tensors="pt",
).to(hf_model.device,
torch.float16) # add torch.float16 for llava-onevision
input_length = inputs.input_ids.shape[-1]
hf_output = hf_model.generate(**inputs,
max_new_tokens=args.max_new_tokens)
hf_result = (hf_processor.batch_decode(
hf_output, skip_special_tokens=True)[0] if args.model_type in [
'blip2'
] else hf_processor.decode(hf_output[0][input_length:],
skip_special_tokens=True))
hf_correct += eval(hf_result, args.eval_task, data)
profiler.stop('hf')
if args.test_trtllm:
profiler.start('tensorrt_llm')
_, output_text = trtllm_model.run(
prompts, image, max_new_tokens=args.max_new_tokens)
if runtime_rank == 0:
trtllm_result = output_text[0][0]
trtllm_correct += eval(trtllm_result, args.eval_task, data)
profiler.stop('tensorrt_llm')
if runtime_rank == 0:
if args.eval_task == 'lmms-lab/VQAv2':
answer = data['answers']
else:
answer = data['answer']
logger.debug(f"prompts: {prompts}")
logger.debug(f"reference answer: {answer}")
if args.test_hf:
logger.debug(f"HF's answer: {hf_result}")
if args.test_trtllm:
logger.debug(f"TRT-LLM's answer: {trtllm_result}")
if runtime_rank == 0:
logger.info(f"total iterations: {args.max_ite}")
if args.test_hf:
logger.info(
f"HF's accuracy: {100 * hf_correct / args.max_ite:4.2f}%")
if args.test_trtllm:
logger.info(
f"TRT-LLM's accuracy: {100 * trtllm_correct / args.max_ite:4.2f}%"
)
# check if the accuracy is above the threshold
if args.accuracy_threshold is not None and args.test_trtllm:
assert trtllm_correct / args.max_ite >= args.accuracy_threshold / 100, \
f"TRT-LLM's accuracy is below the threshold: {args.accuracy_threshold}%"
else:
logger.info("Neither enable test_trtllm nor enable test_hf")