Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature Request: Support for DeciLMForCausalLM #10028

Open
4 tasks done
ymcki opened this issue Oct 24, 2024 · 15 comments
Open
4 tasks done

Feature Request: Support for DeciLMForCausalLM #10028

ymcki opened this issue Oct 24, 2024 · 15 comments
Labels
enhancement New feature or request

Comments

@ymcki
Copy link

ymcki commented Oct 24, 2024

Prerequisites

  • I am running the latest code. Mention the version if possible as well.
  • I carefully followed the README.md.
  • I searched using keywords relevant to my issue to make sure that I am creating a new issue that is not already open (or closed).
  • I reviewed the Discussions, and have a new and useful enhancement to share.

Feature Description

I downloaded nvidia/Llama-3_1-Nemotron-51B-Instruct
but I am getting this error:

python3 convert_hf_to_gguf.py ~/Llama-3_1-Nemotron-51B-Instruct/ --outfile ~/Llama-3_1-Nemotron-51B-Instruct.f16.gguf --outtype f16
INFO:hf-to-gguf:Loading model: Llama-3_1-Nemotron-51B-Instruct
ERROR:hf-to-gguf:Model DeciLMForCausalLM is not supported

Motivation

Is this DeciLMForCausalLM model type going to be supported soon? It seems like the Q4_0 of this model can fit in 3090/4090 by offloading a few layers to CPU, a pretty good use case of llama.cpp.

Possible Implementation

No response

@ymcki ymcki added the enhancement New feature or request label Oct 24, 2024
@ymcki
Copy link
Author

ymcki commented Oct 24, 2024

https://huggingface.co/Deci/DeciLM-7B-instruct-GGUF

Interestingly, the author of DeciLM created GGUF for his model. How could he do that?

@ymcki
Copy link
Author

ymcki commented Oct 24, 2024

https://www.calcalistech.com/ctechnews/article/bkj6phggr

Nvidia acquired Deci, so that's why they are using its technology now. If we think Nvidia LLMs are going to be mainstream, then llama.cpp better supports DeciLM.

@compilade
Copy link
Collaborator

https://huggingface.co/Deci/DeciLM-7B-instruct-GGUF

Interestingly, the author of DeciLM created GGUF for his model. How could he do that?

I think they somehow made it not use variable GQA, hinted by the "uniform-gqa" part of the GGUF file names.

But since #7359, variable GQA is implemented and so it should be relatively straightforward to adapt the convert scripts to DeciLMForCausalLM.

@ymcki
Copy link
Author

ymcki commented Nov 8, 2024

I find that the tokenizer of DeciLM-7B-instruct is the same as Mistral-7B-Instruct-v0.2 by checking the hash.
https://huggingface.co/Deci/DeciLM-7B-instruct/tree/main
https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/tree/main

@ymcki
Copy link
Author

ymcki commented Nov 10, 2024

Since Mistral also is using sliding window grouped query attention, so I figure maybe it can be done by fiddling with Mistral related code.

I simply added this model to along with Mistral and Llama in convert_hf_to_gguf.py
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "DeciLMForCausalLM")

It can convert to f16 gguf without errors. But when I run the gguf, there seems to be a dimension mismatch error.

llama_model_load: error loading model: check_tensor_dims: tensor 'blk.0.attn_k.weight' has wrong shape; expected 4096, 4096, got 4096, 512, 1, 1
llama_load_model_from_file: failed to load model

@compilade
Copy link
Collaborator

But when I run the gguf, there seems to be a dimension mismatch error.

@ymcki

You need to handle variable GQA related metadata in the convert script so that the shapes are correct handled when loading. gguf_writer.add_head_count and gguf_writer.add_head_count_kv support getting a list of per-layer sizes for variable GQA.

I think you will need to use the num_key_value_heads_per_layer field from config.json to correctly set the head counts.

@ymcki
Copy link
Author

ymcki commented Nov 12, 2024

