-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmlp.py
325 lines (240 loc) · 12.1 KB
/
mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
from __future__ import annotations
import torch
import torch.nn.functional as F
from torch import nn
from exllamav2.module import ExLlamaV2Module
from exllamav2.rmsnorm import ExLlamaV2RMSNorm
from exllamav2.layernorm import ExLlamaV2LayerNorm
from exllamav2.linear import ExLlamaV2Linear
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
from exllamav2.lora import ExLlamaV2Lora
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from exllamav2.model import ExLlamaV2
class ExLlamaV2MLP(ExLlamaV2Module):
name: str = "MLP"
layer_idx: int
post_attention_layernorm: ExLlamaV2RMSNorm | ExLlamaV2LayerNorm | None
gate_proj: ExLlamaV2Linear | None
up_proj: ExLlamaV2Linear | None
down_proj: ExLlamaV2Linear | None
q_handle: int | None
temp_lora_size: int
has_norm: bool
has_residual: bool
def __init__(self,
model: ExLlamaV2,
key: str,
layer_idx: int,
has_norm: bool = True,
has_residual: bool = True):
super().__init__(model, key)
cfg = self.model.config
self.layer_idx = layer_idx
self.has_norm = has_norm
self.has_residual = has_residual
self.q_handle = None
self.temp_lora_size = 0
f_a = 0
f_b = cfg.intermediate_size
f_c = f_b + cfg.intermediate_size
f_key = (key + ".mlp." + cfg.arch.fused_mlp_key_12) if cfg.arch.fused_mlp_key_12 else None
if self.has_norm:
if cfg.arch.norm == "layernorm":
self.post_attention_layernorm = ExLlamaV2LayerNorm(model, key + cfg.arch.norm_key_2)
elif cfg.arch.norm == "rmsnorm":
self.post_attention_layernorm = ExLlamaV2RMSNorm(model, key + cfg.arch.norm_key_2)
else:
self.post_attention_layernorm = None
self.up_proj = ExLlamaV2Linear(model, key + cfg.arch.mlp_key_up, cfg.hidden_size, cfg.intermediate_size, self.model.config.arch.mlp_bias, f_key = f_key, f_beg = f_b, f_end = f_c)
self.down_proj = ExLlamaV2Linear(model, key + cfg.arch.mlp_key_down, cfg.intermediate_size, cfg.hidden_size, self.model.config.arch.mlp_bias)
self.submodules = [self.up_proj,
self.down_proj]
if self.has_norm:
self.submodules += [self.post_attention_layernorm]
if cfg.arch.mlp_gate:
self.gate_proj = ExLlamaV2Linear(model, key + cfg.arch.mlp_key_gate, cfg.hidden_size, cfg.intermediate_size, self.model.config.arch.mlp_bias, f_key = f_key, f_beg = f_a, f_end = f_b)
self.submodules += [self.gate_proj]
else:
self.gate_proj = None
def numel(self) -> int:
numel = self.up_proj.numel() + \
self.down_proj.numel()
if self.model.config.arch.mlp_gate:
numel += self.gate_proj.numel()
if self.post_attention_layernorm is not None:
numel += self.post_attention_layernorm.numel()
return numel
def load(self):
cfg = self.model.config
if self.post_attention_layernorm is not None:
self.post_attention_layernorm.load()
if cfg.checkpoint_fused_mlp:
w12 = self.load_weight(self.key + cfg.arch.fused_mlp_key_12)
w1 = nn.Parameter(w12[:cfg.intermediate_size, :].contiguous())
w2 = nn.Parameter(w12[cfg.intermediate_size:, :].contiguous())
w3 = self.load_weight(self.key + cfg.arch.fused_mlp_key_3)
self.gate_proj.load(w1)
self.up_proj.load(w2)
self.down_proj.load(w3)
else:
if self.gate_proj is not None: self.gate_proj.load()
self.up_proj.load()
self.down_proj.load()
if self.up_proj.is_quant():
assert self.gate_proj is None or self.gate_proj.is_quant()
assert self.up_proj.is_quant(), "Partially quantized MLP layer"
device_tensors = self.model.get_device_tensors(self.device_idx)
device_tensors.begin_scratch_alloc()
if self.has_norm:
norm_weight = self.post_attention_layernorm.weight if self.post_attention_layernorm.weight is not None else none_tensor
norm_bias = self.post_attention_layernorm.bias if self.post_attention_layernorm.bias is not None else none_tensor
is_rms = isinstance(self.post_attention_layernorm, ExLlamaV2RMSNorm)
eps = self.post_attention_layernorm.variance_epsilon
else:
norm_weight = none_tensor
norm_bias = none_tensor
is_rms = False
eps = 0
self.q_handle = ext_c.make_q_mlp(norm_weight,
norm_bias,
is_rms,
eps,
0 if self.gate_proj is None else self.gate_proj.q_handle,
self.up_proj.q_handle,
self.down_proj.q_handle,
device_tensors.get_scratch_slice(self.temp_state_size()),
device_tensors.get_scratch_slice(self.temp_a_size()),
device_tensors.get_scratch_slice(self.temp_b_size()),
device_tensors.get_scratch_slice(self.temp_dq_size()),
cfg.max_input_len * cfg.max_batch_size,
cfg.arch.mlp_act_func == "gelu",
self.has_residual)
def unload(self):
if self.q_handle is not None:
ext_c.free_q_mlp(self.q_handle)
self.q_handle = None
if self.post_attention_layernorm is not None: self.post_attention_layernorm.unload()
if self.gate_proj is not None: self.gate_proj.unload()
self.up_proj.unload()
self.down_proj.unload()
def weight_footprint(self) -> int:
if self.model.config.checkpoint_fused_mlp:
fp = 3 * self.model.config.intermediate_size * self.model.config.hidden_size * 2
else:
fp = self.up_proj.weight_footprint() + \
self.down_proj.weight_footprint()
if self.gate_proj is not None:
fp += self.gate_proj.weight_footprint()
if self.post_attention_layernorm is not None:
fp += self.post_attention_layernorm.weight_footprint()
return fp
def scratch_space_fixed(self) -> int:
return self.temp_state_size() + \
self.temp_a_size() + \
self.temp_b_size() + \
self.temp_dq_size()
def scratch_space(self) -> int:
assert self.model.config.intermediate_size >= self.model.config.hidden_size
return self.temp_state_size() + \
self.temp_a_size() + \
self.temp_b_size() + \
self.temp_dq_size()
def temp_state_size(self) -> int:
return self.model.config.max_input_len * self.model.config.max_batch_size * self.model.config.hidden_size * 2 + 128
def temp_a_size(self) -> int:
return self.model.config.max_input_len * self.model.config.max_batch_size * self.model.config.intermediate_size * 2 + 128
def temp_b_size(self) -> int:
return self.model.config.max_input_len * self.model.config.max_batch_size * self.model.config.intermediate_size * 2 + 128
def temp_dq_size(self) -> int:
return max(0 if self.gate_proj is None else self.gate_proj.temp_dq_size(),
self.up_proj.temp_dq_size(),
self.down_proj.temp_dq_size())
def set_device_idx(self, idx: int):
super().set_device_idx(idx)
if self.post_attention_layernorm is not None:
self.post_attention_layernorm.set_device_idx(idx)
if self.gate_proj is not None: self.gate_proj.set_device_idx(idx)
self.up_proj.set_device_idx(idx)
self.down_proj.set_device_idx(idx)
def forward(self,
hidden_states: torch.Tensor,
cache = None,
attn_params = None,
past_len = None,
intermediates: bool = False,
loras: list[ExLlamaV2Lora] | None = None,
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:
if self.q_handle is None or intermediates:
return self.forward_torch(hidden_states, cache, attn_params, past_len, intermediates, loras = loras, **kwargs)
if loras is None or self.temp_lora_size == 0:
pass_loras = []
pass_lora_temp = none_tensor
else:
pass_loras = [id(x) for x in loras]
pass_lora_temp = torch.empty((self.temp_lora_size,), dtype = torch.half, device = hidden_states.device)
ext_c.q_mlp_forward_(self.q_handle,
hidden_states.view(-1, hidden_states.shape[-1]),
pass_loras,
pass_lora_temp)
return hidden_states
def forward_torch(self,
hidden_states: torch.Tensor,
cache = None,
attn_params = None,
past_len = None,
intermediates: bool = False,
loras: list[ExLlamaV2Lora] | None = None,
**kwargs) -> torch.Tensor | dict[str: torch.Tensor]:
residual = hidden_states
post_norm = self.post_attention_layernorm.forward(hidden_states) \
if self.has_norm else hidden_states
if self.gate_proj is not None:
gate = self.gate_proj.forward(post_norm, loras = loras)
if self.model.config.arch.mlp_act_func == "silu":
y = F.silu(gate)
elif self.model.config.arch.mlp_act_func == "gelu":
y = F.gelu(gate)
up = self.up_proj.forward(post_norm, loras = loras)
y *= up
y.clamp_(min = -65504.0, max = 65504.0)
else:
up = self.up_proj.forward(post_norm, loras = loras)
if self.model.config.arch.mlp_act_func == "silu":
y = F.silu(up)
elif self.model.config.arch.mlp_act_func == "gelu":
y = F.gelu(up)
down = self.down_proj.forward(y, loras = loras)
hidden_states = down + residual if self.has_residual else down
if intermediates:
return {"post_norm": post_norm,
"pre_down": y,
"hidden_states": hidden_states}
else:
return hidden_states
def update_loras(self):
if self.q_handle is None: return
if self.gate_proj is None:
gate_proj_lora_a = {}
gate_proj_lora_b = {}
else:
gate_proj_lora_a = { id(k): v for k, v in self.gate_proj.lora_a_tensors.items() }
gate_proj_lora_b = { id(k): v for k, v in self.gate_proj.lora_b_tensors.items() }
up_proj_lora_a = { id(k): v for k, v in self.up_proj.lora_a_tensors.items() }
up_proj_lora_b = { id(k): v for k, v in self.up_proj.lora_b_tensors.items() }
down_proj_lora_a = { id(k): v for k, v in self.down_proj.lora_a_tensors.items() }
down_proj_lora_b = { id(k): v for k, v in self.down_proj.lora_b_tensors.items() }
temp_lora_size = ext_c.q_mlp_set_loras(self.q_handle,
gate_proj_lora_a,
gate_proj_lora_b,
up_proj_lora_a,
up_proj_lora_b,
down_proj_lora_a,
down_proj_lora_b)
self.temp_lora_size = temp_lora_size * self.model.config.max_batch_size * self.model.config.max_input_len
def is_quant(self):
return self.q_handle is not None
def rank_reduce(self, k):
if self.gate_proj is not None: self.gate_proj.rank_reduce(k)
self.up_proj.rank_reduce(k)
self.down_proj.rank_reduce(k)