From ee3b94216b1964100d6c83a0335072adb589da18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E7=82=8E=E6=B3=BC?= Date: Wed, 6 Dec 2023 11:26:29 +0800 Subject: [PATCH] Feature: add `Raft::with_raft_state()` to access `RaftState` with a function This new method serves as a convenience wrapper around `Raft::external_request()`, streamlining use cases where only a single value needs to be returned. Thanks to @tvsfx --- openraft/src/raft/mod.rs | 41 +++++++++++++++++++ tests/tests/client_api/main.rs | 1 + tests/tests/client_api/t16_with_raft_state.rs | 40 ++++++++++++++++++ tests/tests/fixtures/mod.rs | 18 +++++--- tests/tests/life_cycle/t10_initialization.rs | 6 +-- 5 files changed, 97 insertions(+), 9 deletions(-) create mode 100644 tests/tests/client_api/t16_with_raft_state.rs diff --git a/openraft/src/raft/mod.rs b/openraft/src/raft/mod.rs index f585db347..01294f388 100644 --- a/openraft/src/raft/mod.rs +++ b/openraft/src/raft/mod.rs @@ -666,6 +666,47 @@ where C: RaftTypeConfig } } + /// Provides read-only access to [`RaftState`] through a user-provided function. + /// + /// The function `func` is applied to the current [`RaftState`]. The result of this function, + /// of type `V`, is returned wrapped in `Result>`. `Fatal` error will be + /// returned if failed to receive a reply from `RaftCore`. + /// + /// A `Fatal` error is returned if: + /// - Raft core task is stopped normally. + /// - Raft core task is panicked due to programming error. + /// - Raft core task is encountered a storage error. + /// + /// Example for getting the current committed log id: + /// ```ignore + /// let committed = my_raft.with_raft_state(|st| st.committed).await?; + /// ``` + pub async fn with_raft_state(&self, func: F) -> Result> + where + F: FnOnce(&RaftState::Instant>) -> V + Send + 'static, + V: Send + 'static, + { + let (tx, rx) = oneshot::channel(); + + self.external_request(|st| { + let result = func(st); + if let Err(_err) = tx.send(result) { + tracing::error!("{}: to-Raft tx send error", func_name!()); + } + }); + + match rx.await { + Ok(res) => Ok(res), + Err(err) => { + tracing::error!(error = display(&err), "{}: rx recv error", func_name!()); + + let when = format!("{}: rx recv", func_name!()); + let fatal = self.inner.get_core_stopped_error(when, None::).await; + Err(fatal) + } + } + } + /// Send a request to the Raft core loop in a fire-and-forget manner. /// /// The request functor will be called with a mutable reference to both the state machine diff --git a/tests/tests/client_api/main.rs b/tests/tests/client_api/main.rs index f77c67b18..27840277a 100644 --- a/tests/tests/client_api/main.rs +++ b/tests/tests/client_api/main.rs @@ -11,4 +11,5 @@ mod t10_client_writes; mod t11_client_reads; mod t12_trigger_purge_log; mod t13_trigger_snapshot; +mod t16_with_raft_state; mod t50_lagging_network_write; diff --git a/tests/tests/client_api/t16_with_raft_state.rs b/tests/tests/client_api/t16_with_raft_state.rs new file mode 100644 index 000000000..a1f900e85 --- /dev/null +++ b/tests/tests/client_api/t16_with_raft_state.rs @@ -0,0 +1,40 @@ +use std::sync::Arc; + +use anyhow::Result; +use maplit::btreeset; +use openraft::error::Fatal; +use openraft::testing::log_id; +use openraft::Config; + +use crate::fixtures::init_default_ut_tracing; +use crate::fixtures::RaftRouter; + +/// Access Raft state via `Raft::with_raft_state()` +#[async_entry::test(worker_threads = 8, init = "init_default_ut_tracing()", tracing_span = "debug")] +async fn with_raft_state() -> Result<()> { + let config = Arc::new( + Config { + enable_heartbeat: false, + ..Default::default() + } + .validate()?, + ); + + let mut router = RaftRouter::new(config.clone()); + + tracing::info!("--- initializing cluster"); + let log_index = router.new_cluster(btreeset! {0,1,2}, btreeset! {}).await?; + + let n0 = router.get_raft_handle(&0)?; + + let committed = n0.with_raft_state(|st| st.committed).await?; + assert_eq!(committed, Some(log_id(1, 0, log_index))); + + tracing::info!("--- shutting down node 0"); + n0.shutdown().await?; + + let res = n0.with_raft_state(|st| st.committed).await; + assert_eq!(Err(Fatal::Stopped), res); + + Ok(()) +} diff --git a/tests/tests/fixtures/mod.rs b/tests/tests/fixtures/mod.rs index 4e673698d..08953caa1 100644 --- a/tests/tests/fixtures/mod.rs +++ b/tests/tests/fixtures/mod.rs @@ -23,6 +23,7 @@ use maplit::btreeset; use openraft::async_trait::async_trait; use openraft::error::CheckIsLeaderError; use openraft::error::ClientWriteError; +use openraft::error::Fatal; use openraft::error::Infallible; use openraft::error::InstallSnapshotError; use openraft::error::NetworkError; @@ -763,17 +764,24 @@ impl TypedRaftRouter { ) } + /// Send external request to the particular node. + pub async fn with_raft_state(&self, target: MemNodeId, func: F) -> Result> + where + F: FnOnce(&RaftState) -> V + Send + 'static, + V: Send + 'static, + { + let r = self.get_raft_handle(&target).unwrap(); + r.with_raft_state(func).await + } + /// Send external request to the particular node. pub fn external_request) + Send + 'static>( &self, target: MemNodeId, req: F, ) { - let rt = self.nodes.lock().unwrap(); - rt.get(&target) - .unwrap_or_else(|| panic!("node '{}' does not exist in routing table", target)) - .0 - .external_request(req) + let r = self.get_raft_handle(&target).unwrap(); + r.external_request(req) } /// Request the current leader from the target node. diff --git a/tests/tests/life_cycle/t10_initialization.rs b/tests/tests/life_cycle/t10_initialization.rs index 5a9b0b43d..ae3f2fa68 100644 --- a/tests/tests/life_cycle/t10_initialization.rs +++ b/tests/tests/life_cycle/t10_initialization.rs @@ -18,7 +18,6 @@ use openraft::Membership; use openraft::ServerState; use openraft::StoredMembership; use openraft::Vote; -use tokio::sync::oneshot; use crate::fixtures::init_default_ut_tracing; use crate::fixtures::RaftRouter; @@ -146,9 +145,8 @@ async fn initialization() -> anyhow::Result<()> { let mut found_leader = false; let mut follower_count = 0; for node in [0, 1, 2] { - let (tx, rx) = oneshot::channel(); - router.external_request(node, |s| tx.send(s.server_state).unwrap()); - match rx.await.unwrap() { + let server_state = router.with_raft_state(node, |s| s.server_state).await?; + match server_state { ServerState::Leader => { assert!(!found_leader); found_leader = true;