diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index e774337..39158ad 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -236,7 +236,7 @@ Java_com_rnllama_LlamaContext_initContext( jboolean vocab_only, jstring lora_str, jfloat lora_scaled, - jobject lora_adapters, + jobject lora_list, jfloat rope_freq_base, jfloat rope_freq_scale, jint pooling_type, @@ -286,25 +286,6 @@ Java_com_rnllama_LlamaContext_initContext( defaultParams.use_mlock = use_mlock; defaultParams.use_mmap = use_mmap; - const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr); - if (lora_chars != nullptr && lora_chars[0] != '\0') { - defaultParams.lora_adapters.push_back({lora_chars, lora_scaled}); - } - - // lora_adapters: ReadableArray - int lora_adapters_size = readablearray::size(env, lora_adapters); - for (int i = 0; i < lora_adapters_size; i++) { - jobject lora_adapter = readablearray::getMap(env, lora_adapters, i); - jstring path = readablemap::getString(env, lora_adapter, "path", nullptr); - if (path != nullptr) { - const char *path_chars = env->GetStringUTFChars(path, nullptr); - env->ReleaseStringUTFChars(path, path_chars); - float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f); - - defaultParams.lora_adapters.push_back({path_chars, scaled}); - } - } - defaultParams.rope_freq_base = rope_freq_base; defaultParams.rope_freq_scale = rope_freq_scale; @@ -338,7 +319,6 @@ Java_com_rnllama_LlamaContext_initContext( bool is_model_loaded = llama->loadModel(defaultParams); env->ReleaseStringUTFChars(model_path_str, model_path_chars); - env->ReleaseStringUTFChars(lora_str, lora_chars); env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars); env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars); @@ -354,6 +334,38 @@ Java_com_rnllama_LlamaContext_initContext( llama_free(llama->ctx); } + std::vector lora_adapters; + const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr); + if (lora_chars != nullptr && lora_chars[0] != '\0') { + common_lora_adapter_info la; + la.path = lora_chars; + la.scale = lora_scaled; + lora_adapters.push_back(la); + } + env->ReleaseStringUTFChars(lora_str, lora_chars); + + // lora_adapters: ReadableArray + int lora_list_size = readablearray::size(env, lora_list); + for (int i = 0; i < lora_list_size; i++) { + jobject lora_adapter = readablearray::getMap(env, lora_list, i); + jstring path = readablemap::getString(env, lora_adapter, "path", nullptr); + if (path != nullptr) { + const char *path_chars = env->GetStringUTFChars(path, nullptr); + common_lora_adapter_info la; + la.path = path_chars; + la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f); + lora_adapters.push_back(la); + env->ReleaseStringUTFChars(path, path_chars); + } + } + + int result = llama->applyLoraAdapters(lora_adapters); + if (result != 0) { + LOGI("[RNLlama] Failed to apply lora adapters"); + llama_free(llama->ctx); + return -1; + } + return reinterpret_cast(llama->ctx); } diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index 307d14e..fa5b496 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -110,22 +110,6 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig } } - if (params[@"lora"]) { - float lora_scaled = 1.0f; - if (params[@"lora_scaled"]) lora_scaled = [params[@"lora_scaled"] floatValue]; - defaultParams.lora_adapters.push_back({[params[@"lora"] UTF8String], lora_scaled}); - } - - if (params[@"lora_list"] && [params[@"lora_list"] isKindOfClass:[NSArray class]]) { - NSArray *lora_list = params[@"lora_list"]; - for (NSDictionary *lora_adapter in lora_list) { - NSString *path = lora_adapter[@"path"]; - if (!path) continue; - float scale = [lora_adapter[@"scaled"] floatValue]; - defaultParams.lora_adapters.push_back({[path UTF8String], scale}); - } - } - if (params[@"rope_freq_base"]) defaultParams.rope_freq_base = [params[@"rope_freq_base"] floatValue]; if (params[@"rope_freq_scale"]) defaultParams.rope_freq_scale = [params[@"rope_freq_scale"] floatValue]; @@ -140,6 +124,7 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig const int defaultNThreads = nThreads == 4 ? 2 : MIN(4, maxThreads); defaultParams.cpuparams.n_threads = nThreads > 0 ? nThreads : defaultNThreads; + RNLlamaContext *context = [[RNLlamaContext alloc] init]; context->llama = new rnllama::llama_rn_context(); context->llama->is_load_interrupted = false; @@ -169,6 +154,34 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig @throw [NSException exceptionWithName:@"LlamaException" reason:@"Embedding is not supported in encoder-decoder models" userInfo:nil]; } + std::vector lora_adapters; + if (params[@"lora"]) { + common_lora_adapter_info la; + la.path = [params[@"lora"] UTF8String]; + la.scale = 1.0f; + if (params[@"lora_scaled"]) la.scale = [params[@"lora_scaled"] floatValue]; + lora_adapters.push_back(la); + } + if (params[@"lora_list"] && [params[@"lora_list"] isKindOfClass:[NSArray class]]) { + NSArray *lora_list = params[@"lora_list"]; + for (NSDictionary *lora_adapter in lora_list) { + NSString *path = lora_adapter[@"path"]; + if (!path) continue; + float scale = [lora_adapter[@"scaled"] floatValue]; + common_lora_adapter_info la; + la.path = [path UTF8String]; + la.scale = scale; + lora_adapters.push_back(la); + } + } + if (lora_adapters.size() > 0) { + int result = context->llama->applyLoraAdapters(lora_adapters); + if (result != 0) { + delete context->llama; + @throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to apply lora adapters" userInfo:nil]; + } + } + context->is_metal_enabled = isMetalEnabled; context->reason_no_metal = reasonNoMetal;