diff --git a/common/common.cpp b/common/common.cpp index 9fa18472512ab..ba003e44b0173 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2541,7 +2541,7 @@ struct llama_init_result llama_init_from_gpt_params(gpt_params & params) { if (llama_model_has_decoder(model)) { llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0)); } - llama_kv_cache_clear(lctx); + llama_past_clear(lctx); llama_synchronize(lctx); llama_reset_timings(lctx); } diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 27ac34b810acd..e9bb4b20bd6d3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2874,6 +2874,120 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(new_name, data_torch)] +@Model.register("JambaForCausalLM") +class JambaModel(Model): + model_arch = gguf.MODEL_ARCH.JAMBA + + def get_vocab_base_pre(self, tokenizer) -> str: + del tokenizer # unused + + return "gpt-2" + + def set_vocab(self): + if (self.dir_model / "tokenizer.model").is_file(): + # Using Jamba's tokenizer.json causes errors on model load + # (something about "byte not found in vocab"), + # but there's a working tokenizer.model + self._set_vocab_sentencepiece() + else: + # Some Jamba models only have a tokenizer.json, which works. + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + d_model = self.find_hparam(["hidden_size", "mamba_d_model"]) + d_conv = self.find_hparam(["mamba_d_conv"], optional=True) or 4 + d_inner = self.hparams["mamba_expand"] * d_model + d_state = self.find_hparam(["mamba_d_state"], optional=True) or 16 + # ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + # ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58 + dt_rank = self.find_hparam(["mamba_dt_rank"], optional=True) or -(d_model // -16) + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-6 + n_kv_head = self.hparams["num_key_value_heads"] + attn_offset = self.hparams["attn_layer_offset"] + attn_period = self.hparams["attn_layer_period"] + n_kv_vec = [0 for _ in range(attn_offset)] + [ + n_kv_head if (i - attn_offset) % attn_period == 0 else 0 for i in range(attn_offset, self.block_count) + ] + + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.find_hparam(["max_position_embeddings", "n_ctx"])) + self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(n_kv_vec) + self.gguf_writer.add_ssm_conv_kernel(d_conv) + self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_state_size(d_state) + self.gguf_writer.add_ssm_time_step_rank(dt_rank) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_expert_count(self.hparams["num_experts"]) + self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) + self.gguf_writer.add_file_type(self.ftype) + + _experts: list[dict[str, Tensor]] | None = None + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + + # Mini-Jamba + name = name.replace(".moe.", ".feed_forward.") + if bid is not None: + moe_offset = self.hparams["expert_layer_offset"] + moe_period = self.hparams["expert_layer_period"] + + if not (bid >= moe_offset and (bid - moe_offset) % moe_period == 0): + name = name.replace(".experts.0.", ".") + + # process the experts separately + if ".feed_forward.experts." in name: + n_experts = self.hparams["num_experts"] + + assert bid is not None + + if self._experts is None: + self._experts = [{} for _ in range(self.block_count)] + + self._experts[bid][name] = data_torch + + if len(self._experts[bid]) >= n_experts * 3: + + # merge the experts into a single 3d tensor + for wid in ["down_proj", "gate_proj", "up_proj"]: + datas: list[Tensor] = [] + + for xid in range(n_experts): + ename = f"model.layers.{bid}.feed_forward.experts.{xid}.{wid}.weight" + datas.append(self._experts[bid][ename]) + del self._experts[bid][ename] + + data_torch = torch.stack(datas, dim=0) + + # using the same merged name as qwen2moe + merged_name = f"model.layers.{bid}.mlp.experts.{wid}.weight" + + new_name = self.map_tensor_name(merged_name) + + yield new_name, data_torch + return + + new_name = self.map_tensor_name(name) + + if name.endswith(".A_log"): + logger.debug("A_log --> A ==> " + new_name) + data_torch = -torch.exp(data_torch) + + yield new_name, data_torch + + def prepare_tensors(self): + super().prepare_tensors() + + if self._experts is not None: + # flatten `list[dict[str, Tensor]]` into `list[str]` + experts = [k for d in self._experts for k in d.keys()] + if len(experts) > 0: + raise ValueError(f"Unprocessed experts: {experts}") + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): model_arch = gguf.MODEL_ARCH.COMMAND_R diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 25e7c775a0095..f3ce4964f442f 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -153,7 +153,7 @@ int main(int argc, char ** argv) { const auto t_pp_start = ggml_time_us(); - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); if (!decode_helper(ctx, batch, ctx_params.n_batch)) { LOG_TEE("%s: llama_decode() failed\n", __func__); @@ -162,7 +162,7 @@ int main(int argc, char ** argv) { if (is_pp_shared) { for (int32_t i = 1; i < pl; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } } diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 616494d2d841d..2a7324e6b2839 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -98,7 +98,7 @@ if llama_decode(context, batch) != 0 { } for i in 1 ..< n_parallel { - llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens) + llama_past_seq_cp(context, 0, Int32(i), -1, -1) } if n_parallel > 1 { diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 53fbfb0a8cf2a..00cd744d4f6b9 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -132,7 +132,7 @@ int main(int argc, char ** argv) { //// assign the system KV cache to all parallel sequences //// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them //for (int32_t i = 1; i < n_parallel; ++i) { - // llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + // llama_past_seq_cp(ctx, 0, i, -1, -1); //} if (n_parallel > 1) { diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index a68268388389d..0acdfdf381ab2 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -338,7 +338,7 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { } static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size(), 0, 0))) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index b05aa006e7da5..a98b0811d15f5 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -35,7 +35,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu const struct llama_model * model = llama_get_model(ctx); // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // run model fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 2c61c2e1eb3bc..6d78237877fd6 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -43,7 +43,7 @@ static std::vector> encode(llama_context * ctx, const std::ve } // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); llama_set_embeddings(ctx, true); llama_set_causal_attn(ctx, false); @@ -98,7 +98,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo const llama_model * mdl = llama_get_model(ctx); llama_token eos_token = llama_token_eos(mdl); - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); llama_set_embeddings(ctx, false); llama_set_causal_attn(ctx, true); diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 83b85d72b043a..48a7c366e4420 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -499,7 +499,7 @@ static bool compute_imatrix(llama_context * ctx, const gpt_params & params) { const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 05700c1d591d9..778461b1fa8b5 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -385,8 +385,8 @@ int main(int argc, char ** argv) { LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); - llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); + llama_past_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); + llama_past_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); n_past -= n_discard; diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index 8edadef909f42..d86bae1f3be2e 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1515,7 +1515,7 @@ int main(int argc, char ** argv) { test t(inst, lmodel, ctx); - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // cool off before the test if (params.delay) { @@ -1549,7 +1549,7 @@ int main(int argc, char ** argv) { } for (int i = 0; i < params.reps; i++) { - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); uint64_t t_start = get_time_ns(); diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 2aafe23167557..c5366d97f243e 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( } batch->logits[batch->n_tokens - 1] = true; - llama_kv_cache_clear(context); + llama_past_clear(context); const auto t_pp_start = ggml_time_us(); if (llama_decode(context, *batch) != 0) { @@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( LOGi("Benchmark text generation (tg)"); - llama_kv_cache_clear(context); + llama_past_clear(context); const auto t_tg_start = ggml_time_us(); for (i = 0; i < tg; i++) { @@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( const auto t_tg_end = ggml_time_us(); - llama_kv_cache_clear(context); + llama_past_clear(context); const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0; const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0; @@ -439,5 +439,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { - llama_kv_cache_clear(reinterpret_cast(context)); + llama_past_clear(reinterpret_cast(context)); } diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 48b7840ae49c3..f893de7a577e1 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -216,7 +216,7 @@ actor LlamaContext { } batch.logits[Int(batch.n_tokens) - 1] = 1 // true - llama_kv_cache_clear(context) + llama_past_clear(context) let t_pp_start = ggml_time_us() @@ -229,7 +229,7 @@ actor LlamaContext { // bench text generation - llama_kv_cache_clear(context) + llama_past_clear(context) let t_tg_start = ggml_time_us() @@ -248,7 +248,7 @@ actor LlamaContext { let t_tg_end = ggml_time_us() - llama_kv_cache_clear(context) + llama_past_clear(context) let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0 let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0 @@ -298,7 +298,7 @@ actor LlamaContext { func clear() { tokens_list.removeAll() temporary_invalid_cchars.removeAll() - llama_kv_cache_clear(context) + llama_past_clear(context) } private func tokenize(text: String, add_bos: Bool) -> [llama_token] { diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 81cf1629c5b6a..49288ab358407 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -96,7 +96,7 @@ int main(int argc, char ** argv) { llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0)); for (int s = 1; s < W + G + 1; ++s) { - llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); + llama_past_seq_cp(ctx, 0, s, -1, -1); } const auto t_enc_end = ggml_time_us(); @@ -438,17 +438,18 @@ int main(int argc, char ** argv) { // KV cache management // if no verification token matched, we simply remove all cells from this batch -> no fragmentation - llama_kv_cache_seq_rm(ctx, -1, n_past, -1); + // FIXME: recurrent and hybrid models + llama_past_seq_rm(ctx, -1, n_past, -1); if (seq_id_best != 0) { // if a verification token matched, we keep the best sequence and remove the rest // this leads to some KV cache fragmentation - llama_kv_cache_seq_keep(ctx, seq_id_best); - llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1); - llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1); + llama_past_seq_keep(ctx, seq_id_best); + llama_past_seq_cp (ctx, seq_id_best, 0, -1, -1); + llama_past_seq_rm (ctx, seq_id_best, -1, -1); for (int s = 1; s < W + G + 1; ++s) { - llama_kv_cache_seq_cp(ctx, 0, s, -1, -1); + llama_past_seq_cp(ctx, 0, s, -1, -1); } } } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index d53a9828c2ea2..4f1bb51408983 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -194,7 +194,8 @@ int main(int argc, char ** argv){ // KV cache management // clean the cache of draft tokens that weren't accepted - llama_kv_cache_seq_rm(ctx, 0, n_past, -1); + // FIXME: recurrent and hybrid models + llama_past_seq_rm(ctx, 0, n_past, -1); llama_batch_clear(batch_tgt); llama_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 2c05afb048c7b..6304ec4fc94f0 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -369,6 +369,10 @@ int main(int argc, char ** argv) { } n_matching_session_tokens++; } + + // remove any "future" tokens that we might have inherited from the previous session + n_matching_session_tokens = llama_past_seq_rm(ctx, -1, n_matching_session_tokens, -1); + if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) { LOG_TEE("%s: using full prompt from session file\n", __func__); } else if (n_matching_session_tokens >= embd_inp.size()) { @@ -380,9 +384,6 @@ int main(int argc, char ** argv) { LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n", __func__, n_matching_session_tokens, embd_inp.size()); } - - // remove any "future" tokens that we might have inherited from the previous session - llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1); } LOGLN( @@ -395,6 +396,8 @@ int main(int argc, char ** argv) { LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1); session_tokens.resize(embd_inp.size() - 1); + } else { + session_tokens.resize(n_matching_session_tokens); } // number of tokens to keep when resetting context @@ -624,8 +627,8 @@ int main(int argc, char ** argv) { LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", n_past, n_left, n_ctx, params.n_keep, n_discard); - llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); + llama_past_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); + llama_past_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); n_past -= n_discard; @@ -652,9 +655,9 @@ int main(int argc, char ** argv) { LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n); LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd); - llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd); - llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); - llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); + llama_past_seq_add(ctx, 0, ga_i, n_past, ib*bd); + llama_past_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n); + llama_past_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd); n_past -= bd; @@ -668,6 +671,8 @@ int main(int argc, char ** argv) { if (n_session_consumed < (int) session_tokens.size()) { size_t i = 0; for ( ; i < embd.size(); i++) { + // TODO: are the session tokens guaranteed to all be matching here? + // Should n_matching_session_tokens be re-used instead? if (embd[i] != session_tokens[n_session_consumed]) { session_tokens.resize(n_session_consumed); break; diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 621a1c9590622..32c3ba2b0f1a9 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -200,7 +200,7 @@ int main(int argc, char ** argv) { // assign the system KV cache to all parallel sequences for (int32_t i = 1; i <= n_clients; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } LOG_TEE("\n"); @@ -232,9 +232,9 @@ int main(int argc, char ** argv) { if (batch.n_tokens == 0) { // all sequences have ended - clear the entire KV cache for (int i = 1; i <= n_clients; ++i) { - llama_kv_cache_seq_rm(ctx, i, -1, -1); + llama_past_seq_rm(ctx, i, -1, -1); // but keep the system prompt - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } LOG_TEE("%s: clearing the KV cache\n", __func__); @@ -371,8 +371,8 @@ int main(int argc, char ** argv) { } // delete only the generated part of the sequence, i.e. keep the system prompt in the cache - llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1); - llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1); + llama_past_seq_rm(ctx, client.id + 1, -1, -1); + llama_past_seq_cp(ctx, 0, client.id + 1, -1, -1); const auto t_main_end = ggml_time_us(); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index d03215cd1e0a9..c6564c5cfd4c7 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -126,11 +126,11 @@ int main(int argc, char ** argv) { const int ib = i/n_batch - 1; const int bd = n_batch_grp*(n_grp - 1); - llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); - llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); - llama_kv_cache_update (ctx); + llama_past_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd); + llama_past_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp); + llama_kv_cache_update(ctx); - n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; + n_past = llama_past_seq_pos_max(ctx, 0) + 1; } llama_batch_clear(batch); @@ -160,12 +160,12 @@ int main(int argc, char ** argv) { LOG_TEE("%s: shifting KV cache with %d\n", __func__, n_discard); - llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_cache_defrag (ctx); - llama_kv_cache_update (ctx); + llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + //llama_kv_cache_defrag(ctx); + llama_kv_cache_update(ctx); - n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; + n_past = llama_past_seq_pos_max(ctx, 0) + 1; llama_batch_clear(batch); @@ -191,12 +191,12 @@ int main(int argc, char ** argv) { if (n_discard > 0) { LOG_TEE("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard); - llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard); - //llama_kv_cache_defrag (ctx); - llama_kv_cache_update (ctx); + llama_past_seq_rm (ctx, 0, n_keep , n_keep + n_discard); + llama_past_seq_add (ctx, 0, n_keep + n_discard, n_ctx, -n_discard); + //llama_kv_cache_defrag(ctx); + llama_kv_cache_update(ctx); - n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1; + n_past = llama_past_seq_pos_max(ctx, 0) + 1; } } diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 484dd589109c7..f2b0b9df8e93e 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -400,7 +400,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -575,7 +575,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -944,7 +944,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { return; } - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1221,7 +1221,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) { return; } - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1594,7 +1594,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params return; } - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // decode all tasks [i0, i1) if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { @@ -1780,7 +1780,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { } // clear the KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index aab9d81058af9..54bdf5d1fa517 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -82,7 +82,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); // run model fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 3ea7c790d2bf7..906a1970bcaf6 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -199,7 +199,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy); // erase whole kv - llama_kv_cache_clear(ctx3); + llama_past_clear(ctx3); fprintf(stderr, "%s : kv cache cleared\n", __func__); // restore kv into seq 1 diff --git a/examples/server/server.cpp b/examples/server/server.cpp index cc938e80d6a6d..e7f0e3ac9bd84 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1123,7 +1123,7 @@ struct server_context { LOG_VERBOSE("clearing KV cache", {}); // clear the entire KV cache - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); clean_kv_cache = false; } @@ -1158,7 +1158,7 @@ struct server_context { // assign the system KV cache to all parallel sequences for (int32_t i = 1; i <= params.n_parallel; ++i) { - llama_kv_cache_seq_cp(ctx, 0, i, -1, -1); + llama_past_seq_cp(ctx, 0, i, -1, -1); } } @@ -1835,7 +1835,7 @@ struct server_context { // Erase token cache const size_t n_erased = slot->cache_tokens.size(); - llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1); + llama_past_seq_rm(ctx, slot->id + 1, -1, -1); slot->cache_tokens.clear(); server_task_result result; @@ -1960,8 +1960,8 @@ struct server_context { {"n_cache_tokens", slot.cache_tokens.size()} }); - llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); - llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); + llama_past_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard); + llama_past_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard); if (slot.params.cache_prompt) { for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) { @@ -2200,23 +2200,28 @@ struct server_context { } // keep only the common part - int p0 = (int) system_tokens.size() + slot.n_past; - if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) { - // could not partially delete (likely using a non-Transformer model) - llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1); - - p0 = (int) system_tokens.size(); - if (p0 != 0) { - // copy over the system prompt when there is one - llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1); + llama_pos p0 = (llama_pos) system_tokens.size() + slot.n_past; + + // for recurrent and hybrid models, sometimes it goes back further than asked + llama_pos new_p0 = llama_past_seq_rm(ctx, slot.id + 1, p0, -1); + + if (new_p0 < p0) { + GGML_ASSERT(new_p0 >= (llama_pos) system_tokens.size()); + + slot.n_past -= p0 - new_p0; + if (slot.ga_i > 0) { + // TODO: test with an hybrid model (e.g. Jamba) + slot.n_past_se -= p0 - new_p0; } - // there is no common part left (except for the system prompt) - slot.n_past = 0; - slot.n_past_se = 0; - slot.ga_i = 0; - // TODO: is the system prompt ever in the sampling context? + // TODO: find a way to avoid rolling back the sampling context twice llama_sampling_reset(slot.ctx_sampling); + // push the prompt into the sampling context (do not apply grammar) + for (int i = 0; i < slot.n_past; ++i) { + llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false); + } + + p0 = new_p0; } // remove the non-common part from the cache @@ -2321,9 +2326,9 @@ struct server_context { LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n); LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd); - llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd); - llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); - llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); + llama_past_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd); + llama_past_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n); + llama_past_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd); slot.n_past_se -= bd; diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 1616edecbbef6..ee881b679bb50 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -399,14 +399,15 @@ int main(int argc, char ** argv) { { LOG("keeping sequence %d, n_past_tgt = %d, n_past_dft = %d\n", s_keep, n_past_tgt, n_past_dft); - llama_kv_cache_seq_keep(ctx_dft, s_keep); - llama_kv_cache_seq_cp (ctx_dft, s_keep, 0, -1, -1); - llama_kv_cache_seq_keep(ctx_dft, 0); - - llama_kv_cache_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); - llama_kv_cache_seq_keep(ctx_tgt, s_keep); - llama_kv_cache_seq_cp (ctx_tgt, s_keep, 0, -1, -1); - llama_kv_cache_seq_keep(ctx_tgt, 0); + llama_past_seq_keep(ctx_dft, s_keep); + llama_past_seq_cp (ctx_dft, s_keep, 0, -1, -1); + llama_past_seq_keep(ctx_dft, 0); + + // FIXME: recurrent and hybrid models + llama_past_seq_rm (ctx_tgt, s_keep, n_past_tgt, -1); + llama_past_seq_keep(ctx_tgt, s_keep); + llama_past_seq_cp (ctx_tgt, s_keep, 0, -1, -1); + llama_past_seq_keep(ctx_tgt, 0); } for (int s = 0; s < n_seq_dft; ++s) { @@ -423,7 +424,8 @@ int main(int argc, char ** argv) { llama_batch_clear(batch_dft); llama_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); - llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1); + // FIXME: recurrent and hybrid models + llama_past_seq_rm(ctx_dft, 0, n_past_dft, -1); // LOG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); llama_decode(ctx_dft, batch_dft); @@ -479,8 +481,8 @@ int main(int argc, char ** argv) { if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) { LOG("splitting seq %3d into %3d\n", s, n_seq_cur); - llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1); - llama_kv_cache_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); + llama_past_seq_rm(ctx_dft, n_seq_cur, -1, -1); + llama_past_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); // all previous tokens from this branch are now also part of the new branch for (int t = 0; t < batch_tgt.n_tokens; ++t) { @@ -558,9 +560,9 @@ int main(int argc, char ** argv) { // evaluate the target model on the drafted tokens { - llama_kv_cache_seq_keep(ctx_tgt, 0); + llama_past_seq_keep(ctx_tgt, 0); for (int s = 1; s < n_seq_dft; ++s) { - llama_kv_cache_seq_cp(ctx_tgt, 0, s, -1, -1); + llama_past_seq_cp(ctx_tgt, 0, s, -1, -1); } // LOG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 83140d757d0ce..941f015c3a25c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -19817,7 +19817,6 @@ struct ggml_cplan ggml_graph_plan( cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 } } break; - case GGML_OP_CROSS_ENTROPY_LOSS: { cur = ggml_type_size(node->type)*(n_tasks + node->src[0]->ne[0]*n_tasks); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a48c4fb676a46..487fcd07e0e0f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -215,6 +215,7 @@ class MODEL_ARCH(IntEnum): STARCODER2 = auto() RWKV6 = auto() MAMBA = auto() + JAMBA = auto() XVERSE = auto() COMMAND_R = auto() DBRX = auto() @@ -274,7 +275,10 @@ class MODEL_TENSOR(IntEnum): SSM_CONV1D = auto() SSM_X = auto() SSM_DT = auto() + SSM_DT_NORM = auto() SSM_A = auto() + SSM_B_NORM = auto() + SSM_C_NORM = auto() SSM_D = auto() SSM_OUT = auto() TIME_MIX_W1 = auto() @@ -369,6 +373,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.RWKV6: "rwkv6", MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.JAMBA: "jamba", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", @@ -428,7 +433,10 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_CONV1D: "blk.{bid}.ssm_conv1d", MODEL_TENSOR.SSM_X: "blk.{bid}.ssm_x", MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", + MODEL_TENSOR.SSM_DT_NORM: "blk.{bid}.ssm_dt_norm", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", + MODEL_TENSOR.SSM_B_NORM: "blk.{bid}.ssm_b_norm", + MODEL_TENSOR.SSM_C_NORM: "blk.{bid}.ssm_c_norm", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", @@ -954,6 +962,34 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.JAMBA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_X, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_DT_NORM, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_B_NORM, + MODEL_TENSOR.SSM_C_NORM, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_OUT, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index bc9a13ee5bdf5..1391a365ee225 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -238,6 +238,8 @@ class TensorNameMap: "transformer.decoder_layer.{bid}.rms_norm_2", # Grok "encoder.layers.{bid}.post_attention_layernorm", # chatglm "transformer.layers.{bid}.ffn_norm", # openelm + "model.layers.{bid}.pre_ff_layernorm", # jamba + "model.layers.{bid}.pre_moe_layernorm", # mini-jamba ), # Post feed-forward norm @@ -256,6 +258,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.gate", # qwen2moe "transformer.decoder_layer.{bid}.router", # Grok "transformer.blocks.{bid}.ffn.router.layer", # dbrx + "model.layers.{bid}.feed_forward.router", # jamba ), MODEL_TENSOR.FFN_GATE_INP_SHEXP: ( @@ -287,6 +290,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.c_fc", # starcoder2 "encoder.layer.{bid}.mlp.gated_layers_v", # jina-bert-v2 "model.layers.{bid}.residual_mlp.w3", # arctic + "model.layers.{bid}.feed_forward.up_proj", # jamba "encoder.layers.{bid}.mlp.dense_h_to_4h", # chatglm "transformer.h.{bid}.mlp.c_fc_1", # exaone ), @@ -320,6 +324,7 @@ class TensorNameMap: "encoder.layer.{bid}.mlp.gated_layers_w", # jina-bert-v2 "transformer.h.{bid}.mlp.linear_1", # refact "model.layers.{bid}.residual_mlp.w1", # arctic + "model.layers.{bid}.feed_forward.gate_proj", # jamba "transformer.h.{bid}.mlp.c_fc_0", # exaone ), @@ -359,6 +364,7 @@ class TensorNameMap: "transformer.layers.{bid}.ffn.proj_2", # openelm "model.layers.{bid}.residual_mlp.w2", # arctic "encoder.layer.{bid}.mlp.down_layer", # jina-bert-v2 + "model.layers.{bid}.feed_forward.down_proj", # jamba "encoder.layers.{bid}.mlp.dense_4h_to_h", # chatglm "model.layers.h.{bid}.mlp.c_proj", # exaone ), @@ -406,38 +412,59 @@ class TensorNameMap: ), MODEL_TENSOR.SSM_IN: ( - "model.layers.{bid}.in_proj", - "backbone.layers.{bid}.mixer.in_proj", + "model.layers.{bid}.in_proj", # mamba-hf + "backbone.layers.{bid}.mixer.in_proj", # mamba + "model.layers.{bid}.mamba.in_proj", # jamba ), MODEL_TENSOR.SSM_CONV1D: ( - "model.layers.{bid}.conv1d", - "backbone.layers.{bid}.mixer.conv1d", + "model.layers.{bid}.conv1d", # mamba-hf + "backbone.layers.{bid}.mixer.conv1d", # mamba + "model.layers.{bid}.mamba.conv1d", # jamba ), MODEL_TENSOR.SSM_X: ( - "model.layers.{bid}.x_proj", - "backbone.layers.{bid}.mixer.x_proj", + "model.layers.{bid}.x_proj", # mamba-hf + "backbone.layers.{bid}.mixer.x_proj", # mamba + "model.layers.{bid}.mamba.x_proj", # jamba ), MODEL_TENSOR.SSM_DT: ( - "model.layers.{bid}.dt_proj", - "backbone.layers.{bid}.mixer.dt_proj", + "model.layers.{bid}.dt_proj", # mamba-hf + "backbone.layers.{bid}.mixer.dt_proj", # mamba + "model.layers.{bid}.mamba.dt_proj", # jamba + ), + + MODEL_TENSOR.SSM_DT_NORM: ( + "model.layers.{bid}.mamba.dt_layernorm", # jamba ), MODEL_TENSOR.SSM_A: ( - "model.layers.{bid}.A_log", - "backbone.layers.{bid}.mixer.A_log", + "model.layers.{bid}.A_log", # mamba-hf + "backbone.layers.{bid}.mixer.A_log", # mamba + "model.layers.{bid}.mamba.A_log", # jamba + ), + + MODEL_TENSOR.SSM_B_NORM: ( + "model.layers.{bid}.mamba.b_layernorm", # jamba + "model.layers.{bid}.mamba.B_layernorm", # mini-jamba + ), + + MODEL_TENSOR.SSM_C_NORM: ( + "model.layers.{bid}.mamba.c_layernorm", # jamba + "model.layers.{bid}.mamba.C_layernorm", # mini-jamba ), MODEL_TENSOR.SSM_D: ( - "model.layers.{bid}.D", - "backbone.layers.{bid}.mixer.D", + "model.layers.{bid}.D", # mamba-hf + "backbone.layers.{bid}.mixer.D", # mamba + "model.layers.{bid}.mamba.D", # jamba ), MODEL_TENSOR.SSM_OUT: ( - "model.layers.{bid}.out_proj", - "backbone.layers.{bid}.mixer.out_proj", + "model.layers.{bid}.out_proj", # mamba-hf + "backbone.layers.{bid}.mixer.out_proj", # mamba + "model.layers.{bid}.mamba.out_proj", # jamba ), MODEL_TENSOR.TIME_MIX_W1: ( diff --git a/include/llama.h b/include/llama.h index bfc37e88bbb74..49ff5f5c77e2d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -38,10 +38,10 @@ #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN -#define LLAMA_SESSION_VERSION 8 +#define LLAMA_SESSION_VERSION 9 #define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ -#define LLAMA_STATE_SEQ_VERSION 2 +#define LLAMA_STATE_SEQ_VERSION 3 #ifdef __cplusplus extern "C" { @@ -621,6 +621,12 @@ extern "C" { // Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes) LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view); + // Rebuild and check the validity of the recurrent state cache's tree of sequences. + // (slow, use only for debugging purposes) + // Returns whether or not the rs cache was valid. + // The errors are always corrected, but only logged when debug is true. + LLAMA_API bool llama_rs_cache_rebuild(struct llama_context * ctx, bool debug); + // Returns the number of tokens in the KV cache (slow, use only for debug) // If a KV cell has multiple sequences assigned to it, it will be counted multiple times LLAMA_API int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx); @@ -628,36 +634,62 @@ extern "C" { // Returns the number of used KV cells (i.e. have at least one sequence assigned to them) LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); - // Clear the KV cache - both cell info is erased and KV data is zeroed - LLAMA_API void llama_kv_cache_clear( + // Returns the number of used recurrent state cells (i.e. have at least one sequence assigned to them) + LLAMA_API int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx); + + // Clear the KV cache and recurrent states - both cell info is erased and KV data is zeroed + LLAMA_API void llama_past_clear( struct llama_context * ctx); + LLAMA_API DEPRECATED(void llama_kv_cache_clear( + struct llama_context * ctx), + "use llama_past_clear instead"); // Removes all tokens that belong to the specified sequence and have positions in [p0, p1) - // Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails // seq_id < 0 : match any sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API bool llama_kv_cache_seq_rm( + // Returns n_past (one more than the largest remaining pos in the seq_id) + // which is only meaningful to handle for partial removals. + LLAMA_API llama_pos llama_past_seq_rm( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1); + LLAMA_API DEPRECATED(bool llama_kv_cache_seq_rm( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1), + "use llama_past_seq_rm instead, and handle its return value for partial removals"); // Copy all tokens that belong to the specified sequence to another sequence - // Note that this does not allocate extra KV cache memory - it simply assigns the tokens to the new sequence + // Note that this does not allocate extra KV or RS cache memory - it simply assigns the tokens to the new sequence // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_cp( + // Returns n_past (one more than the largest remaining pos in the destination seq_id) + // which is only meaningful to handle when partially copying. + LLAMA_API llama_pos llama_past_seq_cp( struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_cp( + struct llama_context * ctx, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1), + "use llama_past_seq_cp instead, and handle its return value for partial copies"); // Removes all tokens that do not belong to the specified sequence - LLAMA_API void llama_kv_cache_seq_keep( + LLAMA_API void llama_past_seq_keep( struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_keep( + struct llama_context * ctx, + llama_seq_id seq_id), + "use llama_past_seq_keep instead"); // Adds relative position "delta" to all tokens that belong to the specified sequence and have positions in [p0, p1) // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -665,12 +697,19 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_add( + LLAMA_API void llama_past_seq_add( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_add( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta), + "use llama_past_seq_add instead"); // Integer division of the positions by factor of `d > 1` // If the KV cache is RoPEd, the KV data is updated accordingly: @@ -678,17 +717,28 @@ extern "C" { // - explicitly with llama_kv_cache_update() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - LLAMA_API void llama_kv_cache_seq_div( + LLAMA_API void llama_past_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d); + LLAMA_API DEPRECATED(void llama_kv_cache_seq_div( + struct llama_context * ctx, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d), + "use llama_past_seq_div instead"); - // Returns the largest position present in the KV cache for the specified sequence - LLAMA_API llama_pos llama_kv_cache_seq_pos_max( + // Returns the largest position present in the KV and/or RS cache for the specified sequence + LLAMA_API llama_pos llama_past_seq_pos_max( struct llama_context * ctx, llama_seq_id seq_id); + LLAMA_API DEPRECATED(llama_pos llama_kv_cache_seq_pos_max( + struct llama_context * ctx, + llama_seq_id seq_id), + "use llama_past_seq_pos_max instead, which now returns -1 instead of 0 when the seq_id has no cells"); // Defragment the KV cache // This will be applied: diff --git a/src/llama.cpp b/src/llama.cpp index 2113c72f3c90b..043f3d7ec7853 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -198,6 +198,7 @@ enum llm_arch { LLM_ARCH_GEMMA2, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, + LLM_ARCH_JAMBA, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, @@ -246,6 +247,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_JAMBA, "jamba" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, @@ -527,7 +529,10 @@ enum llm_tensor { LLM_TENSOR_SSM_CONV1D, LLM_TENSOR_SSM_X, LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_DT_NORM, LLM_TENSOR_SSM_A, + LLM_TENSOR_SSM_B_NORM, + LLM_TENSOR_SSM_C_NORM, LLM_TENSOR_SSM_D, LLM_TENSOR_SSM_OUT, LLM_TENSOR_TIME_MIX_W1, @@ -1103,6 +1108,37 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, }, }, + { + LLM_ARCH_JAMBA, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, + { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + }, + }, { LLM_ARCH_XVERSE, { @@ -2434,7 +2470,9 @@ struct llama_hparams { return n_embd_head_v * n_head_kv; } - uint32_t n_embd_k_s() const { // dimension of the rolling state embeddings + uint32_t n_embd_r(uint32_t il) const { // dimension of the rolling state embeddings + // TODO: support using an SSM in place of the MLP of a Transformer + if (n_head_kv(il) != 0) { return 0; } // corresponds to Mamba's conv_states size or RWKV's token_shift states size if (wkv_head_size != 0) { // for RWKV models @@ -2446,7 +2484,10 @@ struct llama_hparams { } } - uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings + uint32_t n_embd_s(uint32_t il) const { // dimension of the recurrent state embeddings + // TODO: support using an SSM in place of the MLP of a Transformer + if (n_head_kv(il) != 0) { return 0; } + if (wkv_head_size != 0) { // corresponds to RWKV's wkv_states size return n_embd * wkv_head_size; @@ -2508,6 +2549,9 @@ struct llama_layer { struct ggml_tensor * attn_sub_norm; struct ggml_tensor * attn_post_norm; struct ggml_tensor * ffn_sub_norm; + struct ggml_tensor * ssm_dt_norm; + struct ggml_tensor * ssm_b_norm; + struct ggml_tensor * ssm_c_norm; struct ggml_tensor * attn_norm_cross; struct ggml_tensor * attn_norm_enc; @@ -2658,8 +2702,6 @@ struct llama_ubatch { struct llama_kv_cell { llama_pos pos = -1; llama_pos delta = 0; - int32_t src = -1; // used by recurrent state models to copy states - int32_t tail = -1; std::set seq_id; @@ -2680,7 +2722,6 @@ struct llama_kv_cell { struct llama_kv_cache { bool has_shift = false; bool do_defrag = false; - bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token bool v_trans = true; // the value tensor is transposed // Note: The value of head isn't only used to optimize searching @@ -2701,9 +2742,719 @@ struct llama_kv_cache { std::vector k_l; // per layer std::vector v_l; + size_t total_size() const { + size_t size = 0; + for (struct ggml_tensor * k : k_l) { + size += ggml_nrows(k) * ggml_row_size(k->type, k->ne[0]); + } + for (struct ggml_tensor * v : v_l) { + size += ggml_nrows(v) * ggml_row_size(v->type, v->ne[0]); + } + return size; + } +}; + +// for recurrent models, use a tree of sequences to simplify +// quickly finding the tail cell of each sequence +// TODO: drop the _rs_ infix +struct llama_rs_seq_node { + llama_seq_id seq_id = -1; + int32_t next_cell = -1; + + // needed for automatic typecasting from a llama_seq_id + llama_rs_seq_node(const llama_seq_id s = -1, int32_t i = -1) : seq_id(s), next_cell(i) {} + + // needed for more convenient std::find + bool operator==(const llama_rs_seq_node & other) const { + return seq_id == other.seq_id; + } + + bool is_tail() const { + return next_cell < 0; + } +}; + +struct llama_rs_cell { + llama_pos pos = -1; + int32_t src = -1; // copy source id (cleared next when -1) + + // Link to previous cell in this sequence. + // Sequences can only diverge, never converge, + // so this works when there are multiple seq_ids per cell too. + int32_t prev = -1; + + // ref count of tails (should match the number of next_cell == -1 in seq_nodes) + uint32_t tail_rc = 0; + + // seq_ids by insertion order, to simplify updating n_cells compared to a set + std::vector seq_nodes; + + void insert_node(const llama_rs_seq_node & node) { + auto node_dest = std::find(seq_nodes.begin(), seq_nodes.end(), node); + if (node_dest == seq_nodes.end()) { + seq_nodes.push_back(node); + } else { + // overwrite the pre-existing node with the same seq_id if it exists + *node_dest = node; + } + } + + bool has_seq_id(const llama_seq_id & id) const { + return std::find(seq_nodes.begin(), seq_nodes.end(), id) != seq_nodes.end(); + } + + bool is_empty() const { + return seq_nodes.empty(); + } +}; + +struct llama_rs_seq_meta { + // cell id of the latest state of this seq_id + int32_t tail = -1; + // number of cells for which this seq_id is the first + // (useful to know if cells in this sequence should be pruned) + int32_t n_cells = 0; + // the last pos of this sequence if it is in the current ubatch, + // only set and used when finding a slot. + llama_pos ubatch_end_pos = -1; +}; + +// ring-buffered tree of cached recurrent state data +struct llama_rs_cache { + + uint32_t head = 0; // first state used for the last slot + uint32_t size = 0; + uint32_t used = 0; + + // computed when finding a slot + uint32_t n = 0; // range of states used for the last slot + + // only counts cells which are tails of all of their sequences. + // useful to know the minimum reserved cell count per seq_id. + uint32_t n_seqs = 0; + // cells part of multiple sequences, + // but which are only the tail of some of them. + // useful to dismiss sequences used as a shared prompt + uint32_t n_shared_tail_cells = 0; + + // with state models, a cell can hold the state for more than one past token + // TODO: it's probably not possible to always use contiguous cells + std::vector cells; + + // find tail cells faster + std::vector seq_tails; // map seq_ids to cell ids + + // freeable cell ids, computed when finding a slot + // useful to find the smallest range to defrag + std::vector freeable; + + // per layer + // NOTE: the naming of r and s is arbitrary + std::vector r_l; // rolling/shift states + std::vector s_l; // ssm (recurrent) states + + // TODO: maybe use a simpler data structure than a tree + + // Inefficient, but thorough verification and rebuilding of the rs cache + // from only the cells list with `pos` and seq_ids. + // Should not be called in a hot loop except when desperate and/or debugging. + bool rebuild(bool debug) { + bool was_valid = true; + // skip for non-recurrent models + if (size == 0) { return true; } + // the source of truth is the cells list + // buffer sizes + if (size != cells.size()) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells has wrong size (%zu instead of %u)\n", + __func__, cells.size(), size); + } + cells.resize(size); + was_valid = false; + } + if (size != seq_tails.size()) { + if (debug) { + LLAMA_LOG_ERROR("%s: seq_tails has wrong size (%zu instead of %u)\n", + __func__, seq_tails.size(), size); + } + seq_tails.resize(size); + was_valid = false; + } + // cells consistency + uint32_t used_verif = 0; + for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { + llama_rs_cell & cell = cells[cell_id]; + if (cell.seq_nodes.empty()) { + if (cell.pos >= 0) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d].pos is %d while it's empty (should be -1)\n", + __func__, cell_id, cell.pos); + } + cell.pos = -1; + was_valid = false; + } + } + if (cell.pos < 0) { + if (cell.pos != -1) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d].pos is %d while it's empty (should be -1)\n", + __func__, cell_id, cell.pos); + } + cell.pos = -1; + was_valid = false; + } + if (!cell.seq_nodes.empty()) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d] has %zu seq_ids while it's empty (should have none)\n", + __func__, cell_id, cell.seq_nodes.size()); + } + cell.seq_nodes.clear(); + was_valid = false; + } + cell.src = -1; + if (cell.prev != -1) { + if (debug) { + LLAMA_LOG_ERROR("%s: cells[%d].prev is %d while it's empty (should be -1)\n", + __func__, cell_id, cell.prev); + } + cell.prev = -1; + was_valid = false; + } + } else if (!debug) { + // Assuming the cache should be actually rebuilt when not debugging + cell.src = cell_id; + } + if (!cell.seq_nodes.empty()) { + used_verif += 1; + } + } + if (used != used_verif) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid used cell count (%u instead of %u)\n", + __func__, used, used_verif); + } + used = used_verif; + was_valid = false; + } + // tail verification + std::vector> seq_cells; + for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { + auto & seq = seq_tails[seq_id]; + seq_cells.clear(); + for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { + llama_rs_cell & cell = cells[cell_id]; + if (cell.has_seq_id(seq_id)) { + seq_cells.push_back({cell.pos, cell_id}); + } + } + // sort by pos and then by cell_id + std::sort(seq_cells.begin(), seq_cells.end()); + int32_t tail = seq_cells.empty() ? -1 : seq_cells[seq_cells.size() - 1].second; + if (tail != seq.tail) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid tail for seq_id %d (%d instead of %d)\n", + __func__, seq_id, seq.tail, tail); + } + seq.tail = tail; + was_valid = false; + } + int32_t prev = -1; + for (size_t i = 0; i < seq_cells.size(); ++i) { + uint32_t cell_id = seq_cells[i].second; + llama_rs_cell & cell = cells[cell_id]; + if (cell.prev != prev) { + // TODO: relax the error when multiple cells have the same pos + if (debug) { + LLAMA_LOG_ERROR("%s: invalid prev cell for cells[%u] (%d instead of %d)\n", + __func__, cell_id, cell.prev, prev); + } + cell.prev = prev; + was_valid = false; + } + prev = cell_id; + } + int32_t n_cells = 0; + int32_t next = -1; + for (size_t i = seq_cells.size(); i-- > 0;) { + uint32_t cell_id = seq_cells[i].second; + llama_rs_cell & cell = cells[cell_id]; + // assuming it's always found, because how else would it end up in the list of cells for this seq_id? + auto seq_node = std::find(cell.seq_nodes.begin(), cell.seq_nodes.end(), seq_id); + if (seq_node == cell.seq_nodes.begin()) { + n_cells += 1; + } + if (seq_node->next_cell != next) { + // TODO: relax the error when multiple cells have the same pos + if (debug) { + LLAMA_LOG_ERROR("%s: invalid next cell for seq_id %d in cells[%u] (%d instead of %d)\n", + __func__, seq_id, cell_id, seq_node->next_cell, next); + } + seq_node->next_cell = next; + was_valid = false; + } + next = cell_id; + } + if (seq.n_cells != n_cells) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid n_cells for seq_id %d (%d instead of %d)\n", + __func__, seq_id, seq.n_cells, n_cells); + } + seq.n_cells = n_cells; + } + } + // tail_rc + for (uint32_t cell_id = 0; cell_id < size; ++cell_id) { + llama_rs_cell & cell = cells[cell_id]; + uint32_t tail_rc = 0; + for (llama_seq_id seq_id = 0; (uint32_t) seq_id < size; ++seq_id) { + auto & seq = seq_tails[seq_id]; + if (seq.tail >= 0 && (uint32_t) seq.tail == cell_id) { + tail_rc += 1; + } + } + if (cell.tail_rc != tail_rc) { + if (debug) { + LLAMA_LOG_ERROR("%s: invalid tail_rc for cells[%u] (%u instead of %u)\n", + __func__, cell_id, cell.tail_rc, tail_rc); + } + cell.tail_rc = tail_rc; + was_valid = false; + } + } + // n_seqs + uint32_t n_seqs_verif = 0; + uint32_t n_shared_tail_cells_verif = 0; + for (uint32_t cell_id = 0; (uint32_t) cell_id < size; ++cell_id) { + llama_rs_cell & rs_cell = cells[cell_id]; + if (!rs_cell.seq_nodes.empty()) { + if (rs_cell.seq_nodes.size() == rs_cell.tail_rc) { + n_seqs_verif += 1; + } else if (rs_cell.tail_rc > 0) { + n_shared_tail_cells_verif += 1; + } + } + } + if (n_seqs != n_seqs_verif) { + if (debug) { + LLAMA_LOG_ERROR("%s: wrong n_seqs (%u instead of %u)\n", + __func__, n_seqs, n_seqs_verif); + } + n_seqs = n_seqs_verif; + was_valid = false; + } + if (n_shared_tail_cells != n_shared_tail_cells_verif) { + if (debug) { + LLAMA_LOG_ERROR("%s: wrong n_shared_tail_cells (%u instead of %u)\n", + __func__, n_shared_tail_cells, n_shared_tail_cells_verif); + } + n_shared_tail_cells = n_shared_tail_cells_verif; + was_valid = false; + } + return was_valid; + } + + // each seq_id should have access to at least this many cells + // (to use when pruning (to avoid over-pruning)) + uint32_t min_cells_per_seq(const llama_ubatch & batch) const { + uint32_t seqs = n_seqs; + for (uint32_t i = 0; i < batch.n_seqs; ++i) { + llama_seq_id seq_id = batch.seq_id[i][0]; + const llama_rs_seq_meta & new_seq = seq_tails[seq_id]; + if (new_seq.tail < 0 || new_seq.n_cells == 0) { + seqs += 1; + } + } + return (size - n_shared_tail_cells) / (seqs > 0 ? seqs : 1); + } + + void freeable_for_batch(const llama_ubatch & batch, llama_pos checkpoint_interval) { + GGML_ASSERT(batch.equal_seqs); + int32_t min_cells = min_cells_per_seq(batch); + + // TODO: minimize work required to find freeable cells + // currently, this finds freeable cells by excluding non-freeable cells, + // because some conditions are more easily expressed this way. + + freeable.assign(size, 1); + + for (llama_rs_seq_meta & seq : seq_tails) { + seq.ubatch_end_pos = -1; + } + + for (uint32_t i = 0; i < batch.n_seqs; ++i) { + int32_t n_seq_id = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_id; j++) { + llama_seq_id seq_id = batch.seq_id[i][j]; + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_tails.size()); + llama_rs_seq_meta & seq = seq_tails[seq_id]; + seq.ubatch_end_pos = batch.pos[i * batch.n_seq_tokens + batch.n_seq_tokens - 1]; + } + } + + for (llama_rs_seq_meta & seq : seq_tails) { + if (seq.tail >= 0 && freeable[seq.tail] != 0) { + llama_pos end_pos = seq.ubatch_end_pos; + // When is a tail cell not freeable? + if (end_pos < 0) { + // when any of its tails are not in the batch + freeable[seq.tail] = 0; + } else if (min_cells > 1) { + // TODO: fallback to this less often + llama_rs_cell & tail = cells[seq.tail]; + GGML_ASSERT(tail.pos < end_pos); + if (tail.prev < 0 || tail.pos + checkpoint_interval <= end_pos) { + // make a checkpoint before prompt processing + // TODO: should it always be done after instead? + freeable[seq.tail] = 0; + } else { + llama_rs_cell & prev = cells[tail.prev]; + if (prev.pos + checkpoint_interval <= end_pos) { + // make a checkpoint during text generation + freeable[seq.tail] = 0; + } + } + } + } + } + + for (uint32_t i = 0; i < size; ++i) { + llama_rs_cell & cell = cells[i]; + if (!cell.is_empty() && cell.tail_rc == 0) { + // TODO: reduce indirection here + llama_rs_seq_node & seq_node = cell.seq_nodes[0]; + llama_rs_seq_meta & seq = seq_tails[seq_node.seq_id]; + bool keep_tail = freeable[seq.tail] == 0; + // kept tails use an additional cell, so make them allow freeing a checkpoint + int32_t really_min_cells = keep_tail ? min_cells - 1 : min_cells; + // A checkpoint is kept if there's enough alloted space for this sequence + // or if it's the state right before the tail + if (seq.n_cells <= really_min_cells || (really_min_cells > 1 && seq_node.next_cell == seq.tail)) { + freeable[i] = 0; + } + } + } + } + + // returns an iterator to the seq_node after the removed one, or the same which was passed if it wasn't removed. + // Why an iterator? Because it allows using std::vector::erase. + std::vector::iterator remove_seq_node_from_cell(llama_rs_cell & rs_cell, std::vector::iterator node_iter) { + GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); + // The iterator needs to point inside the correct vector + GGML_ASSERT(&(*node_iter) >= rs_cell.seq_nodes.data() && &(*node_iter) < rs_cell.seq_nodes.data() + rs_cell.seq_nodes.size()); + if (node_iter != rs_cell.seq_nodes.end()) { + // update the tree + llama_rs_seq_node node = *node_iter; + if (node.next_cell >= 0 && (uint32_t) node.next_cell < size) { + // NOTE: because of this, partially removing seq_ids from cells should only be done from the tail + cells[node.next_cell].prev = rs_cell.prev; + } + if (rs_cell.prev >= 0 && (uint32_t) rs_cell.prev < size) { + llama_rs_cell & prev_cell = cells[rs_cell.prev]; + auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), node); + // assuming the previous node is always found + GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); + prev_node->next_cell = node.next_cell; + if (node.is_tail()) { + // move the tail back to the previous cell + prev_cell.tail_rc += 1; + if (prev_cell.seq_nodes.size() > 1) { + if (rs_cell.tail_rc == rs_cell.seq_nodes.size()) { + if (prev_cell.tail_rc == 1) { + n_shared_tail_cells += 1; + } + + if (rs_cell.tail_rc == 1) { + if (prev_cell.tail_rc != prev_cell.seq_nodes.size()) { + // o oo oo + // |/ -> o/ + // | | + // e.g. when removing the leaf of a split tree + n_seqs -= 1; + } else { + // o + // o -> oo + // | | + // e.g. when merging back with a previous tail + n_shared_tail_cells -= 1; + } + } + } + } + } + } + if ((uint32_t) node.seq_id < seq_tails.size()) { + auto & seq = seq_tails[node.seq_id]; + if (node.is_tail()) { + seq.tail = rs_cell.prev; + if (rs_cell.tail_rc == 1) { + if (seq.tail < 0) { + // no more tail, no more sequence + if (rs_cell.seq_nodes.size() > 1) { + n_shared_tail_cells -= 1; + } else { + n_seqs -= 1; + } + } + } + GGML_ASSERT(rs_cell.tail_rc > 0); + rs_cell.tail_rc -= 1; + } else if (rs_cell.tail_rc == rs_cell.seq_nodes.size() - 1) { + // will fully become a tail cell + if (rs_cell.tail_rc > 0) { + n_seqs += 1; + n_shared_tail_cells -= 1; + } + } + if (node_iter == rs_cell.seq_nodes.begin()) { + // this seq_id was the first in the list + seq.n_cells -= 1; + + auto next_node = std::next(node_iter); + if (next_node != rs_cell.seq_nodes.end()) { + // the next node is the new first one, so update its n_cells + if ((uint32_t) next_node->seq_id < seq_tails.size()) { + auto & next_seq = seq_tails[next_node->seq_id]; + next_seq.n_cells += 1; + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + } else { + // this was the last seq_id of the cell + used -= 1; + rs_cell.pos = -1; + rs_cell.src = -1; + rs_cell.prev = -1; + // the other fields *should* have already been updated elsewhere + } + } + } else { + GGML_ASSERT(false && "invalid seq_id"); + } + return rs_cell.seq_nodes.erase(node_iter); + } + return node_iter; + } + + void clear_cell(llama_rs_cell & rs_cell) { + GGML_ASSERT(&rs_cell >= cells.data() && &rs_cell < cells.data() + cells.size()); + for (auto node_iter = rs_cell.seq_nodes.begin(); node_iter != rs_cell.seq_nodes.end();) { + node_iter = remove_seq_node_from_cell(rs_cell, node_iter); + } + } + + // returns whether or not the seq_id was removed + bool remove_seq_from_cell_id(uint32_t i_cell, const llama_seq_id & id) { + if (i_cell < size && (size_t) id < size) { + llama_rs_cell & rs_cell = cells[i_cell]; + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), id); // search once + return node_iter != remove_seq_node_from_cell(rs_cell, node_iter); + } + return false; + } + + bool swap_cells(uint32_t i_src, uint32_t i_dst) { + if (i_src < size && i_dst < size && i_src != i_dst) { + llama_rs_cell & src = cells[i_src]; + llama_rs_cell & dst = cells[i_dst]; + + for (llama_rs_seq_node & seq_node : src.seq_nodes) { + if (seq_node.next_cell >= 0) { + llama_rs_cell & next = cells[seq_node.next_cell]; + next.prev = i_dst; + if ((uint32_t) seq_node.next_cell == i_dst) { + seq_node.next_cell = i_src; + } + } else { + // this is a tail + seq_tails[seq_node.seq_id].tail = i_dst; + } + } + for (llama_rs_seq_node & seq_node : dst.seq_nodes) { + if (seq_node.next_cell >= 0) { + llama_rs_cell & next = cells[seq_node.next_cell]; + next.prev = i_src; + if ((uint32_t) seq_node.next_cell == i_src) { + seq_node.next_cell = i_dst; + } + } else { + // this is a tail + seq_tails[seq_node.seq_id].tail = i_src; + } + } + + if (src.prev == dst.prev) { + // avoid swapping them twice + if (src.prev >= 0) { + llama_rs_cell & prev = cells[src.prev]; + for (llama_rs_seq_node & seq_node : prev.seq_nodes) { + if ((uint32_t) seq_node.next_cell == i_src) { + seq_node.next_cell = i_dst; + } else if ((uint32_t) seq_node.next_cell == i_dst) { + seq_node.next_cell = i_src; + } + } + } + } else { + if (src.prev >= 0) { + llama_rs_cell & prev = cells[src.prev]; + for (llama_rs_seq_node & seq_node : prev.seq_nodes) { + if ((uint32_t) seq_node.next_cell == i_src) { + seq_node.next_cell = i_dst; + } + } + } + if (dst.prev >= 0) { + llama_rs_cell & prev = cells[dst.prev]; + for (llama_rs_seq_node & seq_node : prev.seq_nodes) { + if ((uint32_t) seq_node.next_cell == i_dst) { + seq_node.next_cell = i_src; + } + } + } + } + + std::swap(src.pos, dst.pos); + std::swap(src.src, dst.src); + std::swap(src.prev, dst.prev); + std::swap(src.tail_rc, dst.tail_rc); + std::swap(src.seq_nodes, dst.seq_nodes); + + return true; + } + return false; + } + + bool insert_seq_tail_to_cell_id(uint32_t i_cell, llama_seq_id id, llama_pos end_pos = -1) { + if (i_cell < size && (size_t) id < seq_tails.size()) { + llama_rs_cell & rs_cell = cells[i_cell]; + auto & seq = seq_tails[id]; + int32_t prev = rs_cell.prev; + if (end_pos >= 0) { + if (end_pos <= rs_cell.pos) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d\n", + __func__, end_pos, rs_cell.pos, id); + } + rs_cell.pos = end_pos; + } else { + // if no pos was specified, then the target cell should already have a valid one. + GGML_ASSERT(!rs_cell.is_empty()); + } + if ((uint32_t) seq.tail == i_cell) { + // the cell is already the tail of this seq_id + if (rs_cell.tail_rc != rs_cell.seq_nodes.size()) { + GGML_ASSERT(end_pos >= 0); // make sure this is the first re-added seq_id + // remove non-tail seq_ids (branch off them) + for (size_t i = rs_cell.seq_nodes.size(); i-- > 0;) { + if (!rs_cell.seq_nodes[i].is_tail()) { + remove_seq_node_from_cell(rs_cell, rs_cell.seq_nodes.begin() + i); + } + } + } + return true; + } + if (rs_cell.is_empty()) { + prev = seq.tail; + } + // ensure the new tail won't mess up the tree + GGML_ASSERT(seq.tail == -1 || seq.tail == prev); + if (prev >= 0 && (uint32_t) prev < size) { + // the targeted cell has a previous cell + llama_rs_cell & prev_cell = cells[prev]; + auto prev_node = std::find(prev_cell.seq_nodes.begin(), prev_cell.seq_nodes.end(), id); + GGML_ASSERT(prev_node != prev_cell.seq_nodes.end()); // TODO: recursive insert instead of failing + GGML_ASSERT(prev_node->next_cell == -1); // or else a chain is broken + if (rs_cell.is_empty()) { + rs_cell.src = prev_cell.src; + } + prev_node->next_cell = i_cell; + rs_cell.prev = prev; + if (seq.tail == prev) { + // What to do when the tail moves... + // (Legend: tail: O, one or more non-tails: o, one or more tails O+, empty: _) + // O -> oO (n_seqs--, n_shared_tail_cells++) + // O -> O (seq.n_cells++) + // OO+ -> oO (n_seqs--, n_shared_tail_cells += 2) + // OO+ -> O+ (n_shared_tail_cells++ (the previous cell becomes oO+)) + // _ -> oO (n_shared_tail_cells++) + // _ -> O (seq.n_cells++, n_seqs++) + // Oo -> O (seq.n_cells++, n_seqs++, n_shared_tail_cell--) + // Oo -> OO+ (n_shared_tail_cell--) + // OOo -> O (seq.n_cells++, n_seqs++) + if (prev_cell.seq_nodes.size() == prev_cell.tail_rc) { + // from fully tail + if (prev_cell.tail_rc > 1) { + // the previous tail becomes shared with a non-tail + n_shared_tail_cells += 1; + } + if (!rs_cell.is_empty() && rs_cell.tail_rc == 0) { + // the new tail cell was previously a fully non-tail cell + n_shared_tail_cells += 1; + n_seqs -= 1; + } + } else { + if (rs_cell.is_empty()) { + // from shared to unique + n_seqs += 1; + } + if (prev_cell.tail_rc == 1 && rs_cell.seq_nodes.size() == rs_cell.tail_rc) { + // from last shared to fully tail + n_shared_tail_cells -= 1; + } + } + } + prev_cell.tail_rc -= 1; + } + if (rs_cell.is_empty()) { + // to unique + seq.n_cells += 1; + if (seq.tail < 0) { + // from empty to unique + n_seqs += 1; + // make sure it's cleared + rs_cell.src = -1; + } + used += 1; + } else if (rs_cell.tail_rc == 0) { + // to shared + if (seq.tail < 0) { + // from empty to shared + n_shared_tail_cells += 1; + } + } + // the target cell was not already a tail of this seq_id + rs_cell.insert_node(id); // next_cell == -1 by default + rs_cell.tail_rc += 1; + seq.tail = i_cell; + return true; + } + return false; + } + + size_t total_size() const { + size_t size = 0; + for (struct ggml_tensor * r : r_l) { + size += ggml_nrows(r) * ggml_row_size(r->type, r->ne[0]); + } + for (struct ggml_tensor * s : s_l) { + size += ggml_nrows(s) * ggml_row_size(s->type, s->ne[0]); + } + return size; + } +}; + +struct llama_past { + // key + value cache for self attention + llama_kv_cache kv; + + // recurrent state cache for state space models + llama_rs_cache rs; + std::vector ctxs; std::vector bufs; + // NOTE: padding may make this bigger than kv.total_size() + rs.total_size() size_t total_size() const { size_t size = 0; for (ggml_backend_buffer_t buf : bufs) { @@ -2712,7 +3463,7 @@ struct llama_kv_cache { return size; } - ~llama_kv_cache() { + ~llama_past() { for (struct ggml_context * ctx : ctxs) { ggml_free(ctx); } @@ -3198,7 +3949,7 @@ struct llama_context { struct llama_cparams cparams; struct llama_sampling sampling; struct llama_sbatch sbatch; - struct llama_kv_cache kv_self; + struct llama_past cache; struct llama_control_vector cvec; std::unordered_map lora_adapters; @@ -3274,9 +4025,8 @@ struct llama_context { struct ggml_tensor * inp_K_shift; // I32 [kv_size] struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] - struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] - struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] + struct ggml_tensor * inp_s_copy; // I32 [n_rs] + struct ggml_tensor * inp_s_mask; // F32 [1, n_rs] struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] @@ -3442,37 +4192,45 @@ static size_t llama_get_device_memory(const llama_model & model, int device) { } // -// kv cache helpers +// kv and rs cache helpers // -static bool llama_kv_cache_init( - struct llama_kv_cache & cache, +static bool llama_past_init( + struct llama_past & cache, const llama_context * ctx, ggml_type type_k, ggml_type type_v, - uint32_t kv_size, bool offload) { const llama_model & model = ctx->model; const llama_cparams & cparams = ctx->cparams; const struct llama_hparams & hparams = model.hparams; + const auto & n_head_kv_iter = hparams.n_head_kv_arr.begin(); + const int64_t n_layer = hparams.n_layer; + const bool has_kv = std::any_of(n_head_kv_iter, n_head_kv_iter + n_layer, [](uint32_t n) { return n > 0; }) && hparams.causal_attn; + const bool has_rs = llama_model_is_recurrent(&model); + const uint32_t kv_size = has_kv ? cparams.n_ctx : 0; + const uint32_t rs_size = has_rs ? cparams.n_seq_max : 0; + + cache.kv.size = kv_size; - cache.has_shift = false; + cache.kv.v_trans = !cparams.flash_attn; - cache.recurrent = llama_model_is_recurrent(&model); - cache.v_trans = !cache.recurrent && !cparams.flash_attn; + cache.kv.type_k = type_k; + cache.kv.type_v = type_v; - cache.head = 0; - cache.size = kv_size; - cache.used = 0; + cache.kv.cells.clear(); + cache.kv.cells.resize(kv_size); - cache.type_k = type_k; - cache.type_v = type_v; + cache.rs.size = rs_size; - cache.cells.clear(); - cache.cells.resize(kv_size); + cache.rs.cells.clear(); + cache.rs.cells.resize(rs_size); + cache.rs.seq_tails.clear(); + cache.rs.seq_tails.resize(rs_size); + cache.rs.freeable.reserve(rs_size); // count used buffer types std::map buft_layer_count; @@ -3488,8 +4246,9 @@ static bool llama_kv_cache_init( std::map ctx_map; for (auto & it : buft_layer_count) { int n_layers = it.second; + // TODO: for mixed architectures, avoid allocating empty recurrent state or kv cache tensors struct ggml_init_params params = { - /*.mem_size =*/ 2u*n_layers*ggml_tensor_overhead(), + /*.mem_size =*/ 2*(has_kv + has_rs)*n_layers*ggml_tensor_overhead(), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -3502,20 +4261,33 @@ static bool llama_kv_cache_init( cache.ctxs.push_back(ctx); } - cache.k_l.reserve(n_layer); - cache.v_l.reserve(n_layer); + if (has_kv) { + cache.kv.k_l.reserve(n_layer); + cache.kv.v_l.reserve(n_layer); + } + if (has_rs) { + cache.rs.r_l.reserve(n_layer); + cache.rs.s_l.reserve(n_layer); + } for (int i = 0; i < (int) n_layer; i++) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front(); - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); - ggml_format_name(k, "cache_k_l%d", i); - ggml_format_name(v, "cache_v_l%d", i); - cache.k_l.push_back(k); - cache.v_l.push_back(v); + if (has_kv) { + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_gqa(i)*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_gqa(i)*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + cache.kv.k_l.push_back(k); + cache.kv.v_l.push_back(v); + } + if (has_rs) { + ggml_tensor * r = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_r(i)*rs_size); + ggml_tensor * s = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hparams.n_embd_s(i)*rs_size); + ggml_format_name(r, "cache_r_l%d", i); + ggml_format_name(s, "cache_s_l%d", i); + cache.rs.r_l.push_back(r); + cache.rs.s_l.push_back(s); + } } // allocate tensors and initialize the buffers to avoid NaNs in the padding @@ -3524,11 +4296,15 @@ static bool llama_kv_cache_init( ggml_context * ctx = it.second; ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (!buf) { - LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__); + if (!has_kv && !has_rs) { + // no buffer was needed, so this is fine + return true; + } + LLAMA_LOG_ERROR("%s: failed to allocate buffer for past cache\n", __func__); return false; } ggml_backend_buffer_clear(buf, 0); - LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + LLAMA_LOG_INFO("%s: %10s past cache size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); cache.bufs.push_back(buf); } @@ -3539,226 +4315,254 @@ static bool llama_kv_cache_init( // updates the cache head // Note: On success, it's important that cache.head points // to the first cell of the slot. -static bool llama_kv_cache_find_slot( - struct llama_kv_cache & cache, - const struct llama_ubatch & batch) { +static bool llama_past_find_slot( + struct llama_past & cache, + const struct llama_ubatch & batch) { + const uint32_t kv_size = cache.kv.size; + const uint32_t rs_size = cache.rs.size; const uint32_t n_tokens = batch.n_tokens; const uint32_t n_seqs = batch.n_seqs; const uint32_t n_seq_tokens = batch.n_seq_tokens; - if (cache.recurrent) { - // For recurrent state architectures (like Mamba or RWKV), - // each cache cell can store the state for a whole sequence. - // A slot should be always be contiguous. - - // can only process batches with an equal number of new tokens in each sequence - GGML_ASSERT(batch.equal_seqs); - - int32_t min = cache.size - 1; - int32_t max = 0; - + // only check first, to allow failing gracefully + if (rs_size > 0) { // everything should fit if all seq_ids are smaller than the max - for (uint32_t s = 0; s < n_seqs; ++s) { - const uint32_t n_seq_id = batch.n_seq_id[s]; - for (uint32_t j = 0; j < n_seq_id; ++j) { - const llama_seq_id seq_id = batch.seq_id[s][j]; + for (uint32_t i = 0; i < n_seqs; ++i) { + int32_t n_seq_id = batch.n_seq_id[i]; + for (int32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id = batch.seq_id[i][j]; - if (seq_id < 0 || (uint32_t) seq_id >= cache.size) { + if (seq_id < 0 || (uint32_t) seq_id >= rs_size) { // too big seq_id - // TODO: would it be possible to resize the cache instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size); + // TODO: would it be possible to resize the rs cache size instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.rs.size); return false; } - if (j > 0) { - llama_kv_cell & seq = cache.cells[seq_id]; - if (seq.tail >= 0) { - llama_kv_cell & cell = cache.cells[seq.tail]; - // clear cells from seq_ids that become shared - // (should not normally happen, but let's handle it anyway) - cell.seq_id.erase(seq_id); - seq.tail = -1; - if (cell.seq_id.empty()) { - cell.pos = -1; - cell.src = -1; - cache.used -= 1; - } - } - } } } - -#ifndef NDEBUG + // TODO: configurable checkpoint interval + cache.rs.freeable_for_batch(batch, 8); { - std::vector tails_verif; - tails_verif.assign(cache.size, -1); - for (uint32_t i = 0; i < cache.size; ++i) { - llama_kv_cell & cell = cache.cells[i]; - for (llama_seq_id seq_id : cell.seq_id) { - if (tails_verif[seq_id] != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); - } - tails_verif[seq_id] = i; + uint32_t freeable_rs_cell_count = 0; + for (uint32_t is_freeable : cache.rs.freeable) { + freeable_rs_cell_count += (uint32_t) (is_freeable != 0); + if (freeable_rs_cell_count >= n_seqs) { + // there's enough, no need to count them all + break; } } - for (uint32_t i = 0; i < cache.size; ++i) { - if (tails_verif[i] != cache.cells[i].tail) { - LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cache.cells[i].tail, tails_verif[i]); - } + if (n_seqs > freeable_rs_cell_count) { + // This should not happen + LLAMA_LOG_ERROR("%s: n_seqs=%d > freeable_rs_cell_count=%d\n", __func__, n_seqs, freeable_rs_cell_count); + return false; } } -#endif - - // find next empty cell - uint32_t next_empty_cell = cache.head; + } - for (uint32_t i = 0; i < cache.size; ++i) { - if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; } - llama_kv_cell & cell = cache.cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; + if (kv_size > 0) { + // one KV cell per token + if (n_tokens > kv_size) { + LLAMA_LOG_ERROR("%s: n_tokens=%d > kv_size=%d\n", __func__, n_tokens, kv_size); + return false; } - // find usable cell range - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = batch.seq_id[s][0]; - llama_kv_cell & seq_meta = cache.cells[seq_id]; - bool has_cell = false; - if (seq_meta.tail >= 0) { - llama_kv_cell & cell = cache.cells[seq_meta.tail]; - GGML_ASSERT(cell.has_seq_id(seq_id)); - // does this seq_id "own" the cell? - if (cell.seq_id.size() == 1) { has_cell = true; } - } - if (!has_cell) { - llama_kv_cell & empty_cell = cache.cells[next_empty_cell]; - GGML_ASSERT(empty_cell.is_empty()); - // copy old tail into the empty cell - if (seq_meta.tail >= 0) { - llama_kv_cell & orig_cell = cache.cells[seq_meta.tail]; - empty_cell.pos = orig_cell.pos; - empty_cell.src = orig_cell.src; - orig_cell.seq_id.erase(seq_id); - empty_cell.seq_id.insert(seq_id); // will be overwritten - } - seq_meta.tail = next_empty_cell; - // find next empty cell - if (s + 1 < n_seqs) { - next_empty_cell += 1; - for (uint32_t i = 0; i < cache.size; ++i) { - if (next_empty_cell >= cache.size) { next_empty_cell -= cache.size; } - llama_kv_cell & cell = cache.cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } - } - } - if (min > seq_meta.tail) { min = seq_meta.tail; } - if (max < seq_meta.tail) { max = seq_meta.tail; } + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (cache.kv.head > cache.kv.used + 2*n_tokens) { + cache.kv.head = 0; } - // gather and re-order - for (uint32_t s = 0; s < n_seqs; ++s) { - int32_t dst_id = s + min; - int32_t src_id = cache.cells[batch.seq_id[s][0]].tail; - if (dst_id != src_id) { - llama_kv_cell & dst_cell = cache.cells[dst_id]; - llama_kv_cell & src_cell = cache.cells[src_id]; + uint32_t n_tested = 0; - std::swap(dst_cell.pos, src_cell.pos); - std::swap(dst_cell.src, src_cell.src); - std::swap(dst_cell.seq_id, src_cell.seq_id); + while (true) { + if (cache.kv.head + n_tokens > kv_size) { + n_tested += kv_size - cache.kv.head; + cache.kv.head = 0; + continue; + } - // swap tails (assuming they NEVER overlap) - for (const llama_seq_id seq_id : src_cell.seq_id) { - cache.cells[seq_id].tail = src_id; - } - for (const llama_seq_id seq_id : dst_cell.seq_id) { - cache.cells[seq_id].tail = dst_id; + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cache.kv.cells[cache.kv.head + i].pos >= 0) { + found = false; + cache.kv.head += i + 1; + n_tested += i + 1; + break; } } - } - - // update the pos of the used seqs - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_pos last_pos = batch.pos[n_seq_tokens * s + n_seq_tokens - 1]; - int32_t cell_id = s + min; - llama_kv_cell & cell = cache.cells[cell_id]; - if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", - __func__, last_pos, cell.pos, batch.seq_id[s][0], n_seq_tokens); + if (found) { + break; } - cell.pos = last_pos; - cell.seq_id.clear(); - for (int32_t j = 0; j < batch.n_seq_id[s]; ++j) { - const llama_seq_id seq_id = batch.seq_id[s][j]; - cell.seq_id.insert(seq_id); - cache.cells[seq_id].tail = cell_id; + + if (n_tested >= kv_size) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; } } + } - // allow getting the range of used cells, from head to head + n - cache.head = min; - cache.n = max - min + 1; + // now modification can be done, and should NOT fail + + if (rs_size > 0) { + // For recurrent state architectures (like Mamba or RWKV), + // each cache cell can store the state for a whole sequence. + // A slot should be always be contiguous. + + uint32_t min_head = 0; + uint32_t min_n = cache.rs.size; + uint32_t min_free = 0; + + // compact the freeable cell list + // e.g. 0,1,0,0,1,1,0,1,0,1 -> 1,4,5,7,9 + // while also finding the smallest cell range for the slot + { + uint32_t next_free = 0; + for (size_t i = 0; i < cache.rs.freeable.size(); ++i) { + if (cache.rs.freeable[i]) { + cache.rs.freeable[next_free] = i; + next_free += 1; + + if (next_free >= n_seqs) { + uint32_t head = cache.rs.freeable[next_free - n_seqs]; + // i is the last seen freeable cell id + uint32_t n = i - head + 1; + // keep the first smallest big enough slot + if (n < min_n) { + min_free = next_free - n_seqs; + min_head = head; + min_n = n; + if (n == n_seqs) { + // it's the smallest it can be + break; + } + } + } + } + } + } // sanity check - return cache.n >= n_seqs; - } - // otherwise, one cell per token. + GGML_ASSERT(min_head + min_n <= cache.rs.size); - if (n_tokens > cache.size) { - LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size); - return false; - } + // keep only the necessary range + cache.rs.freeable.resize(min_free + n_seqs); + cache.rs.freeable.erase(cache.rs.freeable.begin(), cache.rs.freeable.begin() + min_free); + GGML_ASSERT(cache.rs.freeable.size() == n_seqs); + GGML_ASSERT(min_n >= n_seqs); + cache.rs.freeable.resize(min_n); - uint32_t n_tested = 0; + // expand the free list + // e.g. 2,4,5,8 -> 1,0,1,1,0,0,1 + for (uint32_t i = n_seqs; i-- > 0;) { + uint32_t dst = cache.rs.freeable[i] - min_head; + if (dst != i) { + cache.rs.freeable[i] = 0; + } + GGML_ASSERT(dst >= i); + cache.rs.freeable[dst] = 1; + } - while (true) { - if (cache.head + n_tokens > cache.size) { - n_tested += cache.size - cache.head; - cache.head = 0; - continue; + // coalesce the free cells together + // e.g. 1,0,1,1,0,0,1 -> 1,1,1,1,0,0,0 + // or 1,0,1,1,1,1 -> 1,1,1,1,1,0 + { + uint32_t top_free = min_n - 1; + for (uint32_t i = min_n; i-- > 1;) { + uint32_t is_free = cache.rs.freeable[i]; + if (!is_free) { + GGML_ASSERT(top_free > i); + cache.rs.swap_cells(min_head + i, min_head + top_free); + std::swap(cache.rs.freeable[i], cache.rs.freeable[top_free]); + // the previous one has to be free, + // otherwise it would already have been swapped. + top_free -= 1; + } + // stop early if all freeable cells have already been put at the beginning + if (top_free < n_seqs) { break; } + } } - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { - if (cache.cells[cache.head + i].pos >= 0) { - found = false; - cache.head += i + 1; - n_tested += i + 1; - break; + // order the re-used cells identically to their batch order + // (and clear the non-reused cells) + { + for (uint32_t i = 0; i < n_seqs; ++i) { + // ignore the already-swapped cells + if (cache.rs.freeable[i]) { + llama_rs_cell & cell = cache.rs.cells[min_head + i]; + if (!cell.is_empty()) { + if (cell.tail_rc == 0) { + cache.rs.clear_cell(cell); + } else { + // Find the seq_id of the first tail of this cell + llama_seq_id seq_id = -1; + for (llama_rs_seq_node & seq_node : cell.seq_nodes) { + if (seq_node.is_tail()) { + seq_id = seq_node.seq_id; + break; + } + } + GGML_ASSERT(seq_id != -1); + + // Which seq_id of the batch is it? + int32_t nth_seq_id = -1; + for (int32_t s = 0; (uint32_t) s < n_seqs; ++s) { + if (seq_id == batch.seq_id[s][0]) { + nth_seq_id = s; + break; + } + } + GGML_ASSERT(nth_seq_id != -1); + + cache.rs.swap_cells(min_head + i, min_head + nth_seq_id); + cache.rs.freeable[i] = 0; + std::swap(cache.rs.freeable[i], cache.rs.freeable[nth_seq_id]); + i -= 1; // check this cell again, now that it was swapped + } + } + } } } - if (found) { - break; + // reserve + { + for (uint32_t i = 0; i < n_seqs; ++i) { + uint32_t i_cell = min_head + i; + int32_t n_seq_id = batch.n_seq_id[i]; + llama_pos end_pos = batch.pos[(i * n_seq_tokens) + n_seq_tokens - 1]; + // set the pos with the first seq_id + cache.rs.insert_seq_tail_to_cell_id(i_cell, batch.seq_id[i][0], end_pos); + // insert the rest of the seq_ids by re-using the cell's pos + for (int j = 1; j < n_seq_id; ++j) { + cache.rs.insert_seq_tail_to_cell_id(i_cell, batch.seq_id[i][j]); + } + } } - if (n_tested >= cache.size) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; - } + // allow getting the range of used cells, from head to head + n + cache.rs.head = min_head; + cache.rs.n = min_n; } - for (uint32_t s = 0; s < n_seqs; s++) { - for (uint32_t i = 0; i < n_seq_tokens; ++i) { - uint32_t k = s*n_seq_tokens + i; - cache.cells[cache.head + k].pos = batch.pos[k]; + if (kv_size > 0) { + for (uint32_t s = 0; s < n_seqs; s++) { + for (uint32_t i = 0; i < n_seq_tokens; ++i) { + uint32_t k = s*n_seq_tokens + i; + cache.kv.cells[cache.kv.head + k].pos = batch.pos[k]; - for (int32_t j = 0; j < batch.n_seq_id[s]; j++) { - cache.cells[cache.head + k].seq_id.insert(batch.seq_id[s][j]); + for (int32_t j = 0; j < batch.n_seq_id[s]; j++) { + cache.kv.cells[cache.kv.head + k].seq_id.insert(batch.seq_id[s][j]); + } } } - } - cache.used += n_tokens; + cache.kv.used += n_tokens; + } return true; } -// find how many cells are currently in use +// find how many KV cells are currently in use static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { for (uint32_t i = cache.size; i > 0; --i) { const llama_kv_cell & cell = cache.cells[i - 1]; @@ -3771,248 +4575,395 @@ static uint32_t llama_kv_cache_cell_max(const struct llama_kv_cache & cache) { return 0; } -static void llama_kv_cache_clear(struct llama_kv_cache & cache) { - for (int32_t i = 0; i < (int32_t) cache.size; ++i) { - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); - cache.cells[i].src = -1; - cache.cells[i].tail = -1; +// find how many recurrent state cells are currently in use +static uint32_t llama_rs_cache_cell_max(const struct llama_rs_cache & cache) { + for (uint32_t i = cache.size; i > 0; --i) { + const llama_rs_cell & cell = cache.cells[i - 1]; + + if (cell.pos >= 0 && !cell.is_empty()) { + return i; + } } - cache.head = 0; - cache.used = 0; + return 0; +} + +static void llama_past_clear(struct llama_past & cache) { + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + kv_cell.pos = -1; + kv_cell.delta = 0; + kv_cell.seq_id.clear(); + } + cache.kv.has_shift = false; + cache.kv.do_defrag = false; + cache.kv.head = 0; + cache.kv.used = 0; + } + if (cache.rs.size > 0) { + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + rs_cell.pos = -1; + rs_cell.src = -1; + rs_cell.prev = -1; + rs_cell.tail_rc = 0; + rs_cell.seq_nodes.clear(); + } + cache.rs.head = 0; + cache.rs.used = 0; + cache.rs.n_seqs = 0; + cache.rs.n_shared_tail_cells = 0; + cache.rs.seq_tails.clear(); + cache.rs.seq_tails.resize(cache.rs.size); + } for (auto & buf : cache.bufs) { ggml_backend_buffer_clear(buf, 0); } } -static bool llama_kv_cache_seq_rm( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1) { - uint32_t new_head = cache.size; +static llama_pos llama_past_seq_rm( + struct llama_past & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1) { + + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); + llama_pos n_past = p0; - // models like Mamba or RWKV can't have a state partially erased - if (cache.recurrent) { - if (seq_id >= (int64_t) cache.size) { + if (cache.rs.size > 0) { + if (seq_id >= (int64_t) cache.rs.size) { // could be fatal - return false; - } - if (0 <= seq_id) { - int32_t & tail_id = cache.cells[seq_id].tail; - if (tail_id >= 0) { - const llama_kv_cell & cell = cache.cells[tail_id]; - // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { - return false; - } - if (p0 <= cell.pos && p1 < cell.pos) { - tail_id = -1; + return n_past; + } + uint32_t new_head = cache.rs.size; + // adjust p0 and p1 according to the states found + llama_pos new_p0 = 0; + llama_pos new_p1 = std::numeric_limits::max(); + + // partial seq_id removal has to happen from the tail + llama_rs_seq_meta & seq = cache.rs.seq_tails[seq_id]; + int32_t cell_id = seq.tail; + + while (cell_id >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + // copy before the cell is potentially changed + int32_t prev_id = rs_cell.prev; + if (rs_cell.pos >= p1 && rs_cell.seq_nodes.size() > 1) { + // non-tail removal for shared cells can only be done when clearing a cell + // (i.e. when the next cell's link to the previous cell can be safely changed) + p1 = rs_cell.pos + 1; + } + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id); + // if the node isn't found, the sequence tree is malformed + GGML_ASSERT(node_iter != rs_cell.seq_nodes.end()); + cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); + // get the smallest removed cell id + if (new_head > (uint32_t) cell_id) { new_head = cell_id; } + } else { + // one more than the biggest non-removed cell of this sequence + if (rs_cell.pos >= n_past) { n_past = rs_cell.pos + 1; } + + if (rs_cell.pos < p0) { + // new_p0 should be right after the max pos in the states before p0 + if (rs_cell.pos >= new_p0) { new_p0 = rs_cell.pos + 1; } + } else { // (rs_cell.pos >= p1) + // new_p1 should be the min pos in the states after p1 + if (rs_cell.pos < new_p1) { new_p1 = rs_cell.pos; } } } - } else { - // seq_id is negative, then the range should include everything or nothing - if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { - return false; - } + cell_id = prev_id; + } + p0 = new_p0; + p1 = new_p1; + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.rs.size && new_head < cache.rs.head) { + cache.rs.head = new_head; } } - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - if (seq_id < 0) { - cache.cells[i].seq_id.clear(); - } else if (cache.cells[i].has_seq_id(seq_id)) { - cache.cells[i].seq_id.erase(seq_id); - } else { - continue; - } - if (cache.cells[i].is_empty()) { - // keep count of the number of used cells - if (cache.cells[i].pos >= 0) cache.used--; + if (cache.kv.size > 0) { + uint32_t new_head = cache.kv.size; + + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; - cache.cells[i].pos = -1; - cache.cells[i].src = -1; - if (new_head == cache.size) new_head = i; + if (seq_id < 0 || kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= p0 && kv_cell.pos < p1) { + if (seq_id < 0) { + kv_cell.seq_id.clear(); + } else { // (kv_cell.has_seq_id(seq_id)) + kv_cell.seq_id.erase(seq_id); + } + if (kv_cell.is_empty()) { + // keep count of the number of used cells + if (kv_cell.pos >= 0) { cache.kv.used--; } + + kv_cell.pos = -1; + if (new_head == cache.kv.size) { new_head = i; } + } + } else if (kv_cell.pos >= n_past) { + n_past = kv_cell.pos + 1; + } } } - } - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.size && new_head < cache.head) cache.head = new_head; + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.kv.size && new_head < cache.kv.head) { + cache.kv.head = new_head; + } + } - return true; + return n_past; } -static void llama_kv_cache_seq_cp( - struct llama_kv_cache & cache, - llama_seq_id seq_id_src, - llama_seq_id seq_id_dst, - llama_pos p0, - llama_pos p1) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); - - if (cache.recurrent) { - if ((uint32_t) seq_id_dst < cache.size && (uint32_t) seq_id_src < cache.size) { - llama_kv_cell & tail_src = cache.cells[seq_id_src]; - llama_kv_cell & tail_dst = cache.cells[seq_id_dst]; - if (tail_dst.tail >= 0) { - // clear destination seq_id if it wasn't empty - llama_kv_cell & cell_dst = cache.cells[tail_dst.tail]; - - cell_dst.seq_id.erase(seq_id_dst); - tail_dst.tail = -1; - if (cell_dst.seq_id.empty()) { - cell_dst.pos = -1; - cell_dst.delta = -1; - cell_dst.src = -1; - cache.used -= 1; +static llama_pos llama_past_seq_cp( + struct llama_past & cache, + llama_seq_id seq_id_src, + llama_seq_id seq_id_dst, + llama_pos p0, + llama_pos p1) { + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } + + // TODO: in practice this seems to be only used on whole sequences; + // should partial sequence copy support be removed? + // TODO: What if the destination sequence is not empty? + + llama_pos n_past = 0; + + if (cache.rs.size > 0) { + // have to start from the beginning for recurrent models + p0 = 0; + if ((uint32_t) seq_id_dst < cache.rs.size && (uint32_t) seq_id_src < cache.rs.size) { + int32_t src_head = -1; + int32_t head_pos = p1; + int32_t src_next = -1; + // find the start of the sequence + for (uint32_t i = 0; i < cache.rs.size; ++i) { + llama_rs_cell & rs_cell = cache.rs.cells[i]; + if (!rs_cell.is_empty() && rs_cell.prev < 0) { + auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src); + if (seq_node != rs_cell.seq_nodes.end()) { + src_head = i; + head_pos = rs_cell.pos; + src_next = seq_node->next_cell; + break; + } } } - if (tail_src.tail >= 0) { - llama_kv_cell & cell_src = cache.cells[tail_src.tail]; - - cell_src.seq_id.insert(seq_id_dst); - tail_dst.tail = tail_src.tail; + while (src_head >= 0 && head_pos < p1) { + cache.rs.insert_seq_tail_to_cell_id(src_head, seq_id_dst); + src_head = src_next; + if (head_pos >= n_past) { n_past = head_pos + 1; } + if (src_next >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[src_next]; + auto seq_node = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), seq_id_src); + head_pos = rs_cell.pos; + // it should always be found if the seq tree is valid + GGML_ASSERT(seq_node != rs_cell.seq_nodes.end()); + src_next = seq_node->next_cell; + } } } + p1 = n_past; + } - return; + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.pos >= p0 && kv_cell.pos < p1 && kv_cell.has_seq_id(seq_id_src)) { + kv_cell.seq_id.insert(seq_id_dst); + if (kv_cell.pos >= n_past) { n_past = kv_cell.pos + 1; } + } + } } - // otherwise, this is the KV cache of a Transformer-like model - cache.head = 0; + return n_past; +} + +static void llama_past_seq_keep(struct llama_past & cache, llama_seq_id seq_id) { + if (cache.rs.size > 0) { + uint32_t new_head = cache.rs.size; + + // partial seq_id removal has to happen from the tail(s) + for (uint32_t i = 0; i < cache.rs.seq_tails.size(); ++i) { + if (i == (uint32_t) seq_id) { continue; } + llama_rs_seq_meta & seq = cache.rs.seq_tails[i]; + int32_t cell_id = seq.tail; + while (cell_id >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + auto node_iter = std::find(rs_cell.seq_nodes.begin(), rs_cell.seq_nodes.end(), i); + GGML_ASSERT(node_iter != rs_cell.seq_nodes.end()); + cache.rs.remove_seq_node_from_cell(rs_cell, node_iter); + cell_id = rs_cell.prev; + if (new_head > (uint32_t) cell_id && rs_cell.is_empty()) { + new_head = cell_id; + } + } + } - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.cells[i].seq_id.insert(seq_id_dst); + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.rs.size && new_head < cache.rs.head) { + cache.rs.head = new_head; } } -} -static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { - uint32_t new_head = cache.size; + if (cache.kv.size > 0) { + uint32_t new_head = cache.kv.size; - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.recurrent && (llama_seq_id) i != seq_id) { - cache.cells[i].tail = -1; + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (!kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= 0) { cache.kv.used--; } + kv_cell.pos = -1; + kv_cell.seq_id.clear(); + if (new_head == cache.kv.size) { new_head = i; } + } else { + kv_cell.seq_id.clear(); + kv_cell.seq_id.insert(seq_id); + } } - if (!cache.cells[i].has_seq_id(seq_id)) { - if (cache.cells[i].pos >= 0) cache.used--; - cache.cells[i].pos = -1; - cache.cells[i].src = -1; - cache.cells[i].seq_id.clear(); - if (new_head == cache.size) new_head = i; - } else { - cache.cells[i].seq_id.clear(); - cache.cells[i].seq_id.insert(seq_id); + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.kv.size && new_head < cache.kv.head) { + cache.kv.head = new_head; } } - - // If we freed up a slot, set head to it so searching can start there. - if (new_head != cache.size && new_head < cache.head) cache.head = new_head; } -static void llama_kv_cache_seq_add( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - llama_pos delta) { - uint32_t new_head = cache.size; - - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); - // If there is no range then return early to avoid looping over the cache. - if (p0 == p1) return; - - if (cache.recurrent) { - // for Mamba-like or RWKV models, only the pos needs to be shifted - if (0 <= seq_id && seq_id < (int64_t) cache.size) { - const int32_t tail_id = cache.cells[seq_id].tail; - if (tail_id >= 0) { - llama_kv_cell & cell = cache.cells[tail_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos += delta; +static void llama_past_seq_add( + struct llama_past & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + llama_pos delta) { + + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } + + if (cache.rs.size > 0) { + // for Mamba-like or RKWV models, only the pos needs to be shifted + auto & seq = cache.rs.seq_tails[seq_id]; + // follow the sequence from its tail + int32_t cell_id = seq.tail; + uint32_t new_head = cache.rs.size; + while (cell_id >= 0) { + GGML_ASSERT((uint32_t) cell_id < cache.rs.size); + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + cell_id = rs_cell.prev; + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.pos += delta; + if (rs_cell.pos < 0) { + // NOTE: this affects the other sequences which share the cell + cache.rs.clear_cell(rs_cell); + if (new_head > (uint32_t) cell_id) { + new_head = cell_id; + } } } } - return; + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + cache.rs.head = new_head != cache.rs.size ? new_head : 0; } - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.has_shift = true; - cache.cells[i].pos += delta; - cache.cells[i].delta += delta; + if (cache.kv.size > 0) { + uint32_t new_head = cache.kv.size; - if (cache.cells[i].pos < 0) { - if (!cache.cells[i].is_empty()) { - cache.used--; - } - cache.cells[i].pos = -1; - cache.cells[i].seq_id.clear(); - if (new_head == cache.size) { - new_head = i; + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= p0 && kv_cell.pos < p1) { + cache.kv.has_shift = true; + kv_cell.pos += delta; + kv_cell.delta += delta; + + if (kv_cell.pos < 0) { + if (!kv_cell.is_empty()) { + cache.kv.used--; + } + kv_cell.pos = -1; + kv_cell.seq_id.clear(); + if (new_head == cache.kv.size) { + new_head = i; + } + } } } } - } - // If we freed up a slot, set head to it so searching can start there. - // Otherwise we just start the next search from the beginning. - cache.head = new_head != cache.size ? new_head : 0; + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + cache.kv.head = new_head != cache.kv.size ? new_head : 0; + } } -static void llama_kv_cache_seq_div( - struct llama_kv_cache & cache, - llama_seq_id seq_id, - llama_pos p0, - llama_pos p1, - int d) { - if (p0 < 0) p0 = 0; - if (p1 < 0) p1 = std::numeric_limits::max(); - // If there is no range then return early to avoid looping over the cache. - if (p0 == p1) return; - - if (cache.recurrent) { +static void llama_past_seq_div( + struct llama_past & cache, + llama_seq_id seq_id, + llama_pos p0, + llama_pos p1, + int d) { + if (p0 < 0) { p0 = 0; } + if (p1 < 0) { p1 = std::numeric_limits::max(); } + + if (cache.rs.size > 0) { // for Mamba-like or RWKV models, only the pos needs to be changed - if (0 <= seq_id && seq_id < (int64_t) cache.size) { - const int32_t tail_id = cache.cells[seq_id].tail; - if (tail_id >= 0) { - llama_kv_cell & cell = cache.cells[tail_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos /= d; - } + auto & seq = cache.rs.seq_tails[seq_id]; + int32_t cell_id = seq.tail; + while (cell_id >= 0) { + GGML_ASSERT((uint32_t) cell_id < cache.rs.size); + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + if (rs_cell.pos >= p0 && rs_cell.pos < p1) { + rs_cell.pos /= d; } + cell_id = rs_cell.prev; } - return; } - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { - cache.has_shift = true; + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.has_seq_id(seq_id)) { + if (kv_cell.pos >= p0 && kv_cell.pos < p1) { + cache.kv.has_shift = true; - { - llama_pos p_old = cache.cells[i].pos; - cache.cells[i].pos /= d; - cache.cells[i].delta += cache.cells[i].pos - p_old; + { + llama_pos p_old = kv_cell.pos; + kv_cell.pos /= d; + kv_cell.delta += kv_cell.pos - p_old; + } + } } } } } -static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama_seq_id seq_id) { - llama_pos result = 0; +static llama_pos llama_past_seq_pos_max(struct llama_past & cache, llama_seq_id seq_id) { + llama_pos result = -1; - for (uint32_t i = 0; i < cache.size; ++i) { - if (cache.cells[i].has_seq_id(seq_id)) { - result = std::max(result, cache.cells[i].pos); + if (cache.rs.size > 0) { + int32_t cell_id = cache.rs.seq_tails[seq_id].tail; + if (cell_id >= 0) { + llama_rs_cell & rs_cell = cache.rs.cells[cell_id]; + result = rs_cell.pos; + } + // exit early + return result; + } + + if (cache.kv.size > 0) { + for (uint32_t i = 0; i < cache.kv.size; ++i) { + llama_kv_cell & kv_cell = cache.kv.cells[i]; + if (kv_cell.has_seq_id(seq_id)) { + result = std::max(result, kv_cell.pos); + } } } @@ -4020,9 +4971,7 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama } static void llama_kv_cache_defrag(struct llama_kv_cache & cache) { - if (!cache.recurrent) { - cache.do_defrag = true; - } + cache.do_defrag = true; } static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) { @@ -5699,6 +6648,22 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_JAMBA: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + // TODO: Jamba layers are a bit heterogenous, so naming this is hard. + case 12: // 900M 8x???M + case 32: // 51B 16x?B + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -7821,6 +8786,117 @@ static bool llm_load_tensors( layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); } } break; + case LLM_ARCH_JAMBA: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t dt_rank = hparams.ssm_dt_rank; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + const int64_t n_head_kv = hparams.n_head_kv(i); + const int64_t n_embd_gqa = hparams.n_embd_v_gqa(i); + + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + // norm + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + if (n_head_kv == 0) { + // Mamba layer + layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); + + layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); + layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); + + layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); + + layer.ssm_dt_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT_NORM, "weight", i), {dt_rank}); + + layer.ssm_dt = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_rank, d_inner}); + layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); + + layer.ssm_b_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_B_NORM, "weight", i), {d_state}); + layer.ssm_c_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_C_NORM, "weight", i), {d_state}); + + // no "weight" suffix for these + layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); + layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner}); + + // out_proj + layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); + + layer.wq = nullptr; + layer.wk = nullptr; + layer.wv = nullptr; + layer.wo = nullptr; + + } else { + // Attention layers + + layer.wq = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}); + layer.wk = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}); + layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}); + layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}); + + layer.ssm_in = nullptr; + layer.ssm_conv1d = nullptr; + layer.ssm_conv1d_b = nullptr; + layer.ssm_x = nullptr; + layer.ssm_dt_norm = nullptr; + layer.ssm_dt = nullptr; + layer.ssm_dt_b = nullptr; + layer.ssm_b_norm = nullptr; + layer.ssm_c_norm = nullptr; + layer.ssm_a = nullptr; + layer.ssm_d = nullptr; + layer.ssm_out = nullptr; + } + + layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}); + + layer.ffn_gate_inp = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, llama_model_loader::TENSOR_NOT_REQUIRED); + + if (layer.ffn_gate_inp) { + // MoE + layer.ffn_gate_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + layer.ffn_down_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff, n_embd, n_expert}); + layer.ffn_up_exps = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}); + + layer.ffn_gate = nullptr; + layer.ffn_down = nullptr; + layer.ffn_up = nullptr; + } else { + // FFN (no MoE) + layer.ffn_gate = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}); + layer.ffn_down = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}); + layer.ffn_up = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}); + + layer.ffn_gate_exps = nullptr; + layer.ffn_down_exps = nullptr; + layer.ffn_up_exps = nullptr; + } + } + } break; case LLM_ARCH_XVERSE: { model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); @@ -9225,15 +10301,15 @@ static struct ggml_tensor * llm_build_copy_mask_state( struct ggml_tensor * state_copy, struct ggml_tensor * state_mask, int32_t n_state, - int32_t kv_size, - int32_t kv_head, - int32_t n_kv, + int32_t rs_size, + int32_t rs_head, + int32_t n_rs, int32_t n_seqs) { - struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size); + struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, rs_size); // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // this shrinks the tensors's ne[1] to n_kv + // NOTE: assuming the copy destinations are ALL contained between rs_head and rs_head + n_rs + // this shrinks the tensors's ne[1] to n_rs states = ggml_get_rows(ctx, states, state_copy); // clear states of sequences which are starting at the beginning of this batch @@ -9243,8 +10319,8 @@ static struct ggml_tensor * llm_build_copy_mask_state( // copy states which won't be changed further (between n_seqs and n_rs) ggml_build_forward_expand(graph, ggml_cpy(ctx, - ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)), - ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s)))); + ggml_view_1d(ctx, states, n_state*(n_rs - n_seqs), n_seqs*n_state*ggml_element_size(states)), + ggml_view_1d(ctx, s, n_state*(n_rs - n_seqs), (rs_head + n_seqs)*n_state*ggml_element_size(s)))); // the part of the states that will be used and modified return ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0); @@ -9259,13 +10335,16 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_tensor * cur, struct ggml_tensor * state_copy, struct ggml_tensor * state_mask, - int32_t kv_head, - int32_t n_kv, + struct ggml_tensor * w_dt_norm, + struct ggml_tensor * w_b_norm, + struct ggml_tensor * w_c_norm, + int32_t rs_head, + int32_t n_rs, const llm_build_cb & cb, int il) { const llama_model & model = lctx.model; const llama_hparams & hparams = model.hparams; - const llama_kv_cache & kv = lctx.kv_self; + const llama_rs_cache & rs = lctx.cache.rs; const int64_t d_conv = hparams.ssm_d_conv; const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; @@ -9273,8 +10352,6 @@ static struct ggml_tensor * llm_build_mamba( const int64_t n_seqs = batch.n_seqs; // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; - // Use the same RMS norm as the final layer norm - const float norm_rms_eps = hparams.f_norm_rms_eps; const int64_t n_seq_tokens = batch.n_seq_tokens; @@ -9282,17 +10359,16 @@ static struct ggml_tensor * llm_build_mamba( GGML_ASSERT(batch.equal_seqs); GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs); - struct ggml_tensor * conv_states_all = kv.k_l[il]; - struct ggml_tensor * ssm_states_all = kv.v_l[il]; + struct ggml_tensor * conv_states_all = rs.r_l[il]; + struct ggml_tensor * ssm_states_all = rs.s_l[il]; - // (ab)using the KV cache to store the states struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, graph, conv_states_all, state_copy, state_mask, - hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); + hparams.n_embd_r(il), rs.size, rs_head, n_rs, n_seqs); conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs); struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, graph, ssm_states_all, state_copy, state_mask, - hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); + hparams.n_embd_s(il), rs.size, rs_head, n_rs, n_seqs); ssm = ggml_reshape_3d(ctx, ssm, d_state, d_inner, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} @@ -9317,7 +10393,7 @@ static struct ggml_tensor * llm_build_mamba( ggml_cpy(ctx, last_conv, ggml_view_1d(ctx, conv_states_all, (d_conv - 1)*(d_inner)*(n_seqs), - kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + rs_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); // 1D convolution // The equivalent is to make a self-overlapping view of conv_x @@ -9328,7 +10404,34 @@ static struct ggml_tensor * llm_build_mamba( // then permute away the ne[0] dimension, // and then you're left with the resulting x tensor. // For simultaneous sequences, all sequences need to have the same length. + + // TODO: remove unused implementations +#if 0 + // For some reason, im2col expects a F16 kernel, but doesn't even read from it. + // TODO: make im2col accept F32 kernels to directly pass ssm_conv1d to it. + // => { d_conv * d_inner, n_seq_tokens, n_seqs} + x = ggml_im2col(ctx, + ggml_new_tensor_2d(ctx, GGML_TYPE_F16, d_conv, d_inner), + conv_x, 1, 0, 0, 0, 1, 0, false, GGML_TYPE_F32); + + #if 0 + // TODO: CUDA, SYCL, and Vulkan don't (yet) support broadcasting the ne[3] dimension on MUL_MAT + x = ggml_reshape_4d(ctx, x, d_conv, 1, d_inner, n_seq_tokens * n_seqs); + + // => {1, 1, d_inner, n_seq_tokens * n_seqs} + x = ggml_mul_mat(ctx, ggml_reshape_3d(ctx, model.layers[il].ssm_conv1d, d_conv, 1, d_inner), x); + #else + x = ggml_reshape_4d(ctx, x, d_conv, d_inner, n_seq_tokens, n_seqs); + + // NOTE: it seems this is very slighly more performant than MUL_MAT on CPU for small row sizes + // => {1, d_inner, n_seq_tokens, n_seqs} + x = ggml_sum_rows(ctx, ggml_mul(ctx, x, model.layers[il].ssm_conv1d)); + #endif + x = ggml_reshape_3d(ctx, x, d_inner, n_seq_tokens, n_seqs); +#else + // Alternatively, this does the same as the above, but faster x = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d); +#endif // bias x = ggml_add(ctx, x, model.layers[il].ssm_conv1d_b); @@ -9345,11 +10448,12 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_tensor * B = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); struct ggml_tensor * C = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); - // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers - if (ssm_dt_b_c_rms) { - dt = ggml_rms_norm(ctx, dt, norm_rms_eps); - B = ggml_rms_norm(ctx, B, norm_rms_eps); - C = ggml_rms_norm(ctx, C, norm_rms_eps); + + // Some Mamba variants (e.g. FalconMamba, Jamba) apply RMS norm in B, C & Dt layers + if (ssm_dt_b_c_rms || (w_dt_norm && w_b_norm && w_c_norm)) { + dt = llm_build_norm(ctx, dt, hparams, w_dt_norm, NULL, LLM_NORM_RMS, cb, il); + B = llm_build_norm(ctx, B, hparams, w_b_norm, NULL, LLM_NORM_RMS, cb, il); + C = llm_build_norm(ctx, C, hparams, w_c_norm, NULL, LLM_NORM_RMS, cb, il); } // {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} @@ -9365,7 +10469,7 @@ static struct ggml_tensor * llm_build_mamba( ggml_build_forward_expand(graph, ggml_cpy(ctx, ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), - ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, rs_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); @@ -9558,6 +10662,7 @@ struct llm_build_context { const llama_cparams & cparams; const llama_ubatch & batch; const llama_kv_cache & kv_self; + const llama_rs_cache & rs_self; const int64_t n_embd; const int64_t n_layer; @@ -9566,9 +10671,7 @@ struct llm_build_context { const int64_t n_head; const int64_t n_head_kv; const int64_t n_embd_head_k; - const int64_t n_embd_k_gqa; const int64_t n_embd_head_v; - const int64_t n_embd_v_gqa; const int64_t n_expert; const int64_t n_expert_used; @@ -9581,11 +10684,15 @@ struct llm_build_context { const float norm_eps; const float norm_rms_eps; + const int32_t n_seqs; + const int32_t n_seq_tokens; const int32_t n_tokens; const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size) + const int32_t n_rs; const int32_t n_outputs; const int32_t n_outputs_enc; const int32_t kv_head; // index of where we store new KV data in the cache + const int32_t rs_head; const int32_t n_ctx_orig; const bool flash_attn; @@ -9610,7 +10717,8 @@ struct llm_build_context { hparams (model.hparams), cparams (lctx.cparams), batch (batch), - kv_self (lctx.kv_self), + kv_self (lctx.cache.kv), + rs_self (lctx.cache.rs), n_embd (hparams.n_embd), n_layer (hparams.n_layer), n_rot (hparams.n_rot), @@ -9618,9 +10726,7 @@ struct llm_build_context { n_head (hparams.n_head()), n_head_kv (hparams.n_head_kv()), n_embd_head_k (hparams.n_embd_head_k), - n_embd_k_gqa (hparams.n_embd_k_gqa()), n_embd_head_v (hparams.n_embd_head_v), - n_embd_v_gqa (hparams.n_embd_v_gqa()), n_expert (hparams.n_expert), n_expert_used (hparams.n_expert_used), freq_base (cparams.rope_freq_base), @@ -9631,11 +10737,15 @@ struct llm_build_context { beta_slow (cparams.yarn_beta_slow), norm_eps (hparams.f_norm_eps), norm_rms_eps (hparams.f_norm_rms_eps), + n_seqs (batch.n_seqs), + n_seq_tokens (batch.n_seq_tokens), n_tokens (batch.n_tokens), n_kv (worst_case ? kv_self.size : kv_self.n), - n_outputs (worst_case ? n_tokens : lctx.n_outputs), - n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), - kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), + n_rs (worst_case ? rs_self.size : rs_self.n), + n_outputs (worst_case ? n_tokens : lctx.n_outputs), + n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), + kv_head (worst_case ? kv_self.size - n_tokens : kv_self.head), + rs_head (worst_case ? 0 : rs_self.head), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), @@ -9665,7 +10775,6 @@ struct llm_build_context { lctx.inp_cls = nullptr; lctx.inp_s_copy = nullptr; lctx.inp_s_mask = nullptr; - lctx.inp_s_seq = nullptr; lctx.inp_pos_bucket = nullptr; lctx.inp_embd_enc = nullptr; lctx.inp_KQ_mask_cross = nullptr; @@ -9843,14 +10952,14 @@ struct llm_build_context { } struct ggml_tensor * build_inp_s_copy() { - lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); + lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_rs); cb(lctx.inp_s_copy, "inp_s_copy", -1); ggml_set_input(lctx.inp_s_copy); return lctx.inp_s_copy; } struct ggml_tensor * build_inp_s_mask() { - lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); + lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_rs); cb(lctx.inp_s_mask, "inp_s_mask", -1); ggml_set_input(lctx.inp_s_mask); return lctx.inp_s_mask; @@ -13202,8 +14311,94 @@ struct llm_build_context { cb(cur, "attn_norm", il); cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, - state_copy, state_mask, - kv_head, n_kv, cb, il); + state_copy, state_mask, NULL, NULL, NULL, + rs_head, n_rs, cb, il); + + if (il == n_layer - 1) { + // skip computing output for unused tokens + struct ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // residual + cur = ggml_add(ctx0, cur, inpL); + cur = lctx.cvec.apply_to(ctx0, cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + // final rmsnorm + cur = llm_build_norm(ctx0, inpL, hparams, + model.output_norm, NULL, + LLM_NORM_RMS, cb, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = ggml_mul_mat(ctx0, model.output, cur); + cb(cur, "result_output", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + + struct ggml_cgraph * build_jamba() { + struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + + const int64_t n_embd_head = hparams.n_embd_head_v; + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); + + struct ggml_tensor * state_copy = build_inp_s_copy(); + struct ggml_tensor * state_mask = build_inp_s_mask(); + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = build_inp_KQ_mask(); + + for (int il = 0; il < n_layer; ++il) { + const int64_t n_head_kv = hparams.n_head_kv(il); + + cur = llm_build_norm(ctx0, inpL, hparams, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "attn_norm", il); + + if (n_head_kv == 0) { + // Mamba + cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, + state_copy, state_mask, + model.layers[il].ssm_dt_norm, model.layers[il].ssm_b_norm, model.layers[il].ssm_c_norm, + rs_head, n_rs, cb, il); + } else { + // Attention + + struct ggml_tensor * Qcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wq, cur); + struct ggml_tensor * Kcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wk, cur); + struct ggml_tensor * Vcur = llm_build_lora_mm(lctx, ctx0, model.layers[il].wv, cur); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + + // No RoPE :) + + cur = llm_build_kv(ctx0, lctx, kv_self, gf, + model.layers[il].wo, NULL, + Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il); + } if (il == n_layer - 1) { // skip computing output for unused tokens @@ -13213,7 +14408,40 @@ struct llm_build_context { } // residual - cur = ggml_add(ctx0, cur, inpL); + struct ggml_tensor * ffn_inp = ggml_add(ctx0, inpL, cur); + cb(cur, "ffn_inp", il); + + cur = llm_build_norm(ctx0, ffn_inp, hparams, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, cb, il); + cb(cur, "ffn_norm", il); + + // feed-forward network + if (model.layers[il].ffn_gate_inp == nullptr) { + // FFN + cur = llm_build_ffn(ctx0, lctx, cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, cb, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = llm_build_moe_ffn(ctx0, lctx, cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + n_expert, n_expert_used, + LLM_FFN_SILU, false, + false, 0.0, + cb, il); + cb(cur, "ffn_moe_out", il); + } + + // residual + cur = ggml_add(ctx0, ffn_inp, cur); cur = lctx.cvec.apply_to(ctx0, cur, il); cb(cur, "l_out", il); @@ -14463,7 +15691,7 @@ struct llm_build_context { struct ggml_tensor * k = ggml_view_3d(ctx0, kv_self.k_l[il], n_embd_head_k, n_kv, n_head_kv, - ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv_self.k_l[il]->type, hparams.n_embd_k_gqa()), ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k), 0); cb(k, "k", il); @@ -15074,10 +16302,10 @@ struct llm_build_context { } ggml_cgraph * build_rwkv6() { - ggml_cgraph *gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); + ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); // Token shift state dimensions should be 2 * n_emb - GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2); + GGML_ASSERT(n_embd == hparams.n_embd_r(0) / 2); const int64_t n_seqs = batch.n_seqs; const int64_t n_seq_tokens = batch.n_seq_tokens; @@ -15099,11 +16327,11 @@ struct llm_build_context { // (ab)using the KV cache to store the states struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0, - gf, kv_self.k_l[il], state_copy, state_mask, - hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs); + gf, rs_self.r_l[il], state_copy, state_mask, + hparams.n_embd_r(il), rs_self.size, rs_head, n_rs, n_seqs); struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0, - gf, kv_self.v_l[il], state_copy, state_mask, - hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs); + gf, rs_self.s_l[il], state_copy, state_mask, + hparams.n_embd_s(il), rs_self.size, rs_head, n_rs, n_seqs); cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); token_shift = ggml_reshape_3d(ctx0, token_shift, n_embd, 2, n_seqs); @@ -15128,9 +16356,9 @@ struct llm_build_context { wkv_states, ggml_view_1d( ctx0, - kv_self.v_l[il], - hparams.n_embd_v_s() * n_seqs, - hparams.n_embd_v_s() * kv_head * ggml_element_size(kv_self.v_l[il]) + rs_self.s_l[il], + hparams.n_embd_s(il) * n_seqs, + hparams.n_embd_s(il) * rs_head * ggml_element_size(rs_self.s_l[il]) ) ) ); @@ -15155,7 +16383,7 @@ struct llm_build_context { ggml_cpy( ctx0, ggml_view_1d(ctx0, token_shift, n_embd * n_seqs * 2, 0), - ggml_view_1d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s() * n_seqs, hparams.n_embd_k_s() * kv_head * ggml_element_size(kv_self.k_l[il])) + ggml_view_1d(ctx0, rs_self.r_l[il], hparams.n_embd_r(il) * n_seqs, hparams.n_embd_r(il) * rs_head * ggml_element_size(rs_self.r_l[il])) ) ); @@ -15365,6 +16593,10 @@ static struct ggml_cgraph * llama_build_graph( { result = llm.build_mamba(); } break; + case LLM_ARCH_JAMBA: + { + result = llm.build_jamba(); + } break; case LLM_ARCH_XVERSE: { result = llm.build_xverse(); @@ -15448,26 +16680,14 @@ static struct ggml_cgraph * llama_build_graph( } static void llama_set_k_shift(llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; + const int64_t kv_size = lctx.cache.kv.size; assert(ggml_backend_buffer_is_host(lctx.inp_K_shift->buffer)); int32_t * data = (int32_t *) lctx.inp_K_shift->data; for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].delta; - } -} - -static void llama_set_s_copy(llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; - - assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); - - int32_t * data = (int32_t *) lctx.inp_s_copy->data; - - for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].src; + data[i] = lctx.cache.kv.cells[i].delta; } } @@ -15502,7 +16722,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { const auto & hparams = lctx.model.hparams; const auto & cparams = lctx.cparams; - const auto & kv_self = lctx.kv_self; + const auto & kv_self = lctx.cache.kv; + const auto & rs_self = lctx.cache.rs; if (batch.token) { const int64_t n_tokens = batch.n_tokens; @@ -15783,46 +17004,47 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { } } - if (kv_self.recurrent) { - const int64_t n_kv = kv_self.n; + if (rs_self.size > 0) { + const int64_t n_rs = rs_self.n; if (lctx.inp_s_mask) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); float * data = (float *) lctx.inp_s_mask->data; // clear unused states - for (int i = 0; i < n_kv; ++i) { - uint32_t cell_id = i + kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; + for (int i = 0; i < n_rs; ++i) { + uint32_t cell_id = i + rs_self.head; + llama_rs_cell & rs_cell = lctx.cache.rs.cells[cell_id]; - data[i] = (float) (kv_cell.src >= 0); + data[i] = (float) (rs_cell.src >= 0); // only clear once - if (kv_cell.src < 0) { - kv_cell.src = cell_id; + if (rs_cell.src < 0) { + rs_cell.src = cell_id; } } } + // checkpoints require copies between cells if (lctx.inp_s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); int32_t * data = (int32_t *) lctx.inp_s_copy->data; // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n - for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; + for (uint32_t i = 0; i < n_rs; ++i) { + const uint32_t cell_id = i + rs_self.head; + llama_rs_cell & rs_cell = lctx.cache.rs.cells[cell_id]; // prevent out-of-bound sources - if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) { - kv_cell.src = cell_id; + if (rs_cell.src < 0 || (uint32_t) rs_cell.src >= rs_self.size) { + rs_cell.src = cell_id; } - data[i] = kv_cell.src; + data[i] = rs_cell.src; // ensure copy only happens once - if (kv_cell.src != (int32_t) cell_id) { - kv_cell.src = cell_id; + if (rs_cell.src != (int32_t) cell_id) { + rs_cell.src = cell_id; } } } @@ -15836,12 +17058,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { int32_t * data = (int32_t *) lctx.inp_pos_bucket->data; + // FIXME: use batch.n_seqs if (!lctx.is_encoding) { const int64_t n_kv = kv_self.n; for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { for (int i = 0; i < n_kv; ++i) { - data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); + data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.cache.kv.cells[i].pos, batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding); } } } @@ -15872,6 +17095,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { float * data = (float *) lctx.inp_KQ_mask_cross->data; + // FIXME: use batch.n_seqs for (int h = 0; h < 1; ++h) { for (int j = 0; j < n_tokens; ++j) { for (int i = 0; i < n_output_enc; ++i) { @@ -16063,7 +17287,8 @@ static int llama_decode_internal( } lctx.n_queued_tokens += n_tokens_all; - auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.cache.kv; + auto & rs_self = lctx.cache.rs; const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -16091,7 +17316,7 @@ static int llama_decode_internal( } lctx.sbatch.from_batch(batch_all, n_embd, - /* simple_split */ !kv_self.recurrent, + /* simple_split */ rs_self.size == 0, /* logits_all */ n_outputs == n_tokens_all); // reserve output buffer @@ -16102,7 +17327,7 @@ static int llama_decode_internal( while (lctx.sbatch.n_tokens > 0) { llama_ubatch ubatch; - if (kv_self.recurrent) { + if (rs_self.size > 0) { if (embd_pooled) { // Pooled embeddings cannot be split across ubatches (yet) ubatch = lctx.sbatch.split_seq(n_ubatch); @@ -16148,11 +17373,12 @@ static int llama_decode_internal( kv_self.head = 0; } - if (!llama_kv_cache_find_slot(kv_self, ubatch)) { + if (!llama_past_find_slot(lctx.cache, ubatch)) { return 1; } - if (!kv_self.recurrent) { + // TODO: move into llama_past_find_slot + if (kv_self.size > 0) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important @@ -16202,11 +17428,15 @@ static int llama_decode_internal( // update the kv ring buffer { kv_self.head += n_tokens; + rs_self.head += rs_self.n; // Ensure kv cache head points to a valid index. if (kv_self.head >= kv_self.size) { kv_self.head = 0; } + if (rs_self.head >= rs_self.size) { + rs_self.head = 0; + } } // plot the computation graph in dot format (for debugging purposes) @@ -16272,6 +17502,10 @@ static int llama_decode_internal( } } n_outputs_prev += lctx.n_outputs; + +#ifndef NDEBUG + GGML_ASSERT(lctx.cache.rs.rebuild(true)); +#endif } // set output mappings @@ -16479,7 +17713,7 @@ static int llama_encode_internal( // find holes from the beginning of the KV cache and fill them by moving data from the end of the cache static void llama_kv_cache_defrag_internal(struct llama_context & lctx) { - auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.cache.kv; const auto & hparams = lctx.model.hparams; @@ -16702,7 +17936,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { bool need_reserve = false; // apply K-shift if needed - if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) { + if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.cache.kv.has_shift) { if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA GGML_ABORT("Deepseek2 does not support K-shift"); } @@ -16722,7 +17956,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } { - auto & kv_self = lctx.kv_self; + auto & kv_self = lctx.cache.kv; kv_self.has_shift = false; @@ -16733,12 +17967,12 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) { } // defragment the KV cache if needed - if (lctx.kv_self.do_defrag) { + if (lctx.cache.kv.do_defrag) { llama_kv_cache_defrag_internal(lctx); need_reserve = true; - lctx.kv_self.do_defrag = false; + lctx.cache.kv.do_defrag = false; } // reserve a worst case graph again @@ -18170,18 +19404,8 @@ struct llama_context * llama_new_context_with_model( // build worst-case graph for encoder if a model contains encoder ctx->is_encoding = llama_model_has_encoder(model); - uint32_t kv_size = cparams.n_ctx; - ggml_type type_k = params.type_k; - ggml_type type_v = params.type_v; - - // Mamba only needs a constant number of KV cache cells per sequence - if (llama_model_is_recurrent(model)) { - // Mamba needs at least as many KV cells as there are sequences kept at any time - kv_size = std::max((uint32_t) 1, params.n_seq_max); - // it's probably best to keep as much precision as possible for the states - type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states - type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states - } + const ggml_type type_k = params.type_k; + const ggml_type type_v = params.type_v; GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); @@ -18333,25 +19557,42 @@ struct llama_context * llama_new_context_with_model( } ctx->backends.push_back(ctx->backend_cpu); - if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) { + if (!llama_past_init(ctx->cache, ctx, type_k, type_v, cparams.offload_kqv)) { LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr; } - { + if (ctx->cache.rs.size > 0) { + size_t memory_size_r = 0; + size_t memory_size_s = 0; + + for (auto & r : ctx->cache.rs.r_l) { + memory_size_r += ggml_nbytes(r); + } + + for (auto & s : ctx->cache.rs.s_l) { + memory_size_s += ggml_nbytes(s); + } + + LLAMA_LOG_INFO("%s: SSM state size = %8.2f MiB, R (%s): %7.2f MiB, S (%s): %7.2f MiB\n", __func__, + (float)(memory_size_r + memory_size_s) / (1024.0f * 1024.0f), + ggml_type_name(GGML_TYPE_F32), (float)memory_size_r / (1024.0f * 1024.0f), + ggml_type_name(GGML_TYPE_F32), (float)memory_size_s / (1024.0f * 1024.0f)); + } + if (ctx->cache.kv.size > 0) { size_t memory_size_k = 0; size_t memory_size_v = 0; - for (auto & k : ctx->kv_self.k_l) { + for (auto & k : ctx->cache.kv.k_l) { memory_size_k += ggml_nbytes(k); } - for (auto & v : ctx->kv_self.v_l) { + for (auto & v : ctx->cache.kv.v_l) { memory_size_v += ggml_nbytes(v); } - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + LLAMA_LOG_INFO("%s: KV cache size = %8.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); @@ -18466,7 +19707,11 @@ uint32_t llama_n_ubatch(const struct llama_context * ctx) { } uint32_t llama_n_seq_max(const struct llama_context * ctx) { - return ctx->kv_self.size; + if (ctx->cache.rs.size > 0) { + return ctx->cache.rs.size; + } else { + return ctx->cache.kv.size; + } } enum llama_vocab_type llama_vocab_type(const struct llama_model * model) { @@ -18482,6 +19727,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: + case LLM_ARCH_JAMBA: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_T5: case LLM_ARCH_T5ENCODER: @@ -18654,9 +19900,12 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) { bool llama_model_is_recurrent(const struct llama_model * model) { switch (model->arch) { - case LLM_ARCH_MAMBA: return true; - case LLM_ARCH_RWKV6: return true; - default: return false; + case LLM_ARCH_JAMBA: + case LLM_ARCH_MAMBA: + case LLM_ARCH_RWKV6: + return true; + default: + return false; } } @@ -18803,8 +20052,9 @@ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) { } void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) { - if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) { - view->n_cells = int32_t(ctx->kv_self.size); + const llama_kv_cache & kv_self = ctx->cache.kv; + if (uint32_t(view->n_cells) < kv_self.size || view->cells == nullptr) { + view->n_cells = int32_t(kv_self.size); void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells); GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells"); view->cells = (struct llama_kv_cache_view_cell *)p; @@ -18813,7 +20063,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k view->cells_sequences = (llama_seq_id *)p; } - const std::vector & kv_cells = ctx->kv_self.cells; + const std::vector & kv_cells = kv_self.cells; llama_kv_cache_view_cell * c_curr = view->cells; llama_seq_id * cs_curr = view->cells_sequences; int32_t used_cells = 0; @@ -18822,7 +20072,7 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k uint32_t max_contig = 0; int32_t max_contig_idx = -1; - for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) { + for (int32_t i = 0; i < int32_t(kv_self.size); i++, c_curr++, cs_curr += view->n_seq_max) { const size_t curr_size = kv_cells[i].seq_id.size(); token_count += curr_size; c_curr->pos = kv_cells[i].pos + kv_cells[i].delta; @@ -18860,67 +20110,118 @@ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_k view->max_contiguous_idx = max_contig_idx; view->token_count = token_count; view->used_cells = used_cells; - if (uint32_t(used_cells) != ctx->kv_self.used) { + if (uint32_t(used_cells) != kv_self.used) { LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n", - __func__, ctx->kv_self.used, used_cells); + __func__, kv_self.used, used_cells); } } +bool llama_rs_cache_rebuild(struct llama_context * ctx, bool debug) { + return ctx->cache.rs.rebuild(debug); +} + int32_t llama_get_kv_cache_token_count(const struct llama_context * ctx) { int result = 0; - for (uint32_t i = 0; i < ctx->kv_self.size; i++) { - result += ctx->kv_self.cells[i].seq_id.size(); + for (uint32_t i = 0; i < ctx->cache.kv.size; i++) { + result += ctx->cache.kv.cells[i].seq_id.size(); } return result; } int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx) { - return ctx->kv_self.used; + return ctx->cache.kv.used; +} + +int32_t llama_get_rs_cache_used_cells(const struct llama_context * ctx) { + return ctx->cache.rs.used; +} + +void llama_past_clear(struct llama_context * ctx) { + llama_past_clear(ctx->cache); } +// deprecated void llama_kv_cache_clear(struct llama_context * ctx) { - llama_kv_cache_clear(ctx->kv_self); + llama_past_clear(ctx); } +llama_pos llama_past_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return 0; } + return llama_past_seq_rm(ctx->cache, seq_id, p0, p1); +} + +// deprecated bool llama_kv_cache_seq_rm(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - return llama_kv_cache_seq_rm(ctx->kv_self, seq_id, p0, p1); + llama_pos n_past = llama_past_seq_rm(ctx, seq_id, p0, p1); + return n_past >= p0; } -void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + +llama_pos llama_past_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + uint32_t n_seq_max = llama_n_seq_max(ctx); + if (seq_id_src < 0 || seq_id_dst < 0 || (uint32_t) seq_id_src >= n_seq_max || (uint32_t) seq_id_dst >= n_seq_max) { + return 0; + } if (seq_id_src == seq_id_dst) { - return; + return llama_past_seq_pos_max(ctx->cache, seq_id_dst) + 1; } - llama_kv_cache_seq_cp(ctx->kv_self, seq_id_src, seq_id_dst, p0, p1); + return llama_past_seq_cp(ctx->cache, seq_id_src, seq_id_dst, p0, p1); } +// deprecated +void llama_kv_cache_seq_cp(struct llama_context * ctx, llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + llama_past_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); +} + +void llama_past_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + llama_past_seq_keep(ctx->cache, seq_id); +} + +// deprecated void llama_kv_cache_seq_keep(struct llama_context * ctx, llama_seq_id seq_id) { - llama_kv_cache_seq_keep(ctx->kv_self, seq_id); + llama_past_seq_keep(ctx, seq_id); +} + +void llama_past_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + if (delta == 0) { return; } + + llama_past_seq_add(ctx->cache, seq_id, p0, p1, delta); } +// deprecated void llama_kv_cache_seq_add(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { - if (delta == 0) { - return; - } + llama_past_seq_add(ctx, seq_id, p0, p1, delta); +} - llama_kv_cache_seq_add(ctx->kv_self, seq_id, p0, p1, delta); +void llama_past_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return; } + if (d == 1) { return; } + + llama_past_seq_div(ctx->cache, seq_id, p0, p1, d); } +// deprecated void llama_kv_cache_seq_div(struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - if (d == 1) { - return; - } + llama_past_seq_div(ctx, seq_id, p0, p1, d); +} - llama_kv_cache_seq_div(ctx->kv_self, seq_id, p0, p1, d); +llama_pos llama_past_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { return -1; } + return llama_past_seq_pos_max(ctx->cache, seq_id); } +// deprecated llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id); + llama_pos max_pos = llama_past_seq_pos_max(ctx, seq_id); + return max_pos < 0 ? 0 : max_pos; } void llama_kv_cache_defrag(struct llama_context * ctx) { - llama_kv_cache_defrag(ctx->kv_self); + llama_kv_cache_defrag(ctx->cache.kv); } void llama_kv_cache_update(struct llama_context * ctx) { @@ -19052,8 +20353,28 @@ struct llama_data_write { } } + void write_rs_cache_meta(const llama_rs_cache & rs_self, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) { + + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = rs_self.cells[i]; + const llama_pos pos = cell.pos; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_nodes.size() : 0; + + write(&pos, sizeof(pos)); + write(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id) { + for (auto seq_node : cell.seq_nodes) { + write(&seq_node.seq_id, sizeof(seq_node.seq_id)); + } + } + } + } + } + void write_kv_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) { - const struct llama_kv_cache & kv_self = ctx->kv_self; + const struct llama_kv_cache & kv_self = ctx->cache.kv; const struct llama_hparams & hparams = ctx->model.hparams; const uint32_t v_trans = kv_self.v_trans ? 1 : 0; @@ -19062,12 +20383,10 @@ struct llama_data_write { write(&v_trans, sizeof(v_trans)); write(&n_layer, sizeof(n_layer)); - std::vector tmp_buf; - // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); // Write key type const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; @@ -19087,7 +20406,7 @@ struct llama_data_write { if (!kv_self.v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Write value type const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; @@ -19108,7 +20427,7 @@ struct llama_data_write { // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = kv_self.size; for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Write value type const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; @@ -19135,43 +20454,151 @@ struct llama_data_write { } } - void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) { - const struct llama_kv_cache & kv_self = ctx->kv_self; - std::vector> cell_ranges; // ranges, from inclusive, to exclusive - uint32_t cell_count = 0; + void write_rs_cache_data(const struct llama_context * ctx, const std::vector> & cell_ranges) { + const struct llama_rs_cache & rs_self = ctx->cache.rs; + const struct llama_hparams & hparams = ctx->model.hparams; - // Count the number of cells with the specified seq_id - // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = kv_self.size; - for (uint32_t i = 0; i < kv_self.size; ++i) { - const auto & cell = kv_self.cells[i]; - if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { - ++cell_count; - if (cell_range_begin == kv_self.size) { - cell_range_begin = i; - } - } else { - if (cell_range_begin != kv_self.size) { - cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = kv_self.size; + const uint32_t n_layer = hparams.n_layer; + + write(&n_layer, sizeof(n_layer)); + + // Iterate and write all recurrent states, each row is a cell + // Get whole range at a time + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_r = hparams.n_embd_r(il); + + // Write type + const int32_t r_type_i = (int32_t)rs_self.r_l[il]->type; + write(&r_type_i, sizeof(r_type_i)); + + // Write row size + const uint64_t r_size_row = ggml_row_size(rs_self.r_l[il]->type, n_embd_r); + write(&r_size_row, sizeof(r_size_row)); + + // Read each range of cells of r_size length each and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * r_size_row; + write_tensor_data(rs_self.r_l[il], range.first * r_size_row, buf_size); + } + } + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_s = hparams.n_embd_s(il); + + // Write type + const int32_t s_type_i = (int32_t)rs_self.s_l[il]->type; + write(&s_type_i, sizeof(s_type_i)); + + // Write row size + const uint64_t s_size_row = ggml_row_size(rs_self.s_l[il]->type, n_embd_s); + write(&s_size_row, sizeof(s_size_row)); + + // Read each range of cells of s_size length each and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * s_size_row; + write_tensor_data(rs_self.s_l[il], range.first * s_size_row, buf_size); + } + } + } + + void write_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) { + const struct llama_kv_cache & kv_self = ctx->cache.kv; + const struct llama_rs_cache & rs_self = ctx->cache.rs; + std::vector> kv_cell_ranges; // ranges, from inclusive, to exclusive + std::vector> rs_cell_ranges; // ranges, from inclusive, to exclusive + uint32_t kv_cell_count = 0; + uint32_t rs_cell_count = 0; + // Transformer KV cache + { + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = kv_self.size; + for (uint32_t i = 0; i < kv_self.size; ++i) { + const auto & cell = kv_self.cells[i]; + if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + ++kv_cell_count; + if (cell_range_begin == kv_self.size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != kv_self.size) { + kv_cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = kv_self.size; + } } } + if (cell_range_begin != kv_self.size) { + kv_cell_ranges.emplace_back(cell_range_begin, kv_self.size); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : kv_cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(kv_cell_count == cell_count_check); } - if (cell_range_begin != kv_self.size) { - cell_ranges.emplace_back(cell_range_begin, kv_self.size); + // Recurrent state cache + if (seq_id == -1) { + // Find all the ranges of cells + uint32_t cell_range_begin = rs_self.size; + for (uint32_t i = 0; i < rs_self.size; ++i) { + const auto & cell = rs_self.cells[i]; + if (!cell.is_empty()) { + ++rs_cell_count; + if (cell_range_begin == rs_self.size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != rs_self.size) { + rs_cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = rs_self.size; + } + } + } + if (cell_range_begin != rs_self.size) { + rs_cell_ranges.emplace_back(cell_range_begin, rs_self.size); + } + + } else { + // Find the cell ranges of the specified seq_id + if ((size_t) seq_id < rs_self.seq_tails.size()) { + int32_t tail_cell_id = rs_self.seq_tails[seq_id].tail; + if (tail_cell_id >= 0) { + ++rs_cell_count; + rs_cell_ranges.emplace_back(tail_cell_id, tail_cell_id + 1); + } + } } - // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count - uint32_t cell_count_check = 0; - for (const auto & range : cell_ranges) { - cell_count_check += range.second - range.first; + { + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : rs_cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(rs_cell_count == cell_count_check); } - GGML_ASSERT(cell_count == cell_count_check); - write(&cell_count, sizeof(cell_count)); + write(&kv_cell_count, sizeof(kv_cell_count)); + write(&rs_cell_count, sizeof(rs_cell_count)); - write_kv_cache_meta(kv_self, cell_ranges, seq_id); - write_kv_cache_data(ctx, cell_ranges); + if (seq_id == -1) { + // write metadata for both when the whole cache needs to be saved + write_kv_cache_meta(kv_self, kv_cell_ranges, seq_id); + write_rs_cache_meta(rs_self, rs_cell_ranges, seq_id); + } else if (kv_cell_count > 0) { + write_kv_cache_meta(kv_self, kv_cell_ranges, seq_id); + } else { + write_rs_cache_meta(rs_self, rs_cell_ranges, seq_id); + } + if (kv_cell_count > 0) { + write_kv_cache_data(ctx, kv_cell_ranges); + } + if (rs_cell_count > 0) { + write_rs_cache_data(ctx, rs_cell_ranges); + } } }; @@ -19263,108 +20690,98 @@ struct llama_data_read { } } - bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) { - struct llama_kv_cache & kv_self = ctx->kv_self; + bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count) { + if (cell_count == 0) { return true; } + struct llama_past & cache = ctx->cache; + struct llama_kv_cache & kv_self = cache.kv; - if (dest_seq_id != -1) { - // single sequence + // whole KV cache restore - llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + if (cell_count > kv_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_kv_cell & cell = kv_self.cells[i]; - llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); - batch.n_tokens = cell_count; - batch.n_seq_tokens = cell_count; - batch.n_seqs = 1; + llama_pos pos; + uint32_t n_seq_id; - for (uint32_t i = 0; i < cell_count; ++i) { - llama_pos pos; - uint32_t n_seq_id; + read_to(&pos, sizeof(pos)); + read_to(&n_seq_id, sizeof(n_seq_id)); - read_to(&pos, sizeof(pos)); - read_to(&n_seq_id, sizeof(n_seq_id)); + cell.pos = pos; - if (n_seq_id != 0) { - LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + read_to(&seq_id, sizeof(seq_id)); + + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); return false; } - batch.pos[i] = pos; - } - batch.n_seq_id[0] = 1; - batch.seq_id[0] = &dest_seq_id; - if (!llama_kv_cache_find_slot(kv_self, batch)) { - LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); - return false; + cell.seq_id.insert(seq_id); } + } - // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) - // Assume that this is one contiguous block of cells - GGML_ASSERT(kv_self.head + cell_count <= kv_self.size); - GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]); - GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); - GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); - GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); - } else { - // whole KV cache restore + kv_self.head = 0; + kv_self.used = cell_count; - if (cell_count > kv_self.size) { - LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); - return false; - } + return true; + } - llama_kv_cache_clear(kv_self); + bool read_rs_cache_meta(struct llama_context * ctx, uint32_t cell_count) { + if (cell_count == 0) { return true; } + struct llama_past & cache = ctx->cache; + struct llama_rs_cache & rs_self = cache.rs; - for (uint32_t i = 0; i < cell_count; ++i) { - llama_kv_cell & cell = kv_self.cells[i]; + // whole RS cache restore - llama_pos pos; - uint32_t n_seq_id; + if (cell_count > rs_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in rs cache\n", __func__); + return false; + } - read_to(&pos, sizeof(pos)); - read_to(&n_seq_id, sizeof(n_seq_id)); + for (uint32_t i = 0; i < cell_count; ++i) { + llama_rs_cell & cell = rs_self.cells[i]; - cell.pos = pos; + llama_pos pos; + uint32_t n_seq_id; - for (uint32_t j = 0; j < n_seq_id; ++j) { - llama_seq_id seq_id; - read_to(&seq_id, sizeof(seq_id)); + read_to(&pos, sizeof(pos)); + read_to(&n_seq_id, sizeof(n_seq_id)); - if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { - LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); - return false; - } + cell.pos = pos; + cell.src = i; - cell.seq_id.insert(seq_id); + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + read_to(&seq_id, sizeof(seq_id)); - if (kv_self.recurrent) { - int32_t & tail = kv_self.cells[seq_id].tail; - if (tail != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); - return false; - } - tail = i; - } + if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); + return false; } - } - kv_self.head = 0; - kv_self.used = cell_count; - } + cell.insert_node(seq_id); - if (kv_self.recurrent) { - for (uint32_t i = 0; i < cell_count; ++i) { - uint32_t cell_id = kv_self.head + i; - // make sure the recurrent states will keep their restored state - kv_self.cells[cell_id].src = cell_id; } } + rs_self.head = 0; + rs_self.used = cell_count; + + rs_self.rebuild(/* debug */ false); + return true; } bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) { + if (cell_count == 0) { return true; } const struct llama_hparams & hparams = ctx->model.hparams; - struct llama_kv_cache & kv_self = ctx->kv_self; + struct llama_kv_cache & kv_self = ctx->cache.kv; uint32_t v_trans; uint32_t n_layer; read_to(&v_trans, sizeof(v_trans)); @@ -19385,7 +20802,7 @@ struct llama_data_read { // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); // Read type of key int32_t k_type_i_ref; @@ -19405,15 +20822,13 @@ struct llama_data_read { return false; } - if (cell_count) { - // Read and set the keys for the whole cell range - ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row); - } + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row); } if (!kv_self.v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Read type of value int32_t v_type_i_ref; @@ -19433,15 +20848,13 @@ struct llama_data_read { return false; } - if (cell_count) { - // Read and set the values for the whole cell range - ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row); - } + // Read and set the values for the whole cell range + ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row); } } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Read type of value int32_t v_type_i_ref; @@ -19469,29 +20882,174 @@ struct llama_data_read { return false; } - if (cell_count) { - // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el; - ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); - } + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el; + ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); } } } return true; } - void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) { - uint32_t cell_count; - read_to(&cell_count, sizeof(cell_count)); + bool read_rs_cache_data(struct llama_context * ctx, uint32_t cell_count) { + if (cell_count == 0) { return true; } + const struct llama_hparams & hparams = ctx->model.hparams; + struct llama_rs_cache & rs_self = ctx->cache.rs; + uint32_t n_layer; + read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != hparams.n_layer) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + return false; + } + if (cell_count > rs_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in rs cache to restore state (%u > %u)\n", __func__, cell_count, rs_self.size); + return false; + } + + // For each layer, one row is one cell, read as one contiguous block + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_r = hparams.n_embd_r(il); + + // Read type of key + int32_t r_type_i_ref; + read_to(&r_type_i_ref, sizeof(r_type_i_ref)); + const int32_t r_type_i = (int32_t)rs_self.r_l[il]->type; + if (r_type_i != r_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t r_size_row_ref; + read_to(&r_size_row_ref, sizeof(r_size_row_ref)); + const size_t r_size_row = ggml_row_size(rs_self.r_l[il]->type, n_embd_r); + if (r_size_row != r_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il); + return false; + } + + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(rs_self.r_l[il], read(cell_count * r_size_row), rs_self.head * r_size_row, cell_count * r_size_row); + } + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_s = hparams.n_embd_s(il); + + // Read type of key + int32_t s_type_i_ref; + read_to(&s_type_i_ref, sizeof(s_type_i_ref)); + const int32_t s_type_i = (int32_t)rs_self.s_l[il]->type; + if (s_type_i != s_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t s_size_row_ref; + read_to(&s_size_row_ref, sizeof(s_size_row_ref)); + const size_t s_size_row = ggml_row_size(rs_self.s_l[il]->type, n_embd_s); + if (s_size_row != s_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il); + return false; + } + + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(rs_self.s_l[il], read(cell_count * s_size_row), rs_self.head * s_size_row, cell_count * s_size_row); + } + + return true; + } + + bool read_cache_seq_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) { + + if (seq_id < 0 || seq_id >= (llama_seq_id) llama_n_seq_max(ctx)) { + LLAMA_LOG_ERROR("%s: seq_id out of range [0, %d): %d\n", __func__, llama_n_seq_max(ctx), seq_id); + return false; + } + + // single sequence + + llama_past & cache = ctx->cache; + llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + read_to(&pos, sizeof(pos)); + read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 0) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + batch.pos[i] = pos; + } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &seq_id; + if (!llama_past_find_slot(cache, batch)) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + + if (cache.kv.size > 0) { + // DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(cache.kv.head + cell_count <= cache.kv.size); + GGML_ASSERT(cache.kv.cells[cache.kv.head].pos == batch.pos[0]); + GGML_ASSERT(cache.kv.cells[cache.kv.head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cache.kv.cells[cache.kv.head].has_seq_id(seq_id)); + GGML_ASSERT(cache.kv.cells[cache.kv.head + cell_count - 1].has_seq_id(seq_id)); + } + if (cache.rs.size > 0) { + GGML_ASSERT(cache.rs.head + cache.rs.n <= cache.rs.size); + GGML_ASSERT(cache.rs.n == 1); + GGML_ASSERT(cache.rs.cells[cache.rs.head + cache.rs.n - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cache.rs.cells[cache.rs.head].has_seq_id(seq_id)); + GGML_ASSERT(cache.rs.cells[cache.rs.head + cache.rs.n - 1].has_seq_id(seq_id)); + // Prevent cells from being cleared + for (uint32_t i = cache.rs.head; i < cache.rs.head + cache.rs.n; ++i) { + cache.rs.cells[i].src = i; + } + } + + return true; + } + + void read_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) { + uint32_t kv_cell_count; + read_to(&kv_cell_count, sizeof(kv_cell_count)); + uint32_t rs_cell_count; + read_to(&rs_cell_count, sizeof(rs_cell_count)); + + bool res = true; + + if (seq_id == -1) { + llama_past_clear(ctx); + res = read_kv_cache_meta(ctx, kv_cell_count) && read_rs_cache_meta(ctx, rs_cell_count); + } else { + llama_past_seq_rm(ctx, seq_id, -1, -1); + // Only a single recurrent cell at most, + // because otherwise the cells can be shuffled when a slot is allocated + if (rs_cell_count > 1) { + LLAMA_LOG_ERROR("%s: too many recurrent state cells for single-sequence session\n", __func__); + res = false; + } + res = res && read_cache_seq_meta(ctx, std::max(kv_cell_count, rs_cell_count), seq_id); + } - bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count); + res = res && read_kv_cache_data(ctx, kv_cell_count) && read_rs_cache_data(ctx, rs_cell_count); if (!res) { if (seq_id == -1) { - llama_kv_cache_clear(ctx); + llama_past_clear(ctx); } else { - llama_kv_cache_seq_rm(ctx, seq_id, -1, -1); + llama_past_seq_rm(ctx, seq_id, -1, -1); } throw std::runtime_error("failed to restore kv cache"); } @@ -19646,7 +21204,7 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da data_ctx.write_logits(ctx); data_ctx.write_embeddings(ctx); - data_ctx.write_kv_cache(ctx); + data_ctx.write_cache(ctx); return data_ctx.get_size_written(); } @@ -19686,7 +21244,7 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da data_ctx.read_logits(ctx); data_ctx.read_embeddings(ctx); - data_ctx.read_kv_cache(ctx); + data_ctx.read_cache(ctx); return data_ctx.get_size_read(); } @@ -19782,7 +21340,7 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) { llama_synchronize(ctx); - data_ctx.write_kv_cache(ctx, seq_id); + data_ctx.write_cache(ctx, seq_id); return data_ctx.get_size_written(); } @@ -19805,7 +21363,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_ static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) { llama_synchronize(ctx); - data_ctx.read_kv_cache(ctx, dest_seq_id); + data_ctx.read_cache(ctx, dest_seq_id); return data_ctx.get_size_read(); } @@ -20006,11 +21564,19 @@ int32_t llama_encode( int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { +#ifndef NDEBUG + GGML_ASSERT(ctx->cache.rs.rebuild(true)); +#endif + const int ret = llama_decode_internal(*ctx, batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); } +#ifndef NDEBUG + GGML_ASSERT(ctx->cache.rs.rebuild(true)); +#endif + return ret; }