Skip to content

Commit

Permalink
use birth_era correctly
Browse files Browse the repository at this point in the history
  • Loading branch information
wyang5 committed Oct 10, 2024
1 parent 86294e5 commit fa50d79
Showing 1 changed file with 95 additions and 124 deletions.
219 changes: 95 additions & 124 deletions src/ibr.rs → src/smr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@ use std::collections::VecDeque;
use std::mem::zeroed;
use std::ptr::NonNull;
use std::sync::atomic::Ordering::{Relaxed, SeqCst};
use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicUsize};
use std::sync::atomic::{AtomicBool, AtomicPtr, AtomicU64};

use crate::utils::{Stack, ULL};

const HELP_FLAG: usize = usize::MAX;
const SLOTS_PER_NODE: usize = 8;

pub struct Reclaimer {
slots: ULL<Slot, SLOTS_PER_NODE>,

// the global epoch value.
epoch: AtomicUsize,
// the global era value.
era: AtomicU64,

// limbo lists may be transferred here on drop.
drop_cache: Stack<Vec<RetiredFn>>,
Expand All @@ -25,7 +24,7 @@ impl Reclaimer {
pub const fn new() -> Self {
Self {
slots: unsafe { zeroed() },
epoch: AtomicUsize::new(1),
era: AtomicU64::new(1),
drop_cache: Stack::new(),
}
}
Expand Down Expand Up @@ -65,18 +64,11 @@ impl Reclaimer {
intervals: RefCell::default(),
}
}
pub fn increment_epoch(&self) {
let new_epoch = self.epoch.load(SeqCst) + 1;
for slot in self.slots.into_iter() {
if slot.end_epoch.load(SeqCst) == HELP_FLAG {
_ = slot
.end_epoch
.compare_exchange(HELP_FLAG, new_epoch, SeqCst, Relaxed)
}
}
_ = self
.epoch
.compare_exchange(new_epoch - 1, new_epoch, SeqCst, Relaxed)
pub fn increment_era(&self) {
self.era.fetch_add(1, SeqCst);
}
pub fn load_era(&self) -> u64 {
self.era.load(SeqCst)
}
}

Expand Down Expand Up @@ -127,8 +119,8 @@ impl Default for Reclaimer {
repr(align(64))
)]
struct Slot {
start_epoch: AtomicUsize,
end_epoch: AtomicUsize,
start_era: AtomicU64,
end_era: AtomicU64,
is_claimed: AtomicBool,
}

Expand All @@ -149,78 +141,57 @@ pub struct ThreadContext<'a> {
cleanup_freq: usize,
cleanup_counter: Cell<usize>,

// a monotonically increasing queue consisting of (epoch, count) tuples.
counts: RefCell<VecDeque<(usize, usize)>>,
// a monotonically increasing queue consisting of (era, count) tuples.
counts: RefCell<VecDeque<(u64, usize)>>,
// a reusable Vec for storing hazardous intervals when scanning slots.
intervals: RefCell<Vec<(usize, usize)>>,
intervals: RefCell<Vec<(u64, u64)>>,
}

