From 5cf3b40310a0d9176b2084af1db623aa0bca63eb Mon Sep 17 00:00:00 2001 From: Kevin Wang Date: Mon, 16 Dec 2024 10:34:44 +0000 Subject: [PATCH] tproxy: Connect to multiple hosts --- Cargo.lock | 2 + Cargo.toml | 1 + tproxy/Cargo.toml | 2 + tproxy/src/config.rs | 8 ++- tproxy/src/main.rs | 2 +- tproxy/src/main_service.rs | 93 +++++++++++++++++++++------- tproxy/src/main_service/tests.rs | 4 +- tproxy/src/proxy.rs | 17 +++-- tproxy/src/proxy/tls_passthough.rs | 44 +++++++++---- tproxy/src/proxy/tls_terminate.rs | 21 +++---- tproxy/src/web_routes.rs | 8 +-- tproxy/src/web_routes/route_index.rs | 4 +- tproxy/tproxy.toml | 5 ++ 13 files changed, 150 insertions(+), 61 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2e42a9c..8dec051 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5075,6 +5075,7 @@ dependencies = [ "certbot", "clap", "fs-err", + "futures", "git-version", "hex", "hickory-resolver", @@ -5092,6 +5093,7 @@ dependencies = [ "serde", "serde_json", "shared_child", + "smallvec", "tokio", "tokio-rustls 0.26.0", "tproxy-rpc", diff --git a/Cargo.toml b/Cargo.toml index b31466b..ab66e67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -140,3 +140,4 @@ tailf = "0.1.2" time = "0.3.37" uuid = { version = "1.11.0", features = ["v4"] } which = "7.0.0" +smallvec = "1.13.2" diff --git a/tproxy/Cargo.toml b/tproxy/Cargo.toml index 9990fa9..5625bff 100644 --- a/tproxy/Cargo.toml +++ b/tproxy/Cargo.toml @@ -32,6 +32,8 @@ tproxy-rpc.workspace = true certbot.workspace = true bytes.workspace = true safe-write.workspace = true +smallvec.workspace = true +futures.workspace = true [target.'cfg(unix)'.dependencies] nix = { workspace = true, features = ["resource"] } diff --git a/tproxy/src/config.rs b/tproxy/src/config.rs index d1913f2..8b225bf 100644 --- a/tproxy/src/config.rs +++ b/tproxy/src/config.rs @@ -31,6 +31,7 @@ pub struct ProxyConfig { pub tappd_port: u16, pub timeouts: Timeouts, pub buffer_size: usize, + pub connect_top_n: usize, } #[derive(Debug, Clone, Deserialize)] @@ -39,6 +40,11 @@ pub struct Timeouts { pub connect: Duration, #[serde(with = "serde_duration")] pub handshake: Duration, + #[serde(with = "serde_duration")] + pub total: Duration, + + #[serde(with = "serde_duration")] + pub cache_top_n: Duration, pub data_timeout_enabled: bool, #[serde(with = "serde_duration")] @@ -47,8 +53,6 @@ pub struct Timeouts { pub write: Duration, #[serde(with = "serde_duration")] pub shutdown: Duration, - #[serde(with = "serde_duration")] - pub total: Duration, } #[derive(Debug, Clone, Deserialize)] diff --git a/tproxy/src/main.rs b/tproxy/src/main.rs index 9103de8..4974b0f 100644 --- a/tproxy/src/main.rs +++ b/tproxy/src/main.rs @@ -61,7 +61,7 @@ async fn main() -> Result<()> { let proxy_config = config.proxy.clone(); let pccs_url = config.pccs_url.clone(); - let state = main_service::AppState::new(config)?; + let state = main_service::Proxy::new(config)?; state.lock().reconfigure()?; proxy::start(proxy_config, state.clone()); diff --git a/tproxy/src/main_service.rs b/tproxy/src/main_service.rs index b543018..6db64b7 100644 --- a/tproxy/src/main_service.rs +++ b/tproxy/src/main_service.rs @@ -1,9 +1,9 @@ use std::{ collections::{BTreeMap, BTreeSet}, net::Ipv4Addr, - process::Command, + process::{Command, Stdio}, sync::{Arc, Mutex, MutexGuard, Weak}, - time::{Duration, SystemTime, UNIX_EPOCH}, + time::{Duration, Instant, SystemTime, UNIX_EPOCH}, }; use anyhow::{bail, Context, Result}; @@ -14,38 +14,41 @@ use rand::seq::IteratorRandom; use rinja::Template as _; use safe_write::safe_write; use serde::{Deserialize, Serialize}; +use smallvec::{smallvec, SmallVec}; use tproxy_rpc::{ tproxy_server::{TproxyRpc, TproxyServer}, AcmeInfoResponse, GetInfoRequest, GetInfoResponse, HostInfo as PbHostInfo, ListResponse, RegisterCvmRequest, RegisterCvmResponse, TappdConfig, WireGuardConfig, }; -use tracing::{debug, error, info}; +use tracing::{debug, error, info, warn}; use crate::{ config::Config, - models::{InstanceInfo, WgConf}, + models::{InstanceInfo, WgConf}, proxy::AddressGroup, }; #[derive(Clone)] -pub struct AppState { +pub struct Proxy { pub(crate) config: Arc, - inner: Arc>, + inner: Arc>, } #[derive(Debug, Serialize, Deserialize)] -struct State { +struct ProxyStateMut { apps: BTreeMap>, instances: BTreeMap, allocated_addresses: BTreeSet, + #[serde(skip)] + top_n: BTreeMap, } -pub(crate) struct AppStateInner { +pub(crate) struct ProxyState { config: Arc, - state: State, + state: ProxyStateMut, } -impl AppState { - pub(crate) fn lock(&self) -> MutexGuard { +impl Proxy { + pub(crate) fn lock(&self) -> MutexGuard { self.inner.lock().expect("Failed to lock AppState") } @@ -56,13 +59,14 @@ impl AppState { let state_str = fs::read_to_string(state_path).context("Failed to read state")?; serde_json::from_str(&state_str).context("Failed to load state")? } else { - State { + ProxyStateMut { apps: BTreeMap::new(), + top_n: BTreeMap::new(), instances: BTreeMap::new(), allocated_addresses: BTreeSet::new(), } }; - let inner = Arc::new(Mutex::new(AppStateInner { + let inner = Arc::new(Mutex::new(ProxyState { config: config.clone(), state, })); @@ -71,7 +75,7 @@ impl AppState { } } -fn start_recycle_thread(state: Weak>, config: Arc) { +fn start_recycle_thread(state: Weak>, config: Arc) { if !config.recycle.enabled { info!("recycle is disabled"); return; @@ -87,7 +91,7 @@ fn start_recycle_thread(state: Weak>, config: Arc) }); } -impl AppStateInner { +impl ProxyState { fn alloc_ip(&mut self) -> Option { for ip in self.config.wg.client_ip_range.hosts() { if ip == self.config.wg.ip { @@ -166,10 +170,49 @@ impl AppStateInner { Ok(()) } - pub(crate) fn select_a_host(&self, id: &str) -> Option { + pub(crate) fn select_top_n_hosts<'a>(&'a mut self, id: &str) -> Result { + let n = self.config.proxy.connect_top_n; + if let Some(instance) = self.state.instances.get(id) { + return Ok(smallvec![instance.ip]); + }; + let app_instances = self.state.apps.get(id).context("app not found")?; + if n == 0 { + // fallback to random selection + return Ok(self.random_select_a_host(id).unwrap_or_default()); + } + let (top_n, insert_time) = self + .state + .top_n + .entry(id.to_string()) + .or_insert((SmallVec::new(), Instant::now())); + if !top_n.is_empty() && insert_time.elapsed() < self.config.proxy.timeouts.cache_top_n { + return Ok(top_n.clone()); + } + + let handshakes = self.latest_handshakes(None); + let mut instances = match handshakes { + Err(err) => { + warn!("Failed to get handshakes, fallback to random selection: {err}"); + return Ok(self.random_select_a_host(id).unwrap_or_default()); + } + Ok(handshakes) => app_instances + .iter() + .filter_map(|instance_id| { + let instance = self.state.instances.get(instance_id)?; + let (_, elapsed) = handshakes.get(&instance.public_key)?; + Some((instance.ip, *elapsed)) + }) + .collect::>(), + }; + instances.sort_by(|a, b| a.1.cmp(&b.1)); + instances.truncate(n); + Ok(instances.into_iter().map(|(ip, _)| ip).collect()) + } + + fn random_select_a_host(&self, id: &str) -> Option { // Direct instance lookup first if let Some(info) = self.state.instances.get(id).cloned() { - return Some(info); + return Some(smallvec![info.ip]); } let app_instances = self.state.apps.get(id)?; @@ -191,9 +234,15 @@ impl AppStateInner { }); let selected = healthy_instances.choose(&mut rand::thread_rng())?; - self.state.instances.get(selected).cloned() + self.state + .instances + .get(selected) + .map(|info| smallvec![info.ip]) } + /// Get latest handshakes + /// + /// Return a map of public key to (timestamp, elapsed) fn latest_handshakes( &self, stale_timeout: Option, @@ -211,6 +260,8 @@ impl AppStateInner { .arg("show") .arg(&self.config.wg.interface) .arg("latest-handshakes") + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) .output() .context("failed to execute wg show command")?; @@ -304,7 +355,7 @@ impl AppStateInner { pub struct RpcHandler { attestation: Option, - state: AppState, + state: Proxy, } impl TproxyRpc for RpcHandler { @@ -413,14 +464,14 @@ impl TproxyRpc for RpcHandler { } } -impl RpcCall for RpcHandler { +impl RpcCall for RpcHandler { type PrpcService = TproxyServer; fn into_prpc_service(self) -> Self::PrpcService { TproxyServer::new(self) } - fn construct(state: &AppState, attestation: Option) -> Result + fn construct(state: &Proxy, attestation: Option) -> Result where Self: Sized, { diff --git a/tproxy/src/main_service/tests.rs b/tproxy/src/main_service/tests.rs index 6a91726..b29eb7d 100644 --- a/tproxy/src/main_service/tests.rs +++ b/tproxy/src/main_service/tests.rs @@ -1,10 +1,10 @@ use super::*; use crate::config::{load_config_figment, Config}; -fn create_test_state() -> AppState { +fn create_test_state() -> Proxy { let figment = load_config_figment(None); let config = figment.focus("core").extract::().unwrap(); - AppState::new(config).expect("failed to create app state") + Proxy::new(config).expect("failed to create app state") } #[test] diff --git a/tproxy/src/proxy.rs b/tproxy/src/proxy.rs index b5fd2e1..c14eade 100644 --- a/tproxy/src/proxy.rs +++ b/tproxy/src/proxy.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{net::Ipv4Addr, sync::Arc}; use anyhow::{bail, Context, Result}; use sni::extract_sni; @@ -10,7 +10,9 @@ use tokio::{ }; use tracing::{debug, error, info}; -use crate::{config::ProxyConfig, main_service::AppState}; +use crate::{config::ProxyConfig, main_service::Proxy}; + +pub(crate) type AddressGroup = smallvec::SmallVec<[Ipv4Addr; 4]>; mod io_bridge; mod sni; @@ -89,7 +91,7 @@ fn parse_destination(sni: &str, dotted_base_domain: &str) -> Result { async fn handle_connection( mut inbound: TcpStream, - state: AppState, + state: Proxy, dotted_base_domain: &str, tls_terminate_proxy: Arc, ) -> Result<()> { @@ -126,7 +128,7 @@ async fn handle_connection( } } -pub async fn run(config: &ProxyConfig, app_state: AppState) -> Result<()> { +pub async fn run(config: &ProxyConfig, app_state: Proxy) -> Result<()> { let dotted_base_domain = { let base_domain = config.base_domain.as_str(); let base_domain = base_domain.strip_prefix(".").unwrap_or(base_domain); @@ -187,7 +189,7 @@ pub async fn run(config: &ProxyConfig, app_state: AppState) -> Result<()> { } } -pub fn start(config: ProxyConfig, app_state: AppState) { +pub fn start(config: ProxyConfig, app_state: Proxy) { tokio::spawn(async move { if let Err(err) = run(&config, app_state).await { error!( @@ -197,3 +199,8 @@ pub fn start(config: ProxyConfig, app_state: AppState) { } }); } + +// async fn connect_to_app(state: &AppState, app_id: &str, port: u16) -> Result { +// let host = state.lock().select_a_host(app_id).context(format!("tapp {app_id} not found"))?; +// TcpStream::connect((host.ip, port)) +// } diff --git a/tproxy/src/proxy/tls_passthough.rs b/tproxy/src/proxy/tls_passthough.rs index dcd3255..41dcf3c 100644 --- a/tproxy/src/proxy/tls_passthough.rs +++ b/tproxy/src/proxy/tls_passthough.rs @@ -1,11 +1,11 @@ use anyhow::{Context, Result}; use std::fmt::Debug; -use tokio::{io::AsyncWriteExt, net::TcpStream, time::timeout}; +use tokio::{io::AsyncWriteExt, net::TcpStream, task::JoinSet, time::timeout}; use tracing::debug; -use crate::main_service::AppState; +use crate::main_service::Proxy; -use super::io_bridge::bridge; +use super::{io_bridge::bridge, AddressGroup}; #[derive(Debug)] struct TappAddress { @@ -43,7 +43,7 @@ async fn resolve_tapp_address(sni: &str) -> Result { } pub(crate) async fn proxy_with_sni( - state: AppState, + state: Proxy, inbound: TcpStream, buffer: Vec, sni: &str, @@ -55,25 +55,43 @@ pub(crate) async fn proxy_with_sni( proxy_to_app(state, inbound, buffer, &tapp_addr.app_id, tapp_addr.port).await } +/// connect to multiple hosts simultaneously and return the first successful connection +pub(crate) async fn connect_multiple_hosts( + addresses: AddressGroup, + port: u16, +) -> Result { + let mut join_set = JoinSet::new(); + for addr in addresses { + debug!("connecting to {addr}:{port}"); + let future = TcpStream::connect((addr, port)); + join_set.spawn(future); + } + // select the first successful connection + let connection = join_set + .join_next() + .await + .context("No app address available")? + .context("Failed to join the connect task")? + .context("Failed to connect to tapp")?; + debug!("connected to {:?}", connection.peer_addr()); + Ok(connection) +} + pub(crate) async fn proxy_to_app( - state: AppState, + state: Proxy, inbound: TcpStream, buffer: Vec, app_id: &str, port: u16, ) -> Result<()> { - let target_ip = state - .lock() - .select_a_host(app_id) - .context("tapp not found")? - .ip; + let addresses = state.lock().select_top_n_hosts(app_id)?; let mut outbound = timeout( state.config.proxy.timeouts.connect, - TcpStream::connect((target_ip, port)), + connect_multiple_hosts(addresses.clone(), port), ) .await - .context("connecting timeout")? - .context("failed to connect to tapp")?; + .with_context(|| format!("connecting timeout to tapp {app_id}: {addresses:?}:{port}"))? + .with_context(|| format!("failed to connect to tapp {app_id}: {addresses:?}:{port}"))?; outbound .write_all(&buffer) .await diff --git a/tproxy/src/proxy/tls_terminate.rs b/tproxy/src/proxy/tls_terminate.rs index 5b45722..1d86b78 100644 --- a/tproxy/src/proxy/tls_terminate.rs +++ b/tproxy/src/proxy/tls_terminate.rs @@ -12,10 +12,12 @@ use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; use tokio::net::TcpStream; use tokio::time::timeout; use tokio_rustls::{rustls, TlsAcceptor}; +use tracing::debug; -use crate::main_service::AppState; +use crate::main_service::Proxy; use super::io_bridge::bridge; +use super::tls_passthough::connect_multiple_hosts; #[pin_project::pin_project] struct IgnoreUnexpectedEofStream { @@ -85,16 +87,12 @@ where } pub struct TlsTerminateProxy { - app_state: AppState, + app_state: Proxy, acceptor: TlsAcceptor, } impl TlsTerminateProxy { - pub fn new( - app_state: &AppState, - cert: impl AsRef, - key: impl AsRef, - ) -> Result { + pub fn new(app_state: &Proxy, cert: impl AsRef, key: impl AsRef) -> Result { let cert_pem = fs::read(cert.as_ref()).context("failed to read certificate")?; let key_pem = fs::read(key.as_ref()).context("failed to read private key")?; let certs = CertificateDer::pem_slice_iter(cert_pem.as_slice()) @@ -123,11 +121,12 @@ impl TlsTerminateProxy { port: Option, ) -> Result<()> { let port = port.unwrap_or(80); - let host = self + let addresses = self .app_state .lock() - .select_a_host(app_id) - .context(format!("tapp {app_id} not found"))?; + .select_top_n_hosts(app_id) + .with_context(|| format!("tapp {app_id} not found"))?; + debug!("selected top n hosts: {addresses:?}"); let stream = MergedStream { buffer, buffer_cursor: 0, @@ -142,7 +141,7 @@ impl TlsTerminateProxy { .context("failed to accept tls connection")?; let outbound = timeout( self.app_state.config.proxy.timeouts.connect, - TcpStream::connect((host.ip, port)), + connect_multiple_hosts(addresses, port), ) .await .map_err(|_| anyhow::anyhow!("connecting timeout"))? diff --git a/tproxy/src/web_routes.rs b/tproxy/src/web_routes.rs index 5ecbf7e..469072b 100644 --- a/tproxy/src/web_routes.rs +++ b/tproxy/src/web_routes.rs @@ -1,4 +1,4 @@ -use crate::main_service::{AppState, RpcHandler}; +use crate::main_service::{Proxy, RpcHandler}; use anyhow::Result; use ra_rpc::rocket_helper::{handle_prpc, QuoteVerifier}; use rocket::{ @@ -14,14 +14,14 @@ use rocket::{ mod route_index; #[get("/")] -async fn index(state: &State) -> Result, String> { +async fn index(state: &State) -> Result, String> { route_index::index(state).await.map_err(|e| format!("{e}")) } #[post("/prpc/?", data = "")] #[allow(clippy::too_many_arguments)] async fn prpc_post( - state: &State, + state: &State, cert: Option>, quote_verifier: Option<&State>, method: &str, @@ -45,7 +45,7 @@ async fn prpc_post( #[get("/prpc/")] async fn prpc_get( - state: &State, + state: &State, cert: Option>, quote_verifier: Option<&State>, method: &str, diff --git a/tproxy/src/web_routes/route_index.rs b/tproxy/src/web_routes/route_index.rs index b691134..4948d4c 100644 --- a/tproxy/src/web_routes/route_index.rs +++ b/tproxy/src/web_routes/route_index.rs @@ -1,5 +1,5 @@ use crate::{ - main_service::{AppState, RpcHandler}, + main_service::{Proxy, RpcHandler}, models::CvmList, }; use anyhow::Context; @@ -8,7 +8,7 @@ use rinja::Template as _; use rocket::{response::content::RawHtml as Html, State}; use tproxy_rpc::tproxy_server::TproxyRpc; -pub async fn index(state: &State) -> anyhow::Result> { +pub async fn index(state: &State) -> anyhow::Result> { let rpc_handler = RpcHandler::construct(state, None).context("Failed to construct RpcHandler")?; let response = rpc_handler.list().await.context("Failed to list hosts")?; diff --git a/tproxy/tproxy.toml b/tproxy/tproxy.toml index 40783e8..979f43a 100644 --- a/tproxy/tproxy.toml +++ b/tproxy/tproxy.toml @@ -33,6 +33,8 @@ listen_addr = "0.0.0.0" listen_port = 8443 tappd_port = 8090 buffer_size = 8192 +# number of hosts to try to connect to +connect_top_n = 3 [core.proxy.timeouts] # Timeout for establishing a connection to the target app. @@ -40,6 +42,9 @@ connect = "5s" # TLS-termination handshake timeout or SNI extraction timeout. handshake = "5s" +# Timeout for top n hosts selection +cache_top_n = "30s" + # Enable data transfer timeouts below. This might impact performance. Turn off if # bad performance is observed. data_timeout_enabled = true