Skip to content

Commit

Permalink
llama : remove sampling from llama_context
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Aug 10, 2024
1 parent 97ab664 commit 5a9753b
Show file tree
Hide file tree
Showing 25 changed files with 75 additions and 137 deletions.
10 changes: 5 additions & 5 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.kv_overrides.back().key[0] = 0;
}

if (params.sparams.seed == LLAMA_DEFAULT_SEED) {
params.sparams.seed = time(NULL);
}

return true;
}

Expand Down Expand Up @@ -294,8 +298,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa

if (arg == "-s" || arg == "--seed") {
CHECK_ARG
// TODO: this is temporary, in the future the sampling state will be moved fully to llama_sampling_context.
params.seed = std::stoul(argv[i]);
sparams.seed = std::stoul(argv[i]);
return true;
}
Expand Down Expand Up @@ -1414,7 +1416,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
options.push_back({ "*", " --verbose-prompt", "print a verbose prompt before generation (default: %s)", params.verbose_prompt ? "true" : "false" });
options.push_back({ "*", " --no-display-prompt", "don't print prompt at generation (default: %s)", !params.display_prompt ? "true" : "false" });
options.push_back({ "*", "-co, --color", "colorise output to distinguish prompt and user input from generations (default: %s)", params.use_color ? "true" : "false" });
options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", params.seed });
options.push_back({ "*", "-t, --threads N", "number of threads to use during generation (default: %d)", params.n_threads });
options.push_back({ "*", "-tb, --threads-batch N", "number of threads to use during batch and prompt processing (default: same as --threads)" });
options.push_back({ "speculative", "-td, --threads-draft N", "number of threads to use during generation (default: same as --threads)" });
Expand Down Expand Up @@ -1465,6 +1466,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
" --spm-infill", "use Suffix/Prefix/Middle pattern for infill (instead of Prefix/Suffix/Middle) as some models prefer this. (default: %s)", params.spm_infill ? "enabled" : "disabled" });

options.push_back({ "sampling" });
options.push_back({ "*", "-s, --seed SEED", "RNG seed (default: %d, use random seed for < 0)", sparams.seed });
options.push_back({ "*", " --samplers SAMPLERS", "samplers that will be used for generation in the order, separated by \';\'\n"
"(default: %s)", sampler_type_names.c_str() });
options.push_back({ "*", " --sampling-seq SEQUENCE",
Expand Down Expand Up @@ -2237,7 +2239,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.n_ubatch = params.n_ubatch;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.seed = params.seed;
cparams.logits_all = params.logits_all;
cparams.embeddings = params.embedding;
cparams.rope_scaling_type = params.rope_scaling_type;
Expand Down Expand Up @@ -3247,7 +3248,6 @@ void yaml_dump_non_result_info(FILE * stream, const gpt_params & params, const l

fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base);
fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale);
fprintf(stream, "seed: %u # default: -1 (random seed)\n", params.seed);
fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false");
fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false");
fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false");
Expand Down
2 changes: 0 additions & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,6 @@ enum dimre_method {
};

struct gpt_params {
uint32_t seed = LLAMA_DEFAULT_SEED; // RNG seed

int32_t n_threads = cpu_get_num_math();
int32_t n_threads_draft = -1;
int32_t n_threads_batch = -1; // number of threads to use for batch processing (-1 = use n_threads)
Expand Down
15 changes: 2 additions & 13 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,10 @@
#include <random>

struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model) {
auto result = llama_sampling_init(params, llama_sampling_init(model, params.grammar.c_str(), "root"));

result->owned = true;

return result;
}

struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl) {
struct llama_sampling_context * result = new llama_sampling_context();

result->params = params;
result->owned = false;
result->smpl = smpl;
result->smpl = llama_sampling_init(model, params.grammar.c_str(), "root");

result->prev.resize(params.n_prev);

Expand All @@ -27,9 +18,7 @@ struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_
}

