Skip to content

Commit

Permalink
add exaone model support
Browse files Browse the repository at this point in the history
  • Loading branch information
mscheong01 committed Aug 14, 2024
1 parent 3071c0a commit 70dab0f
Show file tree
Hide file tree
Showing 6 changed files with 301 additions and 7 deletions.
72 changes: 72 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
if chkhsh == "855059429035d75a914d1eda9f10a876752e281a054a7a3d421ef0533e5b6249":
# ref: https://huggingface.co/HuggingFaceTB/SmolLM-135M
res = "smollm"
if chkhsh == "4e2b24cc4770243d65a2c9ec19770a72f08cffc161adbb73fcbb6b7dd45a0aae":
# ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct
res = "exaone"

if res is None:
logger.warning("\n")
Expand Down Expand Up @@ -3595,6 +3598,75 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
name = name.removeprefix("transformer.")
return [(self.map_tensor_name(name), data_torch)]

@Model.register("ExaoneForCausalLM")
class ExaoneModel(Model):
model_arch = gguf.MODEL_ARCH.EXAONE

def set_gguf_parameters(self):
hparams = self.hparams

assert(hparams["activation_function"] == "silu")

max_position_embeddings = hparams["max_position_embeddings"]
embed_dim = hparams["hidden_size"]
num_heads = hparams["num_attention_heads"]
num_kv_heads = hparams.get("num_key_value_heads", num_heads)
layer_norm_eps = hparams["layer_norm_epsilon"]
intermediate_size = hparams["intermediate_size"] if "intermediate_size" in hparams else 4 * embed_dim
num_layers = hparams["num_layers"]
# ignore for now as EXAONE-3.0-7.8B-Instruct attentino_dropout is 0.0
# attention_dropout_rate = hparams["attention_dropout"]
# ignore for now as EXAONE-3.0-7.8B-Instruct embed_dropout is 0.0
# embed_dropout_rate = hparams["embed_dropout"]
self.gguf_writer.add_embedding_length(embed_dim)
self.gguf_writer.add_head_count(num_heads)
self.gguf_writer.add_head_count_kv(num_kv_heads)
self.gguf_writer.add_context_length(max_position_embeddings)
self.gguf_writer.add_layer_norm_rms_eps(layer_norm_eps)
self.gguf_writer.add_feed_forward_length(intermediate_size)
self.gguf_writer.add_block_count(num_layers)

if (rope_theta := self.hparams.get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base(rope_theta)
rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True)
rotary_factor = rotary_factor if rotary_factor is not None else 1.0
self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"])))
if hparams.get("rope_scaling") is not None and "factor" in hparams["rope_scaling"]:
if hparams["rope_scaling"].get("type") == "linear":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(hparams["rope_scaling"]["factor"])

def prepare_tensors(self):
if rope_scaling := self.find_hparam(["rope_scaling"], optional=True):
if rope_scaling.get("rope_type", '').lower() == "llama3":
base = self.hparams.get("rope_theta", 10000.0)
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))

factor = rope_scaling.get("factor", 8.0)
low_freq_factor = rope_scaling.get("low_freq_factor", 1.0)
high_freq_factor = rope_scaling.get("high_freq_factor", 4.0)
old_context_len = self.hparams.get("original_max_position_embeddings", 8192)

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
assert low_freq_wavelen != high_freq_wavelen

rope_factors = []
for freq in freqs:
wavelen = 2 * math.pi / freq
if wavelen < high_freq_wavelen:
rope_factors.append(1)
elif wavelen > low_freq_wavelen:
rope_factors.append(factor)
else:
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
rope_factors.append(1 / ((1 - smooth) / factor + smooth))

self.gguf_writer.add_tensor(self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), np.array(rope_factors, dtype=np.float32))

super().prepare_tensors()

###### CONVERSION LOGIC ######


Expand Down
1 change: 1 addition & 0 deletions convert_hf_to_gguf_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class TOKENIZER_TYPE(IntEnum):
{"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", },
{"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", },
{"name": "smollm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/HuggingFaceTB/SmolLM-135M", },
{"name": "exaone", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct", },
]


Expand Down
18 changes: 18 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ class MODEL_ARCH(IntEnum):
BITNET = auto()
T5 = auto()
JAIS = auto()
EXAONE = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -345,6 +346,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.BITNET: "bitnet",
MODEL_ARCH.T5: "t5",
MODEL_ARCH.JAIS: "jais",
MODEL_ARCH.EXAONE: "exaone",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -1048,6 +1050,22 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.EXAONE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_ROT_EMBD,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
# TODO
}

