From 68912a18b4d1dee1fa4a027c6881e7095ed207ab Mon Sep 17 00:00:00 2001 From: Jaehwang Jung Date: Sun, 8 Jan 2023 00:59:09 +0900 Subject: [PATCH] optimize hp and fix hp DSs --- .gitignore | 14 ++ Cargo.toml | 10 ++ src/domain.rs | 25 ++++ src/hazard.rs | 304 +++++++++++++++++++++++++++++++++++++++ src/lib.rs | 56 ++++++++ src/retire.rs | 16 +++ src/tag.rs | 36 +++++ src/thread.rs | 127 +++++++++++++++++ tests/test.rs | 384 ++++++++++++++++++++++++++++++++++++++++++++++++++ 9 files changed, 972 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 src/domain.rs create mode 100644 src/hazard.rs create mode 100644 src/lib.rs create mode 100644 src/retire.rs create mode 100644 src/tag.rs create mode 100644 src/thread.rs create mode 100644 tests/test.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6985cf1 --- /dev/null +++ b/.gitignore @@ -0,0 +1,14 @@ +# Generated by Cargo +# will have compiled files and executables +debug/ +target/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information +*.pdb diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..c4e395b --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "hp_pp" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +membarrier = { git = "https://github.com/jeehoonkang/membarrier-rs.git", branch = "pebr" } +crossbeam-utils = "0.8.14" diff --git a/src/domain.rs b/src/domain.rs new file mode 100644 index 0000000..bdb5e68 --- /dev/null +++ b/src/domain.rs @@ -0,0 +1,25 @@ +use std::collections::HashSet; + +use crate::hazard::{ThreadRecord, ThreadRecords}; + +pub struct Domain { + pub(crate) threads: ThreadRecords, +} + +impl Domain { + pub const fn new() -> Self { + Self { + threads: ThreadRecords::new(), + } + } + pub fn acquire(&self) -> &ThreadRecord { + self.threads.acquire() + } + + pub fn collect_guarded_ptrs(&self) -> HashSet<*mut u8> { + self.threads + .iter() + .flat_map(|thread| thread.hazptrs.iter()) + .collect() + } +} diff --git a/src/hazard.rs b/src/hazard.rs new file mode 100644 index 0000000..d818fd8 --- /dev/null +++ b/src/hazard.rs @@ -0,0 +1,304 @@ +use core::cell::Cell; +use core::marker::PhantomData; +use core::ptr; +use core::sync::atomic::{AtomicBool, AtomicPtr, Ordering}; + +use crate::untagged; +use crate::DEFAULT_THREAD; + +pub struct HazardPointer<'domain> { + hazard: &'domain ThreadHazPtrRecord, + thread: &'domain ThreadRecord, + _marker: PhantomData<*mut ()>, // !Send + !Sync +} + +pub enum ProtectError { + Stopped, + Changed(*mut T), +} + +impl Default for HazardPointer<'static> { + fn default() -> Self { + DEFAULT_THREAD.with(|t| HazardPointer::new(t.hazards)) + } +} + +impl<'domain> HazardPointer<'domain> { + /// Creat a hazard pointer in the given thread + pub fn new(thread: &'domain ThreadRecord) -> Self { + Self { + hazard: thread.hazptrs.acquire(), + thread, + _marker: PhantomData, + } + } + + /// Protect the given address. + /// + /// You will very rarely want to use this method, and should prefer the other protection + /// methods instead, as they guard against races between when the value of a shared pointer was + /// read and any changes to the shared pointer address. + /// + /// Note that protecting a given pointer only has an effect if any thread that may drop the + /// pointer does so through the same [`Domain`] as this hazard pointer is associated with. + /// + pub fn protect_raw(&mut self, ptr: *mut T) { + self.hazard.protect(ptr as *mut u8); + } + + /// Release the protection awarded by this hazard pointer, if any. + /// + /// If the hazard pointer was protecting an object, that object may now be reclaimed when + /// retired (assuming it isn't protected by any _other_ hazard pointers). + pub fn reset_protection(&mut self) { + self.hazard.reset(); + } + + /// Check if `src` still points to `pointer`. If not, returns the current value. + /// + /// For a pointer `p`, if "`src` still pointing to `pointer`" implies that `p` is not retired, + /// then `Ok(())` means that shields set to `p` are validated. + pub fn validate(pointer: *mut T, src: &AtomicPtr) -> Result<(), *mut T> { + membarrier::light_membarrier(); + // relaxed is ok thanks to the previous load (that created `pointer`) + the fence above + let new = src.load(Ordering::Relaxed); + if pointer as usize == new as usize { + Ok(()) + } else { + Err(new) + } + } + + /// Try protecting `pointer` obtained from `src`. If not, returns the current value. + /// + /// If "`src` still pointing to `pointer`" implies that `pointer` is not retired, then `Ok(())` + /// means that this shield is validated. + pub fn try_protect(&mut self, pointer: *mut T, src: &AtomicPtr) -> Result<(), *mut T> { + self.protect_raw(pointer); + Self::validate(pointer, src).map_err(|new| { + self.reset_protection(); + new + }) + } + + /// Get a protected pointer from `src`. + /// + /// See `try_protect()`. + pub fn protect(&mut self, src: &AtomicPtr) -> *mut T { + let mut pointer = src.load(Ordering::Relaxed); + while let Err(new) = self.try_protect(pointer, src) { + pointer = new; + } + pointer + } + + /// hp++ protection + pub fn try_protect_pp( + &mut self, + ptr: *mut T, + src: &S, + src_link: &F1, + check_stop: &F2, + ) -> Result<*mut T, ProtectError> + where + F1: Fn(&S) -> &AtomicPtr, + F2: Fn(&S) -> bool, + { + self.protect_raw(ptr); + membarrier::light_membarrier(); + if check_stop(src) { + return Err(ProtectError::Stopped); + } + let ptr_new = untagged(src_link(src).load(Ordering::Acquire)); + if ptr == ptr_new { + return Ok(ptr); + } + Err(ProtectError::Changed(ptr_new)) + } +} + +impl Drop for HazardPointer<'_> { + fn drop(&mut self) { + self.hazard.reset(); + self.thread.hazptrs.release(self.hazard); + } +} + +/// Push-only list of thread records +pub(crate) struct ThreadRecords { + head: AtomicPtr, +} + +pub struct ThreadRecord { + pub(crate) hazptrs: ThreadHazPtrRecords, + pub(crate) next: *mut ThreadRecord, + pub(crate) available: AtomicBool, +} + +/// Single-writer hazard pointer bag. +/// - push only +/// - efficient recycling +/// - No need to use CAS. +// TODO: This can be array, like Chase-Lev deque. +pub(crate) struct ThreadHazPtrRecords { + head: AtomicPtr, + // this is cell because it's only used by the owning thread + head_available: Cell<*mut ThreadHazPtrRecord>, +} + +pub(crate) struct ThreadHazPtrRecord { + pub(crate) ptr: AtomicPtr, + pub(crate) next: *mut ThreadHazPtrRecord, + // this is cell because it's only used by the owning thread + pub(crate) available_next: Cell<*mut ThreadHazPtrRecord>, +} + +impl ThreadRecords { + pub(crate) const fn new() -> Self { + Self { + head: AtomicPtr::new(ptr::null_mut()), + } + } + + pub(crate) fn acquire(&self) -> &ThreadRecord { + if let Some(avail) = self.try_acquire_available() { + return avail; + } + self.acquire_new() + } + + fn try_acquire_available(&self) -> Option<&ThreadRecord> { + let mut cur = self.head.load(Ordering::Acquire); + while let Some(cur_ref) = unsafe { cur.as_ref() } { + if cur_ref.available.load(Ordering::Relaxed) + && cur_ref + .available + .compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed) + .is_ok() + { + return Some(cur_ref); + } + cur = cur_ref.next; + } + None + } + + fn acquire_new(&self) -> &ThreadRecord { + let new = Box::leak(Box::new(ThreadRecord { + hazptrs: ThreadHazPtrRecords { + head: AtomicPtr::new(ptr::null_mut()), + head_available: Cell::new(ptr::null_mut()), + }, + next: ptr::null_mut(), + available: AtomicBool::new(false), + })); + + let mut head = self.head.load(Ordering::Relaxed); + loop { + new.next = head; + match self + .head + .compare_exchange(head, new, Ordering::Release, Ordering::Relaxed) + { + Ok(_) => return new, + Err(head_new) => head = head_new, + } + } + } + + pub(crate) fn release(&self, rec: &ThreadRecord) { + rec.available.store(true, Ordering::Release); + } + + pub(crate) fn iter(&self) -> ThreadRecordsIter<'_> { + ThreadRecordsIter { + cur: self.head.load(Ordering::Acquire).cast_const(), + _marker: PhantomData, + } + } +} + +pub(crate) struct ThreadRecordsIter<'domain> { + cur: *const ThreadRecord, + _marker: PhantomData<&'domain ThreadRecord>, +} + +impl<'domain> Iterator for ThreadRecordsIter<'domain> { + type Item = &'domain ThreadRecord; + + fn next(&mut self) -> Option { + if let Some(cur_ref) = unsafe { self.cur.as_ref() } { + self.cur = cur_ref.next; + Some(cur_ref) + } else { + None + } + } +} + +impl ThreadHazPtrRecords { + pub(crate) fn acquire(&self) -> &ThreadHazPtrRecord { + if let Some(avail) = self.try_acquire_available() { + return avail; + } + self.acquire_new() + } + + fn try_acquire_available(&self) -> Option<&ThreadHazPtrRecord> { + let head = self.head_available.get(); + let head_ref = unsafe { head.as_ref()? }; + let next = head_ref.available_next.get(); + self.head_available.set(next); + Some(head_ref) + } + + fn acquire_new(&self) -> &ThreadHazPtrRecord { + let head = self.head.load(Ordering::Relaxed); + let hazptr = Box::leak(Box::new(ThreadHazPtrRecord { + ptr: AtomicPtr::new(ptr::null_mut()), + next: head, + available_next: Cell::new(ptr::null_mut()), + })); + self.head.store(hazptr, Ordering::Release); + hazptr + } + + pub(crate) fn release(&self, rec: &ThreadHazPtrRecord) { + let avail = self.head_available.get(); + rec.available_next.set(avail); + self.head_available.set(rec as *const _ as *mut _); + } + + pub(crate) fn iter(&self) -> ThreadHazPtrRecordsIter { + ThreadHazPtrRecordsIter { + cur: self.head.load(Ordering::Acquire), + } + } +} + +pub(crate) struct ThreadHazPtrRecordsIter { + cur: *mut ThreadHazPtrRecord, +} + +impl Iterator for ThreadHazPtrRecordsIter { + type Item = *mut u8; + + fn next(&mut self) -> Option { + if let Some(cur_ref) = unsafe { self.cur.as_ref() } { + self.cur = cur_ref.next; + Some(cur_ref.ptr.load(Ordering::Acquire)) + } else { + None + } + } +} + +impl ThreadHazPtrRecord { + fn reset(&self) { + self.ptr.store(ptr::null_mut(), Ordering::Release); + } + + fn protect(&self, ptr: *mut u8) { + self.ptr.store(ptr, Ordering::Release); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..7d488cf --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,56 @@ +mod domain; +mod hazard; +mod retire; +mod tag; +mod thread; + +pub use hazard::HazardPointer; +pub use membarrier::light_membarrier; +pub use tag::*; + +use std::thread_local; + +use crate::domain::Domain; +use crate::thread::Thread; + +static DEFAULT_DOMAIN: Domain = Domain::new(); + +thread_local! { + static DEFAULT_THREAD: Thread<'static> = Thread::new(&DEFAULT_DOMAIN); +} + +/// Retire a pointer, in the thread-local retired pointer bag. +/// +/// # Safety +/// TODO +#[inline] +pub unsafe fn retire(ptr: *mut T) { + DEFAULT_THREAD.with(|t| t.retire(ptr)) +} + +/// Protects `links`, try unlinking `to_be_unlinked`, if successful, mark them as stopped and +/// retire them. +/// +/// # Safety +/// * The memory blocks in `to_be_unlinked` are no longer modified. +/// * TODO +pub unsafe fn try_unlink( + links: &[*mut T], + to_be_unlinked: &[*mut T], + do_unlink: F1, + set_stop: F2, +) -> bool +where + F1: FnOnce() -> bool, + F2: Fn(*mut T), +{ + DEFAULT_THREAD.with(|t| t.try_unlink(links, to_be_unlinked, do_unlink, set_stop)) +} + +/// Trigger reclamation +pub fn do_reclamation() { + DEFAULT_THREAD.with(|t| { + let mut reclaim = t.reclaim.borrow_mut(); + t.do_reclamation(&mut reclaim); + }) +} diff --git a/src/retire.rs b/src/retire.rs new file mode 100644 index 0000000..0b38911 --- /dev/null +++ b/src/retire.rs @@ -0,0 +1,16 @@ +pub(crate) struct Retired { + pub(crate) ptr: *mut u8, + pub(crate) deleter: unsafe fn(ptr: *mut u8), +} + +impl Retired { + pub(crate) fn new(ptr: *mut T) -> Self { + unsafe fn free(ptr: *mut u8) { + drop(Box::from_raw(ptr as *mut T)) + } + Self { + ptr: ptr as *mut u8, + deleter: free::, + } + } +} diff --git a/src/tag.rs b/src/tag.rs new file mode 100644 index 0000000..3abdd4b --- /dev/null +++ b/src/tag.rs @@ -0,0 +1,36 @@ +use core::mem; + +/// Returns a bitmask containing the unused least significant bits of an aligned pointer to `T`. +#[inline] +const fn low_bits() -> usize { + (1 << mem::align_of::().trailing_zeros()) - 1 +} + +/// Returns the pointer with the given tag +#[inline] +pub fn tagged(ptr: *mut T, tag: usize) -> *mut T { + ((ptr as usize & !low_bits::()) | (tag & low_bits::())) as *mut T +} + +/// Decomposes a tagged pointer `data` into the pointer and the tag. +#[inline] +pub fn decompose_ptr(ptr: *mut T) -> (*mut T, usize) { + let ptr = ptr as usize; + let raw = (ptr & !low_bits::()) as *mut T; + let tag = ptr & low_bits::(); + (raw, tag) +} + +/// Extract the actual address out of a tagged pointer +#[inline] +pub fn untagged(ptr: *mut T) -> *mut T { + let ptr = ptr as usize; + (ptr & !low_bits::()) as *mut T +} + +/// Extracts the tag out of a tagged pointer +#[inline] +pub fn tag(ptr: *mut T) -> usize { + let ptr = ptr as usize; + ptr & low_bits::() +} diff --git a/src/thread.rs b/src/thread.rs new file mode 100644 index 0000000..9dbe485 --- /dev/null +++ b/src/thread.rs @@ -0,0 +1,127 @@ +use core::cell::RefCell; +use core::mem; + +use crate::domain::Domain; +use crate::hazard::ThreadRecord; +use crate::retire::Retired; +use crate::HazardPointer; + +pub struct Thread<'domain> { + pub(crate) domain: &'domain Domain, + pub(crate) hazards: &'domain ThreadRecord, + pub(crate) reclaim: RefCell>, +} + +pub(crate) struct Reclamation<'domain> { + // Used for HP++ + pub(crate) hps: Vec>, + pub(crate) retired: Vec, + pub(crate) collect_count: usize, +} + +impl<'domain> Thread<'domain> { + const COUNTS_BETWEEN_COLLECT: usize = 128; + + pub fn new(domain: &'domain Domain) -> Self { + let thread = domain.acquire(); + Self { + domain, + hazards: thread, + reclaim: RefCell::new(Reclamation { + hps: Vec::new(), + retired: Vec::new(), + collect_count: 0, + }), + } + } + + pub unsafe fn retire(&self, ptr: *mut T) { + let mut reclaim = self.reclaim.borrow_mut(); + self.retire_inner(&mut reclaim, ptr); + } + + // NOTE: T: Send not required because we reclaim only locally. + #[inline] + unsafe fn retire_inner(&self, reclaim: &mut Reclamation, ptr: *mut T) { + reclaim.retired.push(Retired::new(ptr)); + + let collect_count = reclaim.collect_count.wrapping_add(1); + reclaim.collect_count = collect_count; + if collect_count % Self::COUNTS_BETWEEN_COLLECT == 0 { + self.do_reclamation(reclaim); + } + } + + pub unsafe fn try_unlink( + &self, + links: &[*mut T], + to_be_unlinked: &[*mut T], + do_unlink: F1, + set_stop: F2, + ) -> bool + where + F1: FnOnce() -> bool, + F2: Fn(*mut T), + { + let mut reclaim = self.reclaim.borrow_mut(); + + let mut hps: Vec<_> = links + .iter() + .map(|&ptr| { + let mut hp = HazardPointer::new(self.hazards); + hp.protect_raw(ptr); + hp + }) + .collect(); + + let unlinked = do_unlink(); + if unlinked { + for &ptr in to_be_unlinked { + set_stop(ptr); + } + reclaim.hps.append(&mut hps); + for &ptr in to_be_unlinked { + unsafe { self.retire_inner(&mut reclaim, ptr) } + } + } else { + drop(hps); + } + unlinked + } + + #[inline] + pub(crate) fn do_reclamation(&self, reclaim: &mut Reclamation) { + membarrier::heavy(); + + // only for hp++, but this doesn't introduce big cost for plain hp. + drop(mem::take(&mut reclaim.hps)); + + let guarded_ptrs = self.domain.collect_guarded_ptrs(); + reclaim.retired = reclaim + .retired + .iter() + .filter_map(|element| { + if guarded_ptrs.contains(&(element.ptr as *mut u8)) { + Some(Retired { + ptr: element.ptr, + deleter: element.deleter, + }) + } else { + unsafe { (element.deleter)(element.ptr) }; + None + } + }) + .collect(); + } +} + +impl<'domain> Drop for Thread<'domain> { + fn drop(&mut self) { + self.domain.threads.release(self.hazards); + let mut reclaim = self.reclaim.borrow_mut(); + while !reclaim.retired.is_empty() { + self.do_reclamation(&mut reclaim); + core::hint::spin_loop(); + } + } +} diff --git a/tests/test.rs b/tests/test.rs new file mode 100644 index 0000000..1336fd4 --- /dev/null +++ b/tests/test.rs @@ -0,0 +1,384 @@ +use core::sync::atomic::{AtomicPtr, Ordering::*}; +use std::thread::sleep; +use std::time::Duration; + +use hp_pp::*; +use queue::Queue; +use stack::Stack; +use std::thread::scope; + +#[test] +fn counter() { + const THREADS: usize = 4; + const ITER: usize = 1024 * 16; + + let count = AtomicPtr::new(Box::leak(Box::new(0usize))); + scope(|s| { + for _ in 0..THREADS { + s.spawn(|| { + for _ in 0..ITER { + let mut new = Box::new(0); + let mut hp = HazardPointer::default(); + loop { + let cur_ptr = hp.protect(&count); + let value = unsafe { *cur_ptr }; + *new = value + 1; + let new_ptr = Box::leak(new); + if count + .compare_exchange(cur_ptr, new_ptr, AcqRel, Acquire) + .is_ok() + { + unsafe { retire(cur_ptr) }; + break; + } else { + new = unsafe { Box::from_raw(new_ptr) }; + } + } + } + }); + } + }); + let cur = count.load(Acquire); + // exclusive access + assert_eq!(unsafe { *cur }, THREADS * ITER); + unsafe { retire(cur) }; +} + +// like `counter`, but trigger interesting interleaving using `sleep` and always call +// `do_reclamation`. +#[test] +fn counter_sleep() { + const THREADS: usize = 4; + const ITER: usize = 1024 * 16; + + let count = AtomicPtr::new(Box::leak(Box::new(0usize))); + scope(|s| { + for _ in 0..THREADS { + s.spawn(|| { + for _ in 0..ITER { + let mut new = Box::new(0); + let mut hp = HazardPointer::default(); + loop { + let cur_ptr = { + let mut cur = count.load(Relaxed); + loop { + match hp.try_protect(cur, &count) { + Ok(_) => break cur, + Err(new) => { + sleep(Duration::from_micros(1)); + cur = new; + } + } + } + }; + sleep(Duration::from_micros(1)); + let value = unsafe { *cur_ptr }; + *new = value + 1; + let new_ptr = Box::leak(new); + if count + .compare_exchange(cur_ptr, new_ptr, AcqRel, Acquire) + .is_ok() + { + unsafe { retire(cur_ptr) }; + do_reclamation(); + break; + } else { + new = unsafe { Box::from_raw(new_ptr) }; + } + } + } + }); + } + }); + let cur = count.load(Acquire); + // exclusive access + assert_eq!(unsafe { *cur }, THREADS * ITER); + unsafe { retire(cur) }; +} + +#[test] +fn stack() { + const THREADS: usize = 8; + const ITER: usize = 1024 * 16; + + let stack = Stack::default(); + scope(|s| { + for _ in 0..THREADS { + s.spawn(|| { + for i in 0..ITER { + stack.push(i); + assert!(stack.try_pop().is_some()); + do_reclamation(); + } + }); + } + }); + assert!(stack.try_pop().is_none()); +} + +#[test] +fn queue() { + const THREADS: usize = 8; + const ITER: usize = 1024 * 32; + + let queue = Queue::default(); + scope(|s| { + for _ in 0..THREADS { + s.spawn(|| { + for i in 0..ITER { + queue.push(i); + assert!(queue.try_pop().is_some()); + do_reclamation(); + } + }); + } + }); +} + +#[test] +fn stack_queue() { + const THREADS: usize = 8; + const ITER: usize = 1024 * 16; + + let stack = Stack::default(); + let queue = Queue::default(); + scope(|s| { + for _ in 0..THREADS { + s.spawn(|| { + for i in 0..ITER { + stack.push(i); + queue.push(i); + stack.try_pop(); + queue.try_pop(); + do_reclamation(); + } + }); + } + }); + assert!(stack.try_pop().is_none()); +} + +mod stack { + use core::mem::ManuallyDrop; + use core::ptr; + use core::sync::atomic::{AtomicPtr, Ordering::*}; + + use hp_pp::*; + + /// Treiber's lock-free stack. + #[derive(Debug)] + pub struct Stack { + head: AtomicPtr>, + } + + #[derive(Debug)] + struct Node { + data: ManuallyDrop, + next: *mut Node, + } + + unsafe impl Send for Node {} + unsafe impl Sync for Node {} + + impl Default for Stack { + fn default() -> Self { + Stack { + head: AtomicPtr::new(ptr::null_mut()), + } + } + } + + impl Stack { + pub fn push(&self, t: T) { + let new = Box::leak(Box::new(Node { + data: ManuallyDrop::new(t), + next: ptr::null_mut(), + })); + + loop { + let head = self.head.load(Relaxed); + new.next = head; + + if self + .head + .compare_exchange(head, new, Release, Relaxed) + .is_ok() + { + break; + } + } + } + + pub fn try_pop(&self) -> Option { + let mut hp = HazardPointer::default(); + loop { + let head_ptr = hp.protect(&self.head); + let head_ref = unsafe { head_ptr.as_ref() }?; + + if self + .head + .compare_exchange(head_ptr, head_ref.next, Relaxed, Relaxed) + .is_ok() + { + let data = unsafe { ManuallyDrop::take(&mut (*head_ptr).data) }; + unsafe { retire(head_ptr) }; + return Some(data); + } + } + } + } + + impl Drop for Stack { + fn drop(&mut self) { + let mut curr = *self.head.get_mut(); + while !curr.is_null() { + let curr_ref = unsafe { Box::from_raw(curr) }; + drop(ManuallyDrop::into_inner(curr_ref.data)); + curr = curr_ref.next; + } + } + } +} + +mod queue { + use core::mem::MaybeUninit; + use core::ptr; + use core::sync::atomic::{AtomicPtr, Ordering::*}; + + use hp_pp::*; + + /// Michael-Scott queue. + #[derive(Debug)] + pub struct Queue { + head: AtomicPtr>, + tail: AtomicPtr>, + } + + #[derive(Debug)] + struct Node { + data: MaybeUninit, + next: AtomicPtr>, + } + + unsafe impl Sync for Queue {} + unsafe impl Send for Queue {} + + impl Default for Queue { + fn default() -> Self { + let q = Self { + head: AtomicPtr::new(ptr::null_mut()), + tail: AtomicPtr::new(ptr::null_mut()), + }; + let sentinel = Box::leak(Box::new(Node { + data: MaybeUninit::uninit(), + next: AtomicPtr::new(ptr::null_mut()), + })); + q.head.store(sentinel, Relaxed); + q.tail.store(sentinel, Relaxed); + q + } + } + + impl Queue { + pub fn push(&self, t: T) { + let new = Box::leak(Box::new(Node { + data: MaybeUninit::new(t), + next: AtomicPtr::new(ptr::null_mut()), + })); + let mut hp = HazardPointer::default(); + + loop { + // We push onto the tail, so we'll start optimistically by looking there first. + let tail = hp.protect(&self.tail); + // SAFETY + // 1. queue's `tail` is always valid as it will be CASed with valid nodes only. + // 2. `tail` is protected & validated. + let tail_ref = unsafe { tail.as_ref().unwrap() }; + + let next = tail_ref.next.load(Acquire); + if !next.is_null() { + let _ = self.tail.compare_exchange(tail, next, Release, Relaxed); + continue; + } + + if tail_ref + .next + .compare_exchange(ptr::null_mut(), new, Release, Relaxed) + .is_ok() + { + let _ = self.tail.compare_exchange(tail, new, Release, Relaxed); + break; + } + } + } + + /// Attempts to dequeue from the front. + /// + /// Returns `None` if the queue is empty. + pub fn try_pop(&self) -> Option { + let mut head_hp = HazardPointer::default(); + let mut next_hp = HazardPointer::default(); + let mut head = self.head.load(Acquire); + loop { + if let Err(new) = head_hp.try_protect(head, &self.head) { + head = new; + continue; + } + // SAFETY: + // 1. queue's `head` is always valid as it will be CASed with valid nodes only. + // 2. `head` is protected & validated. + let head_ref = unsafe { &*head }; + + let next = head_ref.next.load(Acquire); + if next.is_null() { + return None; + } + next_hp.protect_raw(next); + let next_ref = match HazardPointer::validate(head, &self.head) { + Ok(_) => { + // SAFETY: + // 1. If `next` was not null, then it must be a valid node that another + // thread has `push()`ed. + // 2. Validation: If `head` is not retired, then `next` is not retired. So + // re-validating `head` also validates `next. + unsafe { &*next } + } + Err(new) => { + next_hp.reset_protection(); + head = new; + continue; + } + }; + + // Moves `tail` if it's stale. Relaxed load is enough because if tail == head, then + // the messages for that node are already acquired. + let tail = self.tail.load(Relaxed); + if tail == head { + let _ = self.tail.compare_exchange(tail, next, Release, Relaxed); + } + + if self + .head + .compare_exchange(head, next, Release, Relaxed) + .is_ok() + { + let result = unsafe { next_ref.data.assume_init_read() }; + unsafe { retire(head) }; + return Some(result); + } + } + } + } + + impl Drop for Queue { + fn drop(&mut self) { + let sentinel = unsafe { Box::from_raw(*self.head.get_mut()) }; + let mut curr = sentinel.next.into_inner(); + while !curr.is_null() { + let curr_ref = unsafe { Box::from_raw(curr) }; + drop(unsafe { curr_ref.data.assume_init() }); + curr = curr_ref.next.load(Relaxed); + } + } + } +}