diff --git a/utils/src/range/mod.rs b/utils/src/range/mod.rs index 66f8fdc..bae83eb 100644 --- a/utils/src/range/mod.rs +++ b/utils/src/range/mod.rs @@ -127,6 +127,40 @@ impl RangeSet { pub fn max(&self) -> Option { self.ranges.last().map(|range| range.end) } + + /// Splits the set into two at the provided value. + /// + /// Returns a new set containing all the existing elements `>= at`. After the call, + /// the original set will be left containing the elements `< at`. + /// + /// # Panics + /// + /// Panics if `at` is not in the set. + pub fn split_off(&mut self, at: &T) -> Self { + // Find the index of the range containing `at` + let idx = self + .ranges + .iter() + .position(|range| range.contains(at)) + .expect("`at` is in the set"); + + // Split off the range containing `at` and all the ranges to the right. + let mut split_ranges = self.ranges.split_off(idx); + + // If the first range starts before `at` we have to push those values back + // into the existing set and truncate. + if *at > split_ranges[0].start { + self.ranges.push(Range { + start: split_ranges[0].start, + end: *at, + }); + split_ranges[0].start = *at; + } + + Self { + ranges: split_ranges, + } + } } impl RangeSet @@ -337,6 +371,7 @@ impl RangeSubset> for Range { #[allow(clippy::all)] mod tests { use super::*; + use rstest::*; #[test] fn test_range_disjoint() { @@ -432,4 +467,34 @@ mod tests { _ = iter.next(); assert_eq!(iter.len(), 2); } + + #[rstest] + #[case(RangeSet::from([(0..1)]), 0)] + #[case(RangeSet::from([(0..5)]), 1)] + #[case(RangeSet::from([(0..5), (6..10)]), 4)] + #[case(RangeSet::from([(0..5), (6..10)]), 6)] + #[case(RangeSet::from([(0..5), (6..10)]), 9)] + fn test_range_set_split_off(#[case] set: RangeSet, #[case] at: usize) { + let mut a = set.clone(); + let b = a.split_off(&at); + + assert!(a + .ranges + .last() + .map(|range| !range.is_empty()) + .unwrap_or(true)); + assert!(b + .ranges + .first() + .map(|range| !range.is_empty()) + .unwrap_or(true)); + assert_eq!(a.len() + b.len(), set.len()); + assert!(a.iter().chain(b.iter()).eq(set.iter())); + } + + #[test] + #[should_panic = "`at` is in the set"] + fn test_range_set_split_off_panic_not_in_set() { + RangeSet::from([0..1]).split_off(&1); + } }