From b4e4b8a9351d918a56831c73cf9f25c1837b80d1 Mon Sep 17 00:00:00 2001 From: Douglas Hanley Date: Wed, 24 Apr 2024 08:10:07 -0500 Subject: [PATCH] llama : add llama_get_pooling_type function (#6862) * add llama_get_pooling_type function * fix argument name, move with ctx funcs --- common/common.h | 4 ++-- llama.cpp | 4 ++++ llama.h | 6 ++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/common/common.h b/common/common.h index 157b54a3e9e08..87361e8e91500 100644 --- a/common/common.h +++ b/common/common.h @@ -86,8 +86,8 @@ struct gpt_params { ggml_numa_strategy numa = GGML_NUMA_STRATEGY_DISABLED; - llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; - llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings + enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; + enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings // // sampling parameters struct llama_sampling_params sparams; diff --git a/llama.cpp b/llama.cpp index 3a4a03d8f29fb..3a84b4916bd30 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15599,6 +15599,10 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { return LLAMA_ROPE_TYPE_NONE; } +enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { + return ctx->cparams.pooling_type; +} + int32_t llama_n_vocab(const struct llama_model * model) { return model->hparams.n_vocab; } diff --git a/llama.h b/llama.h index 7bfd13740cf25..0eb2a1e9ab0a2 100644 --- a/llama.h +++ b/llama.h @@ -390,8 +390,10 @@ extern "C" { LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx); LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx); - LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); - LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); + LLAMA_API enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx); + + LLAMA_API enum llama_vocab_type llama_vocab_type (const struct llama_model * model); + LLAMA_API enum llama_rope_type llama_rope_type (const struct llama_model * model); LLAMA_API int32_t llama_n_vocab (const struct llama_model * model); LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model);