Skip to content

Commit

Permalink
Merge pull request #946 from drmingdrmer/29-refine-test-fixture
Browse files Browse the repository at this point in the history
Refactor: rpc error simulation in tests/
  • Loading branch information
drmingdrmer authored Nov 23, 2023
2 parents 29995d8 + 7525ea0 commit 751056e
Show file tree
Hide file tree
Showing 14 changed files with 132 additions and 103 deletions.
2 changes: 1 addition & 1 deletion openraft/src/network/rpc_type.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt;

#[derive(Debug, Clone, PartialEq, Eq)]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))]
pub enum RPCTypes {
Vote,
Expand Down
2 changes: 1 addition & 1 deletion tests/tests/append_entries/t11_append_inconsistent_log.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async fn append_inconsistent_log() -> Result<()> {
);
{
router.new_raft_node_with_sto(1, sto1.clone(), sm1.clone()).await;
router.set_node_network_failure(1, true);
router.set_network_error(1, true);
}

tracing::info!(log_index, "--- restart node 0 and 2");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async fn replication_1_voter_to_isolated_learner() -> Result<()> {

tracing::info!(log_index, "--- stop replication to node 1");
{
router.set_node_network_failure(1, true);
router.set_network_error(1, true);

router.client_request_many(0, "0", (10 - log_index) as usize).await?;
log_index = 10;
Expand All @@ -47,7 +47,7 @@ async fn replication_1_voter_to_isolated_learner() -> Result<()> {

tracing::info!(log_index, "--- restore replication to node 1");
{
router.set_node_network_failure(1, false);
router.set_network_error(1, false);

router.client_request_many(0, "0", (10 - log_index) as usize).await?;
log_index = 10;
Expand Down
4 changes: 2 additions & 2 deletions tests/tests/client_api/t11_client_reads.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,12 @@ async fn client_reads() -> Result<()> {

tracing::info!(log_index, "--- isolate node 1 then is_leader should work");

router.set_node_network_failure(1, true);
router.set_network_error(1, true);
router.is_leader(leader).await?;

tracing::info!(log_index, "--- isolate node 2 then is_leader should fail");

router.set_node_network_failure(2, true);
router.set_network_error(2, true);
let rst = router.is_leader(leader).await;
tracing::debug!(?rst, "is_leader with majority down");

Expand Down
183 changes: 106 additions & 77 deletions tests/tests/fixtures/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::collections::HashMap;
use std::collections::HashSet;
use std::env;
use std::fmt;
use std::panic::PanicInfo;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
Expand Down Expand Up @@ -46,6 +46,9 @@ use openraft::Config;
use openraft::LogId;
use openraft::LogIdOptionExt;
use openraft::MessageSummary;
use openraft::Node;
use openraft::NodeId;
use openraft::RPCTypes;
use openraft::Raft;
use openraft::RaftLogId;
use openraft::RaftMetrics;
Expand Down Expand Up @@ -134,6 +137,50 @@ pub fn log_panic(panic: &PanicInfo) {
eprintln!("{}", backtrace);
}

#[derive(Debug, Clone, Copy)]
#[derive(PartialEq, Eq)]
#[derive(Hash)]
enum Direction {
NetSend,
NetRecv,
}

impl fmt::Display for Direction {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
NetSend => write!(f, "sending from"),
NetRecv => write!(f, "receiving by"),
}
}
}

use Direction::NetRecv;
use Direction::NetSend;

#[derive(Debug, Clone, Copy)]
enum RPCErrorType {
/// Returns [`Unreachable`](`openraft::error::Unreachable`).
Unreachable,
/// Returns [`NetworkError`](`openraft::error::NetworkError`).
NetworkError,
}

impl RPCErrorType {
fn make_error<NID, N, E>(&self, id: NID, dir: Direction) -> RPCError<NID, N, RaftError<NID, E>>
where
NID: NodeId,
N: Node,
E: std::error::Error,
{
let msg = format!("error {} id={}", dir, id);

match self {
RPCErrorType::Unreachable => Unreachable::new(&AnyError::error(msg)).into(),
RPCErrorType::NetworkError => NetworkError::new(&AnyError::error(msg)).into(),
}
}
}

/// A type which emulates a network transport and implements the `RaftNetworkFactory` trait.
#[derive(Clone)]
pub struct TypedRaftRouter {
Expand All @@ -144,11 +191,9 @@ pub struct TypedRaftRouter {
#[allow(clippy::type_complexity)]
routing_table: Arc<Mutex<BTreeMap<MemNodeId, (MemRaft, MemLogStore, MemStateMachine)>>>,

/// Nodes that can neither send nor receive frames, and will return an `NetworkError`.
network_failure_nodes: Arc<Mutex<HashSet<MemNodeId>>>,

/// Nodes to which an RPC is sent return an `Unreachable` error.
unreachable_nodes: Arc<Mutex<HashSet<MemNodeId>>>,
/// Whether to fail a network RPC that is sent from/to a node.
/// And it defines what kind of error to return.
fail_rpc: Arc<Mutex<HashMap<(MemNodeId, Direction), RPCErrorType>>>,

/// To emulate network delay for sending, in milliseconds.
/// 0 means no delay.
Expand All @@ -161,14 +206,7 @@ pub struct TypedRaftRouter {
append_entries_quota: Arc<Mutex<Option<u64>>>,

/// Count of RPCs sent.
rpc_count: Arc<Mutex<HashMap<RPCType, u64>>>,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RPCType {
AppendEntries,
InstallSnapshot,
Vote,
rpc_count: Arc<Mutex<HashMap<RPCTypes, u64>>>,
}

/// Default `RaftRouter` for memstore.
Expand Down Expand Up @@ -199,8 +237,7 @@ impl Builder {
TypedRaftRouter {
config: self.config,
routing_table: Default::default(),
network_failure_nodes: Default::default(),
unreachable_nodes: Default::default(),
fail_rpc: Default::default(),
send_delay: Arc::new(AtomicU64::new(send_delay)),
append_entries_quota: Arc::new(Mutex::new(None)),
rpc_count: Default::default(),
Expand Down Expand Up @@ -239,13 +276,13 @@ impl TypedRaftRouter {
*append_entries_quota = quota;
}

fn count_rpc(&self, rpc_type: RPCType) {
fn count_rpc(&self, rpc_type: RPCTypes) {
let mut rpc_count = self.rpc_count.lock().unwrap();
let count = rpc_count.entry(rpc_type).or_insert(0);
*count += 1;
}

pub fn get_rpc_count(&self) -> HashMap<RPCType, u64> {
pub fn get_rpc_count(&self) -> HashMap<RPCTypes, u64> {
self.rpc_count.lock().unwrap().clone()
}

Expand Down Expand Up @@ -375,10 +412,8 @@ impl TypedRaftRouter {
rt.remove(&id)
};

{
let mut isolated = self.network_failure_nodes.lock().unwrap();
isolated.remove(&id);
}
self.set_network_error(id, false);
self.set_unreachable(id, false);

opt_handles
}
Expand All @@ -403,23 +438,36 @@ impl TypedRaftRouter {

/// Isolate the network of the specified node.
#[tracing::instrument(level = "debug", skip(self))]
pub fn set_node_network_failure(&self, id: MemNodeId, emit_failure: bool) {
let mut nodes = self.network_failure_nodes.lock().unwrap();
if emit_failure {
nodes.insert(id);
pub fn set_network_error(&self, id: MemNodeId, emit_failure: bool) {
let v = if emit_failure {
Some(RPCErrorType::NetworkError)
} else {
nodes.remove(&id);
}
None
};

self.set_rpc_failure(id, NetRecv, v);
self.set_rpc_failure(id, NetSend, v);
}

/// Set to `true` to return [`Unreachable`](`openraft::errors::Unreachable`) when sending RPC to
/// a node.
pub fn set_unreachable(&self, id: MemNodeId, unreachable: bool) {
let mut u = self.unreachable_nodes.lock().unwrap();
if unreachable {
u.insert(id);
let v = if unreachable {
Some(RPCErrorType::Unreachable)
} else {
None
};
self.set_rpc_failure(id, NetRecv, v);
self.set_rpc_failure(id, NetSend, v);
}

/// Set whether to emit a specified rpc error when sending to/receiving from a node.
fn set_rpc_failure(&self, id: MemNodeId, dir: Direction, rpc_error_type: Option<RPCErrorType>) {
let mut fails = self.fail_rpc.lock().unwrap();
if let Some(rpc_error_type) = rpc_error_type {
fails.insert((id, dir), rpc_error_type);
} else {
u.remove(&id);
fails.remove(&(id, dir));
}
}

Expand Down Expand Up @@ -551,33 +599,15 @@ impl TypedRaftRouter {

/// Get the ID of the current leader.
pub fn leader(&self) -> Option<MemNodeId> {
let isolated = {
let isolated = self.network_failure_nodes.lock().unwrap();
isolated.clone()
};

tracing::debug!("router::leader: isolated: {:?}", isolated);

self.latest_metrics().into_iter().find_map(|node| {
if node.current_leader == Some(node.id) {
if isolated.contains(&node.id) {
None
} else {
Some(node.id)
}
Some(node.id)
} else {
None
}
})
}

/// Restore the network of the specified node.
#[tracing::instrument(level = "debug", skip(self))]
pub fn restore_node(&self, id: MemNodeId) {
let mut nodes = self.network_failure_nodes.lock().unwrap();
nodes.remove(&id);
}

/// Bring up a new learner and add it to the leader's membership.
pub async fn add_learner(
&self,
Expand Down Expand Up @@ -815,24 +845,20 @@ impl TypedRaftRouter {
}

#[tracing::instrument(level = "debug", skip(self))]
pub fn check_network_error(&self, id: MemNodeId, target: MemNodeId) -> Result<(), NetworkError> {
let isolated = self.network_failure_nodes.lock().unwrap();

if isolated.contains(&target) || isolated.contains(&id) {
let network_err = NetworkError::new(&AnyError::error(format!("isolated:{} -> {}", id, target)));
return Err(network_err);
}

Ok(())
}

#[tracing::instrument(level = "debug", skip(self))]
pub fn check_unreachable(&self, id: MemNodeId, target: MemNodeId) -> Result<(), Unreachable> {
let unreachable = self.unreachable_nodes.lock().unwrap();
pub fn emit_rpc_error<E>(
&self,
id: MemNodeId,
target: MemNodeId,
) -> Result<(), RPCError<MemNodeId, (), RaftError<MemNodeId, E>>>
where
E: std::error::Error,
{
let fails = self.fail_rpc.lock().unwrap();

if unreachable.contains(&target) || unreachable.contains(&id) {
let err = Unreachable::new(&AnyError::error(format!("unreachable:{} -> {}", id, target)));
return Err(err);
for key in [(id, NetSend), (target, NetRecv)] {
if let Some(err_type) = fails.get(&key) {
return Err(err_type.make_error(key.0, key.1));
}
}

Ok(())
Expand Down Expand Up @@ -864,10 +890,11 @@ impl RaftNetwork<MemConfig> for RaftRouterNetwork {
mut rpc: AppendEntriesRequest<MemConfig>,
) -> Result<AppendEntriesResponse<MemNodeId>, RPCError<MemNodeId, (), RaftError<MemNodeId>>> {
tracing::debug!("append_entries to id={} {}", self.target, rpc.summary());
self.owner.count_rpc(RPCType::AppendEntries);
self.owner.count_rpc(RPCTypes::AppendEntries);

let from_id = rpc.vote.leader_id().voted_for().unwrap();
self.owner.emit_rpc_error(from_id, self.target)?;

self.owner.check_network_error(rpc.vote.leader_id().voted_for().unwrap(), self.target)?;
self.owner.check_unreachable(rpc.vote.leader_id().voted_for().unwrap(), self.target)?;
self.owner.rand_send_delay().await;

// decrease quota if quota is set
Expand Down Expand Up @@ -926,10 +953,11 @@ impl RaftNetwork<MemConfig> for RaftRouterNetwork {
rpc: InstallSnapshotRequest<MemConfig>,
) -> Result<InstallSnapshotResponse<MemNodeId>, RPCError<MemNodeId, (), RaftError<MemNodeId, InstallSnapshotError>>>
{
self.owner.count_rpc(RPCType::InstallSnapshot);
self.owner.count_rpc(RPCTypes::InstallSnapshot);

let from_id = rpc.vote.leader_id().voted_for().unwrap();
self.owner.emit_rpc_error(from_id, self.target)?;

self.owner.check_network_error(rpc.vote.leader_id().voted_for().unwrap(), self.target)?;
self.owner.check_unreachable(rpc.vote.leader_id().voted_for().unwrap(), self.target)?;
self.owner.rand_send_delay().await;

let node = self.owner.get_raft_handle(&self.target)?;
Expand All @@ -945,10 +973,11 @@ impl RaftNetwork<MemConfig> for RaftRouterNetwork {
&mut self,
rpc: VoteRequest<MemNodeId>,
) -> Result<VoteResponse<MemNodeId>, RPCError<MemNodeId, (), RaftError<MemNodeId>>> {
self.owner.count_rpc(RPCType::Vote);
self.owner.count_rpc(RPCTypes::Vote);

let from_id = rpc.vote.leader_id().voted_for().unwrap();
self.owner.emit_rpc_error(from_id, self.target)?;

self.owner.check_network_error(rpc.vote.leader_id().voted_for().unwrap(), self.target)?;
self.owner.check_unreachable(rpc.vote.leader_id().voted_for().unwrap(), self.target)?;
self.owner.rand_send_delay().await;

let node = self.owner.get_raft_handle(&self.target)?;
Expand Down
4 changes: 2 additions & 2 deletions tests/tests/membership/t11_add_learner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async fn add_learner_non_blocking() -> Result<()> {
router.new_raft_node(1).await;

// Replication problem should not block adding-learner in non-blocking mode.
router.set_node_network_failure(1, true);
router.set_network_error(1, true);

let raft = router.get_raft_handle(&0)?;
raft.add_learner(1, (), false).await?;
Expand Down Expand Up @@ -208,7 +208,7 @@ async fn add_learner_when_previous_membership_not_committed() -> Result<()> {

tracing::info!(log_index, "--- block replication to prevent committing any log");
{
router.set_node_network_failure(1, true);
router.set_network_error(1, true);

let node = router.get_raft_handle(&0)?;
tokio::spawn(async move {
Expand Down
8 changes: 4 additions & 4 deletions tests/tests/membership/t30_commit_joint_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ async fn commit_joint_config_during_0_to_012() -> Result<()> {
"--- isolate node 1,2, so that membership [0,1,2] wont commit"
);

router.set_node_network_failure(1, true);
router.set_node_network_failure(2, true);
router.set_network_error(1, true);
router.set_network_error(2, true);

tracing::info!(log_index, "--- changing cluster config, should timeout");

Expand Down Expand Up @@ -110,8 +110,8 @@ async fn commit_joint_config_during_012_to_234() -> Result<()> {

tracing::info!(log_index, "--- isolate 3,4");

router.set_node_network_failure(3, true);
router.set_node_network_failure(4, true);
router.set_network_error(3, true);
router.set_network_error(4, true);

tracing::info!(log_index, "--- changing config to 0,1,2");
let node = router.get_raft_handle(&0)?;
Expand Down
Loading

0 comments on commit 751056e

Please sign in to comment.