diff --git a/android/src/main/java/com/rnllama/LlamaContext.java b/android/src/main/java/com/rnllama/LlamaContext.java index ee1d571..b369bb1 100644 --- a/android/src/main/java/com/rnllama/LlamaContext.java +++ b/android/src/main/java/com/rnllama/LlamaContext.java @@ -237,6 +237,16 @@ public WritableMap completion(ReadableMap params) { params.hasKey("ignore_eos") ? params.getBoolean("ignore_eos") : false, // double[][] logit_bias, logit_bias, + // float dry_multiplier, + params.hasKey("dry_multiplier") ? (float) params.getDouble("dry_multiplier") : 0.00f, + // float dry_base, + params.hasKey("dry_base") ? (float) params.getDouble("dry_base") : 1.75f, + // int dry_allowed_length, + params.hasKey("dry_allowed_length") ? params.getInt("dry_allowed_length") : 2, + // int dry_penalty_last_n, + params.hasKey("dry_penalty_last_n") ? params.getInt("dry_penalty_last_n") : -1, + // String[] dry_sequence_breakers, when undef, we use the default definition from common.h + params.hasKey("dry_sequence_breakers") ? params.getArray("dry_sequence_breakers").toArrayList().toArray(new String[0]) : new String[]{"\n", ":", "\"", "*"}, // PartialCompletionCallback partial_completion_callback new PartialCompletionCallback( this, @@ -445,6 +455,11 @@ protected static native WritableMap doCompletion( String[] stop, boolean ignore_eos, double[][] logit_bias, + float dry_multiplier, + float dry_base, + int dry_allowed_length, + int dry_penalty_last_n, + String[] dry_sequence_breakers, PartialCompletionCallback partial_completion_callback ); protected static native void stopCompletion(long contextPtr); diff --git a/android/src/main/jni.cpp b/android/src/main/jni.cpp index db496fd..1ec3d19 100644 --- a/android/src/main/jni.cpp +++ b/android/src/main/jni.cpp @@ -535,6 +535,11 @@ Java_com_rnllama_LlamaContext_doCompletion( jobjectArray stop, jboolean ignore_eos, jobjectArray logit_bias, + jfloat dry_multiplier, + jfloat dry_base, + jint dry_allowed_length, + jint dry_penalty_last_n, + jobjectArray dry_sequence_breakers, jobject partial_completion_callback ) { UNUSED(thiz); @@ -573,12 +578,32 @@ Java_com_rnllama_LlamaContext_doCompletion( sparams.grammar = env->GetStringUTFChars(grammar, nullptr); sparams.xtc_threshold = xtc_threshold; sparams.xtc_probability = xtc_probability; + sparams.dry_multiplier = dry_multiplier; + sparams.dry_base = dry_base; + sparams.dry_allowed_length = dry_allowed_length; + sparams.dry_penalty_last_n = dry_penalty_last_n; sparams.logit_bias.clear(); if (ignore_eos) { sparams.logit_bias[llama_token_eos(llama->model)].bias = -INFINITY; } + // dry break seq + + jint size = env->GetArrayLength(dry_sequence_breakers); + std::vector dry_sequence_breakers_vector; + + for (jint i = 0; i < size; i++) { + jstring javaString = (jstring)env->GetObjectArrayElement(dry_sequence_breakers, i); + const char *nativeString = env->GetStringUTFChars(javaString, 0); + dry_sequence_breakers_vector.push_back(std::string(nativeString)); + env->ReleaseStringUTFChars(javaString, nativeString); + env->DeleteLocalRef(javaString); + } + + sparams.dry_sequence_breakers = dry_sequence_breakers_vector; + + // logit bias const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx)); jsize logit_bias_len = env->GetArrayLength(logit_bias); diff --git a/example/src/App.tsx b/example/src/App.tsx index 628ea07..2cb4c59 100644 --- a/example/src/App.tsx +++ b/example/src/App.tsx @@ -397,22 +397,32 @@ export default function App() { { messages: msgs, n_predict: 100, + grammar, + seed: -1, + n_probs: 0, + + // Sampling params + top_k: 40, + top_p: 0.5, + min_p: 0.05, xtc_probability: 0.5, xtc_threshold: 0.1, + typical_p: 1.0, temperature: 0.7, - top_k: 40, // <= 0 to use vocab size - top_p: 0.5, // 1.0 = disabled - typical_p: 1.0, // 1.0 = disabled - penalty_last_n: 256, // 0 = disable penalty, -1 = context size - penalty_repeat: 1.18, // 1.0 = disabled - penalty_freq: 0.0, // 0.0 = disabled - penalty_present: 0.0, // 0.0 = disabled - mirostat: 0, // 0/1/2 - mirostat_tau: 5, // target entropy - mirostat_eta: 0.1, // learning rate - penalize_nl: false, // penalize newlines - seed: -1, // random seed - n_probs: 0, // Show probabilities + penalty_last_n: 64, + penalty_repeat: 1.0, + penalty_freq: 0.0, + penalty_present: 0.0, + dry_multiplier: 0, + dry_base: 1.75, + dry_allowed_length: 2, + dry_penalty_last_n: -1, + dry_sequence_breakers: ["\n", ":", "\"", "*"], + mirostat: 0, + mirostat_tau: 5, + mirostat_eta: 0.1, + penalize_nl: false, + ignore_eos: false, stop: [ '', '<|end|>', @@ -424,7 +434,6 @@ export default function App() { '<|end_of_turn|>', '<|endoftext|>', ], - grammar, // n_threads: 4, // logit_bias: [[15043,1.0]], }, diff --git a/ios/RNLlamaContext.mm b/ios/RNLlamaContext.mm index ac4e601..c2cb593 100644 --- a/ios/RNLlamaContext.mm +++ b/ios/RNLlamaContext.mm @@ -292,6 +292,19 @@ - (NSDictionary *)completion:(NSDictionary *)params if (params[@"xtc_probability"]) sparams.xtc_probability = [params[@"xtc_probability"] doubleValue]; if (params[@"typical_p"]) sparams.typ_p = [params[@"typical_p"] doubleValue]; + if (params[@"dry_multiplier"]) sparams.dry_multiplier = [params[@"dry_multiplier"] doubleValue]; + if (params[@"dry_base"]) sparams.dry_base = [params[@"dry_base"] doubleValue]; + if (params[@"dry_allowed_length"]) sparams.dry_allowed_length = [params[@"dry_allowed_length"] intValue]; + if (params[@"dry_penalty_last_n"]) sparams.dry_penalty_last_n = [params[@"dry_penalty_last_n"] intValue]; + + // dry break seq + if (params[@"dry_sequence_breakers"] && [params[@"dry_sequence_breakers"] isKindOfClass:[NSArray class]]) { + NSArray *dry_sequence_breakers = params[@"dry_sequence_breakers"]; + for (NSString *s in dry_sequence_breakers) { + sparams.dry_sequence_breakers.push_back([s UTF8String]); + } + } + if (params[@"grammar"]) { sparams.grammar = [params[@"grammar"] UTF8String]; } diff --git a/src/NativeRNLlama.ts b/src/NativeRNLlama.ts index 400e84b..04e06e8 100644 --- a/src/NativeRNLlama.ts +++ b/src/NativeRNLlama.ts @@ -72,6 +72,12 @@ export type NativeCompletionParams = { penalize_nl?: boolean seed?: number + dry_multiplier?: number + dry_base?: number + dry_allowed_length?: number + dry_penalty_last_n?: number + dry_sequence_breakers?: Array + ignore_eos?: boolean logit_bias?: Array>