Thanks for your hint. After consulting the code in OpenELM, I added
if hparams["num_key_value_heads_per_layer"] is not None:
self._num_kv_heads: list[int] = hparams["num_key_value_heads_per_layer"]
assert self.block_count == len(self._num_kv_heads)
self.gguf_writer.add_head_count_kv(self._num_kv_heads)

to the set_gguf_parameter method of LlamaModel. Again, it can convert without errors but I got a slightly different error

llama_model_load: error loading model: check_tensor_dims: tensor 'blk.5.attn_k.weight' has wrong shape; expected 4096, 512, got 4096, 256, 1, 1
llama_load_model_from_file: failed to load model

Since in config.json of DeciLM-7B-Instruct
"num_key_value_heads_per_layer": [4, 4, 4, 4, 4, 2, 2, 2, 2, 2, 4, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 4],

So I suspect it only crashes at the 5th (actually sixth) layer is due to something off by one. How to fix this?

@ymcki
Copy link
Author

ymcki commented Nov 12, 2024

I find that if I modify llama.cpp of b4067 this way,

--- src/llama.cpp	2024-11-12 17:31:39.083117718 +0800
+++ llama.cpp-b4067/src/llama.cpp	2024-11-12 17:26:03.638700023 +0800
@@ -7513,14 +7513,14 @@
                         layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_em
bd}, 0);
 
                         layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_em
bd_head_k * n_head}, 0);
-                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_em
bd_k_gqa}, 0);
-                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_em
bd_v_gqa}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, hpar
ams.n_embd_k_gqa(i)}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, hpar
ams.n_embd_v_gqa(i)}, 0);
                         layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_
k * n_head, n_embd}, 0);
 
                         // optional bias tensors
                         layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     l
lama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, l
lama_model_loader::TENSOR_NOT_REQUIRED);
-                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, l
lama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "bias", i), {hparams.n_embd
_v_gqa(i)}, llama_model_loader::TENSOR_NOT_REQUIRED);
+                        layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "bias", i), {hparams.n_embd
_v_gqa(i)}, llama_model_loader::TENSOR_NOT_REQUIRED);
                         layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd},     l
lama_model_loader::TENSOR_NOT_REQUIRED);
 
                         layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd
}, 0);
@@ -10584,6 +10584,7 @@
         const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : 
hparams.f_attention_scale;
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
+            const int64_t n_head_kv = hparams.n_head_kv(il);
 
             // norm
             cur = llm_build_norm(ctx0, inpL, hparams,

Then there won't be any errors in both conversion and llama-cli. However, I am getting ##### reply from llama-cli, so something is still wrong. What is it?

I noticed that DeciLM-7B-Instruct is using dynamic RoPE scaling but this is not implemented in llama.cpp. Can that be the cause?

@ymcki
Copy link
Author

ymcki commented Nov 15, 2024

I find that "dynamic" is actually the dynamic NTK-aware RoPE scaling method according to
https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_rope_utils.py

I would like to implement it in ggml/src/ggml-cpu.c

However, it requires seq_len as a parameter in the _compute_dynamic_ntk_parameters in transformers modeling_rope_utils.py.

Which variable in ggml_compute_forward_rope_f32 corresponds to seq_len?

@ymcki
Copy link
Author

ymcki commented Nov 15, 2024

I find that unlike the smaller DeciLM-7B, Nemotron-51B has some layers that are linear attention
https://huggingface.co/nvidia/Llama-3_1-Nemotron-51B-Instruct/blob/main/config.json

For example, layer 10 has a normal config, so it can be handled by existing llama code:

   {
      "attention": {
        "n_heads_in_group": 32,
        "no_op": false,
        "replace_with_linear": false
      },
      "ffn": {
        "ffn_mult": 2.625,
        "no_op": false,
        "replace_with_linear": false
      }
    },

However, layer 11 contains a linear attention layer without attention heads

    {
      "attention": {
        "n_heads_in_group": null,
        "no_op": false,
        "replace_with_linear": true
      },
      "ffn": {
        "ffn_mult": 2.625,
        "no_op": false,
        "replace_with_linear": false
      }
    },

So my conversion script crashes

Traceback (most recent call last):
  File "/tank/ai/llama.cpp-b4067/chg.py", line 4438, in <module>
    main()
  File "/tank/ai/llama.cpp-b4067/chg.py", line 4432, in main
    model_instance.write()
  File "/tank/ai/llama.cpp-b4067/chg.py", line 434, in write
    self.prepare_tensors()
  File "/tank/ai/llama.cpp-b4067/chg.py", line 1664, in prepare_tensors
    super().prepare_tensors()
  File "/tank/ai/llama.cpp-b4067/chg.py", line 298, in prepare_tensors
    for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)):
  File "/tank/ai/llama.cpp-b4067/chg.py", line 1632, in modify_tensors
    return [(self.map_tensor_name(name), data_torch)]
  File "/tank/ai/llama.cpp-b4067/chg.py", line 214, in map_tensor_name
    raise ValueError(f"Can not map tensor {name!r}")
