Skip to content

Commit

Permalink
feat: expose DRY sampler params (#91)
Browse files Browse the repository at this point in the history
* fix: loadSession not taking paths with file://

* feat: exposed dry sampling

* feat(ios): expose DRY sampler params

* feat(example): use default values in completion & remove comments

---------

Co-authored-by: Jhen-Jie Hong <[email protected]>
  • Loading branch information
Vali-98 and jhen0409 authored Nov 18, 2024
1 parent 8e7f439 commit 0c04b5e
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 14 deletions.
15 changes: 15 additions & 0 deletions android/src/main/java/com/rnllama/LlamaContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,16 @@ public WritableMap completion(ReadableMap params) {
params.hasKey("ignore_eos") ? params.getBoolean("ignore_eos") : false,
// double[][] logit_bias,
logit_bias,
// float dry_multiplier,
params.hasKey("dry_multiplier") ? (float) params.getDouble("dry_multiplier") : 0.00f,
// float dry_base,
params.hasKey("dry_base") ? (float) params.getDouble("dry_base") : 1.75f,
// int dry_allowed_length,
params.hasKey("dry_allowed_length") ? params.getInt("dry_allowed_length") : 2,
// int dry_penalty_last_n,
params.hasKey("dry_penalty_last_n") ? params.getInt("dry_penalty_last_n") : -1,
// String[] dry_sequence_breakers, when undef, we use the default definition from common.h
params.hasKey("dry_sequence_breakers") ? params.getArray("dry_sequence_breakers").toArrayList().toArray(new String[0]) : new String[]{"\n", ":", "\"", "*"},
// PartialCompletionCallback partial_completion_callback
new PartialCompletionCallback(
this,
Expand Down Expand Up @@ -445,6 +455,11 @@ protected static native WritableMap doCompletion(
String[] stop,
boolean ignore_eos,
double[][] logit_bias,
float dry_multiplier,
float dry_base,
int dry_allowed_length,
int dry_penalty_last_n,
String[] dry_sequence_breakers,
PartialCompletionCallback partial_completion_callback
);
protected static native void stopCompletion(long contextPtr);
Expand Down
25 changes: 25 additions & 0 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,11 @@ Java_com_rnllama_LlamaContext_doCompletion(
jobjectArray stop,
jboolean ignore_eos,
jobjectArray logit_bias,
jfloat dry_multiplier,
jfloat dry_base,
jint dry_allowed_length,
jint dry_penalty_last_n,
jobjectArray dry_sequence_breakers,
jobject partial_completion_callback
) {
UNUSED(thiz);
Expand Down Expand Up @@ -573,12 +578,32 @@ Java_com_rnllama_LlamaContext_doCompletion(
sparams.grammar = env->GetStringUTFChars(grammar, nullptr);
sparams.xtc_threshold = xtc_threshold;
sparams.xtc_probability = xtc_probability;
sparams.dry_multiplier = dry_multiplier;
sparams.dry_base = dry_base;
sparams.dry_allowed_length = dry_allowed_length;
sparams.dry_penalty_last_n = dry_penalty_last_n;

sparams.logit_bias.clear();
if (ignore_eos) {
sparams.logit_bias[llama_token_eos(llama->model)].bias = -INFINITY;
}

// dry break seq

jint size = env->GetArrayLength(dry_sequence_breakers);
std::vector<std::string> dry_sequence_breakers_vector;

for (jint i = 0; i < size; i++) {
jstring javaString = (jstring)env->GetObjectArrayElement(dry_sequence_breakers, i);
const char *nativeString = env->GetStringUTFChars(javaString, 0);
dry_sequence_breakers_vector.push_back(std::string(nativeString));
env->ReleaseStringUTFChars(javaString, nativeString);
env->DeleteLocalRef(javaString);
}

sparams.dry_sequence_breakers = dry_sequence_breakers_vector;

// logit bias
const int n_vocab = llama_n_vocab(llama_get_model(llama->ctx));
jsize logit_bias_len = env->GetArrayLength(logit_bias);

Expand Down
37 changes: 23 additions & 14 deletions example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -397,22 +397,32 @@ export default function App() {
{
messages: msgs,
n_predict: 100,
grammar,
seed: -1,
n_probs: 0,

// Sampling params
top_k: 40,
top_p: 0.5,
min_p: 0.05,
xtc_probability: 0.5,
xtc_threshold: 0.1,
typical_p: 1.0,
temperature: 0.7,
top_k: 40, // <= 0 to use vocab size
top_p: 0.5, // 1.0 = disabled
typical_p: 1.0, // 1.0 = disabled
penalty_last_n: 256, // 0 = disable penalty, -1 = context size
penalty_repeat: 1.18, // 1.0 = disabled
penalty_freq: 0.0, // 0.0 = disabled
penalty_present: 0.0, // 0.0 = disabled
mirostat: 0, // 0/1/2
mirostat_tau: 5, // target entropy
mirostat_eta: 0.1, // learning rate
penalize_nl: false, // penalize newlines
seed: -1, // random seed
n_probs: 0, // Show probabilities
penalty_last_n: 64,
penalty_repeat: 1.0,
penalty_freq: 0.0,
penalty_present: 0.0,
dry_multiplier: 0,
dry_base: 1.75,
dry_allowed_length: 2,
dry_penalty_last_n: -1,
dry_sequence_breakers: ["\n", ":", "\"", "*"],
mirostat: 0,
mirostat_tau: 5,
mirostat_eta: 0.1,
penalize_nl: false,
ignore_eos: false,
stop: [
'</s>',
'<|end|>',
Expand All @@ -424,7 +434,6 @@ export default function App() {
'<|end_of_turn|>',
'<|endoftext|>',
],
grammar,
// n_threads: 4,
// logit_bias: [[15043,1.0]],
},
Expand Down
13 changes: 13 additions & 0 deletions ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,19 @@ - (NSDictionary *)completion:(NSDictionary *)params
if (params[@"xtc_probability"]) sparams.xtc_probability = [params[@"xtc_probability"] doubleValue];
if (params[@"typical_p"]) sparams.typ_p = [params[@"typical_p"] doubleValue];

if (params[@"dry_multiplier"]) sparams.dry_multiplier = [params[@"dry_multiplier"] doubleValue];
if (params[@"dry_base"]) sparams.dry_base = [params[@"dry_base"] doubleValue];
if (params[@"dry_allowed_length"]) sparams.dry_allowed_length = [params[@"dry_allowed_length"] intValue];
if (params[@"dry_penalty_last_n"]) sparams.dry_penalty_last_n = [params[@"dry_penalty_last_n"] intValue];

// dry break seq
if (params[@"dry_sequence_breakers"] && [params[@"dry_sequence_breakers"] isKindOfClass:[NSArray class]]) {
NSArray *dry_sequence_breakers = params[@"dry_sequence_breakers"];
for (NSString *s in dry_sequence_breakers) {
sparams.dry_sequence_breakers.push_back([s UTF8String]);
}
}

if (params[@"grammar"]) {
sparams.grammar = [params[@"grammar"] UTF8String];
}
Expand Down
6 changes: 6 additions & 0 deletions src/NativeRNLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ export type NativeCompletionParams = {
penalize_nl?: boolean
seed?: number

dry_multiplier?: number
dry_base?: number
dry_allowed_length?: number
dry_penalty_last_n?: number
dry_sequence_breakers?: Array<string>

ignore_eos?: boolean
logit_bias?: Array<Array<number>>

Expand Down

0 comments on commit 0c04b5e

Please sign in to comment.