Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama_sampling_sample with default args is more naively usable #6519

Merged
merged 2 commits into from
Apr 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ llama_token llama_sampling_sample(
struct llama_sampling_context * ctx_sampling,
struct llama_context * ctx_main,
struct llama_context * ctx_cfg,
int idx = 0);
int idx = -1);

// Prepares and adjusts the set of token candidates for sampling based on penalties, biases, and sampling parameters.
llama_token_data_array llama_sampling_prepare(
Expand Down
38 changes: 29 additions & 9 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2177,7 +2177,7 @@ struct llama_context {

std::vector<int32_t> output_ids; // map batch token positions to ids of the logits and embd buffers
size_t output_size = 0; // capacity (of tokens positions) for the output buffers
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch
int32_t n_outputs = 0; // number of actually-used outputs in the current ubatch or last logical batch

bool logits_all = false;

Expand Down Expand Up @@ -10411,6 +10411,9 @@ static int llama_decode_internal(
n_outputs_prev += lctx.n_outputs;
}

// set to total number of outputs in the batch, for use in llama_get_logits_ith
lctx.n_outputs = n_outputs;

// wait for the computation to finish (automatically done when obtaining the model output)
//llama_synchronize(&lctx);

Expand Down Expand Up @@ -15511,23 +15514,31 @@ float * llama_get_logits(struct llama_context * ctx) {
}

float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) {
int32_t j = -1;
llama_synchronize(ctx);

try {
if (ctx->logits == nullptr) {
throw std::runtime_error("no logits");
}
if ((size_t) i >= ctx->output_ids.size()) {

if (i < 0) {
j = ctx->n_outputs + i;
if (j < 0) {
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
}
} else if ((size_t) i >= ctx->output_ids.size()) {
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
} else {
j = ctx->output_ids[i];
}
const int32_t j = ctx->output_ids[i];

if (j < 0) {
throw std::runtime_error(format("batch.logits[%d] != true", i));
}
if ((size_t) j >= ctx->output_size) {
if (j >= ctx->n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
}

return ctx->logits + j*ctx->model.hparams.n_vocab;
Expand All @@ -15547,23 +15558,32 @@ float * llama_get_embeddings(struct llama_context * ctx) {
}

float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) {
int32_t j = -1;

llama_synchronize(ctx);

try {
if (ctx->embd == nullptr) {
throw std::runtime_error("no embeddings");
}
if ((size_t) i >= ctx->output_ids.size()) {

if (i < 0) {
j = ctx->n_outputs + i;
if (j < 0) {
throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs));
}
} else if ((size_t) i >= ctx->output_ids.size()) {
throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size()));
} else {
j = ctx->output_ids[i];
}
const int32_t j = ctx->output_ids[i];

if (j < 0) {
throw std::runtime_error(format("batch.logits[%d] != true", i));
}
if ((size_t) j >= ctx->output_size) {
if (j >= ctx->n_outputs) {
// This should not happen
throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size));
throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs));
}

return ctx->embd + j*ctx->model.hparams.n_embd;
Expand Down
6 changes: 4 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,9 @@ extern "C" {
// Cols: n_vocab
LLAMA_API float * llama_get_logits(struct llama_context * ctx);

// Logits for the ith token. Equivalent to:
// Logits for the ith token. For positive indices, Equivalent to:
// llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab
// Negative indicies can be used to access logits in reverse order, -1 is the last logit.
// returns NULL for invalid ids.
LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i);

Expand All @@ -697,8 +698,9 @@ extern "C" {
// Otherwise, returns NULL.
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);

// Get the embeddings for the ith token. Equivalent to:
// Get the embeddings for the ith token. For positive indices, Equivalent to:
// llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
// Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding.
// shape: [n_embd] (1-dimensional)
// returns NULL for invalid ids.
LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i);
Expand Down
Loading