Skip to content

Commit

Permalink
llama : add IBM Granite MoE architecture (ggerganov#9438)
Browse files Browse the repository at this point in the history
* feat(gguf-py): Add granitemoe architecture

This includes the addition of new tensor names for the new moe layers.
These may not be correct at this point due to the need for the hack in
gguf_writer.py to double-check the length of the shape for these layers.

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(convert_hf_to_gguf): Add GraniteMoeModel

GraniteMoe has the same configuration deltas as Granite

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>

* fix(granitemoe convert): Split the double-sized input layer into gate and up

After a lot of staring and squinting, it's clear that the standard mixtral
expert implementation is equivalent to the vectorized parallel experts in
granite. The difference is that in granite, the w1 and w3 are concatenated
into a single tensor "input_linear." Rather than reimplementing all of the
math on the llama.cpp side, the much simpler route is to just split this
tensor during conversion and follow the standard mixtral route.

Branch: GraniteMoE

Co-Authored-By: [email protected]

Signed-off-by: Gabe Goodhart <[email protected]>

* feat(granitemoe): Implement granitemoe

GraniteMoE follows the mixtral architecture (once the input_linear layers
are split into gate_exps/up_exps). The main delta is the addition of the
same four multipliers used in Granite.

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>

* Typo fix in docstring

Co-Authored-By: [email protected]

Co-authored-by: Georgi Gerganov <[email protected]>
Signed-off-by: Gabe Goodhart <[email protected]>

* fix(conversion): Simplify tensor name mapping in conversion

Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>

* fix(convert): Remove unused tensor name mappings

Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>

* fix(convert): Sanity check on merged FFN tensor sizes

Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>

* fix: Allow "output" layer in granite moe architecture (convert and cpp)

Branch: GraniteMoE

Co-Authored-By: [email protected]
Signed-off-by: Gabe Goodhart <[email protected]>

* fix(granite): Add missing 'output' tensor for Granite

This is a fix for the previous `granite` architecture PR. Recent snapshots
have included this (`lm_head.weights`) as part of the architecture

Branch: GraniteMoE

Signed-off-by: Gabe Goodhart <[email protected]>

---------

Signed-off-by: Gabe Goodhart <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
  • Loading branch information
2 people authored and arthw committed Nov 18, 2024
1 parent 205f8a6 commit 7d63c36
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 13 deletions.
33 changes: 31 additions & 2 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4102,16 +4102,45 @@ def set_gguf_parameters(self):
# consistency
if attention_scale := self.hparams.get("attention_multiplier"):
self.gguf_writer.add_attention_scale(attention_scale)
logger.info("gguf: (granite) attention_scale = %s", attention_scale)
if embedding_scale := self.hparams.get("embedding_multiplier"):
self.gguf_writer.add_embedding_scale(embedding_scale)
logger.info("gguf: (granite) embedding_scale = %s", embedding_scale)
if residual_scale := self.hparams.get("residual_multiplier"):
self.gguf_writer.add_residual_scale(residual_scale)
if logits_scaling := self.hparams.get("logits_scaling"):
self.gguf_writer.add_logit_scale(logits_scaling)
logger.info("gguf: (granite) residual_scale = %s", residual_scale)
if logits_scale := self.hparams.get("logits_scaling"):
self.gguf_writer.add_logit_scale(logits_scale)
logger.info("gguf: (granite) logits_scale = %s", logits_scale)


@Model.register("GraniteMoeForCausalLM")
class GraniteMoeModel(GraniteModel):
"""Conversion for IBM's GraniteMoeForCausalLM"""
model_arch = gguf.MODEL_ARCH.GRANITE_MOE

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
"""In modeling_granitemoe, the JetMoe implementation of parallel experts
is used. This essentially merges w1 and w3 into a single tensor with 2x
the hidden size that is then split during forward. To keep compatibility
with existing mixtral support, we pull them apart here.
"""

if name.endswith("block_sparse_moe.input_linear.weight"):
ffn_dim = self.hparams["intermediate_size"]
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * intermediate_size"
gate, up = data_torch[..., :ffn_dim, :], data_torch[..., ffn_dim:, :]
return [
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), gate),
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up),
]

return super().modify_tensors(data_torch, name, bid)


###### CONVERSION LOGIC ######


# tree of lazy tensors
class LazyTorchTensor(gguf.LazyBase):
_tensor_type = torch.Tensor
Expand Down
18 changes: 18 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ class MODEL_ARCH(IntEnum):
NEMOTRON = auto()
EXAONE = auto()
GRANITE = auto()
GRANITE_MOE = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -392,6 +393,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.EXAONE: "exaone",
MODEL_ARCH.GRANITE: "granite",
MODEL_ARCH.GRANITE_MOE: "granitemoe",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -1232,6 +1234,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.GRANITE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
Expand All @@ -1242,6 +1245,21 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.GRANITE_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
# TODO
}

