Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

main: option to disable context shift #9484

Merged
merged 11 commits into from
Sep 16, 2024
8 changes: 7 additions & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.n_keep = value;
}
));
add_opt(llama_arg(
{"--no-context-shift"},
format("disables context shift on inifinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
[](gpt_params & params) {
params.ctx_shift = false;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(llama_arg(
{"--chunks"}, "N",
format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
Expand Down Expand Up @@ -1985,4 +1992,3 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,

return ctx_arg;
}

1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ struct gpt_params {
bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
bool no_perf = false; // disable performance metrics
bool ctx_shift = true; // context shift on inifinite text generation

bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
bool logits_all = false; // return logits for all tokens in the batch
Expand Down
2 changes: 2 additions & 0 deletions examples/main/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ A value of -1 will enable infinite text generation, even though we have a finite

If the pause is undesirable, a value of -2 will stop generation immediately when the context is filled.

The `--no-context-shift` option allows you to stop the infinite text generation once the finite context window is full.

It is important to note that the generated text may be shorter than the specified number of tokens if an End-of-Sequence (EOS) token or a reverse prompt is encountered. In interactive mode, text generation will pause and control will be returned to the user. In non-interactive mode, the program will end. In both cases, the text generation may stop before reaching the specified `--predict` value. If you want the model to keep going without ever producing End-of-Sequence on its own, you can use the `--ignore-eos` parameter.

### Temperature
Expand Down
34 changes: 20 additions & 14 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,29 +559,35 @@ int main(int argc, char ** argv) {
// if we run out of context:
// - take the n_keep first tokens from the original prompt (via n_past)
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches

if (n_past + (int) embd.size() >= n_ctx) {
if (params.n_predict == -2) {
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
if (!params.ctx_shift){
LOG_DBG("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
break;
}
} else {
if (params.n_predict == -2) {
LOG_DBG("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
break;
}

const int n_left = n_past - params.n_keep;
const int n_discard = n_left/2;
const int n_left = n_past - params.n_keep;
const int n_discard = n_left/2;

LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
LOG_DBG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);

llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);

n_past -= n_discard;
n_past -= n_discard;

LOG_DBG("after swap: n_past = %d\n", n_past);
LOG_DBG("after swap: n_past = %d\n", n_past);

LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());
LOG_DBG("embd: %s\n", string_from(ctx, embd).c_str());

LOG_DBG("clear session path\n");
path_session.clear();
LOG_DBG("clear session path\n");
path_session.clear();
}
}
} else {
// context extension via Self-Extend
Expand Down
Loading