Skip to content

Commit

Permalink
llama : sketching new sampling API
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Sep 3, 2024
1 parent f648ca2 commit dcf1359
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 2 deletions.
39 changes: 37 additions & 2 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ extern "C" {
// struct llama_vocab; // TODO: add in the future
struct llama_model;
struct llama_context;
struct llama_sampler;
struct llama_constraint;
struct llama_sampling;

typedef int32_t llama_pos;
Expand Down Expand Up @@ -410,6 +412,11 @@ extern "C" {
bool ignore_eos; // ignore the end-of-sequence token
} llama_sampling_params;

typedef struct llama_sampler_params {
bool dummy;
// TODO: add type of sampler: greedy, dist, mirostat, etc.
} llama_sampler_params;

// performance timing information
struct llama_timings {
double t_start_ms;
Expand Down Expand Up @@ -438,8 +445,10 @@ extern "C" {
struct llama_lora_adapter;

// Helpers for getting default parameters
// TODO: update API to start accepting pointers to params structs (https://github.com/ggerganov/llama.cpp/discussions/9172)
LLAMA_API struct llama_model_params llama_model_default_params(void);
LLAMA_API struct llama_context_params llama_context_default_params(void);
LLAMA_API struct llama_sampler_params llama_sampler_default_params(void);
LLAMA_API struct llama_sampling_params llama_sampling_default_params(void);
LLAMA_API struct llama_model_quantize_params llama_model_quantize_default_params(void);

Expand All @@ -463,7 +472,7 @@ extern "C" {

LLAMA_API struct llama_model * llama_load_model_from_file(
const char * path_model,
struct llama_model_params params);
struct llama_model_params params);

LLAMA_API void llama_free_model(struct llama_model * model);

Expand Down Expand Up @@ -1029,7 +1038,7 @@ extern "C" {
int32_t length);

//
// Sampling functions
// Sampling API
//

// TODO: llama_model should become llama_vocab
Expand Down Expand Up @@ -1154,6 +1163,32 @@ extern "C" {
/// returns LLAMA_TOKEN_NULL if there are no accepted tokens
LLAMA_API llama_token llama_sampling_last(const struct llama_sampling * smpl);

//
// Sampling v2 API
//

// samplers

LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_params params);
LLAMA_API void llama_sampler_free ( struct llama_sampler * smpl);
LLAMA_API struct llama_sampler * llama_sampler_cp (const struct llama_sampler * smpl);
LLAMA_API void llama_sampler_reset( struct llama_sampler * smpl);

LLAMA_API void llama_sampler_add_constraint(struct llama_sampler * smpl, struct llama_constraint * cnstr);

LLAMA_API void llama_sampler_accept(struct llama_sampler * smpl, llama_token token);
LLAMA_API llama_token llama_sampler_sample(struct llama_sampler * smpl, const struct llama_context * ctx, int32_t i);

// constraints

LLAMA_API struct llama_constraint * llama_constraint_init_top_k(int32_t k, int32_t min_keep);
LLAMA_API struct llama_constraint * llama_constraint_init_top_p(float p, int32_t min_keep);
// ...
LLAMA_API void llama_constraint_free(struct llama_constraint * cnstr);

LLAMA_API void llama_constraint_accept(struct llama_constraint * cnstr, llama_token token);
LLAMA_API void llama_constraint_apply (struct llama_constraint * cnstr, llama_token_data_array * candidates);

//
// Model split
//
Expand Down
8 changes: 8 additions & 0 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17906,6 +17906,14 @@ struct llama_context_params llama_context_default_params() {
return result;
}

struct llama_sampler_params llama_sampler_default_params() {
struct llama_sampler_params result = {
/*.dummy =*/ false,
};

return result;
}

struct llama_sampling_params llama_sampling_default_params() {
struct llama_sampling_params result = {
/*.seed =*/ LLAMA_DEFAULT_SEED,
Expand Down

0 comments on commit dcf1359

Please sign in to comment.