Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AsyncRuntime::oneshot #1026

Merged
merged 11 commits into from
Mar 2, 2024
11 changes: 8 additions & 3 deletions cluster_benchmark/tests/benchmark/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use openraft::Entry;
use openraft::EntryPayload;
use openraft::LogId;
use openraft::OptionalSend;
use openraft::OptionalSync;
use openraft::RaftLogId;
use openraft::RaftTypeConfig;
use openraft::SnapshotMeta;
Expand Down Expand Up @@ -225,8 +224,14 @@ impl RaftLogStorage<TypeConfig> for Arc<LogStore> {
}

#[tracing::instrument(level = "trace", skip_all)]
async fn append<I>(&mut self, entries: I, callback: LogFlushed<NodeId>) -> Result<(), StorageError<NodeId>>
where I: IntoIterator<Item = Entry<TypeConfig>> + Send {
async fn append<I>(
&mut self,
entries: I,
callback: LogFlushed<<TypeConfig as RaftTypeConfig>::AsyncRuntime, NodeId>,
) -> Result<(), StorageError<NodeId>>
where
I: IntoIterator<Item = Entry<TypeConfig>> + Send,
{
{
let mut log = self.log.write().await;
log.extend(entries.into_iter().map(|entry| (entry.get_log_id().index, entry)));
Expand Down
12 changes: 9 additions & 3 deletions examples/memstore/src/log_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,14 @@ impl<C: RaftTypeConfig> LogStoreInner<C> {
Ok(self.vote)
}

async fn append<I>(&mut self, entries: I, callback: LogFlushed<C::NodeId>) -> Result<(), StorageError<C::NodeId>>
where I: IntoIterator<Item = C::Entry> {
async fn append<I>(
&mut self,
entries: I,
callback: LogFlushed<C::AsyncRuntime, C::NodeId>,
) -> Result<(), StorageError<C::NodeId>>
where
I: IntoIterator<Item = C::Entry>,
{
// Simple implementation that calls the flush-before-return `append_to_log`.
for entry in entries {
self.log.insert(entry.get_log_id().index, entry);
Expand Down Expand Up @@ -191,7 +197,7 @@ mod impl_log_store {
async fn append<I>(
&mut self,
entries: I,
callback: LogFlushed<C::NodeId>,
callback: LogFlushed<C::AsyncRuntime, C::NodeId>,
) -> Result<(), StorageError<C::NodeId>>
where
I: IntoIterator<Item = C::Entry>,
Expand Down
10 changes: 8 additions & 2 deletions examples/raft-kv-memstore-singlethreaded/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,14 @@ impl RaftLogStorage<TypeConfig> for Rc<LogStore> {
}

#[tracing::instrument(level = "trace", skip(self, entries, callback))]
async fn append<I>(&mut self, entries: I, callback: LogFlushed<NodeId>) -> Result<(), StorageError<NodeId>>
where I: IntoIterator<Item = Entry<TypeConfig>> {
async fn append<I>(
&mut self,
entries: I,
callback: LogFlushed<<TypeConfig as RaftTypeConfig>::AsyncRuntime, NodeId>,
) -> Result<(), StorageError<NodeId>>
where
I: IntoIterator<Item = Entry<TypeConfig>>,
{
// Simple implementation that calls the flush-before-return `append_to_log`.
let mut log = self.log.borrow_mut();
for entry in entries {
Expand Down
7 changes: 6 additions & 1 deletion examples/raft-kv-rocksdb/src/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use openraft::LogId;
use openraft::OptionalSend;
use openraft::RaftLogReader;
use openraft::RaftSnapshotBuilder;
use openraft::RaftTypeConfig;
use openraft::SnapshotMeta;
use openraft::StorageError;
use openraft::StorageIOError;
Expand Down Expand Up @@ -436,7 +437,11 @@ impl RaftLogStorage<TypeConfig> for LogStore {
}

#[tracing::instrument(level = "trace", skip_all)]
async fn append<I>(&mut self, entries: I, callback: LogFlushed<NodeId>) -> StorageResult<()>
async fn append<I>(
&mut self,
entries: I,
callback: LogFlushed<<TypeConfig as RaftTypeConfig>::AsyncRuntime, NodeId>,
) -> StorageResult<()>
where
I: IntoIterator<Item = Entry<TypeConfig>> + Send,
I::IntoIter: Send,
Expand Down
59 changes: 59 additions & 0 deletions openraft/src/async_runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static
/// Type of a thread-local random number generator.
type ThreadLocalRng: rand::Rng;

/// Type of a `oneshot` sender.
type OneshotSender<T: OptionalSend>: AsyncOneshotSendExt<T> + OptionalSend + OptionalSync + Debug + Sized;

/// Type of a `oneshot` receiver error.
type OneshotReceiverError: std::error::Error + OptionalSend;

/// Type of a `oneshot` receiver.
type OneshotReceiver<T: OptionalSend>: OptionalSend
+ OptionalSync
+ Future<Output = Result<T, Self::OneshotReceiverError>>
+ Unpin;

/// Spawn a new task.
fn spawn<T>(future: T) -> Self::JoinHandle<T::Output>
where
Expand Down Expand Up @@ -72,12 +84,24 @@ pub trait AsyncRuntime: Debug + Default + OptionalSend + OptionalSync + 'static
/// This is a per-thread instance, which cannot be shared across threads or
/// sent to another thread.
fn thread_rng() -> Self::ThreadLocalRng;

/// Creates a new one-shot channel for sending single values.
///
/// The function returns separate "send" and "receive" handles. The `Sender`
/// handle is used by the producer to send the value. The `Receiver` handle is
/// used by the consumer to receive the value.
///
/// Each handle can be used on separate tasks.
fn oneshot<T>() -> (Self::OneshotSender<T>, Self::OneshotReceiver<T>)
where T: OptionalSend;
}

/// `Tokio` is the default asynchronous executor.
#[derive(Debug, Default)]
pub struct TokioRuntime;

pub struct TokioOneShotSender<T: OptionalSend>(pub tokio::sync::oneshot::Sender<T>);

impl AsyncRuntime for TokioRuntime {
type JoinError = tokio::task::JoinError;
type JoinHandle<T: OptionalSend + 'static> = tokio::task::JoinHandle<T>;
Expand All @@ -86,6 +110,9 @@ impl AsyncRuntime for TokioRuntime {
type TimeoutError = tokio::time::error::Elapsed;
type Timeout<R, T: Future<Output = R> + OptionalSend> = tokio::time::Timeout<T>;
type ThreadLocalRng = rand::rngs::ThreadRng;
type OneshotSender<T: OptionalSend> = TokioOneShotSender<T>;
type OneshotReceiver<T: OptionalSend> = tokio::sync::oneshot::Receiver<T>;
type OneshotReceiverError = tokio::sync::oneshot::error::RecvError;

#[inline]
fn spawn<T>(future: T) -> Self::JoinHandle<T::Output>
Expand Down Expand Up @@ -132,4 +159,36 @@ impl AsyncRuntime for TokioRuntime {
fn thread_rng() -> Self::ThreadLocalRng {
rand::thread_rng()
}

#[inline]
fn oneshot<T>() -> (Self::OneshotSender<T>, Self::OneshotReceiver<T>)
where T: OptionalSend {
let (tx, rx) = tokio::sync::oneshot::channel();
(TokioOneShotSender(tx), rx)
}
}

pub trait AsyncOneshotSendExt<T>: Unpin {
/// Attempts to send a value on this channel, returning it back if it could
/// not be sent.
///
/// This method consumes `self` as only one value may ever be sent on a `oneshot`
/// channel. It is not marked async because sending a message to an `oneshot`
/// channel never requires any form of waiting. Because of this, the `send`
/// method can be used in both synchronous and asynchronous code without
/// problems.
fn send(self, t: T) -> Result<(), T>;
}

impl<T: OptionalSend> AsyncOneshotSendExt<T> for TokioOneShotSender<T> {
#[inline]
fn send(self, t: T) -> Result<(), T> {
self.0.send(t)
}
}

impl<T: OptionalSend> Debug for TokioOneShotSender<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("TokioSendWrapper").finish()
}
}
35 changes: 20 additions & 15 deletions openraft/src/core/raft_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ use futures::TryFutureExt;
use maplit::btreeset;
use tokio::select;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::watch;
use tracing::Instrument;
use tracing::Level;
use tracing::Span;

use crate::async_runtime::AsyncOneshotSendExt;
use crate::config::Config;
use crate::config::RuntimeConfig;
use crate::core::balancer::Balancer;
Expand Down Expand Up @@ -215,7 +215,10 @@ where
SM: RaftStateMachine<C>,
{
/// The main loop of the Raft protocol.
pub(crate) async fn main(mut self, rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
pub(crate) async fn main(
mut self,
rx_shutdown: <C::AsyncRuntime as AsyncRuntime>::OneshotReceiver<()>,
) -> Result<(), Fatal<C::NodeId>> {
let span = tracing::span!(parent: &self.span, Level::DEBUG, "main");
let res = self.do_main(rx_shutdown).instrument(span).await;

Expand All @@ -239,7 +242,10 @@ where
}

#[tracing::instrument(level="trace", skip_all, fields(id=display(self.id), cluster=%self.config.cluster_name))]
async fn do_main(&mut self, rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
async fn do_main(
&mut self,
rx_shutdown: <C::AsyncRuntime as AsyncRuntime>::OneshotReceiver<()>,
) -> Result<(), Fatal<C::NodeId>> {
tracing::debug!("raft node is initializing");

self.engine.startup();
Expand Down Expand Up @@ -432,7 +438,7 @@ where
&mut self,
changes: ChangeMembers<C::NodeId, C::Node>,
retain: bool,
tx: ResultSender<ClientWriteResponse<C>, ClientWriteError<C::NodeId, C::Node>>,
tx: ResultSender<C, ClientWriteResponse<C>, ClientWriteError<C::NodeId, C::Node>>,
) {
let res = self.engine.state.membership_state.change_handler().apply(changes, retain);
let new_membership = match res {
Expand Down Expand Up @@ -593,7 +599,7 @@ where
pub(crate) fn handle_initialize(
&mut self,
member_nodes: BTreeMap<C::NodeId, C::Node>,
tx: ResultSender<(), InitializeError<C::NodeId, C::Node>>,
tx: ResultSender<C, (), InitializeError<C::NodeId, C::Node>>,
) {
tracing::debug!(member_nodes = debug(&member_nodes), "{}", func_name!());

Expand All @@ -616,7 +622,7 @@ where

/// Reject a request due to the Raft node being in a state which prohibits the request.
#[tracing::instrument(level = "trace", skip(self, tx))]
pub(crate) fn reject_with_forward_to_leader<T, E>(&self, tx: ResultSender<T, E>)
pub(crate) fn reject_with_forward_to_leader<T: OptionalSend, E: OptionalSend>(&self, tx: ResultSender<C, T, E>)
where E: From<ForwardToLeader<C::NodeId, C::Node>> {
let mut leader_id = self.current_leader();
let leader_node = self.get_leader_node(leader_id);
Expand Down Expand Up @@ -680,7 +686,7 @@ where
{
tracing::debug!("append_to_log");

let (tx, rx) = oneshot::channel();
let (tx, rx) = C::AsyncRuntime::oneshot();
let callback = LogFlushed::new(Some(last_log_id), tx);
self.log_store.append(entries, callback).await?;
rx.await
Expand Down Expand Up @@ -865,7 +871,10 @@ where

/// Run an event handling loop
#[tracing::instrument(level="debug", skip_all, fields(id=display(self.id)))]
async fn runtime_loop(&mut self, mut rx_shutdown: oneshot::Receiver<()>) -> Result<(), Fatal<C::NodeId>> {
async fn runtime_loop(
&mut self,
mut rx_shutdown: <C::AsyncRuntime as AsyncRuntime>::OneshotReceiver<()>,
) -> Result<(), Fatal<C::NodeId>> {
// Ratio control the ratio of number of RaftMsg to process to number of Notify to process.
let mut balancer = Balancer::new(10_000);

Expand Down Expand Up @@ -1067,7 +1076,7 @@ where
}

#[tracing::instrument(level = "debug", skip_all)]
pub(super) fn handle_vote_request(&mut self, req: VoteRequest<C::NodeId>, tx: VoteTx<C::NodeId>) {
pub(super) fn handle_vote_request(&mut self, req: VoteRequest<C::NodeId>, tx: VoteTx<C>) {
tracing::info!(req = display(req.summary()), func = func_name!());

let resp = self.engine.handle_vote_req(req);
Expand All @@ -1078,11 +1087,7 @@ where
}

#[tracing::instrument(level = "debug", skip_all)]
pub(super) fn handle_append_entries_request(
&mut self,
req: AppendEntriesRequest<C>,
tx: AppendEntriesTx<C::NodeId>,
) {
pub(super) fn handle_append_entries_request(&mut self, req: AppendEntriesRequest<C>, tx: AppendEntriesTx<C>) {
tracing::debug!(req = display(req.summary()), func = func_name!());

let is_ok = self.engine.handle_append_entries(&req.vote, req.prev_log_id, req.entries, Some(tx));
Expand Down Expand Up @@ -1657,7 +1662,7 @@ where

// Create a channel to let state machine worker to send the snapshot and the replication
// worker to receive it.
let (tx, rx) = oneshot::channel();
let (tx, rx) = C::AsyncRuntime::oneshot();

let cmd = sm::Command::get_snapshot(tx);
self.sm_handle
Expand Down
2 changes: 1 addition & 1 deletion openraft/src/core/raft_msg/external_command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub(crate) enum ExternalCommand<C: RaftTypeConfig> {
Snapshot,

/// Get a snapshot from the state machine, send back via a oneshot::Sender.
GetSnapshot { tx: ResultSender<Option<Snapshot<C>>> },
GetSnapshot { tx: ResultSender<C, Option<Snapshot<C>>> },

/// Purge logs covered by a snapshot up to a specified index.
///
Expand Down
Loading
Loading