Skip to content

Commit

Permalink
llama : move random seed generation to the samplers (ggerganov#9398)
Browse files Browse the repository at this point in the history
* llama_sampler_penalties : clamp penalty_last_n to zero
  • Loading branch information
slaren authored and arthw committed Nov 15, 2024
1 parent 8e9e0ac commit d2fa884
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 34 deletions.
7 changes: 1 addition & 6 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
std::string arg;
const std::string arg_prefix = "--";
gpt_params & params = ctx_arg.params;
gpt_sampler_params & sparams = params.sparams;

std::unordered_map<std::string, llama_arg *> arg_to_options;
for (auto & opt : ctx_arg.options) {
Expand Down Expand Up @@ -283,10 +282,6 @@ static bool gpt_params_parse_ex(int argc, char ** argv, gpt_params_context & ctx
params.kv_overrides.back().key[0] = 0;
}

if (sparams.seed == LLAMA_DEFAULT_SEED) {
sparams.seed = time(NULL);
}

return true;
}

Expand Down Expand Up @@ -909,7 +904,7 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
).set_sparam());
add_opt(llama_arg(
{"-s", "--seed"}, "SEED",
format("RNG seed (default: %d, use random seed for < 0)", params.sparams.seed),
format("RNG seed (default: %u, use random seed for %u)", params.sparams.seed, LLAMA_DEFAULT_SEED),
[](gpt_params & params, const std::string & value) {
params.sparams.seed = std::stoul(value);
}
Expand Down
4 changes: 4 additions & 0 deletions common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,10 @@ llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context
return cur_p.data[cur_p.selected].id;
}

uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl) {
return llama_sampler_get_seed(gsmpl->chain);
}

// helpers

llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl) {
Expand Down
2 changes: 2 additions & 0 deletions common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler *
//
llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);

uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl);

// helpers

// access the internal list of current candidate tokens
Expand Down
2 changes: 0 additions & 2 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,6 @@ int main(int argc, char ** argv) {

print_build_info();

LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);

llama_backend_init();
llama_numa_init(params.numa);

Expand Down
7 changes: 3 additions & 4 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,6 @@ int main(int argc, char ** argv) {

print_build_info();

LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);

LOG("%s: llama backend init\n", __func__);
llama_backend_init();
llama_numa_init(params.numa);
Expand Down Expand Up @@ -301,6 +299,9 @@ int main(int argc, char ** argv) {
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
}
}
smpl = gpt_sampler_init(model, sparams);

