Skip to content

Commit

Permalink
Add tokenizer flag: clean_up_tokenization_spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
jaime-m-p committed Jun 21, 2024
1 parent 6d233bc commit b452e82
Showing 1 changed file with 71 additions and 0 deletions.
71 changes: 71 additions & 0 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2322,6 +2322,7 @@ struct llama_vocab {
bool tokenizer_add_bos = false;
bool tokenizer_add_eos = false;
bool tokenizer_ignore_merges = false;
bool tokenizer_clean_spaces = false; // clean_up_tokenization_spaces

int find_bpe_rank(const std::string & token_left, const std::string & token_right) const {
GGML_ASSERT(token_left.find(' ') == std::string::npos);
Expand Down Expand Up @@ -4823,6 +4824,7 @@ static void llm_load_vocab(
// for now, only BPE models have pre-tokenizers
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
vocab.tokenizer_add_space_prefix = false;
vocab.tokenizer_clean_spaces = true;
if (tokenizer_pre.empty()) {
LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
LLAMA_LOG_WARN("%s: \n", __func__);
Expand All @@ -4844,9 +4846,11 @@ static void llm_load_vocab(
} else if (
tokenizer_pre == "deepseek-llm") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM;
vocab.tokenizer_clean_spaces = false;
} else if (
tokenizer_pre == "deepseek-coder") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER;
vocab.tokenizer_clean_spaces = false;
} else if (
tokenizer_pre == "falcon") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_FALCON;
Expand All @@ -4858,6 +4862,7 @@ static void llm_load_vocab(
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STARCODER;
} else if (
tokenizer_pre == "gpt-2" ||
tokenizer_pre == "phi-2" ||
tokenizer_pre == "jina-es" ||
tokenizer_pre == "jina-de" ||
tokenizer_pre == "jina-v2-es" ||
Expand All @@ -4873,6 +4878,7 @@ static void llm_load_vocab(
} else if (
tokenizer_pre == "qwen2") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_QWEN2;
vocab.tokenizer_clean_spaces = false;
} else if (
tokenizer_pre == "stablelm2") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_STABLELM2;
Expand All @@ -4888,17 +4894,20 @@ static void llm_load_vocab(
} else if (
tokenizer_pre == "poro-chat") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_PORO;
vocab.tokenizer_clean_spaces = false;
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}
} else if (vocab.type == LLAMA_VOCAB_TYPE_SPM) {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
vocab.tokenizer_add_space_prefix = true;
vocab.tokenizer_clean_spaces = false;
vocab.tokenizer_add_bos = true;
vocab.tokenizer_add_eos = false;
} else if (vocab.type == LLAMA_VOCAB_TYPE_WPM) {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
vocab.tokenizer_add_space_prefix = false;
vocab.tokenizer_clean_spaces = true;
vocab.tokenizer_add_bos = true;
vocab.tokenizer_add_eos = false;
} else {
Expand Down Expand Up @@ -18519,6 +18528,68 @@ int32_t llama_detokenize(
}
}

if (total > text_len_max) {
return -total;
}

if (model->vocab.tokenizer_clean_spaces) {
text -= total; // restart text

// first pass: characters ?!., //TODO: where do these characters come from?
const int32_t total1 = total;
total = total ? 1 : 0;
for (int32_t i = 1; i < total1; ++i) {
const char x = text[i];
if (text[i - 1] == ' ') {
if (x == '?' || x == '!' || x == '.' || x == ',') { // " ?", " !", " .", " ,"
total--; // remove space
}
}
text[total++] = x;
}

// second pass: strip single apostrophe between spaces
const int32_t total2 = total;
total = total ? 1 : 0;
for (int32_t i = 1; i < total2; ++i) {
const char x = text[i];
if (x == '\'' && i + 1 < total2 && text[i - 1] == ' ' && text[i + 1] == ' ') { // " ' "
total--; // remove prev space
text[++i] = '\0'; // remove next space
}
text[total++] = x;
}

// third pass: apostrophe contractions //NOTE: this makes sense?
const int32_t total3 = total;
total = total ? 1 : 0;
for (int32_t i = 1; i < total3; ++i) {
const char x = text[i];
if (text[i - 1] == ' ') {
if (x == '\'' && i + 1 < total3) {
const char x1 = text[i + 1];
if (x1 == 't' || x1 == 'd') { // " 't", " 'd"
//total--; // remove space
} else if (x1 == 's' || x1 == 'm') { // " 's", " 'm"
total--; // remove space
} else if (i + 2 < total3) {
const char x2 = text[i + 2];
if ((x1 == 'l' && x2 == 'l')) { // " 'll"
//total--; // remove space
} else if ((x1 == 'r' && x2 == 'e') || (x1 == 'v' && x2 == 'e')) { // " 're", " 've"
total--; // remove space
} else {
//total--; // remove space
}
} else {
//total--; // remove space
}
}
}
text[total++] = x;
}
}

return total <= text_len_max ? total : -total;
}

Expand Down

0 comments on commit b452e82

Please sign in to comment.