void llama_sampling_free(struct llama_sampling_context * ctx) {
if (ctx->owned) {
llama_sampling_free(ctx->smpl);
}
llama_sampling_free(ctx->smpl);

delete ctx;
}
Expand Down
3 changes: 0 additions & 3 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ struct llama_sampling_context {
// mirostat sampler state
float mirostat_mu;

bool owned;

llama_sampling * smpl;

// TODO: replace with ring-buffer
Expand All @@ -86,7 +84,6 @@ struct llama_sampling_context {

// Create a new sampling context instance.
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, const struct llama_model * model);
struct llama_sampling_context * llama_sampling_init(const struct llama_sampling_params & params, struct llama_sampling * smpl);

void llama_sampling_free(struct llama_sampling_context * ctx);

Expand Down
14 changes: 9 additions & 5 deletions examples/batched.swift/Sources/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), mo
print("Failed to load model")
exit(1)
}

defer {
llama_free_model(model)
}
Expand All @@ -37,24 +36,29 @@ var tokens = tokenize(text: prompt, add_bos: true)
let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)

var context_params = llama_context_default_params()
context_params.seed = 1234
context_params.n_ctx = n_kv_req
context_params.n_batch = UInt32(max(n_len, n_parallel))
context_params.n_threads = 8
context_params.n_threads_batch = 8

let context = llama_new_context_with_model(model, context_params)
let smpl = llama_get_sampling(context)

guard context != nil else {
print("Failed to initialize context")
exit(1)
}

defer {
llama_free(context)
}

let smpl = llama_sampling_init(model, nil, nil)
guard smpl != nil else {
print("Failed to initialize sampling")
exit(1)
}
defer {
llama_sampling_free(smpl)
}

let n_ctx = llama_n_ctx(context)

print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n")
Expand Down
2 changes: 1 addition & 1 deletion examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ int main(int argc, char ** argv) {
ctx_params.n_batch = std::max(n_predict, n_parallel);

llama_context * ctx = llama_new_context_with_model(model, ctx_params);
llama_sampling * smpl = llama_get_sampling(ctx);
llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr);

if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
Expand Down
8 changes: 1 addition & 7 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,7 @@ int main(int argc, char ** argv) {

print_build_info();

if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}

fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);

std::mt19937 rng(params.seed);
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);

llama_backend_init();
llama_numa_init(params.numa);
Expand Down
2 changes: 0 additions & 2 deletions examples/eval-callback/eval-callback.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,6 @@ int main(int argc, char ** argv) {

print_build_info();

std::mt19937 rng(params.seed);

llama_backend_init();
llama_numa_init(params.numa);

Expand Down
10 changes: 6 additions & 4 deletions examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,10 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
return result;
}

static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
static std::string generate(llama_context * ctx, llama_sampling * smpl, const std::string & prompt, bool stream) {
std::string result;

const llama_model * model = llama_get_model(ctx);
llama_sampling * smpl = llama_get_sampling(ctx);
llama_token eos_token = llama_token_eos(model);

llama_kv_cache_clear(ctx);
Expand All @@ -117,7 +116,7 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
inputs.clear();

llama_decode(ctx, bat);
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
auto * logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);

auto candidates = std::vector<llama_token_data>(llama_n_vocab(model));
auto n_candidates = (int32_t)candidates.size();
Expand Down Expand Up @@ -173,6 +172,8 @@ int main(int argc, char * argv[]) {
// create generation context
llama_context * ctx = llama_new_context_with_model(model, cparams);

llama_sampling * smpl = llama_sampling_init(model, nullptr, nullptr);

// ### Embedding/Representation ###
// samples taken from: https://github.com/ContextualAI/gritlm#basic
{
Expand Down Expand Up @@ -209,9 +210,10 @@ int main(int argc, char * argv[]) {
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
{
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
std::string response = generate(ctx, prompt, true);
std::string response = generate(ctx, smpl, prompt, true);
}

llama_sampling_free(smpl);
llama_free(ctx);
llama_free_model(model);
llama_backend_free();
Expand Down
13 changes: 3 additions & 10 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,9 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
}

LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
print_build_info();

if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}

LOG_TEE("%s: seed = %u\n", __func__, params.seed);

std::mt19937 rng(params.seed);
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);

