From bebf5d741b53ffa241ff54c3d6a0492da95ddbe0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 6 Aug 2024 18:31:19 +0300 Subject: [PATCH] cont --- examples/batched.swift/Sources/main.swift | 13 +++++++++---- .../llama/src/main/cpp/llama-android.cpp | 3 ++- .../llama.swiftui/llama.cpp.swift/LibLlama.swift | 4 ++-- include/llama.h | 1 + 4 files changed, 14 insertions(+), 7 deletions(-) diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift index 2bc5fce7dfb6e..00c1dbecb480a 100644 --- a/examples/batched.swift/Sources/main.swift +++ b/examples/batched.swift/Sources/main.swift @@ -27,7 +27,6 @@ guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), mo print("Failed to load model") exit(1) } - defer { llama_free_model(model) } @@ -44,17 +43,23 @@ context_params.n_threads = 8 context_params.n_threads_batch = 8 let context = llama_new_context_with_model(model, context_params) -let smpl = llama_get_sampling(context) - guard context != nil else { print("Failed to initialize context") exit(1) } - defer { llama_free(context) } +let smpl = llama_sampling_init(model, nil, nil) +guard smpl != nil else { + print("Failed to initialize sampling") + exit(1) +} +defer { + llama_sampling_free(smpl) +} + let n_ctx = llama_n_ctx(context) print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n") diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index af3b356bc04b7..ed303e61ff714 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -380,12 +380,13 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( JNIEnv * env, jobject, jlong context_pointer, + jlong sampling_pointer, jlong batch_pointer, jint n_len, jobject intvar_ncur ) { const auto context = reinterpret_cast(context_pointer); - const auto sampling = reinterpret_cast(llama_get_sampling(context)); + const auto sampling = reinterpret_cast(sampling_pointer); const auto batch = reinterpret_cast(batch_pointer); const auto model = llama_get_model(context); diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index bfd273072e627..2a7f476ce3939 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -43,14 +43,14 @@ actor LlamaContext { self.tokens_list = [] self.batch = llama_batch_init(512, 0, 1) self.temporary_invalid_cchars = [] - self.sampling = llama_get_sampling(context) + self.sampling = llama_sampling_init(context, nil, nil); } deinit { + llama_sampling_free(sampling) llama_batch_free(batch) llama_free(context) llama_free_model(model) - llama_sampling_free(sampling) llama_backend_free() } diff --git a/include/llama.h b/include/llama.h index 9b1939297f48a..31210d828fa33 100644 --- a/include/llama.h +++ b/include/llama.h @@ -406,6 +406,7 @@ extern "C" { LLAMA_API void llama_free_model(struct llama_model * model); + // TODO: rename to llama_init_from_model LLAMA_API struct llama_context * llama_new_context_with_model( struct llama_model * model, struct llama_context_params params);