Skip to content

Commit

Permalink
Merge pull request #967 from drmingdrmer/55-wait
Browse files Browse the repository at this point in the history
Feature: add `Wait::eq()` and `ge()` to await a metics
  • Loading branch information
drmingdrmer authored Dec 11, 2023
2 parents 212a19c + 9f5a695 commit f0576e5
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 52 deletions.
83 changes: 83 additions & 0 deletions openraft/src/metrics/metric.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use std::cmp::Ordering;

use crate::metrics::metric_display::MetricDisplay;
use crate::LogId;
use crate::LogIdOptionExt;
use crate::Node;
use crate::NodeId;
use crate::RaftMetrics;
use crate::Vote;

/// A metric entry of a Raft node.
///
/// This is used to specify which metric to observe.
#[derive(Debug)]
pub enum Metric<NID>
where NID: NodeId
{
Term(u64),
Vote(Vote<NID>),
LastLogIndex(Option<u64>),
Applied(Option<LogId<NID>>),
AppliedIndex(Option<u64>),
Snapshot(Option<LogId<NID>>),
Purged(Option<LogId<NID>>),
}

impl<NID> Metric<NID>
where NID: NodeId
{
pub(crate) fn name(&self) -> &'static str {
match self {
Metric::Term(_) => "term",
Metric::Vote(_) => "vote",
Metric::LastLogIndex(_) => "last_log_index",
Metric::Applied(_) => "applied",
Metric::AppliedIndex(_) => "applied_index",
Metric::Snapshot(_) => "snapshot",
Metric::Purged(_) => "purged",
}
}

pub(crate) fn value(&self) -> MetricDisplay<'_, NID> {
MetricDisplay { metric: self }
}
}

/// Metric can be compared with RaftMetrics by comparing the corresponding field of RaftMetrics.
impl<NID, N> PartialEq<Metric<NID>> for RaftMetrics<NID, N>
where
NID: NodeId,
N: Node,
{
fn eq(&self, other: &Metric<NID>) -> bool {
match other {
Metric::Term(v) => self.current_term == *v,
Metric::Vote(v) => &self.vote == v,
Metric::LastLogIndex(v) => self.last_log_index == *v,
Metric::Applied(v) => &self.last_applied == v,
Metric::AppliedIndex(v) => self.last_applied.index() == *v,
Metric::Snapshot(v) => &self.snapshot == v,
Metric::Purged(v) => &self.purged == v,
}
}
}

/// Metric can be compared with RaftMetrics by comparing the corresponding field of RaftMetrics.
impl<NID, N> PartialOrd<Metric<NID>> for RaftMetrics<NID, N>
where
NID: NodeId,
N: Node,
{
fn partial_cmp(&self, other: &Metric<NID>) -> Option<Ordering> {
match other {
Metric::Term(v) => Some(self.current_term.cmp(v)),
Metric::Vote(v) => self.vote.partial_cmp(v),
Metric::LastLogIndex(v) => Some(self.last_log_index.cmp(v)),
Metric::Applied(v) => Some(self.last_applied.cmp(v)),
Metric::AppliedIndex(v) => Some(self.last_applied.index().cmp(v)),
Metric::Snapshot(v) => Some(self.snapshot.cmp(v)),
Metric::Purged(v) => Some(self.purged.cmp(v)),
}
}
}
29 changes: 29 additions & 0 deletions openraft/src/metrics/metric_display.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use std::fmt;
use std::fmt::Formatter;

use crate::display_ext::DisplayOption;
use crate::metrics::Metric;
use crate::NodeId;

/// Display the value of a metric.
pub(crate) struct MetricDisplay<'a, NID>
where NID: NodeId
{
pub(crate) metric: &'a Metric<NID>,
}

impl<'a, NID> fmt::Display for MetricDisplay<'a, NID>
where NID: NodeId
{
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match self.metric {
Metric::Term(v) => write!(f, "{}", v),
Metric::Vote(v) => write!(f, "{}", v),
Metric::LastLogIndex(v) => write!(f, "{}", DisplayOption(v)),
Metric::Applied(v) => write!(f, "{}", DisplayOption(v)),
Metric::AppliedIndex(v) => write!(f, "{}", DisplayOption(v)),
Metric::Snapshot(v) => write!(f, "{}", DisplayOption(v)),
Metric::Purged(v) => write!(f, "{}", DisplayOption(v)),
}
}
}
5 changes: 5 additions & 0 deletions openraft/src/metrics/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,21 @@
//! not every change of the state.
//! Because internally, `watch::channel()` only stores one last state.
mod metric;
mod raft_metrics;
mod wait;

