diff --git a/openraft/src/lib.rs b/openraft/src/lib.rs index ba80f0d2d..4f87b13d0 100644 --- a/openraft/src/lib.rs +++ b/openraft/src/lib.rs @@ -62,6 +62,7 @@ pub mod metrics; pub mod network; pub mod raft; pub mod storage; +pub mod sync; pub mod testing; pub mod type_config; diff --git a/openraft/src/raft/mod.rs b/openraft/src/raft/mod.rs index 3d5aa6acd..539932ddc 100644 --- a/openraft/src/raft/mod.rs +++ b/openraft/src/raft/mod.rs @@ -37,7 +37,6 @@ pub use message::InstallSnapshotResponse; pub use message::SnapshotResponse; pub use message::VoteRequest; pub use message::VoteResponse; -use tokio::sync::Mutex; use tracing::trace_span; use tracing::Instrument; use tracing::Level; @@ -78,6 +77,7 @@ pub use crate::raft::runtime_config_handle::RuntimeConfigHandle; use crate::raft::trigger::Trigger; use crate::storage::RaftLogStorage; use crate::storage::RaftStateMachine; +use crate::sync::Mutex; use crate::type_config::alias::JoinErrorOf; use crate::type_config::alias::ResponderOf; use crate::type_config::alias::ResponderReceiverOf; @@ -318,7 +318,7 @@ where C: RaftTypeConfig rx_metrics, rx_data_metrics, rx_server_metrics, - tx_shutdown: Mutex::new(Some(tx_shutdown)), + tx_shutdown: std::sync::Mutex::new(Some(tx_shutdown)), core_state: Mutex::new(CoreState::Running(core_handle)), snapshot: Mutex::new(None), @@ -919,7 +919,7 @@ where C: RaftTypeConfig /// /// It sends a shutdown signal and waits until `RaftCore` returns. pub async fn shutdown(&self) -> Result<(), JoinErrorOf> { - if let Some(tx) = self.inner.tx_shutdown.lock().await.take() { + if let Some(tx) = self.inner.tx_shutdown.lock().unwrap().take() { // A failure to send means the RaftCore is already shutdown. Continue to check the task // return value. let send_res = tx.send(()); diff --git a/openraft/src/raft/raft_inner.rs b/openraft/src/raft/raft_inner.rs index 9fb2caa25..06cb588e6 100644 --- a/openraft/src/raft/raft_inner.rs +++ b/openraft/src/raft/raft_inner.rs @@ -3,7 +3,6 @@ use std::fmt::Debug; use std::future::Future; use std::sync::Arc; -use tokio::sync::Mutex; use tracing::Level; use crate::async_runtime::MpscUnboundedSender; @@ -16,6 +15,7 @@ use crate::error::RaftError; use crate::metrics::RaftDataMetrics; use crate::metrics::RaftServerMetrics; use crate::raft::core_state::CoreState; +use crate::sync::Mutex; use crate::type_config::alias::MpscUnboundedSenderOf; use crate::type_config::alias::OneshotReceiverOf; use crate::type_config::alias::OneshotSenderOf; @@ -40,13 +40,12 @@ where C: RaftTypeConfig pub(in crate::raft) rx_data_metrics: WatchReceiverOf>, pub(in crate::raft) rx_server_metrics: WatchReceiverOf>, - // TODO(xp): it does not need to be a async mutex. #[allow(clippy::type_complexity)] - pub(in crate::raft) tx_shutdown: Mutex>>, - pub(in crate::raft) core_state: Mutex>, + pub(in crate::raft) tx_shutdown: std::sync::Mutex>>, + pub(in crate::raft) core_state: Mutex>, /// The ongoing snapshot transmission. - pub(in crate::raft) snapshot: Mutex>>, + pub(in crate::raft) snapshot: Mutex>>, } impl RaftInner diff --git a/openraft/src/replication/mod.rs b/openraft/src/replication/mod.rs index 9fe857d0b..2d2ec0f46 100644 --- a/openraft/src/replication/mod.rs +++ b/openraft/src/replication/mod.rs @@ -19,7 +19,6 @@ use request::Replicate; use response::ReplicationResult; pub(crate) use response::Response; use tokio::select; -use tokio::sync::Mutex; use tracing_futures::Instrument; use crate::async_runtime::MpscUnboundedReceiver; @@ -50,6 +49,7 @@ use crate::replication::request_id::RequestId; use crate::storage::RaftLogReader; use crate::storage::RaftLogStorage; use crate::storage::Snapshot; +use crate::sync::Mutex; use crate::type_config::alias::InstantOf; use crate::type_config::alias::JoinHandleOf; use crate::type_config::alias::LogIdOf; @@ -114,7 +114,7 @@ where /// Another `RaftNetwork` specific for snapshot replication. /// /// Snapshot transmitting is a long running task, and is processed in a separate task. - snapshot_network: Arc>, + snapshot_network: Arc>, /// The current snapshot replication state. /// @@ -754,7 +754,7 @@ where async fn send_snapshot( request_id: RequestId, - network: Arc>, + network: Arc>, vote: Vote, snapshot: Snapshot, option: RPCOption, diff --git a/openraft/src/sync/mod.rs b/openraft/src/sync/mod.rs new file mode 100644 index 000000000..c54dd4de9 --- /dev/null +++ b/openraft/src/sync/mod.rs @@ -0,0 +1,5 @@ +pub(crate) mod mutex; + +pub(crate) use mutex::Mutex; +#[allow(unused_imports)] +pub(crate) use mutex::MutexGuard; diff --git a/openraft/src/sync/mutex.rs b/openraft/src/sync/mutex.rs new file mode 100644 index 000000000..fa6d830ed --- /dev/null +++ b/openraft/src/sync/mutex.rs @@ -0,0 +1,168 @@ +use std::cell::UnsafeCell; +use std::ops::Deref; +use std::ops::DerefMut; + +use crate::type_config::alias::OneshotReceiverOf; +use crate::type_config::alias::OneshotSenderOf; +use crate::type_config::TypeConfigExt; +use crate::RaftTypeConfig; + +/// A simple async mutex implementation that uses oneshot channels to notify the next waiting task. +/// +/// Openraft use async mutex in non-performance critical path, +/// so it's ok to use this simple implementation. +/// +/// Since oneshot channel is already required by AsyncRuntime implementation, +/// there is no need for the application to implement Mutex. +pub(crate) struct Mutex +where C: RaftTypeConfig +{ + /// The current lock holder. + /// + /// When the acquired `MutexGuard` is dropped, it will notify the next waiting task via this + /// oneshot channel. + lock_holder: std::sync::Mutex>>, + + /// The value protected by the mutex. + value: UnsafeCell, +} + +impl Mutex +where C: RaftTypeConfig +{ + pub(crate) fn new(value: T) -> Self { + Self { + lock_holder: std::sync::Mutex::new(None), + value: UnsafeCell::new(value), + } + } + + pub(crate) async fn lock(&self) -> MutexGuard<'_, C, T> { + // Every lock() call puts a oneshot receiver into the holder + // and takes out the existing one. + // If the existing one is Some(rx), + // it means the lock is already held by another task. + // In this case, the current task should wait for the lock to be released. + // + // Such approach forms a queue in which every task waits for the previous one. + + let (tx, rx) = C::oneshot(); + let current_rx = { + let mut l = self.lock_holder.lock().unwrap(); + l.replace(rx) + }; + + if let Some(rx) = current_rx { + let _ = rx.await; + } + + MutexGuard { guard: tx, lock: self } + } + + #[allow(dead_code)] + pub(crate) fn into_inner(self) -> T { + self.value.into_inner() + } +} + +/// The guard of the mutex. +pub(crate) struct MutexGuard<'a, C, T> +where C: RaftTypeConfig +{ + /// This is only used to trigger `Drop` to notify the next waiting task. + #[allow(dead_code)] + guard: OneshotSenderOf, + lock: &'a Mutex, +} + +impl<'a, C, T> Deref for MutexGuard<'a, C, T> +where C: RaftTypeConfig +{ + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { &*self.lock.value.get() } + } +} + +impl<'a, C, T> DerefMut for MutexGuard<'a, C, T> +where C: RaftTypeConfig +{ + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.lock.value.get() } + } +} + +/// T must be `Send` to make Mutex `Send` +unsafe impl Send for Mutex where T: Send {} + +/// To allow multiple threads to access T through a `&Mutex`, T must be `Send`, +/// because the caller acquires the ownership through `Mutex::lock()`. +unsafe impl Sync for Mutex where T: Send {} + +/// MutexGuard needs to be Sync to across `.await` point. +unsafe impl Sync for MutexGuard<'_, C, T> where T: Send + Sync {} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use super::*; + use crate::engine::testing::UTConfig; + + #[test] + fn bounds() { + fn check_send() {} + fn check_unpin() {} + // This has to take a value, since the async fn's return type is unnameable. + fn check_send_sync_val(_t: T) {} + fn check_send_sync() {} + + check_send::>(); + check_unpin::>(); + check_send_sync::>(); + + let mutex = Mutex::::new(1); + check_send_sync_val(mutex.lock()); + } + + #[test] + fn test_mutex() { + let mutex = Arc::new(Mutex::::new(0)); + + let rt = tokio::runtime::Builder::new_multi_thread() + .worker_threads(8) + .enable_all() + .build() + .expect("Failed building the Runtime"); + + let big_prime_num = 1_000_000_009; + let n = 100_000; + let n_task = 10; + let mut tasks = vec![]; + + for _i in 0..n_task { + let mutex = mutex.clone(); + let h = rt.spawn(async move { + for k in 0..n { + { + let mut guard = mutex.lock().await; + *guard = (*guard + k) % big_prime_num; + } + } + }); + + tasks.push(h); + } + + let got = rt.block_on(async { + for t in tasks { + let _ = t.await; + } + *mutex.lock().await + }); + + println!("got: {}", got); + assert_eq!(got, n_task * n * (n - 1) / 2 % big_prime_num); + } +} diff --git a/openraft/src/type_config/async_runtime/oneshot.rs b/openraft/src/type_config/async_runtime/oneshot.rs index 08048ebac..38af0c930 100644 --- a/openraft/src/type_config/async_runtime/oneshot.rs +++ b/openraft/src/type_config/async_runtime/oneshot.rs @@ -25,6 +25,8 @@ pub trait Oneshot { where T: OptionalSend; } +/// This `Sender` must implement `Drop` to notify the [`Oneshot::Receiver`] that the sending end has +/// been dropped, causing the receiver to return a [`Oneshot::ReceiverError`]. pub trait OneshotSender: OptionalSend + OptionalSync + Sized where T: OptionalSend {