Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MiniCPM3. #9322

Merged
merged 1 commit into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1819,6 +1819,60 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
return [(self.map_tensor_name(name), data_torch)]


@Model.register("MiniCPM3ForCausalLM")
Copy link
Collaborator

@HanClinto HanClinto Sep 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@Model.register("MiniCPM3ForCausalLM")
@Model.register("MiniCPM3ForCausalLM")

Lint is currently breaking on this -- need to add an additional blank line above your class definition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your comments.

class MiniCPM3Model(Model):
model_arch = gguf.MODEL_ARCH.MINICPM3

def set_gguf_parameters(self):
hparams = self.hparams

rope_dims = hparams["qk_rope_head_dim"]

self.gguf_writer.add_file_type(self.ftype)
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"])
self.gguf_writer.add_layer_norm_rms_eps(hparams["rms_norm_eps"])
self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
self.gguf_writer.add_kv_lora_rank(hparams["kv_lora_rank"])
self.gguf_writer.add_key_length(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])

rope_scaling = self.find_hparam(['rope_scaling'], True)
if rope_scaling is None:
return

long_factors = rope_scaling.get('long_factor', None)
short_factors = rope_scaling.get('short_factor', None)

if long_factors is None or short_factors is None:
raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor')

if len(long_factors) != len(short_factors) or len(long_factors) != rope_dims / 2:
raise ValueError(f'The length of rope long and short factors must be {rope_dims / 2}')

self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_LONG] + ".weight", np.array(long_factors, dtype=np.float32))
self.gguf_writer.add_tensor(gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT] + ".weight", np.array(short_factors, dtype=np.float32))

def set_vocab(self):
self._set_vocab_llama_hf()

def _reverse_hf_permute(self, weights: Tensor, n_head: int, n_kv_head: int | None = None) -> Tensor:
if n_kv_head is not None and n_head != n_kv_head:
n_head //= n_kv_head

return (
weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
.swapaxes(1, 2)
.reshape(weights.shape)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
)
)

Lint is also breaking on this -- need to add an additional blank line below your class definition as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your comments.



@Model.register("QWenLMHeadModel")
class QwenModel(Model):
model_arch = gguf.MODEL_ARCH.QWEN
Expand Down
19 changes: 19 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ class MODEL_ARCH(IntEnum):
ORION = auto()
INTERNLM2 = auto()
MINICPM = auto()
MINICPM3 = auto()
GEMMA = auto()
GEMMA2 = auto()
STARCODER2 = auto()
Expand Down Expand Up @@ -364,6 +365,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.ORION: "orion",
MODEL_ARCH.INTERNLM2: "internlm2",
MODEL_ARCH.MINICPM: "minicpm",
MODEL_ARCH.MINICPM3: "minicpm3",
MODEL_ARCH.GEMMA: "gemma",
MODEL_ARCH.GEMMA2: "gemma2",
MODEL_ARCH.STARCODER2: "starcoder2",
Expand Down Expand Up @@ -867,6 +869,23 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.MINICPM3: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q_A,
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.GEMMA: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
Expand Down
Loading
Loading