diff --git a/src/sentencepiece_tokenizer.cc b/src/sentencepiece_tokenizer.cc index 6ca31b8..476a820 100644 --- a/src/sentencepiece_tokenizer.cc +++ b/src/sentencepiece_tokenizer.cc @@ -37,7 +37,13 @@ class SentencePieceTokenizer : public Tokenizer { std::string IdToToken(int32_t id) final { return sentence_piece_.IdToPiece(id); } - int32_t TokenToId(const std::string& token) final { return sentence_piece_.PieceToId(token); } + int32_t TokenToId(const std::string& token) final { + int32_t id = sentence_piece_.PieceToId(token); + if (id == sentence_piece_.unk_id()) { + return -1; + } + return id; + } private: // the tokenizer diff --git a/web/src/tokenizers.ts b/web/src/tokenizers.ts index 72c6ab8..d2e77f9 100644 --- a/web/src/tokenizers.ts +++ b/web/src/tokenizers.ts @@ -75,6 +75,17 @@ export class Tokenizer { return res; } + /** + * Convert the given token to its corresponding id if it exists. If not, return -1. + * + * @param token the input token string. + * @returns The encoded token id. + */ + tokenToId(token: string): number { + const id = this.handle.TokenToId(token.slice()); + return id; + } + /** * Create a tokenizer from jsonArrayBuffer * diff --git a/web/src/tokenizers_binding.cc b/web/src/tokenizers_binding.cc index ec07032..f85bb96 100644 --- a/web/src/tokenizers_binding.cc +++ b/web/src/tokenizers_binding.cc @@ -23,5 +23,6 @@ EMSCRIPTEN_BINDINGS(tokenizers) { .function("Encode", &tokenizers::Tokenizer::Encode) .function("Decode", &tokenizers::Tokenizer::Decode) .function("GetVocabSize", &tokenizers::Tokenizer::GetVocabSize) - .function("IdToToken", &tokenizers::Tokenizer::IdToToken); + .function("IdToToken", &tokenizers::Tokenizer::IdToToken) + .function("TokenToId", &tokenizers::Tokenizer::TokenToId); } diff --git a/web/tests/src/index.ts b/web/tests/src/index.ts index bb7b5bc..83c1a10 100644 --- a/web/tests/src/index.ts +++ b/web/tests/src/index.ts @@ -27,6 +27,18 @@ async function testJSONTokenizer() { if (tok49407 !== "<|endoftext|>") { throw Error("Expect token 49407 to be <|endoftext|>"); } + + const id0 = tok.tokenToId("!"); + console.log("id0=" + id0); + if (id0 !== 0) { + throw Error("Expect token 0 to be !"); + } + + const id49407 = tok.tokenToId("<|endoftext|>"); + console.log("id49407=" + id49407); + if (id49407 !== 49407) { + throw Error("Expect token 49407 to be <|endoftext|>"); + } } async function testLlamaTokenizer() {