From 14a8a06aee6016cbf0d9a361673b391b4f30da72 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 23 Oct 2024 12:21:30 +0200 Subject: [PATCH] squash! llama : rename batch.logits to batch.output Update examples/batched.swift/Sources/main.swift, examples/llama.android/llama/src/main/cpp/llama-android.cpp, examples/llama.swiftui/llama.cpp.swift/LibLlama.swift to use the new batch.output field instead of batch.logits. --- examples/batched.swift/Sources/main.swift | 6 +++--- examples/llama.android/llama/src/main/cpp/llama-android.cpp | 6 +++--- examples/llama.swiftui/llama.cpp.swift/LibLlama.swift | 4 ++-- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 10f2e7fd117a15..ef0571ea570ea6 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -99,11 +99,11 @@ for (i, token) in tokens.enumerated() { if let seq_id = batch.seq_id[i] { seq_id[0] = 0 } - batch.logits[i] = 0 + batch.output[i] = 0 } // llama_decode will output logits only for the last token of the prompt -batch.logits[Int(batch.n_tokens) - 1] = 1 +batch.output[Int(batch.n_tokens) - 1] = 1 if llama_decode(context, batch) != 0 { print("llama_decode() failed") @@ -166,7 +166,7 @@ while n_cur <= n_len { if let seq_id = batch.seq_id[Int(batch.n_tokens)] { seq_id[0] = Int32(i) } - batch.logits[Int(batch.n_tokens)] = 1 + batch.output[Int(batch.n_tokens)] = 1 i_batch[i] = batch.n_tokens diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index b3858ddfb61683..f614e33fe1d565 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -193,7 +193,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( common_batch_add(*batch, 0, i, { 0 }, false); } - batch->logits[batch->n_tokens - 1] = true; + batch->output[batch->n_tokens - 1] = true; llama_kv_cache_clear(context); const auto t_pp_start = ggml_time_us(); @@ -297,7 +297,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, for (int i = 0; i < n_tokens; ++i) { batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); } - batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); + batch->output = (int8_t *) malloc(sizeof(int8_t) * n_tokens); return reinterpret_cast(batch); } @@ -377,7 +377,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( } // llama_decode will output logits only for the last token of the prompt - batch->logits[batch->n_tokens - 1] = true; + batch->output[batch->n_tokens - 1] = true; if (llama_decode(context, *batch) != 0) { LOGe("llama_decode() failed"); diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 65cd4eb515c7f4..68c82b1f2aa08f 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -137,7 +137,7 @@ actor LlamaContext { let i = Int(i1) llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false) } - batch.logits[Int(batch.n_tokens) - 1] = 1 // true + batch.output[Int(batch.n_tokens) - 1] = 1 // true if llama_decode(context, batch) != 0 { print("llama_decode() failed") @@ -206,7 +206,7 @@ actor LlamaContext { for i in 0..