Skip to content

Commit

Permalink
Refactor: pre-hook callback within tests/fixtures
Browse files Browse the repository at this point in the history
Simplify the API to set pre-hook callback.

The first argument is changed from `RPCTypes` to real RPC data, enabling
the hook callback to interact directly with the RPC request.
  • Loading branch information
drmingdrmer committed Nov 24, 2023
1 parent 20832ee commit d32d977
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 19 deletions.
1 change: 1 addition & 0 deletions tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ openraft-memstore = { path="../memstore" }
anyerror = { workspace = true }
anyhow = { workspace = true }
async-entry = { workspace = true }
derive_more = { workspace = true }
futures = { workspace = true }
lazy_static = { workspace = true }
maplit = { workspace = true }
Expand Down
47 changes: 39 additions & 8 deletions tests/tests/fixtures/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ use openraft::Raft;
use openraft::RaftLogId;
use openraft::RaftMetrics;
use openraft::RaftState;
use openraft::RaftTypeConfig;
use openraft::ServerState;
use openraft::TokioInstant;
use openraft::TokioRuntime;
Expand All @@ -63,6 +64,7 @@ use openraft_memstore::ClientResponse;
use openraft_memstore::IntoMemClientRequest;
use openraft_memstore::MemNodeId;
use openraft_memstore::MemStore;
use openraft_memstore::TypeConfig;
use openraft_memstore::TypeConfig as MemConfig;
#[allow(unused_imports)] use pretty_assertions::assert_eq;
#[allow(unused_imports)] use pretty_assertions::assert_ne;
Expand Down Expand Up @@ -185,7 +187,26 @@ impl RPCErrorType {
/// Pre-hook result, which does not return remote Error.
pub type PreHookResult = Result<(), RPCError<MemNodeId, (), Infallible>>;

pub type RPCPreHook = Box<dyn Fn(&TypedRaftRouter, RPCTypes, MemNodeId, MemNodeId) -> PreHookResult + Send + 'static>;
#[derive(derive_more::From, derive_more::TryInto)]
pub enum RPCRequest<C: RaftTypeConfig> {
AppendEntries(AppendEntriesRequest<C>),
InstallSnapshot(InstallSnapshotRequest<C>),
Vote(VoteRequest<C::NodeId>),
}

impl<C: RaftTypeConfig> RPCRequest<C> {
pub fn get_type(&self) -> RPCTypes {
match self {
RPCRequest::AppendEntries(_) => RPCTypes::AppendEntries,
RPCRequest::InstallSnapshot(_) => RPCTypes::InstallSnapshot,
RPCRequest::Vote(_) => RPCTypes::Vote,
}
}
}

/// Arguments: `(router, rpc, from_id, to_id)`
pub type RPCPreHook =
Box<dyn Fn(&TypedRaftRouter, RPCRequest<TypeConfig>, MemNodeId, MemNodeId) -> PreHookResult + Send + 'static>;

/// A type which emulates a network transport and implements the `RaftNetworkFactory` trait.
#[derive(Clone)]
Expand Down Expand Up @@ -482,7 +503,13 @@ impl TypedRaftRouter {
}

/// Set a hook function to be called when before an RPC is sent to target node.
pub fn set_rpc_pre_hook(&self, rpc_type: RPCTypes, hook: Option<RPCPreHook>) {
pub fn set_rpc_pre_hook<F>(&self, rpc_type: RPCTypes, hook: F)
where F: Fn(&TypedRaftRouter, RPCRequest<TypeConfig>, MemNodeId, MemNodeId) -> PreHookResult + Send + 'static {
self.rpc_pre_hook(rpc_type, Some(Box::new(hook)));
}

/// Set or unset a hook function to be called when before an RPC is sent to target node.
pub fn rpc_pre_hook(&self, rpc_type: RPCTypes, hook: Option<RPCPreHook>) {
let mut rpc_pre_hook = self.rpc_pre_hook.lock().unwrap();
if let Some(hook) = hook {
rpc_pre_hook.insert(rpc_type, hook);
Expand All @@ -494,16 +521,20 @@ impl TypedRaftRouter {
/// Call pre-hook before an RPC is sent.
fn call_rpc_pre_hook<E>(
&self,
rpc_type: RPCTypes,
request: impl Into<RPCRequest<TypeConfig>>,
from: MemNodeId,
to: MemNodeId,
) -> Result<(), RPCError<MemNodeId, (), E>>
where
E: std::error::Error,
{
let request = request.into();
let typ = request.get_type();

let rpc_pre_hook = self.rpc_pre_hook.lock().unwrap();
if let Some(hook) = rpc_pre_hook.get(&rpc_type) {
let res = hook(self, rpc_type, from, to);

if let Some(hook) = rpc_pre_hook.get(&typ) {
let res = hook(self, request, from, to);
match res {
Ok(()) => Ok(()),
Err(err) => {
Expand Down Expand Up @@ -946,7 +977,7 @@ impl RaftNetwork<MemConfig> for RaftRouterNetwork {

tracing::debug!("append_entries to id={} {}", self.target, rpc.summary());
self.owner.count_rpc(RPCTypes::AppendEntries);
self.owner.call_rpc_pre_hook(RPCTypes::AppendEntries, from_id, self.target)?;
self.owner.call_rpc_pre_hook(rpc.clone(), from_id, self.target)?;
self.owner.emit_rpc_error(from_id, self.target)?;
self.owner.rand_send_delay().await;

Expand Down Expand Up @@ -1009,7 +1040,7 @@ impl RaftNetwork<MemConfig> for RaftRouterNetwork {
let from_id = rpc.vote.leader_id().voted_for().unwrap();

self.owner.count_rpc(RPCTypes::InstallSnapshot);
self.owner.call_rpc_pre_hook(RPCTypes::InstallSnapshot, from_id, self.target)?;
self.owner.call_rpc_pre_hook(rpc.clone(), from_id, self.target)?;
self.owner.emit_rpc_error(from_id, self.target)?;
self.owner.rand_send_delay().await;

Expand All @@ -1029,7 +1060,7 @@ impl RaftNetwork<MemConfig> for RaftRouterNetwork {
let from_id = rpc.vote.leader_id().voted_for().unwrap();

self.owner.count_rpc(RPCTypes::Vote);
self.owner.call_rpc_pre_hook(RPCTypes::Vote, from_id, self.target)?;
self.owner.call_rpc_pre_hook(rpc.clone(), from_id, self.target)?;
self.owner.emit_rpc_error(from_id, self.target)?;
self.owner.rand_send_delay().await;

Expand Down
19 changes: 8 additions & 11 deletions tests/tests/replication/t50_append_entries_backoff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,14 @@ async fn append_entries_backoff() -> Result<()> {

tracing::info!(log_index, "--- set node 2 to unreachable, and write 10 entries");
{
router.set_rpc_pre_hook(
RPCTypes::AppendEntries,
Some(Box::new(|_router, _t, _id, target| {
if target == 2 {
let any_err = AnyError::error("unreachable");
Err(RPCError::Unreachable(Unreachable::new(&any_err)))
} else {
Ok(())
}
})),
);
router.set_rpc_pre_hook(RPCTypes::AppendEntries, |_router, _req, _id, target| {
if target == 2 {
let any_err = AnyError::error("unreachable");
Err(RPCError::Unreachable(Unreachable::new(&any_err)))
} else {
Ok(())
}
});
// The above is equivalent to the following:
// router.set_unreachable(2, true);

Expand Down

0 comments on commit d32d977

Please sign in to comment.