Skip to content

Commit

Permalink
fix: remove removePrevious
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Nov 20, 2024
1 parent 22ee787 commit 4f392a5
Show file tree
Hide file tree
Showing 11 changed files with 26 additions and 25 deletions.
6 changes: 3 additions & 3 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -303,8 +303,8 @@ public String bench(int pp, int tg, int pl, int nr) {
return bench(this.context, pp, tg, pl, nr);
}

public int applyLoraAdapters(ReadableArray loraAdapters, boolean removePrevious) {
int result = applyLoraAdapters(this.context, loraAdapters, removePrevious);
public int applyLoraAdapters(ReadableArray loraAdapters) {
int result = applyLoraAdapters(this.context, loraAdapters);
if (result != 0) {
throw new IllegalStateException("Failed to apply lora adapters");
}
Expand Down Expand Up @@ -492,7 +492,7 @@ protected static native WritableMap embedding(
int embd_normalize
);
protected static native String bench(long contextPtr, int pp, int tg, int pl, int nr);
protected static native int applyLoraAdapters(long contextPtr, ReadableArray loraAdapters, boolean removePrevious);
protected static native int applyLoraAdapters(long contextPtr, ReadableArray loraAdapters);
protected static native void removeLoraAdapters(long contextPtr);
protected static native WritableArray getLoadedLoraAdapters(long contextPtr);
protected static native void freeContext(long contextPtr);
Expand Down
4 changes: 2 additions & 2 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ protected void onPostExecute(String result) {
tasks.put(task, "bench-" + contextId);
}

public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final boolean removePrevious, final Promise promise) {
public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Void>() {
private Exception exception;
Expand All @@ -425,7 +425,7 @@ protected Void doInBackground(Void... voids) {
if (context == null) {
throw new Exception("Context not found");
}
context.applyLoraAdapters(loraAdapters, removePrevious);
context.applyLoraAdapters(loraAdapters);
} catch (Exception e) {
exception = e;
}
Expand Down
4 changes: 2 additions & 2 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -894,7 +894,7 @@ Java_com_rnllama_LlamaContext_bench(

JNIEXPORT jint JNICALL
Java_com_rnllama_LlamaContext_applyLoraAdapters(
JNIEnv *env, jobject thiz, jlong context_ptr, jobjectArray loraAdapters, jboolean removePrevious) {
JNIEnv *env, jobject thiz, jlong context_ptr, jobjectArray loraAdapters) {
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];

Expand All @@ -912,7 +912,7 @@ Java_com_rnllama_LlamaContext_applyLoraAdapters(
lora_adapters.push_back({path_chars, scaled});
}
}
return llama->applyLoraAdapters(lora_adapters, removePrevious);
return llama->applyLoraAdapters(lora_adapters);
}

JNIEXPORT void JNICALL
Expand Down
4 changes: 2 additions & 2 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ public void bench(double id, final double pp, final double tg, final double pl,
}

@ReactMethod
public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final boolean removePrevious, final Promise promise) {
rnllama.applyLoraAdapters(id, loraAdapters, removePrevious, promise);
public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) {
rnllama.applyLoraAdapters(id, loraAdapters, promise);
}

@ReactMethod
Expand Down
4 changes: 2 additions & 2 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ public void bench(double id, final double pp, final double tg, final double pl,
}

@ReactMethod
public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final boolean removePrevious, final Promise promise) {
rnllama.applyLoraAdapters(id, loraAdapters, removePrevious, promise);
public void applyLoraAdapters(double id, final ReadableArray loraAdapters, final Promise promise) {
rnllama.applyLoraAdapters(id, loraAdapters);
}

@ReactMethod
Expand Down
6 changes: 1 addition & 5 deletions cpp/rn-llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,11 +726,7 @@ struct llama_rn_context
std::string("]");
}

