Skip to content

Commit

Permalink
fix: fix soundness issue related to storing a reference in Ref and …
Browse files Browse the repository at this point in the history
…`RefMut`.

Fixes: zakarumych#2
  • Loading branch information
zicklag committed Aug 26, 2023
1 parent 42cf408 commit 134b76a
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 73 deletions.
71 changes: 43 additions & 28 deletions src/refs/immutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use core::{
fmt::{self, Debug, Display},
hash::{Hash, Hasher},
ops::{Deref, RangeBounds},
ptr::NonNull,
};

use crate::borrow::AtomicBorrow;
Expand All @@ -19,10 +20,14 @@ use crate::borrow::AtomicBorrow;
/// [`AtomicCell`]: struct.AtomicCell.html
/// [`&T`]: https://doc.rust-lang.org/core/primitive.reference.html
pub struct Ref<'a, T: ?Sized> {
value: &'a T,
value: NonNull<T>,
borrow: AtomicBorrow<'a>,
}

// SAFETY: `Ref<'_, T> acts as a reference. `AtomicBorrowR` is a reference to an atomic.
unsafe impl<'b, T: ?Sized + 'b> Sync for Ref<'b, T> where for<'a> &'a T: Sync {}
unsafe impl<'b, T: ?Sized + 'b> Send for Ref<'b, T> where for<'a> &'a T: Send {}

impl<'a, T> Clone for Ref<'a, T>
where
T: ?Sized,
Expand Down Expand Up @@ -50,7 +55,8 @@ where

#[inline(always)]
fn deref(&self) -> &T {
self.value
// SAFETY: We hold a shared borrow lock.
unsafe { self.value.as_ref() }
}
}

Expand All @@ -60,7 +66,7 @@ where
{
#[inline(always)]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Debug::fmt(self.value, f)
<T as Debug>::fmt(self, f)
}
}

Expand All @@ -70,7 +76,7 @@ where
{
#[inline(always)]
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
Display::fmt(self.value, f)
<T as Display>::fmt(self, f)
}
}

Expand All @@ -80,7 +86,7 @@ where
{
#[inline(always)]
fn eq(&self, other: &U) -> bool {
PartialEq::eq(self.value, other)
<T as PartialEq<U>>::eq(self, other)
}
}

Expand All @@ -90,7 +96,7 @@ where
{
#[inline(always)]
fn partial_cmp(&self, other: &U) -> Option<Ordering> {
PartialOrd::partial_cmp(self.value, other)
<T as PartialOrd<U>>::partial_cmp(self, other)
}
}

Expand All @@ -103,14 +109,14 @@ where
where
H: Hasher,
{
Hash::hash(self.value, state)
<T as Hash>::hash(self, state)
}
}

impl<'a, T> Borrow<T> for Ref<'a, T> {
#[inline(always)]
fn borrow(&self) -> &T {
self.value
self
}
}

Expand All @@ -120,7 +126,7 @@ where
{
#[inline(always)]
fn as_ref(&self) -> &U {
self.value.as_ref()
<T as AsRef<U>>::as_ref(self)
}
}

Expand All @@ -145,7 +151,7 @@ where
#[inline]
pub fn new(r: &'a T) -> Self {
Ref {
value: r,
value: NonNull::from(r),
borrow: AtomicBorrow::dummy(),
}
}
Expand All @@ -171,16 +177,20 @@ where
/// ```
#[inline]
pub fn with_borrow(r: &'a T, borrow: AtomicBorrow<'a>) -> Self {
Ref { value: r, borrow }
Ref {
value: NonNull::from(r),
borrow,
}
}

