Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: attempt to reduce the impact of a worst-case scenario on defragmentation #6037

Merged
merged 3 commits into from
Mar 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9036,8 +9036,8 @@ static int llama_decode_internal(
//llama_synchronize(&lctx);

// decide if we need to defrag the kv cache
if (cparams.defrag_thold >= 0.0f) {
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used + n_tokens_all)/float(kv_self.n) : 0.0f;
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;

// queue defragmentation for next llama_kv_cache_update
if (fragmentation > cparams.defrag_thold) {
Expand Down Expand Up @@ -9069,6 +9069,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// number of cells moved
uint32_t n_moves = 0;

// each move requires 6*n_layer tensors (see build_defrag)
// - source view, destination view, copy operation
// - x2 for keys and values
const uint32_t max_moves = LLAMA_MAX_NODES/(6*n_layer);

// determine which KV cells to move where
//
// cell i moves to ids[i]
Expand All @@ -9095,15 +9100,6 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
nh++;
}

// each move requires 6*n_layer tensors (see build_defrag)
// - source view, destination view, copy operation
// - x2 for keys and values
//
if (6*(n_moves + nh)*n_layer >= LLAMA_MAX_NODES) {
// the graph is too big, we cannot move more cells
break;
}

uint32_t nf = 0;
uint32_t is = n_kv - 1;

Expand Down Expand Up @@ -9133,11 +9129,19 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
// are we moving a continuous block of memory?
bool cont = false;

// should we stop searching for the next move?
bool stop = false;

// go back and move the nf cells to the hole
for (; i1 < n_kv; ++i1) {
auto & cell1 = kv_self.cells[i1];

if (cell1.is_empty() || ids[i1] != n_kv) {
if (n_moves == max_moves) {
stop = true;
break;
}

cont = false;
continue;
}
Expand All @@ -9164,6 +9168,10 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
}
}

if (stop || n_moves == max_moves) {
break;
}

//LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, i1 + 1, i0, i0 + nh);

i0 += nh - 1;
Expand Down
Loading