From 5512a424bfc5ef5f663add407efccb0e1881f938 Mon Sep 17 00:00:00 2001 From: Manish Goregaokar Date: Tue, 29 Oct 2024 01:44:06 -0700 Subject: [PATCH] Add safety comments (#1651) * Unsafe comment for from_u32_unchecked * Add safety comments and type assertion for HashSet parallel iteration * Add safety comment for String splice * fixes * fmt * pos --- tokenizers/src/models/bpe/trainer.rs | 27 ++++++++++++++++----- tokenizers/src/pre_tokenizers/byte_level.rs | 3 +++ tokenizers/src/tokenizer/normalizer.rs | 8 ++++++ 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/tokenizers/src/models/bpe/trainer.rs b/tokenizers/src/models/bpe/trainer.rs index 2876f1ef5..a1a0aba76 100644 --- a/tokenizers/src/models/bpe/trainer.rs +++ b/tokenizers/src/models/bpe/trainer.rs @@ -454,7 +454,7 @@ impl BpeTrainer { // 3. Tokenize words // self.update_progress(&progress, word_counts.len(), "Tokenize words"); - let (words, counts) = + let (mut words, counts) = self.tokenize_words(word_counts, &mut word_to_id, &mut id_to_word, &progress); self.finalize_progress(&progress, words.len()); @@ -530,14 +530,29 @@ impl BpeTrainer { merges.push((top.pair, new_token_id)); // Merge the new pair in every words - let changes = top - .pos + // Safety: This is just a type assertion, the code below may no longer be safe + // if the type of `pos` changes + let pos: &HashSet = &top.pos; + + let words_len = words.len(); + struct WordPtr(*mut Word); + // Safety: We do not actually use this for concurrent access to the same memory, + // only to different chunks within the same allocation. + unsafe impl Sync for WordPtr {} + let word_start = WordPtr(words.as_mut_ptr()); + + let changes = pos .maybe_par_iter() .flat_map(|&i| { - let word = &words[i] as *const _ as *mut Word; - // We can merge each of these words in parallel here because each position - // can be there only once (HashSet). So this is safe. + // Safety: + // We are producing a valid pointer since we are indexing in bounds + // + // We can access each `word` here in parallel because each position + // can be there only once (pos is a HashSet). unsafe { + assert!(i < words_len); + // This is words[i], but avoids needing to go through &T (which triggers UB) + let word = word_start.0.add(i); // let word: &mut Word = &mut (*word); (*word) .merge(top.pair.0, top.pair.1, new_token_id, max_token_length) diff --git a/tokenizers/src/pre_tokenizers/byte_level.rs b/tokenizers/src/pre_tokenizers/byte_level.rs index 2d3845b55..8396f1a7b 100644 --- a/tokenizers/src/pre_tokenizers/byte_level.rs +++ b/tokenizers/src/pre_tokenizers/byte_level.rs @@ -28,6 +28,9 @@ pub(crate) fn bytes_char() -> HashMap { } } + // Safety: cs contains all values from bs (between 0 and 255), + // and some values of value 2⁸ + n, where n is between 0 and 255. This is between 255 and 512. + // Both ranges are valid UTF-32 values (which is fully saturated until 0xD000) bs.into_iter() .zip(cs) .map(|(f, t)| (f, unsafe { std::char::from_u32_unchecked(t) })) diff --git a/tokenizers/src/tokenizer/normalizer.rs b/tokenizers/src/tokenizer/normalizer.rs index 9cbbccee2..94b568874 100644 --- a/tokenizers/src/tokenizer/normalizer.rs +++ b/tokenizers/src/tokenizer/normalizer.rs @@ -411,8 +411,16 @@ impl NormalizedString { .collect::(); self.alignments.splice(n_range.clone(), alignments); + + // This bounds check already happens above (`self.normalized[n_range.clone()]`), but future + // code could change to mutate `self` or `self.normalized` in the interim. + // Perform it again and hope the optimizer collapses it. + assert!(self.normalized.get(n_range.clone()).is_some()); unsafe { self.normalized + // Safety: This is safe as long as we do not splice across a + // UTF-8 character, and we only add UTF-8 text. `normalized` is a String + // so the latter is trivially true, and we assert for the former above. .as_mut_vec() .splice(n_range, normalized.bytes()); }