Expand Down
19 changes: 13 additions & 6 deletions gguf-py/gguf/tensor_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TensorNameMap:
# Token embeddings
MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais
"transformer.wte", # gpt2 gpt-j mpt refact qwen dbrx jais exaone
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf
Expand Down Expand Up @@ -52,7 +52,7 @@ class TensorNameMap:
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais
"lm_head", # gpt2 mpt falcon llama-hf baichuan qwen mamba dbrx jais exaone
"output", # llama-pth bloom internlm2
"word_embeddings_for_head", # persimmon
"lm_head.linear", # phi2
Expand All @@ -62,7 +62,7 @@ class TensorNameMap:
# Output norm
MODEL_TENSOR.OUTPUT_NORM: (
"gpt_neox.final_layer_norm", # gptneox
"transformer.ln_f", # gpt2 gpt-j falcon jais
"transformer.ln_f", # gpt2 gpt-j falcon jais exaone
"model.norm", # llama-hf baichuan internlm2
"norm", # llama-pth
"transformer.norm_f", # mpt dbrx
Expand All @@ -88,7 +88,7 @@ class TensorNameMap:
# Attention norm
MODEL_TENSOR.ATTN_NORM: (
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais
"transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen jais exaone
"transformer.blocks.{bid}.norm_1", # mpt
"transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom
Expand Down Expand Up @@ -142,6 +142,7 @@ class TensorNameMap:
"model.layers.layers.{bid}.self_attn.q_proj", # plamo
"model.layers.{bid}.attention.wq", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
"transformer.h.{bid}.attn.attention.q_proj", # exaone
),

# Attention key
Expand All @@ -154,6 +155,7 @@ class TensorNameMap:
"model.layers.layers.{bid}.self_attn.k_proj", # plamo
"model.layers.{bid}.attention.wk", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
"transformer.h.{bid}.attn.attention.k_proj", # exaone
),

# Attention value
Expand All @@ -165,7 +167,8 @@ class TensorNameMap:
"transformer.h.{bid}.attn.v", # refact
"model.layers.layers.{bid}.self_attn.v_proj", # plamo
"model.layers.{bid}.attention.wv", # internlm2
"transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok
"transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
"transformer.h.{bid}.attn.attention.v_proj", # exaone
),

# Attention output
Expand All @@ -190,6 +193,7 @@ class TensorNameMap:
"transformer.blocks.{bid}.norm_attn_norm.attn.out_proj", # dbrx
"encoder.layers.{bid}.self_attention.dense", # chatglm
"transformer.layers.{bid}.attn.out_proj", # openelm
"transformer.h.{bid}.attn.attention.out_proj", # exaone
),

# Attention output norm
Expand All @@ -215,7 +219,7 @@ class TensorNameMap:
# Feed-forward norm
MODEL_TENSOR.FFN_NORM: (
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais
"transformer.h.{bid}.ln_2", # gpt2 refact qwen jais exaone
"h.{bid}.post_attention_layernorm", # bloom
"transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf
Expand Down Expand Up @@ -277,6 +281,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2
"model.layers.{bid}.residual_mlp.w3", # arctic
"encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm
"transformer.h.{bid}.mlp.c_fc_1", # exaone
),

MODEL_TENSOR.FFN_UP_EXP: (
Expand Down Expand Up @@ -308,6 +313,7 @@ class TensorNameMap:
"encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2
"transformer.h.{bid}.mlp.linear_1", # refact
"model.layers.{bid}.residual_mlp.w1", # arctic
"transformer.h.{bid}.mlp.c_fc_0", # exaone
),

MODEL_TENSOR.FFN_GATE_EXP: (
Expand Down Expand Up @@ -347,6 +353,7 @@ class TensorNameMap:
"model.layers.{bid}.residual_mlp.w2", # arctic
"encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2
"encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm
"model.layers.h.{bid}.mlp.c_proj", # exaone
),

MODEL_TENSOR.FFN_DOWN_EXP: (
Expand Down
1 change: 1 addition & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20,
LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21,
LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22,
LLAMA_VOCAB_PRE_TYPE_EXAONE = 23,
};

// note: these values should be synchronized with ggml_rope
Expand Down
Loading

0 comments on commit 70dab0f

Please sign in to comment.