Skip to content

Commit

Permalink
Refactor: gate tokio rt with feature tokio-rt
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveLauC committed Jul 28, 2024
1 parent 590d943 commit d0ae564
Show file tree
Hide file tree
Showing 13 changed files with 59 additions and 23 deletions.
10 changes: 9 additions & 1 deletion openraft/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ serde = { workspace = true, optional = true }
serde_json = { workspace = true, optional = true }
tempfile = { workspace = true, optional = true }
thiserror = { workspace = true }
tokio = { workspace = true }
tokio = { workspace = true, optional = true }
tracing = { workspace = true }
tracing-futures = { workspace = true }
validit = { workspace = true }

or07 = { package = "openraft", version = "0.7.4", optional = true }

async-lock = "3.4.0"

[dev-dependencies]
anyhow = { workspace = true }
async-entry = { workspace = true }
Expand All @@ -44,6 +46,10 @@ serde_json = { workspace = true }


[features]
default = ["tokio-rt"]

# Enable the default Tokio runtime
tokio-rt = ["dep:tokio"]

# Enables benchmarks in unittest.
#
Expand Down Expand Up @@ -113,6 +119,8 @@ features = [
"tracing-log",
]

no-default-features = false

# Do not use this to enable all features:
# "singlethreaded" makes `Raft<C>` a `!Send`, which confuses users.
# all-features = true
Expand Down
20 changes: 9 additions & 11 deletions openraft/src/core/raft_core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ use std::time::Duration;

use anyerror::AnyError;
use futures::stream::FuturesUnordered;
use futures::FutureExt;
use futures::StreamExt;
use futures::TryFutureExt;
use maplit::btreeset;
use tokio::select;
use tracing::Instrument;
use tracing::Level;
use tracing::Span;
Expand Down Expand Up @@ -910,19 +910,17 @@ where
// In each loop, the first step is blocking waiting for any message from any channel.
// Then if there is any message, process as many as possible to maximize throughput.

