Skip to content

Commit

Permalink
Refactor: testing fixture: add RPC pre-hook to let user define what t…
Browse files Browse the repository at this point in the history
…o do before sending an RPC
  • Loading branch information
drmingdrmer committed Nov 23, 2023
1 parent 6fdd794 commit 7d4ebdc
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 13 deletions.
77 changes: 65 additions & 12 deletions tests/tests/fixtures/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use maplit::btreeset;
use openraft::async_trait::async_trait;
use openraft::error::CheckIsLeaderError;
use openraft::error::ClientWriteError;
use openraft::error::Infallible;
use openraft::error::InstallSnapshotError;
use openraft::error::NetworkError;
use openraft::error::RPCError;
Expand Down Expand Up @@ -140,7 +141,7 @@ pub fn log_panic(panic: &PanicInfo) {
#[derive(Debug, Clone, Copy)]
#[derive(PartialEq, Eq)]
#[derive(Hash)]
enum Direction {
pub enum Direction {
NetSend,
NetRecv,
}
Expand All @@ -158,7 +159,7 @@ use Direction::NetRecv;
use Direction::NetSend;

#[derive(Debug, Clone, Copy)]
enum RPCErrorType {
pub enum RPCErrorType {
/// Returns [`Unreachable`](`openraft::error::Unreachable`).
Unreachable,
/// Returns [`NetworkError`](`openraft::error::NetworkError`).
Expand All @@ -181,6 +182,11 @@ impl RPCErrorType {
}
}

/// Pre-hook result, which does not return remote Error.
pub type PreHookResult = Result<(), RPCError<MemNodeId, (), Infallible>>;

pub type RPCPreHook = Box<dyn Fn(&TypedRaftRouter, RPCTypes, MemNodeId, MemNodeId) -> PreHookResult + Send + 'static>;

/// A type which emulates a network transport and implements the `RaftNetworkFactory` trait.
#[derive(Clone)]
pub struct TypedRaftRouter {
Expand All @@ -207,6 +213,9 @@ pub struct TypedRaftRouter {

/// Count of RPCs sent.
rpc_count: Arc<Mutex<HashMap<RPCTypes, u64>>>,

/// A hook function to be called when before an RPC is sent to target node.
rpc_pre_hook: Arc<Mutex<HashMap<RPCTypes, RPCPreHook>>>,
}

/// Default `RaftRouter` for memstore.
Expand Down Expand Up @@ -241,6 +250,7 @@ impl Builder {
send_delay: Arc::new(AtomicU64::new(send_delay)),
append_entries_quota: Arc::new(Mutex::new(None)),
rpc_count: Default::default(),
rpc_pre_hook: Default::default(),
}
}
}
Expand Down Expand Up @@ -462,7 +472,7 @@ impl TypedRaftRouter {
}

/// Set whether to emit a specified rpc error when sending to/receiving from a node.
fn set_rpc_failure(&self, id: MemNodeId, dir: Direction, rpc_error_type: Option<RPCErrorType>) {
pub fn set_rpc_failure(&self, id: MemNodeId, dir: Direction, rpc_error_type: Option<RPCErrorType>) {
let mut fails = self.fail_rpc.lock().unwrap();
if let Some(rpc_error_type) = rpc_error_type {
fails.insert((id, dir), rpc_error_type);
Expand All @@ -471,6 +481,49 @@ impl TypedRaftRouter {
}
}

/// Set a hook function to be called when before an RPC is sent to target node.
pub fn set_rpc_pre_hook(&self, rpc_type: RPCTypes, hook: Option<RPCPreHook>) {
let mut rpc_pre_hook = self.rpc_pre_hook.lock().unwrap();
if let Some(hook) = hook {
rpc_pre_hook.insert(rpc_type, hook);
} else {
rpc_pre_hook.remove(&rpc_type);
}
}

/// Call pre-hook before an RPC is sent.
fn call_rpc_pre_hook<E>(
&self,
rpc_type: RPCTypes,
from: MemNodeId,
to: MemNodeId,
) -> Result<(), RPCError<MemNodeId, (), E>>
where
E: std::error::Error,
{
let rpc_pre_hook = self.rpc_pre_hook.lock().unwrap();
if let Some(hook) = rpc_pre_hook.get(&rpc_type) {
let res = hook(self, rpc_type, from, to);
match res {
Ok(()) => Ok(()),
Err(err) => {
// The pre-hook should only return RPCError variants
let rpc_err = match err {
RPCError::Timeout(e) => e.into(),
RPCError::Unreachable(e) => e.into(),
RPCError::Network(e) => e.into(),
RPCError::RemoteError(e) => {
unreachable!("unexpected RemoteError: {:?}", e);
}
};
Err(rpc_err)
}
}
} else {
Ok(())
}
}

/// Get a payload of the latest metrics from each node in the cluster.
#[allow(clippy::significant_drop_in_scrutinee)]
pub fn latest_metrics(&self) -> Vec<RaftMetrics<MemNodeId, ()>> {
Expand Down Expand Up @@ -889,12 +942,12 @@ impl RaftNetwork<MemConfig> for RaftRouterNetwork {
&mut self,
mut rpc: AppendEntriesRequest<MemConfig>,
) -> Result<AppendEntriesResponse<MemNodeId>, RPCError<MemNodeId, (), RaftError<MemNodeId>>> {
let from_id = rpc.vote.leader_id().voted_for().unwrap();

tracing::debug!("append_entries to id={} {}", self.target, rpc.summary());
self.owner.count_rpc(RPCTypes::AppendEntries);

let from_id = rpc.vote.leader_id().voted_for().unwrap();
self.owner.call_rpc_pre_hook(RPCTypes::AppendEntries, from_id, self.target)?;
self.owner.emit_rpc_error(from_id, self.target)?;

self.owner.rand_send_delay().await;

// decrease quota if quota is set
Expand Down Expand Up @@ -953,11 +1006,11 @@ impl RaftNetwork<MemConfig> for RaftRouterNetwork {
rpc: InstallSnapshotRequest<MemConfig>,
) -> Result<InstallSnapshotResponse<MemNodeId>, RPCError<MemNodeId, (), RaftError<MemNodeId, InstallSnapshotError>>>
{
self.owner.count_rpc(RPCTypes::InstallSnapshot);

let from_id = rpc.vote.leader_id().voted_for().unwrap();
self.owner.emit_rpc_error(from_id, self.target)?;

self.owner.count_rpc(RPCTypes::InstallSnapshot);
self.owner.call_rpc_pre_hook(RPCTypes::InstallSnapshot, from_id, self.target)?;
self.owner.emit_rpc_error(from_id, self.target)?;
self.owner.rand_send_delay().await;

let node = self.owner.get_raft_handle(&self.target)?;
Expand All @@ -973,11 +1026,11 @@ impl RaftNetwork<MemConfig> for RaftRouterNetwork {
&mut self,
rpc: VoteRequest<MemNodeId>,
) -> Result<VoteResponse<MemNodeId>, RPCError<MemNodeId, (), RaftError<MemNodeId>>> {
self.owner.count_rpc(RPCTypes::Vote);

let from_id = rpc.vote.leader_id().voted_for().unwrap();
self.owner.emit_rpc_error(from_id, self.target)?;

self.owner.count_rpc(RPCTypes::Vote);
self.owner.call_rpc_pre_hook(RPCTypes::Vote, from_id, self.target)?;
self.owner.emit_rpc_error(from_id, self.target)?;
self.owner.rand_send_delay().await;

let node = self.owner.get_raft_handle(&self.target)?;
Expand Down
17 changes: 16 additions & 1 deletion tests/tests/replication/t50_append_entries_backoff.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::sync::Arc;
use std::time::Duration;

use anyerror::AnyError;
use anyhow::Result;
use maplit::btreeset;
use openraft::error::RPCError;
use openraft::error::Unreachable;
use openraft::Config;
use openraft::RPCTypes;

Expand Down Expand Up @@ -32,7 +35,19 @@ async fn append_entries_backoff() -> Result<()> {

tracing::info!(log_index, "--- set node 2 to unreachable, and write 10 entries");
{
router.set_unreachable(2, true);
router.set_rpc_pre_hook(
RPCTypes::AppendEntries,
Some(Box::new(|_router, _t, _id, target| {
if target == 2 {
let any_err = AnyError::error("unreachable");
Err(RPCError::Unreachable(Unreachable::new(&any_err)))
} else {
Ok(())
}
})),
);
// The above is equivalent to the following:
// router.set_unreachable(2, true);

router.client_request_many(0, "0", n as usize).await?;
log_index += n;
Expand Down

0 comments on commit 7d4ebdc

Please sign in to comment.