Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support multiple lora files & dynamic apply / remove lora #92

Merged
merged 12 commits into from
Nov 21, 2024
1 change: 1 addition & 0 deletions android/src/main/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ set(
${RNLLAMA_LIB_DIR}/sgemm.cpp
${RNLLAMA_LIB_DIR}/ggml-aarch64.c
${RNLLAMA_LIB_DIR}/rn-llama.hpp
${CMAKE_SOURCE_DIR}/jni-utils.h
${CMAKE_SOURCE_DIR}/jni.cpp
)

Expand Down
24 changes: 23 additions & 1 deletion android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ public LlamaContext(int id, ReactApplicationContext reactContext, ReadableMap pa
params.hasKey("lora") ? params.getString("lora") : "",
// float lora_scaled,
params.hasKey("lora_scaled") ? (float) params.getDouble("lora_scaled") : 1.0f,
// ReadableArray lora_adapters,
params.hasKey("lora_list") ? params.getArray("lora_list") : null,
// float rope_freq_base,
params.hasKey("rope_freq_base") ? (float) params.getDouble("rope_freq_base") : 0.0f,
// float rope_freq_scale
Expand Down Expand Up @@ -301,6 +303,22 @@ public String bench(int pp, int tg, int pl, int nr) {
return bench(this.context, pp, tg, pl, nr);
}

public int applyLoraAdapters(ReadableArray loraAdapters) {
int result = applyLoraAdapters(this.context, loraAdapters);
if (result != 0) {
throw new IllegalStateException("Failed to apply lora adapters");
}
return result;
}

public void removeLoraAdapters() {
removeLoraAdapters(this.context);
}

public WritableArray getLoadedLoraAdapters() {
return getLoadedLoraAdapters(this.context);
}

public void release() {
freeContext(context);
}
Expand Down Expand Up @@ -406,6 +424,7 @@ protected static native long initContext(
boolean vocab_only,
String lora,
float lora_scaled,
ReadableArray lora_list,
float rope_freq_base,
float rope_freq_scale,
int pooling_type,
Expand Down Expand Up @@ -457,7 +476,7 @@ protected static native WritableMap doCompletion(
double[][] logit_bias,
float dry_multiplier,
float dry_base,
int dry_allowed_length,
int dry_allowed_length,
int dry_penalty_last_n,
String[] dry_sequence_breakers,
PartialCompletionCallback partial_completion_callback
Expand All @@ -473,6 +492,9 @@ protected static native WritableMap embedding(
int embd_normalize
);
protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
protected static native int applyLoraAdapters(long contextPtr, ReadableArray loraAdapters);
protected static native void removeLoraAdapters(long contextPtr);
protected static native WritableArray getLoadedLoraAdapters(long contextPtr);
protected static native void freeContext(long contextPtr);
protected static native void logToAndroid();
}
98 changes: 98 additions & 0 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,104 @@ protected void onPostExecute(String result) {
tasks.put(task, "bench-" + contextId);
}

public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Void>() {
private Exception exception;

@Override
protected Void doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
if (context.isPredicting()) {
throw new Exception("Context is busy");
}
context.applyLoraAdapters(loraAdapters);
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(Void result) {
if (exception != null) {
promise.reject(exception);
return;
}
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
tasks.put(task, "applyLoraAdapters-" + contextId);
}

public void removeLoraAdapters(double id, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Void>() {
private Exception exception;

@Override
protected Void doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
if (context.isPredicting()) {
throw new Exception("Context is busy");
}
context.removeLoraAdapters();
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(Void result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(null);
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
tasks.put(task, "removeLoraAdapters-" + contextId);
}

public void getLoadedLoraAdapters(double id, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, ReadableArray>() {
private Exception exception;

@Override
protected ReadableArray doInBackground(Void... voids) {
try {
LlamaContext context = contexts.get(contextId);
if (context == null) {
throw new Exception("Context not found");
}
return context.getLoadedLoraAdapters();
} catch (Exception e) {
exception = e;
}
return null;
}

@Override
protected void onPostExecute(ReadableArray result) {
if (exception != null) {
promise.reject(exception);
return;
}
promise.resolve(result);
}
}.executeOnExecutor(AsyncTask.THREAD_POOL_EXECUTOR);
tasks.put(task, "getLoadedLoraAdapters-" + contextId);
}

public void releaseContext(double id, Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Void>() {
Expand Down
94 changes: 94 additions & 0 deletions android/src/main/jni-utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#include <jni.h>

// ReadableMap utils

namespace readablearray {

int size(JNIEnv *env, jobject readableArray) {
jclass arrayClass = env->GetObjectClass(readableArray);
jmethodID sizeMethod = env->GetMethodID(arrayClass, "size", "()I");
return env->CallIntMethod(readableArray, sizeMethod);
}

jobject getMap(JNIEnv *env, jobject readableArray, int index) {
jclass arrayClass = env->GetObjectClass(readableArray);
jmethodID getMapMethod = env->GetMethodID(arrayClass, "getMap", "(I)Lcom/facebook/react/bridge/ReadableMap;");
return env->CallObjectMethod(readableArray, getMapMethod, index);
}

// Other methods not used yet

}

namespace readablemap {

bool hasKey(JNIEnv *env, jobject readableMap, const char *key) {
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID hasKeyMethod = env->GetMethodID(mapClass, "hasKey", "(Ljava/lang/String;)Z");
jstring jKey = env->NewStringUTF(key);
jboolean result = env->CallBooleanMethod(readableMap, hasKeyMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

int getInt(JNIEnv *env, jobject readableMap, const char *key, jint defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getIntMethod = env->GetMethodID(mapClass, "getInt", "(Ljava/lang/String;)I");
jstring jKey = env->NewStringUTF(key);
jint result = env->CallIntMethod(readableMap, getIntMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

bool getBool(JNIEnv *env, jobject readableMap, const char *key, jboolean defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getBoolMethod = env->GetMethodID(mapClass, "getBoolean", "(Ljava/lang/String;)Z");
jstring jKey = env->NewStringUTF(key);
jboolean result = env->CallBooleanMethod(readableMap, getBoolMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

long getLong(JNIEnv *env, jobject readableMap, const char *key, jlong defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getLongMethod = env->GetMethodID(mapClass, "getLong", "(Ljava/lang/String;)J");
jstring jKey = env->NewStringUTF(key);
jlong result = env->CallLongMethod(readableMap, getLongMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

float getFloat(JNIEnv *env, jobject readableMap, const char *key, jfloat defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getFloatMethod = env->GetMethodID(mapClass, "getDouble", "(Ljava/lang/String;)D");
jstring jKey = env->NewStringUTF(key);
jfloat result = env->CallDoubleMethod(readableMap, getFloatMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

jstring getString(JNIEnv *env, jobject readableMap, const char *key, jstring defaultValue) {
if (!hasKey(env, readableMap, key)) {
return defaultValue;
}
jclass mapClass = env->GetObjectClass(readableMap);
jmethodID getStringMethod = env->GetMethodID(mapClass, "getString", "(Ljava/lang/String;)Ljava/lang/String;");
jstring jKey = env->NewStringUTF(key);
jstring result = (jstring) env->CallObjectMethod(readableMap, getStringMethod, jKey);
env->DeleteLocalRef(jKey);
return result;
}

}
Loading
Loading