diff --git a/Cargo.lock b/Cargo.lock index 5ff0f5354..fcc1f145f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -541,9 +541,9 @@ dependencies = [ [[package]] name = "mock_instant" -version = "0.4.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c356644192565524790740e4075307c2cfc26d04d2543fb8e3ab9ef43a115ec" +checksum = "cdcebb6db83796481097dedc7747809243cc81d9ed83e6a938b76d4ea0b249cf" [[package]] name = "moka" diff --git a/Cargo.toml b/Cargo.toml index 87a6f550b..381fed4f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,8 +60,8 @@ validate = ["bytes", "std", "ring"] zonefile = ["bytes", "serde", "std"] # Unstable features -unstable-client-transport = [ "moka", "net", "tracing" ] -unstable-server-transport = ["arc-swap", "chrono/clock", "libc", "net", "tracing"] +unstable-client-transport = ["moka", "net", "tracing"] +unstable-server-transport = ["arc-swap", "chrono/clock", "libc", "net", "siphasher", "tracing"] unstable-stelline = ["tokio/test-util", "tracing", "tracing-subscriber", "unstable-server-transport", "zonefile"] unstable-validator = ["validate", "zonefile", "unstable-client-transport"] unstable-zonetree = ["futures", "parking_lot", "serde", "tokio", "tracing"] @@ -84,7 +84,7 @@ webpki-roots = { version = "0.26" } #sqlx = { version = "0.6", features = [ "runtime-tokio-native-tls", "mysql" ] } # For testing in integration tests: -mock_instant = { version = "0.4.0" } +mock_instant = { version = "0.5.1" } [package.metadata.docs.rs] all-features = true diff --git a/examples/server-transports.rs b/examples/server-transports.rs index 3adc6d416..091bdf46e 100644 --- a/examples/server-transports.rs +++ b/examples/server-transports.rs @@ -33,7 +33,6 @@ use domain::net::server::dgram::DgramServer; use domain::net::server::message::Request; use domain::net::server::middleware::builder::MiddlewareBuilder; use domain::net::server::middleware::processor::MiddlewareProcessor; -#[cfg(feature = "siphasher")] use domain::net::server::middleware::processors::cookies::CookiesMiddlewareProcessor; use domain::net::server::middleware::processors::mandatory::MandatoryMiddlewareProcessor; use domain::net::server::service::{ @@ -688,12 +687,9 @@ async fn main() { let mut fn_svc_middleware = MiddlewareBuilder::new(); fn_svc_middleware.push(MandatoryMiddlewareProcessor::new().into()); - #[cfg(feature = "siphasher")] - { - let server_secret = "server12secret34".as_bytes().try_into().unwrap(); - fn_svc_middleware - .push(CookiesMiddlewareProcessor::new(server_secret).into()); - } + let server_secret = "server12secret34".as_bytes().try_into().unwrap(); + fn_svc_middleware + .push(CookiesMiddlewareProcessor::new(server_secret).into()); let fn_svc_middleware = fn_svc_middleware.build(); diff --git a/src/base/serial.rs b/src/base/serial.rs index b325057db..9bb91b46e 100644 --- a/src/base/serial.rs +++ b/src/base/serial.rs @@ -13,7 +13,7 @@ use chrono::{DateTime, TimeZone}; use core::cmp::Ordering; use core::{cmp, fmt, str}; #[cfg(all(feature = "std", test))] -use mock_instant::{SystemTime, UNIX_EPOCH}; +use mock_instant::thread_local::{SystemTime, UNIX_EPOCH}; use octseq::parse::Parser; #[cfg(all(feature = "std", not(test)))] use std::time::{SystemTime, UNIX_EPOCH}; diff --git a/src/net/client/validator_test.rs b/src/net/client/validator_test.rs index 4c000dba9..6b8a587e9 100644 --- a/src/net/client/validator_test.rs +++ b/src/net/client/validator_test.rs @@ -14,7 +14,7 @@ use crate::stelline::connect::Connect; use crate::stelline::parse_stelline::parse_file; use crate::stelline::parse_stelline::Config; -use mock_instant::MockClock; +use mock_instant::thread_local::MockClock; use rstest::rstest; use tracing::instrument; diff --git a/src/net/server/dgram.rs b/src/net/server/dgram.rs index 59356c36d..f2c8f5422 100644 --- a/src/net/server/dgram.rs +++ b/src/net/server/dgram.rs @@ -28,6 +28,7 @@ use tokio::time::interval; use tokio::time::timeout; use tokio::time::Instant; use tokio::time::MissedTickBehavior; +use tracing::warn; use tracing::Level; use tracing::{enabled, error, trace}; @@ -720,13 +721,16 @@ where // Actually write the DNS response message bytes to the UDP // socket. - let _ = Self::send_to( + if let Err(err) = Self::send_to( &state.sock, bytes, &client_addr, state.write_timeout, ) - .await; + .await + { + warn!(%client_addr, "Failed to send response: {err}"); + } metrics.dec_num_pending_writes(); metrics.inc_num_sent_responses(); diff --git a/src/net/server/middleware/processors/cookies.rs b/src/net/server/middleware/processors/cookies.rs index f02630c45..fefa61645 100644 --- a/src/net/server/middleware/processors/cookies.rs +++ b/src/net/server/middleware/processors/cookies.rs @@ -1,7 +1,6 @@ //! DNS Cookies related message processing. use core::ops::ControlFlow; -use std::net::IpAddr; use std::vec::Vec; use octseq::Octets; @@ -10,6 +9,7 @@ use tracing::{debug, trace, warn}; use crate::base::iana::{OptRcode, Rcode}; use crate::base::message_builder::AdditionalBuilder; +use crate::base::net::IpAddr; use crate::base::opt; use crate::base::wire::{Composer, ParseError}; use crate::base::{Serial, StreamTarget}; @@ -18,6 +18,8 @@ use crate::net::server::middleware::processor::MiddlewareProcessor; use crate::net::server::util::add_edns_options; use crate::net::server::util::{mk_builder_for_target, start_reply}; +//----------- Constants ------------------------------------------------------- + /// The five minute period referred to by /// https://www.rfc-editor.org/rfc/rfc9018.html#section-4.3. const FIVE_MINUTES_AS_SECS: u32 = 5 * 60; @@ -26,6 +28,8 @@ const FIVE_MINUTES_AS_SECS: u32 = 5 * 60; /// https://www.rfc-editor.org/rfc/rfc9018.html#section-4.3. const ONE_HOUR_AS_SECS: u32 = 60 * 60; +//----------- CookiesMiddlewareProcessor -------------------------------------- + /// A DNS Cookies [`MiddlewareProcessor`]. /// /// Standards covered by ths implementation: @@ -71,19 +75,20 @@ impl CookiesMiddlewareProcessor { } impl CookiesMiddlewareProcessor { - /// Get the DNS COOKIE, if any, for the given message. + /// Get the DNS cookie, if any, for the given message. /// - /// https://datatracker.ietf.org/doc/html/rfc7873#section-5.2: Responding - /// to a Request: "In all cases of multiple COOKIE options in a request, - /// only the first (the one closest to the DNS header) is considered. - /// All others are ignored." + /// https://datatracker.ietf.org/doc/html/rfc7873#section-5.2 + /// 5.2 Responding to a Request + /// "In all cases of multiple COOKIE options in a request, only the + /// first (the one closest to the DNS header) is considered. All others + /// are ignored." /// /// Returns: - /// - `None` if the request has no cookie, - /// - Some(Ok(cookie)) if the request has a cookie in the correct - /// format, - /// - Some(Err(err)) if the request has a cookie that we could not - /// parse. + /// - None if the request has no cookie, + /// - Some(Ok(cookie)) if the first cookie in the request could be + /// parsed. + /// - Some(Err(err)) if the first cookie in the request could not be + /// parsed. #[must_use] fn cookie( request: &Request, @@ -117,7 +122,15 @@ impl CookiesMiddlewareProcessor { let now = Serial::now(); let too_new_at = now.add(FIVE_MINUTES_AS_SECS); let expires_at = serial.add(ONE_HOUR_AS_SECS); - now <= expires_at && serial <= too_new_at + if now > expires_at { + trace!("Invalid server cookie: cookie has expired ({now} > {expires_at})"); + false + } else if serial > too_new_at { + trace!("Invalid server cookie: cookie is too new ({serial} > {too_new_at})"); + false + } else { + true + } } /// Create a DNS response message for the given request, including cookie. @@ -230,6 +243,7 @@ where RequestOctets: Octets, Target: Composer + Default, { + #[tracing::instrument(skip_all, fields(request_ip = %request.client_addr().ip()))] fn preprocess( &self, request: &Request, @@ -245,24 +259,31 @@ where // the request as if the server doesn't implement the // COOKIE option." - // For clients on the IP deny list they MUST authenticate - // themselves to the server, either with a cookie or by - // re-connecting over TCP, so we REFUSE them and reply with - // TC=1 to prompt them to reconnect via TCP. + // https://datatracker.ietf.org/doc/html/rfc7873#section-1 + // 1. Introduction + // "The protection provided by DNS Cookies is similar to + // that provided by using TCP for DNS transactions. + // ... + // Where DNS Cookies are not available but TCP is, falling + // back to using TCP is reasonable." + + // While not required by RFC 7873, like Unbound the caller can + // configure this middleware processor to require clients + // contacting it from certain IP addresses to authenticate + // themselves or be refused with TC=1 to signal that they + // should resubmit their request via TCP. if request.transport_ctx().is_udp() && self.ip_deny_list.contains(&request.client_addr().ip()) { - debug!( - "Rejecting cookie-less non-TCP request due to matching IP deny list entry" - ); + debug!("Rejecting cookie-less non-TCP request due to matching deny list entry"); let builder = mk_builder_for_target(); let mut additional = builder.additional(); additional.header_mut().set_rcode(Rcode::REFUSED); additional.header_mut().set_tc(true); return ControlFlow::Break(additional); - } else { - trace!("Permitting cookie-less request to flow due to use of TCP transport"); } + + // Continue as if we we don't implement the COOKIE option. } Some(Err(err)) => { @@ -305,6 +326,8 @@ where ); if !server_cookie_is_valid { + trace!("Request has an invalid DNS server cookie"); + // https://datatracker.ietf.org/doc/html/rfc7873#section-5.2.3 // Only a Client Cookie: // "Based on server policy, including rate limiting, the @@ -379,10 +402,13 @@ where self.bad_cookie_response(request) }; return ControlFlow::Break(additional); - } else if request.transport_ctx().is_udp() { + } else if request.transport_ctx().is_udp() + && self + .ip_deny_list + .contains(&request.client_addr().ip()) + { let additional = self.bad_cookie_response(request); - debug!( - "Rejecting non-TCP request due to invalid server cookie"); + debug!("Rejecting non-TCP request with invalid server cookie due to matching deny list entry"); return ControlFlow::Break(additional); } } else if request.message().header_counts().qdcount() == 0 { @@ -480,13 +506,15 @@ mod tests { let request = Request::new(client_addr, Instant::now(), message, ctx.into()); - // And pass the query through the middleware processor - let server_secret: [u8; 16] = - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]; - let processor = CookiesMiddlewareProcessor::new(server_secret); + // Setup the cookie middleware processor such that it requires + // the mock client to provide a valid cookie. + let server_secret: [u8; 16] = [1u8; 16]; + let processor = CookiesMiddlewareProcessor::new(server_secret) + .with_denied_ips(["127.0.0.1".parse().unwrap()]); let processor: &dyn MiddlewareProcessor, Vec> = &processor; + // And pass the query through the middleware processor let ControlFlow::Break(mut response) = processor.preprocess(&request) else { unreachable!() diff --git a/src/net/server/middleware/processors/mod.rs b/src/net/server/middleware/processors/mod.rs index 18635c239..a2df774f6 100644 --- a/src/net/server/middleware/processors/mod.rs +++ b/src/net/server/middleware/processors/mod.rs @@ -1,7 +1,6 @@ //! Pre-supplied [`MiddlewareProcessor`] implementations. //! //! [`MiddlewareProcessor`]: super::processor::MiddlewareProcessor -#[cfg(feature = "siphasher")] pub mod cookies; pub mod edns; pub mod mandatory; diff --git a/tests/net-server.rs b/src/net/server/tests/integration.rs similarity index 85% rename from tests/net-server.rs rename to src/net/server/tests/integration.rs index d9959ed80..b4dcfa08d 100644 --- a/tests/net-server.rs +++ b/src/net/server/tests/integration.rs @@ -1,47 +1,46 @@ -#![cfg(feature = "net")] - +use std::boxed::Box; use std::collections::VecDeque; use std::fs::File; use std::future::Future; -use std::net::IpAddr; +use std::net::SocketAddr; use std::path::PathBuf; use std::sync::Arc; use std::time::Duration; +use std::vec::Vec; use octseq::Octets; use rstest::rstest; use tracing::instrument; use tracing::{trace, warn}; -use domain::base::iana::Rcode; -use domain::base::name::{Name, ToName}; -use domain::base::wire::Composer; -use domain::net::client::{dgram, stream}; -use domain::net::server::buf::VecBufSource; -use domain::net::server::dgram::DgramServer; -use domain::net::server::message::Request; -use domain::net::server::middleware::builder::MiddlewareBuilder; -#[cfg(feature = "siphasher")] -use domain::net::server::middleware::processors::cookies::CookiesMiddlewareProcessor; -use domain::net::server::middleware::processors::edns::EdnsMiddlewareProcessor; -use domain::net::server::service::{ +use crate::base::iana::Rcode; +use crate::base::name::{Name, ToName}; +use crate::base::net::IpAddr; +use crate::base::wire::Composer; +use crate::net::client::{dgram, stream}; +use crate::net::server::buf::VecBufSource; +use crate::net::server::dgram::DgramServer; +use crate::net::server::message::Request; +use crate::net::server::middleware::builder::MiddlewareBuilder; +use crate::net::server::middleware::processors::cookies::CookiesMiddlewareProcessor; +use crate::net::server::middleware::processors::edns::EdnsMiddlewareProcessor; +use crate::net::server::service::{ CallResult, Service, ServiceError, Transaction, }; -use domain::net::server::stream::StreamServer; -use domain::net::server::util::{mk_builder_for_target, service_fn}; -use domain::utils::base16; -use domain::zonefile::inplace::{Entry, ScannedRecord, Zonefile}; - -use domain::stelline::channel::ClientServerChannel; -use domain::stelline::client::do_client; -use domain::stelline::client::ClientFactory; -use domain::stelline::client::{ +use crate::net::server::stream::StreamServer; +use crate::net::server::util::{mk_builder_for_target, service_fn}; +use crate::stelline::channel::ClientServerChannel; +use crate::stelline::client::do_client; +use crate::stelline::client::ClientFactory; +use crate::stelline::client::{ CurrStepValue, PerClientAddressClientFactory, QueryTailoredClientFactory, }; -use domain::stelline::parse_stelline; -use domain::stelline::parse_stelline::parse_file; -use domain::stelline::parse_stelline::Config; -use domain::stelline::parse_stelline::Matches; +use crate::stelline::parse_stelline; +use crate::stelline::parse_stelline::parse_file; +use crate::stelline::parse_stelline::Config; +use crate::stelline::parse_stelline::Matches; +use crate::utils::base16; +use crate::zonefile::inplace::{Entry, ScannedRecord, Zonefile}; //----------- Tests ---------------------------------------------------------- @@ -59,6 +58,16 @@ async fn server_tests(#[files("test-data/server/*.rpl")] rpl_file: PathBuf) { // and which responses will be expected, and how the server that // answers them should be configured. + // Initialize tracing based logging. Override with env var RUST_LOG, e.g. + // RUST_LOG=trace. DEBUG level will show the .rpl file name, Stelline step + // numbers and types as they are being executed. + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .with_thread_ids(true) + .without_time() + .try_init() + .ok(); + let file = File::open(&rpl_file).unwrap(); let stelline = parse_file(&file, rpl_file.to_str().unwrap()); let server_config = parse_server_config(&stelline.config); @@ -155,8 +164,9 @@ fn mk_client_factory( }; let tcp_client_factory = PerClientAddressClientFactory::new( - move |_source_addr| { - let stream = stream_server_conn.connect(); + move |source_addr| { + let stream = stream_server_conn + .connect(Some(SocketAddr::new(*source_addr, 0))); let (conn, transport) = stream::Connection::new(stream); tokio::spawn(transport.run()); Box::new(conn) @@ -169,7 +179,12 @@ fn mk_client_factory( let for_all_other_queries = |_: &_| true; let udp_client_factory = PerClientAddressClientFactory::new( - move |_| Box::new(dgram::Connection::new(dgram_server_conn.clone())), + move |source_addr| { + Box::new(dgram::Connection::new( + dgram_server_conn + .new_client(Some(SocketAddr::new(*source_addr, 0))), + )) + }, for_all_other_queries, ); @@ -185,8 +200,8 @@ fn mk_client_factory( fn mk_server_configs( config: &ServerConfig, ) -> ( - domain::net::server::dgram::Config, - domain::net::server::stream::Config, + crate::net::server::dgram::Config, + crate::net::server::stream::Config, ) where RequestOctets: Octets, @@ -195,7 +210,6 @@ where let mut middleware = MiddlewareBuilder::minimal(); if config.cookies.enabled { - #[cfg(feature = "siphasher")] if let Some(secret) = config.cookies.secret { let secret = base16::decode_vec(secret).unwrap(); let secret = <[u8; 16]>::try_from(secret).unwrap(); @@ -204,9 +218,6 @@ where .with_denied_ips(config.cookies.ip_deny_list.clone()); middleware.push(processor.into()); } - - #[cfg(not(feature = "siphasher"))] - panic!("The test uses cookies but the required 'siphasher' feature is not enabled."); } if config.edns_tcp_keepalive { @@ -216,13 +227,13 @@ where let middleware = middleware.build(); - let mut dgram_config = domain::net::server::dgram::Config::default(); + let mut dgram_config = crate::net::server::dgram::Config::default(); dgram_config.set_middleware_chain(middleware.clone()); - let mut stream_config = domain::net::server::stream::Config::default(); + let mut stream_config = crate::net::server::stream::Config::default(); if let Some(idle_timeout) = config.idle_timeout { let mut connection_config = - domain::net::server::ConnectionConfig::default(); + crate::net::server::ConnectionConfig::default(); connection_config.set_idle_timeout(idle_timeout); connection_config.set_middleware_chain(middleware); stream_config.set_connection_config(connection_config); @@ -263,7 +274,7 @@ fn test_service( } fn as_records( - e: Result, + e: Result, ) -> Option { match e { Ok(Entry::Record(r)) => Some(r), diff --git a/src/net/server/tests/mod.rs b/src/net/server/tests/mod.rs new file mode 100644 index 000000000..f4e60adc5 --- /dev/null +++ b/src/net/server/tests/mod.rs @@ -0,0 +1,3 @@ +#![cfg(all(feature = "net", test))] +mod integration; +mod unit; diff --git a/src/net/server/tests.rs b/src/net/server/tests/unit.rs similarity index 98% rename from src/net/server/tests.rs rename to src/net/server/tests/unit.rs index 6ac63da98..e728abcf1 100644 --- a/src/net/server/tests.rs +++ b/src/net/server/tests/unit.rs @@ -20,14 +20,13 @@ use crate::base::Name; use crate::base::Rtype; use crate::base::StaticCompressor; use crate::base::StreamTarget; - -use super::buf::BufSource; -use super::message::Request; -use super::service::{ +use crate::net::server::buf::BufSource; +use crate::net::server::message::Request; +use crate::net::server::service::{ CallResult, Service, ServiceError, ServiceFeedback, Transaction, }; -use super::sock::AsyncAccept; -use super::stream::StreamServer; +use crate::net::server::sock::AsyncAccept; +use crate::net::server::stream::StreamServer; /// Mock I/O which supplies a sequence of mock messages to the server at a /// defined rate. diff --git a/src/stelline/channel.rs b/src/stelline/channel.rs index 5076c3843..66f79c085 100644 --- a/src/stelline/channel.rs +++ b/src/stelline/channel.rs @@ -22,13 +22,14 @@ use crate::net::client::protocol::{ AsyncConnect, AsyncDgramRecv, AsyncDgramSend, }; use crate::net::server::sock::{AsyncAccept, AsyncDgramSock}; +use core::sync::atomic::{AtomicU16, Ordering}; // If MSRV gets bumped to 1.69.0 we can replace these with a const SocketAddr. pub const DEF_CLIENT_ADDR: IpAddr = IpAddr::V4(Ipv4Addr::LOCALHOST); pub const DEF_CLIENT_PORT: u16 = 0; enum Data { - DgramRequest(Vec), + DgramRequest(SocketAddr, Vec), StreamAccept(ClientServerChannel), StreamRequest(Vec), } @@ -124,7 +125,7 @@ struct ServerSocket { /// Senders for the server to send responses to clients. /// /// One per client to which responses must be sent. - response_txs: HashMap<(), mpsc::Sender>>, + response_txs: HashMap>>, /// Buffer for received bytes that overflowed the server read buffer. unread_buf: ReadBufBuffer, @@ -152,7 +153,6 @@ impl ServerSocket { } } -#[derive(Default)] pub struct ClientServerChannel { /// Details of the server end of the connection. server: Arc>, @@ -160,10 +160,30 @@ pub struct ClientServerChannel { /// Details of the client end of the connection, if connected. client: Option, + /// Simulated client address. + client_addr: SocketAddr, + + /// Next mock client port number to use. + next_client_port: Arc, + /// Type of connection. is_stream: bool, } +impl Default for ClientServerChannel { + fn default() -> Self { + let client_addr = SocketAddr::new("127.0.0.1".parse().unwrap(), 0); + + Self { + server: Default::default(), + client: Default::default(), + client_addr, + next_client_port: Arc::new(AtomicU16::new(1)), + is_stream: Default::default(), + } + } +} + impl Clone for ClientServerChannel { /// Clones only the server half, the client half cannot be cloned. The /// result can be used to connect a new client to an existing server. @@ -171,6 +191,8 @@ impl Clone for ClientServerChannel { Self { server: self.server.clone(), client: None, + client_addr: self.client_addr, + next_client_port: self.next_client_port.clone(), is_stream: self.is_stream, } } @@ -191,30 +213,61 @@ impl ClientServerChannel { } } - pub fn connect(&self) -> Self { - fn setup_client(server_socket: &mut ServerSocket) -> ClientSocket { + pub fn new_client(&self, client_addr: Option) -> Self { + let mut client_addr = client_addr.unwrap_or_else(|| { + SocketAddr::new("127.0.0.1".parse().unwrap(), 0) + }); + + if client_addr.port() == 0 { + let client_port = + self.next_client_port.fetch_add(1, Ordering::SeqCst); + client_addr.set_port(client_port); + } + + Self { + server: self.server.clone(), + client: None, + client_addr, + next_client_port: self.next_client_port.clone(), + is_stream: self.is_stream, + } + } + + pub fn connect(&self, client_addr: Option) -> Self { + fn setup_client( + server_socket: &mut ServerSocket, + client_addr: SocketAddr, + ) -> ClientSocket { // Create a client socket for sending requests to the server. let (client, response_tx) = ClientSocket::new(server_socket.sender()); // Tell the server how to respond to the client. - server_socket.response_txs.insert((), response_tx); + server_socket.response_txs.insert(client_addr, response_tx); // Return the created client socket client } + let client_addr = client_addr.unwrap_or_else(|| { + let client_port = + self.next_client_port.fetch_add(1, Ordering::SeqCst); + SocketAddr::new("127.0.0.1".parse().unwrap(), client_port) + }); + match self.is_stream { false => { // For dgram connections all clients communicate with the same // single server socket. let server_socket = &mut self.server.lock().unwrap(); - let client = setup_client(server_socket); + let client = setup_client(server_socket, client_addr); // Tell the client how to contact the server. Self { server: self.server.clone(), client: Some(client), + client_addr: self.client_addr, + next_client_port: self.next_client_port.clone(), is_stream: false, } } @@ -223,12 +276,14 @@ impl ClientServerChannel { // But for stream connections each new client communicates // with a new server-side connection handler socket. let mut server_socket = ServerSocket::default(); - let client = setup_client(&mut server_socket); + let client = setup_client(&mut server_socket, client_addr); // Tell the client how to contact the new server connection handler. let channel = Self { server: Arc::new(Mutex::new(server_socket)), client: Some(client), + client_addr: self.client_addr, + next_client_port: self.next_client_port.clone(), is_stream: true, }; @@ -236,9 +291,10 @@ impl ClientServerChannel { // by unblocking AsyncAccept::poll_accept() which is being polled // by the server. let sender = self.server.lock().unwrap().tx.clone(); - let cloned_channel = channel.clone(); + let channel_for_client = + channel.new_client(Some(client_addr)); tokio::spawn(async move { - sender.send(Data::StreamAccept(cloned_channel)).await + sender.send(Data::StreamAccept(channel_for_client)).await }); channel @@ -263,8 +319,7 @@ impl AsyncConnect for ClientServerChannel { >; fn connect(&self) -> Self::Fut { - let conn = self.connect(); - + let conn = self.connect(Some(self.client_addr)); Box::pin(async move { Ok(conn) }) } } @@ -290,7 +345,7 @@ impl AsyncDgramRecv for ClientServerChannel { Poll::Ready(Ok(())) } Poll::Ready(None) => { - trace!("Broken pipe while reading in dgram client channel"); + trace!("Broken pipe while reading in dgram client channel (is_closed={})", rx.is_closed()); Poll::Ready(Err(io::Error::from(io::ErrorKind::BrokenPipe))) } Poll::Pending => { @@ -309,7 +364,7 @@ impl AsyncDgramSend for ClientServerChannel { ) -> Poll> { match &self.client { Some(client) => { - let msg = Data::DgramRequest(data.into()); + let msg = Data::DgramRequest(self.client_addr, data.into()); // TODO: Can Stelline scripts mix and match fake responses with // responses from a real server? Do we need to first try @@ -358,10 +413,10 @@ impl AsyncDgramSock for ClientServerChannel { &self, cx: &mut Context, data: &[u8], - dest: &std::net::SocketAddr, + dest: &SocketAddr, ) -> Poll> { let server_socket = self.server.lock().unwrap(); - let tx = server_socket.response_txs.get(&()); + let tx = server_socket.response_txs.get(dest); if let Some(server_tx) = tx { let mut fut = Box::pin(server_tx.send(data.to_vec())); match fut.poll_unpin(cx) { @@ -407,12 +462,11 @@ impl AsyncDgramSock for ClientServerChannel { let mut server_socket = self.server.lock().unwrap(); let rx = &mut server_socket.rx; match rx.try_recv() { - Ok(Data::DgramRequest(data)) => { + Ok(Data::DgramRequest(addr, data)) => { // TODO: use unread buf here to prevent overflow of given buf. - trace!("Reading {} bytes into buffer of len {} in dgram server channel", data.len(), buf.remaining()); + trace!("Reading {} bytes from {addr} into buffer of len {} in dgram server channel", data.len(), buf.remaining()); buf.put_slice(&data); - let socket_addr = SocketAddr::new("::".parse().unwrap(), 0); - Ok((data.len(), socket_addr)) + Ok((data.len(), addr)) } Ok(Data::StreamAccept(..)) => unreachable!(), Ok(Data::StreamRequest(..)) => unreachable!(), @@ -439,17 +493,16 @@ impl Future for ClientServerChannelReadableFut { ) -> Poll { let server_socket = self.0.lock().unwrap(); let rx = &server_socket.rx; - trace!("ReadableFut {} in dgram server channel", !rx.is_empty()); - match !rx.is_empty() { - true => Poll::Ready(Ok(())), - false => { - let waker = cx.waker().clone(); - std::thread::spawn(move || { - std::thread::yield_now(); - waker.wake(); - }); - Poll::Pending - } + if !rx.is_empty() { + trace!("Server socket is now readable"); + Poll::Ready(Ok(())) + } else { + trace!("Server socket is not yet readable"); + let waker = cx.waker().clone(); + tokio::task::spawn(async move { + waker.wake(); + }); + Poll::Pending } } } diff --git a/src/stelline/client.rs b/src/stelline/client.rs index d32386dab..6d8f85ba2 100644 --- a/src/stelline/client.rs +++ b/src/stelline/client.rs @@ -10,10 +10,8 @@ use std::time::Duration; use std::vec::Vec; use bytes::Bytes; -/* -#[cfg(feature = "mock-time")] -use mock_instant::MockClock; -*/ +#[cfg(all(feature = "std", test))] +use mock_instant::thread_local::MockClock; use tracing::{debug, info_span, trace}; use tracing_subscriber::EnvFilter; @@ -377,6 +375,12 @@ pub async fn do_client<'a, T: ClientFactory>( ) -> Result<(), StellineErrorCause> { let mut resp: Option> = None; + #[cfg(all(feature = "std", test))] + { + trace!("Setting mock system time to zero."); + MockClock::set_system_time(Duration::ZERO); + } + // Assume steps are in order. Maybe we need to define that. for step in &stelline.scenario.steps { let span = @@ -410,6 +414,8 @@ pub async fn do_client<'a, T: ClientFactory>( .await; } + trace!("Receive result: {res:?}"); + resp = res?; trace!(?resp); @@ -430,10 +436,14 @@ pub async fn do_client<'a, T: ClientFactory>( let duration = Duration::from_secs(step.time_passes.unwrap()); tokio::time::advance(duration).await; - /* - #[cfg(feature = "mock-time")] - MockClock::advance_system_time(duration); - */ + #[cfg(all(feature = "std", test))] + { + trace!( + "Advancing mock system time by {} seconds...", + duration.as_secs() + ); + MockClock::advance_system_time(duration); + } } StepType::Traffic | StepType::CheckTempfile diff --git a/test-data/server/edns_downstream_cookies.rpl.not b/test-data/server/edns_downstream_cookies.rpl similarity index 100% rename from test-data/server/edns_downstream_cookies.rpl.not rename to test-data/server/edns_downstream_cookies.rpl