Skip to content

Commit

Permalink
Updated Swift and Android bindings to use the new llama_sampling_* re…
Browse files Browse the repository at this point in the history
…factor from #8643
  • Loading branch information
HanClinto committed Jul 23, 2024
1 parent dbf8544 commit b6c9b53
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 9 deletions.
14 changes: 8 additions & 6 deletions examples/batched.swift/Sources/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ context_params.n_threads = 8
context_params.n_threads_batch = 8

let context = llama_new_context_with_model(model, context_params)
let smpl = llama_get_sampling(context)

guard context != nil else {
print("Failed to initialize context")
exit(1)
Expand Down Expand Up @@ -144,13 +146,13 @@ while n_cur <= n_len {
let top_p: Float = 0.9
let temp: Float = 0.4

llama_sample_top_k(context, &candidates_p, top_k, 1)
llama_sample_top_p(context, &candidates_p, top_p, 1)
llama_sample_temp(context, &candidates_p, temp)
llama_sampling_top_k(smpl, &candidates_p, top_k, 1)
llama_sampling_top_p(smpl, &candidates_p, top_p, 1)
llama_sampling_temp(smpl, &candidates_p, temp)

let new_token_id = llama_sample_token(context, &candidates_p)
let new_token_id = llama_sampling_sample(smpl, &candidates_p)

// const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
// const llama_token new_token_id = llama_sampling_sample_greedy(smpl, &candidates_p);

// is it an end of stream? -> mark the stream as finished
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
Expand Down Expand Up @@ -212,7 +214,7 @@ let t_main_end = ggml_time_us()

print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n")

llama_print_timings(context)
llama_print_timings(context, smpl, nil)

private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
let utf8Count = text.utf8.count
Expand Down
3 changes: 2 additions & 1 deletion examples/llama.android/llama/src/main/cpp/llama-android.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
jobject intvar_ncur
) {
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto sampling = reinterpret_cast<llama_sampling *>(llama_get_sampling(context));
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const auto model = llama_get_model(context);

Expand All @@ -405,7 +406,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };

// sample the most likely token
const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
const auto new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p);

const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
Expand Down
5 changes: 4 additions & 1 deletion examples/llama.swiftui/llama.cpp.swift/LibLlama.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
actor LlamaContext {
private var model: OpaquePointer
private var context: OpaquePointer
private var sampling: OpaquePointer
private var batch: llama_batch
private var tokens_list: [llama_token]
var is_done: Bool = false
Expand All @@ -42,12 +43,14 @@ actor LlamaContext {
self.tokens_list = []
self.batch = llama_batch_init(512, 0, 1)
self.temporary_invalid_cchars = []
self.sampling = llama_get_sampling(context)
}

deinit {
llama_batch_free(batch)
llama_free(context)
llama_free_model(model)
llama_sampling_free(sampling)
llama_backend_free()
}

Expand Down Expand Up @@ -156,7 +159,7 @@ actor LlamaContext {
candidates.withUnsafeMutableBufferPointer() { buffer in
var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)

new_token_id = llama_sample_token_greedy(context, &candidates_p)
new_token_id = llama_sampling_sample_greedy(sampling, &candidates_p)
}

if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
Expand Down
2 changes: 1 addition & 1 deletion include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1144,7 +1144,7 @@ extern "C" {
float * mu);

/// @details Selects the token with the highest probability.
/// Does not compute the token probabilities. Use llama_sample_softmax() instead.
/// Does not compute the token probabilities. Use llama_sampling_softmax() instead.
LLAMA_API llama_token llama_sampling_sample_greedy(
struct llama_sampling * smpl,
llama_token_data_array * candidates);
Expand Down

0 comments on commit b6c9b53

Please sign in to comment.