Skip to content

Commit

Permalink
Add custom kq scaling from Gemma2Attention
Browse files Browse the repository at this point in the history
  • Loading branch information
abetlen committed Jun 29, 2024
1 parent 6f2464e commit a894279
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 1 deletion.
3 changes: 3 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2369,6 +2369,9 @@ def set_gguf_parameters(self):
self.gguf_writer.add_final_logit_softcapping(
self.hparams["final_logit_softcapping"]
)
self.gguf_writer.add_query_pre_attn_scalar(
self.hparams["query_pre_attn_scalar"]
)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unusem
Expand Down
1 change: 1 addition & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class LLM:
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
QUERY_PRE_ATTN_SCALAR = "{arch}.query_pre_attn_scalar"

class Attention:
HEAD_COUNT = "{arch}.attention.head_count"
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,9 @@ def add_attn_logit_softcapping(self, value: float) -> None:
def add_final_logit_softcapping(self, value: float) -> None:
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)

def add_query_pre_attn_scalar(self, value: float) -> None:
self.add_float32(Keys.LLM.QUERY_PRE_ATTN_SCALAR.format(arch=self.arch), value)

def add_expert_count(self, count: int) -> None:
self.add_uint32(Keys.LLM.EXPERT_COUNT.format(arch=self.arch), count)

Expand Down
6 changes: 5 additions & 1 deletion src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ enum llm_kv {
LLM_KV_DECODER_START_TOKEN_ID,
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
LLM_KV_QUERY_PRE_ATTN_SCALAR,

LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,
Expand Down Expand Up @@ -396,6 +397,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
{ LLM_KV_QUERY_PRE_ATTN_SCALAR, "%s.query_pre_attn_scalar" },

{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
Expand Down Expand Up @@ -2105,6 +2107,7 @@ struct llama_hparams {

float f_attn_logit_softcapping = 50.0f;
float f_final_logit_softcapping = 30.0f;
float f_query_pre_attn_scalar = 144.0f;

float rope_attn_factor = 1.0f;
float rope_freq_base_train;
Expand Down Expand Up @@ -4712,6 +4715,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
ml.get_key(LLM_KV_QUERY_PRE_ATTN_SCALAR, hparams.f_query_pre_attn_scalar, false);
hparams.attn_soft_cap = true;

switch (hparams.n_layer) {
Expand Down Expand Up @@ -10948,7 +10952,7 @@ struct llm_build_context {
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);

Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head_k)));
Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(hparams.f_query_pre_attn_scalar));
cb(Qcur, "Qcur_scaled", il);

Kcur = ggml_rope_ext(
Expand Down

0 comments on commit a894279

Please sign in to comment.