impl<'a> ThreadContext<'a> {
pub fn load<T>(&self, src: &AtomicPtr<T>) -> Option<Guard<'_, 'a, T>> {
self.protect(&src, NonNull::new(src.load(SeqCst))?)
}
pub fn protect<T>(&self, src: &AtomicPtr<T>, ptr: NonNull<T>) -> Option<Guard<'_, 'a, T>> {
let mut counts = self.counts.borrow_mut();
let mut initial_end_epoch = 0;

let mut epoch = self.reclaimer.epoch.load(SeqCst);
let mut initial_end_era = 0;
let mut era = self.reclaimer.era.load(SeqCst);
if let Some(back) = counts.back_mut() {
initial_end_epoch = back.0;
if initial_end_epoch >= epoch {
// the current epoch was already protected by a previous call to this method.
// simply increment the count and return.
initial_end_era = back.0;
if initial_end_era == era {
// the current era was already protected by a previous call to this method.
// simply increment the count of the last protected era.
back.1 += 1;
return Some(Guard {
ctx: self,
epoch: initial_end_epoch,
era,
ptr,
});
}
}

self.slot.end_epoch.store(epoch, SeqCst);
let Some(ptr) = NonNull::new(src.load(SeqCst)) else {
// null pointers don't need protection; reset end_epoch to what it was before.
self.slot.end_epoch.store(initial_end_epoch, SeqCst);
return None;
};
if epoch == self.reclaimer.epoch.load(SeqCst) {
counts.push_back((epoch, 1));
if initial_end_epoch == 0 {
// this is our first reservation, so start_epoch should also be updated.
self.slot.start_epoch.store(epoch, Relaxed);
self.slot.end_era.store(era, SeqCst);
while let Some(ptr) = NonNull::new(src.load(SeqCst)) {
let next_era = self.reclaimer.era.load(SeqCst);
if era == next_era {
counts.push_back((era, 1));
if counts.len() == 1 {
// this is our first reservation, so start_era must also be updated.
self.slot.start_era.store(era, SeqCst);
}
return Some(Guard {
ctx: self,
era,
ptr,
});
}
return Some(Guard {
ctx: self,
epoch,
ptr,
});
era = next_era;
self.slot.end_era.store(era, SeqCst);
}

// the global epoch changed; fall back to the slow path.
self.slot.end_epoch.store(HELP_FLAG, SeqCst);

let Some(ptr) = NonNull::new(src.load(SeqCst)) else {
self.slot.end_epoch.store(initial_end_epoch, SeqCst);
return None;
};
epoch = self.reclaimer.epoch.load(SeqCst);
if let Err(actual) = self
.slot
.end_epoch
.compare_exchange(HELP_FLAG, epoch, SeqCst, SeqCst)
{
epoch = actual;
}
counts.push_back((epoch, 1));
if initial_end_epoch == 0 {
self.slot.start_epoch.store(epoch, Relaxed);
}
Some(Guard {
ctx: self,
epoch,
ptr,
})
// null ptrs don't need protection; reset end_era to what it was before.
self.slot.end_era.store(initial_end_era, SeqCst);
None
}

pub fn retire(&self, ptr: *mut u8, f: fn(*mut u8)) {
pub unsafe fn retire(&self, ptr: *mut u8, f: fn(*mut u8), birth_era: u64) {
if self.cleanup_freq == 0 {
panic!("cannot retire using this context: cleanup_freq is 0.")
}
Expand All @@ -229,11 +200,11 @@ impl<'a> ThreadContext<'a> {
if self.cleanup_counter.get() == 0 {
self.scan_and_cleanup();
}
let epoch_retired = self.reclaimer.epoch.load(SeqCst);
let retire_era = self.reclaimer.era.load(SeqCst);
self.limbo_list.borrow_mut().push(RetiredFn {
ptr,
f,
epoch_retired,
span: (birth_era, retire_era),
});
}

Expand All @@ -243,24 +214,14 @@ impl<'a> ThreadContext<'a> {

// scan the global array of reservations.
for slot in self.reclaimer.slots.into_iter() {
let mut end = slot.end_epoch.load(SeqCst);
let end = slot.end_era.load(SeqCst);
if end == 0 {
// this thread has no reservations.
continue;
}
if end == HELP_FLAG {
// this thread has requested help.
end = self.reclaimer.epoch.load(SeqCst);
if let Err(actual) = slot
.end_epoch
.compare_exchange(HELP_FLAG, end, SeqCst, SeqCst)
{
end = actual;
}
}
let mut start = slot.start_epoch.load(SeqCst);
let mut start = slot.start_era.load(SeqCst);
if start == 0 {
// this slot has one reservation, defined by end_epoch.
// this slot has one reservation, defined by end_era.
start = end;
}
intervals.push((start, end));
Expand All @@ -286,10 +247,9 @@ impl<'a> ThreadContext<'a> {
// go through the limbo list and delete the entries without conflicts.
let mut i = 0;
while i < limbo_list.len() {
let epoch = limbo_list[i].epoch_retired;
let has_conflict = intervals
.iter()
.any(|(start, end)| *start <= epoch && epoch <= *end);
.any(|x| intervals_overlap(limbo_list[i].span, *x));
if has_conflict {
i += 1;
} else {
Expand All @@ -300,6 +260,10 @@ impl<'a> ThreadContext<'a> {
}
}

fn intervals_overlap(a: (u64, u64), b: (u64, u64)) -> bool {
a.0 <= b.1 && b.0 <= a.1
}

impl<'a> Drop for ThreadContext<'a> {
fn drop(&mut self) {
debug_assert!(self.counts.borrow_mut().is_empty());
Expand Down Expand Up @@ -334,7 +298,7 @@ impl<'a> Drop for ThreadContext<'a> {

pub struct Guard<'a, 'b: 'a, T> {
ctx: &'b ThreadContext<'a>,
epoch: usize,
era: u64,
ptr: NonNull<T>,
}

Expand All @@ -349,59 +313,51 @@ impl<'a, 'b: 'a, T> Drop for Guard<'a, 'b, T> {
let mut counts = self.ctx.counts.borrow_mut();

// decrement the count.
let pair = counts.iter_mut().find(|(e, _)| *e == self.epoch).unwrap();
let pair = counts.iter_mut().find(|(e, _)| *e == self.era).unwrap();
pair.1 -= 1;

let mut start_epoch_changed = false;
let mut end_epoch_changed = false;
let mut start_era_changed = false;
let mut end_era_changed = false;

// pop from the front of the queue to shrink the interval.
// pop from the front and back of the queue to shrink the interval.
while let Some((_, count)) = counts.front() {
if *count > 0 {
break;
}
counts.pop_front();
start_epoch_changed = true;
start_era_changed = true;
}
while let Some((_, count)) = counts.back() {
if *count > 0 {
break;
}
counts.pop_back();
end_epoch_changed = true;
end_era_changed = true;
}

// publish our interval update.
// update our interval.
if counts.is_empty() {
// we have no more reservations; zero out our interval.
self.ctx.slot.end_epoch.store(0, SeqCst);
self.ctx.slot.start_epoch.store(0, SeqCst);
} else if start_epoch_changed {
self.ctx.slot.end_era.store(0, SeqCst);
self.ctx.slot.start_era.store(0, SeqCst);
} else if start_era_changed {
self.ctx
.slot
.start_epoch
.start_era
.store(counts.front().unwrap().0, SeqCst);
} else if end_epoch_changed {
} else if end_era_changed {
self.ctx
.slot
.end_epoch
.end_era
.store(counts.back().unwrap().0, SeqCst);
}
debug_assert_eq!(
self.ctx.slot.start_epoch.load(Relaxed),
counts.front().map_or(0, |x| x.0)
);
debug_assert_eq!(
self.ctx.slot.end_epoch.load(Relaxed),
counts.back().map_or(0, |x| x.0)
);
}
}

struct RetiredFn {
ptr: *mut u8,
f: fn(*mut u8),
epoch_retired: usize,
span: (u64, u64),
}

impl Drop for RetiredFn {
Expand All @@ -413,12 +369,16 @@ impl Drop for RetiredFn {
#[cfg(test)]
mod tests {
use std::mem::zeroed;
use std::ptr::NonNull;
use std::sync::atomic::Ordering::{Relaxed, SeqCst};
use std::sync::atomic::{AtomicPtr, AtomicUsize};
use std::thread;

use crate::ibr::Reclaimer;
use crate::smr::Reclaimer;

struct Obj<T> {
val: T,
birth_era: u64,
}

#[test]
fn test_protect_retire() {
Expand All @@ -429,25 +389,36 @@ mod tests {

let counts: [AtomicUsize; MAX_VAL] = unsafe { zeroed() };

let x = AtomicPtr::new(Box::into_raw(Box::new(0)));
let x = AtomicPtr::new(Box::into_raw(Box::new(Obj {
val: 0,
birth_era: r.load_era(),
})));

thread::scope(|scope| {
for _ in 0..THREADS_COUNT {
scope.spawn(|| {
let ctx = r.join(1);
for val in 0..MAX_VAL {
if let Some(p) = NonNull::new(x.load(SeqCst)) {
if let Some(guard) = ctx.protect(&x, p) {
unsafe {
counts[*guard.as_ptr()].fetch_add(1, Relaxed);
}
if let Some(guard) = ctx.load(&x) {
unsafe {
counts[(*guard.as_ptr()).val].fetch_add(1, Relaxed);
}
}
let swapped = x.swap(Box::into_raw(Box::new(val)), SeqCst);
let obj = Obj {
val,
birth_era: r.load_era(),
};
let swapped = x.swap(Box::into_raw(Box::new(obj)), SeqCst);
if !swapped.is_null() {
ctx.retire(swapped as *mut u8, dealloc_boxed_ptr::<usize>);
unsafe {
ctx.retire(
swapped as *mut u8,
dealloc_boxed_ptr::<Obj<usize>>,
(*swapped).birth_era,
);
}
}
r.increment_epoch();
r.increment_era();
}
});
}
Expand Down

0 comments on commit fa50d79

Please sign in to comment.