diff --git a/.config/nats.dic b/.config/nats.dic index 1d5c7f61c..715023e5d 100644 --- a/.config/nats.dic +++ b/.config/nats.dic @@ -133,3 +133,4 @@ ConnectError DNS RequestErrorKind rustls +RttError diff --git a/async-nats/src/client.rs b/async-nats/src/client.rs index 06ef82601..0ab92c916 100644 --- a/async-nats/src/client.rs +++ b/async-nats/src/client.rs @@ -463,6 +463,35 @@ impl Client { Ok(()) } + /// Calculates the round trip time between this client and the server, + /// if the server is currently connected. + /// + /// # Examples + /// + /// ```no_run + /// # #[tokio::main] + /// # async fn main() -> Result<(), async_nats::Error> { + /// let client = async_nats::connect("demo.nats.io").await?; + /// let rtt = client.rtt().await?; + /// println!("server rtt: {:?}", rtt); + /// # Ok(()) + /// # } + /// ``` + pub async fn rtt(&self) -> Result { + let (tx, rx) = tokio::sync::oneshot::channel(); + + self.sender.send(Command::Rtt { result: tx }).await?; + + let rtt = rx + .await + // first handle rx error + .map_err(|err| RttError(Box::new(err)))? + // second handle the actual rtt error + .map_err(|err| RttError(Box::new(err)))?; + + Ok(rtt) + } + /// Returns the current state of the connection. /// /// # Examples @@ -688,3 +717,14 @@ impl From for RequestError { RequestError::with_source(RequestErrorKind::Other, e) } } + +/// Error returned when doing a round-trip time measurement fails. +#[derive(Debug, Error)] +#[error("failed to measure round-trip time: {0}")] +pub struct RttError(#[source] Box); + +impl From> for RttError { + fn from(err: tokio::sync::mpsc::error::SendError) -> Self { + RttError(Box::new(err)) + } +} diff --git a/async-nats/src/lib.rs b/async-nats/src/lib.rs index 58ada931b..4e9b9b578 100644 --- a/async-nats/src/lib.rs +++ b/async-nats/src/lib.rs @@ -105,6 +105,7 @@ use thiserror::Error; use futures::future::FutureExt; use futures::select; use futures::stream::Stream; +use std::time::Instant; use tracing::{debug, error}; use core::fmt; @@ -261,6 +262,9 @@ pub enum Command { }, TryFlush, Connect(ConnectInfo), + Rtt { + result: oneshot::Sender>, + }, } /// `ClientOp` represents all actions of `Client`. @@ -305,6 +309,9 @@ pub(crate) struct ConnectionHandler { info_sender: tokio::sync::watch::Sender, ping_interval: Interval, flush_interval: Interval, + last_ping_time: Option, + last_pong_time: Option, + rtt_senders: Vec>>, } impl ConnectionHandler { @@ -330,6 +337,9 @@ impl ConnectionHandler { info_sender, ping_interval, flush_interval, + last_ping_time: None, + last_pong_time: None, + rtt_senders: Vec::new(), } } @@ -397,6 +407,23 @@ impl ConnectionHandler { } ServerOp::Pong => { debug!("received PONG"); + if self.pending_pings == 1 { + // Do we even need to store the last_pong_time? + self.last_pong_time = Some(Instant::now()); + + while let Some(sender) = self.rtt_senders.pop() { + if let (Some(ping), Some(pong)) = (self.last_ping_time, self.last_pong_time) + { + let rtt = pong.duration_since(ping); + sender.send(Ok(rtt)).map_err(|_| { + io::Error::new( + io::ErrorKind::Other, + "one shot failed to be received", + ) + })?; + } + } + } self.pending_pings = self.pending_pings.saturating_sub(1); } ServerOp::Error(error) => { @@ -509,26 +536,15 @@ impl ConnectionHandler { } } Command::Ping => { - debug!( - "PING command. Pending pings {}, max pings {}", - self.pending_pings, self.max_pings - ); - self.pending_pings += 1; - self.ping_interval.reset(); - - if self.pending_pings > self.max_pings { - debug!( - "pending pings {}, max pings {}. disconnecting", - self.pending_pings, self.max_pings - ); - self.handle_disconnect().await?; - } + self.handle_ping().await?; + } + Command::Rtt { result } => { + self.rtt_senders.push(result); - if let Err(_err) = self.connection.write_op(&ClientOp::Ping).await { - self.handle_disconnect().await?; + if self.pending_pings == 0 { + // do a ping and expect a pong - will calculate rtt when handling the pong + self.handle_ping().await?; } - - self.handle_flush().await?; } Command::Flush { result } => { if let Err(_err) = self.handle_flush().await { @@ -613,8 +629,39 @@ impl ConnectionHandler { Ok(()) } + async fn handle_ping(&mut self) -> Result<(), io::Error> { + debug!( + "PING command. Pending pings {}, max pings {}", + self.pending_pings, self.max_pings + ); + self.pending_pings += 1; + self.ping_interval.reset(); + + if self.pending_pings > self.max_pings { + debug!( + "pending pings {}, max pings {}. disconnecting", + self.pending_pings, self.max_pings + ); + self.handle_disconnect().await?; + } + + if self.pending_pings == 1 { + // start the clock for calculating round trip time + self.last_ping_time = Some(Instant::now()); + } + + if let Err(_err) = self.connection.write_op(&ClientOp::Ping).await { + self.handle_disconnect().await?; + } + + self.handle_flush().await?; + Ok(()) + } + async fn handle_disconnect(&mut self) -> io::Result<()> { self.pending_pings = 0; + self.last_ping_time = None; + self.last_pong_time = None; self.connector.events_tx.try_send(Event::Disconnected).ok(); self.connector.state_tx.send(State::Disconnected).ok(); self.handle_reconnect().await?; diff --git a/async-nats/tests/client_tests.rs b/async-nats/tests/client_tests.rs index 538b78d2a..c4e244fb9 100644 --- a/async-nats/tests/client_tests.rs +++ b/async-nats/tests/client_tests.rs @@ -764,4 +764,15 @@ mod client { drop(servers.remove(0)); rx.recv().await; } + + #[tokio::test] + async fn rtt() { + let server = nats_server::run_basic_server(); + let client = async_nats::connect(server.client_url()).await.unwrap(); + + let rtt = client.rtt().await.unwrap(); + + println!("rtt: {:?}", rtt); + assert!(rtt.as_nanos() > 0); + } }