Expand Down
20 changes: 11 additions & 9 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,12 @@ class TensorNameMap:
),

MODEL_TENSOR.FFN_GATE_INP: (
"layers.{bid}.feed_forward.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
"transformer.decoder_layer.{bid}.router", # Grok
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
"layers.{bid}.feed_forward.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
"model.layers.{bid}.mlp.gate", # qwen2moe olmoe
"transformer.decoder_layer.{bid}.router", # Grok
"transformer.blocks.{bid}.ffn.router.layer", # dbrx
"model.layers.{bid}.block_sparse_moe.router.layer", # granitemoe
),

MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
Expand Down Expand Up @@ -364,10 +365,11 @@ class TensorNameMap:
),

MODEL_TENSOR.FFN_DOWN_EXP: (
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
"layers.{bid}.feed_forward.experts.w2", # mixtral (merged)
"transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged)
"transformer.blocks.{bid}.ffn.experts.mlp.w2", # dbrx
"model.layers.{bid}.mlp.experts.down_proj", # qwen2moe olmoe (merged)
"model.layers.{bid}.block_sparse_moe.output_linear", # granitemoe
),

MODEL_TENSOR.FFN_DOWN_SHEXP: (
Expand Down
30 changes: 28 additions & 2 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ enum llm_arch {
LLM_ARCH_EXAONE,
LLM_ARCH_RWKV6,
LLM_ARCH_GRANITE,
LLM_ARCH_GRANITE_MOE,
LLM_ARCH_UNKNOWN,
};

Expand Down Expand Up @@ -266,6 +267,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_EXAONE, "exaone" },
{ LLM_ARCH_RWKV6, "rwkv6" },
{ LLM_ARCH_GRANITE, "granite" },
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};

Expand Down Expand Up @@ -1467,6 +1469,7 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
Expand All @@ -1478,6 +1481,24 @@ static const std::map<llm_arch, std::map<llm_tensor, std::string>> LLM_TENSOR_NA
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_GRANITE_MOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_UNKNOWN,
{
Expand Down Expand Up @@ -2396,7 +2417,7 @@ struct llama_hparams {
float f_max_alibi_bias = 0.0f;
float f_logit_scale = 0.0f;

// Additional scale factors (Granite)
// Additional scale factors (Granite/Granite MoE)
float f_residual_scale = 0.0f;
float f_embedding_scale = 0.0f;
float f_attention_scale = 0.0f;
Expand Down Expand Up @@ -6048,6 +6069,7 @@ static void llm_load_hparams(
}
} break;
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale);
Expand All @@ -6056,6 +6078,7 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_ATTENTION_SCALE, hparams.f_attention_scale);

switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_3B; break;
case 40: model.type = e_model::MODEL_3B; break;
// Add additional layer/vocab/etc checks here for other model sizes
default: model.type = e_model::MODEL_UNKNOWN;
Expand Down Expand Up @@ -6810,7 +6833,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: n_ff_shexp = %d\n", __func__, hparams.n_ff_shexp);
}

if (model.arch == LLM_ARCH_GRANITE) {
if (model.arch == LLM_ARCH_GRANITE || model.arch == LLM_ARCH_GRANITE_MOE) {
LLAMA_LOG_INFO("%s: f_embedding_scale = %f\n", __func__, hparams.f_embedding_scale);
LLAMA_LOG_INFO("%s: f_residual_scale = %f\n", __func__, hparams.f_residual_scale);
LLAMA_LOG_INFO("%s: f_attention_scale = %f\n", __func__, hparams.f_attention_scale);
Expand Down Expand Up @@ -6988,6 +7011,7 @@ static bool llm_load_tensors(
case LLM_ARCH_REFACT:
case LLM_ARCH_MINICPM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
{
model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab});

Expand Down Expand Up @@ -15934,6 +15958,7 @@ static struct ggml_cgraph * llama_build_graph(
switch (model.arch) {
case LLM_ARCH_LLAMA:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
{
result = llm.build_llama();
} break;
Expand Down Expand Up @@ -19236,6 +19261,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_CHATGLM:
case LLM_ARCH_GRANITE:
case LLM_ARCH_GRANITE_MOE:
return LLAMA_ROPE_TYPE_NORM;

// the pairs of head values are offset by n_rot/2
Expand Down

0 comments on commit 7d63c36

Please sign in to comment.