From f631dd6171676dc3c764fcba49d00d6470288dc4 Mon Sep 17 00:00:00 2001 From: Zhang Jingqiang Date: Sun, 7 Apr 2024 17:04:04 +0800 Subject: [PATCH] update to quinn 0.11 --- .github/workflows/linux-musl.yml | 2 +- Cargo.lock | 131 ++---- Cargo.toml | 5 +- g3bench/Cargo.toml | 4 +- g3bench/src/target/dns/opts.rs | 4 +- g3bench/src/target/h3/opts.rs | 13 +- g3proxy/Cargo.toml | 4 +- g3proxy/src/serve/plain_quic_port/mod.rs | 10 +- g3tiles/Cargo.toml | 6 +- g3tiles/src/backend/keyless_quic/connect.rs | 6 +- g3tiles/src/serve/plain_quic_port/mod.rs | 10 +- lib/g3-daemon/src/listen/quic.rs | 30 +- lib/g3-hickory-client/Cargo.toml | 2 +- lib/g3-hickory-client/src/connect/quinn.rs | 7 +- lib/g3-hickory-client/src/io/quic.rs | 1 - .../src/limit/fixed_window/datagram.rs | 2 +- lib/g3-io-ext/src/quic/limited_socket.rs | 372 ++++++++++++++++++ lib/g3-io-ext/src/quic/mod.rs | 311 +-------------- lib/g3-io-ext/src/quic/udp_poller.rs | 73 ++++ lib/g3-io-ext/src/udp/ext.rs | 29 ++ lib/g3-resolver/src/driver/hickory/client.rs | 4 +- lib/g3-socks/Cargo.toml | 1 + lib/g3-socks/src/v5/quic.rs | 100 +---- lib/g3-types/Cargo.toml | 4 +- lib/g3-types/src/net/rustls/client.rs | 40 +- lib/g3-types/src/net/rustls/mod.rs | 4 + lib/g3-types/src/net/rustls/server.rs | 38 +- 27 files changed, 675 insertions(+), 538 deletions(-) create mode 100644 lib/g3-io-ext/src/quic/limited_socket.rs create mode 100644 lib/g3-io-ext/src/quic/udp_poller.rs diff --git a/.github/workflows/linux-musl.yml b/.github/workflows/linux-musl.yml index a4dd8114d..d6820f03e 100644 --- a/.github/workflows/linux-musl.yml +++ b/.github/workflows/linux-musl.yml @@ -20,7 +20,7 @@ on: env: CARGO_TERM_COLOR: always MUSL_TARGET: x86_64-unknown-linux-musl - MUSL_FEATURES: --no-default-features --features vendored-openssl,vendored-c-ares,hickory + MUSL_FEATURES: --no-default-features --features vendored-openssl,quic,vendored-c-ares,hickory jobs: build: diff --git a/Cargo.lock b/Cargo.lock index 79c791cd1..0deb1f5c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -248,12 +248,6 @@ dependencies = [ "rustc-demangle", ] -[[package]] -name = "base64" -version = "0.21.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" - [[package]] name = "base64" version = "0.22.1" @@ -1166,7 +1160,7 @@ dependencies = [ "hickory-proto", "http", "quinn", - "rustls 0.23.5", + "rustls", "rustls-pki-types", "thiserror", "tokio", @@ -1190,7 +1184,7 @@ version = "0.2.0" dependencies = [ "ahash", "atoi", - "base64 0.22.1", + "base64", "bytes", "g3-io-ext", "g3-types", @@ -1210,7 +1204,7 @@ version = "0.2.0" dependencies = [ "anyhow", "atoi", - "base64 0.22.1", + "base64", "bytes", "flume", "g3-h2", @@ -1299,7 +1293,7 @@ dependencies = [ "ip_network", "rand", "regex", - "rustls-pemfile 2.1.2", + "rustls-pemfile", "rustls-pki-types", "serde_json", "variant-ssl", @@ -1316,8 +1310,8 @@ dependencies = [ "g3-types", "ip_network", "rmpv", - "rustls 0.23.5", - "rustls-pemfile 2.1.2", + "rustls", + "rustls-pemfile", "rustls-pki-types", "uuid", "variant-ssl", @@ -1353,7 +1347,7 @@ dependencies = [ "hickory-proto", "indexmap", "log", - "rustls 0.23.5", + "rustls", "rustls-pki-types", "thiserror", "tokio", @@ -1419,6 +1413,7 @@ dependencies = [ "bytes", "g3-io-ext", "g3-types", + "pin-project-lite", "quinn", "smallvec", "thiserror", @@ -1492,7 +1487,7 @@ version = "0.4.0" dependencies = [ "ahash", "anyhow", - "base64 0.22.1", + "base64", "blake3", "brotli", "bytes", @@ -1517,13 +1512,14 @@ dependencies = [ "num-traits", "once_cell", "percent-encoding", + "quinn", "radix_trie", "rand", "regex", "rustc-hash", - "rustls 0.23.5", - "rustls-native-certs 0.7.0", - "rustls-pemfile 2.1.2", + "rustls", + "rustls-native-certs", + "rustls-pemfile", "rustls-pki-types", "sha-1", "slog", @@ -1586,7 +1582,7 @@ dependencies = [ "ip_network", "rand", "regex", - "rustls-pemfile 2.1.2", + "rustls-pemfile", "rustls-pki-types", "url", "variant-ssl", @@ -1633,8 +1629,8 @@ dependencies = [ "quinn", "rustc-hash", "rustc_version", - "rustls 0.23.5", - "rustls-pemfile 2.1.2", + "rustls", + "rustls-pemfile", "rustls-pki-types", "thiserror", "tokio", @@ -1794,7 +1790,7 @@ dependencies = [ "ascii", "async-recursion", "async-trait", - "base64 0.22.1", + "base64", "bitflags 2.5.0", "bytes", "capnp", @@ -1856,7 +1852,7 @@ dependencies = [ "rmpv", "rustc-hash", "rustc_version", - "rustls 0.23.5", + "rustls", "serde_json", "slog", "thiserror", @@ -1959,7 +1955,7 @@ dependencies = [ "rand", "rustc-hash", "rustc_version", - "rustls 0.23.5", + "rustls", "rustls-pki-types", "serde_json", "slog", @@ -2071,8 +2067,7 @@ dependencies = [ [[package]] name = "h3" version = "0.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1c8886b9e6e93e7ed93d9433f3779e8d07e3ff96bc67b977d14c7b20c849411" +source = "git+https://github.com/djc/h3.git?branch=quinn-0.11#8df1a403160124a5ba3568c47fde412b0bd02cb4" dependencies = [ "bytes", "fastrand", @@ -2086,14 +2081,12 @@ dependencies = [ [[package]] name = "h3-quinn" version = "0.0.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73786bcc0e4c2692ba62c650f7b950ac236e5300c5de3b1d26330555e2322046" +source = "git+https://github.com/djc/h3.git?branch=quinn-0.11#8df1a403160124a5ba3568c47fde412b0bd02cb4" dependencies = [ "bytes", "futures", "h3", "quinn", - "quinn-proto", "tokio", "tokio-util", ] @@ -2804,9 +2797,9 @@ dependencies = [ [[package]] name = "quinn" -version = "0.10.2" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8cc2c5017e4b43d5995dcea317bc46c1e09404c0a9664d2908f7f02dfe943d75" +checksum = "4bb80dc034523335a9fcc34271931dd97e9132d1fb078695db500339eb72e712" dependencies = [ "bytes", "futures-io", @@ -2814,7 +2807,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls 0.21.12", + "rustls", "thiserror", "tokio", "tracing", @@ -2822,15 +2815,15 @@ dependencies = [ [[package]] name = "quinn-proto" -version = "0.10.6" -source = "git+https://github.com/zh-jq/quinn.git?branch=ring-0.17#4a6978bdee67b244ae66241cc35abd0542765544" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a063a47a1aaee4b3b1c2dd44edb7867c10107a2ef171f3543ac40ec5e9092002" dependencies = [ "bytes", "rand", "ring", "rustc-hash", - "rustls 0.21.12", - "rustls-native-certs 0.6.3", + "rustls", "slab", "thiserror", "tinyvec", @@ -2839,15 +2832,15 @@ dependencies = [ [[package]] name = "quinn-udp" -version = "0.4.1" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "055b4e778e8feb9f93c4e439f71dc2156ef13360b432b799e179a8c4cdf0b1d7" +checksum = "cb7ad7bc932e4968523fa7d9c320ee135ff779de720e9350fee8728838551764" dependencies = [ - "bytes", "libc", + "once_cell", "socket2", "tracing", - "windows-sys 0.48.0", + "windows-sys 0.52.0", ] [[package]] @@ -3041,17 +3034,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "rustls" -version = "0.21.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" -dependencies = [ - "ring", - "rustls-webpki 0.101.7", - "sct", -] - [[package]] name = "rustls" version = "0.23.5" @@ -3061,23 +3043,11 @@ dependencies = [ "once_cell", "ring", "rustls-pki-types", - "rustls-webpki 0.102.3", + "rustls-webpki", "subtle", "zeroize", ] -[[package]] -name = "rustls-native-certs" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" -dependencies = [ - "openssl-probe", - "rustls-pemfile 1.0.4", - "schannel", - "security-framework", -] - [[package]] name = "rustls-native-certs" version = "0.7.0" @@ -3085,28 +3055,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" dependencies = [ "openssl-probe", - "rustls-pemfile 2.1.2", + "rustls-pemfile", "rustls-pki-types", "schannel", "security-framework", ] -[[package]] -name = "rustls-pemfile" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" -dependencies = [ - "base64 0.21.7", -] - [[package]] name = "rustls-pemfile" version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" dependencies = [ - "base64 0.22.1", + "base64", "rustls-pki-types", ] @@ -3116,16 +3077,6 @@ version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" -[[package]] -name = "rustls-webpki" -version = "0.101.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" -dependencies = [ - "ring", - "untrusted", -] - [[package]] name = "rustls-webpki" version = "0.102.3" @@ -3158,16 +3109,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" -[[package]] -name = "sct" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" -dependencies = [ - "ring", - "untrusted", -] - [[package]] name = "security-framework" version = "2.11.0" @@ -3517,7 +3458,7 @@ version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls 0.23.5", + "rustls", "rustls-pki-types", "tokio", ] diff --git a/Cargo.toml b/Cargo.toml index a77bc4362..339e00c66 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -152,7 +152,7 @@ rustls = { version = "0.23.5", default-features = false, features = ["std", "tls rustls-pki-types = "1" rustls-pemfile = "2" tokio-rustls = { version = "0.26", default-features = false, features = ["tls12", "ring"] } -quinn = { version = "0.10", default-features = false, features = ["native-certs"] } +quinn = { version = "0.11", default-features = false, features = ["runtime-tokio"] } # openssl = { package = "variant-ssl", version = "0.14.2" } openssl-sys = { package = "variant-ssl-sys", version = "0.13.0" } @@ -237,4 +237,5 @@ debug = 1 debug-assertions = false [patch.crates-io] -quinn-proto = { version = "0.10.6", git = "https://github.com/zh-jq/quinn.git", branch = "ring-0.17" } +h3 = { version = "0.0.4", git = "https://github.com/djc/h3.git", branch = "quinn-0.11" } +h3-quinn = { version = "0.0.5", git = "https://github.com/djc/h3.git", branch = "quinn-0.11" } diff --git a/g3bench/Cargo.toml b/g3bench/Cargo.toml index 2050c032f..9b1fda44f 100644 --- a/g3bench/Cargo.toml +++ b/g3bench/Cargo.toml @@ -20,7 +20,7 @@ url.workspace = true h2.workspace = true h3 = { workspace = true, optional = true } h3-quinn = { workspace = true, optional = true } -quinn = { workspace = true, optional = true, features = ["tls-rustls", "runtime-tokio"] } +quinn = { workspace = true, optional = true, features = ["rustls"] } bytes.workspace = true futures-util.workspace = true atomic-waker.workspace = true @@ -57,7 +57,7 @@ g3-hickory-client.workspace = true rustc_version.workspace = true [features] -default = [] +default = ["quic"] quic = ["g3-types/quic", "g3-socks/quic", "g3-io-ext/quic", "g3-hickory-client/quic", "dep:quinn", "dep:h3", "dep:h3-quinn"] vendored-openssl = ["openssl/vendored", "openssl-probe"] vendored-tongsuo = ["openssl/tongsuo", "openssl-probe", "g3-types/tongsuo"] diff --git a/g3bench/src/target/dns/opts.rs b/g3bench/src/target/dns/opts.rs index 7f37ddd88..ac187a4bc 100644 --- a/g3bench/src/target/dns/opts.rs +++ b/g3bench/src/target/dns/opts.rs @@ -251,7 +251,7 @@ impl BenchDnsArgs { ) -> anyhow::Result { let tls_name = match &self.tls.tls_name { Some(ServerName::DnsName(domain)) => domain.as_ref().to_string(), - Some(ServerName::IpAddress(ip)) => ip.to_string(), + Some(ServerName::IpAddress(ip)) => IpAddr::from(*ip).to_string(), Some(_) => return Err(anyhow!("unsupported tls server name type")), None => self.target.ip().to_string(), }; @@ -279,7 +279,7 @@ impl BenchDnsArgs { ) -> anyhow::Result { let tls_name = match &self.tls.tls_name { Some(ServerName::DnsName(domain)) => domain.as_ref().to_string(), - Some(ServerName::IpAddress(ip)) => ip.to_string(), + Some(ServerName::IpAddress(ip)) => IpAddr::from(*ip).to_string(), Some(_) => return Err(anyhow!("unsupported tls server name type")), None => self.target.ip().to_string(), }; diff --git a/g3bench/src/target/h3/opts.rs b/g3bench/src/target/h3/opts.rs index 940abc4ce..16852a435 100644 --- a/g3bench/src/target/h3/opts.rs +++ b/g3bench/src/target/h3/opts.rs @@ -25,8 +25,9 @@ use clap::{value_parser, Arg, ArgAction, ArgMatches, Command}; use h3::client::SendRequest; use h3_quinn::OpenStreams; use http::{HeaderValue, Method, StatusCode}; -use quinn::{Endpoint, TokioRuntime}; -use rustls::ServerName; +use quinn::crypto::rustls::QuicClientConfig; +use quinn::{ClientConfig, Endpoint, TokioRuntime, TransportConfig, VarInt}; +use rustls_pki_types::ServerName; use tokio::net::TcpStream; use url::Url; @@ -220,8 +221,6 @@ impl BenchH3Args { stats: &Arc, proc_args: &ProcArgs, ) -> anyhow::Result { - use quinn::{ClientConfig, TransportConfig, VarInt}; - let addrs = self .quic_peer_addrs .as_ref() @@ -239,12 +238,14 @@ impl BenchH3Args { // https://http3-explained.haxx.se/en/h3/h3-streams // transport.max_concurrent_uni_streams(VarInt::from_u32(0)); // TODO add more transport settings - let mut client_config = ClientConfig::new(tls_client.driver.clone()); + let quic_config = QuicClientConfig::try_from(tls_client.driver.as_ref().clone()) + .map_err(|e| anyhow!("invalid quic tls config: {e}"))?; + let mut client_config = ClientConfig::new(Arc::new(quic_config)); client_config.transport_config(Arc::new(transport)); let tls_name = match &self.target_tls.tls_name { Some(ServerName::DnsName(domain)) => domain.as_ref().to_string(), - Some(ServerName::IpAddress(ip)) => ip.to_string(), + Some(ServerName::IpAddress(ip)) => IpAddr::from(*ip).to_string(), Some(_) => return Err(anyhow!("unsupported tls server name type")), None => self.target.host().to_string(), }; diff --git a/g3proxy/Cargo.toml b/g3proxy/Cargo.toml index 704df4815..fbc7a232d 100644 --- a/g3proxy/Cargo.toml +++ b/g3proxy/Cargo.toml @@ -24,7 +24,7 @@ tokio = { workspace = true, features = ["rt-multi-thread", "rt", "signal", "sync tokio-util = { workspace = true, features = ["time"] } tokio-rustls.workspace = true rustls.workspace = true -quinn = { workspace = true, optional = true, features = ["tls-rustls", "runtime-tokio"] } +quinn = { workspace = true, optional = true, features = ["rustls"] } openssl.workspace = true openssl-probe = { workspace = true, optional = true } indexmap.workspace = true @@ -96,7 +96,7 @@ tokio-util = { workspace = true, features = ["io"] } rustc_version.workspace = true [features] -default = ["lua54", "python", "c-ares", "hickory"] +default = ["lua54", "python", "c-ares", "hickory", "quic"] lua = ["mlua"] luajit = ["lua", "mlua/luajit"] lua51 = ["lua", "mlua/lua51"] diff --git a/g3proxy/src/serve/plain_quic_port/mod.rs b/g3proxy/src/serve/plain_quic_port/mod.rs index d7e420d72..f86e37c70 100644 --- a/g3proxy/src/serve/plain_quic_port/mod.rs +++ b/g3proxy/src/serve/plain_quic_port/mod.rs @@ -96,7 +96,7 @@ impl PlainQuicPort { ) -> anyhow::Result { let reload_sender = crate::serve::new_reload_notify_channel(); - let tls_server = config.tls_server.build()?; + let quic_server = config.tls_server.build_quic()?; let ingress_net_filter = config .ingress_net_filter @@ -109,7 +109,7 @@ impl PlainQuicPort { ingress_net_filter, listen_config: None, quinn_config: None, - accept_timeout: tls_server.accept_timeout, + accept_timeout: quic_server.accept_timeout, offline_rebind_port: config.offline_rebind_port, }; let (cfg_sender, _cfg_receiver) = watch::channel(aux_config); @@ -117,7 +117,7 @@ impl PlainQuicPort { Ok(PlainQuicPort { name: config.name().clone(), config: ArcSwap::new(config), - quinn_config: quinn::ServerConfig::with_crypto(tls_server.driver), + quinn_config: quinn::ServerConfig::with_crypto(quic_server.driver), listen_stats, reload_sender, cfg_sender, @@ -164,8 +164,8 @@ impl ServerInternal for PlainQuicPort { }; let quinn_config = if flags.contains(PlainQuicPortUpdateFlags::QUINN) { - let tls_config = config.tls_server.build()?; - Some(quinn::ServerConfig::with_crypto(tls_config.driver)) + let quic_config = config.tls_server.build_quic()?; + Some(quinn::ServerConfig::with_crypto(quic_config.driver)) } else { None }; diff --git a/g3tiles/Cargo.toml b/g3tiles/Cargo.toml index 171519fb3..ec90c2d64 100644 --- a/g3tiles/Cargo.toml +++ b/g3tiles/Cargo.toml @@ -33,7 +33,7 @@ openssl.workspace = true openssl-probe = { workspace = true, optional = true } rustls.workspace = true rustls-pki-types.workspace = true -quinn = { workspace = true, optional = true, features = ["tls-rustls", "runtime-tokio"] } +quinn = { workspace = true, optional = true, features = ["rustls"] } tokio-rustls.workspace = true governor = { workspace = true, features = ["std", "jitter"] } chrono = { workspace = true, features = ["clock"] } @@ -57,8 +57,8 @@ g3tiles-proto = { path = "proto" } rustc_version.workspace = true [features] -default = [] -quic = ["g3-daemon/quic", "dep:quinn"] +default = ["quic"] +quic = ["g3-daemon/quic", "g3-types/quinn", "dep:quinn"] vendored-openssl = ["openssl/vendored", "openssl-probe"] vendored-tongsuo = ["openssl/tongsuo", "openssl-probe", "g3-yaml/tongsuo", "g3-types/tongsuo"] vendored-aws-lc = ["openssl/aws-lc", "openssl-probe", "g3-types/aws-lc", "g3-openssl/aws-lc"] diff --git a/g3tiles/src/backend/keyless_quic/connect.rs b/g3tiles/src/backend/keyless_quic/connect.rs index 9b37672f9..c62e23433 100644 --- a/g3tiles/src/backend/keyless_quic/connect.rs +++ b/g3tiles/src/backend/keyless_quic/connect.rs @@ -28,7 +28,7 @@ use tokio::time::Instant; use g3_types::collection::{SelectiveVec, WeightedValue}; use g3_types::ext::DurationExt; -use g3_types::net::RustlsClientConfig; +use g3_types::net::RustlsQuicClientConfig; use crate::config::backend::keyless_quic::KeylessQuicBackendConfig; use crate::module::keyless::{ @@ -41,7 +41,7 @@ pub(super) struct KeylessQuicUpstreamConnector { stats: Arc, duration_recorder: Arc, peer_addrs: Arc>>>, - tls_client: RustlsClientConfig, + tls_client: RustlsQuicClientConfig, } impl KeylessQuicUpstreamConnector { @@ -51,7 +51,7 @@ impl KeylessQuicUpstreamConnector { duration_recorder: Arc, peer_addrs_container: Arc>>>, ) -> anyhow::Result { - let tls_client = config.tls_client.build()?; + let tls_client = config.tls_client.build_quic()?; Ok(KeylessQuicUpstreamConnector { config, stats, diff --git a/g3tiles/src/serve/plain_quic_port/mod.rs b/g3tiles/src/serve/plain_quic_port/mod.rs index 0a8a6a008..330e2b506 100644 --- a/g3tiles/src/serve/plain_quic_port/mod.rs +++ b/g3tiles/src/serve/plain_quic_port/mod.rs @@ -94,7 +94,7 @@ impl PlainQuicPort { ) -> anyhow::Result { let reload_sender = crate::serve::new_reload_notify_channel(); - let tls_server = config.tls_server.build()?; + let quic_server = config.tls_server.build_quic()?; let ingress_net_filter = config .ingress_net_filter @@ -107,7 +107,7 @@ impl PlainQuicPort { ingress_net_filter, listen_config: None, quinn_config: None, - accept_timeout: tls_server.accept_timeout, + accept_timeout: quic_server.accept_timeout, offline_rebind_port: config.offline_rebind_port, }; let (cfg_sender, _cfg_receiver) = watch::channel(aux_config); @@ -115,7 +115,7 @@ impl PlainQuicPort { Ok(PlainQuicPort { name: config.name().clone(), config: ArcSwap::new(config), - quinn_config: quinn::ServerConfig::with_crypto(tls_server.driver), + quinn_config: quinn::ServerConfig::with_crypto(quic_server.driver), listen_stats, reload_sender, cfg_sender, @@ -162,8 +162,8 @@ impl ServerInternal for PlainQuicPort { }; let quinn_config = if flags.contains(PlainQuicPortUpdateFlags::QUINN) { - let tls_config = config.tls_server.build()?; - Some(quinn::ServerConfig::with_crypto(tls_config.driver)) + let quic_config = config.tls_server.build_quic()?; + Some(quinn::ServerConfig::with_crypto(quic_config.driver)) } else { None }; diff --git a/lib/g3-daemon/src/listen/quic.rs b/lib/g3-daemon/src/listen/quic.rs index 73b8feb1f..99aa73dba 100644 --- a/lib/g3-daemon/src/listen/quic.rs +++ b/lib/g3-daemon/src/listen/quic.rs @@ -22,7 +22,7 @@ use std::time::Duration; use async_trait::async_trait; use log::{info, warn}; -use quinn::{Connecting, Connection, Endpoint}; +use quinn::{Connection, Endpoint, Incoming}; use tokio::runtime::Handle; use tokio::sync::{broadcast, watch}; @@ -182,22 +182,22 @@ where } } result = listener.accept() => { - let Some(connecting) = result else { + let Some(incoming) = result else { continue; }; self.listen_stats.add_accepted(); - self.run_task(connecting, listen_addr, &aux_config); + self.run_task(incoming, listen_addr, &aux_config); } } } self.post_stop(); } - fn run_task(&self, connecting: Connecting, listen_addr: SocketAddr, aux_config: &C) + fn run_task(&self, incoming: Incoming, listen_addr: SocketAddr, aux_config: &C) where C: ListenQuicConf + Send + Clone + 'static, { - let peer_addr = connecting.remote_address(); + let peer_addr = incoming.remote_address(); if let Some(filter) = aux_config.ingress_network_acl() { let (_, action) = filter.check(peer_addr.ip()); match action { @@ -209,7 +209,7 @@ where } } - let local_addr = connecting + let local_addr = incoming .local_ip() .map(|ip| SocketAddr::new(ip, listen_addr.port())) .unwrap_or(listen_addr); @@ -226,7 +226,7 @@ where tokio::spawn(async move { Self::accept_connection_and_run( server, - connecting, + incoming, cc_info, accept_timeout, listen_stats, @@ -238,7 +238,7 @@ where rt.handle.spawn(async move { Self::accept_connection_and_run( server, - connecting, + incoming, cc_info, accept_timeout, listen_stats, @@ -249,7 +249,7 @@ where tokio::spawn(async move { Self::accept_connection_and_run( server, - connecting, + incoming, cc_info, accept_timeout, listen_stats, @@ -261,11 +261,19 @@ where async fn accept_connection_and_run( server: S, - connecting: Connecting, + incoming: Incoming, cc_info: ClientConnectionInfo, timeout: Duration, listen_stats: Arc, ) { + let connecting = match incoming.accept() { + Ok(c) => c, + Err(_e) => { + listen_stats.add_failed(); + // TODO may be attack + return; + } + }; match tokio::time::timeout(timeout, connecting).await { Ok(Ok(c)) => { listen_stats.add_accepted(); @@ -349,7 +357,7 @@ where self.server_version, self.instance_id ); - listener.reject_new_connections(); + // listener.reject_new_connections(); tokio::spawn(async move { listener.wait_idle().await }); return; } diff --git a/lib/g3-hickory-client/Cargo.toml b/lib/g3-hickory-client/Cargo.toml index 3e9fa15fe..0153dcae4 100644 --- a/lib/g3-hickory-client/Cargo.toml +++ b/lib/g3-hickory-client/Cargo.toml @@ -18,7 +18,7 @@ tokio = { workspace = true, features = ["net", "time"] } rustls.workspace = true rustls-pki-types.workspace = true tokio-rustls.workspace = true -quinn = { workspace = true, optional = true, features = ["runtime-tokio", "tls-rustls"] } +quinn = { workspace = true, optional = true, features = ["rustls"] } h3 = { workspace = true, optional = true } h3-quinn = { workspace = true, optional = true } diff --git a/lib/g3-hickory-client/src/connect/quinn.rs b/lib/g3-hickory-client/src/connect/quinn.rs index 82f96d6d1..ca43ccf62 100644 --- a/lib/g3-hickory-client/src/connect/quinn.rs +++ b/lib/g3-hickory-client/src/connect/quinn.rs @@ -18,6 +18,7 @@ use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}; use std::sync::Arc; use hickory_proto::error::ProtoError; +use quinn::crypto::rustls::QuicClientConfig; use quinn::{Connection, Endpoint, EndpointConfig, TokioRuntime}; use rustls::ClientConfig; @@ -41,9 +42,11 @@ pub(crate) async fn quic_connect( if tls_config.alpn_protocols.is_empty() { tls_config.alpn_protocols = vec![alpn_protocol.to_vec()]; } - let quinn_config = quinn::ClientConfig::new(Arc::new(tls_config)); + let quic_config = QuicClientConfig::try_from(tls_config) + .map_err(|e| format!("invalid quic tls config: {e}"))?; + let client_config = quinn::ClientConfig::new(Arc::new(quic_config)); // TODO set transport config - endpoint.set_default_client_config(quinn_config); + endpoint.set_default_client_config(client_config); let connection = endpoint .connect(name_server, tls_name) diff --git a/lib/g3-hickory-client/src/io/quic.rs b/lib/g3-hickory-client/src/io/quic.rs index ac32887a6..b35cf2ae5 100644 --- a/lib/g3-hickory-client/src/io/quic.rs +++ b/lib/g3-hickory-client/src/io/quic.rs @@ -139,7 +139,6 @@ async fn quic_send_recv( // and MUST indicate through the STREAM FIN mechanism that no further data will be sent on that stream. send_stream .finish() - .await .map_err(|e| format!("quic mark finish error: {e}"))?; quic_recv(recv_stream).await diff --git a/lib/g3-io-ext/src/limit/fixed_window/datagram.rs b/lib/g3-io-ext/src/limit/fixed_window/datagram.rs index f11314c35..81fe2e1e7 100644 --- a/lib/g3-io-ext/src/limit/fixed_window/datagram.rs +++ b/lib/g3-io-ext/src/limit/fixed_window/datagram.rs @@ -42,7 +42,7 @@ impl<'a, const C: usize> HasPacketSize for RecvMsgHdr<'a, C> { } #[cfg(feature = "quic")] -impl HasPacketSize for quinn::udp::Transmit { +impl<'a> HasPacketSize for quinn::udp::Transmit<'a> { fn packet_size(&self) -> usize { self.contents.len() } diff --git a/lib/g3-io-ext/src/quic/limited_socket.rs b/lib/g3-io-ext/src/quic/limited_socket.rs new file mode 100644 index 000000000..3d970e195 --- /dev/null +++ b/lib/g3-io-ext/src/quic/limited_socket.rs @@ -0,0 +1,372 @@ +/* + * Copyright 2024 ByteDance and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::cell::UnsafeCell; +use std::fmt; +use std::future::Future; +use std::io::{self, IoSliceMut}; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::task::{ready, Context, Poll}; +use std::time::Duration; + +use futures_util::FutureExt; +use quinn::udp; +use quinn::{AsyncTimer, AsyncUdpSocket, Runtime, UdpPoller}; +use tokio::time::{Instant, Sleep}; + +use crate::limit::{DatagramLimitInfo, DatagramLimitResult}; +use crate::{ArcLimitedRecvStats, ArcLimitedSendStats, LimitedRecvStats, LimitedSendStats}; + +struct LimitConf { + shift_millis: u8, + max_send_packets: usize, + max_send_bytes: usize, + max_recv_packets: usize, + max_recv_bytes: usize, +} + +pub struct LimitedTokioRuntime { + inner: R, + limit: Option, + stats: Arc, +} + +impl LimitedTokioRuntime { + pub fn new( + inner: R, + shift_millis: u8, + max_send_packets: usize, + max_send_bytes: usize, + max_recv_packets: usize, + max_recv_bytes: usize, + stats: Arc, + ) -> Self { + let limit = LimitConf { + shift_millis, + max_send_packets, + max_send_bytes, + max_recv_packets, + max_recv_bytes, + }; + LimitedTokioRuntime { + inner, + limit: Some(limit), + stats, + } + } + + pub fn new_unlimited(inner: R, stats: Arc) -> Self { + LimitedTokioRuntime { + inner, + limit: None, + stats, + } + } +} + +impl fmt::Debug for LimitedTokioRuntime { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.inner.fmt(f) + } +} + +impl Runtime for LimitedTokioRuntime +where + ST: LimitedSendStats + LimitedRecvStats + Send + Sync + 'static, +{ + fn new_timer(&self, t: std::time::Instant) -> Pin> { + self.inner.new_timer(t) + } + + fn spawn(&self, future: Pin + Send>>) { + self.inner.spawn(future); + } + + fn wrap_udp_socket(&self, sock: std::net::UdpSocket) -> io::Result> { + let inner = self.inner.wrap_udp_socket(sock)?; + if let Some(limit) = &self.limit { + Ok(Arc::new(LimitedUdpSocket::new( + inner, + limit.shift_millis, + limit.max_send_packets, + limit.max_send_bytes, + limit.max_recv_packets, + limit.max_recv_bytes, + self.stats.clone(), + ))) + } else { + Ok(Arc::new(LimitedUdpSocket::new_unlimited( + inner, + self.stats.clone(), + ))) + } + } +} + +struct LimitedSendLimitState { + delay: Pin>, + poll_delay: bool, + limit: DatagramLimitInfo, +} + +struct LimitedSendState { + started: Instant, + limit: Option>>, + stats: ArcLimitedSendStats, +} + +impl LimitedSendState { + fn new( + started: Instant, + shift_millis: u8, + max_packets: usize, + max_bytes: usize, + stats: ArcLimitedSendStats, + ) -> Self { + let limit = LimitedSendLimitState { + delay: Box::pin(tokio::time::sleep(Duration::from_millis(0))), + poll_delay: false, + limit: DatagramLimitInfo::new(shift_millis, max_packets, max_bytes), + }; + LimitedSendState { + started, + limit: Some(Arc::new(Mutex::new(limit))), + stats, + } + } + + fn new_unlimited(started: Instant, stats: ArcLimitedSendStats) -> Self { + LimitedSendState { + started, + limit: None, + stats, + } + } +} + +struct LimitedRecvState { + delay: Pin>, + started: Instant, + limit: DatagramLimitInfo, + stats: ArcLimitedRecvStats, +} + +impl LimitedRecvState { + fn new( + started: Instant, + shift_millis: u8, + max_packets: usize, + max_bytes: usize, + stats: ArcLimitedRecvStats, + ) -> Self { + LimitedRecvState { + delay: Box::pin(tokio::time::sleep(Duration::from_millis(0))), + started, + limit: DatagramLimitInfo::new(shift_millis, max_packets, max_bytes), + stats, + } + } + + fn new_unlimited(started: Instant, stats: ArcLimitedRecvStats) -> Self { + LimitedRecvState { + delay: Box::pin(tokio::time::sleep(Duration::from_millis(0))), + started, + limit: DatagramLimitInfo::default(), + stats, + } + } +} + +struct LimitedUdpPoller { + inner: Pin>, + limit: Option>>, +} + +impl fmt::Debug for LimitedUdpPoller { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.inner.fmt(f) + } +} + +impl UdpPoller for LimitedUdpPoller { + fn poll_writable(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + if let Some(l) = &self.limit { + let mut l = l.lock().unwrap(); + if l.poll_delay { + ready!(Future::poll(l.delay.as_mut(), cx)); + l.poll_delay = false; + return Poll::Ready(Ok(())); + } + } + self.inner.as_mut().poll_writable(cx) + } +} + +pub struct LimitedUdpSocket { + inner: Arc, + send_state: LimitedSendState, + recv_state: UnsafeCell, +} + +unsafe impl Sync for LimitedUdpSocket {} + +impl LimitedUdpSocket { + fn new( + inner: Arc, + shift_millis: u8, + max_send_packets: usize, + max_send_bytes: usize, + max_recv_packets: usize, + max_recv_bytes: usize, + stats: Arc, + ) -> Self + where + ST: LimitedSendStats + LimitedRecvStats + Send + Sync + 'static, + { + let started = Instant::now(); + let send_state = LimitedSendState::new( + started, + shift_millis, + max_send_packets, + max_send_bytes, + stats.clone() as _, + ); + let recv_state = LimitedRecvState::new( + started, + shift_millis, + max_recv_packets, + max_recv_bytes, + stats as _, + ); + LimitedUdpSocket { + inner, + send_state, + recv_state: UnsafeCell::new(recv_state), + } + } + + fn new_unlimited(inner: Arc, stats: Arc) -> Self + where + ST: LimitedSendStats + LimitedRecvStats + Send + Sync + 'static, + { + let started = Instant::now(); + let send_state = LimitedSendState::new_unlimited(started, stats.clone() as _); + let recv_state = LimitedRecvState::new_unlimited(started, stats as _); + LimitedUdpSocket { + inner, + send_state, + recv_state: UnsafeCell::new(recv_state), + } + } +} + +impl fmt::Debug for LimitedUdpSocket { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.inner.fmt(f) + } +} + +impl AsyncUdpSocket for LimitedUdpSocket { + fn create_io_poller(self: Arc) -> Pin> { + Box::pin(LimitedUdpPoller { + inner: self.inner.clone().create_io_poller(), + limit: self.send_state.limit.clone(), + }) + } + + fn try_send(&self, transmit: &udp::Transmit) -> io::Result<()> { + let len = transmit.contents.len(); + if let Some(l) = &self.send_state.limit { + let dur_millis = self.send_state.started.elapsed().as_millis() as u64; + let mut l = l.lock().unwrap(); + match l.limit.check_packet(dur_millis, len) { + DatagramLimitResult::Advance(_) => { + self.inner.try_send(transmit)?; + l.limit.set_advance(1, len); + self.send_state.stats.add_send_packet(); + self.send_state.stats.add_send_bytes(len); + Ok(()) + } + DatagramLimitResult::DelayFor(ms) => { + l.delay + .as_mut() + .reset(self.send_state.started + Duration::from_millis(dur_millis + ms)); + l.poll_delay = true; + Err(io::Error::new( + io::ErrorKind::WouldBlock, + "delayed by rate limiter", + )) + } + } + } else { + self.inner.try_send(transmit)?; + self.send_state.stats.add_send_packet(); + self.send_state.stats.add_send_bytes(len); + Ok(()) + } + } + + fn poll_recv( + &self, + cx: &mut Context, + bufs: &mut [IoSliceMut<'_>], + meta: &mut [udp::RecvMeta], + ) -> Poll> { + let l = unsafe { &mut *self.recv_state.get() }; + if l.limit.is_set() { + let dur_millis = l.started.elapsed().as_millis() as u64; + match l.limit.check_packets(dur_millis, bufs) { + DatagramLimitResult::Advance(n) => { + let nr = ready!(self.inner.poll_recv(cx, &mut bufs[0..n], &mut meta[0..n]))?; + let len = bufs.iter().take(nr).map(|v| v.len()).sum(); + l.limit.set_advance(nr, len); + l.stats.add_recv_packets(nr); + l.stats.add_recv_bytes(len); + Poll::Ready(Ok(nr)) + } + DatagramLimitResult::DelayFor(ms) => { + l.delay + .as_mut() + .reset(l.started + Duration::from_millis(dur_millis + ms)); + l.delay.poll_unpin(cx).map(|_| Ok(0)) + } + } + } else { + let nr = ready!(self.inner.poll_recv(cx, bufs, meta))?; + let len = bufs.iter().take(nr).map(|v| v.len()).sum(); + l.stats.add_recv_packets(nr); + l.stats.add_recv_bytes(len); + Poll::Ready(Ok(nr)) + } + } + + fn local_addr(&self) -> io::Result { + self.inner.local_addr() + } + + fn max_transmit_segments(&self) -> usize { + self.inner.max_transmit_segments() + } + + fn max_receive_segments(&self) -> usize { + self.inner.max_receive_segments() + } + + fn may_fragment(&self) -> bool { + self.inner.may_fragment() + } +} diff --git a/lib/g3-io-ext/src/quic/mod.rs b/lib/g3-io-ext/src/quic/mod.rs index 9e3b34e03..ada635677 100644 --- a/lib/g3-io-ext/src/quic/mod.rs +++ b/lib/g3-io-ext/src/quic/mod.rs @@ -14,311 +14,8 @@ * limitations under the License. */ -use std::cell::UnsafeCell; -use std::fmt; -use std::future::Future; -use std::io::{self, IoSliceMut}; -use std::pin::Pin; -use std::sync::Arc; -use std::task::{ready, Context, Poll}; -use std::time::Duration; +mod udp_poller; +pub use udp_poller::QuinnUdpPollHelper; -use futures_util::FutureExt; -use quinn::udp; -use quinn::{AsyncTimer, AsyncUdpSocket, Runtime}; -use tokio::time::{Instant, Sleep}; - -use crate::limit::{DatagramLimitInfo, DatagramLimitResult}; -use crate::{ArcLimitedRecvStats, ArcLimitedSendStats, LimitedRecvStats, LimitedSendStats}; - -struct LimitConf { - shift_millis: u8, - max_send_packets: usize, - max_send_bytes: usize, - max_recv_packets: usize, - max_recv_bytes: usize, -} - -pub struct LimitedTokioRuntime { - inner: R, - limit: Option, - stats: Arc, -} - -impl LimitedTokioRuntime { - pub fn new( - inner: R, - shift_millis: u8, - max_send_packets: usize, - max_send_bytes: usize, - max_recv_packets: usize, - max_recv_bytes: usize, - stats: Arc, - ) -> Self { - let limit = LimitConf { - shift_millis, - max_send_packets, - max_send_bytes, - max_recv_packets, - max_recv_bytes, - }; - LimitedTokioRuntime { - inner, - limit: Some(limit), - stats, - } - } - - pub fn new_unlimited(inner: R, stats: Arc) -> Self { - LimitedTokioRuntime { - inner, - limit: None, - stats, - } - } -} - -impl fmt::Debug for LimitedTokioRuntime { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.inner.fmt(f) - } -} - -impl Runtime for LimitedTokioRuntime -where - ST: LimitedSendStats + LimitedRecvStats + Send + Sync + 'static, -{ - fn new_timer(&self, t: std::time::Instant) -> Pin> { - self.inner.new_timer(t) - } - - fn spawn(&self, future: Pin + Send>>) { - self.inner.spawn(future); - } - - fn wrap_udp_socket(&self, sock: std::net::UdpSocket) -> io::Result> { - let inner = self.inner.wrap_udp_socket(sock)?; - if let Some(limit) = &self.limit { - Ok(Box::new(LimitedUdpSocket::new( - inner, - limit.shift_millis, - limit.max_send_packets, - limit.max_send_bytes, - limit.max_recv_packets, - limit.max_recv_bytes, - self.stats.clone(), - ))) - } else { - Ok(Box::new(LimitedUdpSocket::new_unlimited( - inner, - self.stats.clone(), - ))) - } - } -} - -struct LimitedSendState { - delay: Pin>, - started: Instant, - limit: DatagramLimitInfo, - stats: ArcLimitedSendStats, -} - -impl LimitedSendState { - fn new( - started: Instant, - shift_millis: u8, - max_packets: usize, - max_bytes: usize, - stats: ArcLimitedSendStats, - ) -> Self { - LimitedSendState { - delay: Box::pin(tokio::time::sleep(Duration::from_millis(0))), - started, - limit: DatagramLimitInfo::new(shift_millis, max_packets, max_bytes), - stats, - } - } - - fn new_unlimited(started: Instant, stats: ArcLimitedSendStats) -> Self { - LimitedSendState { - delay: Box::pin(tokio::time::sleep(Duration::from_millis(0))), - started, - limit: DatagramLimitInfo::default(), - stats, - } - } -} - -struct LimitedRecvState { - delay: Pin>, - started: Instant, - limit: DatagramLimitInfo, - stats: ArcLimitedRecvStats, -} - -impl LimitedRecvState { - fn new( - started: Instant, - shift_millis: u8, - max_packets: usize, - max_bytes: usize, - stats: ArcLimitedRecvStats, - ) -> Self { - LimitedRecvState { - delay: Box::pin(tokio::time::sleep(Duration::from_millis(0))), - started, - limit: DatagramLimitInfo::new(shift_millis, max_packets, max_bytes), - stats, - } - } - - fn new_unlimited(started: Instant, stats: ArcLimitedRecvStats) -> Self { - LimitedRecvState { - delay: Box::pin(tokio::time::sleep(Duration::from_millis(0))), - started, - limit: DatagramLimitInfo::default(), - stats, - } - } -} - -pub struct LimitedUdpSocket { - inner: Box, - send_state: UnsafeCell, - recv_state: UnsafeCell, -} - -impl LimitedUdpSocket { - fn new( - inner: Box, - shift_millis: u8, - max_send_packets: usize, - max_send_bytes: usize, - max_recv_packets: usize, - max_recv_bytes: usize, - stats: Arc, - ) -> Self - where - ST: LimitedSendStats + LimitedRecvStats + Send + Sync + 'static, - { - let started = Instant::now(); - let send_state = LimitedSendState::new( - started, - shift_millis, - max_send_packets, - max_send_bytes, - stats.clone() as _, - ); - let recv_state = LimitedRecvState::new( - started, - shift_millis, - max_recv_packets, - max_recv_bytes, - stats as _, - ); - LimitedUdpSocket { - inner, - send_state: UnsafeCell::new(send_state), - recv_state: UnsafeCell::new(recv_state), - } - } - - fn new_unlimited(inner: Box, stats: Arc) -> Self - where - ST: LimitedSendStats + LimitedRecvStats + Send + Sync + 'static, - { - let started = Instant::now(); - let send_state = LimitedSendState::new_unlimited(started, stats.clone() as _); - let recv_state = LimitedRecvState::new_unlimited(started, stats as _); - LimitedUdpSocket { - inner, - send_state: UnsafeCell::new(send_state), - recv_state: UnsafeCell::new(recv_state), - } - } -} - -impl fmt::Debug for LimitedUdpSocket { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.inner.fmt(f) - } -} - -impl AsyncUdpSocket for LimitedUdpSocket { - fn poll_send( - &self, - state: &udp::UdpState, - cx: &mut Context, - transmits: &[udp::Transmit], - ) -> Poll> { - let l = unsafe { &mut *self.send_state.get() }; - if l.limit.is_set() { - let dur_millis = l.started.elapsed().as_millis() as u64; - match l.limit.check_packets(dur_millis, transmits) { - DatagramLimitResult::Advance(n) => { - let nw = ready!(self.inner.poll_send(state, cx, &transmits[0..n]))?; - let len = transmits.iter().take(nw).map(|v| v.contents.len()).sum(); - l.limit.set_advance(nw, len); - l.stats.add_send_packets(nw); - l.stats.add_send_bytes(len); - Poll::Ready(Ok(nw)) - } - DatagramLimitResult::DelayFor(ms) => { - l.delay - .as_mut() - .reset(l.started + Duration::from_millis(dur_millis + ms)); - l.delay.poll_unpin(cx).map(|_| Ok(0)) - } - } - } else { - let nw = ready!(self.inner.poll_send(state, cx, transmits))?; - let len = transmits.iter().take(nw).map(|v| v.contents.len()).sum(); - l.stats.add_send_packets(nw); - l.stats.add_send_bytes(len); - Poll::Ready(Ok(nw)) - } - } - - fn poll_recv( - &self, - cx: &mut Context, - bufs: &mut [IoSliceMut<'_>], - meta: &mut [udp::RecvMeta], - ) -> Poll> { - let l = unsafe { &mut *self.recv_state.get() }; - if l.limit.is_set() { - let dur_millis = l.started.elapsed().as_millis() as u64; - match l.limit.check_packets(dur_millis, bufs) { - DatagramLimitResult::Advance(n) => { - let nr = ready!(self.inner.poll_recv(cx, &mut bufs[0..n], &mut meta[0..n]))?; - let len = bufs.iter().take(nr).map(|v| v.len()).sum(); - l.limit.set_advance(nr, len); - l.stats.add_recv_packets(nr); - l.stats.add_recv_bytes(len); - Poll::Ready(Ok(nr)) - } - DatagramLimitResult::DelayFor(ms) => { - l.delay - .as_mut() - .reset(l.started + Duration::from_millis(dur_millis + ms)); - l.delay.poll_unpin(cx).map(|_| Ok(0)) - } - } - } else { - let nr = ready!(self.inner.poll_recv(cx, bufs, meta))?; - let len = bufs.iter().take(nr).map(|v| v.len()).sum(); - l.stats.add_recv_packets(nr); - l.stats.add_recv_bytes(len); - Poll::Ready(Ok(nr)) - } - } - - fn local_addr(&self) -> io::Result { - self.inner.local_addr() - } - - fn may_fragment(&self) -> bool { - self.inner.may_fragment() - } -} +mod limited_socket; +pub use limited_socket::{LimitedTokioRuntime, LimitedUdpSocket}; diff --git a/lib/g3-io-ext/src/quic/udp_poller.rs b/lib/g3-io-ext/src/quic/udp_poller.rs new file mode 100644 index 000000000..cca3720be --- /dev/null +++ b/lib/g3-io-ext/src/quic/udp_poller.rs @@ -0,0 +1,73 @@ +/* + * Copyright 2024 ByteDance and/or its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use std::fmt; +use std::future::Future; +use std::io; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pin_project_lite::pin_project! { + /// Helper adapting a function `MakeFut` that constructs a single-use future `Fut` into a + /// [`quinn::UdpPoller`] that may be reused indefinitely + pub struct QuinnUdpPollHelper { + make_fut: MakeFut, + #[pin] + fut: Option, + } +} + +impl QuinnUdpPollHelper { + /// Construct a [`quinn::UdpPoller`] that calls `make_fut` to get the future to poll, storing it until + /// it yields [`Poll::Ready`], then creating a new one on the next + /// [`poll_writable`](quinn::UdpPoller::poll_writable) + pub fn new(make_fut: MakeFut) -> Self { + Self { + make_fut, + fut: None, + } + } +} + +impl quinn::UdpPoller for QuinnUdpPollHelper +where + MakeFut: Fn() -> Fut + Send + Sync + 'static, + Fut: Future> + Send + Sync + 'static, +{ + fn poll_writable(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let mut this = self.project(); + if this.fut.is_none() { + this.fut.set(Some((this.make_fut)())); + } + // We're forced to `unwrap` here because `Fut` may be `!Unpin`, which means we can't safely + // obtain an `&mut Fut` after storing it in `self.fut` when `self` is already behind `Pin`, + // and if we didn't store it then we wouldn't be able to keep it alive between + // `poll_writable` calls. + let result = this.fut.as_mut().as_pin_mut().unwrap().poll(cx); + if result.is_ready() { + // Polling an arbitrary `Future` after it becomes ready is a logic error, so arrange for + // a new `Future` to be created on the next call. + this.fut.set(None); + } + result + } +} + +impl fmt::Debug for QuinnUdpPollHelper { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("UdpPollHelper").finish_non_exhaustive() + } +} diff --git a/lib/g3-io-ext/src/udp/ext.rs b/lib/g3-io-ext/src/udp/ext.rs index 007ff9c70..41a738669 100644 --- a/lib/g3-io-ext/src/udp/ext.rs +++ b/lib/g3-io-ext/src/udp/ext.rs @@ -190,6 +190,8 @@ pub trait UdpSocketExt { target: Option, ) -> Poll>; + fn try_sendmsg(&self, iov: &[IoSlice<'_>], target: Option) -> io::Result; + fn poll_recvmsg( &self, cx: &mut Context<'_>, @@ -267,6 +269,33 @@ impl UdpSocketExt for UdpSocket { } } + fn try_sendmsg(&self, iov: &[IoSlice<'_>], target: Option) -> io::Result { + #[cfg(any( + target_os = "linux", + target_os = "android", + target_os = "freebsd", + target_os = "dragonfly", + target_os = "netbsd", + target_os = "openbsd", + ))] + let flags: SendFlags = SendFlags::DONTWAIT | SendFlags::NOSIGNAL; + #[cfg(target_os = "macos")] + let flags: SendFlags = SendFlags::DONTWAIT; + + let fd = self.as_fd(); + let mut control = SendAncillaryBuffer::default(); + + self.try_io(Interest::WRITABLE, || match target { + Some(SocketAddr::V4(a4)) => { + sendmsg_v4(fd, &a4, iov, &mut control, flags).map_err(io::Error::from) + } + Some(SocketAddr::V6(a6)) => { + sendmsg_v6(fd, &a6, iov, &mut control, flags).map_err(io::Error::from) + } + None => sendmsg(fd, iov, &mut control, flags).map_err(io::Error::from), + }) + } + fn poll_recvmsg( &self, cx: &mut Context<'_>, diff --git a/lib/g3-resolver/src/driver/hickory/client.rs b/lib/g3-resolver/src/driver/hickory/client.rs index 6eed801ed..9958c7ed1 100644 --- a/lib/g3-resolver/src/driver/hickory/client.rs +++ b/lib/g3-resolver/src/driver/hickory/client.rs @@ -364,7 +364,7 @@ impl HickoryClientConfig { ) -> anyhow::Result { let tls_name = match tls_name { ServerName::DnsName(domain) => domain.as_ref().to_string(), - ServerName::IpAddress(ip) => ip.to_string(), + ServerName::IpAddress(ip) => IpAddr::from(*ip).to_string(), _ => return Err(anyhow!("unsupported tls server name type")), }; @@ -392,7 +392,7 @@ impl HickoryClientConfig { ) -> anyhow::Result { let tls_name = match tls_name { ServerName::DnsName(domain) => domain.as_ref().to_string(), - ServerName::IpAddress(ip) => ip.to_string(), + ServerName::IpAddress(ip) => IpAddr::from(*ip).to_string(), _ => return Err(anyhow!("unsupported tls server name type")), }; diff --git a/lib/g3-socks/Cargo.toml b/lib/g3-socks/Cargo.toml index 630414ce6..022b14428 100644 --- a/lib/g3-socks/Cargo.toml +++ b/lib/g3-socks/Cargo.toml @@ -13,6 +13,7 @@ bytes.workspace = true smallvec.workspace = true tokio.workspace = true quinn = { workspace = true, optional = true, features = ["runtime-tokio"] } +pin-project-lite.workspace = true g3-types.workspace = true g3-io-ext.workspace = true diff --git a/lib/g3-socks/src/v5/quic.rs b/lib/g3-socks/src/v5/quic.rs index b3f0335f9..ec9f7648c 100644 --- a/lib/g3-socks/src/v5/quic.rs +++ b/lib/g3-socks/src/v5/quic.rs @@ -23,13 +23,13 @@ use std::sync::Arc; use std::task::{ready, Context, Poll}; use std::time::Instant; -use quinn::udp::{RecvMeta, Transmit, UdpState}; -use quinn::{AsyncTimer, AsyncUdpSocket, Runtime}; +use quinn::udp::{RecvMeta, Transmit}; +use quinn::{AsyncTimer, AsyncUdpSocket, Runtime, UdpPoller}; use tokio::io::{AsyncRead, AsyncReadExt}; use tokio::sync::{broadcast, oneshot}; use tokio::time::sleep_until; -use g3_io_ext::UdpSocketExt; +use g3_io_ext::{QuinnUdpPollHelper, UdpSocketExt}; use g3_types::net::Host; use super::udp_io::{UDP_HEADER_LEN_IPV4, UDP_HEADER_LEN_IPV6}; @@ -94,7 +94,7 @@ impl Runtime for Socks5UdpTokioRuntime { tokio::spawn(future); } - fn wrap_udp_socket(&self, t: UdpSocket) -> io::Result> { + fn wrap_udp_socket(&self, t: UdpSocket) -> io::Result> { let (sender, receiver) = oneshot::channel(); let mut ctl_close_receiver = self.ctl_close_receiver.resubscribe(); tokio::spawn(async move { @@ -105,7 +105,7 @@ impl Runtime for Socks5UdpTokioRuntime { } }); let io = tokio::net::UdpSocket::from_std(t)?; - Ok(Box::new(Socks5UdpSocket { + Ok(Arc::new(Socks5UdpSocket { io, quic_peer_addr: self.quic_peer_addr, ctl_close_receiver: UnsafeCell::new(receiver), @@ -170,92 +170,28 @@ pub struct Socks5UdpSocket { send_socks_header: SocksHeaderBuffer, } -impl AsyncUdpSocket for Socks5UdpSocket { - #[cfg(any( - target_os = "linux", - target_os = "android", - target_os = "freebsd", - target_os = "netbsd", - target_os = "openbsd", - ))] - fn poll_send( - &self, - _state: &UdpState, - cx: &mut Context, - transmits: &[Transmit], - ) -> Poll> { - use g3_io_ext::SendMsgHdr; - - let mut msgs = Vec::with_capacity(transmits.len()); +unsafe impl Sync for Socks5UdpSocket {} - for transmit in transmits { - assert_eq!(self.quic_peer_addr, transmit.destination); - - msgs.push(SendMsgHdr::new( - [ - IoSlice::new(self.send_socks_header.as_ref()), - IoSlice::new(&transmit.contents), - ], - None, - )) - } - - self.io.poll_batch_sendmsg(cx, &mut msgs) +impl AsyncUdpSocket for Socks5UdpSocket { + fn create_io_poller(self: Arc) -> Pin> { + Box::pin(QuinnUdpPollHelper::new(move || { + let socket = self.clone(); + async move { socket.io.writable().await } + })) } - #[cfg(target_os = "macos")] - fn poll_send( - &self, - _state: &UdpState, - cx: &mut Context, - transmits: &[Transmit], - ) -> Poll> { - // logics from quinn-udp::fallback.rs - let mut sent = 0; - for transmit in transmits { - assert_eq!(self.quic_peer_addr, transmit.destination); + fn try_send(&self, transmit: &Transmit) -> io::Result<()> { + assert_eq!(self.quic_peer_addr, transmit.destination); - match self.io.poll_sendmsg( - cx, + self.io + .try_sendmsg( &[ IoSlice::new(self.send_socks_header.as_ref()), IoSlice::new(&transmit.contents), ], None, - ) { - Poll::Ready(ready) => match ready { - Ok(_) => { - sent += 1; - } - // We need to report that some packets were sent in this case, so we rely on - // errors being either harmlessly transient (in the case of WouldBlock) or - // recurring on the next call. - Err(_) if sent != 0 => return Poll::Ready(Ok(sent)), - Err(e) => { - if e.kind() == io::ErrorKind::WouldBlock { - return Poll::Ready(Err(e)); - } - - // Other errors are ignored, since they will ususally be handled - // by higher level retransmits and timeouts. - // - PermissionDenied errors have been observed due to iptable rules. - // Those are not fatal errors, since the - // configuration can be dynamically changed. - // - Destination unreachable errors have been observed for other - // log_sendmsg_error(&mut self.last_send_error, e, transmit); - sent += 1; - } - }, - Poll::Pending => { - return if sent == 0 { - Poll::Pending - } else { - Poll::Ready(Ok(sent)) - } - } - } - } - Poll::Ready(Ok(sent)) + ) + .map(|_| ()) } #[cfg(any( diff --git a/lib/g3-types/Cargo.toml b/lib/g3-types/Cargo.toml index f270e44b4..8e7b6c64f 100644 --- a/lib/g3-types/Cargo.toml +++ b/lib/g3-types/Cargo.toml @@ -39,6 +39,7 @@ regex = { workspace = true, optional = true } radix_trie = { workspace = true, optional = true } rustls = { workspace = true, optional = true } rustls-pki-types = { workspace = true, optional = true } +quinn = { workspace = true, optional = true } webpki-roots = { version = "0.26", optional = true } rustls-native-certs = { version = "0.7", optional = true } rustls-pemfile = { workspace = true, optional = true } @@ -56,7 +57,8 @@ brotli = { version = "6.0", optional = true , default-features = false, features default = [] auth-crypt = ["dep:digest", "dep:md-5", "dep:sha-1", "dep:blake3", "dep:hex"] resolve = ["dep:ahash", "dep:radix_trie", "dep:fastrand"] -rustls = ["dep:rustls", "dep:rustls-pki-types", "dep:webpki-roots", "dep:rustls-pemfile", "dep:rustls-native-certs", "dep:ahash", "dep:lru"] +quinn = ["dep:quinn"] +rustls = ["dep:rustls", "dep:rustls-pki-types", "dep:webpki-roots", "dep:rustls-pemfile", "dep:rustls-native-certs", "dep:ahash", "dep:lru", "quinn?/rustls"] openssl = ["dep:openssl", "dep:ahash", "dep:lru", "dep:bytes"] tongsuo = ["openssl", "openssl/tongsuo", "dep:brotli"] aws-lc = ["openssl", "openssl/aws-lc", "dep:brotli"] diff --git a/lib/g3-types/src/net/rustls/client.rs b/lib/g3-types/src/net/rustls/client.rs index 5d25de3f1..cde7eeecd 100644 --- a/lib/g3-types/src/net/rustls/client.rs +++ b/lib/g3-types/src/net/rustls/client.rs @@ -18,6 +18,8 @@ use std::sync::Arc; use std::time::Duration; use anyhow::anyhow; +#[cfg(feature = "quinn")] +use quinn::crypto::rustls::QuicClientConfig; use rustls::client::Resumption; use rustls::{ClientConfig, RootCertStore}; use rustls_pki_types::CertificateDer; @@ -34,6 +36,13 @@ pub struct RustlsClientConfig { pub handshake_timeout: Duration, } +#[cfg(feature = "quinn")] +#[derive(Clone)] +pub struct RustlsQuicClientConfig { + pub driver: Arc, + pub handshake_timeout: Duration, +} + #[derive(Clone, Debug, Eq, PartialEq)] pub struct RustlsClientConfigBuilder { no_session_cache: bool, @@ -102,10 +111,10 @@ impl RustlsClientConfigBuilder { self.use_builtin_ca_certs = true; } - pub fn build_with_alpn_protocols( + fn build_client_config( &self, alpn_protocols: Option>, - ) -> anyhow::Result { + ) -> anyhow::Result { let config_builder = ClientConfig::builder(); let mut root_store = RootCertStore::empty(); @@ -153,6 +162,14 @@ impl RustlsClientConfigBuilder { config.enable_sni = false; } + Ok(config) + } + + pub fn build_with_alpn_protocols( + &self, + alpn_protocols: Option>, + ) -> anyhow::Result { + let config = self.build_client_config(alpn_protocols)?; Ok(RustlsClientConfig { driver: Arc::new(config), handshake_timeout: self.handshake_timeout, @@ -162,4 +179,23 @@ impl RustlsClientConfigBuilder { pub fn build(&self) -> anyhow::Result { self.build_with_alpn_protocols(None) } + + #[cfg(feature = "quinn")] + pub fn build_quic_with_alpn_protocols( + &self, + alpn_protocols: Option>, + ) -> anyhow::Result { + let config = self.build_client_config(alpn_protocols)?; + let quic_config = QuicClientConfig::try_from(config) + .map_err(|e| anyhow!("invalid quic tls config: {e}"))?; + Ok(RustlsQuicClientConfig { + driver: Arc::new(quic_config), + handshake_timeout: self.handshake_timeout, + }) + } + + #[cfg(feature = "quinn")] + pub fn build_quic(&self) -> anyhow::Result { + self.build_quic_with_alpn_protocols(None) + } } diff --git a/lib/g3-types/src/net/rustls/mod.rs b/lib/g3-types/src/net/rustls/mod.rs index c70747575..36a99a3dd 100644 --- a/lib/g3-types/src/net/rustls/mod.rs +++ b/lib/g3-types/src/net/rustls/mod.rs @@ -15,9 +15,13 @@ */ mod client; +#[cfg(feature = "quinn")] +pub use client::RustlsQuicClientConfig; pub use client::{RustlsClientConfig, RustlsClientConfigBuilder}; mod server; +#[cfg(feature = "quinn")] +pub use server::RustlsQuicServerConfig; pub use server::{RustlsServerConfig, RustlsServerConfigBuilder}; mod cache; diff --git a/lib/g3-types/src/net/rustls/server.rs b/lib/g3-types/src/net/rustls/server.rs index e346ccdd4..4eb00ae30 100644 --- a/lib/g3-types/src/net/rustls/server.rs +++ b/lib/g3-types/src/net/rustls/server.rs @@ -18,6 +18,8 @@ use std::sync::Arc; use std::time::Duration; use anyhow::{anyhow, Context}; +#[cfg(feature = "quinn")] +use quinn::crypto::rustls::QuicServerConfig; use rustls::crypto::ring::Ticketer; use rustls::server::WebPkiClientVerifier; use rustls::{RootCertStore, ServerConfig}; @@ -32,6 +34,13 @@ pub struct RustlsServerConfig { pub accept_timeout: Duration, } +#[cfg(feature = "quinn")] +#[derive(Clone)] +pub struct RustlsQuicServerConfig { + pub driver: Arc, + pub accept_timeout: Duration, +} + #[derive(Clone, Debug, Eq, PartialEq)] pub struct RustlsServerConfigBuilder { cert_pairs: Vec, @@ -86,10 +95,10 @@ impl RustlsServerConfigBuilder { self.accept_timeout } - pub fn build_with_alpn_protocols( + fn build_server_config( &self, alpn_protocols: Option>, - ) -> anyhow::Result { + ) -> anyhow::Result { let config_builder = ServerConfig::builder(); let config_builder = if self.client_auth { let mut root_store = RootCertStore::empty(); @@ -148,6 +157,14 @@ impl RustlsServerConfigBuilder { } } + Ok(config) + } + + pub fn build_with_alpn_protocols( + &self, + alpn_protocols: Option>, + ) -> anyhow::Result { + let config = self.build_server_config(alpn_protocols)?; Ok(RustlsServerConfig { driver: Arc::new(config), accept_timeout: self.accept_timeout, @@ -157,4 +174,21 @@ impl RustlsServerConfigBuilder { pub fn build(&self) -> anyhow::Result { self.build_with_alpn_protocols(None) } + + pub fn build_quic_with_alpn_protocols( + &self, + alpn_protocols: Option>, + ) -> anyhow::Result { + let config = self.build_server_config(alpn_protocols)?; + let quic_config = QuicServerConfig::try_from(config) + .map_err(|e| anyhow!("invalid quic tls config: {e}"))?; + Ok(RustlsQuicServerConfig { + driver: Arc::new(quic_config), + accept_timeout: self.accept_timeout, + }) + } + + pub fn build_quic(&self) -> anyhow::Result { + self.build_quic_with_alpn_protocols(None) + } }