Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow pooled embeddings on any model #7477

Merged
merged 6 commits into from
Jun 21, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
take out attention_type; add in llama_set_embeddings
iamlemec committed Jun 14, 2024
commit 8093253b41dcc475ba160c97a7435f41b746d04a
12 changes: 0 additions & 12 deletions common/common.cpp
Original file line number Diff line number Diff line change
@@ -546,17 +546,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
else { invalid_param = true; }
return true;
}
if (arg == "--attention") {
if (++i >= argc) {
invalid_param = true;
return true;
}
std::string value(argv[i]);
/**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL; }
else { invalid_param = true; }
return true;
}
if (arg == "--defrag-thold" || arg == "-dt") {
if (++i >= argc) {
invalid_param = true;
@@ -2460,7 +2449,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.pooling_type = params.pooling_type;
cparams.attention_type = params.attention_type;
cparams.defrag_thold = params.defrag_thold;
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
1 change: 0 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
@@ -94,7 +94,6 @@ struct gpt_params {
enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type

// // sampling parameters
struct llama_sampling_params sparams;
22 changes: 10 additions & 12 deletions examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
@@ -44,6 +44,8 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve

// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_cache_clear(ctx);
llama_set_embeddings(ctx, true);
Copy link
Collaborator

@ngxson ngxson Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a small question here: in the case when both embeddings and causal_attn are enabled, will it still be correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, it's possible to run with embeddings=true and causal_attn=true, as long as the underlying model supports causal attention. For the GritLM case, I just checked here, and it will run but give incorrect results since it expects the embeddings to be run non-causally.

llama_set_causal_attn(ctx, false);

// run model
llama_decode(ctx, batch);
@@ -97,6 +99,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
llama_token eos_token = llama_token_eos(mdl);

llama_kv_cache_clear(ctx);
llama_set_embeddings(ctx, false);
llama_set_causal_attn(ctx, true);

llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);

std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
@@ -165,13 +170,7 @@ int main(int argc, char * argv[]) {
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);

// create generation context
llama_context * ctx_gen = llama_new_context_with_model(mdl, cparams);

// create embedding context
cparams.embeddings = true;
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
cparams.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL;
llama_context * ctx_emb = llama_new_context_with_model(mdl, cparams);
llama_context * ctx = llama_new_context_with_model(mdl, cparams);

// ### Embedding/Representation ###
// samples taken from: https://github.com/ContextualAI/gritlm#basic
@@ -189,8 +188,8 @@ int main(int argc, char * argv[]) {
};

// No need to add instruction for retrieval documents
const std::vector<std::vector<float>> d_rep = encode(ctx_emb, documents, gritlm_instruction(""));
const std::vector<std::vector<float>> q_rep = encode(ctx_emb, queries, gritlm_instruction(instruction));
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));

const int n_embd = llama_n_embd(mdl);

@@ -209,11 +208,10 @@ int main(int argc, char * argv[]) {
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
{
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
std::string response = generate(ctx_gen, prompt, true);
std::string response = generate(ctx, prompt, true);
}

llama_free(ctx_gen);
llama_free(ctx_emb);
llama_free(ctx);
llama_free_model(mdl);
llama_backend_free();

12 changes: 5 additions & 7 deletions llama.cpp
Original file line number Diff line number Diff line change
@@ -15931,7 +15931,6 @@ struct llama_context_params llama_context_default_params() {
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
/*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED,
/*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED,
/*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 0.0f,
/*.yarn_ext_factor =*/ -1.0f,
@@ -16173,12 +16172,7 @@ struct llama_context * llama_new_context_with_model(
}

cparams.yarn_attn_factor *= hparams.rope_attn_factor;

if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) {
cparams.causal_attn = hparams.causal_attn;
} else {
cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL;
}
cparams.causal_attn = hparams.causal_attn;

if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) {
@@ -17914,6 +17908,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback)
ctx->abort_callback_data = abort_callback_data;
}

void llama_set_embeddings(struct llama_context * ctx, bool embeddings) {
ctx->cparams.embeddings = embeddings;
}

void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
ctx->cparams.causal_attn = causal_attn;
}
11 changes: 4 additions & 7 deletions llama.h
Original file line number Diff line number Diff line change
@@ -177,12 +177,6 @@ extern "C" {
LLAMA_POOLING_TYPE_LAST = 3,
};

enum llama_attention_type {
LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1,
LLAMA_ATTENTION_TYPE_CAUSAL = 0,
LLAMA_ATTENTION_TYPE_NONCAUSAL = 1,
};

enum llama_split_mode {
LLAMA_SPLIT_MODE_NONE = 0, // single GPU
LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs
@@ -300,7 +294,6 @@ extern "C" {

enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
enum llama_attention_type attention_type; // causal, non-causal, or unspecified

// ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model
@@ -793,6 +786,10 @@ extern "C" {
// Get the number of threads used for prompt and batch processing (multiple token).
LLAMA_API uint32_t llama_n_threads_batch(struct llama_context * ctx);

// Set whether the model is in embeddings model or not
iamlemec marked this conversation as resolved.
Show resolved Hide resolved
// If true, embeddings will be returned but logits will not
LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings);

// Set whether to use causal attention or not
// If set to true, the model will only attend to the past tokens
LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);