From b4c1f2c1e8300421397ea24eef75f3fcc9c08452 Mon Sep 17 00:00:00 2001 From: Jhen Date: Tue, 19 Nov 2024 10:56:53 +0800 Subject: [PATCH 01/12] feat: support multi lora params --- android/src/main/CMakeLists.txt | 1 + .../main/java/com/rnllama/LlamaContext.java | 5 +- android/src/main/jni-utils.h | 94 +++++++++++++++++++ android/src/main/jni.cpp | 18 +++- example/ios/Podfile.lock | 4 +- example/src/App.tsx | 2 +- ios/RNLlamaContext.mm | 10 ++ src/NativeRNLlama.ts | 12 ++- 8 files changed, 140 insertions(+), 6 deletions(-) create mode 100644 android/src/main/jni-utils.h diff --git a/android/src/main/CMakeLists.txt b/android/src/main/CMakeLists.txt index ed77fa83..640dc4c8 100644 --- a/android/src/main/CMakeLists.txt +++ b/android/src/main/CMakeLists.txt @@ -33,6 +33,7 @@ set( ${RNLLAMA_LIB_DIR}/sgemm.cpp ${RNLLAMA_LIB_DIR}/ggml-aarch64.c ${RNLLAMA_LIB_DIR}/rn-llama.hpp + ${CMAKE_SOURCE_DIR}/jni-utils.h ${CMAKE_SOURCE_DIR}/jni.cpp ) diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index b369bb1a..594f9f0a 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -70,6 +70,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa params.hasKey("lora") ? params.getString("lora") : "", // float lora_scaled, params.hasKey("lora_scaled") ? (float) params.getDouble("lora_scaled") : 1.0f, + // ReadableArray lora_adapters, + params.hasKey("lora_adapters") ? params.getArray("lora_adapters") : null, // float rope_freq_base, params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f, // float rope_freq_scale @@ -406,6 +408,7 @@ protected static native long initContext( boolean vocab_only, String lora, float lora_scaled, + ReadableArray lora_adapters, float rope_freq_base, float rope_freq_scale, int pooling_type, @@ -457,7 +460,7 @@ protected static native WritableMap doCompletion( double[][] logit_bias, float dry_multiplier, float dry_base, - int dry_allowed_length, + int dry_allowed_length, int dry_penalty_last_n, String[] dry_sequence_breakers, PartialCompletionCallback partial_completion_callback diff --git a/android/src/main/jni-utils.h b/android/src/main/jni-utils.h new file mode 100644 index 00000000..39bde3e6 --- /dev/null +++ b/android/src/main/jni-utils.h @@ -0,0 +1,94 @@ +#include + +// ReadableMap utils + +namespace readablearray { + +int size(JNIEnv *env, jobject readableArray) { + jclass arrayClass = env->GetObjectClass(readableArray); + jmethodID sizeMethod = env->GetMethodID(arrayClass, "size", "()I"); + return env->CallIntMethod(readableArray, sizeMethod); +} + +jobject getMap(JNIEnv *env, jobject readableArray, int index) { + jclass arrayClass = env->GetObjectClass(readableArray); + jmethodID getMapMethod = env->GetMethodID(arrayClass, "getMap", "(I)Lcom/facebook/react/bridge/ReadableMap;"); + return env->CallObjectMethod(readableArray, getMapMethod, index); +} + +// Other methods not used yet + +} + +namespace readablemap { + +bool hasKey(JNIEnv *env, jobject readableMap, const char *key) { + jclass mapClass = env->GetObjectClass(readableMap); + jmethodID hasKeyMethod = env->GetMethodID(mapClass, "hasKey", "(Ljava/lang/String;)Z"); + jstring jKey = env->NewStringUTF(key); + jboolean result = env->CallBooleanMethod(readableMap, hasKeyMethod, jKey); + env->DeleteLocalRef(jKey); + return result; +} + +int getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) { + if (!hasKey(env, readableMap, key)) { + return defaultValue; + } + jclass mapClass = env->GetObjectClass(readableMap); + jmethodID getIntMethod = env->GetMethodID(mapClass, "getInt", "(Ljava/lang/String;)I"); + jstring jKey = env->NewStringUTF(key); + jint result = env->CallIntMethod(readableMap, getIntMethod, jKey); + env->DeleteLocalRef(jKey); + return result; +} + +bool getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) { + if (!hasKey(env, readableMap, key)) { + return defaultValue; + } + jclass mapClass = env->GetObjectClass(readableMap); + jmethodID getBoolMethod = env->GetMethodID(mapClass, "getBoolean", "(Ljava/lang/String;)Z"); + jstring jKey = env->NewStringUTF(key); + jboolean result = env->CallBooleanMethod(readableMap, getBoolMethod, jKey); + env->DeleteLocalRef(jKey); + return result; +} + +long getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) { + if (!hasKey(env, readableMap, key)) { + return defaultValue; + } + jclass mapClass = env->GetObjectClass(readableMap); + jmethodID getLongMethod = env->GetMethodID(mapClass, "getLong", "(Ljava/lang/String;)J"); + jstring jKey = env->NewStringUTF(key); + jlong result = env->CallLongMethod(readableMap, getLongMethod, jKey); + env->DeleteLocalRef(jKey); + return result; +} + +float getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) { + if (!hasKey(env, readableMap, key)) { + return defaultValue; + } + jclass mapClass = env->GetObjectClass(readableMap); + jmethodID getFloatMethod = env->GetMethodID(mapClass, "getDouble", "(Ljava/lang/String;)D"); + jstring jKey = env->NewStringUTF(key); + jfloat result = env->CallDoubleMethod(readableMap, getFloatMethod, jKey); + env->DeleteLocalRef(jKey); + return result; +} + +jstring getString(JNIEnv *env, jobject readableMap, const char *key, jstring defaultValue) { + if (!hasKey(env, readableMap, key)) { + return defaultValue; + } + jclass mapClass = env->GetObjectClass(readableMap); + jmethodID getStringMethod = env->GetMethodID(mapClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;"); + jstring jKey = env->NewStringUTF(key); + jstring result = (jstring) env->CallObjectMethod(readableMap, getStringMethod, jKey); + env->DeleteLocalRef(jKey); + return result; +} + +} diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 1ec3d19f..a2ed884a 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -12,6 +12,7 @@ #include "llama-impl.h" #include "ggml.h" #include "rn-llama.hpp" +#include "jni-utils.h" #define UNUSED(x) (void)(x) #define TAG "RNLLAMA_ANDROID_JNI" @@ -235,6 +236,7 @@ Java_com_rnllama_LlamaContext_initContext( jboolean vocab_only, jstring lora_str, jfloat lora_scaled, + jobject lora_adapters, jfloat rope_freq_base, jfloat rope_freq_scale, jint pooling_type, @@ -289,6 +291,20 @@ Java_com_rnllama_LlamaContext_initContext( defaultParams.lora_adapters.push_back({lora_chars, lora_scaled}); } + // lora_adapters: ReadableArray + int lora_adapters_size = readablearray::size(env, lora_adapters); + for (int i = 0; i < lora_adapters_size; i++) { + jobject lora_adapter = readablearray::getMap(env, lora_adapters, i); + jstring path = readablemap::getString(env, lora_adapter, "path", nullptr); + if (path != nullptr) { + const char *path_chars = env->GetStringUTFChars(path, nullptr); + env->ReleaseStringUTFChars(path, path_chars); + float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f); + + defaultParams.lora_adapters.push_back({path_chars, scaled}); + } + } + defaultParams.rope_freq_base = rope_freq_base; defaultParams.rope_freq_scale = rope_freq_scale; @@ -537,7 +553,7 @@ Java_com_rnllama_LlamaContext_doCompletion( jobjectArray logit_bias, jfloat dry_multiplier, jfloat dry_base, - jint dry_allowed_length, + jint dry_allowed_length, jint dry_penalty_last_n, jobjectArray dry_sequence_breakers, jobject partial_completion_callback diff --git a/example/ios/Podfile.lock b/example/ios/Podfile.lock index fa10b287..6b776f8b 100644 --- a/example/ios/Podfile.lock +++ b/example/ios/Podfile.lock @@ -8,7 +8,7 @@ PODS: - hermes-engine/Pre-built (= 0.72.3) - hermes-engine/Pre-built (0.72.3) - libevent (2.1.12) - - llama-rn (0.4.0): + - llama-rn (0.4.1): - RCT-Folly - RCTRequired - RCTTypeSafety @@ -1261,7 +1261,7 @@ SPEC CHECKSUMS: glog: 04b94705f318337d7ead9e6d17c019bd9b1f6b1b hermes-engine: 10fbd3f62405c41ea07e71973ea61e1878d07322 libevent: 4049cae6c81cdb3654a443be001fb9bdceff7913 - llama-rn: d935a3e23a8c1bb15ca58578af852c16d608bcaa + llama-rn: 763672c81a2903020663ad432f2051357e1f20ba RCT-Folly: 424b8c9a7a0b9ab2886ffe9c3b041ef628fd4fb1 RCTRequired: a2faf4bad4e438ca37b2040cb8f7799baa065c18 RCTTypeSafety: cb09f3e4747b6d18331a15eb05271de7441ca0b3 diff --git a/example/src/App.tsx b/example/src/App.tsx index 2cb4c59c..4c9f87e8 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -108,7 +108,7 @@ export default function App() { n_gpu_layers: Platform.OS === 'ios' ? 99 : 0, // > 0: enable GPU // embedding: true, - lora: loraFile?.uri, + lora_list: loraFile ? [{ path: loraFile.uri, scaled: 1.0 }] : undefined, // Or lora: loraFile?.uri, }, (progress) => { setMessages((msgs) => { diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index c2cb5932..34f41c67 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -116,6 +116,16 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig defaultParams.lora_adapters.push_back({[params[@"lora"] UTF8String], lora_scaled}); } + if (params[@"lora_list"] && [params[@"lora_list"] isKindOfClass:[NSArray class]]) { + NSArray *lora_list = params[@"lora_list"]; + for (NSDictionary *lora_adapter in lora_list) { + NSString *path = lora_adapter[@"path"]; + if (!path) continue; + float scale = [lora_adapter[@"scaled"] floatValue]; + defaultParams.lora_adapters.push_back({[path UTF8String], scale}); + } + } + if (params[@"rope_freq_base"]) defaultParams.rope_freq_base = [params[@"rope_freq_base"] floatValue]; if (params[@"rope_freq_scale"]) defaultParams.rope_freq_scale = [params[@"rope_freq_scale"] floatValue]; diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index 5278f394..215aa7e3 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -34,8 +34,18 @@ export type NativeContextParams = { use_mmap?: boolean vocab_only?: boolean - lora?: string // lora_adaptor + /** + * Single LoRA adapter path + */ + lora?: string + /** + * Single LoRA adapter scale + */ lora_scaled?: number + /** + * LoRA adapter list + */ + lora_list?: Array<{ path: string; scaled?: number }> rope_freq_base?: number rope_freq_scale?: number From 4eb05397f36436b165cea3d240e9d2a41bc87ef7 Mon Sep 17 00:00:00 2001 From: Jhen Date: Tue, 19 Nov 2024 11:47:23 +0800 Subject: [PATCH 02/12] feat: lora apply / remove / list for initialized context --- .../main/java/com/rnllama/LlamaContext.java | 19 ++++ .../src/main/java/com/rnllama/RNLlama.java | 90 +++++++++++++++++++ android/src/main/jni.cpp | 49 ++++++++++ .../java/com/rnllama/RNLlamaModule.java | 15 ++++ .../java/com/rnllama/RNLlamaModule.java | 15 ++++ cpp/common.cpp | 6 ++ cpp/common.h | 3 + cpp/rn-llama.hpp | 34 +++++++ example/ios/.xcode.env.local | 2 +- ios/RNLlama.mm | 28 ++++++ ios/RNLlamaContext.mm | 30 +++++++ scripts/common.cpp.patch | 27 ++++-- scripts/common.h.patch | 28 ++++-- src/NativeRNLlama.ts | 15 +++- 14 files changed, 343 insertions(+), 18 deletions(-) diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index 594f9f0a..66fe2d84 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -303,6 +303,22 @@ public String bench(int pp, int tg, int pl, int nr) { return bench(this.context, pp, tg, pl, nr); } + public int applyLoraAdapters(ReadableArray loraAdapters, boolean removePrevious) { + int result = applyLoraAdapters(this.context, loraAdapters, removePrevious); + if (result != 0) { + throw new IllegalStateException("Failed to apply lora adapters"); + } + return result; + } + + public void removeLoraAdapters() { + removeLoraAdapters(this.context); + } + + public WritableArray getLoadedLoraAdapters() { + return getLoadedLoraAdapters(this.context); + } + public void release() { freeContext(context); } @@ -476,6 +492,9 @@ protected static native WritableMap embedding( int embd_normalize ); protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr); + protected static native int applyLoraAdapters(long contextPtr, ReadableArray loraAdapters, boolean removePrevious); + protected static native void removeLoraAdapters(long contextPtr); + protected static native WritableArray getLoadedLoraAdapters(long contextPtr); protected static native void freeContext(long contextPtr); protected static native void logToAndroid(); } diff --git a/android/src/main/java/com/rnllama/RNLlama.java b/android/src/main/java/com/rnllama/RNLlama.java index aa19731e..f6a8ea3a 100644 --- a/android/src/main/java/com/rnllama/RNLlama.java +++ b/android/src/main/java/com/rnllama/RNLlama.java @@ -413,6 +413,96 @@ protected void onPostExecute(String result) { tasks.put(task, "bench-" + contextId); } + public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final boolean removePrevious, final Promise promise) { + final int contextId = (int) id; + AsyncTask task = new AsyncTask() { + private Exception exception; + + @Override + protected Void doInBackground(Void... voids) { + try { + LlamaContext context = contexts.get(contextId); + if (context == null) { + throw new Exception("Context not found"); + } + context.applyLoraAdapters(loraAdapters, removePrevious); + } catch (Exception e) { + exception = e; + } + return null; + } + + @Override + protected void onPostExecute(Void result) { + if (exception != null) { + promise.reject(exception); + return; + } + } + }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR); + tasks.put(task, "applyLoraAdapters-" + contextId); + } + + public void removeLoraAdapters(double id, final Promise promise) { + final int contextId = (int) id; + AsyncTask task = new AsyncTask() { + private Exception exception; + + @Override + protected Void doInBackground(Void... voids) { + try { + LlamaContext context = contexts.get(contextId); + if (context == null) { + throw new Exception("Context not found"); + } + context.removeLoraAdapters(); + } catch (Exception e) { + exception = e; + } + return null; + } + + @Override + protected void onPostExecute(Void result) { + if (exception != null) { + promise.reject(exception); + return; + } + } + }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR); + tasks.put(task, "removeLoraAdapters-" + contextId); + } + + public void getLoadedLoraAdapters(double id, final Promise promise) { + final int contextId = (int) id; + AsyncTask task = new AsyncTask() { + private Exception exception; + + @Override + protected ReadableArray doInBackground(Void... voids) { + try { + LlamaContext context = contexts.get(contextId); + if (context == null) { + throw new Exception("Context not found"); + } + return context.getLoadedLoraAdapters(); + } catch (Exception e) { + exception = e; + } + return null; + } + + @Override + protected void onPostExecute(ReadableArray result) { + if (exception != null) { + promise.reject(exception); + return; + } + } + }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR); + tasks.put(task, "getLoadedLoraAdapters-" + contextId); + } + public void releaseContext(double id, Promise promise) { final int contextId = (int) id; AsyncTask task = new AsyncTask() { diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index a2ed884a..dc5a7fb8 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -892,6 +892,55 @@ Java_com_rnllama_LlamaContext_bench( return env->NewStringUTF(result.c_str()); } +JNIEXPORT jint JNICALL +Java_com_rnllama_LlamaContext_applyLoraAdapters( + JNIEnv *env, jobject thiz, jlong context_ptr, jobjectArray loraAdapters, jboolean removePrevious) { + UNUSED(thiz); + auto llama = context_map[(long) context_ptr]; + + // lora_adapters: ReadableArray + std::vector lora_adapters; + int lora_adapters_size = readablearray::size(env, loraAdapters); + for (int i = 0; i < lora_adapters_size; i++) { + jobject lora_adapter = readablearray::getMap(env, loraAdapters, i); + jstring path = readablemap::getString(env, lora_adapter, "path", nullptr); + if (path != nullptr) { + const char *path_chars = env->GetStringUTFChars(path, nullptr); + env->ReleaseStringUTFChars(path, path_chars); + float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f); + + lora_adapters.push_back({path_chars, scaled}); + } + } + return llama->applyLoraAdapters(lora_adapters, removePrevious); +} + +JNIEXPORT void JNICALL +Java_com_rnllama_LlamaContext_removeLoraAdapters( + JNIEnv *env, jobject thiz, jlong context_ptr) { + UNUSED(env); + UNUSED(thiz); + auto llama = context_map[(long) context_ptr]; + llama->removeLoraAdapters(); +} + +JNIEXPORT jobject JNICALL +Java_com_rnllama_LlamaContext_getLoadedLoraAdapters( + JNIEnv *env, jobject thiz, jlong context_ptr) { + UNUSED(env); + UNUSED(thiz); + auto llama = context_map[(long) context_ptr]; + auto loaded_lora_adapters = llama->getLoadedLoraAdapters(); + auto result = createWritableArray(env); + for (common_lora_adapter_container &la : loaded_lora_adapters) { + auto map = createWriteableMap(env); + pushString(env, map, la.path.c_str()); + pushDouble(env, map, la.scale); + pushMap(env, result, map); + } + return result; +} + JNIEXPORT void JNICALL Java_com_rnllama_LlamaContext_freeContext( JNIEnv *env, jobject thiz, jlong context_ptr) { diff --git a/android/src/newarch/java/com/rnllama/RNLlamaModule.java b/android/src/newarch/java/com/rnllama/RNLlamaModule.java index 19077c87..e0f2d809 100644 --- a/android/src/newarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/newarch/java/com/rnllama/RNLlamaModule.java @@ -92,6 +92,21 @@ public void bench(double id, final double pp, final double tg, final double pl, rnllama.bench(id, pp, tg, pl, nr, promise); } + @ReactMethod + public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final boolean removePrevious, final Promise promise) { + rnllama.applyLoraAdapters(id, loraAdapters, removePrevious, promise); + } + + @ReactMethod + public void removeLoraAdapters(double id, final Promise promise) { + rnllama.removeLoraAdapters(id, promise); + } + + @ReactMethod + public void getLoadedLoraAdapters(double id, final Promise promise) { + rnllama.getLoadedLoraAdapters(id, promise); + } + @ReactMethod public void releaseContext(double id, Promise promise) { rnllama.releaseContext(id, promise); diff --git a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java index a96bf3ab..e4013ffd 100644 --- a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java @@ -93,6 +93,21 @@ public void bench(double id, final double pp, final double tg, final double pl, rnllama.bench(id, pp, tg, pl, nr, promise); } + @ReactMethod + public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final boolean removePrevious, final Promise promise) { + rnllama.applyLoraAdapters(id, loraAdapters, removePrevious, promise); + } + + @ReactMethod + public void removeLoraAdapters(double id, final Promise promise) { + rnllama.removeLoraAdapters(id, promise); + } + + @ReactMethod + public void getLoadedLoraAdapters(double id, final Promise promise) { + rnllama.getLoadedLoraAdapters(id, promise); + } + @ReactMethod public void releaseContext(double id, Promise promise) { rnllama.releaseContext(id, promise); diff --git a/cpp/common.cpp b/cpp/common.cpp index 3b527a16..ca3b2b5b 100644 --- a/cpp/common.cpp +++ b/cpp/common.cpp @@ -973,6 +973,12 @@ void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters) { + for (auto & la : lora_adapters) { + llama_lora_adapter_remove(ctx, la.adapter); + } +} + struct llama_model_params common_model_params_to_llama(const common_params & params) { auto mparams = llama_model_default_params(); diff --git a/cpp/common.h b/cpp/common.h index 4070df28..892d6f4d 100644 --- a/cpp/common.h +++ b/cpp/common.h @@ -476,6 +476,9 @@ struct llama_model * common_load_model_from_hf(const char * repo, const char * f // clear LoRA adapters from context, then apply new list of adapters void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters); +// remove LoRA adapters from context +void common_lora_adapters_remove(struct llama_context * ctx, std::vector & lora_adapters); + // Batch utils void common_batch_clear(struct llama_batch & batch); diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp index aa6f9bdd..b9c05ca1 100644 --- a/cpp/rn-llama.hpp +++ b/cpp/rn-llama.hpp @@ -229,6 +229,8 @@ struct llama_rn_context std::string stopping_word; bool incomplete = false; + std::vector lora_adapters; + ~llama_rn_context() { if (ctx) @@ -723,6 +725,38 @@ struct llama_rn_context std::to_string(tg_std) + std::string("]"); } + + int applyLoraAdapters(std::vector lora_adapters, bool remove_previous = false) { + if (remove_previous) { + common_lora_adapters_remove(ctx, this->lora_adapters); + this->lora_adapters.clear(); + } + auto containers = std::vector(); + for (auto & la : lora_adapters) { + common_lora_adapter_container loaded_la; + loaded_la.path = la.path; + loaded_la.scale = la.scale; + loaded_la.adapter = llama_lora_adapter_init(model, la.path.c_str()); + if (loaded_la.adapter == nullptr) { + LOG_ERROR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str()); + return -1; + } + + this->lora_adapters.push_back(loaded_la); + containers.push_back(loaded_la); + } + common_lora_adapters_apply(ctx, containers); + return 0; + } + + void removeLoraAdapters() { + common_lora_adapters_remove(ctx, this->lora_adapters); + this->lora_adapters.clear(); + } + + std::vector getLoadedLoraAdapters() { + return this->lora_adapters; + } }; } diff --git a/example/ios/.xcode.env.local b/example/ios/.xcode.env.local index 3f04e1e9..1fa37491 100644 --- a/example/ios/.xcode.env.local +++ b/example/ios/.xcode.env.local @@ -1 +1 @@ -export NODE_BINARY=/var/folders/4z/1d45cfts3936kdm7v9jl349r0000gn/T/yarn--1731819230881-0.9134550796855203/node +export NODE_BINARY=/var/folders/g8/v75_3l3n23g909mshlzdj4wh0000gn/T/yarn--1731985865125-0.724061577974688/node diff --git a/ios/RNLlama.mm b/ios/RNLlama.mm index 0f61f671..1763abb1 100644 --- a/ios/RNLlama.mm +++ b/ios/RNLlama.mm @@ -271,6 +271,34 @@ - (NSArray *)supportedEvents { } } +RCT_EXPORT_METHOD(applyLoraAdapters:(double)contextId + withLoraAdapters:(NSArray *)loraAdapters + removePrevious:(BOOL)removePrevious + withResolver:(RCTPromiseResolveBlock)resolve + withRejecter:(RCTPromiseRejectBlock)reject) +{ + RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]]; + [context applyLoraAdapters:loraAdapters removePrevious:removePrevious]; + resolve(nil); +} + +RCT_EXPORT_METHOD(removeLoraAdapters:(double)contextId + withResolver:(RCTPromiseResolveBlock)resolve + withRejecter:(RCTPromiseRejectBlock)reject) +{ + RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]]; + [context removeLoraAdapters]; + resolve(nil); +} + +RCT_EXPORT_METHOD(getLoadedLoraAdapters:(double)contextId + withResolver:(RCTPromiseResolveBlock)resolve + withRejecter:(RCTPromiseRejectBlock)reject) +{ + RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]]; + resolve([context getLoadedLoraAdapters]); +} + RCT_EXPORT_METHOD(releaseContext:(double)contextId withResolver:(RCTPromiseResolveBlock)resolve withRejecter:(RCTPromiseRejectBlock)reject) diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index 34f41c67..266d7129 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -548,6 +548,36 @@ - (NSString *)bench:(int)pp tg:(int)tg pl:(int)pl nr:(int)nr { return [NSString stringWithUTF8String:llama->bench(pp, tg, pl, nr).c_str()]; } +- (void)applyLoraAdapters:(NSArray *)loraAdapters removePrevious:(BOOL)removePrevious { + std::vector lora_adapters; + for (NSDictionary *loraAdapter in loraAdapters) { + common_lora_adapter_info la; + la.path = [loraAdapter[@"path"] UTF8String]; + la.scale = [loraAdapter[@"scaled"] doubleValue]; + lora_adapters.push_back(la); + } + int result = llama->applyLoraAdapters(lora_adapters, removePrevious); + if (result != 0) { + @throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to apply lora adapters" userInfo:nil]; + } +} + +- (void)removeLoraAdapters { + [llama removeLoraAdapters]; +} + +- (NSArray *)getLoadedLoraAdapters { + std::vector loaded_lora_adapters = llama->getLoadedLoraAdapters(); + NSMutableArray *result = [[NSMutableArray alloc] init]; + for (common_lora_adapter_container &la : loaded_lora_adapters) { + [result addObject:@{ + @"path": [NSString stringWithUTF8String:la.path.c_str()], + @"scale": @(la.scale) + }]; + } + return result; +} + - (void)invalidate { delete llama; // llama_backend_free(); diff --git a/scripts/common.cpp.patch b/scripts/common.cpp.patch index 1f4c63c5..515ff47d 100644 --- a/scripts/common.cpp.patch +++ b/scripts/common.cpp.patch @@ -1,5 +1,5 @@ ---- common.cpp.orig 2024-11-17 12:52:58 -+++ common.cpp 2024-11-17 12:48:35 +--- common.cpp.orig 2024-11-19 11:11:05 ++++ common.cpp 2024-11-19 11:07:02 @@ -4,10 +4,6 @@ #include "common.h" @@ -33,7 +33,19 @@ // // CPU utils -@@ -979,6 +979,8 @@ +@@ -973,12 +973,20 @@ + } + } + ++void common_lora_adapters_remove(struct llama_context * ctx, std::vector & lora_adapters) { ++ for (auto & la : lora_adapters) { ++ llama_lora_adapter_remove(ctx, la.adapter); ++ } ++} ++ + struct llama_model_params common_model_params_to_llama(const common_params & params) { + auto mparams = llama_model_default_params(); + if (params.n_gpu_layers != -1) { mparams.n_gpu_layers = params.n_gpu_layers; } @@ -42,7 +54,7 @@ mparams.rpc_servers = params.rpc_servers.c_str(); mparams.main_gpu = params.main_gpu; mparams.split_mode = params.split_mode; -@@ -993,6 +995,9 @@ +@@ -993,6 +1001,9 @@ mparams.kv_overrides = params.kv_overrides.data(); } @@ -52,10 +64,11 @@ return mparams; } -@@ -1118,220 +1123,6 @@ +@@ -1117,221 +1128,7 @@ + return false; } - +- -static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { - - // Initialize libcurl @@ -95,7 +108,7 @@ - nlohmann::json metadata; - std::string etag; - std::string last_modified; -- + - if (file_exists) { - // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). - std::ifstream metadata_in(metadata_path); diff --git a/scripts/common.h.patch b/scripts/common.h.patch index 9a2b365c..a17cff18 100644 --- a/scripts/common.h.patch +++ b/scripts/common.h.patch @@ -1,9 +1,10 @@ ---- common.h.orig 2024-11-17 11:56:40 -+++ common.h 2024-11-17 11:56:41 -@@ -41,6 +41,17 @@ +--- common.h.orig 2024-11-19 11:11:05 ++++ common.h 2024-11-19 11:07:10 +@@ -40,6 +40,17 @@ + extern char const * LLAMA_BUILD_TARGET; struct common_control_vector_load_info; - ++ +#define print_build_info() do { \ + fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \ + fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ @@ -14,10 +15,9 @@ +extern char const *LLAMA_COMMIT; +extern char const *LLAMA_COMPILER; +extern char const *LLAMA_BUILD_TARGET; -+ + // // CPU utils - // @@ -154,6 +165,7 @@ }; @@ -26,13 +26,23 @@ int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 4096; // context size int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) -@@ -270,6 +282,9 @@ +@@ -269,6 +281,9 @@ + bool no_kv_offload = false; // disable KV offloading bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data - ++ + llama_progress_callback progress_callback; + void * progress_callback_user_data; -+ + std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_v = "f16"; // KV cache data type for the V +@@ -461,6 +476,9 @@ + // clear LoRA adapters from context, then apply new list of adapters + void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters); + ++// remove LoRA adapters from context ++void common_lora_adapters_remove(struct llama_context * ctx, std::vector & lora_adapters); ++ + // Batch utils + void common_batch_clear(struct llama_batch & batch); diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index 215aa7e3..befd93b6 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -247,7 +247,10 @@ export interface Spec extends TurboModule { setContextLimit(limit: number): Promise modelInfo(path: string, skip?: string[]): Promise - initContext(contextId: number, params: NativeContextParams): Promise + initContext( + contextId: number, + params: NativeContextParams, + ): Promise getFormattedChat( contextId: number, @@ -283,6 +286,16 @@ export interface Spec extends TurboModule { nr: number, ): Promise + applyLoraAdapters( + contextId: number, + loraAdapters: Array<{ path: string; scaled?: number }>, + removePrevious: boolean, + ): Promise + removeLoraAdapters(contextId: number): Promise + getLoadedLoraAdapters( + contextId: number, + ): Promise> + releaseContext(contextId: number): Promise releaseAllContexts(): Promise From 1b6a1c390beb335409efbe2290ce5719c95ce326 Mon Sep 17 00:00:00 2001 From: Jhen Date: Tue, 19 Nov 2024 11:51:22 +0800 Subject: [PATCH 03/12] feat(ts): add methods --- src/index.ts | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/src/index.ts b/src/index.ts index d85d929c..7ac20ba9 100644 --- a/src/index.ts +++ b/src/index.ts @@ -214,6 +214,21 @@ export class LlamaContext { } } + async applyLoraAdapters( + loraAdapters: Array<{ path: string; scaled?: number }>, + removePrevious: boolean, + ): Promise { + return RNLlama.applyLoraAdapters(this.id, loraAdapters, removePrevious) + } + + async removeLoraAdapters(): Promise { + return RNLlama.removeLoraAdapters(this.id) + } + + async getLoadedLoraAdapters(): Promise> { + return RNLlama.getLoadedLoraAdapters(this.id) + } + async release(): Promise { return RNLlama.releaseContext(this.id) } From 22ee78719294990b053cceaef346b350877cec31 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Wed, 20 Nov 2024 16:25:45 +0800 Subject: [PATCH 04/12] fix(ts): lora list path --- src/index.ts | 45 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/src/index.ts b/src/index.ts index 7ac20ba9..573d59ac 100644 --- a/src/index.ts +++ b/src/index.ts @@ -14,7 +14,10 @@ import type { NativeCompletionTokenProbItem, NativeCompletionResultTimings, } from './NativeRNLlama' -import type { SchemaGrammarConverterPropOrder, SchemaGrammarConverterBuiltinRule } from './grammar' +import type { + SchemaGrammarConverterPropOrder, + SchemaGrammarConverterBuiltinRule, +} from './grammar' import { SchemaGrammarConverter, convertJsonSchemaToGrammar } from './grammar' import type { RNLlamaMessagePart, RNLlamaOAICompatibleMessage } from './chat' import { formatChat } from './chat' @@ -63,10 +66,26 @@ type TokenNativeEvent = { export type ContextParams = Omit< NativeContextParams, - 'cache_type_k' | 'cache_type_v' | 'pooling_type' + 'cache_type_k' | 'cache_type_v' | 'pooling_type' > & { - cache_type_k?: 'f16' | 'f32' | 'q8_0' | 'q4_0' | 'q4_1' | 'iq4_nl' | 'q5_0' | 'q5_1' - cache_type_v?: 'f16' | 'f32' | 'q8_0' | 'q4_0' | 'q4_1' | 'iq4_nl' | 'q5_0' | 'q5_1' + cache_type_k?: + | 'f16' + | 'f32' + | 'q8_0' + | 'q4_0' + | 'q4_1' + | 'iq4_nl' + | 'q5_0' + | 'q5_1' + cache_type_v?: + | 'f16' + | 'f32' + | 'q8_0' + | 'q4_0' + | 'q4_1' + | 'iq4_nl' + | 'q5_0' + | 'q5_1' pooling_type?: 'none' | 'mean' | 'cls' | 'last' | 'rank' } @@ -145,7 +164,10 @@ export class LlamaContext { let finalPrompt = params.prompt if (params.messages) { // messages always win - finalPrompt = await this.getFormattedChat(params.messages, params.chatTemplate) + finalPrompt = await this.getFormattedChat( + params.messages, + params.chatTemplate, + ) } let tokenListener: any = @@ -225,7 +247,9 @@ export class LlamaContext { return RNLlama.removeLoraAdapters(this.id) } - async getLoadedLoraAdapters(): Promise> { + async getLoadedLoraAdapters(): Promise< + Array<{ path: string; scaled?: number }> + > { return RNLlama.getLoadedLoraAdapters(this.id) } @@ -269,6 +293,7 @@ export async function initLlama( is_model_asset: isModelAsset, pooling_type: poolingType, lora, + lora_list: loraList, ...rest }: ContextParams, onProgress?: (progress: number) => void, @@ -279,6 +304,13 @@ export async function initLlama( let loraPath = lora if (loraPath?.startsWith('file://')) loraPath = loraPath.slice(7) + let loraAdapters: Array<{ path: string; scaled?: number }> = [] + if (loraList) + loraAdapters = loraList.map((l) => ({ + path: l.path.replace(/file:\/\//, ''), + scaled: l.scaled, + })) + const contextId = contextIdCounter + contextIdRandom() contextIdCounter += 1 @@ -304,6 +336,7 @@ export async function initLlama( use_progress_callback: !!onProgress, pooling_type: poolType, lora: loraPath, + lora_list: loraAdapters, ...rest, }).catch((err: any) => { removeProgressListener?.remove() From 4f392a50bcc386c71abfc5df97c70105b039633f Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Wed, 20 Nov 2024 16:38:30 +0800 Subject: [PATCH 05/12] fix: remove removePrevious --- android/src/main/java/com/rnllama/LlamaContext.java | 6 +++--- android/src/main/java/com/rnllama/RNLlama.java | 4 ++-- android/src/main/jni.cpp | 4 ++-- .../src/newarch/java/com/rnllama/RNLlamaModule.java | 4 ++-- .../src/oldarch/java/com/rnllama/RNLlamaModule.java | 4 ++-- cpp/rn-llama.hpp | 6 +----- ios/RNLlama.mm | 3 +-- ios/RNLlamaContext.h | 4 +++- ios/RNLlamaContext.mm | 4 ++-- src/NativeRNLlama.ts | 1 - src/index.ts | 11 ++++++++--- 11 files changed, 26 insertions(+), 25 deletions(-) diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index 66fe2d84..926b2030 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -303,8 +303,8 @@ public String bench(int pp, int tg, int pl, int nr) { return bench(this.context, pp, tg, pl, nr); } - public int applyLoraAdapters(ReadableArray loraAdapters, boolean removePrevious) { - int result = applyLoraAdapters(this.context, loraAdapters, removePrevious); + public int applyLoraAdapters(ReadableArray loraAdapters) { + int result = applyLoraAdapters(this.context, loraAdapters); if (result != 0) { throw new IllegalStateException("Failed to apply lora adapters"); } @@ -492,7 +492,7 @@ protected static native WritableMap embedding( int embd_normalize ); protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr); - protected static native int applyLoraAdapters(long contextPtr, ReadableArray loraAdapters, boolean removePrevious); + protected static native int applyLoraAdapters(long contextPtr, ReadableArray loraAdapters); protected static native void removeLoraAdapters(long contextPtr); protected static native WritableArray getLoadedLoraAdapters(long contextPtr); protected static native void freeContext(long contextPtr); diff --git a/android/src/main/java/com/rnllama/RNLlama.java b/android/src/main/java/com/rnllama/RNLlama.java index f6a8ea3a..2822209a 100644 --- a/android/src/main/java/com/rnllama/RNLlama.java +++ b/android/src/main/java/com/rnllama/RNLlama.java @@ -413,7 +413,7 @@ protected void onPostExecute(String result) { tasks.put(task, "bench-" + contextId); } - public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final boolean removePrevious, final Promise promise) { + public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) { final int contextId = (int) id; AsyncTask task = new AsyncTask() { private Exception exception; @@ -425,7 +425,7 @@ protected Void doInBackground(Void... voids) { if (context == null) { throw new Exception("Context not found"); } - context.applyLoraAdapters(loraAdapters, removePrevious); + context.applyLoraAdapters(loraAdapters); } catch (Exception e) { exception = e; } diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index dc5a7fb8..e774337d 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -894,7 +894,7 @@ Java_com_rnllama_LlamaContext_bench( JNIEXPORT jint JNICALL Java_com_rnllama_LlamaContext_applyLoraAdapters( - JNIEnv *env, jobject thiz, jlong context_ptr, jobjectArray loraAdapters, jboolean removePrevious) { + JNIEnv *env, jobject thiz, jlong context_ptr, jobjectArray loraAdapters) { UNUSED(thiz); auto llama = context_map[(long) context_ptr]; @@ -912,7 +912,7 @@ Java_com_rnllama_LlamaContext_applyLoraAdapters( lora_adapters.push_back({path_chars, scaled}); } } - return llama->applyLoraAdapters(lora_adapters, removePrevious); + return llama->applyLoraAdapters(lora_adapters); } JNIEXPORT void JNICALL diff --git a/android/src/newarch/java/com/rnllama/RNLlamaModule.java b/android/src/newarch/java/com/rnllama/RNLlamaModule.java index e0f2d809..0c154d4d 100644 --- a/android/src/newarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/newarch/java/com/rnllama/RNLlamaModule.java @@ -93,8 +93,8 @@ public void bench(double id, final double pp, final double tg, final double pl, } @ReactMethod - public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final boolean removePrevious, final Promise promise) { - rnllama.applyLoraAdapters(id, loraAdapters, removePrevious, promise); + public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) { + rnllama.applyLoraAdapters(id, loraAdapters, promise); } @ReactMethod diff --git a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java index e4013ffd..6da8e8fd 100644 --- a/android/src/oldarch/java/com/rnllama/RNLlamaModule.java +++ b/android/src/oldarch/java/com/rnllama/RNLlamaModule.java @@ -94,8 +94,8 @@ public void bench(double id, final double pp, final double tg, final double pl, } @ReactMethod - public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final boolean removePrevious, final Promise promise) { - rnllama.applyLoraAdapters(id, loraAdapters, removePrevious, promise); + public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) { + rnllama.applyLoraAdapters(id, loraAdapters); } @ReactMethod diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp index b9c05ca1..c1740474 100644 --- a/cpp/rn-llama.hpp +++ b/cpp/rn-llama.hpp @@ -726,11 +726,7 @@ struct llama_rn_context std::string("]"); } - int applyLoraAdapters(std::vector lora_adapters, bool remove_previous = false) { - if (remove_previous) { - common_lora_adapters_remove(ctx, this->lora_adapters); - this->lora_adapters.clear(); - } + int applyLoraAdapters(std::vector lora_adapters) { auto containers = std::vector(); for (auto & la : lora_adapters) { common_lora_adapter_container loaded_la; diff --git a/ios/RNLlama.mm b/ios/RNLlama.mm index 1763abb1..18692d13 100644 --- a/ios/RNLlama.mm +++ b/ios/RNLlama.mm @@ -273,12 +273,11 @@ - (NSArray *)supportedEvents { RCT_EXPORT_METHOD(applyLoraAdapters:(double)contextId withLoraAdapters:(NSArray *)loraAdapters - removePrevious:(BOOL)removePrevious withResolver:(RCTPromiseResolveBlock)resolve withRejecter:(RCTPromiseRejectBlock)reject) { RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]]; - [context applyLoraAdapters:loraAdapters removePrevious:removePrevious]; + [context applyLoraAdapters:loraAdapters]; resolve(nil); } diff --git a/ios/RNLlamaContext.h b/ios/RNLlamaContext.h index 52c4e92e..82bcccda 100644 --- a/ios/RNLlamaContext.h +++ b/ios/RNLlamaContext.h @@ -33,7 +33,9 @@ - (NSDictionary *)loadSession:(NSString *)path; - (int)saveSession:(NSString *)path size:(int)size; - (NSString *)bench:(int)pp tg:(int)tg pl:(int)pl nr:(int)nr; - +- (void)applyLoraAdapters:(NSArray *)loraAdapters; +- (void)removeLoraAdapters; +- (NSArray *)getLoadedLoraAdapters; - (void)invalidate; @end diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index 266d7129..307d14e0 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -548,7 +548,7 @@ - (NSString *)bench:(int)pp tg:(int)tg pl:(int)pl nr:(int)nr { return [NSString stringWithUTF8String:llama->bench(pp, tg, pl, nr).c_str()]; } -- (void)applyLoraAdapters:(NSArray *)loraAdapters removePrevious:(BOOL)removePrevious { +- (void)applyLoraAdapters:(NSArray *)loraAdapters { std::vector lora_adapters; for (NSDictionary *loraAdapter in loraAdapters) { common_lora_adapter_info la; @@ -556,7 +556,7 @@ - (void)applyLoraAdapters:(NSArray *)loraAdapters removePrevious:(BOOL)removePre la.scale = [loraAdapter[@"scaled"] doubleValue]; lora_adapters.push_back(la); } - int result = llama->applyLoraAdapters(lora_adapters, removePrevious); + int result = llama->applyLoraAdapters(lora_adapters); if (result != 0) { @throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to apply lora adapters" userInfo:nil]; } diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index befd93b6..e69f37f9 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -289,7 +289,6 @@ export interface Spec extends TurboModule { applyLoraAdapters( contextId: number, loraAdapters: Array<{ path: string; scaled?: number }>, - removePrevious: boolean, ): Promise removeLoraAdapters(contextId: number): Promise getLoadedLoraAdapters( diff --git a/src/index.ts b/src/index.ts index 573d59ac..57fa2089 100644 --- a/src/index.ts +++ b/src/index.ts @@ -237,10 +237,15 @@ export class LlamaContext { } async applyLoraAdapters( - loraAdapters: Array<{ path: string; scaled?: number }>, - removePrevious: boolean, + loraList: Array<{ path: string; scaled?: number }> ): Promise { - return RNLlama.applyLoraAdapters(this.id, loraAdapters, removePrevious) + let loraAdapters: Array<{ path: string; scaled?: number }> = [] + if (loraList) + loraAdapters = loraList.map((l) => ({ + path: l.path.replace(/file:\/\//, ''), + scaled: l.scaled, + })) + return RNLlama.applyLoraAdapters(this.id, loraAdapters) } async removeLoraAdapters(): Promise { From 86fb2390ae1e5011dae18de242d14938f2e5c002 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Wed, 20 Nov 2024 17:03:35 +0800 Subject: [PATCH 06/12] fix: use llama->applyLoraAdapters on init --- android/src/main/jni.cpp | 54 ++++++++++++++++++++++++---------------- ios/RNLlamaContext.mm | 45 +++++++++++++++++++++------------ 2 files changed, 62 insertions(+), 37 deletions(-) diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index e774337d..39158ad2 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -236,7 +236,7 @@ Java_com_rnllama_LlamaContext_initContext( jboolean vocab_only, jstring lora_str, jfloat lora_scaled, - jobject lora_adapters, + jobject lora_list, jfloat rope_freq_base, jfloat rope_freq_scale, jint pooling_type, @@ -286,25 +286,6 @@ Java_com_rnllama_LlamaContext_initContext( defaultParams.use_mlock = use_mlock; defaultParams.use_mmap = use_mmap; - const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr); - if (lora_chars != nullptr && lora_chars[0] != '\0') { - defaultParams.lora_adapters.push_back({lora_chars, lora_scaled}); - } - - // lora_adapters: ReadableArray - int lora_adapters_size = readablearray::size(env, lora_adapters); - for (int i = 0; i < lora_adapters_size; i++) { - jobject lora_adapter = readablearray::getMap(env, lora_adapters, i); - jstring path = readablemap::getString(env, lora_adapter, "path", nullptr); - if (path != nullptr) { - const char *path_chars = env->GetStringUTFChars(path, nullptr); - env->ReleaseStringUTFChars(path, path_chars); - float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f); - - defaultParams.lora_adapters.push_back({path_chars, scaled}); - } - } - defaultParams.rope_freq_base = rope_freq_base; defaultParams.rope_freq_scale = rope_freq_scale; @@ -338,7 +319,6 @@ Java_com_rnllama_LlamaContext_initContext( bool is_model_loaded = llama->loadModel(defaultParams); env->ReleaseStringUTFChars(model_path_str, model_path_chars); - env->ReleaseStringUTFChars(lora_str, lora_chars); env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars); env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars); @@ -354,6 +334,38 @@ Java_com_rnllama_LlamaContext_initContext( llama_free(llama->ctx); } + std::vector lora_adapters; + const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr); + if (lora_chars != nullptr && lora_chars[0] != '\0') { + common_lora_adapter_info la; + la.path = lora_chars; + la.scale = lora_scaled; + lora_adapters.push_back(la); + } + env->ReleaseStringUTFChars(lora_str, lora_chars); + + // lora_adapters: ReadableArray + int lora_list_size = readablearray::size(env, lora_list); + for (int i = 0; i < lora_list_size; i++) { + jobject lora_adapter = readablearray::getMap(env, lora_list, i); + jstring path = readablemap::getString(env, lora_adapter, "path", nullptr); + if (path != nullptr) { + const char *path_chars = env->GetStringUTFChars(path, nullptr); + common_lora_adapter_info la; + la.path = path_chars; + la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f); + lora_adapters.push_back(la); + env->ReleaseStringUTFChars(path, path_chars); + } + } + + int result = llama->applyLoraAdapters(lora_adapters); + if (result != 0) { + LOGI("[RNLlama] Failed to apply lora adapters"); + llama_free(llama->ctx); + return -1; + } + return reinterpret_cast(llama->ctx); } diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index 307d14e0..fa5b496c 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -110,22 +110,6 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig } } - if (params[@"lora"]) { - float lora_scaled = 1.0f; - if (params[@"lora_scaled"]) lora_scaled = [params[@"lora_scaled"] floatValue]; - defaultParams.lora_adapters.push_back({[params[@"lora"] UTF8String], lora_scaled}); - } - - if (params[@"lora_list"] && [params[@"lora_list"] isKindOfClass:[NSArray class]]) { - NSArray *lora_list = params[@"lora_list"]; - for (NSDictionary *lora_adapter in lora_list) { - NSString *path = lora_adapter[@"path"]; - if (!path) continue; - float scale = [lora_adapter[@"scaled"] floatValue]; - defaultParams.lora_adapters.push_back({[path UTF8String], scale}); - } - } - if (params[@"rope_freq_base"]) defaultParams.rope_freq_base = [params[@"rope_freq_base"] floatValue]; if (params[@"rope_freq_scale"]) defaultParams.rope_freq_scale = [params[@"rope_freq_scale"] floatValue]; @@ -140,6 +124,7 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig const int defaultNThreads = nThreads == 4 ? 2 : MIN(4, maxThreads); defaultParams.cpuparams.n_threads = nThreads > 0 ? nThreads : defaultNThreads; + RNLlamaContext *context = [[RNLlamaContext alloc] init]; context->llama = new rnllama::llama_rn_context(); context->llama->is_load_interrupted = false; @@ -169,6 +154,34 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig @throw [NSException exceptionWithName:@"LlamaException" reason:@"Embedding is not supported in encoder-decoder models" userInfo:nil]; } + std::vector lora_adapters; + if (params[@"lora"]) { + common_lora_adapter_info la; + la.path = [params[@"lora"] UTF8String]; + la.scale = 1.0f; + if (params[@"lora_scaled"]) la.scale = [params[@"lora_scaled"] floatValue]; + lora_adapters.push_back(la); + } + if (params[@"lora_list"] && [params[@"lora_list"] isKindOfClass:[NSArray class]]) { + NSArray *lora_list = params[@"lora_list"]; + for (NSDictionary *lora_adapter in lora_list) { + NSString *path = lora_adapter[@"path"]; + if (!path) continue; + float scale = [lora_adapter[@"scaled"] floatValue]; + common_lora_adapter_info la; + la.path = [path UTF8String]; + la.scale = scale; + lora_adapters.push_back(la); + } + } + if (lora_adapters.size() > 0) { + int result = context->llama->applyLoraAdapters(lora_adapters); + if (result != 0) { + delete context->llama; + @throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to apply lora adapters" userInfo:nil]; + } + } + context->is_metal_enabled = isMetalEnabled; context->reason_no_metal = reasonNoMetal; From 121bc46da4742a26055db9314bb63727277f2814 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Wed, 20 Nov 2024 17:05:08 +0800 Subject: [PATCH 07/12] feat(example): add lora comments --- example/src/App.tsx | 41 ++++++++++++++++++++++++++++++++++------- 1 file changed, 34 insertions(+), 7 deletions(-) diff --git a/example/src/App.tsx b/example/src/App.tsx index 4c9f87e8..6d549fed 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -95,7 +95,7 @@ export default function App() { const handleInitContext = async ( file: DocumentPickerResponse, - loraFile?: DocumentPickerResponse, + loraFile: DocumentPickerResponse | null, ) => { await handleReleaseContext() await getModelInfo(file.uri) @@ -179,6 +179,15 @@ export default function App() { return file } + const pickLora = async () => { + let loraFile + const loraRes = await DocumentPicker.pick({ + type: Platform.OS === 'ios' ? 'public.data' : 'application/octet-stream', + }).catch((e) => console.log('No lora file picked, error: ', e.message)) + if (loraRes?.[0]) loraFile = await copyFileIfNeeded('lora', loraRes[0]) + return loraFile + } + const handlePickModel = async () => { const modelRes = await DocumentPicker.pick({ type: Platform.OS === 'ios' ? 'public.data' : 'application/octet-stream', @@ -186,12 +195,10 @@ export default function App() { if (!modelRes?.[0]) return const modelFile = await copyFileIfNeeded('model', modelRes?.[0]) - let loraFile + let loraFile = null // Example: Apply lora adapter (Currently only select one lora file) (Uncomment to use) - // const loraRes = await DocumentPicker.pick({ - // type: Platform.OS === 'ios' ? 'public.data' : 'application/octet-stream', - // }).catch(e => console.log('No lora file picked, error: ', e.message)) - // if (loraRes?.[0]) loraFile = await copyFileIfNeeded('lora', loraRes[0]) + // loraFile = await pickLora() + loraFile = null handleInitContext(modelFile, loraFile) } @@ -278,6 +285,26 @@ export default function App() { addSystemMessage(`Session load failed: ${e.message}`) }) return + case '/lora': + pickLora() + .then((loraFile) => { + if (loraFile) + context.applyLoraAdapters([{ path: loraFile.uri }]) + }) + .then(context.getLoadedLoraAdapters) + .then((loraList) => + addSystemMessage( + `Loaded lora adapters: ${JSON.stringify(loraList)}`, + ), + ) + return + case '/lora-list': + context.getLoadedLoraAdapters().then((loraList) => { + addSystemMessage( + `Loaded lora adapters: ${JSON.stringify(loraList)}`, + ) + }) + return } } const textMessage: MessageType.Text = { @@ -417,7 +444,7 @@ export default function App() { dry_base: 1.75, dry_allowed_length: 2, dry_penalty_last_n: -1, - dry_sequence_breakers: ["\n", ":", "\"", "*"], + dry_sequence_breakers: ['\n', ':', '"', '*'], mirostat: 0, mirostat_tau: 5, mirostat_eta: 0.1, From 4446febd669890003562d3f3062ec84a2ecd3289 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Wed, 20 Nov 2024 18:23:05 +0800 Subject: [PATCH 08/12] fix(android): push map --- .../main/java/com/rnllama/LlamaContext.java | 4 +- .../src/main/java/com/rnllama/RNLlama.java | 2 + android/src/main/jni.cpp | 68 ++++++++++--------- cpp/rn-llama.hpp | 1 + example/src/App.tsx | 7 +- 5 files changed, 46 insertions(+), 36 deletions(-) diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index 926b2030..9c14a3a0 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -71,7 +71,7 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa // float lora_scaled, params.hasKey("lora_scaled") ? (float) params.getDouble("lora_scaled") : 1.0f, // ReadableArray lora_adapters, - params.hasKey("lora_adapters") ? params.getArray("lora_adapters") : null, + params.hasKey("lora_list") ? params.getArray("lora_list") : null, // float rope_freq_base, params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f, // float rope_freq_scale @@ -424,7 +424,7 @@ protected static native long initContext( boolean vocab_only, String lora, float lora_scaled, - ReadableArray lora_adapters, + ReadableArray lora_list, float rope_freq_base, float rope_freq_scale, int pooling_type, diff --git a/android/src/main/java/com/rnllama/RNLlama.java b/android/src/main/java/com/rnllama/RNLlama.java index 2822209a..4dfa2497 100644 --- a/android/src/main/java/com/rnllama/RNLlama.java +++ b/android/src/main/java/com/rnllama/RNLlama.java @@ -468,6 +468,7 @@ protected void onPostExecute(Void result) { promise.reject(exception); return; } + promise.resolve(null); } }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR); tasks.put(task, "removeLoraAdapters-" + contextId); @@ -498,6 +499,7 @@ protected void onPostExecute(ReadableArray result) { promise.reject(exception); return; } + promise.resolve(result); } }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR); tasks.put(task, "getLoadedLoraAdapters-" + contextId); diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 39158ad2..d2616571 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -128,7 +128,7 @@ static inline void pushString(JNIEnv *env, jobject arr, const char *value) { // Method to push WritableMap into WritableArray static inline void pushMap(JNIEnv *env, jobject arr, jobject value) { jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray"); - jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/WritableMap;)V"); + jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/ReadableMap;)V"); env->CallVoidMethod(arr, pushMapMethod, value); } @@ -324,41 +324,42 @@ Java_com_rnllama_LlamaContext_initContext( LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false")); if (is_model_loaded) { - if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) { - LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported"); - llama_free(llama->ctx); - return -1; - } - context_map[(long) llama->ctx] = llama; + if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) { + LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported"); + llama_free(llama->ctx); + return -1; + } + context_map[(long) llama->ctx] = llama; } else { - llama_free(llama->ctx); + llama_free(llama->ctx); } std::vector lora_adapters; const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr); if (lora_chars != nullptr && lora_chars[0] != '\0') { - common_lora_adapter_info la; - la.path = lora_chars; - la.scale = lora_scaled; - lora_adapters.push_back(la); - } - env->ReleaseStringUTFChars(lora_str, lora_chars); - - // lora_adapters: ReadableArray - int lora_list_size = readablearray::size(env, lora_list); - for (int i = 0; i < lora_list_size; i++) { - jobject lora_adapter = readablearray::getMap(env, lora_list, i); - jstring path = readablemap::getString(env, lora_adapter, "path", nullptr); - if (path != nullptr) { - const char *path_chars = env->GetStringUTFChars(path, nullptr); - common_lora_adapter_info la; - la.path = path_chars; - la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f); - lora_adapters.push_back(la); - env->ReleaseStringUTFChars(path, path_chars); + common_lora_adapter_info la; + la.path = lora_chars; + la.scale = lora_scaled; + lora_adapters.push_back(la); + } + + if (lora_list != nullptr) { + // lora_adapters: ReadableArray + int lora_list_size = readablearray::size(env, lora_list); + for (int i = 0; i < lora_list_size; i++) { + jobject lora_adapter = readablearray::getMap(env, lora_list, i); + jstring path = readablemap::getString(env, lora_adapter, "path", nullptr); + if (path != nullptr) { + const char *path_chars = env->GetStringUTFChars(path, nullptr); + common_lora_adapter_info la; + la.path = path_chars; + la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f); + lora_adapters.push_back(la); + env->ReleaseStringUTFChars(path, path_chars); + } } } - + env->ReleaseStringUTFChars(lora_str, lora_chars); int result = llama->applyLoraAdapters(lora_adapters); if (result != 0) { LOGI("[RNLlama] Failed to apply lora adapters"); @@ -920,8 +921,10 @@ Java_com_rnllama_LlamaContext_applyLoraAdapters( const char *path_chars = env->GetStringUTFChars(path, nullptr); env->ReleaseStringUTFChars(path, path_chars); float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f); - - lora_adapters.push_back({path_chars, scaled}); + common_lora_adapter_info la; + la.path = path_chars; + la.scale = scaled; + lora_adapters.push_back(la); } } return llama->applyLoraAdapters(lora_adapters); @@ -939,15 +942,14 @@ Java_com_rnllama_LlamaContext_removeLoraAdapters( JNIEXPORT jobject JNICALL Java_com_rnllama_LlamaContext_getLoadedLoraAdapters( JNIEnv *env, jobject thiz, jlong context_ptr) { - UNUSED(env); UNUSED(thiz); auto llama = context_map[(long) context_ptr]; auto loaded_lora_adapters = llama->getLoadedLoraAdapters(); auto result = createWritableArray(env); for (common_lora_adapter_container &la : loaded_lora_adapters) { auto map = createWriteableMap(env); - pushString(env, map, la.path.c_str()); - pushDouble(env, map, la.scale); + putString(env, map, "path", la.path.c_str()); + putDouble(env, map, "scaled", la.scale); pushMap(env, result, map); } return result; diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp index c1740474..ee1922c9 100644 --- a/cpp/rn-llama.hpp +++ b/cpp/rn-llama.hpp @@ -727,6 +727,7 @@ struct llama_rn_context } int applyLoraAdapters(std::vector lora_adapters) { + this->lora_adapters.clear(); auto containers = std::vector(); for (auto & la : lora_adapters) { common_lora_adapter_container loaded_la; diff --git a/example/src/App.tsx b/example/src/App.tsx index 6d549fed..724257db 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -195,7 +195,7 @@ export default function App() { if (!modelRes?.[0]) return const modelFile = await copyFileIfNeeded('model', modelRes?.[0]) - let loraFile = null + let loraFile: any = null // Example: Apply lora adapter (Currently only select one lora file) (Uncomment to use) // loraFile = await pickLora() loraFile = null @@ -298,6 +298,11 @@ export default function App() { ), ) return + case '/remove-lora': + context.removeLoraAdapters().then(() => { + addSystemMessage('Lora adapters removed!') + }) + return case '/lora-list': context.getLoadedLoraAdapters().then((loraList) => { addSystemMessage( From 6f11179109d8854f3fa9adf2e995fb57eb3d1dda Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Thu, 21 Nov 2024 10:17:04 +0800 Subject: [PATCH 09/12] fix(example): getLoadedLoraAdapters usage --- example/src/App.tsx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/example/src/App.tsx b/example/src/App.tsx index 724257db..3849753c 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -291,7 +291,7 @@ export default function App() { if (loraFile) context.applyLoraAdapters([{ path: loraFile.uri }]) }) - .then(context.getLoadedLoraAdapters) + .then(() => context.getLoadedLoraAdapters()) .then((loraList) => addSystemMessage( `Loaded lora adapters: ${JSON.stringify(loraList)}`, From 8560f03620609cd2f779ece7d535579a3066e657 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Thu, 21 Nov 2024 10:27:41 +0800 Subject: [PATCH 10/12] fix(cpp): apply empty list instead of expose new fn --- cpp/common.cpp | 6 ------ cpp/common.h | 3 --- cpp/rn-llama.hpp | 2 +- scripts/common.cpp.patch | 27 +++++++-------------------- scripts/common.h.patch | 28 +++++++++------------------- 5 files changed, 17 insertions(+), 49 deletions(-) diff --git a/cpp/common.cpp b/cpp/common.cpp index ca3b2b5b..3b527a16 100644 --- a/cpp/common.cpp +++ b/cpp/common.cpp @@ -973,12 +973,6 @@ void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters) { - for (auto & la : lora_adapters) { - llama_lora_adapter_remove(ctx, la.adapter); - } -} - struct llama_model_params common_model_params_to_llama(const common_params & params) { auto mparams = llama_model_default_params(); diff --git a/cpp/common.h b/cpp/common.h index 892d6f4d..4070df28 100644 --- a/cpp/common.h +++ b/cpp/common.h @@ -476,9 +476,6 @@ struct llama_model * common_load_model_from_hf(const char * repo, const char * f // clear LoRA adapters from context, then apply new list of adapters void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters); -// remove LoRA adapters from context -void common_lora_adapters_remove(struct llama_context * ctx, std::vector & lora_adapters); - // Batch utils void common_batch_clear(struct llama_batch & batch); diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp index ee1922c9..c175744d 100644 --- a/cpp/rn-llama.hpp +++ b/cpp/rn-llama.hpp @@ -747,8 +747,8 @@ struct llama_rn_context } void removeLoraAdapters() { - common_lora_adapters_remove(ctx, this->lora_adapters); this->lora_adapters.clear(); + common_lora_adapters_apply(ctx, this->lora_adapters); // apply empty list } std::vector getLoadedLoraAdapters() { diff --git a/scripts/common.cpp.patch b/scripts/common.cpp.patch index 515ff47d..4cc23b7d 100644 --- a/scripts/common.cpp.patch +++ b/scripts/common.cpp.patch @@ -1,5 +1,5 @@ ---- common.cpp.orig 2024-11-19 11:11:05 -+++ common.cpp 2024-11-19 11:07:02 +--- common.cpp.orig 2024-11-21 10:21:53 ++++ common.cpp 2024-11-21 10:22:56 @@ -4,10 +4,6 @@ #include "common.h" @@ -33,19 +33,7 @@ // // CPU utils -@@ -973,12 +973,20 @@ - } - } - -+void common_lora_adapters_remove(struct llama_context * ctx, std::vector & lora_adapters) { -+ for (auto & la : lora_adapters) { -+ llama_lora_adapter_remove(ctx, la.adapter); -+ } -+} -+ - struct llama_model_params common_model_params_to_llama(const common_params & params) { - auto mparams = llama_model_default_params(); - +@@ -979,6 +979,8 @@ if (params.n_gpu_layers != -1) { mparams.n_gpu_layers = params.n_gpu_layers; } @@ -54,7 +42,7 @@ mparams.rpc_servers = params.rpc_servers.c_str(); mparams.main_gpu = params.main_gpu; mparams.split_mode = params.split_mode; -@@ -993,6 +1001,9 @@ +@@ -993,6 +995,9 @@ mparams.kv_overrides = params.kv_overrides.data(); } @@ -64,11 +52,10 @@ return mparams; } -@@ -1117,221 +1128,7 @@ - +@@ -1118,220 +1123,6 @@ return false; } -- + -static bool common_download_file(const std::string & url, const std::string & path, const std::string & hf_token) { - - // Initialize libcurl @@ -108,7 +95,7 @@ - nlohmann::json metadata; - std::string etag; - std::string last_modified; - +- - if (file_exists) { - // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). - std::ifstream metadata_in(metadata_path); diff --git a/scripts/common.h.patch b/scripts/common.h.patch index a17cff18..26b42531 100644 --- a/scripts/common.h.patch +++ b/scripts/common.h.patch @@ -1,10 +1,9 @@ ---- common.h.orig 2024-11-19 11:11:05 -+++ common.h 2024-11-19 11:07:10 -@@ -40,6 +40,17 @@ - extern char const * LLAMA_BUILD_TARGET; +--- common.h.orig 2024-11-21 10:21:53 ++++ common.h 2024-11-21 10:23:00 +@@ -41,6 +41,17 @@ struct common_control_vector_load_info; -+ + +#define print_build_info() do { \ + fprintf(stderr, "%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT); \ + fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \ @@ -15,9 +14,10 @@ +extern char const *LLAMA_COMMIT; +extern char const *LLAMA_COMPILER; +extern char const *LLAMA_BUILD_TARGET; - ++ // // CPU utils + // @@ -154,6 +165,7 @@ }; @@ -26,23 +26,13 @@ int32_t n_predict = -1; // new tokens to predict int32_t n_ctx = 4096; // context size int32_t n_batch = 2048; // logical batch size for prompt processing (must be >=32 to use BLAS) -@@ -269,6 +281,9 @@ - bool no_kv_offload = false; // disable KV offloading +@@ -270,6 +282,9 @@ bool warmup = true; // warmup run bool check_tensors = false; // validate tensor data -+ + + llama_progress_callback progress_callback; + void * progress_callback_user_data; - ++ std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_v = "f16"; // KV cache data type for the V -@@ -461,6 +476,9 @@ - // clear LoRA adapters from context, then apply new list of adapters - void common_lora_adapters_apply(struct llama_context * ctx, std::vector & lora_adapters); - -+// remove LoRA adapters from context -+void common_lora_adapters_remove(struct llama_context * ctx, std::vector & lora_adapters); -+ - // Batch utils - void common_batch_clear(struct llama_batch & batch); From 156979701e78e35753d6ae1d8281578a8f978168 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Thu, 21 Nov 2024 10:28:01 +0800 Subject: [PATCH 11/12] fix(ios): removeLoraAdapters --- ios/RNLlama.mm | 12 ++++++++++++ ios/RNLlamaContext.mm | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/ios/RNLlama.mm b/ios/RNLlama.mm index 18692d13..de3edcab 100644 --- a/ios/RNLlama.mm +++ b/ios/RNLlama.mm @@ -277,6 +277,10 @@ - (NSArray *)supportedEvents { withRejecter:(RCTPromiseRejectBlock)reject) { RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]]; + if (context == nil) { + reject(@"llama_error", @"Context not found", nil); + return; + } [context applyLoraAdapters:loraAdapters]; resolve(nil); } @@ -286,6 +290,10 @@ - (NSArray *)supportedEvents { withRejecter:(RCTPromiseRejectBlock)reject) { RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]]; + if (context == nil) { + reject(@"llama_error", @"Context not found", nil); + return; + } [context removeLoraAdapters]; resolve(nil); } @@ -295,6 +303,10 @@ - (NSArray *)supportedEvents { withRejecter:(RCTPromiseRejectBlock)reject) { RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]]; + if (context == nil) { + reject(@"llama_error", @"Context not found", nil); + return; + } resolve([context getLoadedLoraAdapters]); } diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index fa5b496c..8d364889 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -576,7 +576,7 @@ - (void)applyLoraAdapters:(NSArray *)loraAdapters { } - (void)removeLoraAdapters { - [llama removeLoraAdapters]; + llama->removeLoraAdapters(); } - (NSArray *)getLoadedLoraAdapters { From ca37546cce1de4eb2b0063ffde6a28d6613158e8 Mon Sep 17 00:00:00 2001 From: Jhen-Jie Hong Date: Thu, 21 Nov 2024 10:31:21 +0800 Subject: [PATCH 12/12] feat: check context is predicting --- android/src/main/java/com/rnllama/RNLlama.java | 6 ++++++ ios/RNLlama.mm | 8 ++++++++ 2 files changed, 14 insertions(+) diff --git a/android/src/main/java/com/rnllama/RNLlama.java b/android/src/main/java/com/rnllama/RNLlama.java index 4dfa2497..ab98a923 100644 --- a/android/src/main/java/com/rnllama/RNLlama.java +++ b/android/src/main/java/com/rnllama/RNLlama.java @@ -425,6 +425,9 @@ protected Void doInBackground(Void... voids) { if (context == null) { throw new Exception("Context not found"); } + if (context.isPredicting()) { + throw new Exception("Context is busy"); + } context.applyLoraAdapters(loraAdapters); } catch (Exception e) { exception = e; @@ -455,6 +458,9 @@ protected Void doInBackground(Void... voids) { if (context == null) { throw new Exception("Context not found"); } + if (context.isPredicting()) { + throw new Exception("Context is busy"); + } context.removeLoraAdapters(); } catch (Exception e) { exception = e; diff --git a/ios/RNLlama.mm b/ios/RNLlama.mm index de3edcab..9c6b848c 100644 --- a/ios/RNLlama.mm +++ b/ios/RNLlama.mm @@ -281,6 +281,10 @@ - (NSArray *)supportedEvents { reject(@"llama_error", @"Context not found", nil); return; } + if ([context isPredicting]) { + reject(@"llama_error", @"Context is busy", nil); + return; + } [context applyLoraAdapters:loraAdapters]; resolve(nil); } @@ -294,6 +298,10 @@ - (NSArray *)supportedEvents { reject(@"llama_error", @"Context not found", nil); return; } + if ([context isPredicting]) { + reject(@"llama_error", @"Context is busy", nil); + return; + } [context removeLoraAdapters]; resolve(nil); }