Skip to content

Commit

Permalink
llama: Add support for Gemma2ForCausalLM (#8156)
Browse files Browse the repository at this point in the history
* Inference support for Gemma 2 model family

* Update convert-hf-to-gguf.py, constants, and tensor mappings

* cleanup

* format fix

* Fix special token vocab bug

* Don't add space prefix

* fix deleted lines

* Update src/llama.cpp

Co-authored-by: slaren <[email protected]>

* Add model type names

* Add control vector

* Fix model type identification

---------

Co-authored-by: Andrei Betlen <[email protected]>
Co-authored-by: slaren <[email protected]>
  • Loading branch information
3 people authored Jun 28, 2024
1 parent a27aa50 commit e57dc62
Show file tree
Hide file tree
Showing 4 changed files with 274 additions and 1 deletion.
40 changes: 40 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2340,6 +2340,46 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(self.map_tensor_name(name), data_torch)]


@Model.register("Gemma2ForCausalLM")
class Gemma2Model(Model):
model_arch = gguf.MODEL_ARCH.GEMMA2

def set_vocab(self):
self._set_vocab_llama_hf()
self.gguf_writer.add_add_space_prefix(False)

def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]

self.gguf_writer.add_name(self.dir_model.name if self.model_name is None else self.model_name)
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in hparams else hparams["num_attention_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_key_length(hparams["head_dim"])
self.gguf_writer.add_value_length(hparams["head_dim"])
self.gguf_writer.add_file_type(self.ftype)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unusem

# lm_head is not used in llama.cpp, while autoawq will include this tensor in model
# To prevent errors, skip loading lm_head.weight.
if name == "lm_head.weight":
logger.debug(f"Skipping get tensor {name!r} in safetensors so that convert can end normally.")
return []

# ref: https://github.com/huggingface/transformers/blob/fc37f38915372c15992b540dfcbbe00a916d4fc6/src/transformers/models/gemma/modeling_gemma.py#L89
if name.endswith("norm.weight"):
data_torch = data_torch + 1

return [(self.map_tensor_name(name), data_torch)]


@Model.register("Starcoder2ForCausalLM")
class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2
Expand Down
23 changes: 23 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ class MODEL_ARCH(IntEnum):
INTERNLM2 = auto()
MINICPM = auto()
GEMMA = auto()
GEMMA2 = auto()
STARCODER2 = auto()
MAMBA = auto()
XVERSE = auto()
Expand Down Expand Up @@ -180,10 +181,13 @@ class MODEL_TENSOR(IntEnum):
ATTN_NORM = auto()
ATTN_NORM_2 = auto()
ATTN_OUT_NORM = auto()
ATTN_POST_NORM = auto()
ATTN_ROT_EMBD = auto()
FFN_GATE_INP = auto()
FFN_GATE_INP_SHEXP = auto()
FFN_NORM = auto()
FFN_PRE_NORM = auto()
FFN_POST_NORM = auto()
FFN_GATE = auto()
FFN_DOWN = auto()
FFN_UP = auto()
Expand Down Expand Up @@ -270,6 +274,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.INTERNLM2: "internlm2",
MODEL_ARCH.MINICPM: "minicpm",
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.STARCODER2: "starcoder2",
MODEL_ARCH.MAMBA: "mamba",
MODEL_ARCH.XVERSE: "xverse",
Expand Down Expand Up @@ -303,9 +308,12 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm",
MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm",
MODEL_TENSOR.ATTN_OUT_NORM: "blk.{bid}.attn_output_norm",
MODEL_TENSOR.ATTN_POST_NORM: "blk.{bid}.post_attention_norm",
MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp",
MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp",
MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm",
MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm",
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
Expand Down Expand Up @@ -751,6 +759,21 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_NORM,
],
MODEL_ARCH.GEMMA2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.FFN_PRE_NORM,
MODEL_TENSOR.FFN_POST_NORM,
],
MODEL_ARCH.STARCODER2: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
14 changes: 14 additions & 0 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ class TensorNameMap:
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
),

MODEL_TENSOR.ATTN_POST_NORM: (
"model.layers.{bid}.post_attention_layernorm", # gemma2
),

# Rotary embeddings
MODEL_TENSOR.ATTN_ROT_EMBD: (
"model.layers.{bid}.self_attn.rotary_emb.inv_freq", # llama-hf
Expand All @@ -210,6 +214,16 @@ class TensorNameMap:
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
),

# Post feed-forward norm
MODEL_TENSOR.FFN_PRE_NORM: (
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
),

# Post feed-forward norm
MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm", # gemma2
),

MODEL_TENSOR.FFN_GATE_INP: (
"layers.{bid}.feed_forward.gate", # mixtral
"model.layers.{bid}.block_sparse_moe.gate", # mixtral
Expand Down
Loading

0 comments on commit e57dc62

Please sign in to comment.