From 7d4ebdcb7a9120fc08c8e610ef15c815ac64f914 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E7=82=8E=E6=B3=BC?= Date: Thu, 23 Nov 2023 21:42:06 +0800 Subject: [PATCH] Refactor: testing fixture: add RPC pre-hook to let user define what to do before sending an RPC --- tests/tests/fixtures/mod.rs | 77 ++++++++++++++++--- .../replication/t50_append_entries_backoff.rs | 17 +++- 2 files changed, 81 insertions(+), 13 deletions(-) diff --git a/tests/tests/fixtures/mod.rs b/tests/tests/fixtures/mod.rs index 4f2bf35a8..26fe9c512 100644 --- a/tests/tests/fixtures/mod.rs +++ b/tests/tests/fixtures/mod.rs @@ -23,6 +23,7 @@ use maplit::btreeset; use openraft::async_trait::async_trait; use openraft::error::CheckIsLeaderError; use openraft::error::ClientWriteError; +use openraft::error::Infallible; use openraft::error::InstallSnapshotError; use openraft::error::NetworkError; use openraft::error::RPCError; @@ -140,7 +141,7 @@ pub fn log_panic(panic: &PanicInfo) { #[derive(Debug, Clone, Copy)] #[derive(PartialEq, Eq)] #[derive(Hash)] -enum Direction { +pub enum Direction { NetSend, NetRecv, } @@ -158,7 +159,7 @@ use Direction::NetRecv; use Direction::NetSend; #[derive(Debug, Clone, Copy)] -enum RPCErrorType { +pub enum RPCErrorType { /// Returns [`Unreachable`](`openraft::error::Unreachable`). Unreachable, /// Returns [`NetworkError`](`openraft::error::NetworkError`). @@ -181,6 +182,11 @@ impl RPCErrorType { } } +/// Pre-hook result, which does not return remote Error. +pub type PreHookResult = Result<(), RPCError>; + +pub type RPCPreHook = Box PreHookResult + Send + 'static>; + /// A type which emulates a network transport and implements the `RaftNetworkFactory` trait. #[derive(Clone)] pub struct TypedRaftRouter { @@ -207,6 +213,9 @@ pub struct TypedRaftRouter { /// Count of RPCs sent. rpc_count: Arc>>, + + /// A hook function to be called when before an RPC is sent to target node. + rpc_pre_hook: Arc>>, } /// Default `RaftRouter` for memstore. @@ -241,6 +250,7 @@ impl Builder { send_delay: Arc::new(AtomicU64::new(send_delay)), append_entries_quota: Arc::new(Mutex::new(None)), rpc_count: Default::default(), + rpc_pre_hook: Default::default(), } } } @@ -462,7 +472,7 @@ impl TypedRaftRouter { } /// Set whether to emit a specified rpc error when sending to/receiving from a node. - fn set_rpc_failure(&self, id: MemNodeId, dir: Direction, rpc_error_type: Option) { + pub fn set_rpc_failure(&self, id: MemNodeId, dir: Direction, rpc_error_type: Option) { let mut fails = self.fail_rpc.lock().unwrap(); if let Some(rpc_error_type) = rpc_error_type { fails.insert((id, dir), rpc_error_type); @@ -471,6 +481,49 @@ impl TypedRaftRouter { } } + /// Set a hook function to be called when before an RPC is sent to target node. + pub fn set_rpc_pre_hook(&self, rpc_type: RPCTypes, hook: Option) { + let mut rpc_pre_hook = self.rpc_pre_hook.lock().unwrap(); + if let Some(hook) = hook { + rpc_pre_hook.insert(rpc_type, hook); + } else { + rpc_pre_hook.remove(&rpc_type); + } + } + + /// Call pre-hook before an RPC is sent. + fn call_rpc_pre_hook( + &self, + rpc_type: RPCTypes, + from: MemNodeId, + to: MemNodeId, + ) -> Result<(), RPCError> + where + E: std::error::Error, + { + let rpc_pre_hook = self.rpc_pre_hook.lock().unwrap(); + if let Some(hook) = rpc_pre_hook.get(&rpc_type) { + let res = hook(self, rpc_type, from, to); + match res { + Ok(()) => Ok(()), + Err(err) => { + // The pre-hook should only return RPCError variants + let rpc_err = match err { + RPCError::Timeout(e) => e.into(), + RPCError::Unreachable(e) => e.into(), + RPCError::Network(e) => e.into(), + RPCError::RemoteError(e) => { + unreachable!("unexpected RemoteError: {:?}", e); + } + }; + Err(rpc_err) + } + } + } else { + Ok(()) + } + } + /// Get a payload of the latest metrics from each node in the cluster. #[allow(clippy::significant_drop_in_scrutinee)] pub fn latest_metrics(&self) -> Vec> { @@ -889,12 +942,12 @@ impl RaftNetwork for RaftRouterNetwork { &mut self, mut rpc: AppendEntriesRequest, ) -> Result, RPCError>> { + let from_id = rpc.vote.leader_id().voted_for().unwrap(); + tracing::debug!("append_entries to id={} {}", self.target, rpc.summary()); self.owner.count_rpc(RPCTypes::AppendEntries); - - let from_id = rpc.vote.leader_id().voted_for().unwrap(); + self.owner.call_rpc_pre_hook(RPCTypes::AppendEntries, from_id, self.target)?; self.owner.emit_rpc_error(from_id, self.target)?; - self.owner.rand_send_delay().await; // decrease quota if quota is set @@ -953,11 +1006,11 @@ impl RaftNetwork for RaftRouterNetwork { rpc: InstallSnapshotRequest, ) -> Result, RPCError>> { - self.owner.count_rpc(RPCTypes::InstallSnapshot); - let from_id = rpc.vote.leader_id().voted_for().unwrap(); - self.owner.emit_rpc_error(from_id, self.target)?; + self.owner.count_rpc(RPCTypes::InstallSnapshot); + self.owner.call_rpc_pre_hook(RPCTypes::InstallSnapshot, from_id, self.target)?; + self.owner.emit_rpc_error(from_id, self.target)?; self.owner.rand_send_delay().await; let node = self.owner.get_raft_handle(&self.target)?; @@ -973,11 +1026,11 @@ impl RaftNetwork for RaftRouterNetwork { &mut self, rpc: VoteRequest, ) -> Result, RPCError>> { - self.owner.count_rpc(RPCTypes::Vote); - let from_id = rpc.vote.leader_id().voted_for().unwrap(); - self.owner.emit_rpc_error(from_id, self.target)?; + self.owner.count_rpc(RPCTypes::Vote); + self.owner.call_rpc_pre_hook(RPCTypes::Vote, from_id, self.target)?; + self.owner.emit_rpc_error(from_id, self.target)?; self.owner.rand_send_delay().await; let node = self.owner.get_raft_handle(&self.target)?; diff --git a/tests/tests/replication/t50_append_entries_backoff.rs b/tests/tests/replication/t50_append_entries_backoff.rs index 9e6da0ee8..4dfc4e1a8 100644 --- a/tests/tests/replication/t50_append_entries_backoff.rs +++ b/tests/tests/replication/t50_append_entries_backoff.rs @@ -1,8 +1,11 @@ use std::sync::Arc; use std::time::Duration; +use anyerror::AnyError; use anyhow::Result; use maplit::btreeset; +use openraft::error::RPCError; +use openraft::error::Unreachable; use openraft::Config; use openraft::RPCTypes; @@ -32,7 +35,19 @@ async fn append_entries_backoff() -> Result<()> { tracing::info!(log_index, "--- set node 2 to unreachable, and write 10 entries"); { - router.set_unreachable(2, true); + router.set_rpc_pre_hook( + RPCTypes::AppendEntries, + Some(Box::new(|_router, _t, _id, target| { + if target == 2 { + let any_err = AnyError::error("unreachable"); + Err(RPCError::Unreachable(Unreachable::new(&any_err))) + } else { + Ok(()) + } + })), + ); + // The above is equivalent to the following: + // router.set_unreachable(2, true); router.client_request_many(0, "0", n as usize).await?; log_index += n;