LOG_TEE("sampling seed: %u\n", gpt_sampler_get_seed(smpl));
LOG_TEE("sampling: \n%s\n", sparams.print().c_str());
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
LOG_TEE("\n\n");
Expand Down Expand Up @@ -340,8 +341,6 @@ int main(int argc, char ** argv) {

std::vector<llama_token> embd;

smpl = gpt_sampler_init(model, sparams);

while (n_remain != 0 || params.interactive) {
// predict
if (!embd.empty()) {
Expand Down
6 changes: 3 additions & 3 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,8 +191,6 @@ int main(int argc, char ** argv) {

print_build_info();

LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);

LOG("%s: llama backend init\n", __func__);
llama_backend_init();
llama_numa_init(params.numa);
Expand Down Expand Up @@ -470,8 +468,10 @@ int main(int argc, char ** argv) {
exit(1);
}

LOG_TEE("sampling seed: %u\n", gpt_sampler_get_seed(smpl));
LOG_TEE("sampling params: \n%s\n", sparams.print().c_str());
LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());
LOG_TEE("sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());

LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);

// group-attention state
Expand Down
2 changes: 0 additions & 2 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2007,8 +2007,6 @@ int main(int argc, char ** argv) {

print_build_info();

LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);

llama_backend_init();
llama_numa_init(params.numa);

Expand Down
1 change: 1 addition & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,7 @@ struct server_context {
{"n_predict", slot.n_predict}, // Server configured n_predict
{"model", params.model_alias},
{"seed", slot.sparams.seed},
{"seed_cur", slot.smpl ? gpt_sampler_get_seed(slot.smpl) : 0},
{"temperature", slot.sparams.temp},
{"dynatemp_range", slot.sparams.dynatemp_range},
{"dynatemp_exponent", slot.sparams.dynatemp_exponent},
Expand Down
4 changes: 4 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,10 @@ extern "C" {
int32_t n_logit_bias,
const llama_logit_bias * logit_bias);


// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);

/// @details Sample and accept a token from the idx-th output of the last evaluation
//
// Shorthand for:
Expand Down
91 changes: 74 additions & 17 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <cstring>
#include <ctime>
#include <cfloat>
#include <chrono>
#include <cmath>
#include <numeric>
#include <random>
Expand Down Expand Up @@ -162,6 +163,19 @@ static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k)
cur_p->size = k;
}

static uint32_t get_rng_seed(uint32_t seed) {
if (seed == LLAMA_DEFAULT_SEED) {
// use system clock if std::random_device is not a true RNG
static bool is_rd_prng = std::random_device().entropy() == 0;
if (is_rd_prng) {
return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
}
std::random_device rd;
return rd();
}
return seed;
}

// llama_sampler API

const char * llama_sampler_name(const struct llama_sampler * smpl) {
Expand Down Expand Up @@ -387,6 +401,7 @@ struct llama_sampler * llama_sampler_init_greedy() {

struct llama_sampler_dist {
const uint32_t seed;
uint32_t seed_cur;

std::mt19937 rng;
};
Expand Down Expand Up @@ -416,7 +431,8 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample

static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_dist *) smpl->ctx;
ctx->rng = std::mt19937(ctx->seed);
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
}

static void llama_sampler_dist_free(struct llama_sampler * smpl) {
Expand All @@ -433,11 +449,13 @@ static struct llama_sampler_i llama_sampler_dist_i = {
};

struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
/* .iface = */ &llama_sampler_dist_i,
/* .ctx = */ new llama_sampler_dist {
/* .seed = */ seed,
/* .rng = */ std::mt19937(seed),
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
Expand Down Expand Up @@ -1032,6 +1050,7 @@ struct llama_sampler_mirostat {
const int32_t n_vocab;

const uint32_t seed;
uint32_t seed_cur;

const float tau;
const float eta;
Expand Down Expand Up @@ -1100,7 +1119,8 @@ static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sa
static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
ctx->mu = 2.0f*ctx->tau;
ctx->rng = std::mt19937(ctx->seed);
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
}

static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
Expand All @@ -1117,16 +1137,18 @@ static struct llama_sampler_i llama_sampler_mirostat_i = {
};

struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
/* .iface = */ &llama_sampler_mirostat_i,
/* .ctx = */ new llama_sampler_mirostat {
/* .n_vocab = */ n_vocab,
/* .seed = */ seed,
/* .tau = */ tau,
/* .eta = */ eta,
/* .m = */ m,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed),
/* .n_vocab = */ n_vocab,
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .tau = */ tau,
/* .eta = */ eta,
/* .m = */ m,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
Expand All @@ -1135,6 +1157,7 @@ struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t see

struct llama_sampler_mirostat_v2 {
const uint32_t seed;
uint32_t seed_cur;

const float tau;
const float eta;
Expand Down Expand Up @@ -1179,7 +1202,8 @@ static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_t
static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
ctx->mu = 2.0f*ctx->tau;
ctx->rng = std::mt19937(ctx->seed);
ctx->seed_cur = get_rng_seed(ctx->seed);
ctx->rng.seed(ctx->seed_cur);
}

static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
Expand Down Expand Up @@ -1212,14 +1236,16 @@ static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
};

struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
auto seed_cur = get_rng_seed(seed);
return new llama_sampler {
/* .iface = */ &llama_sampler_mirostat_v2_i,
/* .ctx = */ new llama_sampler_mirostat_v2 {
/* .seed = */ seed,
/* .tau = */ tau,
/* .eta = */ eta,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed),
/* .seed = */ seed,
/* .seed_cur = */ seed_cur,
/* .tau = */ tau,
/* .eta = */ eta,
/* .mu = */ 2.0f*tau,
/* .rng = */ std::mt19937(seed_cur),
},
};
}
Expand Down Expand Up @@ -1505,6 +1531,8 @@ struct llama_sampler * llama_sampler_init_penalties(
ignore_eos = false;
}

penalty_last_n = std::max(penalty_last_n, 0);

return new llama_sampler {
/* .iface = */ &llama_sampler_penalties_i,
/* .ctx = */ new llama_sampler_penalties {
Expand Down Expand Up @@ -1568,6 +1596,7 @@ static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_to
}
}
}

static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
Expand Down Expand Up @@ -1599,3 +1628,31 @@ struct llama_sampler * llama_sampler_init_logit_bias(
},
};
}

// utils

uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
if (smpl->iface == &llama_sampler_dist_i) {
return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
}

if (smpl->iface == &llama_sampler_mirostat_i) {
return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
}

if (smpl->iface == &llama_sampler_mirostat_v2_i) {
return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
}

if (smpl->iface == &llama_sampler_chain_i) {
const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
const uint32_t seed = llama_sampler_get_seed(*it);
if (seed != LLAMA_DEFAULT_SEED) {
return seed;
}
}
}

return LLAMA_DEFAULT_SEED;
}

0 comments on commit d2fa884

Please sign in to comment.