diff --git a/src/ibr.rs b/src/smr.rs similarity index 66% rename from src/ibr.rs rename to src/smr.rs index 7d88b7a..0a972c5 100644 --- a/src/ibr.rs +++ b/src/smr.rs @@ -4,18 +4,17 @@ use std::collections::VecDeque; use std::mem::zeroed; use std::ptr::NonNull; use std::sync::atomic::Ordering::{Relaxed, SeqCst}; -use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize}; +use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicU64}; use crate::utils::{Stack, ULL}; -const HELP_FLAG: usize = usize::MAX; const SLOTS_PER_NODE: usize = 8; pub struct Reclaimer { slots: ULL, - // the global epoch value. - epoch: AtomicUsize, + // the global era value. + era: AtomicU64, // limbo lists may be transferred here on drop. drop_cache: Stack>, @@ -25,7 +24,7 @@ impl Reclaimer { pub const fn new() -> Self { Self { slots: unsafe { zeroed() }, - epoch: AtomicUsize::new(1), + era: AtomicU64::new(1), drop_cache: Stack::new(), } } @@ -65,18 +64,11 @@ impl Reclaimer { intervals: RefCell::default(), } } - pub fn increment_epoch(&self) { - let new_epoch = self.epoch.load(SeqCst) + 1; - for slot in self.slots.into_iter() { - if slot.end_epoch.load(SeqCst) == HELP_FLAG { - _ = slot - .end_epoch - .compare_exchange(HELP_FLAG, new_epoch, SeqCst, Relaxed) - } - } - _ = self - .epoch - .compare_exchange(new_epoch - 1, new_epoch, SeqCst, Relaxed) + pub fn increment_era(&self) { + self.era.fetch_add(1, SeqCst); + } + pub fn load_era(&self) -> u64 { + self.era.load(SeqCst) } } @@ -127,8 +119,8 @@ impl Default for Reclaimer { repr(align(64)) )] struct Slot { - start_epoch: AtomicUsize, - end_epoch: AtomicUsize, + start_era: AtomicU64, + end_era: AtomicU64, is_claimed: AtomicBool, } @@ -149,78 +141,57 @@ pub struct ThreadContext<'a> { cleanup_freq: usize, cleanup_counter: Cell, - // a monotonically increasing queue consisting of (epoch, count) tuples. - counts: RefCell>, + // a monotonically increasing queue consisting of (era, count) tuples. + counts: RefCell>, // a reusable Vec for storing hazardous intervals when scanning slots. - intervals: RefCell>, + intervals: RefCell>, } impl<'a> ThreadContext<'a> { + pub fn load(&self, src: &AtomicPtr) -> Option> { + self.protect(&src, NonNull::new(src.load(SeqCst))?) + } pub fn protect(&self, src: &AtomicPtr, ptr: NonNull) -> Option> { let mut counts = self.counts.borrow_mut(); - let mut initial_end_epoch = 0; - - let mut epoch = self.reclaimer.epoch.load(SeqCst); + let mut initial_end_era = 0; + let mut era = self.reclaimer.era.load(SeqCst); if let Some(back) = counts.back_mut() { - initial_end_epoch = back.0; - if initial_end_epoch >= epoch { - // the current epoch was already protected by a previous call to this method. - // simply increment the count and return. + initial_end_era = back.0; + if initial_end_era == era { + // the current era was already protected by a previous call to this method. + // simply increment the count of the last protected era. back.1 += 1; return Some(Guard { ctx: self, - epoch: initial_end_epoch, + era, ptr, }); } } - - self.slot.end_epoch.store(epoch, SeqCst); - let Some(ptr) = NonNull::new(src.load(SeqCst)) else { - // null pointers don't need protection; reset end_epoch to what it was before. - self.slot.end_epoch.store(initial_end_epoch, SeqCst); - return None; - }; - if epoch == self.reclaimer.epoch.load(SeqCst) { - counts.push_back((epoch, 1)); - if initial_end_epoch == 0 { - // this is our first reservation, so start_epoch should also be updated. - self.slot.start_epoch.store(epoch, Relaxed); + self.slot.end_era.store(era, SeqCst); + while let Some(ptr) = NonNull::new(src.load(SeqCst)) { + let next_era = self.reclaimer.era.load(SeqCst); + if era == next_era { + counts.push_back((era, 1)); + if counts.len() == 1 { + // this is our first reservation, so start_era must also be updated. + self.slot.start_era.store(era, SeqCst); + } + return Some(Guard { + ctx: self, + era, + ptr, + }); } - return Some(Guard { - ctx: self, - epoch, - ptr, - }); + era = next_era; + self.slot.end_era.store(era, SeqCst); } - - // the global epoch changed; fall back to the slow path. - self.slot.end_epoch.store(HELP_FLAG, SeqCst); - - let Some(ptr) = NonNull::new(src.load(SeqCst)) else { - self.slot.end_epoch.store(initial_end_epoch, SeqCst); - return None; - }; - epoch = self.reclaimer.epoch.load(SeqCst); - if let Err(actual) = self - .slot - .end_epoch - .compare_exchange(HELP_FLAG, epoch, SeqCst, SeqCst) - { - epoch = actual; - } - counts.push_back((epoch, 1)); - if initial_end_epoch == 0 { - self.slot.start_epoch.store(epoch, Relaxed); - } - Some(Guard { - ctx: self, - epoch, - ptr, - }) + // null ptrs don't need protection; reset end_era to what it was before. + self.slot.end_era.store(initial_end_era, SeqCst); + None } - pub fn retire(&self, ptr: *mut u8, f: fn(*mut u8)) { + pub unsafe fn retire(&self, ptr: *mut u8, f: fn(*mut u8), birth_era: u64) { if self.cleanup_freq == 0 { panic!("cannot retire using this context: cleanup_freq is 0.") } @@ -229,11 +200,11 @@ impl<'a> ThreadContext<'a> { if self.cleanup_counter.get() == 0 { self.scan_and_cleanup(); } - let epoch_retired = self.reclaimer.epoch.load(SeqCst); + let retire_era = self.reclaimer.era.load(SeqCst); self.limbo_list.borrow_mut().push(RetiredFn { ptr, f, - epoch_retired, + span: (birth_era, retire_era), }); } @@ -243,24 +214,14 @@ impl<'a> ThreadContext<'a> { // scan the global array of reservations. for slot in self.reclaimer.slots.into_iter() { - let mut end = slot.end_epoch.load(SeqCst); + let end = slot.end_era.load(SeqCst); if end == 0 { // this thread has no reservations. continue; } - if end == HELP_FLAG { - // this thread has requested help. - end = self.reclaimer.epoch.load(SeqCst); - if let Err(actual) = slot - .end_epoch - .compare_exchange(HELP_FLAG, end, SeqCst, SeqCst) - { - end = actual; - } - } - let mut start = slot.start_epoch.load(SeqCst); + let mut start = slot.start_era.load(SeqCst); if start == 0 { - // this slot has one reservation, defined by end_epoch. + // this slot has one reservation, defined by end_era. start = end; } intervals.push((start, end)); @@ -286,10 +247,9 @@ impl<'a> ThreadContext<'a> { // go through the limbo list and delete the entries without conflicts. let mut i = 0; while i < limbo_list.len() { - let epoch = limbo_list[i].epoch_retired; let has_conflict = intervals .iter() - .any(|(start, end)| *start <= epoch && epoch <= *end); + .any(|x| intervals_overlap(limbo_list[i].span, *x)); if has_conflict { i += 1; } else { @@ -300,6 +260,10 @@ impl<'a> ThreadContext<'a> { } } +fn intervals_overlap(a: (u64, u64), b: (u64, u64)) -> bool { + a.0 <= b.1 && b.0 <= a.1 +} + impl<'a> Drop for ThreadContext<'a> { fn drop(&mut self) { debug_assert!(self.counts.borrow_mut().is_empty()); @@ -334,7 +298,7 @@ impl<'a> Drop for ThreadContext<'a> { pub struct Guard<'a, 'b: 'a, T> { ctx: &'b ThreadContext<'a>, - epoch: usize, + era: u64, ptr: NonNull, } @@ -349,59 +313,51 @@ impl<'a, 'b: 'a, T> Drop for Guard<'a, 'b, T> { let mut counts = self.ctx.counts.borrow_mut(); // decrement the count. - let pair = counts.iter_mut().find(|(e, _)| *e == self.epoch).unwrap(); + let pair = counts.iter_mut().find(|(e, _)| *e == self.era).unwrap(); pair.1 -= 1; - let mut start_epoch_changed = false; - let mut end_epoch_changed = false; + let mut start_era_changed = false; + let mut end_era_changed = false; - // pop from the front of the queue to shrink the interval. + // pop from the front and back of the queue to shrink the interval. while let Some((_, count)) = counts.front() { if *count > 0 { break; } counts.pop_front(); - start_epoch_changed = true; + start_era_changed = true; } while let Some((_, count)) = counts.back() { if *count > 0 { break; } counts.pop_back(); - end_epoch_changed = true; + end_era_changed = true; } - // publish our interval update. + // update our interval. if counts.is_empty() { // we have no more reservations; zero out our interval. - self.ctx.slot.end_epoch.store(0, SeqCst); - self.ctx.slot.start_epoch.store(0, SeqCst); - } else if start_epoch_changed { + self.ctx.slot.end_era.store(0, SeqCst); + self.ctx.slot.start_era.store(0, SeqCst); + } else if start_era_changed { self.ctx .slot - .start_epoch + .start_era .store(counts.front().unwrap().0, SeqCst); - } else if end_epoch_changed { + } else if end_era_changed { self.ctx .slot - .end_epoch + .end_era .store(counts.back().unwrap().0, SeqCst); } - debug_assert_eq!( - self.ctx.slot.start_epoch.load(Relaxed), - counts.front().map_or(0, |x| x.0) - ); - debug_assert_eq!( - self.ctx.slot.end_epoch.load(Relaxed), - counts.back().map_or(0, |x| x.0) - ); } } struct RetiredFn { ptr: *mut u8, f: fn(*mut u8), - epoch_retired: usize, + span: (u64, u64), } impl Drop for RetiredFn { @@ -413,12 +369,16 @@ impl Drop for RetiredFn { #[cfg(test)] mod tests { use std::mem::zeroed; - use std::ptr::NonNull; use std::sync::atomic::Ordering::{Relaxed, SeqCst}; use std::sync::atomic::{AtomicPtr, AtomicUsize}; use std::thread; - use crate::ibr::Reclaimer; + use crate::smr::Reclaimer; + + struct Obj { + val: T, + birth_era: u64, + } #[test] fn test_protect_retire() { @@ -429,25 +389,36 @@ mod tests { let counts: [AtomicUsize; MAX_VAL] = unsafe { zeroed() }; - let x = AtomicPtr::new(Box::into_raw(Box::new(0))); + let x = AtomicPtr::new(Box::into_raw(Box::new(Obj { + val: 0, + birth_era: r.load_era(), + }))); thread::scope(|scope| { for _ in 0..THREADS_COUNT { scope.spawn(|| { let ctx = r.join(1); for val in 0..MAX_VAL { - if let Some(p) = NonNull::new(x.load(SeqCst)) { - if let Some(guard) = ctx.protect(&x, p) { - unsafe { - counts[*guard.as_ptr()].fetch_add(1, Relaxed); - } + if let Some(guard) = ctx.load(&x) { + unsafe { + counts[(*guard.as_ptr()).val].fetch_add(1, Relaxed); } } - let swapped = x.swap(Box::into_raw(Box::new(val)), SeqCst); + let obj = Obj { + val, + birth_era: r.load_era(), + }; + let swapped = x.swap(Box::into_raw(Box::new(obj)), SeqCst); if !swapped.is_null() { - ctx.retire(swapped as *mut u8, dealloc_boxed_ptr::); + unsafe { + ctx.retire( + swapped as *mut u8, + dealloc_boxed_ptr::>, + (*swapped).birth_era, + ); + } } - r.increment_epoch(); + r.increment_era(); } }); }