Skip to content

Commit

Permalink
examples : fix build
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Sep 4, 2024
1 parent 842d391 commit 5e20320
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 24 deletions.
10 changes: 5 additions & 5 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,21 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
lparams.mirostat_eta = params.mirostat_eta;

auto * result = new gpt_sampler {
.params = params,
.bias = llama_constraint_init_logit_bias(
/* .params = */ params,
/* .bias = */ llama_constraint_init_logit_bias(
model,
params.logit_bias.size(),
params.logit_bias.data()),
.pnlt = llama_constraint_init_penalties(
/* .pnlt = */ llama_constraint_init_penalties(
model,
params.penalty_last_n,
params.penalty_repeat,
params.penalty_freq,
params.penalty_present,
params.penalize_nl,
params.ignore_eos),
.grmr = llama_constraint_init_grammar(model, params.grammar.c_str(), "root"),
.smpl = llama_sampler_init(model, lparams)
/* .grmr = */ llama_constraint_init_grammar(model, params.grammar.c_str(), "root"),
/* .smpl = */ llama_sampler_init(model, lparams)
};

for (const auto & cnstr : params.constraints) {
Expand Down
20 changes: 10 additions & 10 deletions examples/batched.swift/Sources/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,24 @@ defer {
llama_free(context)
}

var sparams = llama_sampling_params()
var sparams = llama_sampler_params()
sparams.top_k = 40
sparams.top_p = 0.9
sparams.temp = 0.4

let smpl = llama_sampling_init(model, sparams)
let smpl = llama_sampler_init(model, sparams)
guard smpl != nil else {
print("Failed to initialize sampling")
exit(1)
}
defer {
llama_sampling_free(smpl)
llama_sampler_free(smpl)
}

llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(40, 1));
llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(0.9, 1));
llama_sampler_add_constraint(smpl, llama_constraint_init_temp (0.4));

let n_ctx = llama_n_ctx(context)

print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n")
Expand Down Expand Up @@ -138,15 +142,11 @@ while n_cur <= n_len {

var logits = llama_get_logits_ith(context, i_batch[i])

llama_sampling_set_logits(smpl, logits)

llama_sampling_top_k(smpl, nil)
llama_sampling_top_p(smpl, nil)
llama_sampling_temp (smpl, nil)
llama_sampler_set_logits(smpl, logits)

let new_token_id = llama_sampling_sample_dist(smpl, nil)
let new_token_id = llama_sampler_sample_dist(smpl, nil)

// const llama_token new_token_id = llama_sampling_sample_greedy(smpl, nil);
// const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nil, false);

// is it an end of stream? -> mark the stream as finished
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
Expand Down
4 changes: 2 additions & 2 deletions examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ int main(int argc, char ** argv) {
llama_sampler * smpl = llama_sampler_init(model, sparams);

llama_sampler_add_constraint(smpl, llama_constraint_init_top_k(params.sparams.top_k, params.sparams.min_keep));
llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_p));
llama_sampler_add_constraint(smpl, llama_constraint_init_top_p(params.sparams.top_p, params.sparams.min_keep));
llama_sampler_add_constraint(smpl, llama_constraint_init_temp (params.sparams.temp));

if (ctx == NULL) {
Expand Down Expand Up @@ -179,7 +179,7 @@ int main(int argc, char ** argv) {

const llama_token new_token_id = llama_sampler_sample_dist(smpl, nullptr);

//const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr);
//const llama_token new_token_id = llama_sampler_sample_greedy(smpl, nullptr, false);

// is it an end of generation? -> mark the stream as finished
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
Expand Down
6 changes: 3 additions & 3 deletions examples/llama.android/llama/src/main/cpp/llama-android.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
jobject intvar_ncur
) {
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto sampling = reinterpret_cast<llama_sampling *>(sampling_pointer);
const auto sampling = reinterpret_cast<llama_sampler *>(sampling_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const auto model = llama_get_model(context);

Expand All @@ -396,10 +396,10 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(

const auto * logits = llama_get_logits_ith(context, batch->n_tokens - 1);

llama_sampling_set_logits(sampling, logits);
llama_sampler_set_logits(sampling, logits);

// sample the most likely token
const auto new_token_id = llama_sampling_sample_greedy(sampling, nullptr);
const auto new_token_id = llama_sampler_sample_greedy(sampling, nullptr, false);

const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
Expand Down
8 changes: 4 additions & 4 deletions examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ actor LlamaContext {
self.tokens_list = []
self.batch = llama_batch_init(512, 0, 1)
self.temporary_invalid_cchars = []
self.sampling = llama_sampling_init(context, llama_sampling_default_params())
self.sampling = llama_sampler_init(context, llama_sampler_default_params())
}

deinit {
llama_sampling_free(sampling)
llama_sampler_free(sampling)
llama_batch_free(batch)
llama_free(context)
llama_free_model(model)
Expand Down Expand Up @@ -149,9 +149,9 @@ actor LlamaContext {
let n_vocab = llama_n_vocab(model)
let logits = llama_get_logits_ith(context, batch.n_tokens - 1)

llama_sampling_set_logits(sampling, logits);
llama_sampler_set_logits(sampling, logits);

new_token_id = llama_sampling_sample_greedy(sampling, nil)
new_token_id = llama_sampler_sample_greedy(sampling, nil, false)

if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
print("\n")
Expand Down

0 comments on commit 5e20320

Please sign in to comment.