From 3e884ac703778161d3825fdfb0fb77ec68934f81 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Fri, 3 Jan 2025 14:34:26 +0100 Subject: [PATCH] update test --- .../src/models/backtracking_bpe/model.rs | 70 ++++++++++++++++--- 1 file changed, 62 insertions(+), 8 deletions(-) diff --git a/tokenizers/src/models/backtracking_bpe/model.rs b/tokenizers/src/models/backtracking_bpe/model.rs index 1afd57dfa..b317121b9 100644 --- a/tokenizers/src/models/backtracking_bpe/model.rs +++ b/tokenizers/src/models/backtracking_bpe/model.rs @@ -234,7 +234,7 @@ fn token_iter<'a>(all_tokens: &'a [u8], token_starts: &'a [u32]) -> impl Iterato token_starts .iter() .tuple_windows() - .map( move |(start, end)| &all_tokens[*start as usize..*end as usize]) + .map(move |(start, end)| &all_tokens[*start as usize..*end as usize]) } fn next_match(longest_searcher: &DoubleArrayAhoCorasick, text: &[u8]) -> Option { @@ -483,12 +483,21 @@ impl BacktrackingBpe { } let vocab: HashMap = token_iter(&all_tokens, &token_starts) .enumerate() - .map(|(id, item)| (unsafe { String::from_utf8_unchecked(Vec::from(item.clone())) }, id as u32)) + .map(|(id, item)| { + ( + unsafe { String::from_utf8_unchecked(Vec::from(item.clone())) }, + id as u32, + ) + }) .collect(); let vocab_r: HashMap = token_iter(&all_tokens, &token_starts) .enumerate() - .map(|(id, item)| (id as u32, unsafe { String::from_utf8_unchecked(Vec::from(item.clone())) })) + .map(|(id, item)| { + (id as u32, unsafe { + String::from_utf8_unchecked(Vec::from(item.clone())) + }) + }) .collect(); let bpe = Self { @@ -658,7 +667,8 @@ impl BacktrackingBpe { } pub fn encode_via_backtracking(&self, text: &[u8]) -> Vec { - let mut enc = BacktrackState::new(text, None); + let next_token = self.next_match(text); + let mut enc = BacktrackState::new(text, next_token); while self.step(&mut enc).is_some() {} println!("_______________________________"); enc.into_tokens() @@ -801,16 +811,60 @@ mod tests { "aac", "ac", "cc", "cca", "aacc", "aaccca", "acca", "acc", "aa", "aaa", "aaaa", // 2 characters each ]; - let mut bpe = BacktrackingBpe::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None); + let mut bpe = + BacktrackingBpe::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None, None); // bpe.encode_via_backtracking(b"baacca"); - bpe.encode_via_backtracking(b"aaaacc"); - + let tokens = bpe.tokenize("aaaacc").unwrap(); + println!("{:?}", bpe.tokenize("aaaacc")); + assert_eq!( + tokens, + vec![ + Token { + id: 12, + value: String::from("aaa"), + offsets: (0, 0) + }, + Token { + id: 10, + value: String::from("acc"), + offsets: (0, 0) + } + ] + ); + println!("{:?}", bpe.tokenize("baaaaccca")); + let tokens = bpe.tokenize("baaaaccca").unwrap(); + assert_eq!( + tokens, + vec![ + Token { + id: 1, + value: String::from("b"), + offsets: (0, 0) + }, + Token { + id: 12, + value: String::from("aaa"), + offsets: (0, 0) + }, + Token { + id: 4, + value: String::from("ac"), + offsets: (0, 0) + }, + Token { + id: 6, + value: String::from("cca"), + offsets: (0, 0) + } + ] + ); bpe.encode_via_backtracking(b"baaaaccca"); let tokens = [ "a", "b", "c", // 1 character each "acca", "cc", "ac", "aac", "cca", ]; - let mut bpe = BacktrackingBpe::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None); + let mut bpe = + BacktrackingBpe::from_dictionary(tokens.map(|t| t.as_bytes().to_vec()), None, None); bpe.encode_via_backtracking(b"baacca"); } }