From 10b7531f83117c7510572cdf4f6b2d736c11f90a Mon Sep 17 00:00:00 2001 From: James Delancey Date: Sun, 2 Jun 2024 21:15:44 -0700 Subject: [PATCH] black/isort formatter --- .gitignore | 1 + export.py | 317 ++++++++++++++++++++++++++++++++++----------------- model.py | 99 ++++++++++------ tokenizer.py | 6 +- 4 files changed, 287 insertions(+), 136 deletions(-) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..b2d4dc68 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +.vscode/ diff --git a/export.py b/export.py index 08b13bef..a3c723d9 100644 --- a/export.py +++ b/export.py @@ -14,12 +14,13 @@ This script aspires to provide all of these conversions. """ -import os + +import argparse import gzip +import json +import os import shutil import struct -import argparse -import json from pathlib import Path import numpy as np @@ -31,18 +32,21 @@ # ----------------------------------------------------------------------------- # common utilities + def serialize_fp32(file, tensor): - """ writes one fp32 tensor to file that is open in wb mode """ + """writes one fp32 tensor to file that is open in wb mode""" d = tensor.detach().cpu().view(-1).to(torch.float32).numpy() - b = struct.pack(f'{len(d)}f', *d) + b = struct.pack(f"{len(d)}f", *d) file.write(b) + def serialize_int8(file, tensor): - """ writes one int8 tensor to file that is open in wb mode """ + """writes one int8 tensor to file that is open in wb mode""" d = tensor.detach().cpu().view(-1).numpy().astype(np.int8) - b = struct.pack(f'{len(d)}b', *d) + b = struct.pack(f"{len(d)}b", *d) file.write(b) + def quantize_q80(w, group_size): """ takes a tensor and returns the Q8_0 quantized version @@ -50,18 +54,18 @@ def quantize_q80(w, group_size): """ assert w.numel() % group_size == 0 ori_shape = w.shape - w = w.float() # convert to float32 + w = w.float() # convert to float32 w = w.reshape(-1, group_size) # find the max in each group wmax = torch.abs(w).max(dim=1).values # calculate the scaling factor such that float = quant * scale scale = wmax / 127.0 # scale into range [-127, 127] - quant = w / scale[:,None] + quant = w / scale[:, None] # round to nearest integer int8val = torch.round(quant).to(torch.int8) # dequantize by rescaling - fp32val = (int8val.float() * scale[:,None]).view(-1) + fp32val = (int8val.float() * scale[:, None]).view(-1) fp32valr = fp32val.reshape(-1, group_size) # calculate the max error in each group err = torch.abs(fp32valr - w).max(dim=1).values @@ -69,12 +73,14 @@ def quantize_q80(w, group_size): maxerr = err.max().item() return int8val, scale, maxerr + # ----------------------------------------------------------------------------- # legacy + def legacy_export(model, filepath): - """ Original export of llama2.c bin files, i.e. version v0 """ - out_file = open(filepath, 'wb') + """Original export of llama2.c bin files, i.e. version v0""" + out_file = open(filepath, "wb") # first write out the header hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] @@ -84,8 +90,16 @@ def legacy_export(model, filepath): if not shared_classifier: p.vocab_size = -p.vocab_size n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads - header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, - n_kv_heads, p.vocab_size, p.max_seq_len) + header = struct.pack( + "iiiiiii", + p.dim, + hidden_dim, + p.n_layers, + p.n_heads, + n_kv_heads, + p.vocab_size, + p.max_seq_len, + ) out_file.write(header) # next write out the embedding weights @@ -115,8 +129,8 @@ def legacy_export(model, filepath): # final rmsnorm serialize_fp32(out_file, model.norm.weight) # freqs_cis - serialize_fp32(out_file, model.freqs_cos[:p.max_seq_len]) - serialize_fp32(out_file, model.freqs_sin[:p.max_seq_len]) + serialize_fp32(out_file, model.freqs_cos[: p.max_seq_len]) + serialize_fp32(out_file, model.freqs_sin[: p.max_seq_len]) # final classifier weights if not shared_classifier: @@ -126,9 +140,11 @@ def legacy_export(model, filepath): out_file.close() print(f"wrote {filepath}") + # ----------------------------------------------------------------------------- # new version + def version1_export(model, filepath): """ Export the model weights in full float32 .bin file to be read from C. @@ -136,25 +152,33 @@ def version1_export(model, filepath): """ version = 1 - out_file = open(filepath, 'wb') + out_file = open(filepath, "wb") # first write out the header. the header will be 256 bytes # 1) write magic, which will be uint32 of "ak42" in ASCII - out_file.write(struct.pack('I', 0x616b3432)) + out_file.write(struct.pack("I", 0x616B3432)) # 2) write version, which will be int - out_file.write(struct.pack('i', version)) + out_file.write(struct.pack("i", version)) # 3) write the params, which will be 7 ints p = model.params hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads - header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, - n_kv_heads, p.vocab_size, p.max_seq_len) + header = struct.pack( + "iiiiiii", + p.dim, + hidden_dim, + p.n_layers, + p.n_heads, + n_kv_heads, + p.vocab_size, + p.max_seq_len, + ) out_file.write(header) # 4) write some other flags shared_classifier = torch.equal(model.tok_embeddings.weight, model.output.weight) - out_file.write(struct.pack('B', int(shared_classifier))) - pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos + out_file.write(struct.pack("B", int(shared_classifier))) + pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos assert pad >= 0 - out_file.write(b'\0' * pad) + out_file.write(b"\0" * pad) # now let's write out all the params weights = [ @@ -179,6 +203,7 @@ def version1_export(model, filepath): out_file.close() print(f"wrote {filepath}") + def version2_export(model, filepath, group_size=64): """ Export the model weights in Q8_0 into .bin file to be read from C. @@ -207,36 +232,46 @@ def version2_export(model, filepath, group_size=64): if not shared_classifier: weights.append(model.output.weight) for w in weights: - assert w.numel() % group_size == 0, f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}" + assert ( + w.numel() % group_size == 0 + ), f"weight {i} has numel {w.numel()}, not a multiple of group_size {group_size}" # write - out_file = open(filepath, 'wb') + out_file = open(filepath, "wb") # first write out the header. the header will be 256 bytes # 1) write magic, which will be uint32 of "ak42" in ASCII - out_file.write(struct.pack('I', 0x616b3432)) + out_file.write(struct.pack("I", 0x616B3432)) # 2) write version, which will be int - out_file.write(struct.pack('i', version)) + out_file.write(struct.pack("i", version)) # 3) write the params, which will be 7 ints p = model.params hidden_dim = model.layers[0].feed_forward.w1.weight.shape[0] n_kv_heads = p.n_heads if p.n_kv_heads is None else p.n_kv_heads - header = struct.pack('iiiiiii', p.dim, hidden_dim, p.n_layers, p.n_heads, - n_kv_heads, p.vocab_size, p.max_seq_len) + header = struct.pack( + "iiiiiii", + p.dim, + hidden_dim, + p.n_layers, + p.n_heads, + n_kv_heads, + p.vocab_size, + p.max_seq_len, + ) out_file.write(header) # 4) write some other flags - out_file.write(struct.pack('B', int(shared_classifier))) - out_file.write(struct.pack('i', group_size)) # group size used for quantization - pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos + out_file.write(struct.pack("B", int(shared_classifier))) + out_file.write(struct.pack("i", group_size)) # group size used for quantization + pad = 256 - out_file.tell() # pad rest with zeros; tell returns current pos assert pad >= 0 - out_file.write(b'\0' * pad) + out_file.write(b"\0" * pad) # now that the header is done, let's write out the model # first let's write out all the params that we are keeping in fp32: the norms - for layer in model.layers: # attention norms + for layer in model.layers: # attention norms serialize_fp32(out_file, layer.attention_norm.weight) - for layer in model.layers: # MLP norms + for layer in model.layers: # MLP norms serialize_fp32(out_file, layer.ffn_norm.weight) - serialize_fp32(out_file, model.norm.weight) # final pre-classifier norm + serialize_fp32(out_file, model.norm.weight) # final pre-classifier norm # now let's write out all the params that we are quantizing to Q8_0 # note we skip classifier weights, which are shared with the embedding @@ -245,11 +280,13 @@ def version2_export(model, filepath, group_size=64): # quantize this weight q, s, err = quantize_q80(w, group_size) # save the int8 weights to file - serialize_int8(out_file, q) # save the tensor in int8 - serialize_fp32(out_file, s) # save scale factors + serialize_int8(out_file, q) # save the tensor in int8 + serialize_fp32(out_file, s) # save scale factors # logging ew.append((err, w.shape)) - print(f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}") + print( + f"{i+1}/{len(weights)} quantized {tuple(w.shape)} to Q8_0 with max error {err}" + ) # print the highest error across all weights, should be very small, e.g. O(~0.001) ew.sort(reverse=True) @@ -259,8 +296,9 @@ def version2_export(model, filepath, group_size=64): out_file.close() print(f"wrote {filepath}") + def hf_export(llama_model, filepath, group_size=64, dtype=torch.float32): - """ Generate the pytorch_model.bin state_dict and config.json for HuggingFace """ + """Generate the pytorch_model.bin state_dict and config.json for HuggingFace""" try: from transformers.models.llama.configuration_llama import LlamaConfig @@ -269,7 +307,7 @@ def hf_export(llama_model, filepath, group_size=64, dtype=torch.float32): print("Please run `pip install transformers` to install it") return None - # Generate LlamaModel state_dict + # Generate LlamaModel state_dict hf_state_dict = {} # Sometimes we have repeated key values for the heads @@ -281,33 +319,62 @@ def hf_export(llama_model, filepath, group_size=64, dtype=torch.float32): # HuggingFace needs the weights permuted. # See: https://github.com/huggingface/transformers/blob/b132c1703eb1c8bd9dfa4ad6a9be2bfd6ef819e9/src/transformers/models/llama/convert_llama_weights_to_hf.py#L122 def permute_original(w, n_heads=llama_model.params.n_heads, dim1=dim, dim2=dim): - return w.view(dim1, dim2).reshape(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + return ( + w.view(dim1, dim2) + .reshape(n_heads, dim1 // n_heads // 2, 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) # Transfer weights from llama model to the HF state dictionary format - hf_state_dict['model.embed_tokens.weight'] = llama_model.tok_embeddings.weight.clone().to(dtype) - hf_state_dict['model.norm.weight'] = llama_model.norm.weight.clone().to(dtype) + hf_state_dict["model.embed_tokens.weight"] = ( + llama_model.tok_embeddings.weight.clone().to(dtype) + ) + hf_state_dict["model.norm.weight"] = llama_model.norm.weight.clone().to(dtype) # Add each layer's weights to the HF state dictionary for i, layer in enumerate(llama_model.layers): layer_id = layer.layer_id - hf_state_dict[f'model.layers.{i}.input_layernorm.weight'] = llama_model.layers[layer_id].attention_norm.weight.clone().to(dtype) - hf_state_dict[f'model.layers.{i}.self_attn.q_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wq.weight.clone()).to(dtype) - hf_state_dict[f'model.layers.{i}.self_attn.k_proj.weight'] = permute_original(llama_model.layers[layer_id].attention.wk.weight.clone(), num_key_value_heads, key_value_dim, dim).to(dtype) - hf_state_dict[f'model.layers.{i}.self_attn.v_proj.weight'] = llama_model.layers[layer_id].attention.wv.weight.clone().to(dtype) - hf_state_dict[f'model.layers.{i}.self_attn.o_proj.weight'] = llama_model.layers[layer_id].attention.wo.weight.clone().to(dtype) - hf_state_dict[f'model.layers.{i}.post_attention_layernorm.weight'] = llama_model.layers[layer_id].ffn_norm.weight.clone().to(dtype) - hf_state_dict[f'model.layers.{i}.mlp.gate_proj.weight'] = llama_model.layers[layer_id].feed_forward.w1.weight.clone().to(dtype) - hf_state_dict[f'model.layers.{i}.mlp.down_proj.weight'] = llama_model.layers[layer_id].feed_forward.w2.weight.clone().to(dtype) - hf_state_dict[f'model.layers.{i}.mlp.up_proj.weight'] = llama_model.layers[layer_id].feed_forward.w3.weight.clone().to(dtype) + hf_state_dict[f"model.layers.{i}.input_layernorm.weight"] = ( + llama_model.layers[layer_id].attention_norm.weight.clone().to(dtype) + ) + hf_state_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = permute_original( + llama_model.layers[layer_id].attention.wq.weight.clone() + ).to(dtype) + hf_state_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = permute_original( + llama_model.layers[layer_id].attention.wk.weight.clone(), + num_key_value_heads, + key_value_dim, + dim, + ).to(dtype) + hf_state_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = ( + llama_model.layers[layer_id].attention.wv.weight.clone().to(dtype) + ) + hf_state_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = ( + llama_model.layers[layer_id].attention.wo.weight.clone().to(dtype) + ) + hf_state_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = ( + llama_model.layers[layer_id].ffn_norm.weight.clone().to(dtype) + ) + hf_state_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = ( + llama_model.layers[layer_id].feed_forward.w1.weight.clone().to(dtype) + ) + hf_state_dict[f"model.layers.{i}.mlp.down_proj.weight"] = ( + llama_model.layers[layer_id].feed_forward.w2.weight.clone().to(dtype) + ) + hf_state_dict[f"model.layers.{i}.mlp.up_proj.weight"] = ( + llama_model.layers[layer_id].feed_forward.w3.weight.clone().to(dtype) + ) # llama2.c usually uses tied weights -> reference the embed_tokens.weights instead - hf_state_dict['lm_head.weight'] = hf_state_dict['model.embed_tokens.weight'] + hf_state_dict["lm_head.weight"] = hf_state_dict["model.embed_tokens.weight"] # We check that the embeddings are tied, else use manual output weights - _embeddings_are_tied: bool = torch.equal(llama_model.tok_embeddings.weight, llama_model.output.weight) + _embeddings_are_tied: bool = torch.equal( + llama_model.tok_embeddings.weight, llama_model.output.weight + ) if not _embeddings_are_tied: - hf_state_dict['lm_head.weight'] = llama_model.output.weight.clone().to(dtype) - + hf_state_dict["lm_head.weight"] = llama_model.output.weight.clone().to(dtype) # Generate LlamaConfig (seen in transformers.models.llama.configuration_llama) @@ -340,9 +407,8 @@ def permute_original(w, n_heads=llama_model.params.n_heads, dim1=dim, dim2=dim): hidden_act="silu", ) - # Save files in directory filepath - # First make the directory if it doesn't exist + # First make the directory if it doesn't exist os.makedirs(filepath, exist_ok=True) # Save the state dictionary in .bin format, and config as .json @@ -353,29 +419,31 @@ def permute_original(w, n_heads=llama_model.params.n_heads, dim1=dim, dim2=dim): # ----------------------------------------------------------------------------- # Load / import functions + def load_checkpoint(checkpoint): # load the provided model checkpoint - checkpoint_dict = torch.load(checkpoint, map_location='cpu') - gptconf = ModelArgs(**checkpoint_dict['model_args']) + checkpoint_dict = torch.load(checkpoint, map_location="cpu") + gptconf = ModelArgs(**checkpoint_dict["model_args"]) model = Transformer(gptconf) - state_dict = checkpoint_dict['model'] - unwanted_prefix = '_orig_mod.' - for k,v in list(state_dict.items()): + state_dict = checkpoint_dict["model"] + unwanted_prefix = "_orig_mod." + for k, v in list(state_dict.items()): if k.startswith(unwanted_prefix): - state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) + state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) model.load_state_dict(state_dict, strict=False) model.eval() return model + def load_meta_model(model_path): - params_path = os.path.join(model_path, 'params.json') + params_path = os.path.join(model_path, "params.json") with open(params_path) as f: params = json.load(f) print(params) - model_paths = sorted(list(Path(model_path).glob('consolidated.*.pth'))) - models = [torch.load(p, map_location='cpu') for p in model_paths] + model_paths = sorted(list(Path(model_path).glob("consolidated.*.pth"))) + models = [torch.load(p, map_location="cpu") for p in model_paths] def concat_weights(models): state_dict = {} @@ -385,9 +453,9 @@ def concat_weights(models): state_dict[name] = tensors[0] continue is_axis_1 = ( - name.startswith('tok_embeddings.') - or name.endswith('.attention.wo.weight') - or name.endswith('.feed_forward.w2.weight') + name.startswith("tok_embeddings.") + or name.endswith(".attention.wo.weight") + or name.endswith(".feed_forward.w2.weight") ) axis = 1 if is_axis_1 else 0 state_dict[name] = torch.cat(tensors, dim=axis) @@ -403,37 +471,53 @@ def concat_weights(models): config.dim = params["dim"] config.n_layers = params["n_layers"] config.n_heads = params["n_heads"] - config.n_kv_heads = params.get('n_kv_heads') or params['n_heads'] + config.n_kv_heads = params.get("n_kv_heads") or params["n_heads"] config.multiple_of = params["multiple_of"] config.norm_eps = params["norm_eps"] - config.vocab_size = state_dict['tok_embeddings.weight'].shape[0] + config.vocab_size = state_dict["tok_embeddings.weight"].shape[0] config.max_seq_len = 2048 - # create a new Transformer object and set weights model = Transformer(config) - model.tok_embeddings.weight = nn.Parameter(state_dict['tok_embeddings.weight']) - model.norm.weight = nn.Parameter(state_dict['norm.weight']) + model.tok_embeddings.weight = nn.Parameter(state_dict["tok_embeddings.weight"]) + model.norm.weight = nn.Parameter(state_dict["norm.weight"]) for layer in model.layers: i = layer.layer_id - layer.attention_norm.weight = nn.Parameter(state_dict[f'layers.{i}.attention_norm.weight']) - layer.attention.wq.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wq.weight']) - layer.attention.wk.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wk.weight']) - layer.attention.wv.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wv.weight']) - layer.attention.wo.weight = nn.Parameter(state_dict[f'layers.{i}.attention.wo.weight']) - layer.ffn_norm.weight = nn.Parameter(state_dict[f'layers.{i}.ffn_norm.weight']) - layer.feed_forward.w1.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w1.weight']) - layer.feed_forward.w2.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w2.weight']) - layer.feed_forward.w3.weight = nn.Parameter(state_dict[f'layers.{i}.feed_forward.w3.weight']) + layer.attention_norm.weight = nn.Parameter( + state_dict[f"layers.{i}.attention_norm.weight"] + ) + layer.attention.wq.weight = nn.Parameter( + state_dict[f"layers.{i}.attention.wq.weight"] + ) + layer.attention.wk.weight = nn.Parameter( + state_dict[f"layers.{i}.attention.wk.weight"] + ) + layer.attention.wv.weight = nn.Parameter( + state_dict[f"layers.{i}.attention.wv.weight"] + ) + layer.attention.wo.weight = nn.Parameter( + state_dict[f"layers.{i}.attention.wo.weight"] + ) + layer.ffn_norm.weight = nn.Parameter(state_dict[f"layers.{i}.ffn_norm.weight"]) + layer.feed_forward.w1.weight = nn.Parameter( + state_dict[f"layers.{i}.feed_forward.w1.weight"] + ) + layer.feed_forward.w2.weight = nn.Parameter( + state_dict[f"layers.{i}.feed_forward.w2.weight"] + ) + layer.feed_forward.w3.weight = nn.Parameter( + state_dict[f"layers.{i}.feed_forward.w3.weight"] + ) # final classifier - model.output.weight = nn.Parameter(state_dict['output.weight']) + model.output.weight = nn.Parameter(state_dict["output.weight"]) model.eval() return model + def load_hf_model(model_path): try: @@ -461,27 +545,49 @@ def load_hf_model(model_path): # create a new Transformer object and set weights model = Transformer(config) - model.tok_embeddings.weight = nn.Parameter(hf_dict['model.embed_tokens.weight']) - model.norm.weight = nn.Parameter(hf_dict['model.norm.weight']) + model.tok_embeddings.weight = nn.Parameter(hf_dict["model.embed_tokens.weight"]) + model.norm.weight = nn.Parameter(hf_dict["model.norm.weight"]) # huggingface permutes WQ and WK, this function reverses it def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim): - return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) + return ( + w.view(n_heads, 2, dim1 // n_heads // 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) for layer in model.layers: i = layer.layer_id - layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight']) - layer.attention.wq.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight'])) - layer.attention.wk.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight'])) - layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight']) - layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight']) - layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight']) - layer.feed_forward.w1.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.gate_proj.weight']) - layer.feed_forward.w2.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.down_proj.weight']) - layer.feed_forward.w3.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.up_proj.weight']) + layer.attention_norm.weight = nn.Parameter( + hf_dict[f"model.layers.{i}.input_layernorm.weight"] + ) + layer.attention.wq.weight = nn.Parameter( + permute_reverse(hf_dict[f"model.layers.{i}.self_attn.q_proj.weight"]) + ) + layer.attention.wk.weight = nn.Parameter( + permute_reverse(hf_dict[f"model.layers.{i}.self_attn.k_proj.weight"]) + ) + layer.attention.wv.weight = nn.Parameter( + hf_dict[f"model.layers.{i}.self_attn.v_proj.weight"] + ) + layer.attention.wo.weight = nn.Parameter( + hf_dict[f"model.layers.{i}.self_attn.o_proj.weight"] + ) + layer.ffn_norm.weight = nn.Parameter( + hf_dict[f"model.layers.{i}.post_attention_layernorm.weight"] + ) + layer.feed_forward.w1.weight = nn.Parameter( + hf_dict[f"model.layers.{i}.mlp.gate_proj.weight"] + ) + layer.feed_forward.w2.weight = nn.Parameter( + hf_dict[f"model.layers.{i}.mlp.down_proj.weight"] + ) + layer.feed_forward.w3.weight = nn.Parameter( + hf_dict[f"model.layers.{i}.mlp.up_proj.weight"] + ) # final classifier - model.output.weight = nn.Parameter(hf_dict['lm_head.weight']) + model.output.weight = nn.Parameter(hf_dict["lm_head.weight"]) model.eval() return model @@ -489,6 +595,7 @@ def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim) # ----------------------------------------------------------------------------- # API entrypoint + def model_export(model, filepath, version, dtype=torch.float32): """ Versions docs: @@ -509,6 +616,7 @@ def model_export(model, filepath, version, dtype=torch.float32): else: raise ValueError(f"unknown version {version}") + def torchscript_export(model, filepath, zero_params=False, gzip_output=False): """ (This was submitted via a PR earlier. Leaving it here, but "orphaned" for now) @@ -537,6 +645,7 @@ def torchscript_export(model, filepath, zero_params=False, gzip_output=False): shutil.copyfileobj(f_in, f_out) os.unlink(filepath) + # ----------------------------------------------------------------------------- # CLI entrypoint @@ -544,8 +653,12 @@ def torchscript_export(model, filepath, zero_params=False, gzip_output=False): parser = argparse.ArgumentParser() parser.add_argument("filepath", type=str, help="the output filepath") - parser.add_argument("--version", default=0, type=int, help="the version to export with") - parser.add_argument("--dtype", type=str, help="dtype of the model (fp16, fp32)", default="fp32") + parser.add_argument( + "--version", default=0, type=int, help="the version to export with" + ) + parser.add_argument( + "--dtype", type=str, help="dtype of the model (fp16, fp32)", default="fp32" + ) group = parser.add_mutually_exclusive_group(required=True) group.add_argument("--checkpoint", type=str, help="model checkpoint, .pt file") group.add_argument("--meta-llama", type=str, help="meta llama model path") diff --git a/model.py b/model.py index 9e4ce220..5d15d3df 100644 --- a/model.py +++ b/model.py @@ -1,6 +1,6 @@ +import inspect import math import struct -import inspect from dataclasses import dataclass from typing import Any, Optional, Tuple @@ -9,6 +9,7 @@ import torch.nn.functional as F from torch import nn + @dataclass class ModelArgs: # default hyperparameters for the Llama 7B model @@ -46,6 +47,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs_sin = torch.sin(freqs) # imaginary part return freqs_cos, freqs_sin + def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim assert 0 <= 1 < ndim @@ -53,11 +55,9 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(shape) + def apply_rotary_emb( - xq: torch.Tensor, - xk: torch.Tensor, - freqs_cos: torch.Tensor, - freqs_sin: torch.Tensor + xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # reshape xq and xk to match the complex representation @@ -80,6 +80,7 @@ def apply_rotary_emb( return xq_out.type_as(xq), xk_out.type_as(xk) + def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" bs, slen, n_kv_heads, head_dim = x.shape @@ -91,6 +92,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: .reshape(bs, slen, n_kv_heads * n_rep, head_dim) ) + class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() @@ -110,9 +112,11 @@ def __init__(self, args: ModelArgs): self.dropout = args.dropout # use flash attention or a manual implementation? - self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') + self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention") if not self.flash: - print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") + print( + "WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0" + ) mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) mask = torch.triu(mask, diagonal=1) self.register_buffer("mask", mask) @@ -145,12 +149,21 @@ def forward( # flash implementation if self.flash: - output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True) + output = torch.nn.functional.scaled_dot_product_attention( + xq, + xk, + xv, + attn_mask=None, + dropout_p=self.dropout if self.training else 0.0, + is_causal=True, + ) else: # manual implementation scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) - assert hasattr(self, 'mask') - scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen) + assert hasattr(self, "mask") + scores = ( + scores + self.mask[:, :, :seqlen, :seqlen] + ) # (bs, n_local_heads, seqlen, cache_len + seqlen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = self.attn_dropout(scores) output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim) @@ -221,10 +234,14 @@ def __init__(self, params: ModelArgs): self.output = nn.Linear(params.dim, params.vocab_size, bias=False) # share the unembedding parameters with the embedding parameters - self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying + self.tok_embeddings.weight = ( + self.output.weight + ) # https://paperswithcode.com/method/weight-tying # some useful precompute for the RoPE relative positional embeddings - freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len) + freqs_cos, freqs_sin = precompute_freqs_cis( + self.params.dim // self.params.n_heads, self.params.max_seq_len + ) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) @@ -232,8 +249,10 @@ def __init__(self, params: ModelArgs): self.apply(self._init_weights) # apply special scaled init to the residual projections, per GPT-2 paper for pn, p in self.named_parameters(): - if pn.endswith('w3.weight') or pn.endswith('wo.weight'): - torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers)) + if pn.endswith("w3.weight") or pn.endswith("wo.weight"): + torch.nn.init.normal_( + p, mean=0.0, std=0.02 / math.sqrt(2 * params.n_layers) + ) # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor. self.last_loss = None @@ -246,7 +265,9 @@ def _init_weights(self, module): elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) - def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None + ) -> torch.Tensor: _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) h = self.dropout(h) @@ -260,10 +281,14 @@ def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) if targets is not None: # if we are given some desired targets also calculate the loss logits = self.output(h) - self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) + self.last_loss = F.cross_entropy( + logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1 + ) else: # inference-time mini-optimization: only forward the output on the very last position - logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim + logits = self.output( + h[:, [-1], :] + ) # note: using list [-1] to preserve the time dim self.last_loss = None return logits @@ -278,35 +303,41 @@ def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] optim_groups = [ - {'params': decay_params, 'weight_decay': weight_decay}, - {'params': nodecay_params, 'weight_decay': 0.0} + {"params": decay_params, "weight_decay": weight_decay}, + {"params": nodecay_params, "weight_decay": 0.0}, ] num_decay_params = sum(p.numel() for p in decay_params) num_nodecay_params = sum(p.numel() for p in nodecay_params) - print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") - print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") + print( + f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters" + ) + print( + f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters" + ) # Create AdamW optimizer and use the fused version if it is available - fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters - use_fused = fused_available and device_type == 'cuda' + fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters + use_fused = fused_available and device_type == "cuda" extra_args = dict(fused=True) if use_fused else dict() - optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) + optimizer = torch.optim.AdamW( + optim_groups, lr=learning_rate, betas=betas, **extra_args + ) print(f"using fused AdamW: {use_fused}") return optimizer def estimate_mfu(self, fwdbwd_per_iter, dt): - """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ + """estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS""" # first estimate the number of flops we do per iteration. # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 N = sum(p.numel() for p in self.parameters()) cfg = self.params - L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len - flops_per_token = 6*N + 12*L*H*Q*T + L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim // cfg.n_heads, cfg.max_seq_len + flops_per_token = 6 * N + 12 * L * H * Q * T flops_per_fwdbwd = flops_per_token * T flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter # express our flops throughput as ratio of A100 bfloat16 peak flops - flops_achieved = flops_per_iter * (1.0/dt) # per second - flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS + flops_achieved = flops_per_iter * (1.0 / dt) # per second + flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS mfu = flops_achieved / flops_promised return mfu @@ -320,10 +351,14 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """ for _ in range(max_new_tokens): # if the sequence context is growing too long we must crop it at block_size - idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:] + idx_cond = ( + idx + if idx.size(1) <= self.params.max_seq_len + else idx[:, -self.params.max_seq_len :] + ) # forward the model to get the logits for the index in the sequence logits = self(idx_cond) - logits = logits[:, -1, :] # crop to just the final time step + logits = logits[:, -1, :] # crop to just the final time step if temperature == 0.0: # "sample" the single most likely index _, idx_next = torch.topk(logits, k=1, dim=-1) @@ -333,7 +368,7 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): # optionally crop the logits to only the top k options if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) - logits[logits < v[:, [-1]]] = -float('Inf') + logits[logits < v[:, [-1]]] = -float("Inf") # apply softmax to convert logits to (normalized) probabilities probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) diff --git a/tokenizer.py b/tokenizer.py index d7f04c02..5cdde2ba 100644 --- a/tokenizer.py +++ b/tokenizer.py @@ -2,10 +2,10 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 3 Community License Agreement. +import argparse import array import os import struct -import argparse from pathlib import Path from typing import List @@ -107,7 +107,9 @@ def export(self): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("-t", "--tokenizer-model", type=str, help="optional path to custom tokenizer ") + parser.add_argument( + "-t", "--tokenizer-model", type=str, help="optional path to custom tokenizer " + ) args = parser.parse_args()