Skip to content

Commit

Permalink
Add byte token type to hf format
Browse files Browse the repository at this point in the history
  • Loading branch information
strutive07 authored Dec 26, 2023
1 parent a206137 commit 9f297f8
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def __init__(self, params: Params, fname_tokenizer: Path) -> None:
for tok in self.tokenizer.all_special_tokens
}
self.special_ids: set[int] = set(self.tokenizer.all_special_ids)
self.reverse_vocab = {id: encoded_tok for encoded_tok, id in self.tokenizer.get_vocab().items()}
self.vocab_size_base: int = self.tokenizer.vocab_size
self.vocab_size: int = self.vocab_size_base + len(self.added_tokens_dict)
self.fname_tokenizer: Path = fname_tokenizer
Expand All @@ -371,14 +372,13 @@ def __init__(self, params: Params, fname_tokenizer: Path) -> None:

def hf_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
tokenizer = self.tokenizer
reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.get_vocab().items()}
added_tokens_ids = set(self.added_tokens_dict.values())

for i in range(self.vocab_size_base):
if i in added_tokens_ids:
continue

text = reverse_vocab[i].encode("utf-8")
text = self.reverse_vocab[i].encode("utf-8")
yield text, self.get_token_score(i), self.get_token_type(i)

def get_token_type(self, token_id: int) -> gguf.TokenType:
Expand All @@ -394,10 +394,13 @@ def get_token_type(self, token_id: int) -> gguf.TokenType:
if self.spm.is_byte(token_id):
toktype = gguf.TokenType.BYTE
else:
token = self.reverse_vocab[token_id]
if token_id == self.unk_token_id:
toktype = gguf.TokenType.UNKNOWN
if token_id in self.special_ids:
elif token_id in self.special_ids:
toktype = gguf.TokenType.CONTROL
elif len(token) == 6 and token.startswith("<0x") and token.endswith(">"):
toktype = gguf.TokenType.BYTE

return toktype

Expand Down

0 comments on commit 9f297f8

Please sign in to comment.