Skip to content

Commit

Permalink
fix(android): push map
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Nov 20, 2024
1 parent 121bc46 commit 4446feb
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 36 deletions.
4 changes: 2 additions & 2 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
// float lora_scaled,
params.hasKey("lora_scaled") ? (float) params.getDouble("lora_scaled") : 1.0f,
// ReadableArray lora_adapters,
params.hasKey("lora_adapters") ? params.getArray("lora_adapters") : null,
params.hasKey("lora_list") ? params.getArray("lora_list") : null,
// float rope_freq_base,
params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
// float rope_freq_scale
Expand Down Expand Up @@ -424,7 +424,7 @@ protected static native long initContext(
boolean vocab_only,
String lora,
float lora_scaled,
ReadableArray lora_adapters,
ReadableArray lora_list,
float rope_freq_base,
float rope_freq_scale,
int pooling_type,
Expand Down
2 changes: 2 additions & 0 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ protected void onPostExecute(Void result) {
promise.reject(exception);
return;
}
promise.resolve(null);
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
tasks.put(task, "removeLoraAdapters-" + contextId);
Expand Down Expand Up @@ -498,6 +499,7 @@ protected void onPostExecute(ReadableArray result) {
promise.reject(exception);
return;
}
promise.resolve(result);
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
tasks.put(task, "getLoadedLoraAdapters-" + contextId);
Expand Down
68 changes: 35 additions & 33 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ static inline void pushString(JNIEnv *env, jobject arr, const char *value) {
// Method to push WritableMap into WritableArray
static inline void pushMap(JNIEnv *env, jobject arr, jobject value) {
jclass mapClass = env->FindClass("com/facebook/react/bridge/WritableArray");
jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/WritableMap;)V");
jmethodID pushMapMethod = env->GetMethodID(mapClass, "pushMap", "(Lcom/facebook/react/bridge/ReadableMap;)V");

env->CallVoidMethod(arr, pushMapMethod, value);
}
Expand Down Expand Up @@ -324,41 +324,42 @@ Java_com_rnllama_LlamaContext_initContext(

LOGI("[RNLlama] is_model_loaded %s", (is_model_loaded ? "true" : "false"));
if (is_model_loaded) {
if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
llama_free(llama->ctx);
return -1;
}
context_map[(long) llama->ctx] = llama;
if (embedding && llama_model_has_encoder(llama->model) && llama_model_has_decoder(llama->model)) {
LOGI("[RNLlama] computing embeddings in encoder-decoder models is not supported");
llama_free(llama->ctx);
return -1;
}
context_map[(long) llama->ctx] = llama;
} else {
llama_free(llama->ctx);
llama_free(llama->ctx);
}

std::vector<common_lora_adapter_info> lora_adapters;
const char *lora_chars = env->GetStringUTFChars(lora_str, nullptr);
if (lora_chars != nullptr && lora_chars[0] != '\0') {
common_lora_adapter_info la;
la.path = lora_chars;
la.scale = lora_scaled;
lora_adapters.push_back(la);
}
env->ReleaseStringUTFChars(lora_str, lora_chars);

// lora_adapters: ReadableArray<ReadableMap>
int lora_list_size = readablearray::size(env, lora_list);
for (int i = 0; i < lora_list_size; i++) {
jobject lora_adapter = readablearray::getMap(env, lora_list, i);
jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
if (path != nullptr) {
const char *path_chars = env->GetStringUTFChars(path, nullptr);
common_lora_adapter_info la;
la.path = path_chars;
la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
lora_adapters.push_back(la);
env->ReleaseStringUTFChars(path, path_chars);
common_lora_adapter_info la;
la.path = lora_chars;
la.scale = lora_scaled;
lora_adapters.push_back(la);
}

if (lora_list != nullptr) {
// lora_adapters: ReadableArray<ReadableMap>
int lora_list_size = readablearray::size(env, lora_list);
for (int i = 0; i < lora_list_size; i++) {
jobject lora_adapter = readablearray::getMap(env, lora_list, i);
jstring path = readablemap::getString(env, lora_adapter, "path", nullptr);
if (path != nullptr) {
const char *path_chars = env->GetStringUTFChars(path, nullptr);
common_lora_adapter_info la;
la.path = path_chars;
la.scale = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);
lora_adapters.push_back(la);
env->ReleaseStringUTFChars(path, path_chars);
}
}
}

env->ReleaseStringUTFChars(lora_str, lora_chars);
int result = llama->applyLoraAdapters(lora_adapters);
if (result != 0) {
LOGI("[RNLlama] Failed to apply lora adapters");
Expand Down Expand Up @@ -920,8 +921,10 @@ Java_com_rnllama_LlamaContext_applyLoraAdapters(
const char *path_chars = env->GetStringUTFChars(path, nullptr);
env->ReleaseStringUTFChars(path, path_chars);
float scaled = readablemap::getFloat(env, lora_adapter, "scaled", 1.0f);

lora_adapters.push_back({path_chars, scaled});
common_lora_adapter_info la;
la.path = path_chars;
la.scale = scaled;
lora_adapters.push_back(la);
}
}
return llama->applyLoraAdapters(lora_adapters);
Expand All @@ -939,15 +942,14 @@ Java_com_rnllama_LlamaContext_removeLoraAdapters(
JNIEXPORT jobject JNICALL
Java_com_rnllama_LlamaContext_getLoadedLoraAdapters(
JNIEnv *env, jobject thiz, jlong context_ptr) {
UNUSED(env);
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];
auto loaded_lora_adapters = llama->getLoadedLoraAdapters();
auto result = createWritableArray(env);
for (common_lora_adapter_container &la : loaded_lora_adapters) {
auto map = createWriteableMap(env);
pushString(env, map, la.path.c_str());
pushDouble(env, map, la.scale);
putString(env, map, "path", la.path.c_str());
putDouble(env, map, "scaled", la.scale);
pushMap(env, result, map);
}
return result;
Expand Down
1 change: 1 addition & 0 deletions cpp/rn-llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,7 @@ struct llama_rn_context
}

int applyLoraAdapters(std::vector<common_lora_adapter_info> lora_adapters) {
this->lora_adapters.clear();
auto containers = std::vector<common_lora_adapter_container>();
for (auto & la : lora_adapters) {
common_lora_adapter_container loaded_la;
Expand Down
7 changes: 6 additions & 1 deletion example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ export default function App() {
if (!modelRes?.[0]) return
const modelFile = await copyFileIfNeeded('model', modelRes?.[0])

let loraFile = null
let loraFile: any = null
// Example: Apply lora adapter (Currently only select one lora file) (Uncomment to use)
// loraFile = await pickLora()
loraFile = null
Expand Down Expand Up @@ -298,6 +298,11 @@ export default function App() {
),
)
return
case '/remove-lora':
context.removeLoraAdapters().then(() => {
addSystemMessage('Lora adapters removed!')
})
return
case '/lora-list':
context.getLoadedLoraAdapters().then((loraList) => {
addSystemMessage(
Expand Down

0 comments on commit 4446feb

Please sign in to comment.