/// Splits wrapper into two parts.
/// One is reference to the value
/// and the other is [`AtomicBorrow`] that guards it from being borrowed mutably.
/// Splits wrapper into two parts. One is reference to the value and the other is
/// [`AtomicBorrow`] that guards it from being borrowed mutably.
///
/// # Safety
///
/// User must ensure reference is not used after [`AtomicBorrow`] is dropped.
/// User must ensure [`NonNull`] is not dereferenced after [`AtomicBorrow`] is dropped.
///
/// Also, the [`NonNull<T>`] that is returned is still only valid for reads, not writes.
///
/// # Examples
///
Expand All @@ -192,7 +202,7 @@ where
///
/// unsafe {
/// let (r, borrow) = Ref::into_split(r);
/// assert_eq!(*r, 42);
/// assert_eq!(*r.as_ref(), 42);
///
/// assert!(cell.try_borrow().is_some(), "Must be able to borrow immutably");
/// assert!(cell.try_borrow_mut().is_none(), "Must not be able to borrow mutably yet");
Expand All @@ -201,7 +211,7 @@ where
/// }
/// ```
#[inline]
pub unsafe fn into_split(r: Ref<'a, T>) -> (&'a T, AtomicBorrow<'a>) {
pub fn into_split(r: Ref<'a, T>) -> (NonNull<T>, AtomicBorrow<'a>) {
(r.value, r.borrow)
}

Expand Down Expand Up @@ -231,7 +241,7 @@ where
U: ?Sized,
{
Ref {
value: f(r.value),
value: NonNull::from(f(&*r)),
borrow: r.borrow,
}
}
Expand Down Expand Up @@ -259,9 +269,9 @@ where
where
F: FnOnce(&T) -> Option<&U>,
{
match f(r.value) {
match f(&*r) {
Some(value) => Ok(Ref {
value,
value: NonNull::from(value),
borrow: r.borrow,
}),
None => Err(r),
Expand Down Expand Up @@ -297,15 +307,16 @@ where
let borrow_u = r.borrow.clone();
let borrow_v = r.borrow;

let (u, v) = f(r.value);
// SAFETY: we have a shared reference lock on the pointer.
let (u, v) = f(unsafe { r.value.as_ref() });

(
Ref {
value: u,
value: NonNull::from(u),
borrow: borrow_u,
},
Ref {
value: v,
value: NonNull::from(v),
borrow: borrow_v,
},
)
Expand Down Expand Up @@ -337,7 +348,8 @@ where
/// ```
pub fn leak(r: Ref<'a, T>) -> &'a T {
core::mem::forget(r.borrow);
r.value
// SAFETY: we have a shared reference lock on the pointer.
unsafe { r.value.as_ref() }
}

/// Converts reference and returns result wrapped in the [`Ref`].
Expand Down Expand Up @@ -366,7 +378,7 @@ where
T: AsRef<U>,
{
Ref {
value: r.value.as_ref(),
value: NonNull::from(<T as AsRef<U>>::as_ref(&r)),
borrow: r.borrow,
}
}
Expand Down Expand Up @@ -396,7 +408,8 @@ where
T: Deref,
{
Ref {
value: &r.value,
value: NonNull::from(<T as Deref>::deref(&*r)),

borrow: r.borrow,
}
}
Expand Down Expand Up @@ -432,7 +445,7 @@ impl<'a, T> Ref<'a, Option<T>> {
#[inline]
pub fn transpose(r: Ref<'a, Option<T>>) -> Option<Ref<'a, T>> {
Some(Ref {
value: r.value.as_ref()?,
value: r.as_ref().map(NonNull::from)?,
borrow: r.borrow,
})
}
Expand Down Expand Up @@ -464,8 +477,10 @@ impl<'a, T> Ref<'a, [T]> {
R: RangeBounds<usize>,
{
let bounds = (range.start_bound().cloned(), range.end_bound().cloned());
let slice = &*r;
let slice = &slice[bounds];
Ref {
value: &r.value[bounds],
value: NonNull::from(slice),
borrow: r.borrow,
}
}
Expand Down
Loading

0 comments on commit 134b76a

Please sign in to comment.