From d64b9ce87d54edb4f4fbcf74a85bba291b09e7e2 Mon Sep 17 00:00:00 2001 From: Bwallker Date: Sun, 12 Nov 2023 15:48:22 +0200 Subject: [PATCH] Fix soundness bug and refactor away unnecessary UnsafeCell. - Fix a potential soundness bug by using ManuallyDrop. - Change SharedValue to contain a T instead of a UnsafeCell and remove as_ptr method. - Swap several calls to `HashMap::get_key_value` with calls to `Hashmap::get_key_value_mut` in order to make using UnsafeCell redundant. - Make unsafe blocks smaller and more refined. --- src/iter.rs | 8 ++--- src/lib.rs | 82 +++++++++++++++++---------------------------- src/mapref/entry.rs | 22 +++++------- src/util.rs | 22 ++++-------- 4 files changed, 48 insertions(+), 86 deletions(-) diff --git a/src/iter.rs b/src/iter.rs index ce50e739..8523ff87 100644 --- a/src/iter.rs +++ b/src/iter.rs @@ -243,13 +243,9 @@ impl<'a, K: Eq + Hash, V, S: 'a + BuildHasher + Clone, M: Map<'a, K, V, S>> Iter if let Some((k, v)) = current.1.next() { let guard = current.0.clone(); - unsafe { - let k = util::change_lifetime_const(k); + let v = v.get_mut(); - let v = &mut *v.as_ptr(); - - return Some(RefMutMulti::new(guard, k, v)); - } + return Some(unsafe { RefMutMulti::new(guard, k, v) }); } } diff --git a/src/lib.rs b/src/lib.rs index 34e5d8ba..fec0c289 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -953,15 +953,10 @@ impl<'a, K: 'a + Eq + Hash, V: 'a, S: 'a + BuildHasher + Clone> Map<'a, K, V, S> let mut shard = unsafe { self._yield_write_shard(idx) }; if let Some((kptr, vptr)) = shard.get_key_value(key) { - unsafe { - let kptr: *const K = kptr; - let vptr: *mut V = vptr.as_ptr(); - - if f(&*kptr, &mut *vptr) { - shard.remove_entry(key).map(|(k, v)| (k, v.into_inner())) - } else { - None - } + if f(kptr, vptr.get()) { + shard.remove_entry(key).map(|(k, v)| (k, v.into_inner())) + } else { + None } } else { None @@ -979,16 +974,11 @@ impl<'a, K: 'a + Eq + Hash, V: 'a, S: 'a + BuildHasher + Clone> Map<'a, K, V, S> let mut shard = unsafe { self._yield_write_shard(idx) }; - if let Some((kptr, vptr)) = shard.get_key_value(key) { - unsafe { - let kptr: *const K = kptr; - let vptr: *mut V = vptr.as_ptr(); - - if f(&*kptr, &mut *vptr) { - shard.remove_entry(key).map(|(k, v)| (k, v.into_inner())) - } else { - None - } + if let Some((kptr, vptr)) = shard.get_key_value_mut(key) { + if f(kptr, vptr.get_mut()) { + shard.remove_entry(key).map(|(k, v)| (k, v.into_inner())) + } else { + None } } else { None @@ -1034,14 +1024,12 @@ impl<'a, K: 'a + Eq + Hash, V: 'a, S: 'a + BuildHasher + Clone> Map<'a, K, V, S> let idx = self.determine_shard(hash); - let shard = unsafe { self._yield_write_shard(idx) }; + let mut shard = unsafe { self._yield_write_shard(idx) }; - if let Some((kptr, vptr)) = shard.get_key_value(key) { - unsafe { - let kptr: *const K = kptr; - let vptr: *mut V = vptr.as_ptr(); - Some(RefMut::new(shard, kptr, vptr)) - } + if let Some((kptr, vptr)) = shard.get_key_value_mut(key) { + let kptr: *const K = kptr; + let vptr: *mut V = vptr.get_mut(); + Some(unsafe { RefMut::new(shard, kptr, vptr) }) } else { None } @@ -1081,17 +1069,15 @@ impl<'a, K: 'a + Eq + Hash, V: 'a, S: 'a + BuildHasher + Clone> Map<'a, K, V, S> let idx = self.determine_shard(hash); - let shard = match unsafe { self._try_yield_write_shard(idx) } { + let mut shard = match unsafe { self._try_yield_write_shard(idx) } { Some(shard) => shard, None => return TryResult::Locked, }; - if let Some((kptr, vptr)) = shard.get_key_value(key) { - unsafe { - let kptr: *const K = kptr; - let vptr: *mut V = vptr.as_ptr(); - TryResult::Present(RefMut::new(shard, kptr, vptr)) - } + if let Some((kptr, vptr)) = shard.get_key_value_mut(key) { + let kptr: *const K = kptr; + let vptr: *mut V = vptr.get_mut(); + TryResult::Present(unsafe { RefMut::new(shard, kptr, vptr) }) } else { TryResult::Absent } @@ -1149,14 +1135,12 @@ impl<'a, K: 'a + Eq + Hash, V: 'a, S: 'a + BuildHasher + Clone> Map<'a, K, V, S> let idx = self.determine_shard(hash); - let shard = unsafe { self._yield_write_shard(idx) }; + let mut shard = unsafe { self._yield_write_shard(idx) }; - if let Some((kptr, vptr)) = shard.get_key_value(&key) { - unsafe { - let kptr: *const K = kptr; - let vptr: *mut V = vptr.as_ptr(); - Entry::Occupied(OccupiedEntry::new(shard, key, (kptr, vptr))) - } + if let Some((kptr, vptr)) = shard.get_key_value_mut(&key) { + let kptr: *const K = kptr; + let vptr: *mut V = vptr.get_mut(); + Entry::Occupied(unsafe { OccupiedEntry::new(shard, key, (kptr, vptr)) }) } else { unsafe { Entry::Vacant(VacantEntry::new(shard, key)) } } @@ -1167,22 +1151,18 @@ impl<'a, K: 'a + Eq + Hash, V: 'a, S: 'a + BuildHasher + Clone> Map<'a, K, V, S> let idx = self.determine_shard(hash); - let shard = match unsafe { self._try_yield_write_shard(idx) } { + let mut shard = match unsafe { self._try_yield_write_shard(idx) } { Some(shard) => shard, None => return None, }; - if let Some((kptr, vptr)) = shard.get_key_value(&key) { - unsafe { - let kptr: *const K = kptr; - let vptr: *mut V = vptr.as_ptr(); + if let Some((kptr, vptr)) = shard.get_key_value_mut(&key) { + let kptr: *const K = kptr; + let vptr: *mut V = vptr.get_mut(); - Some(Entry::Occupied(OccupiedEntry::new( - shard, - key, - (kptr, vptr), - ))) - } + Some(Entry::Occupied(unsafe { + OccupiedEntry::new(shard, key, (kptr, vptr)) + })) } else { unsafe { Some(Entry::Vacant(VacantEntry::new(shard, key))) } } diff --git a/src/mapref/entry.rs b/src/mapref/entry.rs index e9e6b913..30c576df 100644 --- a/src/mapref/entry.rs +++ b/src/mapref/entry.rs @@ -1,12 +1,11 @@ use super::one::RefMut; use crate::lock::RwLockWriteGuard; -use crate::util; use crate::util::SharedValue; use crate::HashMap; use core::hash::{BuildHasher, Hash}; use core::mem; use core::ptr; -use std::collections::hash_map::RandomState; +use std::{collections::hash_map::RandomState, mem::ManuallyDrop}; pub enum Entry<'a, K, V, S = RandomState> { Occupied(OccupiedEntry<'a, K, V, S>), @@ -129,21 +128,18 @@ impl<'a, K: Eq + Hash, V, S: BuildHasher> VacantEntry<'a, K, V, S> { pub fn insert(mut self, value: V) -> RefMut<'a, K, V, S> { unsafe { - let c: K = ptr::read(&self.key); + // Use ManuallyDrop here instead of ptr::read because it doesn't cause a double drop if we unexpectedly panic. + let c = ManuallyDrop::new(ptr::read(&self.key)); self.shard.insert(self.key, SharedValue::new(value)); - let (k, v) = self.shard.get_key_value(&c).unwrap(); + let (k, v) = self.shard.get_key_value_mut(&*c).unwrap(); - let k = util::change_lifetime_const(k); + let k: *const K = k; - let v = &mut *v.as_ptr(); + let v: *mut V = v.get_mut(); - let r = RefMut::new(self.shard, k, v); - - mem::forget(c); - - r + RefMut::new(self.shard, k, v) } } @@ -155,10 +151,10 @@ impl<'a, K: Eq + Hash, V, S: BuildHasher> VacantEntry<'a, K, V, S> { unsafe { self.shard.insert(self.key.clone(), SharedValue::new(value)); - let (k, v) = self.shard.get_key_value(&self.key).unwrap(); + let (k, v) = self.shard.get_key_value_mut(&self.key).unwrap(); let kptr: *const K = k; - let vptr: *mut V = v.as_ptr(); + let vptr: *mut V = v.get_mut(); OccupiedEntry::new(self.shard, self.key, (kptr, vptr)) } } diff --git a/src/util.rs b/src/util.rs index d84e37db..f5f4373b 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,6 +1,5 @@ //! This module is full of hackery and dark magic. //! Either spend a day fixing it and quietly submit a PR or don't mention it to anybody. -use core::cell::UnsafeCell; use core::{mem, ptr}; pub const fn ptr_size_bits() -> usize { @@ -45,16 +44,14 @@ pub unsafe fn change_lifetime_mut<'a, 'b, T>(x: &'a mut T) -> &'b mut T { /// This type is meant to be an implementation detail, but must be exposed due to the `Dashmap::shards` #[repr(transparent)] pub struct SharedValue { - value: UnsafeCell, + value: T, } impl Clone for SharedValue { fn clone(&self) -> Self { let inner = self.get().clone(); - Self { - value: UnsafeCell::new(inner), - } + Self { value: inner } } } @@ -65,29 +62,22 @@ unsafe impl Sync for SharedValue {} impl SharedValue { /// Create a new `SharedValue` pub const fn new(value: T) -> Self { - Self { - value: UnsafeCell::new(value), - } + Self { value } } /// Get a shared reference to `T` pub fn get(&self) -> &T { - unsafe { &*self.value.get() } + &self.value } /// Get an unique reference to `T` pub fn get_mut(&mut self) -> &mut T { - unsafe { &mut *self.value.get() } + &mut self.value } /// Unwraps the value pub fn into_inner(self) -> T { - self.value.into_inner() - } - - /// Get a mutable raw pointer to the underlying value - pub(crate) fn as_ptr(&self) -> *mut T { - self.value.get() + self.value } }