mod metric_display;
mod wait_condition;
#[cfg(test)] mod wait_test;

use std::collections::BTreeMap;

pub use metric::Metric;
pub use raft_metrics::RaftMetrics;
pub use wait::Wait;
pub use wait::WaitError;
pub(crate) use wait_condition::Condition;

use crate::LogId;

Expand Down
99 changes: 47 additions & 52 deletions openraft/src/metrics/wait.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ use std::collections::BTreeSet;
use tokio::sync::watch;

use crate::core::ServerState;
use crate::display_ext::DisplayOption;
use crate::metrics::Condition;
use crate::metrics::Metric;
use crate::metrics::RaftMetrics;
use crate::node::Node;
use crate::AsyncRuntime;
use crate::Instant;
use crate::LogId;
use crate::LogIdOptionExt;
use crate::MessageSummary;
use crate::NodeId;
use crate::OptionalSend;
Expand Down Expand Up @@ -114,7 +114,7 @@ where
/// Wait for `vote` to become `want` or timeout.
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn vote(&self, want: Vote<NID>, msg: impl ToString) -> Result<RaftMetrics<NID, N>, WaitError> {
self.metrics(|m| m.vote == want, &format!("{} .vote -> {}", msg.to_string(), want)).await
self.eq(Metric::Vote(want), msg).await
}

/// Wait for `current_leader` to become `Some(leader_id)` until timeout.
Expand All @@ -131,17 +131,8 @@ where
#[deprecated(note = "use `log_index()` and `applied_index()` instead, deprecated since 0.9.0")]
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn log(&self, want_log_index: Option<u64>, msg: impl ToString) -> Result<RaftMetrics<NID, N>, WaitError> {
self.metrics(
|x| x.last_log_index == want_log_index,
&format!("{} .last_log_index == {:?}", msg.to_string(), want_log_index),
)
.await?;

self.metrics(
|x| x.last_applied.index() == want_log_index,
&format!("{} .last_applied == {:?}", msg.to_string(), want_log_index),
)
.await
self.eq(Metric::LastLogIndex(want_log_index), msg.to_string()).await?;
self.eq(Metric::AppliedIndex(want_log_index), msg.to_string()).await
}

/// Wait until applied at least `want_log`(inclusive) logs or timeout.
Expand All @@ -152,27 +143,14 @@ where
want_log: Option<u64>,
msg: impl ToString,
) -> Result<RaftMetrics<NID, N>, WaitError> {
self.metrics(
|x| x.last_log_index >= want_log,
&format!("{} .last_log_index >= {:?}", msg.to_string(), want_log),
)
.await?;

self.metrics(
|x| x.last_applied.index() >= want_log,
&format!("{} .last_applied >= {:?}", msg.to_string(), want_log),
)
.await
self.ge(Metric::LastLogIndex(want_log), msg.to_string()).await?;
self.ge(Metric::AppliedIndex(want_log), msg.to_string()).await
}

/// Block until the last log index becomes exactly `index`(inclusive) or timeout.
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn log_index(&self, index: Option<u64>, msg: impl ToString) -> Result<RaftMetrics<NID, N>, WaitError> {
self.metrics(
|m| m.last_log_index == index,
&format!("{} .last_log_index == {:?}", msg.to_string(), index),
)
.await
self.eq(Metric::LastLogIndex(index), msg).await
}

/// Block until the last log index becomes at least `index`(inclusive) or timeout.
Expand All @@ -182,11 +160,7 @@ where
index: Option<u64>,
msg: impl ToString,
) -> Result<RaftMetrics<NID, N>, WaitError> {
self.metrics(
|m| m.last_log_index >= index,
&format!("{} .last_log_index >= {:?}", msg.to_string(), index),
)
.await
self.ge(Metric::LastLogIndex(index), msg).await
}

/// Block until the applied index becomes exactly `index`(inclusive) or timeout.
Expand All @@ -196,11 +170,7 @@ where
index: Option<u64>,
msg: impl ToString,
) -> Result<RaftMetrics<NID, N>, WaitError> {
self.metrics(
|m| m.last_applied.index() == index,
&format!("{} .last_applied.index == {:?}", msg.to_string(), index),
)
.await
self.eq(Metric::AppliedIndex(index), msg).await
}

