Skip to content

Commit

Permalink
llama : add EXAONE model support (ggerganov#9025)
Browse files Browse the repository at this point in the history
* add exaone model support

* add chat template

* fix whitespace

Co-authored-by: Georgi Gerganov <[email protected]>

* add ftype

* add exaone pre-tokenizer in `llama-vocab.cpp`

Co-Authored-By: compilade <[email protected]>

* fix lint

Co-Authored-By: compilade <[email protected]>

* add `EXAONE` to supported models in `README.md`

* fix space

Co-authored-by: compilade <[email protected]>

---------

Co-authored-by: Georgi Gerganov <[email protected]>
Co-authored-by: compilade <[email protected]>
Co-authored-by: compilade <[email protected]>
  • Loading branch information
4 people authored and arthw committed Nov 15, 2024
1 parent 057a4f3 commit 32e4a87
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 7 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [Open Elm models](https://huggingface.co/collections/apple/openelm-instruct-models-6619ad295d7ae9f868b759ca)
- [x] [ChatGLM3-6b](https://huggingface.co/THUDM/chatglm3-6b) + [ChatGLM4-9b](https://huggingface.co/THUDM/glm-4-9b)
- [x] [SmolLM](https://huggingface.co/collections/HuggingFaceTB/smollm-6695016cad7167254ce15966)
- [x] [EXAONE-3.0-7.8B-Instruct](https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct)

(instructions for supporting more models: [HOWTO-add-model.md](./docs/development/HOWTO-add-model.md))

Expand Down
74 changes: 74 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "bc01ce58980e1db43859146dc51b1758b3b88729b217a74792e9f8d43e479d21":
# ref: https://huggingface.co/TurkuNLP/gpt3-finnish-small
res = "gpt3-finnish"
if chkhsh == "4e2b24cc4770243d65a2c9ec19770a72f08cffc161adbb73fcbb6b7dd45a0aae":
# ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct
res = "exaone"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -3781,6 +3784,77 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter

return [(self.map_tensor_name(name), data_torch)]


@Model.register("ExaoneForCausalLM")
class ExaoneModel(Model):
model_arch = gguf.MODEL_ARCH.EXAONE

def set_gguf_parameters(self):
hparams = self.hparams

assert(hparams["activation_function"] == "silu")

max_position_embeddings = hparams["max_position_embeddings"]
embed_dim = hparams["hidden_size"]
num_heads = hparams["num_attention_heads"]
num_kv_heads = hparams.get("num_key_value_heads", num_heads)
layer_norm_eps = hparams["layer_norm_epsilon"]
intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim
num_layers = hparams["num_layers"]
# ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0
# attention_dropout_rate = hparams["attention_dropout"]
# ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0
# embed_dropout_rate = hparams["embed_dropout"]
self.gguf_writer.add_embedding_length(embed_dim)
self.gguf_writer.add_head_count(num_heads)
self.gguf_writer.add_head_count_kv(num_kv_heads)
self.gguf_writer.add_context_length(max_position_embeddings)
self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_block_count(num_layers)
self.gguf_writer.add_file_type(self.ftype)

if (rope_theta := self.hparams.get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base(rope_theta)
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True)
rotary_factor = rotary_factor if rotary_factor is not None else 1.0
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
if hparams.get("rope_scaling") is not None and "factor" in hparams["rope_scaling"]:
if hparams["rope_scaling"].get("type") == "linear":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])

def prepare_tensors(self):
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
if rope_scaling.get("rope_type", '').lower() == "llama3":
base = self.hparams.get("rope_theta", 10000.0)
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))

factor = rope_scaling.get("factor", 8.0)
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
assert low_freq_wavelen != high_freq_wavelen

rope_factors = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
rope_factors.append(1)
elif wavelen > low_freq_wavelen:
rope_factors.append(factor)
else:
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
rope_factors.append(1 / ((1 - smooth) / factor + smooth))

self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))

super().prepare_tensors()

