Skip to content

Commit

Permalink
make n_embd_v_gqa_* dependent on layer
Browse files Browse the repository at this point in the history
  • Loading branch information
agray3 committed Jul 27, 2024
1 parent 3241b3d commit d9c7b61
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14608,8 +14608,6 @@ static int llama_decode_internal(

const struct llama_hparams & hparams = model.hparams;
const int64_t kv_head = kv_self.head;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();

for (int i = 0; i < gf->n_nodes; i++) {
ggml_tensor * node = gf->nodes[i];
Expand All @@ -14619,6 +14617,7 @@ static int llama_decode_internal(
const char* k_prefix = "k_cache_view-";
if (strncmp(node->src[1]->name, k_prefix, strlen(k_prefix)) == 0) {
int il = atoi(node->src[1]->name + strlen(k_prefix)); // Layer index from name
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
ggml_tensor * tmp_tensor = kv_self.k_l[il];
size_t tmp_offset = (ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa))*kv_head;
node->src[1]->data = static_cast<char*>(tmp_tensor->data) + tmp_offset;
Expand All @@ -14628,6 +14627,7 @@ static int llama_decode_internal(
const char* v_prefix = "v_cache_view-";
if (strncmp(node->src[1]->name, v_prefix, strlen(v_prefix)) == 0) {
int il = atoi(node->src[1]->name + strlen(v_prefix)); // Layer index from name
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
ggml_tensor * tmp_tensor = kv_self.v_l[il];
size_t tmp_offset;
if (cparams.flash_attn) {
Expand Down

0 comments on commit d9c7b61

Please sign in to comment.