Skip to content

Commit

Permalink
infill : assert prefix/suffix tokens + remove old space logic (#8351)
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov authored Jul 8, 2024
1 parent ffd0079 commit 6f0dbf6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 18 deletions.
2 changes: 1 addition & 1 deletion common/log.h
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,7 @@ inline std::string LOG_TOKENS_TOSTR_PRETTY(const C & ctx, const T & tokens)
buf << "[ ";

bool first = true;
for (const auto &token : tokens)
for (const auto & token : tokens)
{
if (!first) {
buf << ", ";
Expand Down
25 changes: 8 additions & 17 deletions examples/infill/infill.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,21 +204,17 @@ int main(int argc, char ** argv) {
GGML_ASSERT(llama_add_eos_token(model) != 1);
LOG("add_bos: %d\n", add_bos);

bool suff_rm_leading_spc = params.escape;
if (suff_rm_leading_spc && params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
std::vector<llama_token> embd_inp;
std::vector<llama_token> embd_end;
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
const int space_token = 29871;
if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
inp_sfx.erase(inp_sfx.begin());
}

GGML_ASSERT(llama_token_prefix(model) >= 0);
GGML_ASSERT(llama_token_suffix(model) >= 0);

inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));

embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
embd_end = params.spm_infill ? inp_pfx : inp_sfx;
if (add_bos) {
Expand Down Expand Up @@ -516,19 +512,14 @@ int main(int argc, char ** argv) {
string_process_escapes(params.input_prefix);
string_process_escapes(params.input_suffix);
}
suff_rm_leading_spc = params.escape;
if (suff_rm_leading_spc && params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1) {
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}

// tokenize new prefix and suffix
std::vector<llama_token> inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false);
std::vector<llama_token> inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false);
if (suff_rm_leading_spc && inp_sfx[0] == space_token) {
inp_sfx.erase(inp_sfx.begin());
}

inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(model));
inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(model));

embd_inp = params.spm_infill ? inp_sfx : inp_pfx;
embd_end = params.spm_infill ? inp_pfx : inp_sfx;
if (add_bos) {
Expand Down

0 comments on commit 6f0dbf6

Please sign in to comment.