LOG("%s: llama backend init\n", __func__);
llama_backend_init();
Expand Down Expand Up @@ -351,7 +344,7 @@ int main(int argc, char ** argv) {

std::vector<llama_token> embd;

ctx_sampling = llama_sampling_init(sparams, llama_get_sampling(ctx));
ctx_sampling = llama_sampling_init(sparams, model);

while (n_remain != 0 || params.interactive) {
// predict
Expand Down
4 changes: 2 additions & 2 deletions examples/llama.android/llama/src/main/cpp/llama-android.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo
LOGi("Using %d threads", n_threads);

llama_context_params ctx_params = llama_context_default_params();
ctx_params.seed = 1234;
ctx_params.n_ctx = 2048;
ctx_params.n_threads = n_threads;
ctx_params.n_threads_batch = n_threads;
Expand Down Expand Up @@ -380,12 +379,13 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
JNIEnv * env,
jobject,
jlong context_pointer,
jlong sampling_pointer,
jlong batch_pointer,
jint n_len,
jobject intvar_ncur
) {
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto sampling = reinterpret_cast<llama_sampling *>(llama_get_sampling(context));
const auto sampling = reinterpret_cast<llama_sampling *>(sampling_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const auto model = llama_get_model(context);

Expand Down
5 changes: 2 additions & 3 deletions examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,14 @@ actor LlamaContext {
self.tokens_list = []
self.batch = llama_batch_init(512, 0, 1)
self.temporary_invalid_cchars = []
self.sampling = llama_get_sampling(context)
self.sampling = llama_sampling_init(context, nil, nil);
}

deinit {
llama_sampling_free(sampling)
llama_batch_free(batch)
llama_free(context)
llama_free_model(model)
llama_sampling_free(sampling)
llama_backend_free()
}

Expand All @@ -72,7 +72,6 @@ actor LlamaContext {
print("Using \(n_threads) threads")

var ctx_params = llama_context_default_params()
ctx_params.seed = 1234
ctx_params.n_ctx = 2048
ctx_params.n_threads = UInt32(n_threads)
ctx_params.n_threads_batch = UInt32(n_threads)
Expand Down
2 changes: 1 addition & 1 deletion examples/llava/llava-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_

LOG_TEE("\n");

struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, llama_get_sampling(ctx_llava->ctx_llama));
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, ctx_llava->model);
if (!ctx_sampling) {
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
exit(1);
Expand Down
6 changes: 3 additions & 3 deletions examples/llava/minicpmv-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_llama,
int * n_past) {
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
llama_sampling_accept(ctx_sampling, id, true);
static std::string ret;
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
ret = "</s>";
Expand Down Expand Up @@ -218,7 +218,7 @@ static struct llama_sampling_context * llama_init(struct llava_context * ctx_lla

LOG_TEE("\n");

struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams, ctx_llava->model);
return ctx_sampling;
}

Expand Down Expand Up @@ -299,7 +299,7 @@ int main(int argc, char ** argv) {
}
}
printf("\n");
llama_print_timings(ctx_llava->ctx_llama);
llama_print_timings(ctx_llava->ctx_llama, nullptr);

ctx_llava->model = NULL;
llava_free(ctx_llava);
Expand Down
3 changes: 1 addition & 2 deletions examples/lookahead/lookahead.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "common.h"
#include "llama.h"

#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
Expand Down Expand Up @@ -118,7 +117,7 @@ int main(int argc, char ** argv) {
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);

// target model sampling context
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, llama_get_sampling(ctx));
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, model);

// verification n-grams
std::vector<ngram_data> ngrams_cur(G);
Expand Down
4 changes: 1 addition & 3 deletions examples/lookup/lookup.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@
#include "common.h"
#include "ngram-cache.h"

#include <cmath>
#include <cstdint>
#include <cstdio>
#include <fstream>
#include <string>
#include <vector>
#include <unordered_map>

int main(int argc, char ** argv){
gpt_params params;
Expand Down Expand Up @@ -106,7 +104,7 @@ int main(int argc, char ** argv){

bool has_eos = false;

struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, llama_get_sampling(ctx));
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams, model);

std::vector<llama_token> draft;

Expand Down
Loading

0 comments on commit 5a9753b

Please sign in to comment.