diff --git a/llama-burn/src/llama.rs b/llama-burn/src/llama.rs index e1d15c9..5548d78 100644 --- a/llama-burn/src/llama.rs +++ b/llama-burn/src/llama.rs @@ -332,7 +332,7 @@ impl Llama { ) -> GenerationOutput { let mut tokens = self.tokenize(prompt); let prompt_len = tokens.dims()[0]; - let eos_token = self.tokenizer.eos_id(); + let eos_token = self.tokenizer.eos_id() as i64; let mut num_tokens: usize = 0; let mut input_pos = Tensor::::arange(0..tokens.dims()[0] as i64, &self.device); @@ -368,7 +368,7 @@ impl Llama { let tokens = tokens.into_data().value[prompt_len..] .iter() - .map(|t| t.elem::()) + .map(|t| t.elem::()) .collect::>(); let generated = self.tokenizer.decode(tokens); diff --git a/llama-burn/src/tokenizer/base.rs b/llama-burn/src/tokenizer/base.rs index c70d79d..de185e7 100644 --- a/llama-burn/src/tokenizer/base.rs +++ b/llama-burn/src/tokenizer/base.rs @@ -5,14 +5,14 @@ pub trait Tokenizer { Self: Sized; /// Encode a string into a list of token identifiers. - fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec; + fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec; /// Decode a list of token identifiers into a string. - fn decode(&self, tokens: Vec) -> String; + fn decode(&self, tokens: Vec) -> String; /// Beginning of sentence token identifier. - fn bos_id(&self) -> i64; + fn bos_id(&self) -> u32; /// End of sentence token identifier. - fn eos_id(&self) -> i64; + fn eos_id(&self) -> u32; } diff --git a/llama-burn/src/tokenizer/sentence_piece.rs b/llama-burn/src/tokenizer/sentence_piece.rs index a0f9f35..801fe00 100644 --- a/llama-burn/src/tokenizer/sentence_piece.rs +++ b/llama-burn/src/tokenizer/sentence_piece.rs @@ -1,23 +1,20 @@ -use rust_tokenizers::tokenizer::{ - SentencePieceBpeTokenizer, Tokenizer as BaseTokenizer, TruncationStrategy, -}; +use tokenizers::Tokenizer as BaseTokenizer; use super::Tokenizer; -const BOS_TOKEN_ID: i64 = 1; -const EOS_TOKEN_ID: i64 = 2; +const BOS_TOKEN_ID: u32 = 1; +const EOS_TOKEN_ID: u32 = 2; pub struct SentiencePieceTokenizer { - bpe: SentencePieceBpeTokenizer, - bos_token_id: i64, - eos_token_id: i64, + bpe: BaseTokenizer, + bos_token_id: u32, + eos_token_id: u32, } impl Tokenizer for SentiencePieceTokenizer { /// Load the [SentenciePiece](https://github.com/google/sentencepiece) tokenizer. fn new(tokenizer_path: &str) -> Result { - let bpe = SentencePieceBpeTokenizer::from_file(tokenizer_path, false) - .map_err(|e| e.to_string())?; + let bpe = BaseTokenizer::from_file(tokenizer_path).map_err(|e| e.to_string())?; Ok(Self { bpe, @@ -26,34 +23,32 @@ impl Tokenizer for SentiencePieceTokenizer { }) } - fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec { + fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec { let bos_token = if bos { vec![self.bos_token_id] } else { vec![] }; let eos_token = if eos { vec![self.eos_token_id] } else { vec![] }; - // No text combination - let tokens = self - .bpe - .encode(text, None, usize::MAX, &TruncationStrategy::LongestFirst, 0) - .token_ids; + let tokens = self.bpe.encode(text, false).unwrap().get_ids().to_vec(); [bos_token, tokens, eos_token] .into_iter() .flat_map(|t| t.into_iter()) - .map(|t| t as i64) .collect() } - fn decode(&self, tokens: Vec) -> String { + fn decode(&self, tokens: Vec) -> String { self.bpe - .decode(&tokens, true, false) - .replace("<0x0A>", "\n") + .decode( + &tokens.into_iter().map(|t| t as u32).collect::>(), + true, + ) + .unwrap() } - fn bos_id(&self) -> i64 { - self.bos_token_id as i64 + fn bos_id(&self) -> u32 { + self.bos_token_id } - fn eos_id(&self) -> i64 { - self.eos_token_id as i64 + fn eos_id(&self) -> u32 { + self.eos_token_id } } diff --git a/llama-burn/src/tokenizer/tiktoken.rs b/llama-burn/src/tokenizer/tiktoken.rs index 648e0e2..ee18a00 100644 --- a/llama-burn/src/tokenizer/tiktoken.rs +++ b/llama-burn/src/tokenizer/tiktoken.rs @@ -85,7 +85,7 @@ impl Tokenizer for Tiktoken { }) } - fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec { + fn encode(&self, text: &str, bos: bool, eos: bool) -> Vec { let bos_token = if bos { vec![self.bos_token_id] } else { vec![] }; let eos_token = if eos { vec![self.eos_token_id] } else { vec![] }; @@ -95,21 +95,21 @@ impl Tokenizer for Tiktoken { [bos_token, tokens, eos_token] .into_iter() .flat_map(|t| t.into_iter()) - .map(|t| t as i64) + .map(|t| t as u32) .collect() } - fn decode(&self, tokens: Vec) -> String { + fn decode(&self, tokens: Vec) -> String { self.bpe .decode(tokens.into_iter().map(|t| t as usize).collect()) .expect("Should decode tokens") } - fn bos_id(&self) -> i64 { - self.bos_token_id as i64 + fn bos_id(&self) -> u32 { + self.bos_token_id as u32 } - fn eos_id(&self) -> i64 { - self.eos_token_id as i64 + fn eos_id(&self) -> u32 { + self.eos_token_id as u32 } }