ValueError: Can not map tensor 'model.layers.11.self_attn.linear_attn.weight'

Does there any existing code support it already such that I can plug and play?

Supposedly, its implementation is in
https://huggingface.co/nvidia/Llama-3_1-Nemotron-51B-Instruct/blob/main/modeling_decilm.py
that seems quite straight forward:

class DeciLMLinearAttention(nn.Module):
    # DeciLM-specific code
    def __init__(self,
                 config: DeciLMConfig,
                 ):
        super().__init__() 
        self.linear_attn = nn.Linear(in_features=config.hidden_size,
                                     out_features=config.hidden_size,
                                     bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear_attn.forward(x)

@ymcki
Copy link
Author

ymcki commented Nov 20, 2024

By adding the following code to modify_tensors of LlamaModel class in convert_hf_to_gguf.py, I am able to convert DeciLM-7B-Instruct to gguf and deposited at
https://huggingface.co/ymcki/DeciLM-7B-Instruct-GGUF

        if bid is not None and "num_key_value_heads_per_layer" in self.hparams:
            n_kv_head = self.hparams["num_key_value_heads_per_layer"][bid]
        else:
            n_kv_head = self.hparams.get("num_key_value_heads")

The ggufs seem to be working even though I haven't implemented dynamic NTK-aware RoPE scaling. If I figure out how to implement it, I will see how different the response will be.

Anyway, the original purpose of this exercise is to convert Llama-3.1-Nemotron-51B-Instruct and this model doesn't use dynamic NTK-aware RoPE scaling. However, it uses linear attention. Does anyone know other models that uses linear attention such that I can copy and paste code? Thanks a lot in advance.

@ymcki
Copy link
Author

ymcki commented Nov 24, 2024

I find that there are three types of layers in DeciLMForCausalLM. One is exactly the same as llama.

INFO:hf-to-gguf:blk.0.attn_norm.weight,              torch.bfloat16 --> F32, shape = {8192}
INFO:hf-to-gguf:blk.0.ffn_down.weight,               torch.bfloat16 --> F16, shape = {7168, 8192}
INFO:hf-to-gguf:blk.0.ffn_gate.weight,               torch.bfloat16 --> F16, shape = {8192, 7168}
INFO:hf-to-gguf:blk.0.ffn_up.weight,                 torch.bfloat16 --> F16, shape = {8192, 7168}
INFO:hf-to-gguf:blk.0.ffn_norm.weight,               torch.bfloat16 --> F32, shape = {8192}
INFO:hf-to-gguf:blk.0.attn_k.weight,                 torch.bfloat16 --> F16, shape = {8192, 1024}
INFO:hf-to-gguf:blk.0.attn_output.weight,            torch.bfloat16 --> F16, shape = {8192, 8192}
INFO:hf-to-gguf:blk.0.attn_q.weight,                 torch.bfloat16 --> F16, shape = {8192, 8192}
INFO:hf-to-gguf:blk.0.attn_v.weight,                 torch.bfloat16 --> F16, shape = {8192, 1024}

The second type is a linear attention that replaces attn_k,attn_q, attn_v and attn_output.

INFO:hf-to-gguf:blk.11.attn_norm.weight,             torch.bfloat16 --> F32, shape = {8192}
INFO:hf-to-gguf:blk.11.ffn_down.weight,              torch.bfloat16 --> F16, shape = {14336, 8192}
INFO:hf-to-gguf:blk.11.ffn_gate.weight,              torch.bfloat16 --> F16, shape = {8192, 14336}
INFO:hf-to-gguf:blk.11.ffn_up.weight,                torch.bfloat16 --> F16, shape = {8192, 14336}
INFO:hf-to-gguf:blk.11.ffn_norm.weight,              torch.bfloat16 --> F32, shape = {8192}
INFO:hf-to-gguf:blk.11.self_attn.linear_attn.weight, torch.bfloat16 --> F16, shape = {8192, 8192}

The third type is an attention-free layer that only has four weights.

INFO:hf-to-gguf:blk.50.ffn_down.weight,              torch.bfloat16 --> F16, shape = {7168, 8192}
INFO:hf-to-gguf:blk.50.ffn_gate.weight,              torch.bfloat16 --> F16, shape = {8192, 7168}
INFO:hf-to-gguf:blk.50.ffn_up.weight,                torch.bfloat16 --> F16, shape = {8192, 7168}
INFO:hf-to-gguf:blk.50.ffn_norm.weight,              torch.bfloat16 --> F32, shape = {8192}

I believe I can handle the first type exactly as other llama model. But what is the proper way to handle the second and third?

To distinguish the second and third types, I set the n_head_kv to zero for the layer of the second type and both n_head and n_head_kv to zero for the layer of the third type.

Then I made these changes to llm_load_tensors in the LLM_ARCH_LLAMA case

                        if (n_head_kv == 0 && n_head > 0) {
                            // linear attention for DeciLMCausalModel
                            layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
                            layer.wqkv = create_tensor(tn(LLM_TENSOR_SELF_ATTN_LINEAR_ATTN, "weight", i), {n_embd, n_embd}, 0);
                        }
                        else if (n_head_kv > 0) { // original code to load attn_norm and four attn weights
                        }

In build_llama, I made these changes

            if (n_head > 0) {
                // norm
                cur = llm_build_norm(ctx0, inpL, hparams,
                        model.layers[il].attn_norm, NULL,
                        LLM_NORM_RMS, cb, il);
                cb(cur, "attn_norm", il);
            } else {
                cur = inpL;
            }

            // self-attention
            if (n_head_kv == 0 && n_head > 0) { // linear attention
                cur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wqkv, cur);
                cb(cur, "wqkv", il);
            } else if (n_head > 0)
            { // original code
            }

These changes allow llama-cli to run. However, I am getting gibberish in reply. How can I fix this? Thanks a lot in advance.

@slaren
Copy link
Collaborator

slaren commented Nov 24, 2024

You need to look at the python inference code and ensure that the same operations are being run in llama.cpp. llama-eval-callback can be useful to check the results of intermediate operations.

@ymcki
Copy link
Author

ymcki commented Nov 27, 2024

You need to look at the python inference code and ensure that the same operations are being run in llama.cpp. llama-eval-callback can be useful to check the results of intermediate operations.

Thanks slaren for this hint. Is there an equivalent tool that can also print similar numbers when loading a huggingface model? I want to know the numbers generated from the model I downloaded from
https://huggingface.co/nvidia/Llama-3_1-Nemotron-51B-Instruct
such that I can compare to the gguf numbers and see what went wrong.

Thanks a lot in advance.

@slaren
Copy link
Collaborator

slaren commented Nov 27, 2024

There isn't any specific tool to do that, but I suppose you could modify the code in modeling_decilm.py to print the values of the tensors at different points during the evaluation, and compare that to the output of llama-eval-callback.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants