From 16af28bee75eca7c847f98ae774d143438167a35 Mon Sep 17 00:00:00 2001 From: Arthur Zucker Date: Wed, 24 Jul 2024 20:09:45 +0200 Subject: [PATCH] revert and cleanup --- tokenizers/src/lib.rs | 2 -- tokenizers/src/tokenizer/added_vocabulary.rs | 22 +------------------- tokenizers/src/tokenizer/mod.rs | 3 ++- tokenizers/src/tokenizer/pre_tokenizer.rs | 19 ----------------- tokenizers/src/utils/parallelism.rs | 2 +- 5 files changed, 4 insertions(+), 44 deletions(-) diff --git a/tokenizers/src/lib.rs b/tokenizers/src/lib.rs index 071a72e50..eb89b9315 100644 --- a/tokenizers/src/lib.rs +++ b/tokenizers/src/lib.rs @@ -151,5 +151,3 @@ pub use utils::parallelism; // Re-export for from_pretrained #[cfg(feature = "http")] pub use utils::from_pretrained::FromPretrainedParameters; -#[global_allocator] -static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc; diff --git a/tokenizers/src/tokenizer/added_vocabulary.rs b/tokenizers/src/tokenizer/added_vocabulary.rs index 30bcae530..18d5800a5 100644 --- a/tokenizers/src/tokenizer/added_vocabulary.rs +++ b/tokenizers/src/tokenizer/added_vocabulary.rs @@ -453,26 +453,6 @@ impl AddedVocabulary { splits } - fn fast_split_with_indices( - &self, - sentence: NormalizedString, - split_re: &MatchingSet, - ) -> Vec<(NormalizedString, Option>)> { - self.find_matches(sentence.get(), split_re) - .into_iter() - .map(|(id, byte_offsets)| { - let slice = sentence - .slice(Range::Normalized(byte_offsets.0..byte_offsets.1)) - .expect("AddedVocabulary bad split"); - if let Some(id) = id { - (slice, Some(vec![Token::new(id, String::new(), (0, 0))])) - } else { - (slice, None) - } - }) - .collect() - } - /// Split the input sentence to extract anything we found from the `MatchingSet`, as well as /// the list of corresponding IDs /// The list of IDs have the exact same number of elements than the Iterator. @@ -514,7 +494,7 @@ impl AddedVocabulary { // 1. We extract all the non-normalized tokens from the non-normalized string pretokenized .split(|_, sequence| { - Ok(self.fast_split_with_indices( + Ok(self.split_with_indices( sequence, &self.split_trie_vec[hash_current_thread() % MAX_NUM_THREADS], )) diff --git a/tokenizers/src/tokenizer/mod.rs b/tokenizers/src/tokenizer/mod.rs index c7c879800..e0ee2eab4 100644 --- a/tokenizers/src/tokenizer/mod.rs +++ b/tokenizers/src/tokenizer/mod.rs @@ -895,7 +895,7 @@ where ) -> Result { let mut pretokenized: PreTokenizedString = pretokenized.into(); pretokenized.tokenize(|normalized| self.model.tokenize(normalized.get()))?; - pretokenized.fast_into_encoding() + pretokenized.into_encoding(word_idx, type_id, offsets_type) } } @@ -1070,6 +1070,7 @@ where .num_threads(num_threads) .build() .unwrap(); + let mut encodings = pool.install(|| { let result = inputs .into_maybe_par_iter() diff --git a/tokenizers/src/tokenizer/pre_tokenizer.rs b/tokenizers/src/tokenizer/pre_tokenizer.rs index c645d0da9..54e24f76a 100644 --- a/tokenizers/src/tokenizer/pre_tokenizer.rs +++ b/tokenizers/src/tokenizer/pre_tokenizer.rs @@ -186,25 +186,6 @@ impl PreTokenizedString { } } - pub fn fast_into_encoding(self) -> Result { - if self.splits.is_empty() { - Ok(Encoding::default()) - } else if !self.splits.iter().all(|split| split.tokens.is_some()) { - Err("Split has not been tokenized.".into()) - } else { - let tokens = self - .splits - .into_iter() - .flat_map(|split| { - split.tokens.unwrap().into_iter().map(|token| { - // Replace this with the actual fields you need for the Encoding type - (token.id, String::new(), (0, 0), None, 0) - }) - }) - .collect(); - Ok(tokens) - } - } /// Returns a list of splits, each of them being a slice of the normalized /// string, the associated offsets either in original or normalized /// referential, as well as the potention tokens diff --git a/tokenizers/src/utils/parallelism.rs b/tokenizers/src/utils/parallelism.rs index a59d7102f..b955731d1 100644 --- a/tokenizers/src/utils/parallelism.rs +++ b/tokenizers/src/utils/parallelism.rs @@ -73,7 +73,7 @@ where if parallelism { USED_PARALLELISM.store(true, Ordering::SeqCst); } - CondIterator::new(self, true) + CondIterator::new(self, parallelism) } fn into_maybe_par_iter_cond(self, cond: bool) -> CondIterator {