Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama : support Jamba hybrid Transformer-Mamba models #7531

Draft
wants to merge 41 commits into
base: master
Choose a base branch
from

Conversation

compilade
Copy link
Collaborator

@compilade compilade commented May 25, 2024

This adds support for Jamba (fixes #6372). (https://arxiv.org/abs/2403.19887)

To complement llama_kv_cache, I propose to add llama_rs_cache, as well as a top-level llama_past to more easily manage both at once.

The current implementation of recurrent states (initially written for Mamba) re-uses the tensors allocated for the KV cache to store its recurrent states. Obviously, when both Attention and recurrent states are used at the same time, this previous approach does not work.

Note that since this uses some of the same operators as Mamba, this is CPU-only for now. (see #6758)

API changes

Most of the changes are backward-compatible, but the llama_kv_cache_seq_rm and llama_kv_cache_seq_cp functions have been renamed and now return the token position of the next token after the end of the sequence(s) they affect.

This is necessary to properly handle recurrent state checkpoints with models that also use the KV cache (like Jamba (and eventually Griffin)), in case the last valid state doesn't line up with the requested removal when using, for example, llama_past_seq_rm.

  • Deprecate most llama_kv_cache_* functions to rename them to llama_past_*.
    • Not strictly necessary, but since the return type and meaning changed, it might be easier to migrate existing code-bases with backward-compatible wrappers for any differing return behavior.
    • It's no longer only a KV cache, so removing _kv_ from the names could make them less confusing when working with pure or mixed recurrent models
    • I'm open to suggestions for a better name prefix! I think llama_past_* might be a bit too different from the previous name.
      • I could also re-use the old names, but this will be a breaking change because of the change in meaning of the return type of llama_kv_cache_seq_rm. It would also be confusing to figure out at a glance which functions are specific to the KV cache and which are specific the the recurrent state cache.
  • llama_past_seq_rm and llama_past_seq_cp now return n_past, which is the number of previous tokens in the affected sequence (or it can also be interpreted as the next token position at end of the sequence).
    • It should be handled by processing the tokens again from this point until the desired end.
    • Note that nothing needs to be handled when using these function on whole sequences (i.e. when -1 is passed to both p0 and p1)
  • llama_past_seq_max_pos returns -1 when there are no cells matching the specified seq_id, to allow calculating n_past by adding one to its result.
    • llama_kv_cache_seq_max_pos previously returned 0 in this case, which makes it indistinguishable from a when there's a single cell with pos 0

New features

  • Jamba support
    • The first hybrid Transformer+Mamba model in llama.cpp
  • State checkpoints for recurrent models
    • Works best when n_parallel is at least 3 or 4 times the number of actual users
    • Allows backtracking tokens from the end of the last generation without having to reprocess the whole context
      • Very useful with the server example when trimming the stop string
  • No longer unnecessarily allocate a KV cache for non-causal models (like BERT)
  • Variable GQA
    • GGUF metadata {model}.attention.head_count_kv can now also be an array of int32_t, one value per layer
    • Layers with 0 kv heads are considered recurrent layers (Mamba, in the case of Jamba).
    • This will make proper support of DeciLM possible

Internal changes

  • new struct llama_rs_cache, a ring-buffered tree of recurrent states
    • might be possible to simplify, but the data structure for recurrent states needs quick access to at least
      • the last state of a sequence (the tail cell)
      • the number of sequences for which a particular cell is the last state
      • how many sequences a cell is part of
      • the number of cells used by a sequence
      • the number of "active" sequences (which use cells they don't share with other sequences)
      • the number of cells used by "shared" sequences (e.g. the system prompt)
  • new struct llama_cache which contains both llama_kv_cache and llama_rs_cache
  • simpler Mamba state processing
    • RS cells can be the tail of multiple sequences which allow
      • one-to-one instead of one-to-many state processing (for the ggml_ssm_* operators)
      • llama_past_seq_cp doesn't use more RS cells the more sequences there are
    • RS slots are always contiguous, and are transparently defragmented if necessary when chosen.
  • new struct llama_ubatch for more metadata about sequences
  • batches are split with equal-length sequences for recurrent models
    • This allows to simplify the SSM operations
    • But the logits of the split batches have to be re-ordered when directly using llama_get_logits to match the old expected output. This is not a problem with llama_get_logits_ith, because there was already an indirection with lctx.output_ids which is reused.

TODO

  • Find a better prefix than llama_past_* Anybody has better name suggestions?
    • llama_cache_*
    • llama_past_*
    • llama_kv_cache_*
      • Will be confusing with recurrent models, and doesn't offer the possibility of backward-compatible wrappers if the same name is used
    • llama_ctx_cache_*
    • llama_llm_cache_*
    • llama_seq_cache_*
      • Could work, but would not help with discerning sequence-wise functions from cache-wise functions
    • llama_tok_cache_*
    • llama_comp_cache_*
    • llama_past_cache_*
    • llama_work_cache_*
    • llama_kvrs_cache_*
    • llama_causal_cache_*
    • llama_context_cache_*
  • session file save and restore
  • handle the returned n_past from the llama_past_* functions used in the various examples
    • server, main
    • speculative, lookup, lookahead
  • add consistency tests (perhaps in tests/test-llama-past.cpp)
  • Make the recurrent state checkpoint interval configurable
  • Make the minimum number of recurrent states per client configurable to more than one

Future ideas

  • Fairly split the available KV cells among active sequences, similarly to RS cells.
  • Handle token shift (and Self-Extend?) when finding a slot.
    • This could help with the fair split of KV cells by freeing cells of sequences which use more than their fair share of cells.

Testing

Example output of jamba-900M-v0.13-KIx2 (click to expand)
$  ./bin/main -m /srv/LLMstash/tmp/jamba-900M.bf16.gguf --temp 0 -e -p "I believe the meaning of life is" --repeat-penalty 1.2 --repeat-last-n 256 -c 16384 -n 256
Log start
main: build = 3003 (0fd13e94)
main: built with gcc (GCC) 13.2.0 for x86_64-unknown-linux-gnu
main: seed  = 1716594011
llama_model_loader: loaded meta data with 26 key-value pairs and 189 tensors from /srv/LLMstash/tmp/jamba-900M.bf16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = jamba
llama_model_loader: - kv   1:                               general.name str              = jamba-900M-v0.13-KIx2
llama_model_loader: - kv   2:                          jamba.block_count u32              = 12
llama_model_loader: - kv   3:                       jamba.context_length u32              = 16384
llama_model_loader: - kv   4:                     jamba.embedding_length u32              = 1024
llama_model_loader: - kv   5:                  jamba.feed_forward_length u32              = 4096
llama_model_loader: - kv   6:                 jamba.attention.head_count u32              = 32
llama_model_loader: - kv   7:              jamba.attention.head_count_kv arr[i32,12]      = [0, 0, 8, 0, 0, 0, 8, 0, 0, 0, 8, 0]
llama_model_loader: - kv   8:                      jamba.ssm.conv_kernel u32              = 4
llama_model_loader: - kv   9:                       jamba.ssm.inner_size u32              = 2048
llama_model_loader: - kv  10:                       jamba.ssm.state_size u32              = 16
llama_model_loader: - kv  11:                   jamba.ssm.time_step_rank u32              = 256
llama_model_loader: - kv  12:     jamba.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  13:                         jamba.expert_count u32              = 8
llama_model_loader: - kv  14:                    jamba.expert_used_count u32              = 2
llama_model_loader: - kv  15:                          general.file_type u32              = 32
llama_model_loader: - kv  16:                       tokenizer.ggml.model str              = gpt2
llama_model_loader: - kv  17:                         tokenizer.ggml.pre str              = gpt-2
llama_model_loader: - kv  18:                      tokenizer.ggml.tokens arr[str,65024]   = ["<EOT>", "<META>", "<META_START>", "...
llama_model_loader: - kv  19:                  tokenizer.ggml.token_type arr[i32,65024]   = [3, 3, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, ...
llama_model_loader: - kv  20:                      tokenizer.ggml.merges arr[str,64739]   = ["Ġ Ġ", "ĠĠ ĠĠ", "i n", "ĠĠ �...
llama_model_loader: - kv  21:                tokenizer.ggml.bos_token_id u32              = 0
llama_model_loader: - kv  22:                tokenizer.ggml.eos_token_id u32              = 0
llama_model_loader: - kv  23:            tokenizer.ggml.unknown_token_id u32              = 0
llama_model_loader: - kv  24:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  25:               general.quantization_version u32              = 2
llama_model_loader: - type  f32:  121 tensors
llama_model_loader: - type bf16:   68 tensors
llm_load_vocab: special tokens definition check successful ( 29/65024 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = jamba
llm_load_print_meta: vocab type       = BPE
llm_load_print_meta: n_vocab          = 65024
llm_load_print_meta: n_merges         = 64739
llm_load_print_meta: n_ctx_train      = 16384
llm_load_print_meta: n_embd           = 1024
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 32
llm_load_print_meta: n_layer          = 12
llm_load_print_meta: n_rot            = 32
llm_load_print_meta: n_embd_head_k    = 32
llm_load_print_meta: n_embd_head_v    = 32
llm_load_print_meta: n_gqa            = 0
llm_load_print_meta: n_embd_k_gqa     = 0
llm_load_print_meta: n_embd_v_gqa     = 0
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 4096
llm_load_print_meta: n_expert         = 8
llm_load_print_meta: n_expert_used    = 2
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = -1
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 16384
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 4
llm_load_print_meta: ssm_d_inner      = 2048
llm_load_print_meta: ssm_d_state      = 16
llm_load_print_meta: ssm_dt_rank      = 256
llm_load_print_meta: model type       = ?B
llm_load_print_meta: model ftype      = BF16
llm_load_print_meta: model params     = 887.66 M
llm_load_print_meta: model size       = 1.67 GiB (16.19 BPW) 
llm_load_print_meta: general.name     = jamba-900M-v0.13-KIx2
llm_load_print_meta: BOS token        = 0 '<EOT>'
llm_load_print_meta: EOS token        = 0 '<EOT>'
llm_load_print_meta: UNK token        = 0 '<EOT>'
llm_load_print_meta: PAD token        = 0 '<EOT>'
llm_load_print_meta: LF token         = 133 'Ä'
llm_load_tensors: ggml ctx size =    0.09 MiB
llm_load_tensors:        CPU buffer size =  1713.16 MiB
......................................
llama_new_context_with_model: n_ctx      = 16384
llama_new_context_with_model: n_batch    = 2048
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_cache_init:        CPU cache buf size =    49.34 MiB
llama_new_context_with_model: SSM state size =     1.34 MiB, R (f32):    0.21 MiB, S (f32):    1.12 MiB
llama_new_context_with_model: KV cache size  =    48.00 MiB, K (f16):   24.00 MiB, V (f16):   24.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.25 MiB
llama_new_context_with_model:        CPU compute buffer size =  1062.03 MiB
llama_new_context_with_model: graph nodes  = 621
llama_new_context_with_model: graph splits = 1

system_info: n_threads = 2 / 4 | AVX = 1 | AVX_VNNI = 0 | AVX2 = 1 | AVX512 = 0 | AVX512_VBMI = 0 | AVX512_VNNI = 0 | AVX512_BF16 = 0 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
sampling: 
        repeat_last_n = 256, repeat_penalty = 1.200, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.000
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 16384, n_batch = 2048, n_predict = 256, n_keep = 0


<EOT>I believe the meaning of life is not to be found in a single word, but rather as an expression of one's own feelings and thoughts.

The idea that we are all born with our bodies, whether they are human or animal, has been around for centuries. It was believed by some that it was something like a body made up of bones, which were attached to each other at birth. The most common form of this type of bone is called a "bone." This is what makes it so hard to tell if you're alive or dead. In fact, there are many different types of bones, including those that have been used for various purposes such as healing wounds, wounding wounds, etc.

In ancient times, people had a lot of teeth, and these were often very small. They could also be placed on top of their heads, where they would sit down and look at them. These were usually large, round stones, which were sometimes covered with hair. When the skin was removed from the head, the bones became more prominent, and the muscles began to grow larger.

This kind of bone was known as a "bone" because it was made out of two parts: the outermost part (the innermost portion) and the innermost part (the outermost
llama_print_timings:        load time =     252.28 ms
llama_print_timings:      sample time =     303.07 ms /   256 runs   (    1.18 ms per token,   844.68 tokens per second)
llama_print_timings: prompt eval time =     200.72 ms /     8 tokens (   25.09 ms per token,    39.86 tokens per second)
llama_print_timings:        eval time =   12516.79 ms /   255 runs   (   49.09 ms per token,    20.37 tokens per second)
llama_print_timings:       total time =   13213.95 ms /   263 tokens
Log end

@compilade compilade added enhancement New feature or request model Model specific refactoring Refactoring need feedback Testing and feedback with results are needed embeddings embedding related topics python python script changes Review Complexity : High Generally require indepth knowledge of LLMs or GPUs labels May 25, 2024
@compilade compilade marked this pull request as draft May 25, 2024 03:38
llama.cpp Outdated
Comment on lines 5244 to 5248
switch (hparams.n_layer) {
// TODO: Jamba layers are a bit heterogenous, so naming this is hard.
case 12: // 900M 8x???M
case 32: // 51B 16x?B
default: model.type = e_model::MODEL_UNKNOWN;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure what model size type(s) I should give to Jamba models.

@github-actions github-actions bot added the ggml changes relating to the ggml tensor library for machine learning label May 25, 2024
Copy link
Contributor

github-actions bot commented May 25, 2024

📈 llama.cpp server for bench-server-baseline on Standard_NC4as_T4_v3 for phi-2-q4_0: 557 iterations 🚀

Expand details for performance related PR only
  • Concurrent users: 8, duration: 10m
  • HTTP request : avg=8384.34ms p(95)=20451.68ms fails=, finish reason: stop=510 truncated=47
  • Prompt processing (pp): avg=102.96tk/s p(95)=478.95tk/s
  • Token generation (tg): avg=36.48tk/s p(95)=48.13tk/s
  • ggml-org/models/phi-2/ggml-model-q4_0.gguf parallel=8 ctx-size=16384 ngl=33 batch-size=2048 ubatch-size=256 pp=1024 pp+tg=2048 branch=compilade/refactor-kv-cache commit=fee3c1d740c0e027c81e2f2f3fb48d619857175f

prompt_tokens_seconds

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 557 iterations"
    y-axis "llamacpp:prompt_tokens_seconds"
    x-axis "llamacpp:prompt_tokens_seconds" 1717475210 --> 1717475834
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 306.61, 306.61, 306.61, 306.61, 306.61, 572.5, 572.5, 572.5, 572.5, 572.5, 579.51, 579.51, 579.51, 579.51, 579.51, 601.73, 601.73, 601.73, 601.73, 601.73, 638.34, 638.34, 638.34, 638.34, 638.34, 702.62, 702.62, 702.62, 702.62, 702.62, 704.56, 704.56, 704.56, 704.56, 704.56, 718.91, 718.91, 718.91, 718.91, 718.91, 723.54, 723.54, 723.54, 723.54, 723.54, 739.59, 739.59, 739.59, 739.59, 739.59, 771.46, 771.46, 771.46, 771.46, 771.46, 802.48, 802.48, 802.48, 802.48, 802.48, 815.12, 815.12, 815.12, 815.12, 815.12, 804.65, 804.65, 804.65, 804.65, 804.65, 797.38, 797.38, 797.38, 797.38, 797.38, 800.86, 800.86, 800.86, 800.86, 800.86, 805.61, 805.61, 805.61, 805.61, 805.61, 803.64, 803.64, 803.64, 803.64, 803.64, 824.04, 824.04, 824.04, 824.04, 824.04, 823.3, 823.3, 823.3, 823.3, 823.3, 830.32, 830.32, 830.32, 830.32, 830.32, 832.47, 832.47, 832.47, 832.47, 832.47, 846.38, 846.38, 846.38, 846.38, 846.38, 842.07, 842.07, 842.07, 842.07, 842.07, 844.76, 844.76, 844.76, 844.76, 844.76, 861.96, 861.96, 861.96, 861.96, 861.96, 855.54, 855.54, 855.54, 855.54, 855.54, 854.58, 854.58, 854.58, 854.58, 854.58, 856.84, 856.84, 856.84, 856.84, 856.84, 860.17, 860.17, 860.17, 860.17, 860.17, 858.21, 858.21, 858.21, 858.21, 858.21, 861.33, 861.33, 861.33, 861.33, 861.33, 871.29, 871.29, 871.29, 871.29, 871.29, 847.29, 847.29, 847.29, 847.29, 847.29, 832.73, 832.73, 832.73, 832.73, 832.73, 831.59, 831.59, 831.59, 831.59, 831.59, 831.76, 831.76, 831.76, 831.76, 831.76, 835.52, 835.52, 835.52, 835.52, 835.52, 836.15, 836.15, 836.15, 836.15, 836.15, 836.37, 836.37, 836.37, 836.37, 836.37, 817.57, 817.57, 817.57, 817.57, 817.57, 820.16, 820.16, 820.16, 820.16, 820.16, 820.49, 820.49, 820.49, 820.49, 820.49, 820.0, 820.0, 820.0, 820.0, 820.0, 817.08, 817.08, 817.08, 817.08, 817.08, 820.83, 820.83, 820.83, 820.83, 820.83, 823.82, 823.82, 823.82, 823.82, 823.82, 823.03, 823.03, 823.03, 823.03, 823.03, 827.7, 827.7, 827.7, 827.7, 827.7, 826.96, 826.96, 826.96, 826.96, 826.96, 833.12, 833.12, 833.12, 833.12, 833.12, 832.75, 832.75, 832.75, 832.75, 832.75, 832.65, 832.65, 832.65, 832.65, 832.65, 826.23, 826.23, 826.23, 826.23, 826.23, 827.38, 827.38, 827.38, 827.38, 827.38, 827.43, 827.43, 827.43, 827.43, 827.43, 827.46, 827.46, 827.46, 827.46, 827.46, 825.87, 825.87, 825.87, 825.87, 825.87, 828.84, 828.84, 828.84, 828.84, 828.84, 829.05, 829.05, 829.05, 829.05, 829.05, 829.15, 829.15, 829.15]
                    
Loading
predicted_tokens_seconds
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 557 iterations"
    y-axis "llamacpp:predicted_tokens_seconds"
    x-axis "llamacpp:predicted_tokens_seconds" 1717475210 --> 1717475834
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 42.1, 42.1, 42.1, 42.1, 42.1, 30.42, 30.42, 30.42, 30.42, 30.42, 28.2, 28.2, 28.2, 28.2, 28.2, 28.69, 28.69, 28.69, 28.69, 28.69, 29.63, 29.63, 29.63, 29.63, 29.63, 30.55, 30.55, 30.55, 30.55, 30.55, 32.02, 32.02, 32.02, 32.02, 32.02, 32.76, 32.76, 32.76, 32.76, 32.76, 33.41, 33.41, 33.41, 33.41, 33.41, 33.56, 33.56, 33.56, 33.56, 33.56, 34.05, 34.05, 34.05, 34.05, 34.05, 33.99, 33.99, 33.99, 33.99, 33.99, 33.35, 33.35, 33.35, 33.35, 33.35, 33.38, 33.38, 33.38, 33.38, 33.38, 32.25, 32.25, 32.25, 32.25, 32.25, 31.71, 31.71, 31.71, 31.71, 31.71, 30.36, 30.36, 30.36, 30.36, 30.36, 30.81, 30.81, 30.81, 30.81, 30.81, 30.82, 30.82, 30.82, 30.82, 30.82, 30.39, 30.39, 30.39, 30.39, 30.39, 30.41, 30.41, 30.41, 30.41, 30.41, 30.5, 30.5, 30.5, 30.5, 30.5, 30.85, 30.85, 30.85, 30.85, 30.85, 30.97, 30.97, 30.97, 30.97, 30.97, 31.24, 31.24, 31.24, 31.24, 31.24, 31.45, 31.45, 31.45, 31.45, 31.45, 31.23, 31.23, 31.23, 31.23, 31.23, 31.18, 31.18, 31.18, 31.18, 31.18, 31.36, 31.36, 31.36, 31.36, 31.36, 31.43, 31.43, 31.43, 31.43, 31.43, 31.63, 31.63, 31.63, 31.63, 31.63, 31.71, 31.71, 31.71, 31.71, 31.71, 31.78, 31.78, 31.78, 31.78, 31.78, 31.61, 31.61, 31.61, 31.61, 31.61, 31.48, 31.48, 31.48, 31.48, 31.48, 31.35, 31.35, 31.35, 31.35, 31.35, 31.43, 31.43, 31.43, 31.43, 31.43, 31.54, 31.54, 31.54, 31.54, 31.54, 31.71, 31.71, 31.71, 31.71, 31.71, 31.79, 31.79, 31.79, 31.79, 31.79, 31.85, 31.85, 31.85, 31.85, 31.85, 31.71, 31.71, 31.71, 31.71, 31.71, 31.42, 31.42, 31.42, 31.42, 31.42, 31.06, 31.06, 31.06, 31.06, 31.06, 29.65, 29.65, 29.65, 29.65, 29.65, 29.37, 29.37, 29.37, 29.37, 29.37, 29.37, 29.37, 29.37, 29.37, 29.37, 29.4, 29.4, 29.4, 29.4, 29.4, 29.46, 29.46, 29.46, 29.46, 29.46, 29.58, 29.58, 29.58, 29.58, 29.58, 29.61, 29.61, 29.61, 29.61, 29.61, 29.57, 29.57, 29.57, 29.57, 29.57, 29.58, 29.58, 29.58, 29.58, 29.58, 29.45, 29.45, 29.45, 29.45, 29.45, 29.55, 29.55, 29.55, 29.55, 29.55, 29.69, 29.69, 29.69, 29.69, 29.69, 29.83, 29.83, 29.83, 29.83, 29.83, 29.9, 29.9, 29.9, 29.9, 29.9, 29.96, 29.96, 29.96, 29.96, 29.96, 29.97, 29.97, 29.97, 29.97, 29.97, 30.03, 30.03, 30.03]
                    
Loading

Details

kv_cache_usage_ratio

More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 557 iterations"
    y-axis "llamacpp:kv_cache_usage_ratio"
    x-axis "llamacpp:kv_cache_usage_ratio" 1717475210 --> 1717475834
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.14, 0.14, 0.14, 0.14, 0.14, 0.37, 0.37, 0.37, 0.37, 0.37, 0.25, 0.25, 0.25, 0.25, 0.25, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.14, 0.14, 0.14, 0.14, 0.14, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16, 0.16, 0.16, 0.25, 0.25, 0.25, 0.25, 0.25, 0.26, 0.26, 0.26, 0.26, 0.26, 0.12, 0.12, 0.12, 0.12, 0.12, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.36, 0.41, 0.41, 0.41, 0.41, 0.41, 0.14, 0.14, 0.14, 0.14, 0.14, 0.15, 0.15, 0.15, 0.15, 0.15, 0.23, 0.23, 0.23, 0.23, 0.23, 0.2, 0.2, 0.2, 0.2, 0.2, 0.19, 0.19, 0.19, 0.19, 0.19, 0.16, 0.16, 0.16, 0.16, 0.16, 0.19, 0.19, 0.19, 0.19, 0.19, 0.11, 0.11, 0.11, 0.11, 0.11, 0.13, 0.13, 0.13, 0.13, 0.13, 0.32, 0.32, 0.32, 0.32, 0.32, 0.21, 0.21, 0.21, 0.21, 0.21, 0.1, 0.1, 0.1, 0.1, 0.1, 0.12, 0.12, 0.12, 0.12, 0.12, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.13, 0.23, 0.23, 0.23, 0.23, 0.23, 0.21, 0.21, 0.21, 0.21, 0.21, 0.28, 0.28, 0.28, 0.28, 0.28, 0.3, 0.3, 0.3, 0.3, 0.3, 0.2, 0.2, 0.2, 0.2, 0.2, 0.09, 0.09, 0.09, 0.09, 0.09, 0.14, 0.14, 0.14, 0.14, 0.14, 0.13, 0.13, 0.13, 0.13, 0.13, 0.16, 0.16, 0.16, 0.16, 0.16, 0.45, 0.45, 0.45, 0.45, 0.45, 0.56, 0.56, 0.56, 0.56, 0.56, 0.56, 0.56, 0.56, 0.56, 0.56, 0.64, 0.64, 0.64, 0.64, 0.64, 0.36, 0.36, 0.36, 0.36, 0.36, 0.21, 0.21, 0.21, 0.21, 0.21, 0.2, 0.2, 0.2, 0.2, 0.2, 0.17, 0.17, 0.17, 0.17, 0.17, 0.13, 0.13, 0.13, 0.13, 0.13, 0.11, 0.11, 0.11, 0.11, 0.11, 0.29, 0.29, 0.29, 0.29, 0.29, 0.27, 0.27, 0.27, 0.27, 0.27, 0.24, 0.24, 0.24, 0.24, 0.24, 0.19, 0.19, 0.19, 0.19, 0.19, 0.13, 0.13, 0.13, 0.13, 0.13, 0.12, 0.12, 0.12, 0.12, 0.12, 0.17, 0.17, 0.17, 0.17, 0.17, 0.15, 0.15, 0.15, 0.15, 0.15, 0.22, 0.22, 0.22, 0.22, 0.22, 0.12, 0.12, 0.12, 0.12, 0.12, 0.16, 0.16, 0.16]
                    
Loading
requests_processing
More
---
config:
    xyChart:
        titleFontSize: 12
        width: 900
        height: 600
    themeVariables:
        xyChart:
            titleColor: "#000000"
---
xychart-beta
    title "llama.cpp bench-server-baseline on Standard_NC4as_T4_v3
 duration=10m 557 iterations"
    y-axis "llamacpp:requests_processing"
    x-axis "llamacpp:requests_processing" 1717475210 --> 1717475834
    line [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 3.0, 3.0, 3.0, 3.0, 3.0, 6.0, 6.0, 6.0, 6.0, 6.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 6.0, 6.0, 6.0, 6.0, 6.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 4.0, 4.0, 4.0, 4.0, 4.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 8.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 7.0, 7.0, 7.0, 7.0, 7.0, 6.0, 6.0, 6.0, 6.0, 6.0, 2.0, 2.0, 2.0, 2.0, 2.0, 3.0, 3.0, 3.0, 3.0, 3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0, 4.0, 6.0, 6.0, 6.0, 6.0, 6.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 2.0, 2.0, 2.0, 2.0, 2.0, 5.0, 5.0, 5.0, 5.0, 5.0, 8.0, 8.0, 8.0, 8.0, 8.0, 5.0, 5.0, 5.0, 5.0, 5.0, 6.0, 6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 3.0]
                    
Loading

@arch-btw
Copy link
Contributor

Great job! Works for me too, it's very fast. There were some warnings during compilation, but nothing major.

<EOT>Hello!

I'll get a new one for you and I think this is going to be really cool, so good. And I'm sure there's lots of ways in which [...]

llama_print_timings:        load time =     286.42 ms
llama_print_timings:      sample time =     155.94 ms /   256 runs   (    0.61 ms per token,  1641.63 tokens per second)
llama_print_timings: prompt eval time =      70.77 ms /     3 tokens (   23.59 ms per token,    42.39 tokens per second)
llama_print_timings:        eval time =    9368.54 ms /   255 runs   (   36.74 ms per token,    27.22 tokens per second)
llama_print_timings:       total time =    9686.16 ms /   258 tokens

@TechxGenus
Copy link

Amazing work!
I initially tested Jamba-v0.1 on a machine with 500G RAM and it worked great!

./main -m ./Jamba-v0.1-hf-00001-of-00024.gguf -n 120 --prompt "def max(arr):" --temp 0
Log start
main: build = 3006 (fc59407e)
main: built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
main: seed  = 1716710334
llama_model_loader: additional 23 GGUFs metadata loaded.
llama_model_loader: loaded meta data with 31 key-value pairs and 531 tensors from ./Jamba-v0.1-hf-00001-of-00024.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = jamba
llama_model_loader: - kv   1:                               general.name str              = Jamba-v0.1-hf
llama_model_loader: - kv   2:                          jamba.block_count u32              = 32
llama_model_loader: - kv   3:                       jamba.context_length u32              = 262144
llama_model_loader: - kv   4:                     jamba.embedding_length u32              = 4096
llama_model_loader: - kv   5:                  jamba.feed_forward_length u32              = 14336
llama_model_loader: - kv   6:                 jamba.attention.head_count u32              = 32
llama_model_loader: - kv   7:              jamba.attention.head_count_kv arr[i32,32]      = [0, 0, 0, 0, 8, 0, 0, 0, 0, 0, 0, 0, ...
llama_model_loader: - kv   8:                      jamba.ssm.conv_kernel u32              = 4
llama_model_loader: - kv   9:                       jamba.ssm.inner_size u32              = 8192
llama_model_loader: - kv  10:                       jamba.ssm.state_size u32              = 16
llama_model_loader: - kv  11:                   jamba.ssm.time_step_rank u32              = 256
llama_model_loader: - kv  12:     jamba.attention.layer_norm_rms_epsilon f32              = 0.000001
llama_model_loader: - kv  13:                         jamba.expert_count u32              = 16
llama_model_loader: - kv  14:                    jamba.expert_used_count u32              = 2
llama_model_loader: - kv  15:                          general.file_type u32              = 32
llama_model_loader: - kv  16:                       tokenizer.ggml.model str              = llama
llama_model_loader: - kv  17:                         tokenizer.ggml.pre str              = default
llama_model_loader: - kv  18:                      tokenizer.ggml.tokens arr[str,65536]   = ["<|pad|>", "<|startoftext|>", "<|end...
llama_model_loader: - kv  19:                      tokenizer.ggml.scores arr[f32,65536]   = [0.000000, 0.000000, 0.000000, 0.0000...
llama_model_loader: - kv  20:                  tokenizer.ggml.token_type arr[i32,65536]   = [3, 3, 3, 2, 3, 3, 3, 3, 3, 3, 3, 3, ...
llama_model_loader: - kv  21:                tokenizer.ggml.bos_token_id u32              = 1
llama_model_loader: - kv  22:                tokenizer.ggml.eos_token_id u32              = 2
llama_model_loader: - kv  23:            tokenizer.ggml.unknown_token_id u32              = 3
llama_model_loader: - kv  24:            tokenizer.ggml.padding_token_id u32              = 0
llama_model_loader: - kv  25:               tokenizer.ggml.add_bos_token bool             = true
llama_model_loader: - kv  26:               tokenizer.ggml.add_eos_token bool             = false
llama_model_loader: - kv  27:               general.quantization_version u32              = 2
llama_model_loader: - kv  28:                                   split.no u16              = 0
llama_model_loader: - kv  29:                                split.count u16              = 24
llama_model_loader: - kv  30:                        split.tensors.count i32              = 531
llama_model_loader: - type  f32:  361 tensors
llama_model_loader: - type bf16:  170 tensors
llm_load_vocab: special tokens definition check successful ( 1799/65536 ).
llm_load_print_meta: format           = GGUF V3 (latest)
llm_load_print_meta: arch             = jamba
llm_load_print_meta: vocab type       = SPM
llm_load_print_meta: n_vocab          = 65536
llm_load_print_meta: n_merges         = 0
llm_load_print_meta: n_ctx_train      = 262144
llm_load_print_meta: n_embd           = 4096
llm_load_print_meta: n_head           = 32
llm_load_print_meta: n_head_kv        = 32
llm_load_print_meta: n_layer          = 32
llm_load_print_meta: n_rot            = 128
llm_load_print_meta: n_embd_head_k    = 128
llm_load_print_meta: n_embd_head_v    = 128
llm_load_print_meta: n_gqa            = 0
llm_load_print_meta: n_embd_k_gqa     = 0
llm_load_print_meta: n_embd_v_gqa     = 0
llm_load_print_meta: f_norm_eps       = 0.0e+00
llm_load_print_meta: f_norm_rms_eps   = 1.0e-06
llm_load_print_meta: f_clamp_kqv      = 0.0e+00
llm_load_print_meta: f_max_alibi_bias = 0.0e+00
llm_load_print_meta: f_logit_scale    = 0.0e+00
llm_load_print_meta: n_ff             = 14336
llm_load_print_meta: n_expert         = 16
llm_load_print_meta: n_expert_used    = 2
llm_load_print_meta: causal attn      = 1
llm_load_print_meta: pooling type     = 0
llm_load_print_meta: rope type        = -1
llm_load_print_meta: rope scaling     = linear
llm_load_print_meta: freq_base_train  = 10000.0
llm_load_print_meta: freq_scale_train = 1
llm_load_print_meta: n_yarn_orig_ctx  = 262144
llm_load_print_meta: rope_finetuned   = unknown
llm_load_print_meta: ssm_d_conv       = 4
llm_load_print_meta: ssm_d_inner      = 8192
llm_load_print_meta: ssm_d_state      = 16
llm_load_print_meta: ssm_dt_rank      = 256
llm_load_print_meta: model type       = ?B
llm_load_print_meta: model ftype      = BF16
llm_load_print_meta: model params     = 51.57 B
llm_load_print_meta: model size       = 96.30 GiB (16.04 BPW) 
llm_load_print_meta: general.name     = Jamba-v0.1-hf
llm_load_print_meta: BOS token        = 1 '<|startoftext|>'
llm_load_print_meta: EOS token        = 2 '<|endoftext|>'
llm_load_print_meta: UNK token        = 3 '<|unk|>'
llm_load_print_meta: PAD token        = 0 '<|pad|>'
llm_load_print_meta: LF token         = 1554 '<0x0A>'
llm_load_print_meta: EOT token        = 2 '<|endoftext|>'
llm_load_tensors: ggml ctx size =    0.24 MiB
llm_load_tensors:        CPU buffer size =  4851.72 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  4210.03 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  5095.47 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  3584.03 MiB
llm_load_tensors:        CPU buffer size =  4210.03 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  4210.03 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4339.75 MiB
llm_load_tensors:        CPU buffer size =  4210.03 MiB
llm_load_tensors:        CPU buffer size =  3584.00 MiB
llm_load_tensors:        CPU buffer size =  4851.77 MiB
llm_load_tensors:        CPU buffer size =  3584.03 MiB
..............................................................
llama_new_context_with_model: n_ctx      = 512
llama_new_context_with_model: n_batch    = 512
llama_new_context_with_model: n_ubatch   = 512
llama_new_context_with_model: flash_attn = 0
llama_new_context_with_model: freq_base  = 10000.0
llama_new_context_with_model: freq_scale = 1
llama_cache_init:        CPU cache buf size =    24.63 MiB
llama_new_context_with_model: SSM state size =    16.62 MiB, R (f32):    2.62 MiB, S (f32):   14.00 MiB
llama_new_context_with_model: KV cache size  =     8.00 MiB, K (f16):    4.00 MiB, V (f16):    4.00 MiB
llama_new_context_with_model:        CPU  output buffer size =     0.25 MiB
llama_new_context_with_model:        CPU compute buffer size =   145.10 MiB
llama_new_context_with_model: graph nodes  = 1730
llama_new_context_with_model: graph splits = 1

system_info: n_threads = 32 / 64 | AVX = 1 | AVX_VNNI = 1 | AVX2 = 1 | AVX512 = 1 | AVX512_VBMI = 1 | AVX512_VNNI = 1 | AVX512_BF16 = 1 | FMA = 1 | NEON = 0 | ARM_FMA = 0 | F16C = 1 | FP16_VA = 0 | WASM_SIMD = 0 | BLAS = 0 | SSE3 = 1 | SSSE3 = 1 | VSX = 0 | MATMUL_INT8 = 0 | LLAMAFILE = 1 | 
sampling: 
        repeat_last_n = 64, repeat_penalty = 1.000, frequency_penalty = 0.000, presence_penalty = 0.000
        top_k = 40, tfs_z = 1.000, top_p = 0.950, min_p = 0.050, typical_p = 1.000, temp = 0.000
        mirostat = 0, mirostat_lr = 0.100, mirostat_ent = 5.000
sampling order: 
CFG -> Penalties -> top_k -> tfs_z -> typical_p -> top_p -> min_p -> temperature 
generate: n_ctx = 512, n_batch = 2048, n_predict = 120, n_keep = 1


<|startoftext|> def max(arr):
    return max(arr)


def min(arr):
    return min(arr)


def mean(arr):
    return sum(arr) / len(arr)


def median(arr):
    arr.sort()
    if len(arr) % 2 == 0:
        return (arr[len(arr) // 2] + arr[len(arr) // 2 - 1]) / 2
    else:
        return arr[len(arr) // 2]


llama_print_timings:        load time =   82494.54 ms
llama_print_timings:      sample time =       9.61 ms /   120 runs   (    0.08 ms per token, 12490.89 tokens per second)
llama_print_timings: prompt eval time =     666.33 ms /     6 tokens (  111.06 ms per token,     9.00 tokens per second)
llama_print_timings:        eval time =   27656.31 ms /   119 runs   (  232.41 ms per token,     4.30 tokens per second)
llama_print_timings:       total time =   28862.18 ms /   125 tokens
Log end

ggml.c Outdated
Comment on lines 16264 to 16267
if (n_rs > 1) {
// multiple sequences means it's hard to know when it's the first time a state is read,
// so copy them all over to the destination, just to be sure.
for (int i3 = 0; i3 < n_kv; ++i3) {
for (int i3 = 0; i3 < n_rs; ++i3) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking at adding the missing Metal kernels for SSM_CONV and SSM_SCAN. I'm wondering if this part of the kernels where we copy src0 -> dst could be extracted outside of the operation via ggml_cpy + ggml_view or ggml_acc? Would simplify the implementation

Also, I still haven't understood the details of the computation, but if we find a way to express these ops via existing ops all together (e.g. using ggml_conv, ggml_mul_mat, ...), it would be preferred to do so, in order to reduce the amount of kernels that we have to write.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if this part of the kernels where we copy src0 -> dst could be extracted outside of the operation via ggml_cpy + ggml_view or ggml_acc? Would simplify the implementation

Yes, this is definitely possible. I'll find a way to extract the copies outside.

if we find a way to express these ops via existing ops all together (e.g. using ggml_conv, ggml_mul_mat, ...), it would be preferred to do so, in order to reduce the amount of kernels that we have to write.

For SSM_SCAN, I think there's a way to fully express it in terms of other ops, though it will use much more memory because of the big intermediate tensors, and new operators like SOFT_PLUS and EXP would be needed instead. But different lengths of simultaneous sequences might make a custom operator still necessary. I'll think about ways to make it simpler, especially since other recurrent architectures (like RWKV) will also need to work on multiple sequences per batch.

For simplifying SSM_CONV, I don't think ggml_conv supports working on independent 1D rolling windows with varying sequence lengths.

When working on a single sequence, though, it's quite simple to do the equivalent of ggml_ssm_conv with a self-overlapping view, as I did in my original implementation which I described in more detail in #5328 (comment):

llama.cpp/llama.cpp

Lines 6973 to 6982 in 64fbce0

// prepare convolution for all tokens in the batch with a self-overlapping view,
// shifting by one column each ... depth? ... with a window of d_conv columns.
// {(d_conv-1)+n_tok, d_inner} => {d_conv, d_inner, n_tok}
conv_x = ggml_view_3d(ctx0, conv_x, d_conv, d_inner, n_tok, conv_x_nb1, 1*ggml_element_size(conv_x), 0);
// perform convolution
// => {1, d_inner, n_tok}
x = ggml_sum_rows(ctx0, ggml_mul(ctx0, conv_x, model.layers[il].ssm_conv1d));
// => {d_inner, n_tok, 1}
x = ggml_permute(ctx0, x, 2, 0, 1, 3);

Setting nb[2] to the element size makes the view self-overlapping.

But this would create too many nodes in the compute graph when done with multiple sequences (unless they're always all the same length in which case the 4th dimension could be used), so a custom operator is necessary.

Copy link
Owner

@ggerganov ggerganov May 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea that we might consider is to unfuse the n_rs dimension from the SSM ops and make them work per 1 recurrent state. Then, during inference and right before the SSM operations, we split the batch into same-sequence chunks and SSM them individually. After that we concat back the results into the full hidden state for the batch

The main goal would be to simplify the SSM operators, and potentially express them as other existing ops if possible. But additionally, I'm considering a similar processing mode for the standard transformer KV cache in which we don't rely on a "unified" buffer for all the sequences, but instead each sequence has it's own separate KV cache buffer. In that mode, we would do a similar same-sequence batch split before the attention. The main purpose of supporting this mode would be to achieve reproducible results during parallel decoding (currently, decoding the same sequence in parallel can yield slightly different results due to the unified KV cache).

Just throwing some thoughts that I have so far - will continue looking at the PR in the next days

Edit: I was writing this comment before I saw you posted - will take a look tomorrow

Copy link
Collaborator Author

@compilade compilade May 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea that we might consider is to unfuse the n_rs dimension from the SSM ops and make them work per 1 recurrent state. Then, during inference and right before the SSM operations, we split the batch into same-sequence chunks and SSM them individually. After that we concat back the results into the full hidden state for the batch

Yes, this would be doable, but would make the number of compute graph nodes scale with the number of sequences. (EDIT: if it's split when making ubatches, then the number of compute graph nodes can stay constant)

Another way would be to make all sequences have the same number of new tokens in a ubatch, to allow using another dimension instead of having to loop when building the compute graphs. This would still allow batching multiple sequences with recurrent models, but without the need for new custom operators for each architecture, and still with a constant number of compute graph nodes.

The recurrent steps are simpler for ubatches with sequence lengths of 1, but prompt processing performance would be much slower than with a per-recurrent-architecture operator for longer sequences. Still thinking about ways to generalize this while keeping good performance.

But additionally, I'm considering a similar processing mode for the standard transformer KV cache in which we don't rely on a "unified" buffer for all the sequences, but instead each sequence has it's own separate KV cache buffer. In that mode, we would do a similar same-sequence batch split before the attention.

For the transformer KV cache, if there's logic to make all sequences within a ubatch to have the same number of new tokens, I think a mode to split batches sequence-wise will be simpler and could re-use much of the same code.

I also think there's a way to keep the unified KV cache (one buffer) and chunk it to make each sequence have their own independent contiguous reserved cells. Batching sequences together might still be possible though, if the KQ mask gets another dimension (the number of sequences in the ubatch, and the number of new tokens per sequence instead of the batch size) so that these equal-sized "chunks" get processed independently in parallel. But this might not work (because the newly-calculated KV cells have to be copied in a bunch of not-regularly-spaced places), unless... unless maybe with some kind of ggml_set_rows? Not sure about the transposed V cache, though.

A sequence-wise processing mode is likely simpler, although it's not really parallel processing then (the model weights are all read at each ubatch).

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's split when making ubatches, then the number of compute graph nodes can stay constant

No, it has to be split only for the attention so that the rest of the ops are still batched. Otherwise we will sacrifice a lot of performance

Another way would be to make all sequences have the same number of new tokens in a ubatch, to allow using another dimension instead of having to loop when building the compute graphs. This would still allow batching multiple sequences with recurrent models, but without the need for new custom operators for each architecture, and still with a constant number of compute graph nodes.

For the transformer KV cache, if there's logic to make all sequences within a ubatch to have the same number of new tokens, I think a mode to split batches sequence-wise will be simpler and could re-use much of the same code.

Not sure how that would work. Adding dummy tokens sounds too much overhead (at least
in the case of the regular transformer). Any other ideas?

A sequence-wise processing mode is likely simpler, although it's not really parallel processing then (the model weights are all read at each ubatch).

From a broad PoV, if we have an implementation that works with a single-sequence and any batch size, then to extend it to multi-sequence batches we can split the batch into same-sequence tokens right before the attention and merge it back after the attention. Each split will do what we already do for the single-sequence solution, using separate cache for each sequence. I didn't consider the number of nodes until you noted - so that might be a problem indeed.

I'm currently working on a big refactor of how Mamba (and Jamba) works to make all sequences of a sub-batch be of the same length (initially only for models with recurrent states), and to make recurrent state slots contiguous, with the goal of simplifying the SSM operations (and removing GGML_OP_SSM_CONV), so that GPU support will be much easier to implement after that.

Looking forward to this!

Copy link
Collaborator Author

@compilade compilade May 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it has to be split only for the attention so that the rest of the ops are still batched. Otherwise we will sacrifice a lot of performance

It will sacrifice some performance, but only in the cases where a batch contains an unequal number of tokens for each affected sequence. So this should not affect large prompt processing or parallel text generation, if both are not done in the same batch.

Not sure how that would work. Adding dummy tokens sounds too much overhead (at least
in the case of the regular transformer). Any other ideas?

This is not about adding dummy tokens, but about making the number of new tokens in each ubatch the same per sequence. I think the overhead will be minmal, though there is still some.

Let me illustrate.

Let's say there's a batch with new tokens for 4 sequences of length 16, 7, 1, 1, respectively.

0: ################
1: #######
2: #
3: #

Splitting that into equal-length sequences would make 3 ubatches, like so:

0: #
1: #
2: #
3: #
0: ######
1: ######
0: #########

Each of these shapes are nice and rectangular, which is good for recurrent architectures because their operations can be more easily batched across sequences this way.

But I'm not yet sure if it would also benefit Transformers, which is why I'm thinking of initially only enabling the equal-length splitting for recurrent (or hybrid) model architectures.

From a broad PoV, if we have an implementation that works with a single-sequence and any batch size, then to extend it to multi-sequence batches we can split the batch into same-sequence tokens right before the attention and merge it back after the attention. Each split will do what we already do for the single-sequence solution, using separate cache for each sequence. I didn't consider the number of nodes until you noted - so that might be a problem indeed.

Doing this with a constant number of graph nodes is pretty much what using same-length sequences (as illustrated above) allows, because the split into same-sequence tokens can then simply become another tensor dimension.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aha, got it. Good idea. I'm also not sure if this can help Transformers, but it's something to think about 👍

@compilade
Copy link
Collaborator Author

The change is quite big and I'm having a bit of trouble to merge it all at once. Wonder if we should take a more step-by-step approach.

I agree that this is quite big. Sorry about that. I'll see what I can do.

The ggml changes alone are good - could these be merged alone and used for the existing mamba-only implementation on master?

Unfortunately, the ggml changes to the mamba-related operators depend on equal sequence length u-batches and contiguous (and ordered) allocation for recurrent states. It might still be possible to extract enough of the new behavior onto the current way recurrent states are managed on master, or not. I'll look into ways to do this.

I think I might be able to separate some parts of this PR.

These are the main separable parts:

  • Variable GQA support
    • makes {arch}.attention.head_count_kv also capable of being an array of integers
    • Isn't really used outside of DeciLM and hybrid models. I originally added it to simplify the allocation of the KV cache to reserve space only for the layers that need it in Jamba, and also to identify which layers use Attention and which don't in Jamba. This seemed like a good way to solve three problems at once.
  • Advanced batch splits
    • Can be useful on its own, since the buffers it adds eliminates the need for extra allocations in llama_decode_internal when llama_batch_get_one is used.
  • ggml improvements to GGML_OP_SSM_CONV and GGML_OP_SSM_SCAN
    • depends on equal-sequence-length u-batches and contiguous (and ordered) recurrent state slot allocation.
      • There might be a way to retro-fit contiguous allocation on the old way the KV cache was re-used for recurrent states. I'll need to think more about this.
  • Separate recurrent state cache from the KV cache
    • This is a big one, since this includes the (maybe over-engineered) recurrent state management which allows keeping state checkpoints and which makes recurrent state slot allocation always contiguous and use the same order the associated seq_id have in the batch (which benefits from equal-sequence-length u-batch splitting). This also simplifies how copies of cells between sequences are made, since recurrent state cells can now be shared between seq_id with llama_kv_cache_seq_cp/llama_past_seq_cp, while the latest states are unaliased during slot allocation.
    • Maybe after that try to decompose the KV cache changes somehow. Probably after refactoring the code a bit to prepare for this change and splitting the llama.cpp code into more source files.

      • I'll think about how to make this more easily manageable. But this is inherently a lot of interlinked changes, since the KV cache API has one more type of cache to manage simultaneously (!) for hybrid models, and some operations get their p0 and/or p1 ranges modified depending on the presence of state checkpoints.
  • Session file support for the separate recurrent state cache
    • This is not yet done
  • Jamba support
    • depends on all of the above

@ggerganov
Copy link
Owner

Variable GQA support

Could we extend this point a bit more and add support for OpenELM together with it? The PR for OpenELM is almost ready, but has some quick hacks that seem relevant to this: #7359

@compilade compilade mentioned this pull request Jun 18, 2024
@compilade
Copy link
Collaborator Author

Now that variable GQA support is in master (because of #7359 which has been merged), I plan to separate the advanced batch splits feature in its own PR for easier review.

(for some context, this allows splitting batches as described in #7531 (comment), and also single-sequence ubatches, as well as the current simple split used on master)

@Autumnlight02
Copy link

Any updates on this since Jamba 1.5 is now out?

jploski added a commit to jploski/llama.cpp that referenced this pull request Aug 25, 2024
@compilade
Copy link
Collaborator Author

Any updates on this since Jamba 1.5 is now out?

@Autumnlight02

Basically, since #8526 was merged, now I need to resolve a very big merge conflict because I didn't keep the code identical. This will probably take a few days.

@compilade
Copy link
Collaborator Author

compilade commented Sep 1, 2024

Some progress update on Jamba:

I began resolving the merge conflicts, and there were at least 2000+ lines of conflicts (basically half of this PR). This is manageable.

While I've solved most of them, the result is not usable (and it doesn't build, and so I did not push it here yet (sorry), I will push once it works) because of the state saving and restoring code which was changed in #8699, and this doesn't yet handle two caches.

My problem right now is with the single-sequence session restoring, which uses llama_kv_cache_find_slot in master, but this won't really work for how transparent recurrent state checkpoints are implemented here, so I'm thinking of other ways.

(EDIT: on further thought llama_kv_cache_find_slot can work, but only for a single checkpoint per sequence. This might be sufficient. I'm still leaving the rest of this comment intact because it's still somewhat relevant to know the tradeoffs of the implementation of recurrent state checkpoints)

To make single-sequence session restores simpler, I could either

  • Keep using llama_kv_cache_find_slot for that because it turns out it's not a problem
    • Only realized this after writing this whole comment.
    • Would only work to restore a single state checkpoint per sequence.
  • Throw away state checkpoints and postpone them for a future PR.
    • This would simplify everything, but would result in a bad user experience with recurrent and hybrid models due to excessive prompt reprocessing when using llama-server for conversations, because recurrent states can't be rolled back (yet?), and prompt processing has to start back from the beginning when the server removes more than one token at the end (extremely common).
      • This is also currently the situation for purely recurrent models on master (might not be that bad?)
  • Make llama_kv_cache_defrag defragment the whole KV cache to get an easy contiguous slot at the "end"
    • Requires deeply refactoring kv cache defrag to use ggml_get_rows instead of potentially thousands of individual tensor copies (otherwise defragmenting the whole KV cache won't really be doable in one shot)
  • Store fragmented cache
    • Might be more complicated (and sometimes less efficient) than defragmenting beforehand.
  • Explicitly fail for single-sequence session restore for recurrent (and hybrid) models

The least bad option (EDIT: apart from simply using llama_kv_cache_find_slot) seems to be to improve KV cache defragmentation, and again it seems like it could be its own PR. I'll begin working on that.

But I'm starting to think that maybe state checkpoints add too much complexity.

The current implementation uses a unified pool of recurrent state cells to allocate checkpoints and/or current states for each seq_id while ensuring the available cells are allocated fairly to each "used" seq_id. If there are only 2 "used" seq_id but there are 8 allocated cells, then each seq_id will get 4 cells (1 for the "tail" cell, and 3 for the checkpoints). If there's a third seq_id appearing, then they will each get at least 2 cells, while some of them will use the remaining 2 cells. That behavior requires keeping track of a lot of things including the relationship of cells in a tree of sequences. (some cells can be common between sequences, and the count of shared cells is managed differently; the explanation in this paragraph glosses over some details)

Some alternatives:

  • Manually use one more seq_id (with llama_kv_cache_seq_cp) at each "checkpoint"
    • Internally simpler
    • Potentially better checkpoint placement
    • Harder to manage in the examples (and for 3rd party apps using the KV cache API) than automatic checkpoints
      • When using llama_kv_cache_seq_rm with a partial token range instead of with a whole sequence, the same problems as before apply: with recurrent models, processing the prompt would need to start over from the beginning, and would not detect the most appropriate checkpoint to use (if there was one).
        • Although with automatic checkpoints, this also has to be handled, but "starting over" happens closer to the end of the prompt.
    • super-sequences have to be tracked manually
      • seq_id re-use can be complicated
      • slot-ids in llama-server would no longer directly map to seq_ids
    • Has to be explicitly managed anywhere it could be useful
      • simpler on the inside but more complicated on the outside
    • Not sure if -np should still also be the number of distinct recurrent states, because it's also the slot count in the server.
  • Calculate states in reverse
    • best memory usage
    • would be very cool
    • not sure it's possible
    • not a general solution for all recurrent models
      • needs further research for each recurrent architecture
    • would need keeping track of tokens in the cache
  • Pre-allocate a fixed number of recurrent state checkpoints for each sequence
    • Simpler, but not really (contiguous slot allocation could make this more complicated)
    • (considering a minimum of 3 checkpoints per sequence is necessary to properly benefit from checkpoints)
    • Not ideal memory usage
      • Especially with the dedicated sequence for the system prompt in llama-server and llama-parallel
    • Would need some way of specifying the number of checkpoints per sequence (can't be -c or -np, for reasons)

Manual checkpoint management seems tempting, but would offload the complexity to llama-server, which might not be desirable (since in the end it's mostly the same things which need to be tracked).

Meanwhile I will attempt to refactor KV cache defragmentation soon (which should be useful anyway).

@ggerganov
Copy link
Owner

Regarding the manual checkpoint management - recently, the commonly used APIs in the cloud (e.g. Anthropic, OpenAI, etc) introduced "prompt caching" [0], which adds a "cache control" parameter to the requests. It can be used to cache prompts, but I guess it fits well with the idea of manual recurrent state checkpointing from the user code.

I'm thinking that the changes for Jamba should be kept to a minimum for now, even if this would require longer processing times for common use cases. The reason is that the architecture is not yet well adopted, so increasing the complexity of the codebase to support it is not very justified. The better approach would be to improve the support for the existing transformer and mamba arches, by refactoring the KV cache and state management implementation and adding unit tests. I suppose a large part of the complexity with Jamba comes from the fact that we are trying to fit the logic into the existing KV cache implementation, which is not well-fit for this architecture.

[0] - https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching

@compilade
Copy link
Collaborator Author

compilade commented Sep 2, 2024

For the first time state saving and reloading works for Jamba (both for the whole state and single-sequences). 🎉

This is implemented in fcb889c

I'm thinking that the changes for Jamba should be kept to a minimum for now, even if this would require longer processing times for common use cases. The reason is that the architecture is not yet well adopted, so increasing the complexity of the codebase to support it is not very justified.

Agreed. I'll start simplifying the code and will think about how to best approach manual/explicit checkpoints for a future PR. The implicit checkpoints implemented here are a bit over-engineered, and do not fit in the idea of a "minimal" change.

I suppose a large part of the complexity with Jamba comes from the fact that we are trying to fit the logic into the existing KV cache implementation, which is not well-fit for this architecture.

I did not find the existing KV cache implementation to be particularly limiting. Most of the complexity in the Jamba implementation here comes from the allocation of recurrent states and implicit checkpoints. The only necessary complexity needed for Jamba is that both the KV cache and the recurrent state cache should be kept in sync, and even then most of the complexity is in keeping the metadata of the tree of sequences consistent (some of which is only there to allow fairly allocating the cache between seq_ids).


My plan for this PR in the next days/weeks:

  • Remove implicit recurrent state checkpoints to remove unnecessary complexity (I estimate this will reduce the change by 1000+ lines)
    • Proper explicit checkpoint handling (likely with some new API to make it more convenient) is postponed to a future pull-request.
      • I'll still keep the code for implicit checkpoints somewhere because it's possible to make it explicit instead with only a few lines changed
      • I want to explore simpler ways to do this.
    • The advantage of explicit checkpoints is that the minimum useful number of states per user is 2 instead of 3, because llama-server "knows" when the stop token (or the beginning of the stop string!) is sampled.
  • Rename llama_past back to llama_kv_cache (which will still contain both Attention's KV cache and the recurrent state cache)
    • llama_kv_cache (which currently contains only Attention's KV cache) will be renamed to
      • llama_kv_self_cache, for self-attention, because even in T5 it's only used for self-attention
    • llama_rs_cache (which contains recurrent states) will be renamed to either
      • llama_kv_rect_cache (stands for "RECurrenT")
      • llama_kv_rest_cache (stands for "REcurrent STate", but might be confusing)
      • llama_kv_iter_cache (might be confused with Iterators in the C++ language even if unrelated)
      • llama_rs_cache (same name, no renaming)
      • llama_rs_self_cache (for consistency with kv_self? self has no particular meaning here)
        • Probably what I'll pick
    • n_rs and rs_head (vs n_kv and kv_head) will keep their name.
    • This will avoid changing the names of the functions of the KV cache API and will let us still call it "the KV cache API"

What will be left intact is:

  • Jamba support
  • The KV cache API which also simultaneously manages recurrent states
    • Same drawbacks as for Mamba and RWKV-v6 on master, i.e. only one state per sequence (no rollback).
  • Session saving and reloading for hybrid models

As before, hybrid models of different architectures should be able to work on top of that (like how RWKV-v6 and Mamba can share the same recurrent state management code), as long as it's about hybrids between Attention and some recurrent block. This will mean models like Zamba (Mamba + Attention), Zamba2 (Mamba-2 + Attention), RecurrentGemma (RG-LRU + Attention), and others should be easier to implement without worrying about the KV cache API too much, and they will benefit from future improvements in state checkpoint management.

(Note that mixing different recurrent architectures in the same model is out of scope, but I don't think this will be a problem)

leafspark added a commit to leafspark/AutoGGUF that referenced this pull request Sep 20, 2024
- added support for MiniCPM3, RWKVv6, OLMoE, IBM Granite, and Jamba (conversion only: ggerganov/llama.cpp#7531)
- update gguf library from upstream
@hg0428
Copy link

hg0428 commented Oct 1, 2024

How's this going?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
android Issues specific to Android embeddings embedding related topics enhancement New feature or request examples ggml changes relating to the ggml tensor library for machine learning model Model specific need feedback Testing and feedback with results are needed python python script changes refactoring Refactoring Review Complexity : High Generally require indepth knowledge of LLMs or GPUs server
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Suport for Jamba JambaForCausalLM
7 participants