-
Notifications
You must be signed in to change notification settings - Fork 10.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow pooled embeddings on any model #7477
Changes from all commits
0105714
1756c4b
7c37ae9
d4e6972
8093253
5cc7b45
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,6 +44,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve | |
|
||
// clear previous kv_cache values (irrelevant for embeddings) | ||
llama_kv_cache_clear(ctx); | ||
llama_set_embeddings(ctx, true); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have a small question here: in the case when both There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In general, it's possible to run with |
||
llama_set_causal_attn(ctx, false); | ||
|
||
// run model | ||
|
@@ -98,7 +99,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo | |
llama_token eos_token = llama_token_eos(mdl); | ||
|
||
llama_kv_cache_clear(ctx); | ||
llama_set_embeddings(ctx, false); | ||
llama_set_causal_attn(ctx, true); | ||
|
||
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); | ||
|
||
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true); | ||
|
@@ -166,8 +169,7 @@ int main(int argc, char * argv[]) { | |
|
||
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); | ||
|
||
// create new context - set to embedding mode | ||
cparams.embeddings = true; | ||
// create generation context | ||
llama_context * ctx = llama_new_context_with_model(mdl, cparams); | ||
|
||
// ### Embedding/Representation ### | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -7435,6 +7435,50 @@ struct llm_build_context { | |||||
return lctx.inp_s_seq; | ||||||
} | ||||||
|
||||||
struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) { | ||||||
// find result_norm tensor for input | ||||||
struct ggml_tensor * inp = nullptr; | ||||||
for (int i = gf->n_nodes - 1; i >= 0; --i) { | ||||||
inp = gf->nodes[i]; | ||||||
if (strcmp(inp->name, "result_norm") == 0 || strcmp(inp->name, "result_embd") == 0) { | ||||||
break; | ||||||
} else { | ||||||
inp = nullptr; | ||||||
} | ||||||
} | ||||||
GGML_ASSERT(inp != nullptr && "missing result_norm/result_embd tensor"); | ||||||
|
||||||
struct ggml_tensor * cur; | ||||||
|
||||||
switch (pooling_type) { | ||||||
case LLAMA_POOLING_TYPE_MEAN: | ||||||
{ | ||||||
struct ggml_tensor * inp_mean = build_inp_mean(); | ||||||
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean); | ||||||
} break; | ||||||
case LLAMA_POOLING_TYPE_CLS: | ||||||
case LLAMA_POOLING_TYPE_LAST: | ||||||
{ | ||||||
struct ggml_tensor * inp_cls = build_inp_cls(); | ||||||
cur = ggml_get_rows(ctx0, inp, inp_cls); | ||||||
} break; | ||||||
case LLAMA_POOLING_TYPE_NONE: | ||||||
{ | ||||||
cur = inp; | ||||||
} break; | ||||||
default: | ||||||
{ | ||||||
GGML_ASSERT(false && "unknown pooling type"); | ||||||
} break; | ||||||
} | ||||||
|
||||||
cb(cur, "result_embd_pooled", -1); | ||||||
|
||||||
ggml_build_forward_expand(gf, cur); | ||||||
|
||||||
return gf; | ||||||
} | ||||||
|
||||||
struct ggml_cgraph * build_llama() { | ||||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); | ||||||
|
||||||
|
@@ -8415,8 +8459,6 @@ struct llm_build_context { | |||||
if (model.arch != LLM_ARCH_JINA_BERT_V2) { | ||||||
inp_pos = build_inp_pos(); | ||||||
} | ||||||
struct ggml_tensor * inp_mean = build_inp_mean(); | ||||||
struct ggml_tensor * inp_cls = build_inp_cls(); | ||||||
|
||||||
// construct input embeddings (token, type, position) | ||||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); | ||||||
|
@@ -8591,28 +8633,6 @@ struct llm_build_context { | |||||
cur = inpL; | ||||||
cb(cur, "result_embd", -1); | ||||||
|
||||||
// pooling layer | ||||||
switch (pooling_type) { | ||||||
case LLAMA_POOLING_TYPE_NONE: | ||||||
{ | ||||||
// nop | ||||||
} break; | ||||||
case LLAMA_POOLING_TYPE_MEAN: | ||||||
{ | ||||||
cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean); | ||||||
cb(cur, "result_embd_pooled", -1); | ||||||
} break; | ||||||
case LLAMA_POOLING_TYPE_CLS: | ||||||
{ | ||||||
cur = ggml_get_rows(ctx0, cur, inp_cls); | ||||||
cb(cur, "result_embd_pooled", -1); | ||||||
} break; | ||||||
case LLAMA_POOLING_TYPE_UNSPECIFIED: | ||||||
{ | ||||||
GGML_ASSERT(false && "Invalid pooling type"); | ||||||
} break; | ||||||
} | ||||||
|
||||||
ggml_build_forward_expand(gf, cur); | ||||||
|
||||||
return gf; | ||||||
|
@@ -11697,6 +11717,11 @@ static struct ggml_cgraph * llama_build_graph( | |||||
GGML_ASSERT(false); | ||||||
} | ||||||
|
||||||
// add on pooling layer | ||||||
if (lctx.cparams.embeddings) { | ||||||
result = llm.append_pooling(result); | ||||||
} | ||||||
|
||||||
llm.free(); | ||||||
|
||||||
return result; | ||||||
|
@@ -11786,7 +11811,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | |||||
// (!a || b) is a logical implication (a -> b) | ||||||
// !hparams.causal_attn -> !cparams.causal_attn | ||||||
(hparams.causal_attn || !cparams.causal_attn) && | ||||||
"causal attention with embedding models is not supported" | ||||||
"causal attention is not supported by this model" | ||||||
); | ||||||
|
||||||
if (lctx.inp_KQ_mask) { | ||||||
|
@@ -11918,6 +11943,37 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { | |||||
} | ||||||
} | ||||||
|
||||||
if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { | ||||||
const int64_t n_tokens = batch.n_tokens; | ||||||
|
||||||
GGML_ASSERT(lctx.inp_cls); | ||||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); | ||||||
|
||||||
uint32_t * data = (uint32_t *) lctx.inp_cls->data; | ||||||
memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); | ||||||
|
||||||
std::vector<int> last_pos(n_tokens, -1); | ||||||
std::vector<int> last_row(n_tokens, -1); | ||||||
|
||||||
for (int i = 0; i < n_tokens; ++i) { | ||||||
const llama_seq_id seq_id = batch.seq_id[i][0]; | ||||||
const llama_pos pos = batch.pos[i]; | ||||||
|
||||||
GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); | ||||||
|
||||||
if (pos >= last_pos[seq_id]) { | ||||||
last_pos[seq_id] = pos; | ||||||
last_row[seq_id] = i; | ||||||
} | ||||||
} | ||||||
|
||||||
for (int i = 0; i < n_tokens; ++i) { | ||||||
if (last_row[i] >= 0) { | ||||||
data[i] = last_row[i]; | ||||||
} | ||||||
} | ||||||
} | ||||||
|
||||||
if (kv_self.recurrent) { | ||||||
const int64_t n_kv = kv_self.n; | ||||||
|
||||||
|
@@ -11979,8 +12035,8 @@ static size_t llama_output_reserve(llama_context & lctx, size_t n_outputs) { | |||||
const auto n_embd = hparams.n_embd; | ||||||
|
||||||
// TODO: use a per-batch flag for logits presence instead | ||||||
const bool has_logits = cparams.causal_attn; | ||||||
const bool has_embd = cparams.embeddings && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); | ||||||
const bool has_logits = !cparams.embeddings; | ||||||
const bool has_embd = cparams.embeddings && (cparams.pooling_type == LLAMA_POOLING_TYPE_NONE); | ||||||
|
||||||
const size_t logits_size = has_logits ? n_vocab*n_outputs_max : 0; | ||||||
const size_t embd_size = has_embd ? n_embd*n_outputs_max : 0; | ||||||
|
@@ -12110,11 +12166,13 @@ static int llama_decode_internal( | |||||
std::vector<std::vector<llama_seq_id>> seq_id; | ||||||
|
||||||
// count outputs | ||||||
if (batch_all.logits) { | ||||||
if (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE) { | ||||||
n_outputs = n_tokens_all; | ||||||
} else if (batch_all.logits) { | ||||||
for (uint32_t i = 0; i < n_tokens_all; ++i) { | ||||||
n_outputs += batch_all.logits[i] != 0; | ||||||
} | ||||||
} else if (lctx.logits_all || (cparams.embeddings && cparams.pooling_type != LLAMA_POOLING_TYPE_NONE)) { | ||||||
} else if (lctx.logits_all) { | ||||||
n_outputs = n_tokens_all; | ||||||
} else { | ||||||
// keep last output only | ||||||
|
@@ -12245,30 +12303,13 @@ static int llama_decode_internal( | |||||
// no output | ||||||
res = nullptr; | ||||||
embd = nullptr; | ||||||
} else if (!hparams.causal_attn) { | ||||||
res = nullptr; // do not extract logits for embedding models such as BERT | ||||||
|
||||||
// token or sequence embeddings | ||||||
embd = gf->nodes[gf->n_nodes - 1]; | ||||||
|
||||||
GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0); | ||||||
} else if (cparams.embeddings) { | ||||||
// the embeddings could be in the second to last tensor, or any of the previous tensors | ||||||
int i_embd = gf->n_nodes - 2; | ||||||
for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) { | ||||||
i_embd = gf->n_nodes - i; | ||||||
if (i_embd < 0) { break; } | ||||||
embd = gf->nodes[i_embd]; | ||||||
} | ||||||
GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor"); | ||||||
|
||||||
// TODO: use a per-batch flag to know when to skip logits while keeping embeddings | ||||||
if (!cparams.causal_attn) { | ||||||
res = nullptr; // do not extract logits when not needed | ||||||
// skip computing logits | ||||||
// TODO: is this safe? | ||||||
gf->n_nodes = i_embd + 1; | ||||||
res = nullptr; // do not extract logits for embedding case | ||||||
embd = gf->nodes[gf->n_nodes - 1]; | ||||||
if (strcmp(embd->name, "result_embd_pooled") != 0) { | ||||||
embd = gf->nodes[gf->n_nodes - 2]; | ||||||
} | ||||||
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor"); | ||||||
} else { | ||||||
embd = nullptr; // do not extract embeddings when not needed | ||||||
GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); | ||||||
Comment on lines
-12248
to
12315
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So an embeddings model will crash on the first decode when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, though I can't think of any case where you'd use an embedding model without There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Well, There might be a need for a dedicated metadata key-value pair for embedding-only models if non-causal text generation models are a thing. (T5? Or is it causal?) Anyway, I think there should at least be some abstraction (exported in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see, so at least for now, it looks like Then I guess we want to assert
Comment on lines
12303
to
12315
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There are places that need to know when embeddings or logits will be output, like Lines 11064 to 11065 in cd93a28
This will need to be updated to reflect exactly how this affects what happens later in this function near the comments There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So can we get away with saying you're either getting logits or embeddings but never both, and that behavior is exclusively controlled by const bool has_logits = !cparams.embeddings;
const bool has_embd = cparams.embeddings; There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I can't really think of a use-case where both would be needed at the same time. Except maybe for a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, but for a given call to |
||||||
|
@@ -12337,11 +12378,10 @@ static int llama_decode_internal( | |||||
ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float)); | ||||||
} | ||||||
} break; | ||||||
case LLAMA_POOLING_TYPE_CLS: | ||||||
case LLAMA_POOLING_TYPE_MEAN: | ||||||
case LLAMA_POOLING_TYPE_CLS: | ||||||
case LLAMA_POOLING_TYPE_LAST: | ||||||
{ | ||||||
GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0); | ||||||
|
||||||
// extract sequence embeddings | ||||||
auto & embd_seq_out = lctx.embd_seq; | ||||||
embd_seq_out.clear(); | ||||||
|
@@ -17870,6 +17910,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback) | |||||
ctx->abort_callback_data = abort_callback_data; | ||||||
} | ||||||
|
||||||
void llama_set_embeddings(struct llama_context * ctx, bool embeddings) { | ||||||
ctx->cparams.embeddings = embeddings; | ||||||
} | ||||||
|
||||||
void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { | ||||||
ctx->cparams.causal_attn = causal_attn; | ||||||
} | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove support for
LLAMA_POOLING_TYPE_NONE
in theembedding
example?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mostly because we're not actually printing out the entire token level embeddings anyway. The way it was implemented before was essentially doing last token pooling (not necessarily the last position in the sequence though, just the last one in the order the batch was loaded), but now that last token pooling is an official option, may as well encourage the user to make that choice conciously.