select! {
// Check shutdown in each loop first so that a message flood in `tx_api` won't block shutting down.
// `select!` without `biased` provides a random fairness.
// We want to check shutdown prior to other channels.
// See: https://docs.rs/tokio/latest/tokio/macro.select.html#fairness
biased;

_ = &mut rx_shutdown => {
// Check shutdown in each loop first so that a message flood in `tx_api` won't block shutting down.
// `select!` without `biased` provides a random fairness.
// We want to check shutdown prior to other channels.
// See: https://docs.rs/tokio/latest/tokio/macro.select.html#fairness
futures::select_biased! {
_ = (&mut rx_shutdown).fuse() => {
tracing::info!("recv from rx_shutdown");
return Err(Fatal::Stopped);
}

notify_res = self.rx_notification.recv() => {
notify_res = self.rx_notification.recv().fuse() => {
match notify_res {
Some(notify) => self.handle_notification(notify)?,
None => {
Expand All @@ -932,7 +930,7 @@ where
};
}

msg_res = self.rx_api.recv() => {
msg_res = self.rx_api.recv().fuse() => {
match msg_res {
Some(msg) => self.handle_api_msg(msg).await,
None => {
Expand Down
1 change: 1 addition & 0 deletions openraft/src/engine/log_id_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ where C: RaftTypeConfig
self.key_log_ids.last()
}

#[cfg(feature = "tokio-rt")]
pub(crate) fn key_log_ids(&self) -> &[LogId<C::NodeId>] {
&self.key_log_ids
}
Expand Down
1 change: 1 addition & 0 deletions openraft/src/impls/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ pub use crate::entry::Entry;
pub use crate::node::BasicNode;
pub use crate::node::EmptyNode;
pub use crate::raft::responder::impls::OneshotResponder;
#[cfg(feature = "tokio-rt")]
pub use crate::type_config::async_runtime::impls::TokioRuntime;
2 changes: 2 additions & 0 deletions openraft/src/instant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ pub trait Instant:
}
}

#[cfg(feature = "tokio-rt")]
pub type TokioInstant = tokio::time::Instant;

#[cfg(feature = "tokio-rt")]
impl Instant for tokio::time::Instant {
#[inline]
fn now() -> Self {
Expand Down
3 changes: 3 additions & 0 deletions openraft/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub mod metrics;
pub mod network;
pub mod raft;
pub mod storage;
#[cfg(feature = "tokio-rt")]
pub mod testing;
pub mod type_config;

Expand All @@ -71,6 +72,7 @@ pub use anyerror;
pub use anyerror::AnyError;
pub use openraft_macros::add_async_trait;
pub use type_config::async_runtime;
#[cfg(feature = "tokio-rt")]
pub use type_config::async_runtime::impls::TokioRuntime;
pub use type_config::AsyncRuntime;

Expand All @@ -82,6 +84,7 @@ pub use crate::core::ServerState;
pub use crate::entry::Entry;
pub use crate::entry::EntryPayload;
pub use crate::instant::Instant;
#[cfg(feature = "tokio-rt")]
pub use crate::instant::TokioInstant;
pub use crate::log_id::LogId;
pub use crate::log_id::LogIdOptionExt;
Expand Down
8 changes: 5 additions & 3 deletions openraft/src/metrics/wait.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use core::time::Duration;
use std::collections::BTreeSet;

use futures::FutureExt;

use crate::async_runtime::watch::WatchReceiver;
use crate::core::ServerState;
use crate::metrics::Condition;
Expand Down Expand Up @@ -62,12 +64,12 @@ where C: RaftTypeConfig
tracing::debug!(?sleep_time, "wait timeout");
let delay = C::sleep(sleep_time);

tokio::select! {
_ = delay => {
futures::select_biased! {
_ = delay.fuse() => {
tracing::debug!( "id={} timeout wait {:} latest: {}", latest.id, msg.to_string(), latest );
return Err(WaitError::Timeout(self.timeout, format!("{} latest: {}", msg.to_string(), latest)));
}
changed = rx.changed() => {
changed = rx.changed().fuse() => {
match changed {
Ok(_) => {
// metrics changed, continue the waiting loop
Expand Down
16 changes: 16 additions & 0 deletions openraft/src/network/snapshot_transport.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,46 @@
//! AsyncWrite + AsyncRead + AsyncSeek + Unpin.
use std::future::Future;
#[cfg(feature = "tokio-rt")]
use std::io::SeekFrom;
#[cfg(feature = "tokio-rt")]
use std::time::Duration;

#[cfg(feature = "tokio-rt")]
use futures::FutureExt;
use openraft_macros::add_async_trait;
#[cfg(feature = "tokio-rt")]
use tokio::io::AsyncReadExt;
#[cfg(feature = "tokio-rt")]
use tokio::io::AsyncSeekExt;
#[cfg(feature = "tokio-rt")]
use tokio::io::AsyncWriteExt;

use crate::error::Fatal;
use crate::error::InstallSnapshotError;
#[cfg(feature = "tokio-rt")]
use crate::error::RPCError;
use crate::error::RaftError;
use crate::error::ReplicationClosed;
use crate::error::StreamingError;
use crate::network::RPCOption;
use crate::raft::InstallSnapshotRequest;
use crate::raft::SnapshotResponse;
#[cfg(feature = "tokio-rt")]
use crate::type_config::TypeConfigExt;
#[cfg(feature = "tokio-rt")]
use crate::ErrorSubject;
#[cfg(feature = "tokio-rt")]
use crate::ErrorVerb;
use crate::OptionalSend;
use crate::Raft;
use crate::RaftNetwork;
use crate::RaftTypeConfig;
use crate::Snapshot;
use crate::SnapshotId;
#[cfg(feature = "tokio-rt")]
use crate::StorageError;
#[cfg(feature = "tokio-rt")]
use crate::ToStorageResult;
use crate::Vote;

Expand Down Expand Up @@ -94,6 +106,7 @@ pub trait SnapshotTransport<C: RaftTypeConfig> {
pub struct Chunked {}

/// This chunk based implementation requires `SnapshotData` to be `AsyncRead + AsyncSeek`.
#[cfg(feature = "tokio-rt")]
impl<C: RaftTypeConfig> SnapshotTransport<C> for Chunked
where C::SnapshotData: tokio::io::AsyncRead + tokio::io::AsyncWrite + tokio::io::AsyncSeek + Unpin
{
Expand Down Expand Up @@ -282,6 +295,8 @@ pub struct Streaming<C>
where C: RaftTypeConfig
{
/// The offset of the last byte written to the snapshot.
#[cfg_attr(not(feature = "tokio-rt"), allow(dead_code))]
// This field will only be read when feature tokio-rt is on
offset: u64,

/// The ID of the snapshot being written.
Expand Down Expand Up @@ -312,6 +327,7 @@ where C: RaftTypeConfig
}
}

#[cfg(feature = "tokio-rt")]
impl<C> Streaming<C>
where
C: RaftTypeConfig,
Expand Down
1 change: 1 addition & 0 deletions openraft/src/network/v2/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#[cfg(feature = "tokio-rt")]
mod adapt_v1;
mod network;

Expand Down
5 changes: 3 additions & 2 deletions openraft/src/raft/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use std::future::Future;
use std::sync::Arc;
use std::time::Duration;

use async_lock::Mutex;
use core_state::CoreState;
pub use message::AppendEntriesRequest;
pub use message::AppendEntriesResponse;
Expand All @@ -40,7 +41,6 @@ pub use message::InstallSnapshotResponse;
pub use message::SnapshotResponse;
pub use message::VoteRequest;
pub use message::VoteResponse;
use tokio::sync::Mutex;
use tracing::trace_span;
use tracing::Instrument;
use tracing::Level;
Expand Down Expand Up @@ -168,7 +168,7 @@ macro_rules! declare_raft_types {
(NodeId , , u64 ),
(Node , , $crate::impls::BasicNode ),
(Entry , , $crate::impls::Entry<Self> ),
(SnapshotData , , Cursor<Vec<u8>> ),
(SnapshotData , , std::io::Cursor<Vec<u8>> ),
(Responder , , $crate::impls::OneshotResponder<Self> ),
(AsyncRuntime , , $crate::impls::TokioRuntime ),
);
Expand Down Expand Up @@ -437,6 +437,7 @@ where C: RaftTypeConfig
/// If receiving is finished `done == true`, it installs the snapshot to the state machine.
/// Nothing will be done if the input snapshot is older than the state machine.
#[tracing::instrument(level = "debug", skip_all)]
#[cfg(feature = "tokio-rt")]
pub async fn install_snapshot(
&self,
req: InstallSnapshotRequest<C>,
Expand Down
4 changes: 3 additions & 1 deletion openraft/src/raft/raft_inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::fmt::Debug;
use std::future::Future;
use std::sync::Arc;

use tokio::sync::Mutex;
use async_lock::Mutex;
use tracing::Level;

use crate::async_runtime::MpscUnboundedSender;
Expand Down Expand Up @@ -46,6 +46,8 @@ where C: RaftTypeConfig
pub(in crate::raft) core_state: Mutex<CoreState<C>>,

/// The ongoing snapshot transmission.
#[cfg_attr(not(feature = "tokio-rt"), allow(dead_code))]
// This field will only be read when feature tokio-rt is on
pub(in crate::raft) snapshot: Mutex<Option<crate::network::snapshot_transport::Streaming<C>>>,
}

Expand Down
9 changes: 4 additions & 5 deletions openraft/src/replication/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,14 @@ use std::sync::Arc;
use std::time::Duration;

use anyerror::AnyError;
use async_lock::Mutex;
use futures::future::FutureExt;
pub(crate) use replication_session_id::ReplicationSessionId;
use request::Data;
use request::DataWithId;
use request::Replicate;
use response::ReplicationResult;
pub(crate) use response::Response;
use tokio::select;
use tokio::sync::Mutex;
use tracing_futures::Instrument;

use crate::async_runtime::MpscUnboundedReceiver;
Expand Down Expand Up @@ -587,12 +586,12 @@ where

tracing::debug!("backoff timeout: {:?}", sleep_duration);

select! {
_ = sleep => {
futures::select! {
_ = sleep.fuse() => {
tracing::debug!("backoff timeout");
return Ok(());
}
recv_res = recv => {
recv_res = recv.fuse() => {
let event = recv_res.ok_or(ReplicationClosed::new("RaftCore closed replication"))?;
self.process_event(event);
}
Expand Down
2 changes: 2 additions & 0 deletions openraft/src/type_config/async_runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
//! `async-std`, etc.
pub(crate) mod impls {
#[cfg(feature = "tokio-rt")]
mod tokio_runtime;

#[cfg(feature = "tokio-rt")]
pub use tokio_runtime::TokioRuntime;
}
pub mod mpsc_unbounded;
Expand Down

0 comments on commit d0ae564

Please sign in to comment.