From 2d50f81ec7398f6de4c80c0797f9ab2d5ac63235 Mon Sep 17 00:00:00 2001 From: Icecream95 Date: Sat, 18 May 2024 19:41:42 +1200 Subject: [PATCH] Initial OpenELM support (270M only so far) --- convert-hf-to-gguf.py | 44 ++++++ gguf-py/gguf/constants.py | 14 ++ gguf-py/gguf/tensor_mapping.py | 17 ++- llama.cpp | 250 +++++++++++++++++++++++++++++++++ 4 files changed, 322 insertions(+), 3 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index cd1750aa3f3ba2..a128dd88f87807 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2395,6 +2395,50 @@ def set_vocab(self, *args, **kwargs): self.gguf_writer.add_add_eos_token(True) +@Model.register("OpenELMForCausalLM") +class OpenELMModel(Model): + model_arch = gguf.MODEL_ARCH.OPENELM + + # Copied from LlamaModel + def set_vocab(self): + try: + self. _set_vocab_sentencepiece() + except FileNotFoundError: + self._set_vocab_llama_hf() + + def set_gguf_parameters(self): + # TODO: Look closer at these + + self.gguf_writer.add_name("OpenELM") + self.block_count = self.find_hparam(["num_transformer_layers"]) + self.gguf_writer.add_layer_norm_eps(1e-5) + # https://huggingface.co/apple/OpenELM-270M-Instruct/blob/c401df2/modeling_openelm.py#L30 + self.gguf_writer.add_layer_norm_rms_eps(1e-6) + n_embd = self.find_hparam(["model_dim"]) + self.gguf_writer.add_embedding_length(n_embd) + head_dim = self.find_hparam(["head_dim"]) + n_head = n_embd // head_dim + rot_pct = 1.0 + self.gguf_writer.add_context_length(self.find_hparam(["max_context_length"])) + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_head_count_kv(n_head*10) + self.gguf_writer.add_head_count(n_head*10) + self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head) + self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_feed_forward_length(0) # dynamically calculated + + def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any: + # TODO: Read configuration! + if "n_layers" in keys: + return 16 # num_transformer_layers + if "hidden_size" in keys: + return 1280 # model_dim + if "num_attention_heads" in keys: + return 64 # head_dim + + return super().find_hparam(keys, optional) + + ###### CONVERSION LOGIC ###### diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 978fcada3b42c5..fcd9995eb08abf 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -139,6 +139,7 @@ class MODEL_ARCH(IntEnum): COMMAND_R = auto() DBRX = auto() OLMO = auto() + OPENELM = auto() class MODEL_TENSOR(IntEnum): @@ -217,6 +218,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", MODEL_ARCH.OLMO: "olmo", + MODEL_ARCH.OPENELM: "openelm", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { @@ -743,6 +745,18 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.OPENELM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_DOWN, + ], # TODO } diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 8e1cac9152f55e..58167663fc633a 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -24,6 +24,7 @@ class TensorNameMap: "backbone.embedding", # mamba "backbone.embeddings", # mamba-hf "transformer.in_out_embed", # Grok + "transformer.token_embeddings", # openelm ), # Token type embeddings @@ -36,6 +37,7 @@ class TensorNameMap: "word_embeddings_layernorm", # bloom "embeddings.LayerNorm", # bert "emb_ln", # nomic-bert + "transformer.norm", # openelm ), # Position embeddings @@ -68,6 +70,7 @@ class TensorNameMap: "model.norm_f", # mamba-qbert "backbone.norm_f", # mamba "transformer.rms_norm", # Grok + "transformer.norm", # openelm ), # Rope frequencies @@ -97,6 +100,7 @@ class TensorNameMap: "backbone.layers.{bid}.norm", # mamba "transformer.decoder_layer.{bid}.rms_norm", # Grok "transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx + "transformer.layers.{bid}.attn_norm", # openelm ), # Attention norm 2 @@ -117,7 +121,8 @@ class TensorNameMap: "h.{bid}.attn.c_attn", # gpt2 "transformer.h.{bid}.mixer.Wqkv", # phi2 "encoder.layers.{bid}.attn.Wqkv", # nomic-bert - "model.layers.{bid}.self_attn.qkv_proj" # phi3 + "model.layers.{bid}.self_attn.qkv_proj", # phi3 + "transformer.layers.{bid}.attn.qkv_proj", # openelm ), # Attention query @@ -175,6 +180,7 @@ class TensorNameMap: "encoder.layers.{bid}.attn.out_proj", # nomic-bert "transformer.decoder_layer.{bid}.multi_head_attention.linear", # Grok "transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx + "transformer.layers.{bid}.attn.out_proj", # openelm ), # Attention output norm @@ -206,6 +212,7 @@ class TensorNameMap: "h.{bid}.ln_2", # gpt2 "model.layers.{bid}.ffn_norm", # internlm2 "transformer.decoder_layer.{bid}.rms_norm_2", # Grok + "transformer.layers.{bid}.ffn_norm", # openelm ), MODEL_TENSOR.FFN_GATE_INP: ( @@ -244,6 +251,7 @@ class TensorNameMap: "encoder.layers.{bid}.mlp.fc11", # nomic-bert "model.layers.{bid}.mlp.c_fc", # starcoder2 "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 + "transformer.layers.{bid}.ffn.proj_1", # openelm ), MODEL_TENSOR.FFN_UP_EXP: ( @@ -306,6 +314,7 @@ class TensorNameMap: "encoder.layers.{bid}.mlp.fc2", # nomic-bert "model.layers.{bid}.mlp.c_proj", # starcoder2 "encoder.layer.{bid}.mlp.wo", # jina-bert-v2 + "transformer.layers.{bid}.ffn.proj_2", # openelm ), MODEL_TENSOR.FFN_DOWN_EXP: ( @@ -324,7 +333,8 @@ class TensorNameMap: "model.layers.{bid}.self_attn.q_layernorm", # persimmon "model.layers.{bid}.self_attn.q_norm", # cohere "transformer.blocks.{bid}.attn.q_ln", # sea-lion - "encoder.layer.{bid}.attention.self.layer_norm_q" # jina-bert-v2 + "encoder.layer.{bid}.attention.self.layer_norm_q", # jina-bert-v2 + "transformer.layers.{bid}.attn.q_norm", # openelm ), MODEL_TENSOR.ATTN_K_NORM: ( @@ -332,7 +342,8 @@ class TensorNameMap: "model.layers.{bid}.self_attn.k_layernorm", # persimmon "model.layers.{bid}.self_attn.k_norm", # cohere "transformer.blocks.{bid}.attn.k_ln", # sea-lion - "encoder.layer.{bid}.attention.self.layer_norm_k" # jina-bert-v2 + "encoder.layer.{bid}.attention.self.layer_norm_k", # jina-bert-v2 + "transformer.layers.{bid}.attn.k_norm", # openelm ), MODEL_TENSOR.ROPE_FREQS: ( diff --git a/llama.cpp b/llama.cpp index b752ddc6b401fb..a35bd884abdd50 100644 --- a/llama.cpp +++ b/llama.cpp @@ -229,6 +229,7 @@ enum llm_arch { LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, LLM_ARCH_OLMO, + LLM_ARCH_OPENELM, LLM_ARCH_UNKNOWN, }; @@ -266,6 +267,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, { LLM_ARCH_OLMO, "olmo" }, + { LLM_ARCH_OPENELM, "openelm" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -1052,6 +1054,21 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_OPENELM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -1715,7 +1732,9 @@ enum e_model { MODEL_33M, MODEL_109M, MODEL_137M, + MODEL_270M, MODEL_335M, + MODEL_450M, MODEL_0_5B, MODEL_1B, MODEL_2B, @@ -4261,6 +4280,18 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_OPENELM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 16: model.type = e_model::MODEL_270M; break; + case 20: model.type = e_model::MODEL_450M; break; + case 28: model.type = e_model::MODEL_1B; break; + case 36: model.type = e_model::MODEL_3B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; default: (void)0; } @@ -4767,6 +4798,24 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { if (vocab.special_eot_id != -1) { LLAMA_LOG_INFO( "%s: EOT token = %d '%s'\n", __func__, vocab.special_eot_id, vocab.id_to_token[vocab.special_eot_id].text.c_str() ); } } +static int make_divisible( + double v, + int divisor = 8, + double min_value = 0 +) { + if (min_value == 0) { + min_value = divisor; + } + int new_v = int(v + divisor / 2.0) / divisor * divisor; + if (new_v < min_value) { + new_v = min_value; + } + if (new_v < 0.9 * v) { + new_v += divisor; + } + return new_v; +} + // Returns false if cancelled by progress_callback static bool llm_load_tensors( llama_model_loader & ml, @@ -6060,6 +6109,52 @@ static bool llm_load_tensors( layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); } } break; + case LLM_ARCH_OPENELM: + { + { + { + std::vector num_kv_heads = {3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5}; + std::vector num_query_heads = {12, 12, 12, 12, 12, 16, 16, 16, 16, 16, 16, 16, 20, 20, 20, 20}; + std::vector ffn_multipliers = {0.5, 0.73, 0.97, 1.2, 1.43, 1.67, 1.9, 2.13, 2.37, 2.6, 2.83, 3.07, 3.3, 3.53, 3.77, 4.0}; + + llama_hparams modified_hparams(hparams); + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }); + + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }); + model.output = ml.create_tensor(ctx_output, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + ml.n_created--; // artificial tensor + ml.size_data += ggml_nbytes(model.output); + } + + for (int i = 0; i < n_layer; ++i) { + const int64_t n_head_k = num_kv_heads[i]; + const int64_t n_head_v = num_kv_heads[i]; + const int64_t n_head_kv = n_head_k + n_head_v; + const int64_t n_head = n_head_kv + num_query_heads[i]; + modified_hparams.n_head = n_head; + modified_hparams.n_embd_head_v = 64; + modified_hparams.n_embd_head_k = 64; + int64_t n_embd_head = modified_hparams.n_embd_head_v; + + modified_hparams.n_head_kv = n_head_kv; + const int64_t ffn_inter = make_divisible(n_embd*ffn_multipliers[i], 256); + + ggml_context* ctx_layer = ctx_for_layer(i); + ggml_context* ctx_split = ctx_for_layer_split(i); + auto& layer = model.layers[i]; + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }); + layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, n_embd_head*n_head }); + layer.attn_q_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head }); + layer.attn_k_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head }); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_head_kv*n_embd_head*2, n_embd }); + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * ffn_inter }); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { ffn_inter, n_embd }); + } + } + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -10780,6 +10875,156 @@ struct llm_build_context { return gf; } + + struct ggml_cgraph * build_openelm() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); + + const int64_t n_embd_head = 64; + // TODO: get this from config + std::vector num_kv_heads = {3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5}; + std::vector num_query_heads = {12, 12, 12, 12, 12, 16, 16, 16, 16, 16, 16, 16, 20, 20, 20, 20}; + + llama_hparams modified_hparams(hparams); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + // inp_pos - contains the positions + struct ggml_tensor * inp_pos = build_inp_pos(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + for (int il = 0; il < n_layer; ++il) { + const int64_t n_head_q = num_query_heads[il]; + const int64_t n_head_k = num_kv_heads[il]; + const int64_t n_head_v = num_kv_heads[il]; + const int64_t n_head_kv = n_head_k + n_head_v; + const int64_t n_head = n_head_kv + n_head_q; + + modified_hparams.n_head = n_head_q; + modified_hparams.n_head_kv = n_head_k; + modified_hparams.n_embd_head_k = 64; + modified_hparams.n_embd_head_v = 64; + + cur = inpL; + struct ggml_tensor * residual = cur; + + // norm + cur = llm_build_norm(ctx0, inpL, modified_hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + // self-attention + { + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + cb(cur, "wqkv", il); + + cur = ggml_reshape_3d(ctx0, cur, n_embd_head, n_head, n_tokens); + + struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_q, n_tokens, cur->nb[1], cur->nb[2], 0)); + cb(Qcur, "Qcur", il); + + struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_k, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*n_head_q)); + cb(Kcur, "Kcur", il); + + struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_3d(ctx0, cur, n_embd_head, n_head_v, n_tokens, cur->nb[1], cur->nb[2], cur->nb[1]*(n_head_q+n_head_k))); + cb(Vcur, "Vcur", il); + + Qcur = llm_build_norm(ctx0, Qcur, modified_hparams, + model.layers[il].attn_q_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(Qcur, "Qcur", il); + + Kcur = llm_build_norm(ctx0, Kcur, modified_hparams, + model.layers[il].attn_k_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(Kcur, "Kcur", il); + + Qcur = ggml_rope_custom( + ctx0, Qcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Qcur, "Qcur", il); + + Kcur = ggml_rope_custom( + ctx0, Kcur, inp_pos, n_rot, rope_type, 0, n_orig_ctx, + freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow + ); + cb(Kcur, "Kcur", il); + + Vcur = ggml_reshape_2d(ctx0, Vcur, n_embd_head * n_head_v, n_tokens); + cb(Qcur, "Vcur", il); + + cur = llm_build_kv(ctx0, model, modified_hparams, cparams, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + } + + struct ggml_tensor * ffn_inp = ggml_add(ctx0, residual, cur); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = llm_build_norm(ctx0, ffn_inp, modified_hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // TODO: Split ffn_up during conversion? + struct ggml_tensor * ffn_gate = + ggml_view_2d(ctx0, + model.layers[il].ffn_up, + model.layers[il].ffn_down->ne[1], + model.layers[il].ffn_down->ne[0], + model.layers[il].ffn_up->nb[1], + 0); + + struct ggml_tensor * ffn_up = + ggml_view_2d(ctx0, + model.layers[il].ffn_up, + model.layers[il].ffn_down->ne[1], + model.layers[il].ffn_down->ne[0], + model.layers[il].ffn_up->nb[1], + model.layers[il].ffn_up->nb[1] * model.layers[il].ffn_down->ne[0]); + + cur = llm_build_ffn(ctx0, cur, + ffn_up, NULL, + ffn_gate, NULL, + model.layers[il].ffn_down, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "l_out", il); + + inpL = cur; + } + + cur = inpL; + + // norm + cur = llm_build_norm(ctx0, cur, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } }; static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector & ids) { @@ -10994,6 +11239,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_olmo(); } break; + case LLM_ARCH_OPENELM: + { + result = llm.build_openelm(); + } break; default: GGML_ASSERT(false); } @@ -16036,6 +16285,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_PHI3: case LLM_ARCH_GEMMA: case LLM_ARCH_STARCODER2: + case LLM_ARCH_OPENELM: return LLAMA_ROPE_TYPE_NEOX; // all model arches should be listed explicitly here