Skip to content

Commit

Permalink
AsyncRuntime::oneshot (#1026)
Browse files Browse the repository at this point in the history
* Feature: Have oneshot as a Runtime implementation

* Refactor: Change type definition to pass a 'RaftTypeConfig' instead

* Refactor: Change type definition to pass a 'RaftTypeConfig' instead

* Refactor: simplify types for Respond

* Refactor: Use OneshotSenderOf

* Refactor: LogPush with RaftTypeConfig type parameters

* Refactor: ValueSender use RaftTypeConfig instead of just AsyncRuntime

* Refactor: Enforce 'AsyncRuntime' to be 'PartialEq' + 'Eq' and remove manual impl of 'PartialEq' and 'Eq'

---------

Signed-off-by: Anthony Griffon <[email protected]>
  • Loading branch information
Miaxos authored Mar 2, 2024
1 parent 685fe8d commit 5a0d974
Show file tree
Hide file tree
Showing 27 changed files with 247 additions and 149 deletions.
11 changes: 8 additions & 3 deletions cluster_benchmark/tests/benchmark/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use openraft::Entry;
use openraft::EntryPayload;
use openraft::LogId;
use openraft::OptionalSend;
use openraft::OptionalSync;
use openraft::RaftLogId;
use openraft::RaftTypeConfig;
use openraft::SnapshotMeta;
Expand Down Expand Up @@ -225,8 +224,14 @@ impl RaftLogStorage<TypeConfig> for Arc<LogStore> {
}

#[tracing::instrument(level = "trace", skip_all)]
async fn append<I>(&mut self, entries: I, callback: LogFlushed<NodeId>) -> Result<(), StorageError<NodeId>>
where I: IntoIterator<Item = Entry<TypeConfig>> + Send {
async fn append<I>(
&mut self,
entries: I,
callback: LogFlushed<TypeConfig>,
) -> Result<(), StorageError<NodeId>>
where
I: IntoIterator<Item = Entry<TypeConfig>> + Send,
{
{
let mut log = self.log.write().await;
log.extend(entries.into_iter().map(|entry| (entry.get_log_id().index, entry)));
Expand Down
12 changes: 3 additions & 9 deletions examples/memstore/src/log_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ impl<C: RaftTypeConfig> LogStoreInner<C> {
Ok(self.vote)
}

async fn append<I>(&mut self, entries: I, callback: LogFlushed<C::NodeId>) -> Result<(), StorageError<C::NodeId>>
async fn append<I>(&mut self, entries: I, callback: LogFlushed<C>) -> Result<(), StorageError<C::NodeId>>
where I: IntoIterator<Item = C::Entry> {
// Simple implementation that calls the flush-before-return `append_to_log`.
for entry in entries {
Expand Down Expand Up @@ -188,14 +188,8 @@ mod impl_log_store {
inner.read_vote().await
}

async fn append<I>(
&mut self,
entries: I,
callback: LogFlushed<C::NodeId>,
) -> Result<(), StorageError<C::NodeId>>
where
I: IntoIterator<Item = C::Entry>,
{
async fn append<I>(&mut self, entries: I, callback: LogFlushed<C>) -> Result<(), StorageError<C::NodeId>>
where I: IntoIterator<Item = C::Entry> {
let mut inner = self.inner.lock().await;
inner.append(entries, callback).await
}
Expand Down
2 changes: 1 addition & 1 deletion examples/raft-kv-memstore-singlethreaded/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ impl RaftLogStorage<TypeConfig> for Rc<LogStore> {
}

#[tracing::instrument(level = "trace", skip(self, entries, callback))]
async fn append<I>(&mut self, entries: I, callback: LogFlushed<NodeId>) -> Result<(), StorageError<NodeId>>
async fn append<I>(&mut self, entries: I, callback: LogFlushed<TypeConfig>) -> Result<(), StorageError<NodeId>>
where I: IntoIterator<Item = Entry<TypeConfig>> {
// Simple implementation that calls the flush-before-return `append_to_log`.
let mut log = self.log.borrow_mut();
Expand Down
2 changes: 1 addition & 1 deletion examples/raft-kv-rocksdb/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ impl RaftLogStorage<TypeConfig> for LogStore {
}

#[tracing::instrument(level = "trace", skip_all)]
async fn append<I>(&mut self, entries: I, callback: LogFlushed<NodeId>) -> StorageResult<()>
async fn append<I>(&mut self, entries: I, callback: LogFlushed<TypeConfig>) -> StorageResult<()>
where
I: IntoIterator<Item = Entry<TypeConfig>> + Send,
I::IntoIter: Send,
Expand Down
63 changes: 61 additions & 2 deletions openraft/src/async_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use crate::TokioInstant;
/// ## Note
///
/// The default asynchronous runtime is `tokio`.
pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static {
pub trait AsyncRuntime: Debug + Default + PartialEq + Eq + OptionalSend + OptionalSync + 'static {
/// The error type of [`Self::JoinHandle`].
type JoinError: Debug + Display + OptionalSend;

Expand All @@ -44,6 +44,18 @@ pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static
/// Type of a thread-local random number generator.
type ThreadLocalRng: rand::Rng;

/// Type of a `oneshot` sender.
type OneshotSender<T: OptionalSend>: AsyncOneshotSendExt<T> + OptionalSend + OptionalSync + Debug + Sized;

/// Type of a `oneshot` receiver error.
type OneshotReceiverError: std::error::Error + OptionalSend;

/// Type of a `oneshot` receiver.
type OneshotReceiver<T: OptionalSend>: OptionalSend
+ OptionalSync
+ Future<Output = Result<T, Self::OneshotReceiverError>>
+ Unpin;

/// Spawn a new task.
fn spawn<T>(future: T) -> Self::JoinHandle<T::Output>
where
Expand Down Expand Up @@ -72,12 +84,24 @@ pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static
/// This is a per-thread instance, which cannot be shared across threads or
/// sent to another thread.
fn thread_rng() -> Self::ThreadLocalRng;

/// Creates a new one-shot channel for sending single values.
///
/// The function returns separate "send" and "receive" handles. The `Sender`
/// handle is used by the producer to send the value. The `Receiver` handle is
/// used by the consumer to receive the value.
///
/// Each handle can be used on separate tasks.
fn oneshot<T>() -> (Self::OneshotSender<T>, Self::OneshotReceiver<T>)
where T: OptionalSend;
}

/// `Tokio` is the default asynchronous executor.
#[derive(Debug, Default)]
#[derive(Debug, Default, PartialEq, Eq)]
pub struct TokioRuntime;

pub struct TokioOneShotSender<T: OptionalSend>(pub tokio::sync::oneshot::Sender<T>);

impl AsyncRuntime for TokioRuntime {
type JoinError = tokio::task::JoinError;
type JoinHandle<T: OptionalSend + 'static> = tokio::task::JoinHandle<T>;
Expand All @@ -86,6 +110,9 @@ impl AsyncRuntime for TokioRuntime {
type TimeoutError = tokio::time::error::Elapsed;
type Timeout<R, T: Future<Output = R> + OptionalSend> = tokio::time::Timeout<T>;
type ThreadLocalRng = rand::rngs::ThreadRng;
type OneshotSender<T: OptionalSend> = TokioOneShotSender<T>;
type OneshotReceiver<T: OptionalSend> = tokio::sync::oneshot::Receiver<T>;
type OneshotReceiverError = tokio::sync::oneshot::error::RecvError;

#[inline]
fn spawn<T>(future: T) -> Self::JoinHandle<T::Output>
Expand Down Expand Up @@ -132,4 +159,36 @@ impl AsyncRuntime for TokioRuntime {
fn thread_rng() -> Self::ThreadLocalRng {
rand::thread_rng()
}

#[inline]
fn oneshot<T>() -> (Self::OneshotSender<T>, Self::OneshotReceiver<T>)
where T: OptionalSend {
let (tx, rx) = tokio::sync::oneshot::channel();
(TokioOneShotSender(tx), rx)
}
}

pub trait AsyncOneshotSendExt<T>: Unpin {
/// Attempts to send a value on this channel, returning it back if it could
/// not be sent.
///
/// This method consumes `self` as only one value may ever be sent on a `oneshot`
/// channel. It is not marked async because sending a message to an `oneshot`
/// channel never requires any form of waiting. Because of this, the `send`
/// method can be used in both synchronous and asynchronous code without
/// problems.
fn send(self, t: T) -> Result<(), T>;
}

impl<T: OptionalSend> AsyncOneshotSendExt<T> for TokioOneShotSender<T> {
#[inline]
fn send(self, t: T) -> Result<(), T> {
self.0.send(t)
}
}

impl<T: OptionalSend> Debug for TokioOneShotSender<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("TokioSendWrapper").finish()
}
}
35 changes: 20 additions & 15 deletions openraft/src/core/raft_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ use futures::TryFutureExt;
use maplit::btreeset;
use tokio::select;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::watch;
use tracing::Instrument;
use tracing::Level;
use tracing::Span;

use crate::async_runtime::AsyncOneshotSendExt;
use crate::config::Config;
use crate::config::RuntimeConfig;
use crate::core::balancer::Balancer;
Expand Down Expand Up @@ -215,7 +215,10 @@ where
SM: RaftStateMachine<C>,
{
/// The main loop of the Raft protocol.
pub(crate) async fn main(mut self, rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
pub(crate) async fn main(
mut self,
rx_shutdown: <C::AsyncRuntime as AsyncRuntime>::OneshotReceiver<()>,
) -> Result<(), Fatal<C::NodeId>> {
let span = tracing::span!(parent: &self.span, Level::DEBUG, "main");
let res = self.do_main(rx_shutdown).instrument(span).await;

Expand All @@ -239,7 +242,10 @@ where
}

#[tracing::instrument(level="trace", skip_all, fields(id=display(self.id), cluster=%self.config.cluster_name))]
async fn do_main(&mut self, rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
async fn do_main(
&mut self,
rx_shutdown: <C::AsyncRuntime as AsyncRuntime>::OneshotReceiver<()>,
) -> Result<(), Fatal<C::NodeId>> {
tracing::debug!("raft node is initializing");

self.engine.startup();
Expand Down Expand Up @@ -432,7 +438,7 @@ where
&mut self,
changes: ChangeMembers<C::NodeId, C::Node>,
retain: bool,
tx: ResultSender<ClientWriteResponse<C>, ClientWriteError<C::NodeId, C::Node>>,
tx: ResultSender<C, ClientWriteResponse<C>, ClientWriteError<C::NodeId, C::Node>>,
) {
let res = self.engine.state.membership_state.change_handler().apply(changes, retain);
let new_membership = match res {
Expand Down Expand Up @@ -593,7 +599,7 @@ where
pub(crate) fn handle_initialize(
&mut self,
member_nodes: BTreeMap<C::NodeId, C::Node>,
tx: ResultSender<(), InitializeError<C::NodeId, C::Node>>,
tx: ResultSender<C, (), InitializeError<C::NodeId, C::Node>>,
) {
tracing::debug!(member_nodes = debug(&member_nodes), "{}", func_name!());

Expand All @@ -616,7 +622,7 @@ where

/// Reject a request due to the Raft node being in a state which prohibits the request.
#[tracing::instrument(level = "trace", skip(self, tx))]
pub(crate) fn reject_with_forward_to_leader<T, E>(&self, tx: ResultSender<T, E>)
pub(crate) fn reject_with_forward_to_leader<T: OptionalSend, E: OptionalSend>(&self, tx: ResultSender<C, T, E>)
where E: From<ForwardToLeader<C::NodeId, C::Node>> {
let mut leader_id = self.current_leader();
let leader_node = self.get_leader_node(leader_id);
Expand Down Expand Up @@ -680,7 +686,7 @@ where
{
tracing::debug!("append_to_log");

let (tx, rx) = oneshot::channel();
let (tx, rx) = C::AsyncRuntime::oneshot();
let callback = LogFlushed::new(Some(last_log_id), tx);
self.log_store.append(entries, callback).await?;
rx.await
Expand Down Expand Up @@ -865,7 +871,10 @@ where

/// Run an event handling loop
#[tracing::instrument(level="debug", skip_all, fields(id=display(self.id)))]
async fn runtime_loop(&mut self, mut rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
async fn runtime_loop(
&mut self,
mut rx_shutdown: <C::AsyncRuntime as AsyncRuntime>::OneshotReceiver<()>,
) -> Result<(), Fatal<C::NodeId>> {
// Ratio control the ratio of number of RaftMsg to process to number of Notify to process.
let mut balancer = Balancer::new(10_000);

Expand Down Expand Up @@ -1067,7 +1076,7 @@ where
}

#[tracing::instrument(level = "debug", skip_all)]
pub(super) fn handle_vote_request(&mut self, req: VoteRequest<C::NodeId>, tx: VoteTx<C::NodeId>) {
pub(super) fn handle_vote_request(&mut self, req: VoteRequest<C::NodeId>, tx: VoteTx<C>) {
tracing::info!(req = display(req.summary()), func = func_name!());

let resp = self.engine.handle_vote_req(req);
Expand All @@ -1078,11 +1087,7 @@ where
}

#[tracing::instrument(level = "debug", skip_all)]
pub(super) fn handle_append_entries_request(
&mut self,
req: AppendEntriesRequest<C>,
tx: AppendEntriesTx<C::NodeId>,
) {
pub(super) fn handle_append_entries_request(&mut self, req: AppendEntriesRequest<C>, tx: AppendEntriesTx<C>) {
tracing::debug!(req = display(req.summary()), func = func_name!());

let is_ok = self.engine.handle_append_entries(&req.vote, req.prev_log_id, req.entries, Some(tx));
Expand Down Expand Up @@ -1657,7 +1662,7 @@ where

// Create a channel to let state machine worker to send the snapshot and the replication
// worker to receive it.
let (tx, rx) = oneshot::channel();
let (tx, rx) = C::AsyncRuntime::oneshot();

let cmd = sm::Command::get_snapshot(tx);
self.sm_handle
Expand Down
2 changes: 1 addition & 1 deletion openraft/src/core/raft_msg/external_command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub(crate) enum ExternalCommand<C: RaftTypeConfig> {
Snapshot,

/// Get a snapshot from the state machine, send back via a oneshot::Sender.
GetSnapshot { tx: ResultSender<Option<Snapshot<C>>> },
GetSnapshot { tx: ResultSender<C, Option<Snapshot<C>>> },

/// Purge logs covered by a snapshot up to a specified index.
///
Expand Down
30 changes: 16 additions & 14 deletions openraft/src/core/raft_msg/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
use std::collections::BTreeMap;

use tokio::sync::oneshot;

use crate::core::raft_msg::external_command::ExternalCommand;
use crate::error::CheckIsLeaderError;
use crate::error::ClientWriteError;
Expand All @@ -15,10 +13,13 @@ use crate::raft::ClientWriteResponse;
use crate::raft::SnapshotResponse;
use crate::raft::VoteRequest;
use crate::raft::VoteResponse;
use crate::type_config::alias::AsyncRuntimeOf;
use crate::type_config::alias::LogIdOf;
use crate::type_config::alias::NodeIdOf;
use crate::type_config::alias::NodeOf;
use crate::type_config::alias::OneshotSenderOf;
use crate::type_config::alias::SnapshotDataOf;
use crate::AsyncRuntime;
use crate::ChangeMembers;
use crate::MessageSummary;
use crate::RaftTypeConfig;
Expand All @@ -28,22 +29,23 @@ use crate::Vote;
pub(crate) mod external_command;

/// A oneshot TX to send result from `RaftCore` to external caller, e.g. `Raft::append_entries`.
pub(crate) type ResultSender<T, E = Infallible> = oneshot::Sender<Result<T, E>>;
pub(crate) type ResultSender<C, T, E = Infallible> = OneshotSenderOf<C, Result<T, E>>;

pub(crate) type ResultReceiver<T, E = Infallible> = oneshot::Receiver<Result<T, E>>;
pub(crate) type ResultReceiver<C, T, E = Infallible> =
<AsyncRuntimeOf<C> as AsyncRuntime>::OneshotReceiver<Result<T, E>>;

/// TX for Vote Response
pub(crate) type VoteTx<NID> = ResultSender<VoteResponse<NID>>;
pub(crate) type VoteTx<C> = ResultSender<C, VoteResponse<NodeIdOf<C>>>;

/// TX for Append Entries Response
pub(crate) type AppendEntriesTx<NID> = ResultSender<AppendEntriesResponse<NID>>;
pub(crate) type AppendEntriesTx<C> = ResultSender<C, AppendEntriesResponse<NodeIdOf<C>>>;

/// TX for Client Write Response
pub(crate) type ClientWriteTx<C> = ResultSender<ClientWriteResponse<C>, ClientWriteError<NodeIdOf<C>, NodeOf<C>>>;
pub(crate) type ClientWriteTx<C> = ResultSender<C, ClientWriteResponse<C>, ClientWriteError<NodeIdOf<C>, NodeOf<C>>>;

/// TX for Linearizable Read Response
pub(crate) type ClientReadTx<C> =
ResultSender<(Option<LogIdOf<C>>, Option<LogIdOf<C>>), CheckIsLeaderError<NodeIdOf<C>, NodeOf<C>>>;
ResultSender<C, (Option<LogIdOf<C>>, Option<LogIdOf<C>>), CheckIsLeaderError<NodeIdOf<C>, NodeOf<C>>>;

/// A message sent by application to the [`RaftCore`].
///
Expand All @@ -53,18 +55,18 @@ where C: RaftTypeConfig
{
AppendEntries {
rpc: AppendEntriesRequest<C>,
tx: AppendEntriesTx<C::NodeId>,
tx: AppendEntriesTx<C>,
},

RequestVote {
rpc: VoteRequest<C::NodeId>,
tx: VoteTx<C::NodeId>,
tx: VoteTx<C>,
},

InstallFullSnapshot {
vote: Vote<C::NodeId>,
snapshot: Snapshot<C>,
tx: ResultSender<SnapshotResponse<C::NodeId>>,
tx: ResultSender<C, SnapshotResponse<C::NodeId>>,
},

/// Begin receiving a snapshot from the leader.
Expand All @@ -74,7 +76,7 @@ where C: RaftTypeConfig
/// will be returned in a Err
BeginReceivingSnapshot {
vote: Vote<C::NodeId>,
tx: ResultSender<Box<SnapshotDataOf<C>>, HigherVote<C::NodeId>>,
tx: ResultSender<C, Box<SnapshotDataOf<C>>, HigherVote<C::NodeId>>,
},

ClientWriteRequest {
Expand All @@ -88,7 +90,7 @@ where C: RaftTypeConfig

Initialize {
members: BTreeMap<C::NodeId, C::Node>,
tx: ResultSender<(), InitializeError<C::NodeId, C::Node>>,
tx: ResultSender<C, (), InitializeError<C::NodeId, C::Node>>,
},

ChangeMembership {
Expand All @@ -98,7 +100,7 @@ where C: RaftTypeConfig
/// config will be converted into learners, otherwise they will be removed.
retain: bool,

tx: ResultSender<ClientWriteResponse<C>, ClientWriteError<C::NodeId, C::Node>>,
tx: ResultSender<C, ClientWriteResponse<C>, ClientWriteError<C::NodeId, C::Node>>,
},

ExternalCoreRequest {
Expand Down
Loading

0 comments on commit 5a0d974

Please sign in to comment.