Skip to content

Commit

Permalink
llama : add check for KV cache shifts (#10401)
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov authored Nov 19, 2024
1 parent a88ad00 commit 8e752a7
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
6 changes: 6 additions & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,12 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}

if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
LOG_ERR("%s: KV cache shifting is not supported for this model (--no-context-shift to disable)'\n", __func__);
llama_free_model(model);
return iparams;
}

if (!params.control_vectors.empty()) {
if (params.control_vector_layer_start <= 0) params.control_vector_layer_start = 1;
if (params.control_vector_layer_end <= 0) params.control_vector_layer_end = llama_n_layer(model);
Expand Down
3 changes: 3 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,9 @@ extern "C" {
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);

// Check if the context supports KV cache shifting
LLAMA_API bool llama_kv_cache_can_shift(struct llama_context * ctx);

//
// State / sessions
//
Expand Down
6 changes: 5 additions & 1 deletion src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18213,7 +18213,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {

// apply K-shift if needed
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
if (lctx.model.arch == LLM_ARCH_DEEPSEEK2) { // not supported due to MLA
if (!llama_kv_cache_can_shift(&lctx)) {
GGML_ABORT("Deepseek2 does not support K-shift");
}

Expand Down Expand Up @@ -20462,6 +20462,10 @@ void llama_kv_cache_update(struct llama_context * ctx) {
llama_kv_cache_update_internal(*ctx);
}

bool llama_kv_cache_can_shift(struct llama_context * ctx) {
return ctx->model.arch != LLM_ARCH_DEEPSEEK2; // not supported due to MLA
}

// deprecated
size_t llama_get_state_size(struct llama_context * ctx) {
return llama_state_get_size(ctx);
Expand Down

0 comments on commit 8e752a7

Please sign in to comment.