From 1c5eba6f8e628fb0a98afb27d8aaeb3b0e136451 Mon Sep 17 00:00:00 2001 From: Andrei Date: Sat, 29 Jun 2024 20:44:08 -0700 Subject: [PATCH] llama: Add attention and final logit soft-capping, update scaling factor to Gemma2 (#8197) * Add attention and final logit softcapping. * fix * Add custom add_ functions * Disable flash attention for Gemma2 * Update src/llama.cpp Co-authored-by: slaren * Add default value for attention and final logit softcap value * Add custom kq scaling from Gemma2Attention * Remove custom pre attention scaling and use computed value instead. --------- Co-authored-by: slaren --- convert-hf-to-gguf.py | 6 ++++++ gguf-py/gguf/constants.py | 2 ++ gguf-py/gguf/gguf_writer.py | 6 ++++++ src/llama.cpp | 35 ++++++++++++++++++++++++++++++++--- 4 files changed, 46 insertions(+), 3 deletions(-) diff --git a/convert-hf-to-gguf.py b/convert-hf-to-gguf.py index 5bcc849db999d..3ef2f69e7c0df 100755 --- a/convert-hf-to-gguf.py +++ b/convert-hf-to-gguf.py @@ -2363,6 +2363,12 @@ def set_gguf_parameters(self): self.gguf_writer.add_key_length(hparams["head_dim"]) self.gguf_writer.add_value_length(hparams["head_dim"]) self.gguf_writer.add_file_type(self.ftype) + self.gguf_writer.add_attn_logit_softcapping( + self.hparams["attn_logit_softcapping"] + ) + self.gguf_writer.add_final_logit_softcapping( + self.hparams["final_logit_softcapping"] + ) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unusem diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index cf3d09e70d3e7..9bfa891d5dc52 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -50,6 +50,8 @@ class LLM: POOLING_TYPE = "{arch}.pooling_type" LOGIT_SCALE = "{arch}.logit_scale" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" + ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" + FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" class Attention: HEAD_COUNT = "{arch}.attention.head_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 9869f6fe3445a..1aeb0d9b08685 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -516,6 +516,12 @@ def add_clamp_kqv(self, value: float) -> None: def add_logit_scale(self, value: float) -> None: self.add_float32(Keys.LLM.LOGIT_SCALE.format(arch=self.arch), value) + def add_attn_logit_softcapping(self, value: float) -> None: + self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value) + + def add_final_logit_softcapping(self, value: float) -> None: + self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value) + def add_expert_count(self, count: int) -> None: self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count) diff --git a/src/llama.cpp b/src/llama.cpp index 3edaa98e8d01b..2a4d73856fcd9 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -302,6 +302,8 @@ enum llm_kv { LLM_KV_POOLING_TYPE, LLM_KV_LOGIT_SCALE, LLM_KV_DECODER_START_TOKEN_ID, + LLM_KV_ATTN_LOGIT_SOFTCAPPING, + LLM_KV_FINAL_LOGIT_SOFTCAPPING, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -392,6 +394,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_POOLING_TYPE , "%s.pooling_type" }, { LLM_KV_LOGIT_SCALE, "%s.logit_scale" }, { LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" }, + { LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" }, + { LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -2099,6 +2103,9 @@ struct llama_hparams { float f_norm_eps; float f_norm_rms_eps; + float f_attn_logit_softcapping = 50.0f; + float f_final_logit_softcapping = 30.0f; + float rope_attn_factor = 1.0f; float rope_freq_base_train; float rope_freq_scale_train; @@ -2115,8 +2122,9 @@ struct llama_hparams { float f_max_alibi_bias = 0.0f; float f_logit_scale = 0.0f; - bool causal_attn = true; - bool use_alibi = false; + bool causal_attn = true; + bool use_alibi = false; + bool attn_soft_cap = false; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -4702,6 +4710,9 @@ static void llm_load_hparams( case LLM_ARCH_GEMMA2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); + ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false); + hparams.attn_soft_cap = true; switch (hparams.n_layer) { case 42: model.type = e_model::MODEL_9B; break; @@ -7579,6 +7590,12 @@ static struct ggml_tensor * llm_build_kqv( kq = ggml_scale(ctx, kq, 30); } + if (hparams.attn_soft_cap) { + kq = ggml_scale(ctx, kq, 1.0f / hparams.f_attn_logit_softcapping); + kq = ggml_tanh(ctx, kq); + kq = ggml_scale(ctx, kq, hparams.f_attn_logit_softcapping); + } + kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias); cb(kq, "kq_soft_max_ext", il); @@ -11039,7 +11056,7 @@ struct llm_build_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); - Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k))); + Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd / n_head))); cb(Qcur, "Qcur_scaled", il); Kcur = ggml_rope_ext( @@ -11106,6 +11123,12 @@ struct llm_build_context { // lm_head cur = ggml_mul_mat(ctx0, model.output, cur); + + // final logit soft-capping + cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping); + cur = ggml_tanh(ctx0, cur); + cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping); + cb(cur, "result_output", -1); ggml_build_forward_expand(gf, cur); @@ -17379,6 +17402,12 @@ struct llama_context * llama_new_context_with_model( params.flash_attn = false; } + if (params.flash_attn && model->hparams.attn_soft_cap) { + LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__); + params.flash_attn = false; + } + + if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) { LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__); params.flash_attn = false;