From a89427908d04fcf3b4e975724596efddce4db737 Mon Sep 17 00:00:00 2001 From: Andrei Betlen Date: Sat, 29 Jun 2024 10:17:33 -0400 Subject: [PATCH] Add custom kq scaling from Gemma2Attention --- convert-hf-to-gguf.py | 3 +++ gguf-py/gguf/constants.py | 1 + gguf-py/gguf/gguf_writer.py | 3 +++ src/llama.cpp | 6 +++++- 4 files changed, 12 insertions(+), 1 deletion(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 3ef2f69e7c0df..23a3573435e7b 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2369,6 +2369,9 @@ def set_gguf_parameters(self): self.gguf_writer.add_final_logit_softcapping( self.hparams["final_logit_softcapping"] ) + self.gguf_writer.add_query_pre_attn_scalar( + self.hparams["query_pre_attn_scalar"] + ) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unusem diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 9bfa891d5dc52..eab5cbf69fac8 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -52,6 +52,7 @@ class LLM: DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" + QUERY_PRE_ATTN_SCALAR = "{arch}.query_pre_attn_scalar" class Attention: HEAD_COUNT = "{arch}.attention.head_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 1aeb0d9b08685..37c41a5bf64e5 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -522,6 +522,9 @@ def add_attn_logit_softcapping(self, value: float) -> None: def add_final_logit_softcapping(self, value: float) -> None: self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value) + def add_query_pre_attn_scalar(self, value: float) -> None: + self.add_float32(Keys.LLM.QUERY_PRE_ATTN_SCALAR.format(arch=self.arch), value) + def add_expert_count(self, count: int) -> None: self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count) diff --git a/src/llama.cpp b/src/llama.cpp index 9654ffad39848..56a6898c31215 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -304,6 +304,7 @@ enum llm_kv { LLM_KV_DECODER_START_TOKEN_ID, LLM_KV_ATTN_LOGIT_SOFTCAPPING, LLM_KV_FINAL_LOGIT_SOFTCAPPING, + LLM_KV_QUERY_PRE_ATTN_SCALAR, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -396,6 +397,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, + { LLM_KV_QUERY_PRE_ATTN_SCALAR, "%s.query_pre_attn_scalar" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -2105,6 +2107,7 @@ struct llama_hparams { float f_attn_logit_softcapping = 50.0f; float f_final_logit_softcapping = 30.0f; + float f_query_pre_attn_scalar = 144.0f; float rope_attn_factor = 1.0f; float rope_freq_base_train; @@ -4712,6 +4715,7 @@ static void llm_load_hparams( ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + ml.get_key(LLM_KV_QUERY_PRE_ATTN_SCALAR, hparams.f_query_pre_attn_scalar, false); hparams.attn_soft_cap = true; switch (hparams.n_layer) { @@ -10948,7 +10952,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); - Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); + Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(hparams.f_query_pre_attn_scalar)); cb(Qcur, "Qcur_scaled", il); Kcur = ggml_rope_ext(