Skip to content

Commit

Permalink
feat: add size property for saveSession
Browse files Browse the repository at this point in the history
  • Loading branch information
jhen0409 committed Oct 20, 2023
1 parent 8a4f916 commit 829e219
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 18 deletions.
7 changes: 4 additions & 3 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ public WritableMap loadSession(String path) {
return result;
}

public int saveSession(String path) {
return saveSession(this.context, path);
public int saveSession(String path, int size) {
return saveSession(this.context, path, size);
}

public WritableMap completion(ReadableMap params) {
Expand Down Expand Up @@ -286,7 +286,8 @@ protected static native WritableMap loadSession(
);
protected static native int saveSession(
long contextPtr,
String path
String path,
int size
);
protected static native WritableMap doCompletion(
long context_ptr,
Expand Down
4 changes: 2 additions & 2 deletions android/src/main/java/com/rnllama/RNLlama.java
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ protected void onPostExecute(WritableMap result) {
tasks.put(task, "loadSession-" + contextId);
}

public void saveSession(double id, final String path, Promise promise) {
public void saveSession(double id, final String path, double size, Promise promise) {
final int contextId = (int) id;
AsyncTask task = new AsyncTask<Void, Void, Integer>() {
private Exception exception;
Expand All @@ -124,7 +124,7 @@ protected Integer doInBackground(Void... voids) {
if (context == null) {
throw new Exception("Context not found");
}
Integer count = context.saveSession(path);
Integer count = context.saveSession(path, (int) size);
return count;
} catch (Exception e) {
exception = e;
Expand Down
7 changes: 5 additions & 2 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,18 @@ Java_com_rnllama_LlamaContext_saveSession(
JNIEnv *env,
jobject thiz,
jlong context_ptr,
jstring path
jstring path,
jint size
) {
UNUSED(thiz);
auto llama = context_map[(long) context_ptr];

const char *path_chars = env->GetStringUTFChars(path, nullptr);

std::vector<llama_token> session_tokens = llama->embd;
if (!llama_save_session_file(llama->ctx, path_chars, session_tokens.data(), session_tokens.size())) {
int default_size = session_tokens.size();
int save_size = size > 0 && size <= default_size ? size : default_size;
if (!llama_save_session_file(llama->ctx, path_chars, session_tokens.data(), save_size)) {
env->ReleaseStringUTFChars(path, path_chars);
return -1;
}
Expand Down
4 changes: 2 additions & 2 deletions android/src/newarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ public void loadSession(double id, String path, Promise promise) {
}

@ReactMethod
public void saveSession(double id, String path, Promise promise) {
rnllama.saveSession(id, path, promise);
public void saveSession(double id, String path, double size, Promise promise) {
rnllama.saveSession(id, path, size, promise);
}

@ReactMethod
Expand Down
4 changes: 2 additions & 2 deletions android/src/oldarch/java/com/rnllama/RNLlamaModule.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ public void loadSession(double id, String path, Promise promise) {
}

@ReactMethod
public void saveSession(double id, String path, Promise promise) {
rnllama.saveSession(id, path, promise);
public void saveSession(double id, String path, int size, Promise promise) {
rnllama.saveSession(id, path, size, promise);
}

@ReactMethod
Expand Down
3 changes: 2 additions & 1 deletion ios/RNLlama.mm
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ @implementation RNLlama

RCT_EXPORT_METHOD(saveSession:(double)contextId
withFilePath:(NSString *)filePath
withSize:(double)size
withResolver:(RCTPromiseResolveBlock)resolve
withRejecter:(RCTPromiseRejectBlock)reject)
{
Expand All @@ -98,7 +99,7 @@ @implementation RNLlama
dispatch_async(dispatch_get_main_queue(), ^{ // TODO: Fix for use in llamaDQue
@try {
@autoreleasepool {
int count = [context saveSession:filePath];
int count = [context saveSession:filePath size:(int)size];
resolve(@(count));
}
} @catch (NSException *exception) {
Expand Down
2 changes: 1 addition & 1 deletion ios/RNLlamaContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
- (NSString *)detokenize:(NSArray *)tokens;
- (NSArray *)embedding:(NSString *)text;
- (NSDictionary *)loadSession:(NSString *)path;
- (int)saveSession:(NSString *)path;
- (int)saveSession:(NSString *)path size:(int)size;

- (void)invalidate;

Expand Down
6 changes: 4 additions & 2 deletions ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,11 @@ - (NSDictionary *)loadSession:(NSString *)path {
};
}

- (int)saveSession:(NSString *)path {
- (int)saveSession:(NSString *)path size:(int)size {
std::vector<llama_token> session_tokens = llama->embd;
if (!llama_save_session_file(llama->ctx, [path UTF8String], session_tokens.data(), session_tokens.size())) {
int default_size = session_tokens.size();
int save_size = size > 0 && size <= default_size ? size : default_size;
if (!llama_save_session_file(llama->ctx, [path UTF8String], session_tokens.data(), save_size)) {
@throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to save session" userInfo:nil];
}
return session_tokens.size();
Expand Down
2 changes: 1 addition & 1 deletion src/NativeRNLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ export interface Spec extends TurboModule {
initContext(params: NativeContextParams): Promise<NativeLlamaContext>;

loadSession(contextId: number, filepath: string): Promise<NativeSessionLoadResult>;
saveSession(contextId: number, filepath: string): Promise<number>;
saveSession(contextId: number, filepath: string, size: number): Promise<number>;
completion(contextId: number, params: NativeCompletionParams): Promise<NativeCompletionResult>;
stopCompletion(contextId: number): Promise<void>;
tokenize(contextId: number, text: string): Promise<NativeTokenizeResult>;
Expand Down
4 changes: 2 additions & 2 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ export class LlamaContext {
/**
* Save current cached prompt & completion state to a file.
*/
async saveSession(filepath: string): Promise<number> {
return RNLlama.saveSession(this.id, filepath)
async saveSession(filepath: string, options: { tokenSize: number }): Promise<number> {
return RNLlama.saveSession(this.id, filepath, options?.tokenSize)
}

async completion(
Expand Down

0 comments on commit 829e219

Please sign in to comment.