From 9f5a6955d69aedb5f6635807e274f83955acdc83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E7=82=8E=E6=B3=BC?= Date: Sat, 9 Dec 2023 22:32:35 +0800 Subject: [PATCH] Feature: add `Wait::eq()` and `ge()` to await a metics `Wait` does not need many method for each metric. In this commit, it provides method `eq()` and `ge()` to specify waiting condition in a general way. The metric to await is specified by `Metric` as the first argument. ```rust my_raft.wait(None).ge(Metric::Term(2), "").await? ``` --- openraft/src/metrics/metric.rs | 83 +++++++++++++++++++++ openraft/src/metrics/metric_display.rs | 29 ++++++++ openraft/src/metrics/mod.rs | 5 ++ openraft/src/metrics/wait.rs | 99 ++++++++++++-------------- openraft/src/metrics/wait_condition.rs | 57 +++++++++++++++ 5 files changed, 221 insertions(+), 52 deletions(-) create mode 100644 openraft/src/metrics/metric.rs create mode 100644 openraft/src/metrics/metric_display.rs create mode 100644 openraft/src/metrics/wait_condition.rs diff --git a/openraft/src/metrics/metric.rs b/openraft/src/metrics/metric.rs new file mode 100644 index 000000000..48a9a7c0a --- /dev/null +++ b/openraft/src/metrics/metric.rs @@ -0,0 +1,83 @@ +use std::cmp::Ordering; + +use crate::metrics::metric_display::MetricDisplay; +use crate::LogId; +use crate::LogIdOptionExt; +use crate::Node; +use crate::NodeId; +use crate::RaftMetrics; +use crate::Vote; + +/// A metric entry of a Raft node. +/// +/// This is used to specify which metric to observe. +#[derive(Debug)] +pub enum Metric +where NID: NodeId +{ + Term(u64), + Vote(Vote), + LastLogIndex(Option), + Applied(Option>), + AppliedIndex(Option), + Snapshot(Option>), + Purged(Option>), +} + +impl Metric +where NID: NodeId +{ + pub(crate) fn name(&self) -> &'static str { + match self { + Metric::Term(_) => "term", + Metric::Vote(_) => "vote", + Metric::LastLogIndex(_) => "last_log_index", + Metric::Applied(_) => "applied", + Metric::AppliedIndex(_) => "applied_index", + Metric::Snapshot(_) => "snapshot", + Metric::Purged(_) => "purged", + } + } + + pub(crate) fn value(&self) -> MetricDisplay<'_, NID> { + MetricDisplay { metric: self } + } +} + +/// Metric can be compared with RaftMetrics by comparing the corresponding field of RaftMetrics. +impl PartialEq> for RaftMetrics +where + NID: NodeId, + N: Node, +{ + fn eq(&self, other: &Metric) -> bool { + match other { + Metric::Term(v) => self.current_term == *v, + Metric::Vote(v) => &self.vote == v, + Metric::LastLogIndex(v) => self.last_log_index == *v, + Metric::Applied(v) => &self.last_applied == v, + Metric::AppliedIndex(v) => self.last_applied.index() == *v, + Metric::Snapshot(v) => &self.snapshot == v, + Metric::Purged(v) => &self.purged == v, + } + } +} + +/// Metric can be compared with RaftMetrics by comparing the corresponding field of RaftMetrics. +impl PartialOrd> for RaftMetrics +where + NID: NodeId, + N: Node, +{ + fn partial_cmp(&self, other: &Metric) -> Option { + match other { + Metric::Term(v) => Some(self.current_term.cmp(v)), + Metric::Vote(v) => self.vote.partial_cmp(v), + Metric::LastLogIndex(v) => Some(self.last_log_index.cmp(v)), + Metric::Applied(v) => Some(self.last_applied.cmp(v)), + Metric::AppliedIndex(v) => Some(self.last_applied.index().cmp(v)), + Metric::Snapshot(v) => Some(self.snapshot.cmp(v)), + Metric::Purged(v) => Some(self.purged.cmp(v)), + } + } +} diff --git a/openraft/src/metrics/metric_display.rs b/openraft/src/metrics/metric_display.rs new file mode 100644 index 000000000..7468c11b5 --- /dev/null +++ b/openraft/src/metrics/metric_display.rs @@ -0,0 +1,29 @@ +use std::fmt; +use std::fmt::Formatter; + +use crate::display_ext::DisplayOption; +use crate::metrics::Metric; +use crate::NodeId; + +/// Display the value of a metric. +pub(crate) struct MetricDisplay<'a, NID> +where NID: NodeId +{ + pub(crate) metric: &'a Metric, +} + +impl<'a, NID> fmt::Display for MetricDisplay<'a, NID> +where NID: NodeId +{ + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + match self.metric { + Metric::Term(v) => write!(f, "{}", v), + Metric::Vote(v) => write!(f, "{}", v), + Metric::LastLogIndex(v) => write!(f, "{}", DisplayOption(v)), + Metric::Applied(v) => write!(f, "{}", DisplayOption(v)), + Metric::AppliedIndex(v) => write!(f, "{}", DisplayOption(v)), + Metric::Snapshot(v) => write!(f, "{}", DisplayOption(v)), + Metric::Purged(v) => write!(f, "{}", DisplayOption(v)), + } + } +} diff --git a/openraft/src/metrics/mod.rs b/openraft/src/metrics/mod.rs index 841eaf5f7..f67f3e7a5 100644 --- a/openraft/src/metrics/mod.rs +++ b/openraft/src/metrics/mod.rs @@ -27,16 +27,21 @@ //! not every change of the state. //! Because internally, `watch::channel()` only stores one last state. +mod metric; mod raft_metrics; mod wait; +mod metric_display; +mod wait_condition; #[cfg(test)] mod wait_test; use std::collections::BTreeMap; +pub use metric::Metric; pub use raft_metrics::RaftMetrics; pub use wait::Wait; pub use wait::WaitError; +pub(crate) use wait_condition::Condition; use crate::LogId; diff --git a/openraft/src/metrics/wait.rs b/openraft/src/metrics/wait.rs index 7eab434be..4dc90afbd 100644 --- a/openraft/src/metrics/wait.rs +++ b/openraft/src/metrics/wait.rs @@ -4,13 +4,13 @@ use std::collections::BTreeSet; use tokio::sync::watch; use crate::core::ServerState; -use crate::display_ext::DisplayOption; +use crate::metrics::Condition; +use crate::metrics::Metric; use crate::metrics::RaftMetrics; use crate::node::Node; use crate::AsyncRuntime; use crate::Instant; use crate::LogId; -use crate::LogIdOptionExt; use crate::MessageSummary; use crate::NodeId; use crate::OptionalSend; @@ -114,7 +114,7 @@ where /// Wait for `vote` to become `want` or timeout. #[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))] pub async fn vote(&self, want: Vote, msg: impl ToString) -> Result, WaitError> { - self.metrics(|m| m.vote == want, &format!("{} .vote -> {}", msg.to_string(), want)).await + self.eq(Metric::Vote(want), msg).await } /// Wait for `current_leader` to become `Some(leader_id)` until timeout. @@ -131,17 +131,8 @@ where #[deprecated(note = "use `log_index()` and `applied_index()` instead, deprecated since 0.9.0")] #[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))] pub async fn log(&self, want_log_index: Option, msg: impl ToString) -> Result, WaitError> { - self.metrics( - |x| x.last_log_index == want_log_index, - &format!("{} .last_log_index == {:?}", msg.to_string(), want_log_index), - ) - .await?; - - self.metrics( - |x| x.last_applied.index() == want_log_index, - &format!("{} .last_applied == {:?}", msg.to_string(), want_log_index), - ) - .await + self.eq(Metric::LastLogIndex(want_log_index), msg.to_string()).await?; + self.eq(Metric::AppliedIndex(want_log_index), msg.to_string()).await } /// Wait until applied at least `want_log`(inclusive) logs or timeout. @@ -152,27 +143,14 @@ where want_log: Option, msg: impl ToString, ) -> Result, WaitError> { - self.metrics( - |x| x.last_log_index >= want_log, - &format!("{} .last_log_index >= {:?}", msg.to_string(), want_log), - ) - .await?; - - self.metrics( - |x| x.last_applied.index() >= want_log, - &format!("{} .last_applied >= {:?}", msg.to_string(), want_log), - ) - .await + self.ge(Metric::LastLogIndex(want_log), msg.to_string()).await?; + self.ge(Metric::AppliedIndex(want_log), msg.to_string()).await } /// Block until the last log index becomes exactly `index`(inclusive) or timeout. #[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))] pub async fn log_index(&self, index: Option, msg: impl ToString) -> Result, WaitError> { - self.metrics( - |m| m.last_log_index == index, - &format!("{} .last_log_index == {:?}", msg.to_string(), index), - ) - .await + self.eq(Metric::LastLogIndex(index), msg).await } /// Block until the last log index becomes at least `index`(inclusive) or timeout. @@ -182,11 +160,7 @@ where index: Option, msg: impl ToString, ) -> Result, WaitError> { - self.metrics( - |m| m.last_log_index >= index, - &format!("{} .last_log_index >= {:?}", msg.to_string(), index), - ) - .await + self.ge(Metric::LastLogIndex(index), msg).await } /// Block until the applied index becomes exactly `index`(inclusive) or timeout. @@ -196,11 +170,7 @@ where index: Option, msg: impl ToString, ) -> Result, WaitError> { - self.metrics( - |m| m.last_applied.index() == index, - &format!("{} .last_applied.index == {:?}", msg.to_string(), index), - ) - .await + self.eq(Metric::AppliedIndex(index), msg).await } /// Block until the last applied log index become at least `index`(inclusive) or timeout. @@ -211,11 +181,7 @@ where index: Option, msg: impl ToString, ) -> Result, WaitError> { - self.metrics( - |m| m.last_log_index >= index && m.last_applied.index() >= index, - &format!("{} .last_applied.index >= {:?}", msg.to_string(), index), - ) - .await + self.ge(Metric::AppliedIndex(index), msg).await } /// Wait for `state` to become `want_state` or timeout. @@ -274,19 +240,48 @@ where snapshot_last_log_id: LogId, msg: impl ToString, ) -> Result, WaitError> { - self.metrics( - |m| m.snapshot == Some(snapshot_last_log_id), - &format!("{} .snapshot == {}", msg.to_string(), snapshot_last_log_id), - ) - .await + self.eq(Metric::Snapshot(Some(snapshot_last_log_id)), msg).await } /// Wait for `purged` to become `want` or timeout. #[tracing::instrument(level = "trace", skip(self), fields(msg=msg.to_string().as_str()))] pub async fn purged(&self, want: Option>, msg: impl ToString) -> Result, WaitError> { + self.eq(Metric::Purged(want), msg).await + } + + /// Block until a metric becomes greater than or equal the specified value or timeout. + /// + /// For example, to await until the term becomes 2 or greater: + /// ```ignore + /// my_raft.wait(None).ge(Metric::Term(2), "become term 2").await?; + /// ``` + pub async fn ge(&self, metric: Metric, msg: impl ToString) -> Result, WaitError> { + self.until(Condition::ge(metric), msg).await + } + + /// Block until a metric becomes equal to the specified value or timeout. + /// + /// For example, to await until the term becomes exact 2: + /// ```ignore + /// my_raft.wait(None).eq(Metric::Term(2), "become term 2").await?; + /// ``` + pub async fn eq(&self, metric: Metric, msg: impl ToString) -> Result, WaitError> { + self.until(Condition::eq(metric), msg).await + } + + /// Block until a metric satisfies the specified condition or timeout. + #[tracing::instrument(level = "trace", skip_all, fields(cond=cond.to_string(), msg=msg.to_string().as_str()))] + pub(crate) async fn until( + &self, + cond: Condition, + msg: impl ToString, + ) -> Result, WaitError> { self.metrics( - |m| m.purged == want, - &format!("{} .purged == {}", msg.to_string(), DisplayOption(&want)), + |raft_metrics| match &cond { + Condition::GE(expect) => raft_metrics >= expect, + Condition::EQ(expect) => raft_metrics == expect, + }, + &format!("{} .{}", msg.to_string(), cond), ) .await } diff --git a/openraft/src/metrics/wait_condition.rs b/openraft/src/metrics/wait_condition.rs new file mode 100644 index 000000000..b488811d9 --- /dev/null +++ b/openraft/src/metrics/wait_condition.rs @@ -0,0 +1,57 @@ +use std::fmt; + +use crate::metrics::metric_display::MetricDisplay; +use crate::metrics::Metric; +use crate::NodeId; + +/// A condition that the application wait for. +#[derive(Debug)] +pub(crate) enum Condition +where NID: NodeId +{ + GE(Metric), + EQ(Metric), +} + +impl Condition +where NID: NodeId +{ + /// Build a new condition which the application will await to meet or exceed. + pub(crate) fn ge(v: Metric) -> Self { + Self::GE(v) + } + + /// Build a new condition which the application will await to meet. + pub(crate) fn eq(v: Metric) -> Self { + Self::EQ(v) + } + + pub(crate) fn name(&self) -> &'static str { + match self { + Condition::GE(v) => v.name(), + Condition::EQ(v) => v.name(), + } + } + + pub(crate) fn op(&self) -> &'static str { + match self { + Condition::GE(_) => ">=", + Condition::EQ(_) => "==", + } + } + + pub(crate) fn value(&self) -> MetricDisplay<'_, NID> { + match self { + Condition::GE(v) => v.value(), + Condition::EQ(v) => v.value(), + } + } +} + +impl fmt::Display for Condition +where NID: NodeId +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}{}{}", self.name(), self.op(), self.value()) + } +}