diff --git a/examples/tokenize/tokenize.cpp b/examples/tokenize/tokenize.cpp index c817be566cf54..f36684ea358f9 100644 --- a/examples/tokenize/tokenize.cpp +++ b/examples/tokenize/tokenize.cpp @@ -32,6 +32,7 @@ static void print_usage_information(const char * argv0, FILE * stream) { fprintf(stream, " --no-parse-special do not parse control tokens.\n"); fprintf(stream, " --log-disable disable logs. Makes stderr quiet when loading the model.\n"); fprintf(stream, " --show-count print the total number of tokens.\n"); + fprintf(stream, " --show-count-only only print the total number of tokens (skip the printing of the actual tokens).\n"); } static void llama_log_callback_null(ggml_log_level level, const char * text, void * user_data) { @@ -199,6 +200,7 @@ int main(int raw_argc, char ** raw_argv) { bool no_parse_special = false; bool disable_logging = false; bool show_token_count = false; + bool show_token_count_only = false; const char * model_path = NULL; const char * prompt_path = NULL; const char * prompt_arg = NULL; @@ -259,6 +261,9 @@ int main(int raw_argc, char ** raw_argv) { else if (arg == "--show-count") { show_token_count = true; } + else if (arg == "--show-count-only") { + show_token_count_only = show_token_count = true; + } else { fprintf(stderr, "Error: unknown option '%s'\n", argv[iarg].c_str()); return 1; @@ -369,30 +374,32 @@ int main(int raw_argc, char ** raw_argv) { std::vector tokens; tokens = ::llama_tokenize(model, prompt, add_bos, parse_special); - if (printing_ids) { - printf("["); - } - - for (int i = 0; i < (int) tokens.size(); i++) { + if (!show_token_count_only) { if (printing_ids) { - if (i > 0) { - printf(", "); - } - printf("%d", tokens[i]); - } else { - bool invalid_utf8 = false; - printf("%6d -> '", tokens[i]); - write_utf8_cstr_to_stdout(llama_token_to_piece(ctx, tokens[i]).c_str(), invalid_utf8); - if (invalid_utf8) { - printf("' (utf-8 decode failure)\n"); + printf("["); + } + + for (int i = 0; i < (int) tokens.size(); i++) { + if (printing_ids) { + if (i > 0) { + printf(", "); + } + printf("%d", tokens[i]); } else { - printf("'\n"); + bool invalid_utf8 = false; + printf("%6d -> '", tokens[i]); + write_utf8_cstr_to_stdout(llama_token_to_piece(ctx, tokens[i]).c_str(), invalid_utf8); + if (invalid_utf8) { + printf("' (utf-8 decode failure)\n"); + } else { + printf("'\n"); + } } } - } - if (printing_ids) { - printf("]\n"); + if (printing_ids) { + printf("]\n"); + } } if (show_token_count) {