/// Block until the last applied log index become at least `index`(inclusive) or timeout.
Expand All @@ -211,11 +181,7 @@ where
index: Option<u64>,
msg: impl ToString,
) -> Result<RaftMetrics<NID, N>, WaitError> {
self.metrics(
|m| m.last_log_index >= index && m.last_applied.index() >= index,
&format!("{} .last_applied.index >= {:?}", msg.to_string(), index),
)
.await
self.ge(Metric::AppliedIndex(index), msg).await
}

/// Wait for `state` to become `want_state` or timeout.
Expand Down Expand Up @@ -274,19 +240,48 @@ where
snapshot_last_log_id: LogId<NID>,
msg: impl ToString,
) -> Result<RaftMetrics<NID, N>, WaitError> {
self.metrics(
|m| m.snapshot == Some(snapshot_last_log_id),
&format!("{} .snapshot == {}", msg.to_string(), snapshot_last_log_id),
)
.await
self.eq(Metric::Snapshot(Some(snapshot_last_log_id)), msg).await
}

/// Wait for `purged` to become `want` or timeout.
#[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))]
pub async fn purged(&self, want: Option<LogId<NID>>, msg: impl ToString) -> Result<RaftMetrics<NID, N>, WaitError> {
self.eq(Metric::Purged(want), msg).await
}

/// Block until a metric becomes greater than or equal the specified value or timeout.
///
/// For example, to await until the term becomes 2 or greater:
/// ```ignore
/// my_raft.wait(None).ge(Metric::Term(2), "become term 2").await?;
/// ```
pub async fn ge(&self, metric: Metric<NID>, msg: impl ToString) -> Result<RaftMetrics<NID, N>, WaitError> {
self.until(Condition::ge(metric), msg).await
}

/// Block until a metric becomes equal to the specified value or timeout.
///
/// For example, to await until the term becomes exact 2:
/// ```ignore
/// my_raft.wait(None).eq(Metric::Term(2), "become term 2").await?;
/// ```
pub async fn eq(&self, metric: Metric<NID>, msg: impl ToString) -> Result<RaftMetrics<NID, N>, WaitError> {
self.until(Condition::eq(metric), msg).await
}

/// Block until a metric satisfies the specified condition or timeout.
#[tracing::instrument(level = "trace", skip_all, fields(cond=cond.to_string(), msg=msg.to_string().as_str()))]
pub(crate) async fn until(
&self,
cond: Condition<NID>,
msg: impl ToString,
) -> Result<RaftMetrics<NID, N>, WaitError> {
self.metrics(
|m| m.purged == want,
&format!("{} .purged == {}", msg.to_string(), DisplayOption(&want)),
|raft_metrics| match &cond {
Condition::GE(expect) => raft_metrics >= expect,
Condition::EQ(expect) => raft_metrics == expect,
},
&format!("{} .{}", msg.to_string(), cond),
)
.await
}
Expand Down
57 changes: 57 additions & 0 deletions openraft/src/metrics/wait_condition.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
use std::fmt;

use crate::metrics::metric_display::MetricDisplay;
use crate::metrics::Metric;
use crate::NodeId;

/// A condition that the application wait for.
#[derive(Debug)]
pub(crate) enum Condition<NID>
where NID: NodeId
{
GE(Metric<NID>),
EQ(Metric<NID>),
}

impl<NID> Condition<NID>
where NID: NodeId
{
/// Build a new condition which the application will await to meet or exceed.
pub(crate) fn ge(v: Metric<NID>) -> Self {
Self::GE(v)
}

/// Build a new condition which the application will await to meet.
pub(crate) fn eq(v: Metric<NID>) -> Self {
Self::EQ(v)
}

pub(crate) fn name(&self) -> &'static str {
match self {
Condition::GE(v) => v.name(),
Condition::EQ(v) => v.name(),
}
}

pub(crate) fn op(&self) -> &'static str {
match self {
Condition::GE(_) => ">=",
Condition::EQ(_) => "==",
}
}

pub(crate) fn value(&self) -> MetricDisplay<'_, NID> {
match self {
Condition::GE(v) => v.value(),
Condition::EQ(v) => v.value(),
}
}
}

impl<NID> fmt::Display for Condition<NID>
where NID: NodeId
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}{}{}", self.name(), self.op(), self.value())
}
}

0 comments on commit f0576e5

Please sign in to comment.