-
Notifications
You must be signed in to change notification settings - Fork 10k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
llama : support RWKV v6 models #8980
Conversation
5280749
to
cf40fd3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few things I've noticed. I'll review this more deeply in the next days.
487fb6d
to
9bf958f
Compare
6edbe81
to
bc3e37d
Compare
ecf84ca
to
e7d35a3
Compare
d7e71a5
to
c3564d8
Compare
Synchronized the changes and made it working again after #8526 being merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm impressed that ggml_rwkv_wkv
only takes around 2% of the CPU time during inference of the 1.6B RWKV-v6 model (when measured with perf record --call-graph=lbr
).
I have some styling comments, some suggestions, and I also found some problems.
Indeed. I did consider writing a metal kernel for wkv, but it turned out that wkv kernels didn't eat much cpu time. |
8e2e9aa
to
a8db247
Compare
Signed-off-by: Molly Sophia <[email protected]>
Signed-off-by: Molly Sophia <[email protected]>
Co-authored-by: compilade <[email protected]>
Co-authored-by: compilade <[email protected]>
Signed-off-by: Molly Sophia <[email protected]>
Signed-off-by: Molly Sophia <[email protected]>
Signed-off-by: Molly Sophia <[email protected]>
Signed-off-by: Molly Sophia <[email protected]>
…t tensors Signed-off-by: Molly Sophia <[email protected]>
Signed-off-by: Molly Sophia <[email protected]>
Signed-off-by: Molly Sophia <[email protected]>
Signed-off-by: Molly Sophia <[email protected]>
a1429c2
to
7444046
Compare
Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight Signed-off-by: Molly Sophia <[email protected]>
Lets look to merge soon. @MollySophia Which HF model do you recommend to run a few tests with this branch? |
https://huggingface.co/RWKV/v6-Finch-1B6-HF should be enough for testing the functionalities. |
I've updated the tokenizer to use a true for string search (7004323). With this change the time for tokenizing |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW What's next for this PR?
@MollySophia It looks ready for me, at least. Nice work!
There's some potential division by zero with hparams.rescale_every_n_layers
which I think should be fixed before merging.
Improvements to ggml_rwkv_wkv
(if relevant) can be done later in a follow-up PR, so I think this will be ready to merge.
Co-authored-by: compilade <[email protected]>
Signed-off-by: Molly Sophia <[email protected]>
* convert_hf_to_gguf: Add support for RWKV v6 Signed-off-by: Molly Sophia <[email protected]> * Add RWKV tokenization * Fix build Signed-off-by: Molly Sophia <[email protected]> * Do not use special tokens when matching in RWKV tokenizer * Fix model loading * Add (broken) placeholder graph builder for RWKV * Add workaround for kv cache * Add logits conversion to rwkv5 * Add rwkv5 layer norms * Add time mix KVRG & correct merge mistake * Add remaining time mix parameters * Add time mix output loading * Add placeholder llm_build_time_mix * Fix build Signed-off-by: Molly Sophia <[email protected]> * Load more tensors for rwkv v6 Signed-off-by: Molly Sophia <[email protected]> * Fix rwkv tokenizer Signed-off-by: Molly Sophia <[email protected]> * ggml: Add unary operator Exp Signed-off-by: Molly Sophia <[email protected]> * RWKV v6 graph building Signed-off-by: Molly Sophia <[email protected]> * Add ``rescale_every_n_layers`` parameter Signed-off-by: Molly Sophia <[email protected]> * Add ``wkv.head_size`` key for RWKV so it doesn't reuse Mamba ssm parameters Signed-off-by: Molly Sophia <[email protected]> * Fix offloading layers to CUDA Signed-off-by: Molly Sophia <[email protected]> * Fix parallel inferencing for RWKV Signed-off-by: Molly Sophia <[email protected]> * Remove trailing whitespaces Signed-off-by: Molly Sophia <[email protected]> * build_rwkv: Avoid using inplace operations Signed-off-by: Molly Sophia <[email protected]> * convert_hf_to_gguf: rwkv: Avoid using ``eval`` Signed-off-by: Molly Sophia <[email protected]> * convert_hf_to_gguf: rwkv tokenizer: Don't escape sequences manually Signed-off-by: Molly Sophia <[email protected]> * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> * ggml: Add backward computation for unary op ``exp`` Signed-off-by: Molly Sophia <[email protected]> * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> * Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV Signed-off-by: Molly Sophia <[email protected]> * build_rwkv6: Simplify graph Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Detect model.type Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Fix tensor loading for 7B/14B models Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Fix group_norm assertion failure with Metal Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Clean up Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Add quantization tensor exclusion Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Use the new advanced batch splits Signed-off-by: Molly Sophia <[email protected]> * Update src/llama.cpp Co-authored-by: compilade <[email protected]> * llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm`` Co-authored-by: compilade <[email protected]> * llama: rwkv6: Apply code style and misc changes Signed-off-by: Molly Sophia <[email protected]> * converter: Use class name ``Rwkv6Model`` Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Make use of key ``feed_forward_length`` Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Add kv ``time_mix_extra_dim`` and ``time_decay_extra_dim`` Signed-off-by: Molly Sophia <[email protected]> * converter: Match ``new_name`` instead of ``name`` for float32 explicit tensors Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Keep ``time_mix_w1/w2`` as F32 Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Remove unused nodes Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Apply code format changes Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Add lora for some supported tensors Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight Signed-off-by: Molly Sophia <[email protected]> * rwkv : speed-up tokenization using trie * minor : style + indentation * llama: rwkv6: Avoid division by zero Co-authored-by: compilade <[email protected]> * ggml: rwkv_wkv: Avoid copying the state Signed-off-by: Molly Sophia <[email protected]> --------- Signed-off-by: Molly Sophia <[email protected]> Co-authored-by: Layl Bongers <[email protected]> Co-authored-by: compilade <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
* convert_hf_to_gguf: Add support for RWKV v6 Signed-off-by: Molly Sophia <[email protected]> * Add RWKV tokenization * Fix build Signed-off-by: Molly Sophia <[email protected]> * Do not use special tokens when matching in RWKV tokenizer * Fix model loading * Add (broken) placeholder graph builder for RWKV * Add workaround for kv cache * Add logits conversion to rwkv5 * Add rwkv5 layer norms * Add time mix KVRG & correct merge mistake * Add remaining time mix parameters * Add time mix output loading * Add placeholder llm_build_time_mix * Fix build Signed-off-by: Molly Sophia <[email protected]> * Load more tensors for rwkv v6 Signed-off-by: Molly Sophia <[email protected]> * Fix rwkv tokenizer Signed-off-by: Molly Sophia <[email protected]> * ggml: Add unary operator Exp Signed-off-by: Molly Sophia <[email protected]> * RWKV v6 graph building Signed-off-by: Molly Sophia <[email protected]> * Add ``rescale_every_n_layers`` parameter Signed-off-by: Molly Sophia <[email protected]> * Add ``wkv.head_size`` key for RWKV so it doesn't reuse Mamba ssm parameters Signed-off-by: Molly Sophia <[email protected]> * Fix offloading layers to CUDA Signed-off-by: Molly Sophia <[email protected]> * Fix parallel inferencing for RWKV Signed-off-by: Molly Sophia <[email protected]> * Remove trailing whitespaces Signed-off-by: Molly Sophia <[email protected]> * build_rwkv: Avoid using inplace operations Signed-off-by: Molly Sophia <[email protected]> * convert_hf_to_gguf: rwkv: Avoid using ``eval`` Signed-off-by: Molly Sophia <[email protected]> * convert_hf_to_gguf: rwkv tokenizer: Don't escape sequences manually Signed-off-by: Molly Sophia <[email protected]> * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> * ggml: Add backward computation for unary op ``exp`` Signed-off-by: Molly Sophia <[email protected]> * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> * Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV Signed-off-by: Molly Sophia <[email protected]> * build_rwkv6: Simplify graph Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Detect model.type Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Fix tensor loading for 7B/14B models Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Fix group_norm assertion failure with Metal Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Clean up Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Add quantization tensor exclusion Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Use the new advanced batch splits Signed-off-by: Molly Sophia <[email protected]> * Update src/llama.cpp Co-authored-by: compilade <[email protected]> * llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm`` Co-authored-by: compilade <[email protected]> * llama: rwkv6: Apply code style and misc changes Signed-off-by: Molly Sophia <[email protected]> * converter: Use class name ``Rwkv6Model`` Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Make use of key ``feed_forward_length`` Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Add kv ``time_mix_extra_dim`` and ``time_decay_extra_dim`` Signed-off-by: Molly Sophia <[email protected]> * converter: Match ``new_name`` instead of ``name`` for float32 explicit tensors Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Keep ``time_mix_w1/w2`` as F32 Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Remove unused nodes Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Apply code format changes Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Add lora for some supported tensors Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight Signed-off-by: Molly Sophia <[email protected]> * rwkv : speed-up tokenization using trie * minor : style + indentation * llama: rwkv6: Avoid division by zero Co-authored-by: compilade <[email protected]> * ggml: rwkv_wkv: Avoid copying the state Signed-off-by: Molly Sophia <[email protected]> --------- Signed-off-by: Molly Sophia <[email protected]> Co-authored-by: Layl Bongers <[email protected]> Co-authored-by: compilade <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
* convert_hf_to_gguf: Add support for RWKV v6 Signed-off-by: Molly Sophia <[email protected]> * Add RWKV tokenization * Fix build Signed-off-by: Molly Sophia <[email protected]> * Do not use special tokens when matching in RWKV tokenizer * Fix model loading * Add (broken) placeholder graph builder for RWKV * Add workaround for kv cache * Add logits conversion to rwkv5 * Add rwkv5 layer norms * Add time mix KVRG & correct merge mistake * Add remaining time mix parameters * Add time mix output loading * Add placeholder llm_build_time_mix * Fix build Signed-off-by: Molly Sophia <[email protected]> * Load more tensors for rwkv v6 Signed-off-by: Molly Sophia <[email protected]> * Fix rwkv tokenizer Signed-off-by: Molly Sophia <[email protected]> * ggml: Add unary operator Exp Signed-off-by: Molly Sophia <[email protected]> * RWKV v6 graph building Signed-off-by: Molly Sophia <[email protected]> * Add ``rescale_every_n_layers`` parameter Signed-off-by: Molly Sophia <[email protected]> * Add ``wkv.head_size`` key for RWKV so it doesn't reuse Mamba ssm parameters Signed-off-by: Molly Sophia <[email protected]> * Fix offloading layers to CUDA Signed-off-by: Molly Sophia <[email protected]> * Fix parallel inferencing for RWKV Signed-off-by: Molly Sophia <[email protected]> * Remove trailing whitespaces Signed-off-by: Molly Sophia <[email protected]> * build_rwkv: Avoid using inplace operations Signed-off-by: Molly Sophia <[email protected]> * convert_hf_to_gguf: rwkv: Avoid using ``eval`` Signed-off-by: Molly Sophia <[email protected]> * convert_hf_to_gguf: rwkv tokenizer: Don't escape sequences manually Signed-off-by: Molly Sophia <[email protected]> * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> * ggml: Add backward computation for unary op ``exp`` Signed-off-by: Molly Sophia <[email protected]> * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> * Update convert_hf_to_gguf.py Co-authored-by: compilade <[email protected]> * Use MODEL_ARCH.RWKV6 instead of MODEL_ARCH.RWKV Signed-off-by: Molly Sophia <[email protected]> * build_rwkv6: Simplify graph Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Detect model.type Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Fix tensor loading for 7B/14B models Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Fix group_norm assertion failure with Metal Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Clean up Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Add quantization tensor exclusion Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Use the new advanced batch splits Signed-off-by: Molly Sophia <[email protected]> * Update src/llama.cpp Co-authored-by: compilade <[email protected]> * llama: rwkv6: Use ``ggml_norm`` instead of ``ggml_group_norm`` Co-authored-by: compilade <[email protected]> * llama: rwkv6: Apply code style and misc changes Signed-off-by: Molly Sophia <[email protected]> * converter: Use class name ``Rwkv6Model`` Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Make use of key ``feed_forward_length`` Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Add kv ``time_mix_extra_dim`` and ``time_decay_extra_dim`` Signed-off-by: Molly Sophia <[email protected]> * converter: Match ``new_name`` instead of ``name`` for float32 explicit tensors Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Keep ``time_mix_w1/w2`` as F32 Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Remove unused nodes Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Apply code format changes Signed-off-by: Molly Sophia <[email protected]> * llama: rwkv6: Add lora for some supported tensors Currently att.key/receptance/value/gate/output, ffn.receptance/key/value, as well as head.weight Signed-off-by: Molly Sophia <[email protected]> * rwkv : speed-up tokenization using trie * minor : style + indentation * llama: rwkv6: Avoid division by zero Co-authored-by: compilade <[email protected]> * ggml: rwkv_wkv: Avoid copying the state Signed-off-by: Molly Sophia <[email protected]> --------- Signed-off-by: Molly Sophia <[email protected]> Co-authored-by: Layl Bongers <[email protected]> Co-authored-by: compilade <[email protected]> Co-authored-by: Georgi Gerganov <[email protected]>
This should fix #846.
Added:
ggml:
Exp
rwkv_wkv
operation with CPU implrwkv_token_shift
operation with CPU impl to handle multiple sequences in parallel(may not be necessary after llama : simplify Mamba with advanced batch splits #8526 is done)llama.cpp:
rwkv_world
tokenizer support (by @LaylBongers)convert_hf_to_gguf.py
support for converting RWKV v6 HF modelsTODO:
Do modifications after llama : simplify Mamba with advanced batch splits #8526 is ready accordinglyDoneAdd CUDA or Metal implementation forMaybe next PRrwkv_wkv
operation