forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
helper.py
executable file
·93 lines (76 loc) · 3.25 KB
/
helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import typing
from typing import Union
import numpy as np
import torch # pytype: disable=import-error
from tensorrt_llm._utils import str_dtype_to_torch
def split(v: Union[np.ndarray, torch.Tensor],
tp_size: int,
tp_rank: int,
dim=0):
if tp_size == 1:
if isinstance(v, np.ndarray):
return np.ascontiguousarray(v.copy())
else:
return v.clone().detach()
assert len(v.shape) > 1 or dim == 0
if isinstance(v, np.ndarray):
return np.ascontiguousarray(
np.split(v, tp_size, axis=dim)[tp_rank].copy())
else:
assert v.shape[dim] % tp_size == 0, \
'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.'
split_size = v.shape[dim] // tp_size
return v.split(split_size, dim=dim)[tp_rank].clone().detach()
def reshape(v: torch.Tensor, shape=None):
if shape is None:
return v.contiguous()
else:
return v.reshape(shape).contiguous()
def fuse_qkv_one_layer(params, attn_module_name, trtllm_layer_name, tp_size,
tp_rank, model_type, weight_shape, bias_shape):
qkv_module_names = get_qkv_module_name(model_type)
weight = {}
# fuse weights of q, k, v
q_w = params[f'{attn_module_name}.{qkv_module_names["q"]}.weight']
k_w = params[f'{attn_module_name}.{qkv_module_names["k"]}.weight']
v_w = params[f'{attn_module_name}.{qkv_module_names["v"]}.weight']
# fuse qkv weight
shape = q_w.shape # (do, din)
qkv_w = torch.cat([q_w, k_w, v_w],
dim=0).reshape([3, shape[0], shape[1]]) # (3, do, din)
qkv_w = split(qkv_w, tp_size, tp_rank, dim=1)
weight[f'{trtllm_layer_name}.qkv.weight'] = reshape(qkv_w,
shape=weight_shape)
# fuse qkv biases if present
if f'{attn_module_name}.{qkv_module_names["q"]}.bias' in params.keys(
) and params[f'{attn_module_name}.{qkv_module_names["q"]}.bias'] is not None:
q_b = params[f'{attn_module_name}.{qkv_module_names["q"]}.bias']
k_b = params[f'{attn_module_name}.{qkv_module_names["k"]}.bias']
v_b = params[f'{attn_module_name}.{qkv_module_names["v"]}.bias']
shape = q_b.shape[0] # (do,)
qkv_b = torch.cat([q_b, k_b, v_b], dim=0).reshape([3, shape]) # (3, do)
qkv_b = split(qkv_b, tp_size, tp_rank, dim=1)
weight[f'{trtllm_layer_name}.qkv.bias'] = reshape(qkv_b,
shape=bias_shape)
return weight
def get_qkv_module_name(model_type):
if model_type in ["t5", "blip2"]:
q = "q"
k = "k"
v = "v"
elif model_type == "bart" or model_type == "nmt":
q = "q_proj"
k = "k_proj"
v = "v_proj"
elif model_type == "pix2struct":
q = "query"
k = "key"
v = "value"
return {"q": q, "k": k, "v": v}
def convert_weight_to_dtype(params: typing.Dict[str, torch.Tensor],
dtype: typing.Optional[np.dtype] = None):
if dtype is not None:
assert isinstance(dtype,
str), f"dtype must be str, but get type {type(dtype)}"
for name in params.keys():
params[name] = params[name].to(str_dtype_to_torch(dtype))