Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tokenizer : special token handling #3538

Merged
merged 11 commits into from
Oct 17, 2023
12 changes: 7 additions & 5 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -862,21 +862,23 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_bos) {
return llama_tokenize(llama_get_model(ctx), text, add_bos);
bool add_bos,
bool allow_special_tokens) {
return llama_tokenize(llama_get_model(ctx), text, add_bos, allow_special_tokens);
}

std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos) {
bool add_bos,
bool allow_special_tokens) {
// upper limit for the number of tokens
int n_tokens = text.length() + add_bos;
std::vector<llama_token> result(n_tokens);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
n_tokens = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, allow_special_tokens);
if (n_tokens < 0) {
result.resize(-n_tokens);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos);
int check = llama_tokenize(model, text.data(), text.length(), result.data(), result.size(), add_bos, allow_special_tokens);
GGML_ASSERT(check == -n_tokens);
} else {
result.resize(n_tokens);
Expand Down
6 changes: 4 additions & 2 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,14 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
std::vector<llama_token> llama_tokenize(
const struct llama_context * ctx,
const std::string & text,
bool add_bos);
bool add_bos,
bool allow_special_tokens = false);

std::vector<llama_token> llama_tokenize(
const struct llama_model * model,
const std::string & text,
bool add_bos);
bool add_bos,
bool allow_special_tokens = false);

// tokenizes a token into a piece
// should work similar to Python's `tokenizer.id_to_piece`
Expand Down
8 changes: 4 additions & 4 deletions common/train.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,7 @@ size_t tokenize_file(
(int) buf.size(),
out_tokens.data(),
(int) out_tokens.size(),
false);
false,false);
if (n_tokens < 0) {
out_tokens.resize(-n_tokens);
n_tokens = llama_tokenize(
Expand All @@ -872,7 +872,7 @@ size_t tokenize_file(
(int) buf.size(),
out_tokens.data(),
(int) out_tokens.size(),
false);
false,false);
}
if (n_tokens >= 0) {
out_tokens.resize(n_tokens);
Expand Down Expand Up @@ -966,15 +966,15 @@ size_t tokenize_file(
(int) buf_sample.size(),
tok_sample.data(),
(int) tok_sample.size(),
false);
false,false);
if (n_tokens < 0) {
tok_sample.resize(-n_tokens);
n_tokens = llama_tokenize(llama_get_model(lctx),
buf_sample.data(),
(int) buf_sample.size(),
tok_sample.data(),
(int) tok_sample.size(),
false);
false,false);
GGML_ASSERT(n_tokens >= 0);
}
GGML_ASSERT(n_tokens <= (int) tok_sample.size());
Expand Down
14 changes: 7 additions & 7 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ int main(int argc, char ** argv) {

if (params.interactive_first || params.instruct || !params.prompt.empty() || session_tokens.empty()) {
LOG("tokenize the prompt\n");
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
embd_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
} else {
LOG("use session tokens\n");
embd_inp = session_tokens;
Expand All @@ -259,10 +259,10 @@ int main(int argc, char ** argv) {
if (ctx_guidance) {
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(params.cfg_negative_prompt));

guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos);
guidance_inp = ::llama_tokenize(ctx_guidance, params.cfg_negative_prompt, add_bos, true);
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp));

std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos);
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp));

original_prompt_len = original_inp.size();
Expand Down Expand Up @@ -316,8 +316,8 @@ int main(int argc, char ** argv) {
}

// prefix & suffix for instruct mode
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos);
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false);
const auto inp_pfx = ::llama_tokenize(ctx, "\n\n### Instruction:\n\n", add_bos, true);
const auto inp_sfx = ::llama_tokenize(ctx, "\n\n### Response:\n\n", false, true);

LOG("inp_pfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_pfx));
LOG("inp_sfx: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, inp_sfx));
Expand Down Expand Up @@ -715,7 +715,7 @@ int main(int argc, char ** argv) {
if (params.interactive) {
if (!params.antiprompt.empty()) {
// tokenize and inject first reverse prompt
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false);
const auto first_antiprompt = ::llama_tokenize(ctx, params.antiprompt.front(), false, true);
embd_inp.insert(embd_inp.end(), first_antiprompt.begin(), first_antiprompt.end());
is_antiprompt = true;
}
Expand Down Expand Up @@ -780,7 +780,7 @@ int main(int argc, char ** argv) {
embd_inp.insert(embd_inp.end(), inp_pfx.begin(), inp_pfx.end());
}

const auto line_inp = ::llama_tokenize(ctx, buffer, false);
const auto line_inp = ::llama_tokenize(ctx, buffer, false, true);
LOG("input tokens: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, line_inp));

embd_inp.insert(embd_inp.end(), line_inp.begin(), line_inp.end());
Expand Down
Loading
Loading