Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tproxy: Connect to multiple hosts #60

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,4 @@ tailf = "0.1.2"
time = "0.3.37"
uuid = { version = "1.11.0", features = ["v4"] }
which = "7.0.0"
smallvec = "1.13.2"
2 changes: 2 additions & 0 deletions tproxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ tproxy-rpc.workspace = true
certbot.workspace = true
bytes.workspace = true
safe-write.workspace = true
smallvec.workspace = true
futures.workspace = true

[target.'cfg(unix)'.dependencies]
nix = { workspace = true, features = ["resource"] }
Expand Down
8 changes: 6 additions & 2 deletions tproxy/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub struct ProxyConfig {
pub tappd_port: u16,
pub timeouts: Timeouts,
pub buffer_size: usize,
pub connect_top_n: usize,
}

#[derive(Debug, Clone, Deserialize)]
Expand All @@ -39,6 +40,11 @@ pub struct Timeouts {
pub connect: Duration,
#[serde(with = "serde_duration")]
pub handshake: Duration,
#[serde(with = "serde_duration")]
pub total: Duration,

#[serde(with = "serde_duration")]
pub cache_top_n: Duration,

pub data_timeout_enabled: bool,
#[serde(with = "serde_duration")]
Expand All @@ -47,8 +53,6 @@ pub struct Timeouts {
pub write: Duration,
#[serde(with = "serde_duration")]
pub shutdown: Duration,
#[serde(with = "serde_duration")]
pub total: Duration,
}

#[derive(Debug, Clone, Deserialize)]
Expand Down
2 changes: 1 addition & 1 deletion tproxy/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async fn main() -> Result<()> {

let proxy_config = config.proxy.clone();
let pccs_url = config.pccs_url.clone();
let state = main_service::AppState::new(config)?;
let state = main_service::Proxy::new(config)?;
state.lock().reconfigure()?;
proxy::start(proxy_config, state.clone());

Expand Down
92 changes: 72 additions & 20 deletions tproxy/src/main_service.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::{
collections::{BTreeMap, BTreeSet},
net::Ipv4Addr,
process::Command,
process::{Command, Stdio},
sync::{Arc, Mutex, MutexGuard, Weak},
time::{Duration, SystemTime, UNIX_EPOCH},
time::{Duration, Instant, SystemTime, UNIX_EPOCH},
};

use anyhow::{bail, Context, Result};
Expand All @@ -14,38 +14,42 @@ use rand::seq::IteratorRandom;
use rinja::Template as _;
use safe_write::safe_write;
use serde::{Deserialize, Serialize};
use smallvec::{smallvec, SmallVec};
use tproxy_rpc::{
tproxy_server::{TproxyRpc, TproxyServer},
AcmeInfoResponse, GetInfoRequest, GetInfoResponse, HostInfo as PbHostInfo, ListResponse,
RegisterCvmRequest, RegisterCvmResponse, TappdConfig, WireGuardConfig,
};
use tracing::{debug, error, info};
use tracing::{debug, error, info, warn};

use crate::{
config::Config,
models::{InstanceInfo, WgConf},
proxy::AddressGroup,
};

#[derive(Clone)]
pub struct AppState {
pub struct Proxy {
pub(crate) config: Arc<Config>,
inner: Arc<Mutex<AppStateInner>>,
inner: Arc<Mutex<ProxyState>>,
}

#[derive(Debug, Serialize, Deserialize)]
struct State {
struct ProxyStateMut {
apps: BTreeMap<String, BTreeSet<String>>,
instances: BTreeMap<String, InstanceInfo>,
allocated_addresses: BTreeSet<Ipv4Addr>,
#[serde(skip)]
top_n: BTreeMap<String, (AddressGroup, Instant)>,
}

pub(crate) struct AppStateInner {
pub(crate) struct ProxyState {
config: Arc<Config>,
state: State,
state: ProxyStateMut,
}

impl AppState {
pub(crate) fn lock(&self) -> MutexGuard<AppStateInner> {
impl Proxy {
pub(crate) fn lock(&self) -> MutexGuard<ProxyState> {
self.inner.lock().expect("Failed to lock AppState")
}

Expand All @@ -56,13 +60,14 @@ impl AppState {
let state_str = fs::read_to_string(state_path).context("Failed to read state")?;
serde_json::from_str(&state_str).context("Failed to load state")?
} else {
State {
ProxyStateMut {
apps: BTreeMap::new(),
top_n: BTreeMap::new(),
instances: BTreeMap::new(),
allocated_addresses: BTreeSet::new(),
}
};
let inner = Arc::new(Mutex::new(AppStateInner {
let inner = Arc::new(Mutex::new(ProxyState {
config: config.clone(),
state,
}));
Expand All @@ -71,7 +76,7 @@ impl AppState {
}
}

fn start_recycle_thread(state: Weak<Mutex<AppStateInner>>, config: Arc<Config>) {
fn start_recycle_thread(state: Weak<Mutex<ProxyState>>, config: Arc<Config>) {
if !config.recycle.enabled {
info!("recycle is disabled");
return;
Expand All @@ -87,7 +92,7 @@ fn start_recycle_thread(state: Weak<Mutex<AppStateInner>>, config: Arc<Config>)
});
}

impl AppStateInner {
impl ProxyState {
fn alloc_ip(&mut self) -> Option<Ipv4Addr> {
for ip in self.config.wg.client_ip_range.hosts() {
if ip == self.config.wg.ip {
Expand Down Expand Up @@ -166,10 +171,49 @@ impl AppStateInner {
Ok(())
}

pub(crate) fn select_a_host(&self, id: &str) -> Option<InstanceInfo> {
pub(crate) fn select_top_n_hosts(&mut self, id: &str) -> Result<AddressGroup> {
let n = self.config.proxy.connect_top_n;
if let Some(instance) = self.state.instances.get(id) {
return Ok(smallvec![instance.ip]);
};
let app_instances = self.state.apps.get(id).context("app not found")?;
if n == 0 {
// fallback to random selection
return Ok(self.random_select_a_host(id).unwrap_or_default());
}
let (top_n, insert_time) = self
.state
.top_n
.entry(id.to_string())
.or_insert((SmallVec::new(), Instant::now()));
if !top_n.is_empty() && insert_time.elapsed() < self.config.proxy.timeouts.cache_top_n {
return Ok(top_n.clone());
}

let handshakes = self.latest_handshakes(None);
let mut instances = match handshakes {
Err(err) => {
warn!("Failed to get handshakes, fallback to random selection: {err}");
return Ok(self.random_select_a_host(id).unwrap_or_default());
}
Ok(handshakes) => app_instances
.iter()
.filter_map(|instance_id| {
let instance = self.state.instances.get(instance_id)?;
let (_, elapsed) = handshakes.get(&instance.public_key)?;
Some((instance.ip, *elapsed))
})
.collect::<SmallVec<[_; 4]>>(),
};
instances.sort_by(|a, b| a.1.cmp(&b.1));
instances.truncate(n);
Ok(instances.into_iter().map(|(ip, _)| ip).collect())
}

fn random_select_a_host(&self, id: &str) -> Option<AddressGroup> {
// Direct instance lookup first
if let Some(info) = self.state.instances.get(id).cloned() {
return Some(info);
return Some(smallvec![info.ip]);
}

let app_instances = self.state.apps.get(id)?;
Expand All @@ -191,9 +235,15 @@ impl AppStateInner {
});

let selected = healthy_instances.choose(&mut rand::thread_rng())?;
self.state.instances.get(selected).cloned()
self.state
.instances
.get(selected)
.map(|info| smallvec![info.ip])
}

/// Get latest handshakes
///
/// Return a map of public key to (timestamp, elapsed)
fn latest_handshakes(
&self,
stale_timeout: Option<Duration>,
Expand All @@ -211,6 +261,8 @@ impl AppStateInner {
.arg("show")
.arg(&self.config.wg.interface)
.arg("latest-handshakes")
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.output()
.context("failed to execute wg show command")?;

Expand Down Expand Up @@ -304,7 +356,7 @@ impl AppStateInner {

pub struct RpcHandler {
attestation: Option<Attestation>,
state: AppState,
state: Proxy,
}

impl TproxyRpc for RpcHandler {
Expand Down Expand Up @@ -413,14 +465,14 @@ impl TproxyRpc for RpcHandler {
}
}

impl RpcCall<AppState> for RpcHandler {
impl RpcCall<Proxy> for RpcHandler {
type PrpcService = TproxyServer<Self>;

fn into_prpc_service(self) -> Self::PrpcService {
TproxyServer::new(self)
}

fn construct(state: &AppState, attestation: Option<Attestation>) -> Result<Self>
fn construct(state: &Proxy, attestation: Option<Attestation>) -> Result<Self>
where
Self: Sized,
{
Expand Down
4 changes: 2 additions & 2 deletions tproxy/src/main_service/tests.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use super::*;
use crate::config::{load_config_figment, Config};

fn create_test_state() -> AppState {
fn create_test_state() -> Proxy {
let figment = load_config_figment(None);
let config = figment.focus("core").extract::<Config>().unwrap();
AppState::new(config).expect("failed to create app state")
Proxy::new(config).expect("failed to create app state")
}

#[test]
Expand Down
17 changes: 12 additions & 5 deletions tproxy/src/proxy.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{net::Ipv4Addr, sync::Arc};

use anyhow::{bail, Context, Result};
use sni::extract_sni;
Expand All @@ -10,7 +10,9 @@ use tokio::{
};
use tracing::{debug, error, info};

use crate::{config::ProxyConfig, main_service::AppState};
use crate::{config::ProxyConfig, main_service::Proxy};

pub(crate) type AddressGroup = smallvec::SmallVec<[Ipv4Addr; 4]>;

mod io_bridge;
mod sni;
Expand Down Expand Up @@ -89,7 +91,7 @@ fn parse_destination(sni: &str, dotted_base_domain: &str) -> Result<DstInfo> {

async fn handle_connection(
mut inbound: TcpStream,
state: AppState,
state: Proxy,
dotted_base_domain: &str,
tls_terminate_proxy: Arc<TlsTerminateProxy>,
) -> Result<()> {
Expand Down Expand Up @@ -126,7 +128,7 @@ async fn handle_connection(
}
}

pub async fn run(config: &ProxyConfig, app_state: AppState) -> Result<()> {
pub async fn run(config: &ProxyConfig, app_state: Proxy) -> Result<()> {
let dotted_base_domain = {
let base_domain = config.base_domain.as_str();
let base_domain = base_domain.strip_prefix(".").unwrap_or(base_domain);
Expand Down Expand Up @@ -187,7 +189,7 @@ pub async fn run(config: &ProxyConfig, app_state: AppState) -> Result<()> {
}
}

pub fn start(config: ProxyConfig, app_state: AppState) {
pub fn start(config: ProxyConfig, app_state: Proxy) {
tokio::spawn(async move {
if let Err(err) = run(&config, app_state).await {
error!(
Expand All @@ -197,3 +199,8 @@ pub fn start(config: ProxyConfig, app_state: AppState) {
}
});
}

// async fn connect_to_app(state: &AppState, app_id: &str, port: u16) -> Result<TcpStream> {
// let host = state.lock().select_a_host(app_id).context(format!("tapp {app_id} not found"))?;
// TcpStream::connect((host.ip, port))
// }
Loading
Loading