diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index 926b203..9c14a3a 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -71,7 +71,7 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa // float lora_scaled, params.hasKey("lora_scaled") ? (float) params.getDouble("lora_scaled") : 1.0f, // ReadableArray lora_adapters, - params.hasKey("lora_adapters") ? params.getArray("lora_adapters") : null, + params.hasKey("lora_list") ? params.getArray("lora_list") : null, // float rope_freq_base, params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f, // float rope_freq_scale @@ -424,7 +424,7 @@ protected static native long initContext( boolean vocab_only, String lora, float lora_scaled, - ReadableArray lora_adapters, + ReadableArray lora_list, float rope_freq_base, float rope_freq_scale, int pooling_type, diff --git a/android/src/main/java/com/rnllama/RNLlama.java b/android/src/main/java/com/rnllama/RNLlama.java index 2822209..4dfa249 100644 --- a/android/src/main/java/com/rnllama/RNLlama.java +++ b/android/src/main/java/com/rnllama/RNLlama.java @@ -468,6 +468,7 @@ protected void onPostExecute(Void result) { promise.reject(exception); return; } + promise.resolve(null); } }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR); tasks.put(task, "removeLoraAdapters-" + contextId); @@ -498,6 +499,7 @@ protected void onPostExecute(ReadableArray result) { promise.reject(exception); return; } + promise.resolve(result); } }.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR); tasks.put(task, "getLoadedLoraAdapters-" + contextId); diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index 39158ad..d261657 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -128,7 +128,7 @@ static inline void pushString(JNIEnv *env, jobject arr, const char *value) { // Method to push WritableMap into WritableArray static inline void pushMap(JNIEnv *env, jobject arr, jobject value) { jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray"); - jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/WritableMap;)V"); + jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/ReadableMap;)V"); env->CallVoidMethod(arr, pushMapMethod, value); } @@ -324,41 +324,42 @@ Java_com_rnllama_LlamaContext_initContext( LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false")); if (is_model_loaded) { - if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) { - LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported"); - llama_free(llama->ctx); - return -1; - } - context_map[(long) llama->ctx] = llama; + if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) { + LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported"); + llama_free(llama->ctx); + return -1; + } + context_map[(long) llama->ctx] = llama; } else { - llama_free(llama->ctx); + llama_free(llama->ctx); } std::vector lora_adapters; const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr); if (lora_chars != nullptr && lora_chars[0] != '\0') { - common_lora_adapter_info la; - la.path = lora_chars; - la.scale = lora_scaled; - lora_adapters.push_back(la); - } - env->ReleaseStringUTFChars(lora_str, lora_chars); - - // lora_adapters: ReadableArray - int lora_list_size = readablearray::size(env, lora_list); - for (int i = 0; i < lora_list_size; i++) { - jobject lora_adapter = readablearray::getMap(env, lora_list, i); - jstring path = readablemap::getString(env, lora_adapter, "path", nullptr); - if (path != nullptr) { - const char *path_chars = env->GetStringUTFChars(path, nullptr); - common_lora_adapter_info la; - la.path = path_chars; - la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f); - lora_adapters.push_back(la); - env->ReleaseStringUTFChars(path, path_chars); + common_lora_adapter_info la; + la.path = lora_chars; + la.scale = lora_scaled; + lora_adapters.push_back(la); + } + + if (lora_list != nullptr) { + // lora_adapters: ReadableArray + int lora_list_size = readablearray::size(env, lora_list); + for (int i = 0; i < lora_list_size; i++) { + jobject lora_adapter = readablearray::getMap(env, lora_list, i); + jstring path = readablemap::getString(env, lora_adapter, "path", nullptr); + if (path != nullptr) { + const char *path_chars = env->GetStringUTFChars(path, nullptr); + common_lora_adapter_info la; + la.path = path_chars; + la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f); + lora_adapters.push_back(la); + env->ReleaseStringUTFChars(path, path_chars); + } } } - + env->ReleaseStringUTFChars(lora_str, lora_chars); int result = llama->applyLoraAdapters(lora_adapters); if (result != 0) { LOGI("[RNLlama] Failed to apply lora adapters"); @@ -920,8 +921,10 @@ Java_com_rnllama_LlamaContext_applyLoraAdapters( const char *path_chars = env->GetStringUTFChars(path, nullptr); env->ReleaseStringUTFChars(path, path_chars); float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f); - - lora_adapters.push_back({path_chars, scaled}); + common_lora_adapter_info la; + la.path = path_chars; + la.scale = scaled; + lora_adapters.push_back(la); } } return llama->applyLoraAdapters(lora_adapters); @@ -939,15 +942,14 @@ Java_com_rnllama_LlamaContext_removeLoraAdapters( JNIEXPORT jobject JNICALL Java_com_rnllama_LlamaContext_getLoadedLoraAdapters( JNIEnv *env, jobject thiz, jlong context_ptr) { - UNUSED(env); UNUSED(thiz); auto llama = context_map[(long) context_ptr]; auto loaded_lora_adapters = llama->getLoadedLoraAdapters(); auto result = createWritableArray(env); for (common_lora_adapter_container &la : loaded_lora_adapters) { auto map = createWriteableMap(env); - pushString(env, map, la.path.c_str()); - pushDouble(env, map, la.scale); + putString(env, map, "path", la.path.c_str()); + putDouble(env, map, "scaled", la.scale); pushMap(env, result, map); } return result; diff --git a/cpp/rn-llama.hpp b/cpp/rn-llama.hpp index c174047..ee1922c 100644 --- a/cpp/rn-llama.hpp +++ b/cpp/rn-llama.hpp @@ -727,6 +727,7 @@ struct llama_rn_context } int applyLoraAdapters(std::vector lora_adapters) { + this->lora_adapters.clear(); auto containers = std::vector(); for (auto & la : lora_adapters) { common_lora_adapter_container loaded_la; diff --git a/example/src/App.tsx b/example/src/App.tsx index 6d549fe..724257d 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -195,7 +195,7 @@ export default function App() { if (!modelRes?.[0]) return const modelFile = await copyFileIfNeeded('model', modelRes?.[0]) - let loraFile = null + let loraFile: any = null // Example: Apply lora adapter (Currently only select one lora file) (Uncomment to use) // loraFile = await pickLora() loraFile = null @@ -298,6 +298,11 @@ export default function App() { ), ) return + case '/remove-lora': + context.removeLoraAdapters().then(() => { + addSystemMessage('Lora adapters removed!') + }) + return case '/lora-list': context.getLoadedLoraAdapters().then((loraList) => { addSystemMessage(