Skip to content

Commit

Permalink
llama : check all graph nodes when searching for result_embd_pooled (g…
Browse files Browse the repository at this point in the history
…gerganov#8956)


Co-authored-by: Stanisław Szymczyk <[email protected]>
  • Loading branch information
2 people authored and arthw committed Nov 15, 2024
1 parent a76ca96 commit f7aac82
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14726,12 +14726,15 @@ static int llama_decode_internal(
res = nullptr;
embd = nullptr;
} else if (cparams.embeddings) {
res = nullptr; // do not extract logits for embedding case
embd = gf->nodes[gf->n_nodes - 1];
if (strcmp(embd->name, "result_embd_pooled") != 0) {
embd = gf->nodes[gf->n_nodes - 2];
res = nullptr; // do not extract logits for embedding case
embd = nullptr;
for (int i = gf->n_nodes - 1; i >= 0; --i) {
if (strcmp(gf->nodes[i]->name, "result_embd_pooled") == 0) {
embd = gf->nodes[i];
break;
}
}
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor");
GGML_ASSERT(embd != nullptr && "missing embeddings tensor");
} else {
embd = nullptr; // do not extract embeddings when not needed
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor");
Expand Down

0 comments on commit f7aac82

Please sign in to comment.