diff --git a/src/lib.rs b/src/lib.rs index 02dacd68..257d92e6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -220,3 +220,4 @@ pub use storage::{ }; mod vm; pub use vm::{ExecutionError, PevmTxExecutionResult}; +mod unsafe_vec; diff --git a/src/mv_memory.rs b/src/mv_memory.rs index 4186712c..1cf576b5 100644 --- a/src/mv_memory.rs +++ b/src/mv_memory.rs @@ -5,8 +5,8 @@ use alloy_primitives::Address; use dashmap::{mapref::one::Ref, DashMap}; use crate::{ - BuildAddressHasher, BuildIdentityHasher, MemoryEntry, MemoryLocationHash, NewLazyAddresses, - ReadOrigin, ReadSet, TxIdx, TxVersion, WriteSet, + unsafe_vec::UnsafeVec, BuildAddressHasher, BuildIdentityHasher, MemoryEntry, + MemoryLocationHash, NewLazyAddresses, ReadOrigin, ReadSet, TxIdx, TxVersion, WriteSet, }; #[derive(Default, Debug)] @@ -37,7 +37,7 @@ pub struct MvMemory { // Nevertheless, the compiler should be good enough to optimize these cases anyway. data: DashMap, BuildIdentityHasher>, /// Last read & written locations of each transaction - last_locations: Vec>, + last_locations: UnsafeVec, /// Lazy addresses that need full evaluation at the end of the block lazy_addresses: Mutex, } @@ -66,7 +66,7 @@ impl MvMemory { } Self { data, - last_locations: (0..block_size).map(|_| Mutex::default()).collect(), + last_locations: UnsafeVec::new((0..block_size).map(|_| Default::default()).collect()), lazy_addresses: Mutex::new(lazy_addresses), } } @@ -92,7 +92,7 @@ impl MvMemory { ); } // TODO: Faster "difference" function when there are many locations - let mut last_locations = index_mutex!(self.last_locations, tx_version.tx_idx); + let last_locations = self.last_locations.get_mut(tx_version.tx_idx); for prev_location in last_locations.write.iter() { if !new_locations.contains(prev_location) { if let Some(mut written_transactions) = self.data.get_mut(prev_location) { @@ -136,7 +136,7 @@ impl MvMemory { // validations that successfully abort affect the state and each incarnation // can be aborted at most once). pub(crate) fn validate_read_locations(&self, tx_idx: TxIdx) -> bool { - for (location, prior_origins) in index_mutex!(self.last_locations, tx_idx).read.iter() { + for (location, prior_origins) in self.last_locations.get(tx_idx).read.iter() { if let Some(written_transactions) = self.read_location(location) { let mut iter = written_transactions.range(..tx_idx); for prior_origin in prior_origins { @@ -177,7 +177,7 @@ impl MvMemory { // structure with special ESTIMATE markers to quickly abort higher transactions // that read them. pub(crate) fn convert_writes_to_estimates(&self, tx_idx: TxIdx) { - for location in index_mutex!(self.last_locations, tx_idx).write.iter() { + for location in self.last_locations.get(tx_idx).write.iter() { if let Some(mut written_transactions) = self.data.get_mut(location) { written_transactions.insert(tx_idx, MemoryEntry::Estimate); } diff --git a/src/unsafe_vec.rs b/src/unsafe_vec.rs new file mode 100644 index 00000000..7b76e2af --- /dev/null +++ b/src/unsafe_vec.rs @@ -0,0 +1,84 @@ +use std::cell::UnsafeCell; +use std::marker::PhantomData; +use std::ops::{Deref, DerefMut}; + +/// A vector that allows for unsafe concurrent updates without locking. +/// The user must ensure that each index is accessed by only one thread at a time. +#[derive(Debug)] +pub(crate) struct UnsafeVec { + data: UnsafeCell>, + _marker: PhantomData, +} + +// Implementing Sync for UnsafeVec to allow sharing between threads. +unsafe impl Sync for UnsafeVec {} + +impl UnsafeVec { + pub(crate) fn new(vec: Vec) -> UnsafeVec { + UnsafeVec { + data: UnsafeCell::new(vec), + _marker: PhantomData, + } + } + + /// Sets the value at the specified index. + /// + /// # Safety + /// + /// This method is unsafe because it allows for concurrent mutable access to the vector. + /// The caller must ensure that no other threads are accessing the same index concurrently. + #[allow(dead_code)] + pub(crate) fn set(&self, index: usize, value: T) { + unsafe { + (*self.data.get())[index] = value; + } + } + + /// Gets a reference to the value at the specified index. + /// + /// # Safety + /// + /// This method is unsafe for two reasons: + /// + /// 1. It allows for concurrent immutable access to the vector. + /// The caller must ensure that no other threads are mutating the same index concurrently. + /// + /// 2. The caller must ensure that the index is within the bounds of the vector. + /// Accessing an out-of-bounds index can lead to undefined behavior. + pub(crate) fn get(&self, index: usize) -> &T { + unsafe { (*self.data.get()).get_unchecked(index) } + } + + /// Gets a mutable reference to the value at the specified index. + /// + /// # Safety + /// + /// This method is unsafe for two reasons: + /// + /// 1. It allows for concurrent mutable access to the vector. + /// The caller must ensure that no other threads are accessing the same index concurrently, + /// and that there are no overlapping mutable references to the same index. + /// + /// 2. The caller must ensure that the index is within the bounds of the vector. + /// Accessing an out-of-bounds index can lead to undefined behavior. + #[allow(clippy::mut_from_ref)] + pub(crate) fn get_mut(&self, index: usize) -> &mut T { + unsafe { (*self.data.get()).get_unchecked_mut(index) } + } +} + +// Implementing Deref to delegate method calls to the underlying vector. +impl Deref for UnsafeVec { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + unsafe { &*self.data.get() } + } +} + +// Implementing DerefMut to delegate mutable method calls to the underlying vector. +impl DerefMut for UnsafeVec { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.data.get() } + } +}