Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite get_many_mut methods #367

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions src/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1730,9 +1730,15 @@ where
where
Q: Hash + Equivalent<K>,
{
let hashes = self.build_hashes_inner(ks);
self.table
.get_many_mut(hashes, |i, (k, _)| ks[i].equivalent(k))
let hash_builder = &self.hash_builder;

let mut iter = ks.into_iter().map(|key| {
(
make_hash::<Q, S>(hash_builder, key),
equivalent_key::<Q, K, V>(key),
)
});
self.table.get_many_mut_from_iter(&mut iter)
}

unsafe fn get_many_unchecked_mut_inner<Q: ?Sized, const N: usize>(
Expand All @@ -1742,20 +1748,15 @@ where
where
Q: Hash + Equivalent<K>,
{
let hashes = self.build_hashes_inner(ks);
self.table
.get_many_unchecked_mut(hashes, |i, (k, _)| ks[i].equivalent(k))
}
let hash_builder = &self.hash_builder;

fn build_hashes_inner<Q: ?Sized, const N: usize>(&self, ks: [&Q; N]) -> [u64; N]
where
Q: Hash + Equivalent<K>,
{
let mut hashes = [0_u64; N];
for i in 0..N {
hashes[i] = make_hash::<Q, S>(&self.hash_builder, ks[i]);
}
hashes
let mut iter = ks.into_iter().map(|key| {
(
make_hash::<Q, S>(hash_builder, key),
equivalent_key::<Q, K, V>(key),
)
});
self.table.get_many_unchecked_mut_from_iter(&mut iter)
}

/// Inserts a key-value pair into the map.
Expand Down
154 changes: 154 additions & 0 deletions src/raw/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,7 @@ impl<T, A: Allocator + Clone> RawTable<T, A> {
///
/// The `eq` argument should be a closure such that `eq(i, k)` returns true if `k` is equal to
/// the `i`th key to be looked up.
#[cfg(feature = "raw")]
pub fn get_many_mut<const N: usize>(
&mut self,
hashes: [u64; N],
Expand All @@ -886,6 +887,7 @@ impl<T, A: Allocator + Clone> RawTable<T, A> {
}
}

#[cfg(feature = "raw")]
pub unsafe fn get_many_unchecked_mut<const N: usize>(
&mut self,
hashes: [u64; N],
Expand All @@ -895,6 +897,7 @@ impl<T, A: Allocator + Clone> RawTable<T, A> {
Some(mem::transmute_copy(&ptrs))
}

#[cfg(feature = "raw")]
unsafe fn get_many_mut_pointers<const N: usize>(
&mut self,
hashes: [u64; N],
Expand All @@ -913,6 +916,157 @@ impl<T, A: Allocator + Clone> RawTable<T, A> {
Some(outs.assume_init())
}

/// Attempts to get mutable references to `N element` in the table at once using
/// `hash` and equality function from iterator.
///
/// The `iter` argument should be an iterator that return `hash` of the stored
/// `element` and closure for checking the equivalence of that `element`.
///
/// This function return `None`:
///
/// - if an `element` is not found for any item from the iterator;
/// - if any of the requested `elements` from table are duplicated;
/// - if the given `const N` is equal to zero (`0`).
/// - if the given iterator length is not equal to the specified `const N`;
#[allow(clippy::explicit_counter_loop)]
pub fn get_many_mut_from_iter<'a, 'b, I, F, const N: usize>(
&'a mut self,
iter: &'b mut I,
) -> Option<[&'a mut T; N]>
where
I: Iterator<Item = (u64, F)>,
F: FnMut(&T) -> bool,
{
let pointers: [*mut T; N] = self.get_many_mut_pointers_from_iter(iter)?;

// Avoid using `Iterator::enumerate` because of double checking
let mut index = 0_usize;
for &current in pointers.iter() {
// SAFETY: we now exactly that the `index` less than `pointers` length
if unsafe { pointers.get_unchecked(..index) }
.iter()
.any(|&previous| previous == current)
{
return None;
}
index += 1;
}

// SAFETY: All bucket are distinct from all previous buckets, `*mut T` and `&T`
// are guaranteed properly aligned and have the same layout, so we're clear to
// return the result of the lookup.
// Also no needance of using mem::forget(pointers) because it is just array of
// pointers.
Some(unsafe { (&pointers as *const _ as *const [&mut T; N]).read() })
}

/// Attempts to get mutable references to `N element` in the table at once using
/// `hash` and equality function from iterator, without checking the uniqueness
/// of the found elements.
///
/// The `iter` argument should be an iterator that return `hash` of the stored
/// `element` and closure for checking the equivalence of that `element`.
///
/// This function return `None`:
///
/// - if an `element` is not found for any item from the iterator;
/// - if the given `const N` is equal to zero (`0`).
/// - if the given iterator length is not equal to the specified `const N`;
///
/// # Safety
///
/// Calling this method is *[undefined behavior]* if iterator contain overlapping
/// items that refer to the same `elements` in the table even if the resulting
/// references to `elements` in the table are not used.
///
/// [undefined behavior]: https://doc.rust-lang.org/reference/behavior-considered-undefined.html
pub unsafe fn get_many_unchecked_mut_from_iter<'a, 'b, I, F, const N: usize>(
&'a mut self,
iter: &'b mut I,
) -> Option<[&'a mut T; N]>
where
I: Iterator<Item = (u64, F)>,
F: FnMut(&T) -> bool,
{
let pointers: [*mut T; N] = self.get_many_mut_pointers_from_iter(iter)?;

// SAFETY: the caller must uphold the safety contract for `get_many_unchecked_mut_from_iter`.
// We only know that `*mut T` and `&T` are guaranteed properly aligned and have the same layout.
// Also we know that there is no needance of using mem::forget(pointers) because it is just
// array of pointers.
Some((&pointers as *const _ as *const [&mut T; N]).read())
}

/// Attempts to get mutable pointers to `N element` in the table at once using
/// `hash` and equality function from iterator, without checking the uniqueness
/// of the found elements.
///
/// The `iter` argument should be an iterator that return `hash` of the stored
/// `element` and closure for checking the equivalence of that `element`.
///
/// This function return `None`:
///
/// - if an `element` is not found for any item from the iterator;
/// - if the given `const N` is equal to zero (`0`).
/// - if the given iterator length is not equal to the specified `const N`;
///
/// # Safety
///
/// Calling this method is safe, but the returned array may contain overlapping
/// items pointing to the same `elements` in the table.
fn get_many_mut_pointers_from_iter<I, F, const N: usize>(
&mut self,
iter: &mut I,
) -> Option<[*mut T; N]>
where
I: Iterator<Item = (u64, F)>,
F: FnMut(&T) -> bool,
{
// Check trivial cases
if N == 0 || N > self.len() {
return None;
}

// If `iterator::size_hint` returns some upper bound, we check
// that it is equal to `const N`, else return from the function
if let (_, Some(upper_bound)) = iter.size_hint() {
if upper_bound != N {
return None;
}
}

// SAFETY: An uninitialized `[MaybeUninit<_>; LEN]` is valid,
// because the type we are claiming to have initialized here is a
// bunch of `MaybeUninit`s, which do not require initialization.
//
// FIXME: Use `MaybeUninit::uninit_array` or `maybe_uninit_uninit_array_transpose`
// (https://github.com/rust-lang/rust/pull/102023) instead as soon as either becomes
// stable
let mut array = unsafe { MaybeUninit::<[MaybeUninit<*mut T>; N]>::uninit().assume_init() };

for element in &mut array {
match iter.next() {
Some((hash, eq)) => match self.find(hash, eq) {
Some(bucket) => {
element.write(bucket.as_ptr());
}
None => return None,
},
None => return None,
}
}
// SAFETY: All elements of the array were populated in the loop above,
// `MaybeUninit<*mut T>` and `*mut T` are guaranteed properly aligned and
// have the same layout.
// Also no needance of using mem::forget(array) because it is just array of
// pointers.
//
// FIXME: Use `MaybeUninit::array_assume_init` or `maybe_uninit_uninit_array_transpose`
// (https://github.com/rust-lang/rust/pull/102023) instead as soon as either becomes
// stable
Some(unsafe { (&array as *const _ as *const [*mut T; N]).read() })
}

/// Returns the number of elements the map can hold without reallocating.
///
/// This number is a lower bound; the table might be able to hold
Expand Down