-
Notifications
You must be signed in to change notification settings - Fork 78
/
llama3.py
290 lines (225 loc) · 11.2 KB
/
llama3.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
from __future__ import annotations
import math
import sys
import time
from typing import TypeVar, Generic, Optional
import numpy as np
from config import ModelArgs
from tokenizer import Tokenizer
from utils import load_parameters
Shape = TypeVar("Shape")
class Array(np.ndarray, Generic[Shape]): ...
def softmax(x):
exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
def silu(x):
return x * (1 / (1 + np.exp(-x)))
def compute_cos_sin_cache(head_dim: int, max_seq_len: int, base: int = 10000):
inv_freq: Array["HD//2"] = 1.0 / (base ** (np.arange(0, head_dim, 2)[: (head_dim // 2)] / head_dim))
t: Array["M"] = np.arange(max_seq_len)
freqs: Array["M, HD//2"] = np.outer(t, inv_freq)
return np.cos(freqs), np.sin(freqs)
def apply_rotary_emb(xq: Array["B, L or 1, QHN, HD"], xk: Array["B, L or 1, KVHN, HD"],
freqs_cos: Array["L or 1, HD//2"], freqs_sin: Array["L or 1, HD//2"]):
# ["B, L or 1, QHN, HD"] -> ["B, L or 1, QHN, HD//2, 2"]
xqri: Array["B, L or 1, QHN, HD//2, 2"] = xq.reshape(xq.shape[:-1] + (-1, 2))
xkri: Array["B, L or 1, KVHN, HD//2, 2"] = xk.reshape(xk.shape[:-1] + (-1, 2))
# Reshape `xq` and `xk` to match the complex representation.
xq_r, xq_i = np.split(xqri, 2, axis=-1)
xq_r: Array["B, L or 1, QHN, HD//2"] = xq_r.squeeze(-1)
xq_i: Array["B, L or 1, QHN, HD//2"] = xq_i.squeeze(-1)
xk_r, xk_i = np.split(xkri, 2, axis=-1)
xk_r: Array["B, L or 1, KVHN, HD//2"] = xk_r.squeeze(-1)
xk_i: Array["B, L or 1, KVHN, HD//2"] = xk_i.squeeze(-1)
# Reshape `freqs_cos` and `freqs_sin` for broadcasting.
freqs_cos: Array["B, L or 1, 1, HD//2"] = np.expand_dims(freqs_cos, axis=(0, 2))
freqs_sin: Array["B, L or 1, 1, HD//2"] = np.expand_dims(freqs_sin, axis=(0, 2))
# Apply rotation using real numbers.
xq_out_r: Array["B, L or 1, QHN, HD//2"] = xq_r * freqs_cos - xq_i * freqs_sin
xq_out_i: Array["B, L or 1, QHN, HD//2"] = xq_r * freqs_sin + xq_i * freqs_cos
xk_out_r: Array["B, L or 1, KVHN, HD//2"] = xk_r * freqs_cos - xk_i * freqs_sin
xk_out_i: Array["B, L or 1, KVHN, HD//2"] = xk_r * freqs_sin + xk_i * freqs_cos
# Flatten last two dimensions.
xq_out: Array["B, L or 1, QHN, HD//2, 2"] = np.stack([xq_out_r, xq_out_i], axis=-1)
xk_out: Array["B, L or 1, KVHN, HD//2, 2"] = np.stack([xk_out_r, xk_out_i], axis=-1)
xq_out: Array["B, L or 1, QHN, HD"] = xq_out.reshape(xq_out.shape[:-2] + (-1,))
xk_out: Array["B, L or 1, KVHN, HD"] = xk_out.reshape(xk_out.shape[:-2] + (-1,))
return xq_out, xk_out
def repeat_kv(x: Array["B, L, KVHN, HD"], n_rep: int):
if n_rep == 1:
return x
z: Array["B, L, QHN, HD"] = np.repeat(x, n_rep, axis=2)
return z
class FeedForward:
def __init__(self, up_weight: Array["FD, D"], gate_weight: Array["FD, D"], down_weight: Array["D, FD"]):
self.up_weight = up_weight.T
self.gate_weight = gate_weight.T
self.down_weight = down_weight.T
def __call__(self, x: Array["B, L or 1, D"]):
# FD = 2 * 4 * D / 3
swish: Array["B, L or 1, FD"] = silu(x @ self.gate_weight)
x_V: Array["B, L or 1, FD"] = x @ self.up_weight
x: Array["B, L or 1, FD"] = swish * x_V
x: Array["B, L or 1, D"] = x @ self.down_weight
return x
class RMSNorm:
def __init__(self, weight: Array["H"], eps: float):
self.weight = weight
self.eps = eps
def __call__(self, x: Array["B, L or 1, D"]):
z: Array["B, L or 1, 1"] = (x ** 2).mean(-1, keepdims=True) + self.eps
z: Array["B, L or 1, D"] = x / np.sqrt(z)
return z * self.weight
class Attention:
def __init__(self, q_weight: Array["D, D"], k_weight: Array["D, D"], v_weight: Array["D, D"],
o_weight: Array["D, D"], args: ModelArgs):
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_local_heads = args.n_heads
self.n_local_kv_heads = self.n_kv_heads
self.n_rep = self.n_local_heads // self.n_local_kv_heads
self.head_dim = args.dim // args.n_heads
self.q_weight = q_weight.T
self.k_weight = k_weight.T
self.v_weight = v_weight.T
self.o_weight = o_weight.T
self.cache_k = np.zeros((args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim))
self.cache_v = np.zeros((args.max_batch_size, args.max_seq_len, self.n_local_kv_heads, self.head_dim))
def __call__(self, x: Array["B, L or 1, D"], start_pos: int, mask: Optional[Array["L, L"]],
freqs_cos: Array["L or 1, HD//2"], freqs_sin: Array["L or 1, HD//2"]):
B, L, _ = x.shape
# QKV
xq: Array["B, L or 1, D"] = x @ self.q_weight
xk: Array["B, L or 1, D"] = x @ self.k_weight
xv: Array["B, L or 1, D"] = x @ self.v_weight
xq: Array["B, L or 1, QHN, HD"] = xq.reshape(B, L, self.n_local_heads, self.head_dim)
xk: Array["B, L or 1, KVHN, HD"] = xk.reshape(B, L, self.n_local_kv_heads, self.head_dim)
xv: Array["B, L or 1, KVHN, HD"] = xv.reshape(B, L, self.n_local_kv_heads, self.head_dim)
# RoPE #2
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
# KV Cache
self.cache_k[:B, start_pos: start_pos + L] = xk
self.cache_v[:B, start_pos: start_pos + L] = xv
ks: Array["B, L, KVHN, HD"] = self.cache_k[:B, : start_pos + L]
vs: Array["B, L, KVHN, HD"] = self.cache_v[:B, : start_pos + L]
# GQA
xk: Array["B, L, HN, HD"] = repeat_kv(ks, self.n_rep)
xv: Array["B, L, HN, HD"] = repeat_kv(vs, self.n_rep)
# ["B, L, HN, HD"] -> ["B, HN, L, HD"]
xq: Array["B, HN, L or 1, HD"] = xq.transpose(0, 2, 1, 3)
xk: Array["B, HN, L, HD"] = xk.transpose(0, 2, 1, 3)
xv: Array["B, HN, L, HD"] = xv.transpose(0, 2, 1, 3)
# Scaled Dot-Product Attention
# ["B, HN, L or 1, HD"] @ ["B, HN, HD, L"] -> ["B, HN, L or 1, L"]
attention: Array["B, HN, L or 1, L"] = xq @ xk.transpose(0, 1, 3, 2) / math.sqrt(self.head_dim)
# `mask` is used only once at the beginning.
if mask is not None:
attention = attention + mask[None, None, :, :]
attention = softmax(attention)
output: Array["B, HN, L or 1, HD"] = attention @ xv
# ["B, HN, L or 1, HD"] -> ["B, L or 1, D"]
output: Array["B, L or 1, D"] = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
output: Array["B, L or 1, D"] = output @ self.o_weight
return output
class TransformerBlock:
def __init__(self, weight: dict, layer_id: int, args: ModelArgs):
self.attention = Attention(
weight.get(f"model.layers.{layer_id}.self_attn.q_proj.weight"),
weight.get(f"model.layers.{layer_id}.self_attn.k_proj.weight"),
weight.get(f"model.layers.{layer_id}.self_attn.v_proj.weight"),
weight.get(f"model.layers.{layer_id}.self_attn.o_proj.weight"),
args
)
self.feed_forward = FeedForward(
weight.get(f"model.layers.{layer_id}.mlp.up_proj.weight"),
weight.get(f"model.layers.{layer_id}.mlp.gate_proj.weight"),
weight.get(f"model.layers.{layer_id}.mlp.down_proj.weight"),
)
self.input_layernorm = RMSNorm(
weight.get(f"model.layers.{layer_id}.input_layernorm.weight"),
eps=args.norm_eps
)
self.post_attention_layernorm = RMSNorm(
weight.get(f"model.layers.{layer_id}.post_attention_layernorm.weight"),
eps=args.norm_eps
)
def __call__(self, x: Array["B, L or 1, D"], start_pos: int, mask: Array["L, L"],
freqs_cos: Array["L or 1, HD//2"], freqs_sin: Array["L or 1, HD//2"]):
# RMSNorm
norm_x: Array["B, L or 1, D"] = self.input_layernorm(x)
# Masked Multi-Head Attention
h1: Array["B, L or 1, D"] = self.attention(norm_x, start_pos, mask, freqs_cos, freqs_sin)
z = x + h1
# RMSNorm
norm_z = self.post_attention_layernorm(z)
# Feed Forward + SwiGLU
h2: Array["B, L or 1, D"] = self.feed_forward(norm_z)
out = z + h2
return out
class Llama:
def __init__(self, model_path: str, args: ModelArgs):
self.args = args
weight = load_parameters(model_path)
self.tok_embedding: Array["VS, D"] = weight.get("model.embed_tokens.weight")
# RoPE #1
self.freqs_cos, self.freqs_sin = compute_cos_sin_cache(args.dim // args.n_heads, args.max_seq_len)
self.layers = []
for layer_id in range(args.n_layers):
self.layers.append(TransformerBlock(weight, layer_id, args))
self.norm = RMSNorm(weight.get("model.norm.weight"), eps=args.norm_eps)
self.lm_head_weight: Array["D, VS"] = weight.get("lm_head.weight").T
del weight
def __call__(self, input_ids: Array["B, L"], start_pos: int):
_, L = input_ids.shape
h: Array["B, L or 1, D"] = self.tok_embedding[input_ids]
# ["M, HD//2"] -> ["L or 1, HD//2"]
freqs_cos: Array["L or 1, HD//2"] = self.freqs_cos[start_pos: start_pos + L]
freqs_sin: Array["L or 1, HD//2"] = self.freqs_sin[start_pos: start_pos + L]
# `mask` is generated only once at the beginning.
mask: Array["L, L"] = None
if L > 1:
mask = np.full((L, L), float("-inf"))
mask = np.triu(mask, k=1)
mask = np.concatenate([np.zeros((L, start_pos)), mask], axis=1)
# Transformer Layers
for i, layer in enumerate(self.layers):
h: Array["B, L or 1, D"] = layer(h, start_pos, mask, freqs_cos, freqs_sin)
# RMSNorm
h: Array["B, L or 1, D"] = self.norm(h)
# Only forward the output from the last position.
# ["B, 1, VS"] = ["B, 1(L), D"] @ ["D, VS"]
logit: Array["B, 1, VS"] = h[:, [-1], :] @ self.lm_head_weight
return logit
def generate(self, input_ids: Array["B, L"], max_new_tokens: int):
_, L = input_ids.shape
for i, curr_pos in enumerate(range(L, max_new_tokens)):
if i == 0: # Prefill Phase
inputs = input_ids
pos = 0
else: # Decode Phase
inputs = next_id
pos = curr_pos
logits: Array["B, 1, VS"] = self(inputs, pos)
next_id = logits[:, -1, :].argmax(-1, keepdims=True)
yield next_id
if __name__ == '__main__':
args = ModelArgs()
tokenizer = Tokenizer("./tokenizer.model.np")
model = Llama("./stories15M.model.npz", args)
if len(sys.argv) == 1:
prompt = "I have a dream"
else:
prompt = sys.argv[1]
print(f"\n{prompt}", end="")
input_ids = np.array([tokenizer.encode(prompt)])
start = time.time()
_, L = input_ids.shape
for id in model.generate(input_ids, args.max_new_tokens):
L += 1
output_id = id[0].tolist()
if output_id[-1] in [tokenizer.eos_id, tokenizer.bos_id]:
break
print(tokenizer.decode(output_id), end="")
sys.stdout.flush()
elapsed = time.time() - start
print(f"\n\nToken count: {L}, elapsed: {elapsed:.2f}s, {round(L / elapsed)} tokens/s")