diff --git a/Cargo.lock b/Cargo.lock index 599621943cda..e0ff4e29a60d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4591,6 +4591,7 @@ dependencies = [ name = "talpid-wireguard" version = "0.0.0" dependencies = [ + "async-trait", "bitflags 1.3.2", "byteorder", "chrono", diff --git a/talpid-wireguard/Cargo.toml b/talpid-wireguard/Cargo.toml index e02bf874d253..3a19f5a70afa 100644 --- a/talpid-wireguard/Cargo.toml +++ b/talpid-wireguard/Cargo.toml @@ -11,6 +11,7 @@ rust-version.workspace = true workspace = true [dependencies] +async-trait = "0.1" thiserror = { workspace = true } futures = { workspace = true } hex = "0.4" diff --git a/talpid-wireguard/src/connectivity/check.rs b/talpid-wireguard/src/connectivity/check.rs index 702ce97f2d4d..d44bd7c6ae21 100644 --- a/talpid-wireguard/src/connectivity/check.rs +++ b/talpid-wireguard/src/connectivity/check.rs @@ -1,7 +1,9 @@ -use std::cmp; use std::net::Ipv4Addr; -use std::sync::mpsc; -use std::time::{Duration, Instant}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::broadcast; +use tokio::time::Instant; use super::constants::*; use super::error::Error; @@ -35,52 +37,69 @@ use pinger::Pinger; /// /// Once a connection established, a connection is only considered broken once the connectivity /// monitor has started pinging and no traffic has been received for a duration of `PING_TIMEOUT`. -pub struct Check { +pub struct Check { conn_state: ConnState, ping_state: PingState, - strategy: Strategy, + cancel_receiver: CancelReceiver, retry_attempt: u32, } -// Define the type state of [Check] -pub(crate) trait Strategy { - fn should_shut_down(&mut self, timeout: Duration) -> bool; +/// A handle that can be used to shut down the connectivity monitor. +#[derive(Debug, Clone)] +pub struct CancelToken { + closed: Arc, + tx: broadcast::Sender<()>, } -/// An uncancellable [Check] that will run [Check::establish_connectivity] until -/// completion or until it times out. -pub struct Timeout; +/// A handle that can be passed to a [Check]. The corresponding [CancelToken] causes the [Check] to +/// be stopped. Any [CancelToken] will cancel all receivers +#[derive(Debug)] +pub struct CancelReceiver { + closed: Arc, + rx: broadcast::Receiver<()>, +} -impl Strategy for Timeout { - /// The Timeout strategy cannot receive shut down signals so this function always returns false. - fn should_shut_down(&mut self, _timeout: Duration) -> bool { - false +impl CancelReceiver { + fn closed(&self) -> bool { + self.closed.load(Ordering::SeqCst) } } -/// A cancellable [Check] may be cancelled before it will time out by sending -/// a signal on the channel returned by [Check::with_cancellation]. Otherwise, -/// it behaves as [Timeout]. -pub struct Cancellable { - close_receiver: mpsc::Receiver<()>, +impl Clone for CancelReceiver { + fn clone(&self) -> Self { + Self { + closed: self.closed.clone(), + rx: self.rx.resubscribe(), + } + } } -impl Strategy for Cancellable { - /// Returns true if monitor should be shut down - fn should_shut_down(&mut self, timeout: Duration) -> bool { - match self.close_receiver.recv_timeout(timeout) { - Ok(()) | Err(mpsc::RecvTimeoutError::Disconnected) => true, - Err(mpsc::RecvTimeoutError::Timeout) => false, - } +impl CancelToken { + pub fn new() -> (Self, CancelReceiver) { + let (tx, rx) = broadcast::channel(1); + let closed = Arc::new(AtomicBool::new(false)); + ( + CancelToken { + closed: closed.clone(), + tx, + }, + CancelReceiver { closed, rx }, + ) + } + + pub fn close(&self) { + self.closed.store(true, Ordering::SeqCst); + let _ = self.tx.send(()); } } -impl Check { +impl Check { pub fn new( addr: Ipv4Addr, #[cfg(any(target_os = "macos", target_os = "linux"))] interface: String, retry_attempt: u32, - ) -> Result, Error> { + cancel_receiver: CancelReceiver, + ) -> Result { Ok(Check { conn_state: ConnState::new(Instant::now(), Default::default()), ping_state: PingState::new( @@ -88,47 +107,37 @@ impl Check { #[cfg(any(target_os = "macos", target_os = "linux"))] interface, )?, - strategy: Timeout, retry_attempt, + cancel_receiver, }) } - /// Cancel a [Check] preemptively by sennding a message on the channel or by dropping - /// the returned channel. - pub fn with_cancellation(self) -> (Check, mpsc::Sender<()>) { - let (cancellation_tx, cancellation_rx) = mpsc::channel(); - let check = Check { - conn_state: self.conn_state, - ping_state: self.ping_state, - strategy: Cancellable { - close_receiver: cancellation_rx, - }, - retry_attempt: self.retry_attempt, - }; - (check, cancellation_tx) - } - #[cfg(test)] - /// Create a new [Check] with a custom initial state. To use the [Cancellable] strategy, - /// see [Check::with_cancellation]. - pub(super) fn mock(conn_state: ConnState, ping_state: PingState) -> Self { - Check { - conn_state, - ping_state, - strategy: Timeout, - retry_attempt: 0, - } + /// Create a new [Check] with a custom initial state. + pub(super) fn mock(conn_state: ConnState, ping_state: PingState) -> (Self, CancelToken) { + let (cancel_token, cancel_receiver) = CancelToken::new(); + ( + Check { + conn_state, + ping_state, + retry_attempt: 0, + cancel_receiver, + }, + cancel_token, + ) } -} -impl Check { // checks if the tunnel has ever worked. Intended to check if a connection to a tunnel is // successful at the start of a connection. - pub fn establish_connectivity(&mut self, tunnel_handle: &TunnelType) -> Result { + pub async fn establish_connectivity( + &mut self, + tunnel_handle: &TunnelType, + ) -> Result { // Send initial ping to prod WireGuard into connecting. self.ping_state .pinger .send_icmp() + .await .map_err(Error::PingError)?; self.establish_connectivity_inner( self.retry_attempt, @@ -137,18 +146,15 @@ impl Check { MAX_ESTABLISH_TIMEOUT, tunnel_handle, ) + .await } - pub(crate) fn reset(&mut self, current_iteration: Instant) { - self.ping_state.reset(); + pub(crate) async fn reset(&mut self, current_iteration: Instant) { + self.ping_state.reset().await; self.conn_state.reset_after_suspension(current_iteration); } - pub(crate) fn should_shut_down(&mut self, timeout: Duration) -> bool { - self.strategy.should_shut_down(timeout) - } - - fn establish_connectivity_inner( + async fn establish_connectivity_inner( &mut self, retry_attempt: u32, timeout_initial: Duration, @@ -160,59 +166,96 @@ impl Check { return Ok(true); } - let check_timeout = cmp::min( - max_timeout, - timeout_initial.saturating_mul(timeout_multiplier.saturating_pow(retry_attempt)), - ); + let check_timeout = max_timeout + .min(timeout_initial.saturating_mul(timeout_multiplier.saturating_pow(retry_attempt))); + + // Begin polling tunnel traffic stats periodically + let poll_check = async { + loop { + if Self::check_connectivity_interval( + &mut self.conn_state, + &mut self.ping_state, + Instant::now(), + check_timeout, + tunnel_handle, + ) + .await? + { + return Ok(true); + } + tokio::time::sleep(Duration::from_millis(20)).await; + } + }; + + let timeout = tokio::time::sleep(check_timeout); - let start = Instant::now(); - while start.elapsed() < check_timeout { - if self.check_connectivity_interval(Instant::now(), check_timeout, tunnel_handle)? { - return Ok(true); + tokio::select! { + // Tunnel status polling returned a result + result = poll_check => { + result } - if self.should_shut_down(DELAY_ON_INITIAL_SETUP) { - return Ok(false); + + // Cancel token signal + _ = self.cancel_receiver.rx.recv() => { + Ok(false) + } + + // Give up if the timeout is hit + _ = timeout => { + Ok(false) } } - Ok(false) + } + + pub(crate) fn should_shut_down(&self) -> bool { + self.cancel_receiver.closed() } /// Returns true if connection is established - pub(crate) fn check_connectivity( + pub(crate) async fn check_connectivity( &mut self, now: Instant, tunnel_handle: &TunnelType, ) -> Result { - self.check_connectivity_interval(now, PING_TIMEOUT, tunnel_handle) + Self::check_connectivity_interval( + &mut self.conn_state, + &mut self.ping_state, + now, + PING_TIMEOUT, + tunnel_handle, + ) + .await } /// Returns true if connection is established - fn check_connectivity_interval( - &mut self, + async fn check_connectivity_interval( + conn_state: &mut ConnState, + ping_state: &mut PingState, now: Instant, timeout: Duration, tunnel_handle: &TunnelType, ) -> Result { - match Self::get_stats(tunnel_handle).map_err(Error::ConfigReadError)? { + match Self::get_stats(tunnel_handle) + .await + .map_err(Error::ConfigReadError)? + { None => Ok(false), Some(new_stats) => { - if self.conn_state.update(now, new_stats) { - self.ping_state.reset(); + if conn_state.update(now, new_stats) { + ping_state.reset().await; return Ok(true); } - self.maybe_send_ping(now)?; - Ok(!self.ping_state.ping_timed_out(timeout) && self.conn_state.connected()) + Self::maybe_send_ping(conn_state, ping_state, now).await?; + Ok(!ping_state.ping_timed_out(timeout) && conn_state.connected()) } } } /// If None is returned, then the underlying tunnel has already been closed and all subsequent /// calls will also return None. - /// - /// NOTE: will panic if called from within a tokio runtime. - fn get_stats(tunnel_handle: &TunnelType) -> Result, TunnelError> { - let stats = tunnel_handle.get_tunnel_stats()?; + async fn get_stats(tunnel_handle: &TunnelType) -> Result, TunnelError> { + let stats = tunnel_handle.get_tunnel_stats().await?; if stats.is_empty() { log::error!("Tunnel unexpectedly shut down"); Ok(None) @@ -221,28 +264,31 @@ impl Check { } } - fn maybe_send_ping(&mut self, now: Instant) -> Result<(), Error> { + async fn maybe_send_ping( + conn_state: &mut ConnState, + ping_state: &mut PingState, + now: Instant, + ) -> Result<(), Error> { // Only send out a ping if we haven't received a byte in a while or no traffic has flowed // in the last 2 minutes, but if a ping already has been sent out, only send one out every // 3 seconds. - if (self.conn_state.rx_timed_out() || self.conn_state.traffic_timed_out()) - && self - .ping_state + if (conn_state.rx_timed_out() || conn_state.traffic_timed_out()) + && ping_state .initial_ping_timestamp .map(|initial_ping_timestamp| { - initial_ping_timestamp.elapsed() / self.ping_state.num_pings_sent - < SECONDS_PER_PING + initial_ping_timestamp.elapsed() / ping_state.num_pings_sent < SECONDS_PER_PING }) .unwrap_or(true) { - self.ping_state + ping_state .pinger .send_icmp() + .await .map_err(Error::PingError)?; - if self.ping_state.initial_ping_timestamp.is_none() { - self.ping_state.initial_ping_timestamp = Some(now); + if ping_state.initial_ping_timestamp.is_none() { + ping_state.initial_ping_timestamp = Some(now); } - self.ping_state.num_pings_sent += 1; + ping_state.num_pings_sent += 1; } Ok(()) } @@ -284,10 +330,10 @@ impl PingState { } /// Reset timeouts - assume that the last time bytes were received is now. - fn reset(&mut self) { + async fn reset(&mut self) { self.initial_ping_timestamp = None; self.num_pings_sent = 0; - self.pinger.reset(); + self.pinger.reset().await; } } @@ -420,6 +466,8 @@ impl ConnState { #[cfg(test)] mod test { + use tokio::sync::mpsc; + use super::*; use crate::connectivity::mock::*; @@ -527,100 +575,115 @@ mod test { assert!(!conn_state.traffic_timed_out()); } - #[test] + #[tokio::test] /// Verify that `check_connectivity()` returns `false` if the tunnel is connected and traffic is /// not flowing after `BYTES_RX_TIMEOUT` and `PING_TIMEOUT`. - fn test_ping_times_out() { + async fn test_ping_times_out() { let tunnel = MockTunnel::never_incrementing().boxed(); let pinger = MockPinger::default(); let now = Instant::now(); let start = now .checked_sub(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(10)) .unwrap(); - let mut checker = mock_checker(start, Box::new(pinger)); + let (mut checker, _cancel_token) = mock_checker(start, Box::new(pinger)); // Mock the state - connectivity has been established checker.conn_state = connected_state(start); // A ping was sent to verify connectivity - checker.maybe_send_ping(start).unwrap(); - assert!(!checker.check_connectivity(now, &tunnel).unwrap()) + Check::maybe_send_ping(&mut checker.conn_state, &mut checker.ping_state, start) + .await + .unwrap(); + assert!(!checker.check_connectivity(now, &tunnel).await.unwrap()) } - #[test] + #[tokio::test] /// Verify that `check_connectivity()` returns `true` if the tunnel is connected and traffic is /// flowing constantly. - fn test_no_connection_on_start() { + async fn test_no_connection_on_start() { let tunnel = MockTunnel::never_incrementing().boxed(); let pinger = MockPinger::default(); let now = Instant::now(); let start = now.checked_sub(Duration::from_secs(1)).unwrap(); - let mut monitor = mock_checker(start, Box::new(pinger)); + let (mut checker, _cancel_token) = mock_checker(start, Box::new(pinger)); - assert!(!monitor.check_connectivity(now, &tunnel).unwrap()) + assert!(!checker.check_connectivity(now, &tunnel).await.unwrap()) } - #[test] + #[tokio::test] /// Verify that `check_connectivity()` returns `true` if the tunnel is connected and traffic is /// flowing constantly. - fn test_connection_works() { + async fn test_connection_works() { let tunnel = MockTunnel::always_incrementing().boxed(); let pinger = MockPinger::default(); let now = Instant::now(); let start = now.checked_sub(Duration::from_secs(1)).unwrap(); - let mut monitor = mock_checker(start, Box::new(pinger)); + let (mut checker, _cancel_token) = mock_checker(start, Box::new(pinger)); // Mock the state - connectivity has been established - monitor.conn_state = connected_state(start); + checker.conn_state = connected_state(start); - assert!(monitor.check_connectivity(now, &tunnel).unwrap()) + assert!(checker.check_connectivity(now, &tunnel).await.unwrap()) } - #[test] + #[tokio::test(start_paused = true)] /// Verify that the timeout for setting up a tunnel works as expected. - fn test_establish_timeout() { - let pinger = MockPinger::default(); - let tunnel = { - let mut tunnel_stats = StatsMap::new(); - tunnel_stats.insert( - [0u8; 32], - Stats { - tx_bytes: 0, - rx_bytes: 0, - }, - ); - MockTunnel::new(move || Ok(tunnel_stats.clone())).boxed() - }; + async fn test_establish_timeout() { + const ESTABLISH_TIMEOUT_MULTIPLIER: u32 = 2; + const ESTABLISH_TIMEOUT: Duration = Duration::from_millis(500); + const MAX_ESTABLISH_TIMEOUT: Duration = Duration::from_secs(2); - let (result_tx, result_rx) = mpsc::channel(); + let (result_tx, mut result_rx) = mpsc::channel(1); - std::thread::spawn(move || { + tokio::spawn(async move { + let pinger = MockPinger::default(); let now = Instant::now(); let start = now.checked_sub(Duration::from_secs(1)).unwrap(); - let mut monitor = mock_checker(start, Box::new(pinger)); - - const ESTABLISH_TIMEOUT_MULTIPLIER: u32 = 2; - const ESTABLISH_TIMEOUT: Duration = Duration::from_millis(500); - const MAX_ESTABLISH_TIMEOUT: Duration = Duration::from_secs(2); - - for attempt in 0..4 { - result_tx - .send(monitor.establish_connectivity_inner( - attempt, - ESTABLISH_TIMEOUT, - ESTABLISH_TIMEOUT_MULTIPLIER, - MAX_ESTABLISH_TIMEOUT, - &tunnel, - )) - .unwrap(); - } + let (mut monitor, _cancel_token) = mock_checker(start, Box::new(pinger)); + + let tunnel = { + let mut tunnel_stats = StatsMap::new(); + tunnel_stats.insert( + [0u8; 32], + Stats { + tx_bytes: 0, + rx_bytes: 0, + }, + ); + MockTunnel::new(move || Ok(tunnel_stats.clone())).boxed() + }; + + result_tx + .send( + monitor + .establish_connectivity_inner( + 0, + ESTABLISH_TIMEOUT, + ESTABLISH_TIMEOUT_MULTIPLIER, + MAX_ESTABLISH_TIMEOUT, + &tunnel, + ) + .await, + ) + .await + .unwrap(); }); - let err = DELAY_ON_INITIAL_SETUP + Duration::from_millis(350); - let assert_rx = |recv_timeout: Duration| { - assert!(!result_rx.recv_timeout(recv_timeout + err).unwrap().unwrap()); - }; - assert_rx(Duration::from_millis(500)); - assert_rx(Duration::from_secs(1)); - assert_rx(Duration::from_secs(2)); - assert_rx(Duration::from_secs(2)); + + tokio::time::timeout( + ESTABLISH_TIMEOUT - Duration::from_millis(100), + result_rx.recv(), + ) + .await + .expect_err("expected timeout"); + + // Should assume no connectivity after timeout + let connected = tokio::time::timeout( + ESTABLISH_TIMEOUT + Duration::from_millis(100), + result_rx.recv(), + ) + .await + .expect("expected no timeout") + .unwrap() + .unwrap(); + assert!(!connected); } } diff --git a/talpid-wireguard/src/connectivity/constants.rs b/talpid-wireguard/src/connectivity/constants.rs index a8d6752dddd7..28c8acf1a5e0 100644 --- a/talpid-wireguard/src/connectivity/constants.rs +++ b/talpid-wireguard/src/connectivity/constants.rs @@ -1,7 +1,5 @@ use std::time::Duration; -/// Sleep time used when initially establishing connectivity -pub(crate) const DELAY_ON_INITIAL_SETUP: Duration = Duration::from_millis(50); /// Timeout for waiting on receiving traffic after sending outgoing traffic. Once this timeout is /// hit, a ping will be sent every `SECONDS_PER_PING` until `PING_TIMEOUT` is reached, or traffic /// is received. diff --git a/talpid-wireguard/src/connectivity/mock.rs b/talpid-wireguard/src/connectivity/mock.rs index eea3004bfc71..5b7c98b18300 100644 --- a/talpid-wireguard/src/connectivity/mock.rs +++ b/talpid-wireguard/src/connectivity/mock.rs @@ -1,8 +1,8 @@ use std::future::Future; use std::pin::Pin; -use std::time::Instant; +use tokio::time::Instant; -use super::check::{ConnState, PingState, Timeout}; +use super::check::{CancelToken, ConnState, PingState}; use super::pinger; use super::Check; @@ -14,14 +14,14 @@ pub use crate::stats::{Stats, StatsMap}; #[derive(Default)] pub(crate) struct MockPinger { - on_send_ping: Option>, + on_send_ping: Option>, } pub(crate) struct MockTunnel { - on_get_stats: Box Result + Send>, + on_get_stats: Box Result + Send + Sync>, } -pub fn mock_checker(now: Instant, pinger: Box) -> Check { +pub fn mock_checker(now: Instant, pinger: Box) -> (Check, CancelToken) { let conn_state = ConnState::new(now, Default::default()); let ping_state = PingState::new_with(pinger); Check::mock(conn_state, ping_state) @@ -47,7 +47,7 @@ pub fn connected_state(timestamp: Instant) -> ConnState { impl MockTunnel { const PEER: [u8; 32] = [0u8; 32]; - pub fn new Result + Send + 'static>(f: F) -> Self { + pub fn new Result + Send + Sync + 'static>(f: F) -> Self { Self { on_get_stats: Box::new(f), } @@ -97,6 +97,7 @@ impl MockTunnel { } } +#[async_trait::async_trait] impl Tunnel for MockTunnel { fn get_interface_name(&self) -> String { "mock-tunnel".to_string() @@ -106,7 +107,7 @@ impl Tunnel for MockTunnel { Ok(()) } - fn get_tunnel_stats(&self) -> Result { + async fn get_tunnel_stats(&self) -> Result { (self.on_get_stats)() } @@ -126,8 +127,9 @@ impl Tunnel for MockTunnel { } } +#[async_trait::async_trait] impl Pinger for MockPinger { - fn send_icmp(&mut self) -> Result<(), pinger::Error> { + async fn send_icmp(&mut self) -> Result<(), pinger::Error> { if let Some(callback) = self.on_send_ping.as_mut() { (callback)(); } diff --git a/talpid-wireguard/src/connectivity/mod.rs b/talpid-wireguard/src/connectivity/mod.rs index 512d8715f17d..2da555ad4565 100644 --- a/talpid-wireguard/src/connectivity/mod.rs +++ b/talpid-wireguard/src/connectivity/mod.rs @@ -7,7 +7,7 @@ mod monitor; mod pinger; #[cfg(target_os = "android")] -pub use check::Cancellable; -pub use check::Check; +pub use check::CancelReceiver; +pub use check::{CancelToken, Check}; pub use error::Error; pub use monitor::Monitor; diff --git a/talpid-wireguard/src/connectivity/monitor.rs b/talpid-wireguard/src/connectivity/monitor.rs index 583b8d9589db..1272b43f4db2 100644 --- a/talpid-wireguard/src/connectivity/monitor.rs +++ b/talpid-wireguard/src/connectivity/monitor.rs @@ -1,69 +1,73 @@ -use std::{ - sync::Weak, - time::{Duration, Instant}, -}; +use std::{sync::Weak, time::Duration}; use tokio::sync::Mutex; +use tokio::time::{Instant, MissedTickBehavior}; use crate::TunnelType; -use super::check::{Cancellable, Check}; +use super::check::Check; use super::error::Error; /// Sleep time used when checking if an established connection is still working. const REGULAR_LOOP_SLEEP: Duration = Duration::from_secs(1); +/// Reset the checker if the last check occurred this long ago +const SUSPEND_TIMEOUT: Duration = Duration::from_secs(6); + pub struct Monitor { - connectivity_check: Check, + connectivity_check: Check, } impl Monitor { - pub fn init(connectivity_check: Check) -> Self { + pub fn init(connectivity_check: Check) -> Self { Self { connectivity_check } } - pub fn run(self, tunnel_handle: Weak>>) -> Result<(), Error> { - self.wait_loop(REGULAR_LOOP_SLEEP, tunnel_handle) - } - - fn wait_loop( + pub async fn run( mut self, - iter_delay: Duration, tunnel_handle: Weak>>, ) -> Result<(), Error> { - let mut last_iteration = Instant::now(); - while !self.connectivity_check.should_shut_down(iter_delay) { - let mut current_iteration = Instant::now(); - let time_slept = current_iteration - last_iteration; - if time_slept < (iter_delay * 2) { - let Some(tunnel) = tunnel_handle.upgrade() else { - return Ok(()); - }; - let lock = tunnel.blocking_lock(); - let Some(tunnel) = lock.as_ref() else { - return Ok(()); - }; - - if !self - .connectivity_check - .check_connectivity(Instant::now(), tunnel)? - { - return Ok(()); - } - drop(lock); + let mut last_check = Instant::now(); - let end = Instant::now(); - if end - current_iteration > Duration::from_secs(1) { - current_iteration = end; - } - } else { - // Loop was suspended for too long, so it's safer to assume that the host still has - // connectivity. - self.connectivity_check.reset(current_iteration); + let mut interval = tokio::time::interval(REGULAR_LOOP_SLEEP); + interval.set_missed_tick_behavior(MissedTickBehavior::Delay); + + loop { + if self.connectivity_check.should_shut_down() { + return Ok(()); + } + + let now = Instant::now(); + let time_slept = now - last_check; + last_check = now; + + if time_slept >= SUSPEND_TIMEOUT { + self.connectivity_check.reset(now).await; + } else if !self.tunnel_exists_and_is_connected(&tunnel_handle).await? { + return Ok(()); } - last_iteration = current_iteration; + + interval.tick().await; } - Ok(()) + } + + async fn tunnel_exists_and_is_connected( + &mut self, + tunnel_handle: &Weak>>, + ) -> Result { + let Some(tunnel) = tunnel_handle.upgrade() else { + // Tunnel closed + return Ok(false); + }; + let lock = tunnel.lock().await; + let Some(tunnel) = lock.as_ref() else { + // Tunnel closed + return Ok(false); + }; + + self.connectivity_check + .check_connectivity(Instant::now(), tunnel) + .await } } @@ -71,54 +75,52 @@ impl Monitor { mod test { use super::*; - // TODO: Port to async + tokio to reduce cost of testing? use std::sync::atomic::{AtomicBool, Ordering}; - use std::sync::mpsc; use std::sync::Arc; use std::time::Duration; - use std::time::Instant; + use tokio::sync::mpsc; use tokio::sync::Mutex; use crate::connectivity::constants::*; use crate::connectivity::mock::*; - #[test] + #[tokio::test(start_paused = true)] /// Verify that the connectivity monitor doesn't fail if the tunnel constantly sends traffic, /// and it shuts down properly. - fn test_wait_loop() { - use std::sync::mpsc; - let (result_tx, result_rx) = mpsc::channel(); + async fn test_wait_loop() { + let (result_tx, mut result_rx) = mpsc::channel(1); let tunnel = MockTunnel::always_incrementing().boxed(); let pinger = MockPinger::default(); let (mut checker, stop_tx) = { let now = Instant::now(); let start = now.checked_sub(Duration::from_secs(1)).unwrap(); - mock_checker(start, Box::new(pinger)).with_cancellation() + mock_checker(start, Box::new(pinger)) }; - std::thread::spawn(move || { - let start_result = checker.establish_connectivity(&tunnel); - result_tx.send(start_result).unwrap(); + + tokio::spawn(async move { + let start_result = checker.establish_connectivity(&tunnel).await; + result_tx.send(start_result).await.unwrap(); // Pointer dance let tunnel = Arc::new(Mutex::new(Some(tunnel))); let _tunnel = Arc::downgrade(&tunnel); - let result = Monitor::init(checker).run(_tunnel).map(|_| true); - result_tx.send(result).unwrap(); + let result = Monitor::init(checker).run(_tunnel).await.map(|_| true); + result_tx.send(result).await.unwrap(); }); - std::thread::sleep(Duration::from_secs(1)); + tokio::time::sleep(Duration::from_secs(1)).await; assert!(result_rx.try_recv().unwrap().unwrap()); - stop_tx.send(()).unwrap(); - std::thread::sleep(Duration::from_secs(1)); + stop_tx.close(); + tokio::time::sleep(Duration::from_secs(2)).await; assert!(result_rx.try_recv().unwrap().is_ok()); } - #[test] + #[tokio::test(start_paused = true)] /// Verify that the connectivity monitor detects the tunnel timing out after no longer than /// `BYTES_RX_TIMEOUT` and `PING_TIMEOUT` combined. - fn test_wait_loop_timeout() { - let should_stop = Arc::new(AtomicBool::new(false)); - let should_stop_inner = should_stop.clone(); + async fn test_wait_loop_timeout() { + let stop_bytes_rx = Arc::new(AtomicBool::new(false)); + let stop_bytes_rx_inner = stop_bytes_rx.clone(); let mut map = StatsMap::new(); map.insert( @@ -133,7 +135,7 @@ mod test { let pinger = MockPinger::default(); let tunnel = MockTunnel::new(move || { let mut tunnel_stats = tunnel_stats.lock().unwrap(); - if !should_stop_inner.load(Ordering::SeqCst) { + if !stop_bytes_rx_inner.load(Ordering::SeqCst) { for traffic in tunnel_stats.values_mut() { traffic.rx_bytes += 1; } @@ -145,30 +147,41 @@ mod test { }) .boxed(); - let (result_tx, result_rx) = mpsc::channel(); + let (result_tx, mut result_rx) = mpsc::channel(1); - std::thread::spawn(move || { + tokio::spawn(async move { let (mut checker, _cancellation_token) = { let now = Instant::now(); let start = now.checked_sub(Duration::from_secs(1)).unwrap(); - mock_checker(start, Box::new(pinger)).with_cancellation() + mock_checker(start, Box::new(pinger)) }; - let start_result = checker.establish_connectivity(&tunnel); - result_tx.send(start_result).unwrap(); + let start_result = checker.establish_connectivity(&tunnel).await; + result_tx.send(start_result).await.unwrap(); // Pointer dance let _tunnel = Arc::new(Mutex::new(Some(tunnel))); let tunnel = Arc::downgrade(&_tunnel); - let end_result = Monitor::init(checker).run(tunnel).map(|_| true); - result_tx.send(end_result).expect("Failed to send result"); + let end_result = Monitor::init(checker).run(tunnel).await.map(|_| true); + result_tx + .send(end_result) + .await + .expect("Failed to send result"); }); - assert!(result_rx - .recv_timeout(Duration::from_secs(1)) - .unwrap() - .unwrap()); - should_stop.store(true, Ordering::SeqCst); - assert!(result_rx - .recv_timeout(BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(2)) - .unwrap() - .is_ok()); + + assert!( + tokio::time::timeout(Duration::from_secs(1), result_rx.recv()) + .await + .unwrap() + .unwrap() + .unwrap() + ); + stop_bytes_rx.store(true, Ordering::SeqCst); + assert!(tokio::time::timeout( + BYTES_RX_TIMEOUT + PING_TIMEOUT + Duration::from_secs(2), + result_rx.recv() + ) + .await + .unwrap() + .unwrap() + .is_ok()); } } diff --git a/talpid-wireguard/src/connectivity/pinger/android.rs b/talpid-wireguard/src/connectivity/pinger/android.rs index 00ad4d8fd379..34e28f8891f9 100644 --- a/talpid-wireguard/src/connectivity/pinger/android.rs +++ b/talpid-wireguard/src/connectivity/pinger/android.rs @@ -1,4 +1,9 @@ -use std::{io, net::Ipv4Addr}; +use std::net::Ipv4Addr; +use std::process::Stdio; +use std::time::Duration; + +use tokio::io; +use tokio::process::{Child, Command}; /// Pinger errors #[derive(thiserror::Error, Debug)] @@ -15,7 +20,7 @@ pub enum Error { /// A pinger that sends ICMP requests without waiting for responses pub struct Pinger { addr: Ipv4Addr, - processes: Vec, + processes: Vec, } impl Pinger { @@ -28,60 +33,40 @@ impl Pinger { } fn try_deplete_process_list(&mut self) { - self.processes.retain(|child| { - match child.try_wait() { - // child has terminated, doesn't have to be retained - Ok(Some(_)) => false, - _ => true, - } + self.processes.retain_mut(|child| { + // retain non-terminated children + matches!(child.try_wait(), Err(_) | Ok(None)) }); } } +#[async_trait::async_trait] impl super::Pinger for Pinger { // Send an ICMP packet without waiting for a reply - fn send_icmp(&mut self) -> Result<(), Error> { + async fn send_icmp(&mut self) -> Result<(), Error> { self.try_deplete_process_list(); - let cmd = ping_cmd(self.addr, 1); - let handle = cmd.start().map_err(Error::PingError)?; - self.processes.push(handle); + let child = ping_cmd(self.addr, Duration::from_secs(1)).map_err(Error::PingError)?; + self.processes.push(child); Ok(()) } - fn reset(&mut self) { - let processes = std::mem::take(&mut self.processes); - for proc in processes { - if proc - .try_wait() - .map(|maybe_stopped| maybe_stopped.is_none()) - .unwrap_or(false) - { - if let Err(err) = proc.kill() { - log::error!("Failed to kill ping process: {}", err); - } - } - } + async fn reset(&mut self) { + self.processes.clear(); } } -impl Drop for Pinger { - fn drop(&mut self) { - for child in self.processes.iter_mut() { - if let Err(e) = child.kill() { - log::error!("Failed to kill ping process: {}", e); - } - } - } -} +fn ping_cmd(ip: Ipv4Addr, timeout: Duration) -> io::Result { + let mut cmd = Command::new("ping"); -fn ping_cmd(ip: Ipv4Addr, timeout_secs: u16) -> duct::Expression { - let timeout_secs = timeout_secs.to_string(); + let timeout_secs = timeout.as_secs().to_string(); let ip = ip.to_string(); - let args = ["-n", "-i", "1", "-w", &timeout_secs, &ip]; + cmd.args(["-n", "-i", "1", "-w", &timeout_secs, &ip]); + + cmd.stdin(Stdio::null()) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .kill_on_drop(true); - duct::cmd("ping", args) - .stdin_null() - .stdout_null() - .unchecked() + cmd.spawn() } diff --git a/talpid-wireguard/src/connectivity/pinger/icmp.rs b/talpid-wireguard/src/connectivity/pinger/icmp.rs index 0e5d73942527..b17ee7ddc1b7 100644 --- a/talpid-wireguard/src/connectivity/pinger/icmp.rs +++ b/talpid-wireguard/src/connectivity/pinger/icmp.rs @@ -1,11 +1,11 @@ use byteorder::{NetworkEndian, WriteBytesExt}; use rand::Rng; use socket2::{Domain, Protocol, Socket, Type}; +use tokio::net::UdpSocket; use std::{ io::{self, Write}, net::{Ipv4Addr, SocketAddr}, - thread, time::Duration, }; @@ -30,6 +30,10 @@ pub enum Error { #[error("Failed to write to socket")] Write(#[source] io::Error), + /// Failed to convert to tokio socket + #[error("Failed to convert to tokio socket")] + ConvertSocket(#[source] io::Error), + /// Failed to get device index #[cfg(target_os = "macos")] #[error("Failed to obtain device index")] @@ -43,16 +47,12 @@ pub enum Error { /// ICMP buffer too small #[error("ICMP message buffer too small")] BufferTooSmall, - - /// Interface name contains null bytes - #[error("Interface name contains a null byte")] - InterfaceNameContainsNull, } type Result = std::result::Result; pub struct Pinger { - sock: Socket, + sock: UdpSocket, addr: SocketAddr, id: u16, seq: u16, @@ -76,6 +76,9 @@ impl Pinger { #[cfg(target_os = "macos")] Self::set_device_index(&sock, &interface_name)?; + let sock = + UdpSocket::from_std(std::net::UdpSocket::from(sock)).map_err(Error::ConvertSocket)?; + Ok(Self { sock, addr, @@ -96,25 +99,19 @@ impl Pinger { Ok(()) } - fn send_ping_request(&mut self, message: &[u8], destination: SocketAddr) -> Result<()> { + async fn send_ping_request(&mut self, message: &[u8], destination: SocketAddr) -> Result<()> { let mut tries = 0; - let mut result = Ok(()); - while tries < SEND_RETRY_ATTEMPTS { - match self.sock.send_to(message, &destination.into()) { - Ok(_) => { - return Ok(()); - } - Err(err) => { - if Some(10065) != err.raw_os_error() { - return Err(Error::Write(err)); - } - result = Err(Error::Write(err)); - } + loop { + let Err(error) = self.sock.send_to(message, destination).await else { + return Ok(()); + }; + if tries >= SEND_RETRY_ATTEMPTS || !should_retry_send(&error) { + return Err(Error::Write(error)); } - thread::sleep(Duration::from_secs(1)); + + tokio::time::sleep(Duration::from_secs(1)).await; tries += 1; } - result } fn construct_icmpv4_packet(&mut self, buffer: &mut [u8]) -> Result<()> { @@ -125,11 +122,25 @@ impl Pinger { } } +#[cfg(windows)] +fn should_retry_send(err: &io::Error) -> bool { + // Winsock error for when there is no route + // NOTE: It's unclear if we need to check this on Windows anymore, or why specifically on Windows + const WSAEHOSTUNREACH: i32 = 10065; + err.raw_os_error() == Some(WSAEHOSTUNREACH) +} + +#[cfg(unix)] +fn should_retry_send(_err: &io::Error) -> bool { + false +} + +#[async_trait::async_trait] impl super::Pinger for Pinger { - fn send_icmp(&mut self) -> Result<()> { + async fn send_icmp(&mut self) -> Result<()> { let mut message = [0u8; 50]; self.construct_icmpv4_packet(&mut message)?; - self.send_ping_request(&message, self.addr) + self.send_ping_request(&message, self.addr).await } } diff --git a/talpid-wireguard/src/connectivity/pinger/mod.rs b/talpid-wireguard/src/connectivity/pinger/mod.rs index ef2394f1b79d..10875afb8af7 100644 --- a/talpid-wireguard/src/connectivity/pinger/mod.rs +++ b/talpid-wireguard/src/connectivity/pinger/mod.rs @@ -9,11 +9,12 @@ mod imp; pub use imp::Error; /// Trait for sending ICMP requests to get some traffic from a remote server +#[async_trait::async_trait] pub trait Pinger: Send { /// Sends an ICMP packet - fn send_icmp(&mut self) -> Result<(), Error>; + async fn send_icmp(&mut self) -> Result<(), Error>; /// Clears all resources used by the pinger. - fn reset(&mut self) {} + async fn reset(&mut self) {} } /// Create a new pinger diff --git a/talpid-wireguard/src/ephemeral.rs b/talpid-wireguard/src/ephemeral.rs index 31f3957253e9..442be27ccf64 100644 --- a/talpid-wireguard/src/ephemeral.rs +++ b/talpid-wireguard/src/ephemeral.rs @@ -50,7 +50,7 @@ pub async fn config_ephemeral_peers( log::trace!("Resetting tunnel MTU"); try_set_ipv4_mtu(&iface_name, config.mtu); - Ok(()) + Ok(result) } #[cfg(windows)] @@ -226,6 +226,7 @@ async fn reconfigure_tunnel( let updated_tunnel = tunnel .set_config(&config) + .await .map_err(Error::TunnelError) .map_err(CloseMsg::SetupError)?; diff --git a/talpid-wireguard/src/lib.rs b/talpid-wireguard/src/lib.rs index cb19559dcf96..0ab9651c9523 100644 --- a/talpid-wireguard/src/lib.rs +++ b/talpid-wireguard/src/lib.rs @@ -145,7 +145,7 @@ pub struct WireguardMonitor { /// Callback to signal tunnel events event_hook: EventHook, close_msg_receiver: sync_mpsc::Receiver, - pinger_stop_sender: sync_mpsc::Sender<()>, + pinger_stop_sender: connectivity::CancelToken, obfuscator: Arc>>, } @@ -211,21 +211,22 @@ impl WireguardMonitor { let obfuscator = Arc::new(AsyncMutex::new(obfuscator)); let gateway = config.ipv4_gateway; - let (mut connectivity_monitor, pinger_tx) = connectivity::Check::new( + let (cancel_token, cancel_receiver) = connectivity::CancelToken::new(); + let mut connectivity_monitor = connectivity::Check::new( gateway, #[cfg(any(target_os = "macos", target_os = "linux"))] iface_name.clone(), args.retry_attempt, + cancel_receiver, ) - .map_err(Error::ConnectivityMonitorError)? - .with_cancellation(); + .map_err(Error::ConnectivityMonitorError)?; let monitor = WireguardMonitor { runtime: args.runtime.clone(), tunnel: Arc::new(AsyncMutex::new(Some(tunnel))), event_hook: args.event_hook.clone(), close_msg_receiver: close_obfs_listener, - pinger_stop_sender: pinger_tx, + pinger_stop_sender: cancel_token, obfuscator, }; @@ -325,28 +326,26 @@ impl WireguardMonitor { }); } - let cloned_tunnel = Arc::clone(&tunnel); - - let connectivity_check = tokio::task::spawn_blocking(move || { - let lock = cloned_tunnel.blocking_lock(); - let tunnel = lock.as_ref().expect("The tunnel was dropped unexpectedly"); - match connectivity_monitor.establish_connectivity(tunnel) { - Ok(true) => Ok(connectivity_monitor), - Ok(false) => { - log::warn!("Timeout while checking tunnel connection"); - Err(CloseMsg::PingErr) - } - Err(error) => { - log::error!( - "{}", - error.display_chain_with_msg("Failed to check tunnel connection") - ); - Err(CloseMsg::PingErr) - } + let lock = tunnel.lock().await; + let borrowed_tun = lock.as_ref().expect("The tunnel was dropped unexpectedly"); + match connectivity_monitor + .establish_connectivity(borrowed_tun) + .await + { + Ok(true) => Ok(()), + Ok(false) => { + log::warn!("Timeout while checking tunnel connection"); + Err(CloseMsg::PingErr) } - }) - .await - .unwrap()?; + Err(error) => { + log::error!( + "{}", + error.display_chain_with_msg("Failed to check tunnel connection") + ); + Err(CloseMsg::PingErr) + } + }?; + drop(lock); // Add any default route(s) that may exist. args.route_manager @@ -358,19 +357,15 @@ impl WireguardMonitor { let metadata = Self::tunnel_metadata(&iface_name, &config); event_hook.on_event(TunnelEvent::Up(metadata)).await; - let monitored_tunnel = Arc::downgrade(&tunnel); - tokio::task::spawn_blocking(move || { - if let Err(error) = - connectivity::Monitor::init(connectivity_check).run(monitored_tunnel) - { - log::error!( - "{}", - error.display_chain_with_msg("Connectivity monitor failed") - ); - } - }) - .await - .unwrap(); + if let Err(error) = connectivity::Monitor::init(connectivity_monitor) + .run(Arc::downgrade(&tunnel)) + .await + { + log::error!( + "{}", + error.display_chain_with_msg("Connectivity monitor failed") + ); + } Err::(CloseMsg::PingErr) }; @@ -429,12 +424,15 @@ impl WireguardMonitor { let should_negotiate_ephemeral_peer = config.quantum_resistant || config.daita; - let (connectivity_check, pinger_tx) = - connectivity::Check::new(config.ipv4_gateway, args.retry_attempt) - .map_err(Error::ConnectivityMonitorError)? - .with_cancellation(); + let (cancel_token, cancel_receiver) = connectivity::CancelToken::new(); + let connectivity_check = connectivity::Check::new( + config.ipv4_gateway, + args.retry_attempt, + cancel_receiver.clone(), + ) + .map_err(Error::ConnectivityMonitorError)?; - let tunnel = Self::open_wireguard_go_tunnel( + let tunnel = args.runtime.block_on(Self::open_wireguard_go_tunnel( &config, log_path, args.tun_provider.clone(), @@ -442,8 +440,8 @@ impl WireguardMonitor { // that we only allows traffic to/from the gateway. This is only needed on Android // since we lack a firewall there. should_negotiate_ephemeral_peer, - connectivity_check, - )?; + cancel_receiver, + ))?; let iface_name = tunnel.get_interface_name(); let tunnel = Arc::new(AsyncMutex::new(Some(tunnel))); @@ -453,7 +451,7 @@ impl WireguardMonitor { tunnel: Arc::clone(&tunnel), event_hook: event_hook.clone(), close_msg_receiver: close_obfs_listener, - pinger_stop_sender: pinger_tx, + pinger_stop_sender: cancel_token, obfuscator: Arc::new(AsyncMutex::new(obfuscator)), }; @@ -502,29 +500,15 @@ impl WireguardMonitor { let metadata = Self::tunnel_metadata(&iface_name, &config); event_hook.on_event(TunnelEvent::Up(metadata)).await; - // HACK: The tunnel does not need the connectivity::Check anymore, so lets take it - let connectivity_check = { - let mut tunnel_lock = tunnel.lock().await; - let Some(tunnel) = tunnel_lock.as_mut() else { - log::debug!("Tunnel is no longer running"); - return Err::(CloseMsg::PingErr); - }; - tunnel - .take_checker() - .expect("connectivity checker unexpectedly dropped") - }; - - tokio::task::spawn_blocking(move || { - let tunnel = Arc::downgrade(&tunnel); - if let Err(error) = connectivity::Monitor::init(connectivity_check).run(tunnel) { - log::error!( - "{}", - error.display_chain_with_msg("Connectivity monitor failed") - ); - } - }) - .await - .unwrap(); + if let Err(error) = connectivity::Monitor::init(connectivity_check) + .run(Arc::downgrade(&tunnel)) + .await + { + log::error!( + "{}", + error.display_chain_with_msg("Connectivity monitor failed") + ); + } Err::(CloseMsg::PingErr) }; @@ -663,13 +647,18 @@ impl WireguardMonitor { if !*FORCE_USERSPACE_WIREGUARD { // If DAITA is enabled, wireguard-go has to be used. if config.daita { - let tunnel = - Self::open_wireguard_go_tunnel(config, log_path, tun_provider).map(Box::new)?; + let tunnel = runtime + .block_on(Self::open_wireguard_go_tunnel( + config, + log_path, + tun_provider, + )) + .map(Box::new)?; return Ok(tunnel); } if will_nm_manage_dns() { - match wireguard_kernel::NetworkManagerTunnel::new(runtime, config) { + match wireguard_kernel::NetworkManagerTunnel::new(runtime.clone(), config) { Ok(tunnel) => { log::debug!("Using NetworkManager to use kernel WireGuard implementation"); return Ok(Box::new(tunnel)); @@ -684,7 +673,7 @@ impl WireguardMonitor { } }; } else { - match wireguard_kernel::NetlinkTunnel::new(runtime, config) { + match wireguard_kernel::NetlinkTunnel::new(runtime.clone(), config) { Ok(tunnel) => { log::debug!("Using kernel WireGuard implementation"); return Ok(Box::new(tunnel)); @@ -713,28 +702,28 @@ impl WireguardMonitor { #[cfg(target_os = "linux")] log::debug!("Using userspace WireGuard implementation"); - let tunnel = Self::open_wireguard_go_tunnel( - config, - log_path, - tun_provider, - #[cfg(target_os = "android")] - gateway_only, - ) - .map(Box::new)?; + let tunnel = runtime + .block_on(Self::open_wireguard_go_tunnel( + config, + log_path, + tun_provider, + #[cfg(target_os = "android")] + gateway_only, + )) + .map(Box::new)?; Ok(tunnel) } } /// Configure and start a Wireguard-go tunnel. #[cfg(wireguard_go)] - fn open_wireguard_go_tunnel( + #[allow(clippy::unused_async)] + async fn open_wireguard_go_tunnel( config: &Config, log_path: Option<&Path>, tun_provider: Arc>, #[cfg(target_os = "android")] gateway_only: bool, - #[cfg(target_os = "android")] connectivity_check: connectivity::Check< - connectivity::Cancellable, - >, + #[cfg(target_os = "android")] cancel_receiver: connectivity::CancelReceiver, ) -> Result { let routes = config .get_tunnel_destinations() @@ -769,8 +758,9 @@ impl WireguardMonitor { log_path, tun_provider, routes, - connectivity_check, + cancel_receiver, ) + .await .map_err(Error::TunnelError)? } else { WgGoTunnel::start_tunnel( @@ -779,8 +769,9 @@ impl WireguardMonitor { log_path, tun_provider, routes, - connectivity_check, + cancel_receiver, ) + .await .map_err(Error::TunnelError)? }; @@ -799,7 +790,7 @@ impl WireguardMonitor { Err(_) => Ok(()), }; - let _ = self.pinger_stop_sender.send(()); + self.pinger_stop_sender.close(); self.runtime .block_on(self.event_hook.on_event(TunnelEvent::Down)); @@ -984,7 +975,7 @@ impl WireguardMonitor { async fn log_tunnel_data_usage(config: &Config, tunnel: &Arc>>) { let tunnel = tunnel.lock().await; let Some(tunnel) = &*tunnel else { return }; - let Ok(tunnel_stats) = tunnel.get_tunnel_stats() else { + let Ok(tunnel_stats) = tunnel.get_tunnel_stats().await else { return; }; if let Some(stats) = config @@ -1012,10 +1003,11 @@ enum CloseMsg { } #[allow(unused)] -pub(crate) trait Tunnel: Send { +#[async_trait::async_trait] +pub(crate) trait Tunnel: Send + Sync { fn get_interface_name(&self) -> String; fn stop(self: Box) -> std::result::Result<(), TunnelError>; - fn get_tunnel_stats(&self) -> std::result::Result; + async fn get_tunnel_stats(&self) -> std::result::Result; fn set_config<'a>( &'a mut self, _config: Config, diff --git a/talpid-wireguard/src/wireguard_go/mod.rs b/talpid-wireguard/src/wireguard_go/mod.rs index b91f3c464a4c..0b3de775fcad 100644 --- a/talpid-wireguard/src/wireguard_go/mod.rs +++ b/talpid-wireguard/src/wireguard_go/mod.rs @@ -1,7 +1,5 @@ #[cfg(target_os = "android")] use super::config; -#[cfg(target_os = "android")] -use super::Error; use super::{ stats::{Stats, StatsMap}, Config, Tunnel, TunnelError, @@ -106,32 +104,26 @@ impl WgGoTunnel { } } - pub fn set_config(self, config: &Config) -> Result { + pub async fn set_config(self, config: &Config) -> Result { let state = self.as_state(); let log_path = state._logging_context.path.clone(); + let cancel_receiver = state.cancel_receiver.clone(); let tun_provider = Arc::clone(&state.tun_provider); let routes = config.get_tunnel_destinations(); match self { - WgGoTunnel::Multihop(mut state) if !config.is_multihop() => { - let connectivity_checker = state - .connectivity_checker - .take() - .expect("connectivity checker unexpectedly dropped"); + WgGoTunnel::Multihop(state) if !config.is_multihop() => { state.stop()?; Self::start_tunnel( config, log_path.as_deref(), tun_provider, routes, - connectivity_checker, + cancel_receiver, ) + .await } - WgGoTunnel::Singlehop(mut state) if config.is_multihop() => { - let connectivity_checker = state - .connectivity_checker - .take() - .expect("connectivity checker unexpectedly dropped"); + WgGoTunnel::Singlehop(state) if config.is_multihop() => { state.stop()?; Self::start_multihop_tunnel( config, @@ -139,8 +131,9 @@ impl WgGoTunnel { log_path.as_deref(), tun_provider, routes, - connectivity_checker, + cancel_receiver, ) + .await } WgGoTunnel::Singlehop(mut state) => { state.set_config(config.clone())?; @@ -170,13 +163,9 @@ pub(crate) struct WgGoTunnelState { tun_provider: Arc>, #[cfg(daita)] config: Config, - // HACK: Check is not Clone, so we have to pass this around .. - // This is conceptually the connection between this Tunnel and the currently running - // WireguardMonitor, and it is used to allow WireguardMonitor to cancel the setup of - // a new Tunnel during the "ensure_connectivity" phase. This field should be removed - // as soon as we implement a better way to cancel Check asynchronously. + /// This is used to cancel the connectivity checks that occur when toggling multihop #[cfg(target_os = "android")] - connectivity_checker: Option>, + cancel_receiver: connectivity::CancelReceiver, } impl WgGoTunnelState { @@ -293,12 +282,12 @@ impl WgGoTunnel { #[cfg(target_os = "android")] impl WgGoTunnel { - pub fn start_tunnel( + pub async fn start_tunnel( config: &Config, log_path: Option<&Path>, tun_provider: Arc>, routes: impl Iterator, - mut connectivity_check: connectivity::Check, + cancel_receiver: connectivity::CancelReceiver, ) -> Result { let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?; @@ -321,7 +310,7 @@ impl WgGoTunnel { Self::bypass_tunnel_sockets(&handle, &mut tunnel_device) .map_err(TunnelError::BypassError)?; - let mut tunnel = WgGoTunnel::Singlehop(WgGoTunnelState { + let tunnel = WgGoTunnel::Singlehop(WgGoTunnelState { interface_name, tunnel_handle: handle, _tunnel_device: tunnel_device, @@ -329,23 +318,22 @@ impl WgGoTunnel { tun_provider, #[cfg(daita)] config: config.clone(), - connectivity_checker: None, + cancel_receiver, }); // HACK: Check if the tunnel is working by sending a ping in the tunnel. - tunnel.ensure_tunnel_is_running(&mut connectivity_check)?; - tunnel.as_state_mut().connectivity_checker = Some(connectivity_check); + tunnel.ensure_tunnel_is_running().await?; Ok(tunnel) } - pub fn start_multihop_tunnel( + pub async fn start_multihop_tunnel( config: &Config, exit_peer: &PeerConfig, log_path: Option<&Path>, tun_provider: Arc>, routes: impl Iterator, - mut connectivity_check: connectivity::Check, + cancel_receiver: connectivity::CancelReceiver, ) -> Result { let (mut tunnel_device, tunnel_fd) = Self::get_tunnel(Arc::clone(&tun_provider), config, routes)?; @@ -384,7 +372,7 @@ impl WgGoTunnel { Self::bypass_tunnel_sockets(&handle, &mut tunnel_device) .map_err(TunnelError::BypassError)?; - let mut tunnel = WgGoTunnel::Multihop(WgGoTunnelState { + let tunnel = WgGoTunnel::Multihop(WgGoTunnelState { interface_name, tunnel_handle: handle, _tunnel_device: tunnel_device, @@ -392,12 +380,11 @@ impl WgGoTunnel { tun_provider, #[cfg(daita)] config: config.clone(), - connectivity_checker: None, + cancel_receiver: cancel_receiver.clone(), }); // HACK: Check if the tunnel is working by sending a ping in the tunnel. - tunnel.ensure_tunnel_is_running(&mut connectivity_check)?; - tunnel.as_state_mut().connectivity_checker = Some(connectivity_check); + tunnel.ensure_tunnel_is_running().await?; Ok(tunnel) } @@ -415,30 +402,34 @@ impl WgGoTunnel { Ok(()) } - pub fn take_checker(&mut self) -> Option> { - self.as_state_mut().connectivity_checker.take() - } - /// There is a brief period of time between setting up a Wireguard-go tunnel and the tunnel being ready to serve /// traffic. This function blocks until the tunnel starts to serve traffic or until [connectivity::Check] times out. - fn ensure_tunnel_is_running( - &self, - checker: &mut connectivity::Check, - ) -> Result<()> { - let connection_established = checker + async fn ensure_tunnel_is_running(&self) -> Result<()> { + let state = self.as_state(); + let addr = state.config.ipv4_gateway; + let cancel_receiver = state.cancel_receiver.clone(); + let mut check = connectivity::Check::new(addr, 0, cancel_receiver) + .map_err(|err| TunnelError::RecoverableStartWireguardError(Box::new(err)))?; + + // TODO: retry attempt? + + let connection_established = check .establish_connectivity(self) + .await .map_err(|e| TunnelError::RecoverableStartWireguardError(Box::new(e)))?; // Timed out if !connection_established { return Err(TunnelError::RecoverableStartWireguardError(Box::new( - Error::TimeoutError, + super::Error::TimeoutError, ))); } + Ok(()) } } +#[async_trait::async_trait] impl Tunnel for WgGoTunnel { fn get_interface_name(&self) -> String { self.as_state().interface_name.clone() @@ -448,14 +439,16 @@ impl Tunnel for WgGoTunnel { self.into_state().stop() } - fn get_tunnel_stats(&self) -> Result { - self.as_state() - .tunnel_handle - .get_config(|cstr| { - Stats::parse_config_str(cstr.to_str().expect("Go strings are always UTF-8")) - }) - .ok_or(TunnelError::GetConfigError)? - .map_err(|error| TunnelError::StatsError(BoxedError::new(error))) + async fn get_tunnel_stats(&self) -> Result { + tokio::task::block_in_place(|| { + self.as_state() + .tunnel_handle + .get_config(|cstr| { + Stats::parse_config_str(cstr.to_str().expect("Go strings are always UTF-8")) + }) + .ok_or(TunnelError::GetConfigError)? + .map_err(|error| TunnelError::StatsError(BoxedError::new(error))) + }) } fn set_config( diff --git a/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs b/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs index 8b84b3769d99..86285d80a2b7 100644 --- a/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs +++ b/talpid-wireguard/src/wireguard_kernel/netlink_tunnel.rs @@ -65,6 +65,7 @@ impl NetlinkTunnel { } } +#[async_trait::async_trait] impl Tunnel for NetlinkTunnel { fn get_interface_name(&self) -> String { let mut wg = self.netlink_connections.wg_handle.clone(); @@ -103,16 +104,14 @@ impl Tunnel for NetlinkTunnel { }) } - fn get_tunnel_stats(&self) -> std::result::Result { - let mut wg = self.netlink_connections.wg_handle.clone(); + async fn get_tunnel_stats(&self) -> std::result::Result { let interface_index = self.interface_index; - self.tokio_handle.block_on(async move { - let device = wg.get_by_index(interface_index).await.map_err(|err| { - log::error!("Failed to fetch WireGuard device config: {}", err); - TunnelError::GetConfigError - })?; - Ok(Stats::parse_device_message(&device)) - }) + let mut wg = self.netlink_connections.wg_handle.clone(); + let device = wg.get_by_index(interface_index).await.map_err(|err| { + log::error!("Failed to fetch WireGuard device config: {}", err); + TunnelError::GetConfigError + })?; + Ok(Stats::parse_device_message(&device)) } fn set_config( diff --git a/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs b/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs index 070e3d1ee9e8..ba3bca14befc 100644 --- a/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs +++ b/talpid-wireguard/src/wireguard_kernel/nm_tunnel.rs @@ -28,7 +28,6 @@ pub struct NetworkManagerTunnel { network_manager: NetworkManager, tunnel: Option, netlink_connections: Handle, - tokio_handle: tokio::runtime::Handle, interface_name: String, } @@ -58,12 +57,12 @@ impl NetworkManagerTunnel { network_manager, tunnel: Some(tunnel), netlink_connections, - tokio_handle, interface_name, }) } } +#[async_trait::async_trait] impl Tunnel for NetworkManagerTunnel { fn get_interface_name(&self) -> String { self.interface_name.clone() @@ -82,18 +81,16 @@ impl Tunnel for NetworkManagerTunnel { } } - fn get_tunnel_stats(&self) -> std::result::Result { + async fn get_tunnel_stats(&self) -> std::result::Result { let mut wg = self.netlink_connections.wg_handle.clone(); - self.tokio_handle.block_on(async move { - let device = wg - .get_by_name(self.interface_name.clone()) - .await - .map_err(|err| { - log::error!("Failed to fetch WireGuard device config: {}", err); - TunnelError::GetConfigError - })?; - Ok(Stats::parse_device_message(&device)) - }) + let device = wg + .get_by_name(self.interface_name.clone()) + .await + .map_err(|err| { + log::error!("Failed to fetch WireGuard device config: {}", err); + TunnelError::GetConfigError + })?; + Ok(Stats::parse_device_message(&device)) } fn set_config( diff --git a/talpid-wireguard/src/wireguard_nt/mod.rs b/talpid-wireguard/src/wireguard_nt/mod.rs index fefb7879e9c0..9243425cde87 100644 --- a/talpid-wireguard/src/wireguard_nt/mod.rs +++ b/talpid-wireguard/src/wireguard_nt/mod.rs @@ -1037,13 +1037,20 @@ unsafe fn deserialize_config( Ok((interface, peers)) } +#[async_trait::async_trait] impl Tunnel for WgNtTunnel { fn get_interface_name(&self) -> String { self.interface_name.clone() } - fn get_tunnel_stats(&self) -> std::result::Result { - if let Some(ref device) = self.device { + async fn get_tunnel_stats(&self) -> std::result::Result { + let Some(ref device) = self.device else { + log::error!("Failed to obtain tunnel stats as device no longer exists"); + return Err(super::TunnelError::GetConfigError); + }; + + let device = device.clone(); + tokio::task::spawn_blocking(move || { let mut map = StatsMap::new(); let (_interface, peers) = device.get_config().map_err(|error| { log::error!( @@ -1062,10 +1069,9 @@ impl Tunnel for WgNtTunnel { ); } Ok(map) - } else { - log::error!("Failed to obtain tunnel stats as device no longer exists"); - Err(super::TunnelError::GetConfigError) - } + }) + .await + .unwrap() } fn stop(mut self: Box) -> std::result::Result<(), super::TunnelError> {