Skip to content

Commit

Permalink
llama : add llama_get_pooling_type function (ggerganov#6862)
Browse files Browse the repository at this point in the history
* add llama_get_pooling_type function

* fix argument name, move with ctx funcs
  • Loading branch information
iamlemec authored Apr 24, 2024
1 parent 3fe847b commit b4e4b8a
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
4 changes: 2 additions & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,8 @@ struct gpt_params {

ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED;

llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings

// // sampling parameters
struct llama_sampling_params sparams;
Expand Down
4 changes: 4 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15599,6 +15599,10 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
return LLAMA_ROPE_TYPE_NONE;
}

enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) {
return ctx->cparams.pooling_type;
}

int32_t llama_n_vocab(const struct llama_model * model) {
return model->hparams.n_vocab;
}
Expand Down
6 changes: 4 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,10 @@ extern "C" {
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);

LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);
LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx);

LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model);
LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model);

LLAMA_API int32_t llama_n_vocab (const struct llama_model * model);
LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);
Expand Down

0 comments on commit b4e4b8a

Please sign in to comment.