Skip to content

Commit

Permalink
fix(llama.cpp): Use separate switch clause for granite in llm_load_hp…
Browse files Browse the repository at this point in the history
…arams

Branch: GraniteLM

Signed-off-by: Gabe Goodhart <[email protected]>
  • Loading branch information
gabe-l-hart committed Sep 16, 2024
1 parent 65c5bb9 commit 5d054a4
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5438,7 +5438,6 @@ static void llm_load_hparams(
// arch-specific KVs
switch (model.arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_GRANITE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);

Expand All @@ -5455,20 +5454,13 @@ static void llm_load_hparams(
// granite uses a vocab with len 49152
case 32: model.type = hparams.n_vocab == 49152 ? e_model::MODEL_3B : (hparams.n_vocab < 40000 ? e_model::MODEL_7B : e_model::MODEL_8B); break;
case 36: model.type = e_model::MODEL_8B; break; // granite
case 40: model.type = (hparams.n_vocab == 49152 || hparams.n_vocab == 49156) ? e_model::MODEL_3B : e_model::MODEL_13B; break;
case 40: model.type = e_model::MODEL_13B; break;
case 48: model.type = e_model::MODEL_34B; break;
case 60: model.type = e_model::MODEL_30B; break;
case 80: model.type = hparams.n_head() == hparams.n_head_kv() ? e_model::MODEL_65B : e_model::MODEL_70B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
}
// Extra multipliers for Granite architecture
if (model.arch == LLM_ARCH_GRANITE) {
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);
}
} break;
case LLM_ARCH_MINICPM:
{
Expand Down Expand Up @@ -6059,6 +6051,20 @@ static void llm_load_hparams(
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
case LLM_ARCH_GRANITE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale);
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale);
ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);

switch (hparams.n_layer) {
case 40: model.type = e_model::MODEL_3B; break;
// Add additional layer/vocab/etc checks here for other model sizes
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
default: (void)0;
}

Expand Down

0 comments on commit 5d054a4

Please sign in to comment.