diff --git a/src/lib.rs b/src/lib.rs index 2b9e15ff..b466edd1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,85 +15,61 @@ use rustc_hash::FxHashMap as HashMap; type Rank = u32; -fn _byte_pair_merge( - ranks: &HashMap, Rank>, - piece: &[u8], -) -> Vec<(usize, Rank)> { +fn _byte_pair_merge(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { // This is a vector of (start, rank). - // The rank is of the byte pair starting at position start. - // The rank of the last item in the vector is not a valid value. - let mut parts: Vec<(usize, Rank)> = (0..piece.len() + 1).map(|i| (i, Rank::MAX)).collect(); + // The rank is of the pair starting at position start. + let mut parts = Vec::with_capacity(piece.len() + 1); + + // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE + // the way we currently do, this is equivalent. An easy way to break this would be to decouple + // merge priority from token index or to prevent specific token merges. + let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX); + for i in 0..piece.len() - 1 { + let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX); + if rank < min_rank.0 { + min_rank = (rank, i); + } + parts.push((i, rank)); + } + parts.push((piece.len() - 1, Rank::MAX)); + parts.push((piece.len(), Rank::MAX)); let get_rank = { #[inline(always)] - |parts: &Vec<(usize, Rank)>, start_idx: usize, skip: usize| { - if (start_idx + skip + 2) < parts.len() { - ranks - .get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0]) - .copied() + |parts: &Vec<(usize, Rank)>, i: usize| { + if (i + 3) < parts.len() { + // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted + // parts[i + 1], see comment in the main loop. + *ranks + .get(&piece[parts[i].0..parts[i + 3].0]) + .unwrap_or(&Rank::MAX) } else { - None + Rank::MAX } } }; - // We look up the ranks once in the beginning and iteratively update - // them during each merge, which reduces the number of rank lookups. - for i in 0..parts.len() - 2 { - match get_rank(&parts, i, 0) { - Some(rank) => { - // Rank::MAX is a sentinel value and cannot be a valid rank - debug_assert!(rank != Rank::MAX); - parts[i].1 = rank; - } - None => { - continue; - } - }; - } - // If you have n parts and m merges, this does O(mn) work. // We could do something with a heap and do O(m log n) work. - // It is important to consider that n is often small (<100), and as such - // the cache-locality benefits outweigh the algorithmic complexity downsides - // of the `parts` vector data structure above. - - // Note that we hash bytes, not token pairs. As long as we train BPE the way we - // currently do, this is equivalent. An easy way to break this would be to decouple - // merge priority from token index or to prevent specific token merges. - loop { - if parts.len() == 1 { - break; + // n is often very small so considerations like cache-locality outweigh the algorithmic + // complexity downsides of the `parts` vector. + while min_rank.0 != Rank::MAX { + let i = min_rank.1; + // Update parts[i] and parts[i - 1] before removing parts[i + 1], since + // `parts.remove(i + 1)` will thrash the cache. + if i > 0 { + parts[i - 1].1 = get_rank(&parts, i - 1); } + parts[i].1 = get_rank(&parts, i); + parts.remove(i + 1); - // Rank::MAX is a sentinel rank value allowing us to - // take the min more quickly - let mut min_rank: (Rank, usize) = (Rank::MAX, 0); + min_rank = (Rank::MAX, usize::MAX); for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() { if rank < min_rank.0 { min_rank = (rank, i); } } - - if min_rank.0 != Rank::MAX { - let i = min_rank.1; - - // NOTE: We are about to remove parts[i + 1]. We do not do it - // yet because there are cache-locality benefits to updating - // parts[i] and parts[i-1] before removing, which could thrash - // the cache. Thus, we update the rank calculation by skipping over - // parts[i + 1], by invoking `get_rank!` with `skip = 1`. - parts[i].1 = get_rank(&parts, i, 1).unwrap_or(Rank::MAX); - if i > 0 { - parts[i - 1].1 = get_rank(&parts, i - 1, 1).unwrap_or(Rank::MAX); - } - - parts.remove(i + 1); - } else { - break; - } } - parts }