forked from NVIDIA/TensorRT-LLM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
204 lines (170 loc) · 7.72 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from ..._utils import pad_vocab_size
from ...functional import PositionEmbeddingType, Tensor, allreduce
from ...layers import (MLP, Attention, AttentionMaskType, ColumnLinear,
Embedding, LayerNorm)
from ...mapping import Mapping
from ...module import Module
from ..modeling_utils import DecoderLayerList, DecoderModelForCausalLM
from .config import GPTJConfig
from .convert import load_weights_from_hf_model
class GPTJDecoderLayer(Module):
def __init__(self, config: GPTJConfig, layer_idx: int):
super().__init__()
self.layer_idx = layer_idx
self.config = config
hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
rotary_dim = config.rotary_dim
dtype = config.dtype
tp_size = config.mapping.tp_size
tp_rank = config.mapping.tp_rank
layernorm_epsilon = config.norm_epsilon
self.input_layernorm = LayerNorm(normalized_shape=hidden_size,
eps=layernorm_epsilon,
dtype=dtype)
layers_range = config.mapping.pp_layers(config.num_hidden_layers)
local_layer_idx = layer_idx - layers_range[0]
self.attention = Attention(
local_layer_idx=local_layer_idx,
hidden_size=hidden_size,
num_attention_heads=num_attention_heads,
rotary_embedding_percentage=rotary_dim /
(hidden_size // num_attention_heads),
max_position_embeddings=config.max_position_embeddings,
attention_mask_type=AttentionMaskType.causal,
dtype=dtype,
tp_group=None,
tp_size=tp_size,
tp_rank=tp_rank,
bias=False,
position_embedding_type=PositionEmbeddingType.rope_gptj,
quant_mode=config.quant_mode)
self.mlp = MLP(hidden_size=hidden_size,
ffn_hidden_size=hidden_size * 4,
hidden_act=config.hidden_act,
dtype=dtype,
bias=True,
tp_group=None,
tp_size=tp_size,
quant_mode=config.quant_mode)
def forward(self,
hidden_states: Tensor,
attention_mask=None,
use_cache=False,
kv_cache_params=None,
attention_params=None):
assert isinstance(hidden_states, Tensor)
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attention_output = self.attention(hidden_states,
attention_mask=attention_mask,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params)
if use_cache:
attention_output, presents = attention_output
attention_output = attention_output
feed_forward_hidden_states = self.mlp(hidden_states)
hidden_states = attention_output + feed_forward_hidden_states
if self.config.mapping.tp_size > 1:
hidden_states = allreduce(hidden_states,
self.config.mapping.tp_group)
hidden_states = hidden_states + residual
if use_cache:
return (hidden_states, presents)
return hidden_states
class GPTJModel(Module):
def __init__(self, config: GPTJConfig):
super().__init__()
self.config = config
if config.mapping.is_first_pp_rank():
self.vocab_embedding = Embedding(config.vocab_size,
config.hidden_size,
dtype=config.dtype)
self.layers = DecoderLayerList(GPTJDecoderLayer, config)
if config.mapping.is_last_pp_rank():
self.ln_f = LayerNorm(normalized_shape=config.hidden_size,
dtype=config.dtype)
def forward(self,
input_ids: Tensor,
position_ids=None,
use_cache=False,
attention_mask=None,
kv_cache_params=None,
attention_params=None):
hidden_states = self.vocab_embedding(input_ids)
hidden_states = self.layers(hidden_states,
use_cache=use_cache,
attention_mask=attention_mask,
kv_cache_params=kv_cache_params,
attention_params=attention_params)
if use_cache:
hidden_states, presents = hidden_states
hidden_states = self.ln_f(hidden_states)
if use_cache:
return (hidden_states, tuple(presents))
return hidden_states
class GPTJForCausalLM(DecoderModelForCausalLM):
config_class = GPTJConfig
def __init__(self, config: GPTJConfig):
transformer = GPTJModel(config)
vocab_size_padded = pad_vocab_size(config.vocab_size,
config.mapping.tp_size)
if config.mapping.is_last_pp_rank():
lm_head = ColumnLinear(config.hidden_size,
vocab_size_padded,
bias=True,
dtype=config.dtype,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
gather_output=True)
else:
lm_head = None
super().__init__(config, transformer, lm_head)
@classmethod
def from_hugging_face(
cls,
hf_model_or_dir: Union[str, 'transformers.PreTrainedModel'],
dtype: str = 'auto',
mapping: Optional[Mapping] = None,
quant_config=None,
**kwargs):
import transformers
use_preloading = isinstance(hf_model_or_dir,
transformers.PreTrainedModel)
if use_preloading:
hf_model = hf_model_or_dir
hf_config_or_dir = hf_model.config
else:
hf_model_dir = hf_model_or_dir
hf_config_or_dir = hf_model_or_dir
config = GPTJConfig.from_hugging_face(hf_config_or_dir,
dtype=dtype,
mapping=mapping,
quant_config=quant_config,
**kwargs)
if not use_preloading:
trust_remote_code = kwargs.pop('trust_remote_code', True)
hf_model = transformers.AutoModelForCausalLM.from_pretrained(
hf_model_dir,
torch_dtype='auto',
trust_remote_code=trust_remote_code)
weights = load_weights_from_hf_model(hf_model, config)
model = GPTJForCausalLM(config)
model.load(weights)
return model