forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun.py
875 lines (778 loc) · 41 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
import time
from pathlib import Path
# isort: off
import torch
import tensorrt as trt
# isort: on
import numpy as np
from transformers import (AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer,
BartForConditionalGeneration,
MBartForConditionalGeneration,
T5ForConditionalGeneration)
import tensorrt_llm
from tensorrt_llm import logger
from tensorrt_llm._utils import torch_to_numpy, trt_dtype_to_torch
from tensorrt_llm.lora_manager import LoraManager
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
def get_engine_name(rank):
return 'rank{}.engine'.format(rank)
def print_tensor(tensor_name, tensor, num_elements=10):
if tensor.dtype in (torch.int32, torch.int64):
tensor = tensor.to(dtype=float)
print(
f'{tensor_name}: mean={tensor.abs().mean().item():.3f}, sum={tensor.abs().sum().item():.3f}, max={tensor.abs().max().item():.3f}'
)
# Pass num_elements=-1 will print the whole tensor
if num_elements < 0:
num_elements = torch.numel(tensor)
print(f'{tensor.flatten()[:num_elements]}')
print("Tensor Shape: ", tensor.size())
print("")
def read_config(config_path: Path):
with open(config_path, "r") as f:
config = json.load(f)
builder_config = config['build_config']
plugin_config = builder_config['plugin_config']
pretrained_config = config['pretrained_config']
lora_config = builder_config['lora_config']
auto_parallel_config = builder_config['auto_parallel_config']
use_gpt_attention_plugin = plugin_config["gpt_attention_plugin"]
remove_input_padding = plugin_config["remove_input_padding"]
use_lora_plugin = plugin_config["lora_plugin"]
tp_size = pretrained_config['mapping']['tp_size']
pp_size = pretrained_config['mapping']['pp_size']
gpus_per_node = auto_parallel_config['gpus_per_node']
world_size = tp_size * pp_size
assert world_size == tensorrt_llm.mpi_world_size(), \
f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
num_heads = pretrained_config["num_attention_heads"]
hidden_size = pretrained_config["hidden_size"]
head_size = pretrained_config["head_size"]
vocab_size = pretrained_config["vocab_size"]
max_batch_size = builder_config["max_batch_size"]
max_beam_width = builder_config["max_beam_width"]
num_layers = pretrained_config["num_hidden_layers"]
num_kv_heads = pretrained_config.get('num_kv_heads', num_heads)
assert (num_heads % tp_size) == 0
num_heads = num_heads // tp_size
hidden_size = hidden_size // tp_size
num_kv_heads = (num_kv_heads + tp_size - 1) // tp_size
cross_attention = pretrained_config["architecture"] == "DecoderModel"
skip_cross_qkv = pretrained_config.get('skip_cross_qkv', False)
has_position_embedding = pretrained_config["has_position_embedding"]
has_token_type_embedding = hasattr(pretrained_config, "type_vocab_size")
use_custom_all_reduce = plugin_config.get('use_custom_all_reduce', False)
dtype = pretrained_config["dtype"]
paged_kv_cache = plugin_config['paged_kv_cache']
tokens_per_block = plugin_config['tokens_per_block']
gather_context_logits = builder_config.get('gather_context_logits', False)
gather_generation_logits = builder_config.get('gather_generation_logits',
False)
max_prompt_embedding_table_size = builder_config.get(
'max_prompt_embedding_table_size', 0)
model_config = ModelConfig(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
hidden_size=hidden_size,
head_size=head_size,
max_batch_size=max_batch_size,
max_beam_width=max_beam_width,
vocab_size=vocab_size,
num_layers=num_layers,
gpt_attention_plugin=use_gpt_attention_plugin,
remove_input_padding=remove_input_padding,
paged_kv_cache=paged_kv_cache,
tokens_per_block=tokens_per_block,
cross_attention=cross_attention,
has_position_embedding=has_position_embedding,
has_token_type_embedding=has_token_type_embedding,
use_custom_all_reduce=use_custom_all_reduce,
dtype=dtype,
gather_context_logits=gather_context_logits,
gather_generation_logits=gather_generation_logits,
max_prompt_embedding_table_size=max_prompt_embedding_table_size,
lora_plugin=use_lora_plugin,
lora_target_modules=lora_config.get('lora_target_modules'),
trtllm_modules_to_hf_modules=lora_config.get(
'trtllm_modules_to_hf_modules'),
skip_cross_qkv=skip_cross_qkv,
)
return model_config, tp_size, pp_size, gpus_per_node, dtype
def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--max_new_tokens", type=int, default=64)
parser.add_argument("--log_level", type=str, default="error")
parser.add_argument("--engine_dir", "-i", type=str, default="trt_engines")
parser.add_argument("--engine_name", type=str, default="enc_dec")
parser.add_argument("--model_name",
type=str,
help="HuggingFace model name or FairSeq model path",
default="t5-small")
parser.add_argument("--num_beams",
type=int,
help="Use beam search if num_beams >1",
default=1)
parser.add_argument("--debug_mode",
help="Whether or not to turn on the debug mode",
action='store_true')
parser.add_argument("--compare_hf_fp32",
help="Compare results with HuggingFace FP32",
action='store_true')
parser.add_argument('--lora_dir', type=str, default=None, nargs="+")
parser.add_argument('--lora_task_uids', type=str, default=None, nargs="+")
parser.add_argument(
"--output_encoder_npy",
help=
"Store tensors like encoder outputs used for testing enc-dec C++ runtime.",
action="store_true")
return parser.parse_args()
class TRTLLMEncDecModel:
def __init__(
self,
engine_name,
engine_dir,
lora_dir=None,
lora_task_uids=None,
debug_mode=False,
skip_encoder=False,
stream: torch.cuda.Stream = None,
):
# in multi-node setup, it's important to set_device at the very beginning so .to('cuda') refers to current device
# accordingly, all input & output tensors should be moved to current device
# otherwise, it's default to 'cuda:0'
self.runtime_rank = tensorrt_llm.mpi_rank()
device_id = self.runtime_rank % torch.cuda.device_count()
torch.cuda.set_device(device_id)
self.device = torch.cuda.current_device()
self.skip_encoder = skip_encoder
self.lora_task_uids = lora_task_uids
# when enc-dec runs by itself, stream can be None and we create new stream here
# when enc-dec has to run as a component in a bigger workflow (e.g., multimodal), earlier components in the workflow may have results in its stream, which we should pass that stream in to avoid unnecessary stream sync
self.stream = stream
if self.stream is None:
self.stream = torch.cuda.Stream(self.device)
torch.cuda.set_stream(self.stream)
engine_dir = Path(engine_dir)
def engine_setup(component):
# model config
config_path = engine_dir / component / "config.json"
logger.info(f"Using config path {config_path}")
model_config, tp_size, pp_size, gpus_per_node, dtype = read_config(
config_path)
# MGMN config
world_size = tp_size * pp_size
runtime_rank = tensorrt_llm.mpi_rank()
assert runtime_rank < world_size, "Runtime GPU rank exceeds MPI world size. Did you launch more MPI processes than required?"
runtime_mapping = tensorrt_llm.Mapping(world_size,
runtime_rank,
tp_size=tp_size,
pp_size=pp_size,
gpus_per_node=gpus_per_node)
# load engine
engine_fname = get_engine_name(runtime_rank)
with open(engine_dir / component / engine_fname, "rb") as f:
engine_buffer = f.read()
return model_config, runtime_mapping, engine_buffer
# Note: encoder and decoder doesn't necessarily have the same TP & PP config
if not skip_encoder:
self.encoder_model_config, self.encoder_runtime_mapping, encoder_engine_buffer = engine_setup(
component='encoder')
# for Pipeline Parallelism in encoder
self.nccl_comm = torch.classes.trtllm.NcclCommunicatorOp(
self.encoder_runtime_mapping.tp_size,
self.encoder_runtime_mapping.pp_size,
self.encoder_runtime_mapping.rank)
# session setup
self.encoder_session = tensorrt_llm.runtime.Session.from_serialized_engine(
encoder_engine_buffer)
# encoder lora manager setup
if self.encoder_model_config.lora_plugin:
self.encoder_lora_manager = LoraManager()
# TODO: this is only for bart
self.encoder_lora_manager.load_from_hf(
model_dirs=lora_dir,
model_config=self.encoder_model_config,
runtime_mapping=self.encoder_runtime_mapping,
component='encoder',
)
else:
self.encoder_lora_manager = None
else:
self.encoder_model_config, self.encoder_runtime_mapping, encoder_engine_buffer = None, None, None
self.nccl_comm, self.encoder_session = None, None
self.decoder_model_config, self.decoder_runtime_mapping, decoder_engine_buffer = engine_setup(
component='decoder')
self.decoder_session = tensorrt_llm.runtime.GenerationSession(
self.decoder_model_config,
decoder_engine_buffer,
self.decoder_runtime_mapping,
debug_mode=debug_mode)
# decoder lora manager setup
if self.decoder_model_config.lora_plugin:
self.decoder_lora_manager = LoraManager()
# TODO: this is only for bart
self.decoder_lora_manager.load_from_hf(
model_dirs=lora_dir,
model_config=self.decoder_model_config,
runtime_mapping=self.decoder_runtime_mapping,
component='decoder',
)
else:
self.decoder_lora_manager = None
@classmethod
def from_engine(cls,
engine_name,
engine_dir,
lora_dir=None,
lora_task_uids=None,
debug_mode=False,
skip_encoder=False,
stream=None):
return cls(engine_name,
engine_dir,
lora_dir,
lora_task_uids,
debug_mode=debug_mode,
skip_encoder=skip_encoder,
stream=stream)
def process_input(self,
input_ids,
remove_input_padding=False,
pad_token_id=0,
prompt_tasks=None):
if remove_input_padding:
# in remove padding mode --> flatten input, calculate actual length and max length
# Note: 1st token should never be removed, even if it is pad_token_id
first_ids = input_ids[:, 0]
input_ids = input_ids[:, 1:]
input_lengths = 1 + (input_ids != pad_token_id).sum(dim=1).type(
torch.IntTensor).to(self.device) # [batch_size]
new_ids = []
for i in range(len(input_ids)):
row = input_ids[i, :]
row = row[row != pad_token_id]
new_ids.append(
torch.cat(
(torch.IntTensor([first_ids[i]]).to(self.device), row)))
input_ids = torch.cat(new_ids) # [num_tokens]
if prompt_tasks is not None:
prompt_tasks = prompt_tasks[:input_ids.shape[0]]
else:
# in padding mode --> keep input, just calculate actual length and max length
# Note: 1st token should always count, even if it is pad_token_id. e.g., decoder start id in enc-dec models could be a single pad_token_id, we should count
input_lengths = torch.tensor(
1 + (input_ids[:, 1:] != pad_token_id).sum(dim=1).type(
torch.IntTensor).to(self.device),
dtype=torch.int32,
device=self.device)
max_input_length = torch.max(input_lengths).item()
return input_ids, input_lengths, max_input_length, prompt_tasks
def encoder_run(self,
input_ids,
input_lengths,
max_input_length,
position_ids=None,
token_type_ids=None,
debug_mode=False,
prompt_embedding_table=None,
prompt_tasks=None,
prompt_vocab_size=None,
attention_mask=None):
# each engine has hidden_dim/TP, don't forget to multiply TP
hidden_size = self.encoder_model_config.hidden_size * self.encoder_runtime_mapping.tp_size
if input_ids.dim() == 1:
hidden_states_shape = (input_ids.shape[0], hidden_size
) # [num_tokens,D]
else:
hidden_states_shape = (input_ids.shape[0], input_ids.shape[1],
hidden_size) # [BS,seqlen,D]
hidden_states_dtype = lambda name: trt_dtype_to_torch(
self.encoder_session.engine.get_tensor_dtype(name))
# input tensors. only first PP rank has id input, others are hidden_states input
inputs = {}
if self.encoder_runtime_mapping.is_first_pp_rank():
inputs['input_ids'] = input_ids.contiguous()
if self.encoder_model_config.has_position_embedding:
if position_ids is None:
if self.encoder_model_config.remove_input_padding:
position_ids = [
torch.arange(sample_length,
dtype=torch.int32,
device=input_ids.device)
for sample_length in torch_to_numpy(input_lengths)
]
position_ids = torch.cat(position_ids)
else:
bsz, seq_len = input_ids.shape[:2]
position_ids = torch.arange(
seq_len, dtype=torch.int32,
device=input_ids.device).expand(bsz, -1)
inputs['position_ids'] = position_ids.contiguous()
if self.encoder_model_config.has_token_type_embedding:
inputs['token_type_ids'] = token_type_ids.contiguous()
if self.encoder_model_config.max_prompt_embedding_table_size > 0:
inputs[
'prompt_embedding_table'] = prompt_embedding_table.contiguous(
)
inputs['tasks'] = prompt_tasks.contiguous()
inputs['prompt_vocab_size'] = prompt_vocab_size.contiguous()
else:
# just need a placeholder, engine will call NCCL to recv and fill data from previous rank
inputs['hidden_states_input'] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype('hidden_states_input'),
device=self.device).contiguous()
if attention_mask is not None and not self.encoder_model_config.gpt_attention_plugin:
inputs['attention_mask'] = attention_mask.contiguous()
inputs['input_lengths'] = input_lengths
# use shape info to pass max length info in remove padding mode
inputs['max_input_length'] = torch.empty(
(max_input_length, ),
dtype=hidden_states_dtype('max_input_length'),
device=self.device).contiguous()
batch_size = input_lengths.size(0)
inputs['host_request_types'] = torch.IntTensor([0] *
batch_size).to('cpu')
if self.encoder_model_config.remove_input_padding:
inputs['host_context_lengths'] = input_lengths.to('cpu')
if self.encoder_model_config.lora_plugin and self.encoder_lora_manager is not None:
inputs.update(
self.encoder_lora_manager.input_buffers(
self.lora_task_uids,
self.encoder_runtime_mapping,
self.encoder_model_config.num_layers,
))
# Note: runtime.Session's run() method will set input/output tensor address, here we only need to provide tensor shape
self.encoder_session.set_shapes(inputs)
# output tensors. only last PP rank final encoder output, others are intermediate hidden_states output. Need broadcast later
outputs = {}
if self.encoder_runtime_mapping.is_last_pp_rank():
outputs['encoder_output'] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype('encoder_output'),
device=self.device).contiguous()
else:
outputs['hidden_states_output'] = torch.empty(
hidden_states_shape,
dtype=hidden_states_dtype('hidden_states_output'),
device=self.device).contiguous()
# -------------------------------------------
if debug_mode:
engine = self.encoder_session.engine
context = self.encoder_session.context
# setup debugging buffer for the encoder
for i in range(self.encoder_session.engine.num_io_tensors):
name = engine.get_tensor_name(i)
if engine.get_tensor_mode(
name
) == trt.TensorIOMode.OUTPUT and name not in outputs.keys():
dtype = engine.get_tensor_dtype(name)
shape = context.get_tensor_shape(name)
outputs[name] = torch.zeros(tuple(shape),
dtype=trt_dtype_to_torch(dtype),
device=self.device)
context.set_tensor_address(name, outputs[name].data_ptr())
# -------------------------------------------
# TRT session run
# Note: need cuda stream ID, not a torch Stream
ok = self.encoder_session.run(inputs, outputs, self.stream.cuda_stream)
assert ok, "Runtime execution failed"
self.stream.synchronize()
# Tensor Parallelism is handled by model/engine definition
# But we need to broadcast among PP group at the end of encoder's Pipeline Parallelism
# After this, all ranks should recv the encoder output, and world might be re-configured using decoder's TP-PP config
def pp_communicate_encoder_output(encoder_output):
if self.encoder_runtime_mapping.is_last_pp_rank():
for pp_rank in self.encoder_runtime_mapping.pp_group:
if pp_rank != self.encoder_runtime_mapping.rank:
self.nccl_comm.send(encoder_output, pp_rank)
return encoder_output
else:
self.nccl_comm.recv(encoder_output,
self.encoder_runtime_mapping.pp_group[-1])
return encoder_output
if self.encoder_runtime_mapping.has_pp():
# use hidden_states output buffer to receive output as the shapes are same
encoder_output_buf = outputs[
'encoder_output'] if self.encoder_runtime_mapping.is_last_pp_rank(
) else outputs['hidden_states_output']
encoder_output = pp_communicate_encoder_output(encoder_output_buf)
else:
encoder_output = outputs['encoder_output']
# -------------------------------------------
if debug_mode and self.encoder_runtime_mapping.tp_rank == 0: # only tp_rank 0 print encoder output
torch.cuda.synchronize()
# use print_tensor() to print the tensors registered in the encoder network
print("--------------------------------------")
print("Debug output for Encoder")
print("--------------------------------------")
print("Registered output tensors are: ", outputs.keys())
for k, v in outputs.items():
print_tensor(k, v, num_elements=30)
print_tensor('encoder_output', encoder_output)
print("--------------------------------------")
# -------------------------------------------
return encoder_output
def generate(self,
encoder_input_ids,
decoder_input_ids,
max_new_tokens,
num_beams=1,
pad_token_id=None,
eos_token_id=None,
bos_token_id=None,
debug_mode=False,
return_dict=False,
prompt_embedding_table=None,
prompt_tasks=None,
prompt_vocab_size=None,
attention_mask=None,
time_encoder=False,
return_encoder_output=False):
## ensure all externally provided tensors are on the correct device.
encoder_input_ids = encoder_input_ids.to(self.device)
decoder_input_ids = decoder_input_ids.to(self.device)
if attention_mask is not None:
attention_mask = torch.tensor(attention_mask,
dtype=torch.int32,
device=self.device)
## encoder run
encoder_remove_input_padding = self.encoder_model_config.remove_input_padding if self.encoder_model_config else self.decoder_model_config.remove_input_padding
encoder_input_ids, encoder_input_lengths, encoder_max_input_length, prompt_tasks = self.process_input(
encoder_input_ids, encoder_remove_input_padding, pad_token_id,
prompt_tasks)
if not self.skip_encoder:
logger.info(f"Rank {self.runtime_rank} Running encoder engine ...")
if time_encoder:
tik = time.time()
encoder_output = self.encoder_run(
encoder_input_ids,
encoder_input_lengths,
encoder_max_input_length,
debug_mode=debug_mode,
prompt_embedding_table=prompt_embedding_table,
prompt_tasks=prompt_tasks,
prompt_vocab_size=prompt_vocab_size,
attention_mask=attention_mask)
if time_encoder:
tok = time.time()
print(f"TRT-LLM Encoder time {(tok-tik)*1000}ms")
else:
encoder_output = prompt_embedding_table
if encoder_input_ids.dim() > 1:
encoder_output = encoder_output.unsqueeze(0)
## decoder run
logger.info(f"Rank {self.runtime_rank} Running decoder engine ...")
decoder_input_ids, decoder_input_lengths, decoder_max_input_length, _ = self.process_input(
decoder_input_ids, self.decoder_model_config.remove_input_padding,
pad_token_id)
# `cross_attention_mask` in context phase [batch_size, query_len, encoder_input_len]
# where query_len happens to be 1 in current cases, but not necessarily always, and
# `cross_attention_mask` in generation phase [batch_size, 1, encoder_input_len] where
# the query_len is always 1 since we have kv cache.
cross_attention_mask = None
if attention_mask is not None:
cross_attention_mask = torch.tensor(attention_mask,
dtype=torch.int32,
device=self.device).reshape(
attention_mask.shape[0], 1,
attention_mask.shape[1])
# generation config
sampling_config = SamplingConfig(end_id=eos_token_id,
pad_id=pad_token_id,
num_beams=num_beams,
min_length=1,
return_dict=return_dict)
sampling_config.update(output_cum_log_probs=return_dict,
output_log_probs=return_dict)
# decoder autoregressive generation
self.decoder_session.setup(
decoder_input_lengths.size(0),
decoder_max_input_length,
max_new_tokens,
num_beams,
max_attention_window_size=None,
encoder_max_input_length=encoder_max_input_length,
lora_manager=self.decoder_lora_manager,
lora_uids=self.lora_task_uids,
)
output = self.decoder_session.decode(
decoder_input_ids,
decoder_input_lengths,
sampling_config,
encoder_output=encoder_output,
encoder_input_lengths=encoder_input_lengths,
return_dict=return_dict,
cross_attention_mask=cross_attention_mask)
if return_encoder_output:
return output, encoder_output
return output
def test_fairseq_models(args):
## Note: NMT is the only FairSeq model. Adding FairSeq dependency is too heavy for the CI workflow, hence we used fixed input/output ids for correctness check and leave FairSeq code in comments. Users can follow Encoder-Decoder's README to install FairSeq and test locally.
'''
from fairseq.models.transformer import TransformerModel
fairseq_model = TransformerModel.from_pretrained(model_name_or_path=args.model_name, data_name_or_path=args.model_name, bpe='subword_nmt', tokenizer='moses').cuda()
input_text = "Good Morning! How are you doing today?"
input_ids = fairseq_model.encode(input_text)
tik = time.time()
# Note: FairSeq sampling=True results are not deterministic, disable during accuracy check
fairseq_output_ids = fairseq_model.generate(input_ids, beam=1, sampling=False) #
tik = time.time()
fairseq_output_ids = fairseq_output_ids[0]['tokens']
fairseq_output_text = fairseq_model.decode(fairseq_output_ids)
print("--------------------------------------")
print("input text: ", input_text)
print("input ids: ", input_ids) # [9938, 5384, 9328, 812, 3619, 53, 181, 3829, 1735, 171, 2]
print("fairseq_output ids: ", fairseq_output_ids) # [9804, 391, 4, 4625, 167, 25, 1003, 5123, 17, 167, 1466, 1234, 171, 2]
print("fairseq_output text: ", fairseq_output_text) # "Bonjour, Comment vous en tirez-vous aujourd'hui ?"
print(f"FairSeq E2E time {(tok-tik)*1000}ms")
print("--------------------------------------")
'''
max_new_tokens = args.max_new_tokens
bos_token_id = 2
pad_token_id = 0
eos_token_id = 2
decoder_start_token_id = bos_token_id
input_ids = torch.tensor(
[9938, 5384, 9328, 812, 3619, 53, 181, 3829, 1735, 171, 2])
fairseq_output_ids = torch.tensor(
[9804, 391, 4, 4625, 167, 25, 1003, 5123, 17, 167, 1466, 1234, 171, 2])
input_ids = torch.tensor([input_ids.tolist()]).type(torch.IntTensor).cuda()
decoder_input_ids = torch.IntTensor([[decoder_start_token_id]]).cuda()
decoder_input_ids = decoder_input_ids.repeat((input_ids.shape[0], 1))
tllm_model = TRTLLMEncDecModel.from_engine(args.engine_name,
args.engine_dir,
debug_mode=args.debug_mode)
inference_dtype = tllm_model.encoder_model_config.dtype
return_dict = False # when set return_dict=True, get outputs by key
tik = time.time()
tllm_output = tllm_model.generate(
encoder_input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
num_beams=args.num_beams,
bos_token_id=bos_token_id,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
debug_mode=args.debug_mode,
)
tok = time.time()
torch.cuda.synchronize()
if return_dict:
tllm_output_ids = tllm_output['output_ids']
else:
tllm_output_ids = tllm_output
if tensorrt_llm.mpi_rank() == 0:
output_ids = tllm_output_ids[:, 0, :]
output_ids = output_ids[output_ids != eos_token_id]
fairseq_output_ids = fairseq_output_ids[
fairseq_output_ids != eos_token_id]
print("--------------------------------------")
print("TRT-LLM output_ids: ", output_ids)
print(f"TRT-LLM E2E time {(tok-tik)*1000}ms")
print("Precision:", inference_dtype)
print("--------------------------------------")
assert output_ids.tolist() == fairseq_output_ids.tolist(
), f"TRT-LLM output ids {output_ids} does not match Fairseq ids {fairseq_output_ids}"
if __name__ == "__main__":
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
args = parse_arguments()
logger.set_level(args.log_level)
# FairSeq NMT test logic is different from HuggingFace models
if 'wmt' in args.model_name:
test_fairseq_models(args)
exit()
test_remove_padding = True
if not test_remove_padding:
if 't5' in args.model_name:
input_text = "translate English to German: The house is wonderful, radiating timeless charm and offering a warm, inviting interior with beautiful details and a serene backyard."
elif 'bart' in args.model_name:
input_text = "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct."
else:
raise RuntimeError('Unsupported model type!')
else:
input_text = [
"translate English to German: The house is wonderful.",
"summarize: I am a high-performance inference optimizer and runtime.",
"During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world",
]
# TRT-LLM runtime
tllm_model = TRTLLMEncDecModel.from_engine(args.engine_name,
args.engine_dir,
args.lora_dir,
args.lora_task_uids,
debug_mode=args.debug_mode)
inference_dtype = tllm_model.encoder_model_config.dtype
if inference_dtype == 'float32':
if "byt5" in args.model_name:
print(
"ByT5 models tokenize input by bytes instead of words, causing the input text in this example to be longer than the default value during build stage. Please adjust --max_input_len during trtllm-build to select the right length limit for ByT5 models."
)
else:
input_text.append(
"Summarize this article in one sentence.\n\nKristine Watts (Molie Weeks) is broken apart, missing her lover; she is not able to overcome her love for him that is lost in the past. She hires a stranger (Douglas Davis) and gives a list of her mistakes to him with things to fix. But time is irreversible and sometimes the cure for the pain is a tragic end.\n\nThe first point that impresses in \"The Cure\" is the stylish cinematography that alternates black and white with color. The concise and sharp screenplay is capable to develop a tragic and bleak tale of love with an unexpected plot point in the very end in less than eight minutes. The soundtrack is beautiful but the volume is a little loud and associated to the fact that English is not my native language, in some moments I needed to repeat some words whispered by the narrator. The unknown lead actress has magnificent performance and is extremely gorgeous. I hope to have a chance to see her again on the screen. Last but not the least, the debut of the director and writer Ryan Jafri could not be better. My vote is nine.\n\nTitle (Brazil): Not Available",
)
tokenizer = AutoTokenizer.from_pretrained(
args.model_name) # TODO: use model path instead
tokenized_inputs = tokenizer(input_text, return_tensors="pt", padding=True)
max_new_tokens = args.max_new_tokens
input_ids = tokenized_inputs.input_ids.type(torch.IntTensor).to(
'cuda') # [batch_size, padded_length]
# by default int64, must cast to int32! otherwise C++ kernel will interpret as [a, 0, b, 0, c, 0, ...]
CPP_RESULTS_SAVED_DIR = 'cpp/tests/resources/data/enc_dec'
if tensorrt_llm.mpi_rank() == 0:
if args.output_encoder_npy:
if not os.path.isdir(CPP_RESULTS_SAVED_DIR):
os.mkdir(os.path.join(CPP_RESULTS_SAVED_DIR))
np_input_ids = tokenized_inputs.input_ids.type(torch.IntTensor)
np_input_ids = np_input_ids.numpy()
np.save(os.path.join(CPP_RESULTS_SAVED_DIR, 'enc_input_ids.npy'),
np_input_ids)
input_lengths = tokenized_inputs.attention_mask.sum(dim=1).type(
torch.IntTensor).numpy()
np.save(
os.path.join(CPP_RESULTS_SAVED_DIR, 'enc_input_lengths.npy'),
input_lengths)
print("--------------------------------------")
print(
f"BOS={tokenizer.bos_token_id}, PAD={tokenizer.pad_token_id}, EOS={tokenizer.eos_token_id}"
)
print("input text: ", input_text)
print("input ids: ", input_ids)
print("input lengths: ", tokenized_inputs.attention_mask.sum(dim=1))
print("--------------------------------------")
model_config = AutoConfig.from_pretrained(args.model_name)
# start_id for decoder (could add more input_ids as forced_decoder_ids)
decoder_input_ids = torch.IntTensor([[model_config.decoder_start_token_id]
]).to('cuda')
decoder_input_ids = decoder_input_ids.repeat((input_ids.shape[0], 1))
# simple comparison with HF on FP32
if args.compare_hf_fp32:
if tensorrt_llm.mpi_rank() == 0:
hf_model = AutoModelForSeq2SeqLM.from_pretrained(
args.model_name, # TODO: use model path instead
# torch_dtype=torch.float16 if '16' in dtype else torch.float32, # TODO: use matched torch dtype
).to('cuda').eval() # TODO: create config model path instead
assert type(hf_model) in (
T5ForConditionalGeneration, BartForConditionalGeneration,
MBartForConditionalGeneration), 'Unsupported model!'
if args.lora_dir is not None:
assert len(args.lora_dir
) >= 1, "At least one lora model dir is required"
# we can only test single lora with HF
from peft import PeftModel
hf_model = PeftModel.from_pretrained(
hf_model, args.lora_dir[0]).to('cuda').eval()
tik = time.time()
hf_gen_output = hf_model.generate(
input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
num_beams=args.num_beams,
bos_token_id=tokenizer.bos_token_id,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
# control logits processors
no_repeat_ngram_size=0, # disable no repeat post-processor
forced_bos_token_id=None, # disable forced first/last token
forced_eos_token_id=None,
min_length=0,
# for debug
output_scores=True,
output_hidden_states=True,
return_dict_in_generate=True)
# get hf output scores
hf_output_ids = hf_gen_output.sequences
# convert to logits
torch.cuda.synchronize()
tok = time.time()
output_ids = hf_output_ids.squeeze(dim=1)
hf_output_text = tokenizer.batch_decode(output_ids,
skip_special_tokens=True)
decoder_input_lengths = (decoder_input_ids !=
tokenizer.pad_token_id).sum(dim=1)
output_gen_lengths = (output_ids != tokenizer.eos_token_id).sum(
dim=1) - decoder_input_lengths
print("--------------------------------------")
print("HF output_ids: ", output_ids)
print("HF output text: ", hf_output_text)
print("HF output generated lengths: ", output_gen_lengths)
print(f"HF E2E time {(tok-tik)*1000}ms")
print("--------------------------------------")
return_dict = False # when set return_dict=True, get outputs by key
tik = time.time()
tllm_output = tllm_model.generate(
encoder_input_ids=input_ids,
decoder_input_ids=decoder_input_ids,
max_new_tokens=max_new_tokens,
num_beams=args.num_beams,
bos_token_id=tokenizer.bos_token_id,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
debug_mode=args.debug_mode,
return_dict=return_dict,
attention_mask=tokenized_inputs.attention_mask,
time_encoder=True,
return_encoder_output=args.output_encoder_npy
and tensorrt_llm.mpi_rank() == 0)
tok = time.time()
if args.output_encoder_npy and tensorrt_llm.mpi_rank() == 0:
tllm_output, encoder_output = tllm_output
encoder_output = encoder_output.cpu().numpy()
np.save(os.path.join(CPP_RESULTS_SAVED_DIR, 'encoder_output.npy'),
encoder_output)
if return_dict:
tllm_output_ids = tllm_output['output_ids']
else:
tllm_output_ids = tllm_output
if tensorrt_llm.mpi_rank() == 0:
output_ids = tllm_output_ids[:, 0, :]
output_text = tokenizer.batch_decode(output_ids,
skip_special_tokens=True)
decoder_input_lengths = (decoder_input_ids !=
tokenizer.pad_token_id).sum(dim=1)
output_gen_lengths = (output_ids != tokenizer.eos_token_id).sum(
dim=1) - decoder_input_lengths
print("--------------------------------------")
print("TRT-LLM output_ids: ", output_ids)
print("TRT-LLM output text: ", output_text)
print("TRT-LLM output generated lengths: ", output_gen_lengths)
print(f"TRT-LLM E2E time {(tok-tik)*1000}ms")
print("Precision:", inference_dtype)
print("--------------------------------------")
# simple accuracy check
if args.compare_hf_fp32:
from difflib import SequenceMatcher
match_rate = SequenceMatcher(None, "\n".join(output_text),
"\n".join(hf_output_text)).ratio()
print(output_text)
print(hf_output_text)
if inference_dtype != "float32":
print("")
print(
f"[CAVEAT] Comparing TRT-LLM {inference_dtype} results with HF float32 results. Close match are not expected!"
)
assert match_rate > 0.8, f"Incorrect results! Match rate {match_rate}"
else:
assert match_rate > 0.95, f"Incorrect results! Match rate {match_rate}"
print(
f"TRT-LLM results match HF FP32 results with literal match rate {match_rate}"
)