Skip to content

Commit

Permalink
fix: use llama->applyLoraAdapters on init
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Nov 20, 2024
1 parent 4f392a5 commit 86fb239
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 37 deletions.
54 changes: 33 additions & 21 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ Java_com_rnllama_LlamaContext_initContext(
jboolean vocab_only,
jstring lora_str,
jfloat lora_scaled,
jobject lora_adapters,
jobject lora_list,
jfloat rope_freq_base,
jfloat rope_freq_scale,
jint pooling_type,
Expand Down Expand Up @@ -286,25 +286,6 @@ Java_com_rnllama_LlamaContext_initContext(
defaultParams.use_mlock = use_mlock;
defaultParams.use_mmap = use_mmap;

const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
if (lora_chars != nullptr && lora_chars[0] != '\0') {
defaultParams.lora_adapters.push_back({lora_chars, lora_scaled});
}

// lora_adapters: ReadableArray<ReadableMap>
int lora_adapters_size = readablearray::size(env, lora_adapters);
for (int i = 0; i < lora_adapters_size; i++) {
jobject lora_adapter = readablearray::getMap(env, lora_adapters, i);
jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
if (path != nullptr) {
const char *path_chars = env->GetStringUTFChars(path, nullptr);
env->ReleaseStringUTFChars(path, path_chars);
float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);

defaultParams.lora_adapters.push_back({path_chars, scaled});
}
}

defaultParams.rope_freq_base = rope_freq_base;
defaultParams.rope_freq_scale = rope_freq_scale;

Expand Down Expand Up @@ -338,7 +319,6 @@ Java_com_rnllama_LlamaContext_initContext(
bool is_model_loaded = llama->loadModel(defaultParams);

env->ReleaseStringUTFChars(model_path_str, model_path_chars);
env->ReleaseStringUTFChars(lora_str, lora_chars);
env->ReleaseStringUTFChars(cache_type_k, cache_type_k_chars);
env->ReleaseStringUTFChars(cache_type_v, cache_type_v_chars);

Expand All @@ -354,6 +334,38 @@ Java_com_rnllama_LlamaContext_initContext(
llama_free(llama->ctx);
}

std::vector<common_lora_adapter_info> lora_adapters;
const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
if (lora_chars != nullptr && lora_chars[0] != '\0') {
common_lora_adapter_info la;
la.path = lora_chars;
la.scale = lora_scaled;
lora_adapters.push_back(la);
}
env->ReleaseStringUTFChars(lora_str, lora_chars);

// lora_adapters: ReadableArray<ReadableMap>
int lora_list_size = readablearray::size(env, lora_list);
for (int i = 0; i < lora_list_size; i++) {
jobject lora_adapter = readablearray::getMap(env, lora_list, i);
jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
if (path != nullptr) {
const char *path_chars = env->GetStringUTFChars(path, nullptr);
common_lora_adapter_info la;
la.path = path_chars;
la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
lora_adapters.push_back(la);
env->ReleaseStringUTFChars(path, path_chars);
}
}

int result = llama->applyLoraAdapters(lora_adapters);
if (result != 0) {
LOGI("[RNLlama] Failed to apply lora adapters");
llama_free(llama->ctx);
return -1;
}

return reinterpret_cast<jlong>(llama->ctx);
}

Expand Down
45 changes: 29 additions & 16 deletions ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,6 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig
}
}

if (params[@"lora"]) {
float lora_scaled = 1.0f;
if (params[@"lora_scaled"]) lora_scaled = [params[@"lora_scaled"] floatValue];
defaultParams.lora_adapters.push_back({[params[@"lora"] UTF8String], lora_scaled});
}

if (params[@"lora_list"] && [params[@"lora_list"] isKindOfClass:[NSArray class]]) {
NSArray *lora_list = params[@"lora_list"];
for (NSDictionary *lora_adapter in lora_list) {
NSString *path = lora_adapter[@"path"];
if (!path) continue;
float scale = [lora_adapter[@"scaled"] floatValue];
defaultParams.lora_adapters.push_back({[path UTF8String], scale});
}
}

if (params[@"rope_freq_base"]) defaultParams.rope_freq_base = [params[@"rope_freq_base"] floatValue];
if (params[@"rope_freq_scale"]) defaultParams.rope_freq_scale = [params[@"rope_freq_scale"] floatValue];

Expand All @@ -140,6 +124,7 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig
const int defaultNThreads = nThreads == 4 ? 2 : MIN(4, maxThreads);
defaultParams.cpuparams.n_threads = nThreads > 0 ? nThreads : defaultNThreads;


RNLlamaContext *context = [[RNLlamaContext alloc] init];
context->llama = new rnllama::llama_rn_context();
context->llama->is_load_interrupted = false;
Expand Down Expand Up @@ -169,6 +154,34 @@ + (instancetype)initWithParams:(NSDictionary *)params onProgress:(void (^)(unsig
@throw [NSException exceptionWithName:@"LlamaException" reason:@"Embedding is not supported in encoder-decoder models" userInfo:nil];
}

std::vector<common_lora_adapter_info> lora_adapters;
if (params[@"lora"]) {
common_lora_adapter_info la;
la.path = [params[@"lora"] UTF8String];
la.scale = 1.0f;
if (params[@"lora_scaled"]) la.scale = [params[@"lora_scaled"] floatValue];
lora_adapters.push_back(la);
}
if (params[@"lora_list"] && [params[@"lora_list"] isKindOfClass:[NSArray class]]) {
NSArray *lora_list = params[@"lora_list"];
for (NSDictionary *lora_adapter in lora_list) {
NSString *path = lora_adapter[@"path"];
if (!path) continue;
float scale = [lora_adapter[@"scaled"] floatValue];
common_lora_adapter_info la;
la.path = [path UTF8String];
la.scale = scale;
lora_adapters.push_back(la);
}
}
if (lora_adapters.size() > 0) {
int result = context->llama->applyLoraAdapters(lora_adapters);
if (result != 0) {
delete context->llama;
@throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to apply lora adapters" userInfo:nil];
}
}

context->is_metal_enabled = isMetalEnabled;
context->reason_no_metal = reasonNoMetal;

Expand Down

0 comments on commit 86fb239

Please sign in to comment.