Skip to content

Commit

Permalink
Add rtt to Client
Browse files Browse the repository at this point in the history
  • Loading branch information
n1ghtmare committed Apr 13, 2023
1 parent ce7f825 commit 00b98c1
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 18 deletions.
1 change: 1 addition & 0 deletions .config/nats.dic
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,4 @@ ConnectError
DNS
RequestErrorKind
rustls
RttError
40 changes: 40 additions & 0 deletions async-nats/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,35 @@ impl Client {
Ok(())
}

/// Calculates the round trip time between this client and the server,
/// if the server is currently connected.
///
/// # Examples
///
/// ```no_run
/// # #[tokio::main]
/// # async fn main() -> Result<(), async_nats::Error> {
/// let client = async_nats::connect("demo.nats.io").await?;
/// let rtt = client.rtt().await?;
/// println!("server rtt: {:?}", rtt);
/// # Ok(())
/// # }
/// ```
pub async fn rtt(&self) -> Result<Duration, RttError> {
let (tx, rx) = tokio::sync::oneshot::channel();

self.sender.send(Command::Rtt { result: tx }).await?;

let rtt = rx
.await
// first handle rx error
.map_err(|err| RttError(Box::new(err)))?
// second handle the actual rtt error
.map_err(|err| RttError(Box::new(err)))?;

Ok(rtt)
}

/// Returns the current state of the connection.
///
/// # Examples
Expand Down Expand Up @@ -688,3 +717,14 @@ impl From<SubscribeError> for RequestError {
RequestError::with_source(RequestErrorKind::Other, e)
}
}

/// Error returned when doing a round-trip time measurement fails.
#[derive(Debug, Error)]
#[error("failed to measure round-trip time: {0}")]
pub struct RttError(#[source] Box<dyn std::error::Error + Send + Sync>);

impl From<tokio::sync::mpsc::error::SendError<Command>> for RttError {
fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
RttError(Box::new(err))
}
}
83 changes: 65 additions & 18 deletions async-nats/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ use thiserror::Error;
use futures::future::FutureExt;
use futures::select;
use futures::stream::Stream;
use std::time::Instant;
use tracing::{debug, error};

use core::fmt;
Expand Down Expand Up @@ -261,6 +262,9 @@ pub enum Command {
},
TryFlush,
Connect(ConnectInfo),
Rtt {
result: oneshot::Sender<Result<Duration, io::Error>>,
},
}

/// `ClientOp` represents all actions of `Client`.
Expand Down Expand Up @@ -305,6 +309,9 @@ pub(crate) struct ConnectionHandler {
info_sender: tokio::sync::watch::Sender<ServerInfo>,
ping_interval: Interval,
flush_interval: Interval,
last_ping_time: Option<Instant>,
last_pong_time: Option<Instant>,
rtt_senders: Vec<oneshot::Sender<Result<Duration, io::Error>>>,
}

impl ConnectionHandler {
Expand All @@ -330,6 +337,9 @@ impl ConnectionHandler {
info_sender,
ping_interval,
flush_interval,
last_ping_time: None,
last_pong_time: None,
rtt_senders: Vec::new(),
}
}

Expand Down Expand Up @@ -397,6 +407,23 @@ impl ConnectionHandler {
}
ServerOp::Pong => {
debug!("received PONG");
if self.pending_pings == 1 {
// Do we even need to store the last_pong_time?
self.last_pong_time = Some(Instant::now());

while let Some(sender) = self.rtt_senders.pop() {
if let (Some(ping), Some(pong)) = (self.last_ping_time, self.last_pong_time)
{
let rtt = pong.duration_since(ping);
sender.send(Ok(rtt)).map_err(|_| {
io::Error::new(
io::ErrorKind::Other,
"one shot failed to be received",
)
})?;
}
}
}
self.pending_pings = self.pending_pings.saturating_sub(1);
}
ServerOp::Error(error) => {
Expand Down Expand Up @@ -509,26 +536,15 @@ impl ConnectionHandler {
}
}
Command::Ping => {
debug!(
"PING command. Pending pings {}, max pings {}",
self.pending_pings, self.max_pings
);
self.pending_pings += 1;
self.ping_interval.reset();

if self.pending_pings > self.max_pings {
debug!(
"pending pings {}, max pings {}. disconnecting",
self.pending_pings, self.max_pings
);
self.handle_disconnect().await?;
}
self.handle_ping().await?;
}
Command::Rtt { result } => {
self.rtt_senders.push(result);

if let Err(_err) = self.connection.write_op(&ClientOp::Ping).await {
self.handle_disconnect().await?;
if self.pending_pings == 0 {
// do a ping and expect a pong - will calculate rtt when handling the pong
self.handle_ping().await?;
}

self.handle_flush().await?;
}
Command::Flush { result } => {
if let Err(_err) = self.handle_flush().await {
Expand Down Expand Up @@ -613,8 +629,39 @@ impl ConnectionHandler {
Ok(())
}

async fn handle_ping(&mut self) -> Result<(), io::Error> {
debug!(
"PING command. Pending pings {}, max pings {}",
self.pending_pings, self.max_pings
);
self.pending_pings += 1;
self.ping_interval.reset();

if self.pending_pings > self.max_pings {
debug!(
"pending pings {}, max pings {}. disconnecting",
self.pending_pings, self.max_pings
);
self.handle_disconnect().await?;
}

if self.pending_pings == 1 {
// start the clock for calculating round trip time
self.last_ping_time = Some(Instant::now());
}

if let Err(_err) = self.connection.write_op(&ClientOp::Ping).await {
self.handle_disconnect().await?;
}

self.handle_flush().await?;
Ok(())
}

async fn handle_disconnect(&mut self) -> io::Result<()> {
self.pending_pings = 0;
self.last_ping_time = None;
self.last_pong_time = None;
self.connector.events_tx.try_send(Event::Disconnected).ok();
self.connector.state_tx.send(State::Disconnected).ok();
self.handle_reconnect().await?;
Expand Down
11 changes: 11 additions & 0 deletions async-nats/tests/client_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -764,4 +764,15 @@ mod client {
drop(servers.remove(0));
rx.recv().await;
}

#[tokio::test]
async fn rtt() {
let server = nats_server::run_basic_server();
let client = async_nats::connect(server.client_url()).await.unwrap();

let rtt = client.rtt().await.unwrap();

println!("rtt: {:?}", rtt);
assert!(rtt.as_nanos() > 0);
}
}

0 comments on commit 00b98c1

Please sign in to comment.