-
Notifications
You must be signed in to change notification settings - Fork 10k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: Support for falcon-mamba
architecture
#9074
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
src/llama.cpp
Outdated
// Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers | ||
if (ssm_b_dt_rms) { | ||
dt = ggml_rms_norm(ctx0, dt, norm_rms_eps); | ||
B = ggml_rms_norm(ctx0, B, norm_rms_eps); | ||
C = ggml_rms_norm(ctx0, C, norm_rms_eps); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will eventually be rewritten to use llm_build_norm
, because some Mamba-based architectures like Jamba use RMS norms with learnable parameters here.
But for now I think this is fine.
Co-authored-by: compilade <[email protected]>
Co-authored-by: compilade <[email protected]>
Co-authored-by: compilade <[email protected]>
Thanks for the detailed review @compilade ! Should be all addressed now |
Co-authored-by: compilade <[email protected]>
Co-authored-by: compilade <[email protected]>
5aeaca0
to
bf5e344
Compare
2f1391d
to
ca4db9e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My only remaining nitpick is related vertical alignment of a printed string. This will be good to merge after fixing that.
I can confirm old Mamba models still work correctly with this change. I'll try Falcon-Mamba next to see if perplexity looks reasonable.
Co-authored-by: compilade <[email protected]>
Getting an error during conversion of https://huggingface.co/tiiuae/falcon-mamba-7b-instruct, will investigate
This is likely related with metadata extraction from the model card (when the |
@compilade this commit: https://huggingface.co/tiiuae/falcon-mamba-7b-instruct/commit/5e1687c297b82872dc38b33878d4601810e2ed67 should fix it, somehow gguf is not happy when |
I've been experimenting with quantization, and from what I've seen with Mamba-2, I think it could be safe to quantize Mamba-1's Here's the patch I used if you're interested (click to expand)diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 4b843991..108c822c 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -295,6 +295,7 @@ class Model:
gguf.MODEL_TENSOR.FFN_GATE_INP,
gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES,
+ gguf.MODEL_TENSOR.SSM_CONV1D,
)
)
or not name.endswith(".weight")
@@ -2786,23 +2787,6 @@ class MambaModel(Model):
return [(new_name, data_torch)]
- def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
- if bid is not None and new_name in (
- self.format_tensor_name(
- n, bid, ".weight" if name.endswith(".weight") else ""
- )
- for n in [
- gguf.MODEL_TENSOR.SSM_CONV1D,
- gguf.MODEL_TENSOR.SSM_X,
- gguf.MODEL_TENSOR.SSM_DT,
- gguf.MODEL_TENSOR.SSM_A,
- gguf.MODEL_TENSOR.SSM_D,
- ]
- ):
- return gguf.GGMLQuantizationType.F32
-
- return super().tensor_force_quant(name, new_name, bid, n_dims)
-
@Model.register("CohereForCausalLM")
class CommandR2Model(Model):
diff --git a/src/llama.cpp b/src/llama.cpp
index 84fe4967..b8fa7684 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -16450,8 +16450,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// do not quantize Mamba's small yet 2D weights
// NOTE: can't use LLM_TN here because the layer number is not known
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
- quantize &= name.find("ssm_x.weight") == std::string::npos;
- quantize &= name.find("ssm_dt.weight") == std::string::npos;
// do not quantize relative position bias (T5)
quantize &= name.find("attn_rel_b.weight") == std::string::npos; This reduces the At I did not measure perplexity yet, because my hardware is a low-power laptop with 8GB of RAM and I only get 2 tokens per second with that model (with 12% of the time in $ ./bin/llama-perplexity -m /path/to/falcon-mamba-7B-chat-Q4_K_S.gguf -f /path/to/wiki.test.txt -b 512 -c 512 For me it would take 4 minutes per chunk, and the 5GB model barely fits in my free RAM. |
Hi @compilade
I will let you decide here, happy to push the patch in this PR and I'll upload the converted quants on the TII HF org |
Yeah I thought I already tackled this issue in #8774 . Double checked by copying your falcon mamba model card metadata to my test repo and running against it. No issues detected. Basically added a 'zero length array' check to In that case then... Should we throw an error if e.g. def add_key_value(self, key: str, val: Any, vtype: GGUFValueType) -> None:
assert(GGUFValueType.get_type(val) == vtype)
... edit: Ah, my sketch above won't deal with different sized integers... but the point still stands |
The problematic model card is no longer the latest version (which works fine), it was https://huggingface.co/tiiuae/falcon-mamba-7b/blob/503c3d4eaf202d970aabd81376c9f0d0e3defe2c/README.md.
Note that the problem here was that a list was passed to The failure was during serialization of the metadata, so a way to reproduce the error would be to make a vocab-only conversion.
Yes, failing early would make it easier to debug. This is a good idea, because the traceback I got doesn't explicitly mention the source of the problem, only that the types are wrong somewhere. But ideally incorrect metadata in the model card should not prevent conversion, which means the types should be checked (and/or coerced) in Actually, in this case Anyway, I think that more type checking in |
I've ran some further tests on a small Mamba model, and I realized that my initial decision to avoid quantizing these tensors was because I've added a fallback to the fallback quantization types: diff --git a/src/llama.cpp b/src/llama.cpp
index b8fa7684..fe3c0db6 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -16122,6 +16122,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
}
+ if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
+ new_type = GGML_TYPE_F16;
+ }
LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
++qs.n_fallback;
} Full patch so far for the changes to Mamba's quantization (click to expand)diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py
index 4b843991..108c822c 100755
--- a/convert_hf_to_gguf.py
+++ b/convert_hf_to_gguf.py
@@ -295,6 +295,7 @@ class Model:
gguf.MODEL_TENSOR.FFN_GATE_INP,
gguf.MODEL_TENSOR.POS_EMBD,
gguf.MODEL_TENSOR.TOKEN_TYPES,
+ gguf.MODEL_TENSOR.SSM_CONV1D,
)
)
or not name.endswith(".weight")
@@ -2786,23 +2787,6 @@ class MambaModel(Model):
return [(new_name, data_torch)]
- def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
- if bid is not None and new_name in (
- self.format_tensor_name(
- n, bid, ".weight" if name.endswith(".weight") else ""
- )
- for n in [
- gguf.MODEL_TENSOR.SSM_CONV1D,
- gguf.MODEL_TENSOR.SSM_X,
- gguf.MODEL_TENSOR.SSM_DT,
- gguf.MODEL_TENSOR.SSM_A,
- gguf.MODEL_TENSOR.SSM_D,
- ]
- ):
- return gguf.GGMLQuantizationType.F32
-
- return super().tensor_force_quant(name, new_name, bid, n_dims)
-
@Model.register("CohereForCausalLM")
class CommandR2Model(Model):
diff --git a/src/llama.cpp b/src/llama.cpp
index 84fe4967..fe3c0db6 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -16122,6 +16122,9 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n
case GGML_TYPE_Q6_K: new_type = GGML_TYPE_Q8_0; break;
default: throw std::runtime_error("\nUnsupported tensor size encountered\n");
}
+ if (tensor->ne[0] % ggml_blck_size(new_type) != 0) {
+ new_type = GGML_TYPE_F16;
+ }
LLAMA_LOG_WARN(" - using fallback quantization %s\n", ggml_type_name(new_type));
++qs.n_fallback;
}
@@ -16450,8 +16453,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
// do not quantize Mamba's small yet 2D weights
// NOTE: can't use LLM_TN here because the layer number is not known
quantize &= name.find("ssm_conv1d.weight") == std::string::npos;
- quantize &= name.find("ssm_x.weight") == std::string::npos;
- quantize &= name.find("ssm_dt.weight") == std::string::npos;
// do not quantize relative position bias (T5)
quantize &= name.find("attn_rel_b.weight") == std::string::npos; When doing this, a In summary, for Mamba-130M:
EDIT: for Mamba-370M:
This seems reasonable considering the file sizes differ by So I think this patch is worth it, at least when comparing the perplexity for given file sizes with Mamba-130M. I think this should also apply to Falcon-Mamba-7B. @younesbelkada Do you want me to push this directly here or do you want to commit the patch by yourself? Either way is fine with me.
To save you some bandwidth, be aware that currently, for Mamba models, there is no difference within variants like |
Thank you very much for the detailed answer, with respect to the experiments this are clear on my side! |
* llama : use f16 as the fallback of fallback quant types
Hi, I try this pr.
and model response is weird: ...............................................................................................
llama_new_context_with_model: n_ctx = 1048576
llama_new_context_with_model: n_batch = 2048
llama_new_context_with_model: n_ubatch = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base = 10000.0
llama_new_context_with_model: freq_scale = 1
ggml_metal_init: allocating
ggml_metal_init: found device: Apple M2 Ultra
ggml_metal_init: picking default device: Apple M2 Ultra
ggml_metal_init: using embedded metal library
ggml_metal_init: GPU name: Apple M2 Ultra
ggml_metal_init: GPU family: MTLGPUFamilyApple8 (1008)
ggml_metal_init: GPU family: MTLGPUFamilyCommon3 (3003)
ggml_metal_init: GPU family: MTLGPUFamilyMetal3 (5001)
ggml_metal_init: simdgroup reduction support = true
ggml_metal_init: simdgroup matrix mul. support = true
ggml_metal_init: hasUnifiedMemory = true
ggml_metal_init: recommendedMaxWorkingSetSize = 103079.22 MB
llama_kv_cache_init: Metal KV buffer size = 38.00 MiB
llama_new_context_with_model: KV self size = 38.00 MiB, K (f32): 6.00 MiB, V (f32): 32.00 MiB
llama_new_context_with_model: CPU output buffer size = 0.25 MiB
llama_new_context_with_model: Metal compute buffer size = 151.13 MiB
llama_new_context_with_model: CPU compute buffer size = 16.51 MiB
llama_new_context_with_model: graph nodes = 2568
llama_new_context_with_model: graph splits = 384
main: chat template example: <|im_start|>system
You are a helpful assistant<|im_end|>
<|im_start|>user
Hello<|im_end|>
<|im_start|>assistant
Hi there<|im_end|>
<|im_start|>user
How are you?<|im_end|>
<|im_start|>assistant
system_info: n_threads = 16 / 24 | AVX = 0 | AVX_VNNI = 0 | AVX2 = 0 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 0 | NEON = 1 | SVE = 0 | ARM_FMA = 1 | F16C = 0 | FP16_VA = 1 | WASM_SIMD = 0 | BLAS = 1 | SSE3 = 0 | SSSE3 = 0 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 |
main: interactive mode on.
sampling:
repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.800
mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order:
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature
generate: n_ctx = 1048576, n_batch = 2048, n_predict = -1, n_keep = 0
== Running in interactive mode. ==
- Press Ctrl+C to interject at any time.
- Press Return to return control to the AI.
- To return control without starting a new line, end your input with '/'.
- If you want to submit another line, end your input with '\'.
system
You are a helpful
> how to build a website by python
TalesFromYourServer
“I don’t think you deserve a tip for this.” I have a similar story to this. My husband and I went to the casino and I won $1200 in one hour. I was so excited and went to cash in my tickets at the counter, it was $1200 in quarters and nickels. I asked the cashier if I could have them counted so I could take the cash out of the machine and she said she was busy and that it would be about 10 minutes. I said okay and went to sit down with my husband to wait. When I got up 10 minutes later, she had counted them and I had $1100 in $1 bills. I was so upset, I went to the manager and he gave me $200 in cash for the $100 I lost. I was still upset and went to another cashier and told her the story and asked if she could count it for me. She said yes and she counted it in 2 minutes and had $1100 in $1 bills for me. I thanked her and she said that it was no problem and to have a nice night. I was very happy with that and I went to tell my husband that the second cashier counted them for me and she said it was no problem. I told her the whole story and she was shocked. She said that she would never have done that. She said that it was a lot of work and she could have gotten into trouble. I told her that it was the right thing to do and that I was grateful for her help. I don't get it, why wouldn't you just have the $1000 in $1 coins, then you'd have $1000 in $1 coins and you wouldn't need to count them. I didn’t even think about that.
> |
@LiuChaoXD can you make sure you have compiled llama.cpp using this branch with the command |
@compilade @ggerganov thanks for all the reviews, is there anything to do before merging this PR that I can help ? |
yes. |
@LiuChaoXD using the 4bit quantized model I get coherent results with the same system prompt you shared, see below: And using the prompt you shared: |
Thanks, appreciate. |
* feat: initial support for llama.cpp * fix: lint * refactor: better refactor * Update src/llama.cpp Co-authored-by: compilade <[email protected]> * Update src/llama.cpp Co-authored-by: compilade <[email protected]> * fix: address comments * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> * fix: add more cleanup and harmonization * fix: lint * Update gguf-py/gguf/gguf_writer.py Co-authored-by: compilade <[email protected]> * fix: change name * Apply suggestions from code review Co-authored-by: compilade <[email protected]> * add in operator * fix: add `dt_b_c_rms` in `llm_load_print_meta` * fix: correct printf format for bool * fix: correct print format * Update src/llama.cpp Co-authored-by: compilade <[email protected]> * llama : quantize more Mamba tensors * llama : use f16 as the fallback of fallback quant types --------- Co-authored-by: compilade <[email protected]>
* feat: initial support for llama.cpp * fix: lint * refactor: better refactor * Update src/llama.cpp Co-authored-by: compilade <[email protected]> * Update src/llama.cpp Co-authored-by: compilade <[email protected]> * fix: address comments * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> * fix: add more cleanup and harmonization * fix: lint * Update gguf-py/gguf/gguf_writer.py Co-authored-by: compilade <[email protected]> * fix: change name * Apply suggestions from code review Co-authored-by: compilade <[email protected]> * add in operator * fix: add `dt_b_c_rms` in `llm_load_print_meta` * fix: correct printf format for bool * fix: correct print format * Update src/llama.cpp Co-authored-by: compilade <[email protected]> * llama : quantize more Mamba tensors * llama : use f16 as the fallback of fallback quant types --------- Co-authored-by: compilade <[email protected]>
What does this PR do?
Fixes: #9009
Fixes: #9048
This PR adds the support for FalconMamba architecture in
llama.cpp
. I followed the suggestion from @compilade here: #9009 (comment) by simply extending the current Mamba architecture to be able to perform RMS norm operations for B / dt & C projections, in order to make things simple.Output from the model converted locally:
cc @compilade @ggerganov