Skip to content

Commit

Permalink
Reworked to directly update KV cache params using info from name
Browse files Browse the repository at this point in the history
  • Loading branch information
agray3 committed Jul 27, 2024
1 parent 5289a6a commit 3241b3d
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 51 deletions.
11 changes: 1 addition & 10 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -552,13 +552,6 @@ extern "C" {
GGML_TENSOR_FLAG_PARAM = 4,
};

// Flag (used on GGML_OP_CPY nodes) on whether node is associated with K or V cache
enum ggml_kv_cache_flag {
GGML_KV_CACHE_FLAG_NONE = 0,
GGML_KV_CACHE_FLAG_K = 1,
GGML_KV_CACHE_FLAG_V = 2
};

// ggml object
struct ggml_object {
size_t offs;
Expand Down Expand Up @@ -593,8 +586,6 @@ extern "C" {
// op params - allocated as int32_t for alignment
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];

enum ggml_kv_cache_flag kv_cache_flag;

int32_t flags;

struct ggml_tensor * grad;
Expand All @@ -610,7 +601,7 @@ extern "C" {

void * extra; // extra things e.g. for ggml-cuda.cu

char padding[1];
// char padding[4];
};

static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
Expand Down
3 changes: 1 addition & 2 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -3638,7 +3638,6 @@ static struct ggml_tensor * ggml_new_tensor_impl(
/*.nb =*/ { 0, 0, 0, 0 },
/*.op =*/ GGML_OP_NONE,
/*.op_params =*/ { 0 },
/*.kv_cache_flag=*/ GGML_KV_CACHE_FLAG_NONE,
/*.flags =*/ 0,
/*.grad =*/ NULL,
/*.src =*/ { NULL },
Expand All @@ -3647,7 +3646,7 @@ static struct ggml_tensor * ggml_new_tensor_impl(
/*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data,
/*.name =*/ { 0 },
/*.extra =*/ NULL,
/*.padding =*/ { 0 },
///*.padding =*/ { 0 },
};

#ifdef __clang__
Expand Down
72 changes: 33 additions & 39 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7794,9 +7794,7 @@ static void llm_build_kv_store(
cb(k_cache_view, "k_cache_view", il);

// note: storing RoPE-ed version of K in the KV cache
ggml_tensor * tmp = ggml_cpy(ctx, k_cur, k_cache_view);
tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_K;
ggml_build_forward_expand(graph, tmp);
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));

assert(v_cur->ne[0] == n_embd_v_gqa && v_cur->ne[1] == n_tokens);

Expand All @@ -7814,9 +7812,7 @@ static void llm_build_kv_store(
v_cur = ggml_transpose(ctx, v_cur);
}
cb(v_cache_view, "v_cache_view", il);
tmp=ggml_cpy(ctx, v_cur, v_cache_view);
tmp->kv_cache_flag = GGML_KV_CACHE_FLAG_V;
ggml_build_forward_expand(graph, tmp);
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
}

static struct ggml_tensor * llm_build_norm(
Expand Down Expand Up @@ -14607,43 +14603,41 @@ static int llama_decode_internal(
}
lctx.cached_graph.gf = gf;

if(ggml_use_cached_graph(lctx.sched)) {

// Temporarily store KV cache parameters that will need updated in cached graph.
// Update K and V cache parameters in cached graph.
if(gf != nullptr && gf->nodes != nullptr && ggml_use_cached_graph(lctx.sched)) {

const struct llama_hparams & hparams = model.hparams;
const int64_t n_layer = hparams.n_layer;
const int64_t kv_head = kv_self.head;
std::vector<void *> k_cache_ptrs;
std::vector<void *> v_cache_ptrs;
for (int il = 0; il < n_layer; ++il) {
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
ggml_tensor * tmp_tensor = kv_self.k_l[il];
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
k_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
tmp_tensor = kv_self.v_l[il];
if (cparams.flash_attn) {
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
} else {
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
}
v_cache_ptrs.push_back(static_cast<char*>(tmp_tensor->data) + tmp_offset);
}

// Update KV cache parameters in cached graph.
int k_count = 0;
int v_count = 0;
if(gf != nullptr && gf->nodes != nullptr){
for (int i = 0; i < gf->n_nodes; i++) {
ggml_tensor * node = gf->nodes[i];
if (node->op == GGML_OP_CPY) {
if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_K) {
node->src[1]->data = k_cache_ptrs[k_count++];
}
if (node->kv_cache_flag == GGML_KV_CACHE_FLAG_V) {
node->src[1]->data = v_cache_ptrs[v_count++];
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();

for (int i = 0; i < gf->n_nodes; i++) {
ggml_tensor * node = gf->nodes[i];
if (node->op == GGML_OP_CPY) {

// K cache
const char* k_prefix = "k_cache_view-";
if (strncmp(node->src[1]->name, k_prefix, strlen(k_prefix)) == 0) {
int il = atoi(node->src[1]->name + strlen(k_prefix)); // Layer index from name
ggml_tensor * tmp_tensor = kv_self.k_l[il];
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
}

// V cache
const char* v_prefix = "v_cache_view-";
if (strncmp(node->src[1]->name, v_prefix, strlen(v_prefix)) == 0) {
int il = atoi(node->src[1]->name + strlen(v_prefix)); // Layer index from name
ggml_tensor * tmp_tensor = kv_self.v_l[il];
size_t tmp_offset;
if (cparams.flash_attn) {
tmp_offset = (kv_head)*ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
} else {
tmp_offset = (kv_head)*ggml_element_size(kv_self.v_l[il]);
}
node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
}

}
}

Expand Down

0 comments on commit 3241b3d

Please sign in to comment.