int applyLoraAdapters(std::vector<common_lora_adapter_info> lora_adapters, bool remove_previous = false) {
if (remove_previous) {
common_lora_adapters_remove(ctx, this->lora_adapters);
this->lora_adapters.clear();
}
int applyLoraAdapters(std::vector<common_lora_adapter_info> lora_adapters) {
auto containers = std::vector<common_lora_adapter_container>();
for (auto & la : lora_adapters) {
common_lora_adapter_container loaded_la;
Expand Down
3 changes: 1 addition & 2 deletions ios/RNLlama.mm
Original file line number Diff line number Diff line change
Expand Up @@ -273,12 +273,11 @@ - (NSArray *)supportedEvents {

RCT_EXPORT_METHOD(applyLoraAdapters:(double)contextId
withLoraAdapters:(NSArray *)loraAdapters
removePrevious:(BOOL)removePrevious
withResolver:(RCTPromiseResolveBlock)resolve
withRejecter:(RCTPromiseRejectBlock)reject)
{
RNLlamaContext *context = llamaContexts[[NSNumber numberWithDouble:contextId]];
[context applyLoraAdapters:loraAdapters removePrevious:removePrevious];
[context applyLoraAdapters:loraAdapters];
resolve(nil);
}

Expand Down
4 changes: 3 additions & 1 deletion ios/RNLlamaContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@
- (NSDictionary *)loadSession:(NSString *)path;
- (int)saveSession:(NSString *)path size:(int)size;
- (NSString *)bench:(int)pp tg:(int)tg pl:(int)pl nr:(int)nr;

- (void)applyLoraAdapters:(NSArray *)loraAdapters;
- (void)removeLoraAdapters;
- (NSArray *)getLoadedLoraAdapters;
- (void)invalidate;

@end
4 changes: 2 additions & 2 deletions ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -548,15 +548,15 @@ - (NSString *)bench:(int)pp tg:(int)tg pl:(int)pl nr:(int)nr {
return [NSString stringWithUTF8String:llama->bench(pp, tg, pl, nr).c_str()];
}

- (void)applyLoraAdapters:(NSArray *)loraAdapters removePrevious:(BOOL)removePrevious {
- (void)applyLoraAdapters:(NSArray *)loraAdapters {
std::vector<common_lora_adapter_info> lora_adapters;
for (NSDictionary *loraAdapter in loraAdapters) {
common_lora_adapter_info la;
la.path = [loraAdapter[@"path"] UTF8String];
la.scale = [loraAdapter[@"scaled"] doubleValue];
lora_adapters.push_back(la);
}
int result = llama->applyLoraAdapters(lora_adapters, removePrevious);
int result = llama->applyLoraAdapters(lora_adapters);
if (result != 0) {
@throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to apply lora adapters" userInfo:nil];
}
Expand Down
1 change: 0 additions & 1 deletion src/NativeRNLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,6 @@ export interface Spec extends TurboModule {
applyLoraAdapters(
contextId: number,
loraAdapters: Array<{ path: string; scaled?: number }>,
removePrevious: boolean,
): Promise<void>
removeLoraAdapters(contextId: number): Promise<void>
getLoadedLoraAdapters(
Expand Down
11 changes: 8 additions & 3 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,15 @@ export class LlamaContext {
}

async applyLoraAdapters(
loraAdapters: Array<{ path: string; scaled?: number }>,
removePrevious: boolean,
loraList: Array<{ path: string; scaled?: number }>
): Promise<void> {
return RNLlama.applyLoraAdapters(this.id, loraAdapters, removePrevious)
let loraAdapters: Array<{ path: string; scaled?: number }> = []
if (loraList)
loraAdapters = loraList.map((l) => ({
path: l.path.replace(/file:\/\//, ''),
scaled: l.scaled,
}))
return RNLlama.applyLoraAdapters(this.id, loraAdapters)
}

async removeLoraAdapters(): Promise<void> {
Expand Down

0 comments on commit 4f392a5

Please sign in to comment.