-
Notifications
You must be signed in to change notification settings - Fork 10k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature Request: Support for DeciLMForCausalLM #10028
Comments
https://huggingface.co/Deci/DeciLM-7B-instruct-GGUF Interestingly, the author of DeciLM created GGUF for his model. How could he do that? |
https://www.calcalistech.com/ctechnews/article/bkj6phggr Nvidia acquired Deci, so that's why they are using its technology now. If we think Nvidia LLMs are going to be mainstream, then llama.cpp better supports DeciLM. |
I think they somehow made it not use variable GQA, hinted by the "uniform-gqa" part of the GGUF file names. But since #7359, variable GQA is implemented and so it should be relatively straightforward to adapt the convert scripts to |
I find that the tokenizer of DeciLM-7B-instruct is the same as Mistral-7B-Instruct-v0.2 by checking the hash. |
Since Mistral also is using sliding window grouped query attention, so I figure maybe it can be done by fiddling with Mistral related code. I simply added this model to along with Mistral and Llama in convert_hf_to_gguf.py It can convert to f16 gguf without errors. But when I run the gguf, there seems to be a dimension mismatch error. llama_model_load: error loading model: check_tensor_dims: tensor 'blk.0.attn_k.weight' has wrong shape; expected 4096, 4096, got 4096, 512, 1, 1 |
You need to handle variable GQA related metadata in the convert script so that the shapes are correct handled when loading. I think you will need to use the |
Thanks for your hint. After consulting the code in OpenELM, I added to the set_gguf_parameter method of LlamaModel. Again, it can convert without errors but I got a slightly different error llama_model_load: error loading model: check_tensor_dims: tensor 'blk.5.attn_k.weight' has wrong shape; expected 4096, 512, got 4096, 256, 1, 1 Since in config.json of DeciLM-7B-Instruct So I suspect it only crashes at the 5th (actually sixth) layer is due to something off by one. How to fix this? |
I find that if I modify llama.cpp of b4067 this way,
Then there won't be any errors in both conversion and llama-cli. However, I am getting ##### reply from llama-cli, so something is still wrong. What is it? I noticed that DeciLM-7B-Instruct is using dynamic RoPE scaling but this is not implemented in llama.cpp. Can that be the cause? |
I find that "dynamic" is actually the dynamic NTK-aware RoPE scaling method according to I would like to implement it in ggml/src/ggml-cpu.c However, it requires seq_len as a parameter in the _compute_dynamic_ntk_parameters in transformers modeling_rope_utils.py. Which variable in ggml_compute_forward_rope_f32 corresponds to seq_len? |
I find that unlike the smaller DeciLM-7B, Nemotron-51B has some layers that are linear attention For example, layer 10 has a normal config, so it can be handled by existing llama code:
However, layer 11 contains a linear attention layer without attention heads
So my conversion script crashes
Does there any existing code support it already such that I can plug and play? Supposedly, its implementation is in
|
By adding the following code to modify_tensors of LlamaModel class in convert_hf_to_gguf.py, I am able to convert DeciLM-7B-Instruct to gguf and deposited at
The ggufs seem to be working even though I haven't implemented dynamic NTK-aware RoPE scaling. If I figure out how to implement it, I will see how different the response will be. Anyway, the original purpose of this exercise is to convert Llama-3.1-Nemotron-51B-Instruct and this model doesn't use dynamic NTK-aware RoPE scaling. However, it uses linear attention. Does anyone know other models that uses linear attention such that I can copy and paste code? Thanks a lot in advance. |
I find that there are three types of layers in DeciLMForCausalLM. One is exactly the same as llama.
The second type is a linear attention that replaces attn_k,attn_q, attn_v and attn_output.
The third type is an attention-free layer that only has four weights.
I believe I can handle the first type exactly as other llama model. But what is the proper way to handle the second and third? To distinguish the second and third types, I set the n_head_kv to zero for the layer of the second type and both n_head and n_head_kv to zero for the layer of the third type. Then I made these changes to llm_load_tensors in the LLM_ARCH_LLAMA case
In build_llama, I made these changes
These changes allow llama-cli to run. However, I am getting gibberish in reply. How can I fix this? Thanks a lot in advance. |
You need to look at the python inference code and ensure that the same operations are being run in llama.cpp. |
Thanks slaren for this hint. Is there an equivalent tool that can also print similar numbers when loading a huggingface model? I want to know the numbers generated from the model I downloaded from Thanks a lot in advance. |
There isn't any specific tool to do that, but I suppose you could modify the code in |
Prerequisites
Feature Description
I downloaded nvidia/Llama-3_1-Nemotron-51B-Instruct
but I am getting this error:
Motivation
Is this DeciLMForCausalLM model type going to be supported soon? It seems like the Q4_0 of this model can fit in 3090/4090 by offloading a few layers to CPU, a pretty good use case of llama.cpp.
Possible Implementation
No response
The text was updated successfully, but these errors were encountered: