Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix soundness bug and refactor away unnecessary UnsafeCell. #290

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/iter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,9 @@ impl<'a, K: Eq + Hash, V, S: 'a + BuildHasher + Clone, M: Map<'a, K, V, S>> Iter
if let Some((k, v)) = current.1.next() {
let guard = current.0.clone();

unsafe {
let k = util::change_lifetime_const(k);
let v = v.get_mut();

let v = &mut *v.as_ptr();

return Some(RefMutMulti::new(guard, k, v));
}
return Some(unsafe { RefMutMulti::new(guard, k, v) });
}
}

Expand Down
82 changes: 31 additions & 51 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -953,15 +953,10 @@ impl<'a, K: 'a + Eq + Hash, V: 'a, S: 'a + BuildHasher + Clone> Map<'a, K, V, S>
let mut shard = unsafe { self._yield_write_shard(idx) };

if let Some((kptr, vptr)) = shard.get_key_value(key) {
unsafe {
let kptr: *const K = kptr;
let vptr: *mut V = vptr.as_ptr();

if f(&*kptr, &mut *vptr) {
shard.remove_entry(key).map(|(k, v)| (k, v.into_inner()))
} else {
None
}
if f(kptr, vptr.get()) {
shard.remove_entry(key).map(|(k, v)| (k, v.into_inner()))
} else {
None
}
} else {
None
Expand All @@ -979,16 +974,11 @@ impl<'a, K: 'a + Eq + Hash, V: 'a, S: 'a + BuildHasher + Clone> Map<'a, K, V, S>

let mut shard = unsafe { self._yield_write_shard(idx) };

if let Some((kptr, vptr)) = shard.get_key_value(key) {
unsafe {
let kptr: *const K = kptr;
let vptr: *mut V = vptr.as_ptr();

if f(&*kptr, &mut *vptr) {
shard.remove_entry(key).map(|(k, v)| (k, v.into_inner()))
} else {
None
}
if let Some((kptr, vptr)) = shard.get_key_value_mut(key) {
if f(kptr, vptr.get_mut()) {
shard.remove_entry(key).map(|(k, v)| (k, v.into_inner()))
} else {
None
}
} else {
None
Expand Down Expand Up @@ -1034,14 +1024,12 @@ impl<'a, K: 'a + Eq + Hash, V: 'a, S: 'a + BuildHasher + Clone> Map<'a, K, V, S>

let idx = self.determine_shard(hash);

let shard = unsafe { self._yield_write_shard(idx) };
let mut shard = unsafe { self._yield_write_shard(idx) };

if let Some((kptr, vptr)) = shard.get_key_value(key) {
unsafe {
let kptr: *const K = kptr;
let vptr: *mut V = vptr.as_ptr();
Some(RefMut::new(shard, kptr, vptr))
}
if let Some((kptr, vptr)) = shard.get_key_value_mut(key) {
let kptr: *const K = kptr;
let vptr: *mut V = vptr.get_mut();
Some(unsafe { RefMut::new(shard, kptr, vptr) })
} else {
None
}
Expand Down Expand Up @@ -1081,17 +1069,15 @@ impl<'a, K: 'a + Eq + Hash, V: 'a, S: 'a + BuildHasher + Clone> Map<'a, K, V, S>

let idx = self.determine_shard(hash);

let shard = match unsafe { self._try_yield_write_shard(idx) } {
let mut shard = match unsafe { self._try_yield_write_shard(idx) } {
Some(shard) => shard,
None => return TryResult::Locked,
};

if let Some((kptr, vptr)) = shard.get_key_value(key) {
unsafe {
let kptr: *const K = kptr;
let vptr: *mut V = vptr.as_ptr();
TryResult::Present(RefMut::new(shard, kptr, vptr))
}
if let Some((kptr, vptr)) = shard.get_key_value_mut(key) {
let kptr: *const K = kptr;
let vptr: *mut V = vptr.get_mut();
TryResult::Present(unsafe { RefMut::new(shard, kptr, vptr) })
} else {
TryResult::Absent
}
Expand Down Expand Up @@ -1149,14 +1135,12 @@ impl<'a, K: 'a + Eq + Hash, V: 'a, S: 'a + BuildHasher + Clone> Map<'a, K, V, S>

let idx = self.determine_shard(hash);

let shard = unsafe { self._yield_write_shard(idx) };
let mut shard = unsafe { self._yield_write_shard(idx) };

if let Some((kptr, vptr)) = shard.get_key_value(&key) {
unsafe {
let kptr: *const K = kptr;
let vptr: *mut V = vptr.as_ptr();
Entry::Occupied(OccupiedEntry::new(shard, key, (kptr, vptr)))
}
if let Some((kptr, vptr)) = shard.get_key_value_mut(&key) {
let kptr: *const K = kptr;
let vptr: *mut V = vptr.get_mut();
Entry::Occupied(unsafe { OccupiedEntry::new(shard, key, (kptr, vptr)) })
} else {
unsafe { Entry::Vacant(VacantEntry::new(shard, key)) }
}
Expand All @@ -1167,22 +1151,18 @@ impl<'a, K: 'a + Eq + Hash, V: 'a, S: 'a + BuildHasher + Clone> Map<'a, K, V, S>

let idx = self.determine_shard(hash);

let shard = match unsafe { self._try_yield_write_shard(idx) } {
let mut shard = match unsafe { self._try_yield_write_shard(idx) } {
Some(shard) => shard,
None => return None,
};

if let Some((kptr, vptr)) = shard.get_key_value(&key) {
unsafe {
let kptr: *const K = kptr;
let vptr: *mut V = vptr.as_ptr();
if let Some((kptr, vptr)) = shard.get_key_value_mut(&key) {
let kptr: *const K = kptr;
let vptr: *mut V = vptr.get_mut();

Some(Entry::Occupied(OccupiedEntry::new(
shard,
key,
(kptr, vptr),
)))
}
Some(Entry::Occupied(unsafe {
OccupiedEntry::new(shard, key, (kptr, vptr))
}))
} else {
unsafe { Some(Entry::Vacant(VacantEntry::new(shard, key))) }
}
Expand Down
22 changes: 9 additions & 13 deletions src/mapref/entry.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use super::one::RefMut;
use crate::lock::RwLockWriteGuard;
use crate::util;
use crate::util::SharedValue;
use crate::HashMap;
use core::hash::{BuildHasher, Hash};
use core::mem;
use core::ptr;
use std::collections::hash_map::RandomState;
use std::{collections::hash_map::RandomState, mem::ManuallyDrop};

pub enum Entry<'a, K, V, S = RandomState> {
Occupied(OccupiedEntry<'a, K, V, S>),
Expand Down Expand Up @@ -129,21 +128,18 @@ impl<'a, K: Eq + Hash, V, S: BuildHasher> VacantEntry<'a, K, V, S> {

pub fn insert(mut self, value: V) -> RefMut<'a, K, V, S> {
unsafe {
let c: K = ptr::read(&self.key);
// Use ManuallyDrop here instead of ptr::read because it doesn't cause a double drop if we unexpectedly panic.
let c = ManuallyDrop::new(ptr::read(&self.key));

self.shard.insert(self.key, SharedValue::new(value));

let (k, v) = self.shard.get_key_value(&c).unwrap();
let (k, v) = self.shard.get_key_value_mut(&*c).unwrap();

let k = util::change_lifetime_const(k);
let k: *const K = k;

let v = &mut *v.as_ptr();
let v: *mut V = v.get_mut();

let r = RefMut::new(self.shard, k, v);

mem::forget(c);

r
RefMut::new(self.shard, k, v)
}
}

Expand All @@ -155,10 +151,10 @@ impl<'a, K: Eq + Hash, V, S: BuildHasher> VacantEntry<'a, K, V, S> {
unsafe {
self.shard.insert(self.key.clone(), SharedValue::new(value));

let (k, v) = self.shard.get_key_value(&self.key).unwrap();
let (k, v) = self.shard.get_key_value_mut(&self.key).unwrap();

let kptr: *const K = k;
let vptr: *mut V = v.as_ptr();
let vptr: *mut V = v.get_mut();
OccupiedEntry::new(self.shard, self.key, (kptr, vptr))
}
}
Expand Down
22 changes: 6 additions & 16 deletions src/util.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! This module is full of hackery and dark magic.
//! Either spend a day fixing it and quietly submit a PR or don't mention it to anybody.
use core::cell::UnsafeCell;
use core::{mem, ptr};

pub const fn ptr_size_bits() -> usize {
Expand Down Expand Up @@ -45,16 +44,14 @@ pub unsafe fn change_lifetime_mut<'a, 'b, T>(x: &'a mut T) -> &'b mut T {
/// This type is meant to be an implementation detail, but must be exposed due to the `Dashmap::shards`
#[repr(transparent)]
pub struct SharedValue<T> {
value: UnsafeCell<T>,
value: T,
}

impl<T: Clone> Clone for SharedValue<T> {
fn clone(&self) -> Self {
let inner = self.get().clone();

Self {
value: UnsafeCell::new(inner),
}
Self { value: inner }
}
}

Expand All @@ -65,29 +62,22 @@ unsafe impl<T: Sync> Sync for SharedValue<T> {}
impl<T> SharedValue<T> {
/// Create a new `SharedValue<T>`
pub const fn new(value: T) -> Self {
Self {
value: UnsafeCell::new(value),
}
Self { value }
}

/// Get a shared reference to `T`
pub fn get(&self) -> &T {
unsafe { &*self.value.get() }
&self.value
}

/// Get an unique reference to `T`
pub fn get_mut(&mut self) -> &mut T {
unsafe { &mut *self.value.get() }
&mut self.value
}

/// Unwraps the value
pub fn into_inner(self) -> T {
self.value.into_inner()
}

/// Get a mutable raw pointer to the underlying value
pub(crate) fn as_ptr(&self) -> *mut T {
self.value.get()
self.value
}
}

Expand Down
Loading