Skip to content

Commit

Permalink
fix: seed not working properly
Browse files Browse the repository at this point in the history
- llama_set_rng_seed used in beginCompletion before
- line `always add a first space` should be removed because llama_tokenize_internal already did it

closes #41
  • Loading branch information
jhen0409 committed Jan 20, 2024
1 parent 9008c85 commit c5287f0
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 9 deletions.
7 changes: 3 additions & 4 deletions android/src/main/jni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ Java_com_rnllama_LlamaContext_doCompletion(
llama_reset_timings(llama->ctx);

llama->params.prompt = env->GetStringUTFChars(prompt, nullptr);
llama->params.seed = seed;

int max_threads = std::thread::hardware_concurrency();
// Use 2 threads by default on 4-core devices, 4 threads on more cores
Expand Down Expand Up @@ -328,8 +329,6 @@ Java_com_rnllama_LlamaContext_doCompletion(
sparams.n_probs = n_probs;
sparams.grammar = env->GetStringUTFChars(grammar, nullptr);

llama_set_rng_seed(llama->ctx, seed);

sparams.logit_bias.clear();
if (ignore_eos) {
sparams.logit_bias[llama_token_eos(llama->model)] = -INFINITY;
Expand Down Expand Up @@ -371,8 +370,8 @@ Java_com_rnllama_LlamaContext_doCompletion(
putString(env, result, "error", "Failed to initialize sampling");
return reinterpret_cast<jobject>(result);
}
llama->loadPrompt();
llama->beginCompletion();
llama->loadPrompt();

size_t sent_count = 0;
size_t sent_token_probs_index = 0;
Expand Down Expand Up @@ -550,8 +549,8 @@ Java_com_rnllama_LlamaContext_embedding(
llama->params.prompt = text_chars;

llama->params.n_predict = 0;
llama->loadPrompt();
llama->beginCompletion();
llama->loadPrompt();
llama->doCompletion();

std::vector<float> embedding = llama->getEmbedding();
Expand Down
1 change: 0 additions & 1 deletion cpp/rn-llama.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@ struct llama_rn_context

void loadPrompt()
{
params.prompt.insert(0, 1, ' '); // always add a first space
std::vector<llama_token> prompt_tokens = ::llama_tokenize(ctx, params.prompt, true);
num_prompt_tokens = prompt_tokens.size();

Expand Down
2 changes: 1 addition & 1 deletion example/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ export default function App() {
initLlama({
model: file.uri,
use_mlock: true,
n_gpu_layers: Platform.OS === 'ios' ? 1 : 0, // > 0: enable GPU
n_gpu_layers: Platform.OS === 'ios' ? 0 : 0, // > 0: enable GPU
// embedding: true,
})
.then((ctx) => {
Expand Down
7 changes: 4 additions & 3 deletions ios/RNLlamaContext.mm
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ + (instancetype)initWithParams:(NSDictionary *)params {
if (params[@"rope_freq_base"]) defaultParams.rope_freq_base = [params[@"rope_freq_base"] floatValue];
if (params[@"rope_freq_scale"]) defaultParams.rope_freq_scale = [params[@"rope_freq_scale"] floatValue];

if (params[@"seed"]) defaultParams.seed = [params[@"seed"] intValue];

int nThreads = params[@"n_threads"] ? [params[@"n_threads"] intValue] : 0;
const int maxThreads = (int) [[NSProcessInfo processInfo] processorCount];
// Use 2 threads by default on 4-core devices, 4 threads on more cores
Expand Down Expand Up @@ -131,6 +133,7 @@ - (NSDictionary *)completion:(NSDictionary *)params
NSString *prompt = [params objectForKey:@"prompt"];

llama->params.prompt = [prompt UTF8String];
llama->params.seed = params[@"seed"] ? [params[@"seed"] intValue] : -1;

if (params[@"n_threads"]) {
int nThreads = params[@"n_threads"] ? [params[@"n_threads"] intValue] : llama->params.n_threads;
Expand Down Expand Up @@ -164,8 +167,6 @@ - (NSDictionary *)completion:(NSDictionary *)params

if (params[@"typical_p"]) sparams.typical_p = [params[@"typical_p"] doubleValue];

llama_set_rng_seed(llama->ctx, params[@"seed"] ? [params[@"seed"] intValue] : -1);

if (params[@"grammar"]) {
sparams.grammar = [params[@"grammar"] UTF8String];
}
Expand Down Expand Up @@ -203,8 +204,8 @@ - (NSDictionary *)completion:(NSDictionary *)params
if (!llama->initSampling()) {
@throw [NSException exceptionWithName:@"LlamaException" reason:@"Failed to initialize sampling" userInfo:nil];
}
llama->loadPrompt();
llama->beginCompletion();
llama->loadPrompt();

size_t sent_count = 0;
size_t sent_token_probs_index = 0;
Expand Down

0 comments on commit c5287f0

Please sign in to comment.