###### CONVERSION LOGIC ######


Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", },
{'name': "bloom", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/bigscience/bloom", },
{'name': "gpt3-finnish", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/TurkuNLP/gpt3-finnish-small", },
{"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", },
]


Expand Down
18 changes: 18 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ class MODEL_ARCH(IntEnum):
T5ENCODER = auto()
JAIS = auto()
NEMOTRON = auto()
EXAONE = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -349,6 +350,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.T5ENCODER: "t5encoder",
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.NEMOTRON: "nemotron",
MODEL_ARCH.EXAONE: "exaone",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -1082,6 +1084,22 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.EXAONE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
# TODO
}

Expand Down
19 changes: 13 additions & 6 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TensorNameMap:
# Token embeddings
MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf nemotron
Expand Down Expand Up @@ -52,7 +52,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais nemotron exaone
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
Expand All @@ -62,7 +62,7 @@ class TensorNameMap:
# Output norm
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 gpt-j falcon jais
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
"model.norm", # llama-hf baichuan internlm2
"norm", # llama-pth
"transformer.norm_f", # mpt dbrx
Expand All @@ -89,7 +89,7 @@ class TensorNameMap:
# Attention norm
MODEL_TENSOR.ATTN_NORM: (
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais exaone
"transformer.blocks.{bid}.norm_1", # mpt
"transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom
Expand Down Expand Up @@ -143,6 +143,7 @@ class TensorNameMap:
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
"model.layers.{bid}.attention.wq", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
"transformer.h.{bid}.attn.attention.q_proj", # exaone
),

# Attention key
Expand All @@ -155,6 +156,7 @@ class TensorNameMap:
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
"model.layers.{bid}.attention.wk", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
"transformer.h.{bid}.attn.attention.k_proj", # exaone
),

# Attention value
Expand All @@ -166,7 +168,8 @@ class TensorNameMap:
"transformer.h.{bid}.attn.v", # refact
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
"model.layers.{bid}.attention.wv", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok
"transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
"transformer.h.{bid}.attn.attention.v_proj", # exaone
),

# Attention output
Expand All @@ -191,6 +194,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
"encoder.layers.{bid}.self_attention.dense", # chatglm
"transformer.layers.{bid}.attn.out_proj", # openelm
"transformer.h.{bid}.attn.attention.out_proj", # exaone
),

# Attention output norm
Expand All @@ -216,7 +220,7 @@ class TensorNameMap:
# Feed-forward norm
MODEL_TENSOR.FFN_NORM: (
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
"h.{bid}.post_attention_layernorm", # bloom
"transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf nemotron
Expand Down Expand Up @@ -278,6 +282,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
"model.layers.{bid}.residual_mlp.w3", # arctic
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
"transformer.h.{bid}.mlp.c_fc_1", # exaone
),

MODEL_TENSOR.FFN_UP_EXP: (
Expand Down Expand Up @@ -309,6 +314,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
"transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.residual_mlp.w1", # arctic
"transformer.h.{bid}.mlp.c_fc_0", # exaone
),

MODEL_TENSOR.FFN_GATE_EXP: (
Expand Down Expand Up @@ -348,6 +354,7 @@ class TensorNameMap:
"model.layers.{bid}.residual_mlp.w2", # arctic
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
"model.layers.h.{bid}.mlp.c_proj", # exaone
),

MODEL_TENSOR.FFN_DOWN_EXP: (
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
LLAMA_VOCAB_PRE_TYPE_BLOOM = 23,
LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24,
LLAMA_VOCAB_PRE_TYPE_EXAONE = 25,
};

enum llama_rope_type {
Expand Down
1 change: 1 addition & 0 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ struct llm_tokenizer_bpe {
case LLAMA_VOCAB_PRE_TYPE_COMMAND_R:
case LLAMA_VOCAB_PRE_TYPE_SMOLLM:
case LLAMA_VOCAB_PRE_TYPE_CODESHELL:
case LLAMA_VOCAB_PRE_TYPE_EXAONE:
regex_exprs = {
"\\p{N}",
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
Expand Down
Loading

0 comments on commit 32e4a87

Please sign in to comment.