Skip to content

Commit

Permalink
100% safe implementation of RepeatN
Browse files Browse the repository at this point in the history
  • Loading branch information
Soveu committed Nov 29, 2024
1 parent a45391f commit ff05378
Showing 1 changed file with 44 additions and 86 deletions.
130 changes: 44 additions & 86 deletions library/core/src/iter/sources/repeat_n.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::fmt;
use crate::iter::{FusedIterator, TrustedLen, UncheckedIterator};
use crate::mem::{self, MaybeUninit};
use crate::num::NonZero;

/// Creates a new iterator that repeats a single element a given number of times.
Expand Down Expand Up @@ -57,78 +56,48 @@ use crate::num::NonZero;
#[inline]
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
pub fn repeat_n<T: Clone>(element: T, count: usize) -> RepeatN<T> {
let element = if count == 0 {
// `element` gets dropped eagerly.
MaybeUninit::uninit()
} else {
MaybeUninit::new(element)
};

RepeatN { element, count }
RepeatN { inner: RepeatNInner::new(element, count) }
}

#[derive(Clone)]
struct RepeatNInner<T> {
element: T,
count: NonZero<usize>,
}

impl<T> RepeatNInner<T> {
fn new(element: T, count: usize) -> Option<Self> {
let count = NonZero::<usize>::new(count)?;
Some(Self { element, count })
}
}

/// An iterator that repeats an element an exact number of times.
///
/// This `struct` is created by the [`repeat_n()`] function.
/// See its documentation for more.
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
#[derive(Clone)]
pub struct RepeatN<A> {
count: usize,
// Invariant: uninit iff count == 0.
element: MaybeUninit<A>,
inner: Option<RepeatNInner<A>>,
}

impl<A> RepeatN<A> {
/// Returns the element if it hasn't been dropped already.
fn element_ref(&self) -> Option<&A> {
if self.count > 0 {
// SAFETY: The count is non-zero, so it must be initialized.
Some(unsafe { self.element.assume_init_ref() })
} else {
None
}
}
/// If we haven't already dropped the element, return it in an option.
///
/// Clears the count so it won't be dropped again later.
#[inline]
fn take_element(&mut self) -> Option<A> {
if self.count > 0 {
self.count = 0;
let element = mem::replace(&mut self.element, MaybeUninit::uninit());
// SAFETY: We just set count to zero so it won't be dropped again,
// and it used to be non-zero so it hasn't already been dropped.
unsafe { Some(element.assume_init()) }
} else {
None
}
}
}

#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A: Clone> Clone for RepeatN<A> {
fn clone(&self) -> RepeatN<A> {
RepeatN {
count: self.count,
element: self.element_ref().cloned().map_or_else(MaybeUninit::uninit, MaybeUninit::new),
}
self.inner.take().map(|inner| inner.element)
}
}

#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A: fmt::Debug> fmt::Debug for RepeatN<A> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RepeatN")
.field("count", &self.count)
.field("element", &self.element_ref())
.finish()
}
}

#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A> Drop for RepeatN<A> {
fn drop(&mut self) {
self.take_element();
let (count, element) = match self.inner.as_ref() {
Some(inner) => (inner.count.get(), Some(&inner.element)),
None => (0, None),
};
f.debug_struct("RepeatN").field("count", &count).field("element", &element).finish()
}
}

Expand All @@ -138,12 +107,17 @@ impl<A: Clone> Iterator for RepeatN<A> {

#[inline]
fn next(&mut self) -> Option<A> {
if self.count > 0 {
// SAFETY: Just checked it's not empty
unsafe { Some(self.next_unchecked()) }
} else {
None
let inner = self.inner.as_mut()?;
let count = inner.count.get();

if let Some(decremented) = NonZero::<usize>::new(count - 1) {
// Order of these is important for optimization
let tmp = inner.element.clone();
inner.count = decremented;
return Some(tmp);
}

return self.take_element();
}

#[inline]
Expand All @@ -154,19 +128,19 @@ impl<A: Clone> Iterator for RepeatN<A> {

#[inline]
fn advance_by(&mut self, skip: usize) -> Result<(), NonZero<usize>> {
let len = self.count;
let Some(inner) = self.inner.as_mut() else {
return NonZero::<usize>::new(skip).map(Err).unwrap_or(Ok(()));
};

if skip >= len {
self.take_element();
}
let len = inner.count.get();

if skip > len {
// SAFETY: we just checked that the difference is positive
Err(unsafe { NonZero::new_unchecked(skip - len) })
} else {
self.count = len - skip;
Ok(())
if let Some(new_len) = len.checked_sub(skip).and_then(NonZero::<usize>::new) {
inner.count = new_len;
return Ok(());
}

self.inner = None;
return NonZero::<usize>::new(skip - len).map(Err).unwrap_or(Ok(()));
}

#[inline]
Expand All @@ -183,7 +157,7 @@ impl<A: Clone> Iterator for RepeatN<A> {
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A: Clone> ExactSizeIterator for RepeatN<A> {
fn len(&self) -> usize {
self.count
self.inner.as_ref().map(|inner| inner.count.get()).unwrap_or(0)
}
}

Expand Down Expand Up @@ -211,20 +185,4 @@ impl<A: Clone> FusedIterator for RepeatN<A> {}
#[unstable(feature = "trusted_len", issue = "37572")]
unsafe impl<A: Clone> TrustedLen for RepeatN<A> {}
#[stable(feature = "iter_repeat_n", since = "1.82.0")]
impl<A: Clone> UncheckedIterator for RepeatN<A> {
#[inline]
unsafe fn next_unchecked(&mut self) -> Self::Item {
// SAFETY: The caller promised the iterator isn't empty
self.count = unsafe { self.count.unchecked_sub(1) };
if self.count == 0 {
// SAFETY: the check above ensured that the count used to be non-zero,
// so element hasn't been dropped yet, and we just lowered the count to
// zero so it won't be dropped later, and thus it's okay to take it here.
unsafe { mem::replace(&mut self.element, MaybeUninit::uninit()).assume_init() }
} else {
// SAFETY: the count is non-zero, so it must have not been dropped yet.
let element = unsafe { self.element.assume_init_ref() };
A::clone(element)
}
}
}
impl<A: Clone> UncheckedIterator for RepeatN<A> {}

0 comments on commit ff05378

Please sign in to comment.