Skip to content

Commit

Permalink
add tests in dns_client
Browse files Browse the repository at this point in the history
  • Loading branch information
shiroedev2024 committed May 6, 2024
1 parent f030031 commit 1df3ff2
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 5 deletions.
2 changes: 1 addition & 1 deletion leaf/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ all-endpoints = [
"outbound-select",
"outbound-fragment",
"outbound-h2",
"outbound-http"
"outbound-http",
]

# Ring-related
Expand Down
266 changes: 265 additions & 1 deletion leaf/src/app/dns_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,4 +555,268 @@ impl DnsClient {

impl UdpConnector for DnsClient {}

impl TcpConnector for DnsClient {}
impl TcpConnector for DnsClient {}

#[cfg(test)]
mod tests {
use tokio::net::TcpSocket;
use std::str::FromStr;
use std::sync::Arc;
use anyhow::anyhow;
use bytes::{Bytes, BytesMut};
use rand::prelude::StdRng;
use rand::{Rng, SeedableRng};
use tokio_rustls::rustls::{ClientConfig, RootCertStore};
use tokio_rustls::rustls::pki_types::ServerName;
use tokio_rustls::TlsConnector;
use trust_dns_proto::op::{Message, MessageType, OpCode, Query};
use trust_dns_proto::rr::{Name, RecordType};
use http::Request;
use h2::RecvStream;
use http::response::Parts;
use tokio::io::{AsyncReadExt, AsyncWriteExt};

pub const MAXIMUM_DNS_PACKET_SIZE: usize = 65536;

#[tokio::test]
async fn test_dns_over_https() -> anyhow::Result<()> {
let host = "www.google.com";

let mut fqdn = host.to_owned();
fqdn.push('.');
let name = match Name::from_str(&fqdn) {
Ok(n) => n,
Err(e) => return Err(anyhow!("invalid domain name [{}]: {}", host, e)),
};

let mut msg = Message::new();
msg.add_query(Query::query(name, RecordType::A));
let mut rng = StdRng::from_entropy();
let id: u16 = rng.gen();
msg.set_id(id);
msg.set_op_code(OpCode::Query);
msg.set_message_type(MessageType::Query);
msg.set_recursion_desired(true);
let msg_buf = match msg.to_vec() {
Ok(b) => b,
Err(e) => return Err(anyhow!("encode message to buffer failed: {}", e)),
};

// create a tcp socket
let socket = TcpSocket::new_v4()?;
let mut stream = socket.connect("1.1.1.1:443".parse()?).await?;
// use tokio-rustls to make a tls handshake
let mut client_config = ClientConfig::builder()
.with_root_certificates(RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec()
})
.with_no_client_auth();
client_config.alpn_protocols.push(b"h2".to_vec());
let connector = TlsConnector::from(Arc::new(client_config));
let stream = connector.connect(ServerName::try_from("cloudflare-dns.com")?, stream).await?;

// handshake h2
let (send_request, connection) = h2::client::handshake(stream).await?;
tokio::spawn(async move {
connection.await.unwrap();
});

let mut send_request = send_request.ready().await?;

for i in 0..2 {
let request = Request::builder()
.method("POST")
.uri("https://cloudflare-dns.com/dns-query")
.header("accept", "application/dns-message")
.header("content-type", "application/dns-message")
.header("content-length", msg_buf.len().to_string())
.header("host", "cloudflare-dns.com")
.version(http::Version::HTTP_2)
.body(())
.unwrap();

let (mut response, mut req) = send_request.send_request(request, false)?;
req.send_data(Bytes::from(msg_buf.clone()), true)?;

let (header, mut recv_stream) = response.await?.into_parts();

check_header_status(&header)?;
check_header_content_type(&header)?;

let body = get_body(&mut recv_stream).await?;
let dns_response = Message::from_vec(&body)?;

println!("response for {}: {:?}",i, dns_response);
}

Ok(())
}

#[tokio::test]
async fn test_dns_over_tls() -> anyhow::Result<()> {
let host = "www.google.com";

let mut fqdn = host.to_owned();
fqdn.push('.');
let name = match Name::from_str(&fqdn) {
Ok(n) => n,
Err(e) => return Err(anyhow!("invalid domain name [{}]: {}", host, e)),
};

let mut msg = Message::new();
msg.add_query(Query::query(name, RecordType::A));
let mut rng = StdRng::from_entropy();
let id: u16 = rng.gen();
msg.set_id(id);
msg.set_op_code(OpCode::Query);
msg.set_message_type(MessageType::Query);
msg.set_recursion_desired(true);
let msg_buf = match msg.to_vec() {
Ok(b) => b,
Err(e) => return Err(anyhow!("encode message to buffer failed: {}", e)),
};

// create a tcp socket
let socket = TcpSocket::new_v4()?;
let mut stream = socket.connect("1.1.1.1:853".parse()?).await?;
// use tokio-rustls to make a tls handshake
let mut client_config = ClientConfig::builder()
.with_root_certificates(RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec()
})
.with_no_client_auth();
let connector = TlsConnector::from(Arc::new(client_config));
let mut stream = connector.connect(ServerName::try_from("cloudflare-dns.com")?, stream).await?;

let len = msg_buf.len();
let mut req = vec![0u8; 2 + len];

req[0] = (len >> 8) as u8;
req[1] = (len & 0xff) as u8;
req[2..].copy_from_slice(&msg_buf);

stream.write_all(&req).await?;

// read 2 bytes as length
let mut buf = [0u8; 2];
stream.read_exact(&mut buf).await?;
let len = (buf[0] as u16) << 8 | (buf[1] as u16);
let mut res = vec![0u8; len as usize];
stream.read_exact(&mut res).await?;

let dns_response = Message::from_vec(&res)?;

println!("response {:?}", dns_response);

Ok(())
}

#[tokio::test]
async fn test_dns_over_tcp() -> anyhow::Result<()> {
let host = "www.google.com";

let mut fqdn = host.to_owned();
fqdn.push('.');
let name = match Name::from_str(&fqdn) {
Ok(n) => n,
Err(e) => return Err(anyhow!("invalid domain name [{}]: {}", host, e)),
};

let mut msg = Message::new();
msg.add_query(Query::query(name, RecordType::A));
let mut rng = StdRng::from_entropy();
let id: u16 = rng.gen();
msg.set_id(id);
msg.set_op_code(OpCode::Query);
msg.set_message_type(MessageType::Query);
msg.set_recursion_desired(true);
let msg_buf = match msg.to_vec() {
Ok(b) => b,
Err(e) => return Err(anyhow!("encode message to buffer failed: {}", e)),
};

// create a tcp socket
let socket = TcpSocket::new_v4()?;
let mut stream = socket.connect("1.1.1.1:53".parse()?).await?;

let len = msg_buf.len();
let mut req = vec![0u8; 2 + len];

req[0] = (len >> 8) as u8;
req[1] = (len & 0xff) as u8;
req[2..].copy_from_slice(&msg_buf);

// write request
stream.write_all(&req).await?;

// Read 2 bytes as length
let len = stream.read_u16().await?;
println!("length: {}", len);

// read remaining bytes
let mut res = vec![0u8; len as usize];
stream.read_exact(&mut res).await?;
println!("received {} bytes", len);

let dns_response = Message::from_vec(&res)?;

println!("response {:?}", dns_response);

Ok(())
}

fn check_header_status(header: &Parts) -> anyhow::Result<()> {
if header.status.is_success() {
Ok(())
} else {
Err(anyhow!("response error {}", header.status))
}
}

fn check_header_content_type(header: &Parts) -> anyhow::Result<()> {
match header.headers.get("content-type") {
Some(value) => {
if value == "application/dns-message" {
Ok(())
} else {
Err(anyhow!("invalid content-type: {}", value.to_str()?))
}
}
None => Err(anyhow!("missing content-type header")),
}
}

async fn get_body(recv_stream: &mut RecvStream) -> anyhow::Result<Bytes> {
let mut body = BytesMut::new();
while let Some(result) = recv_stream.data().await {
match result {
Ok(b) => {
let body_len = body.len();
let b_len = b.len();

recv_stream.flow_control().release_capacity(b_len)?;

if body_len < MAXIMUM_DNS_PACKET_SIZE {
if body_len + b_len < MAXIMUM_DNS_PACKET_SIZE {
body.extend(b);
} else {
body.extend(b.slice(0..MAXIMUM_DNS_PACKET_SIZE - body_len));
break;
}
} else {
break;
}
}
Err(e) => {
// If we get a reset and already received any bytes then use as a response.
if e.is_reset() && !body.is_empty() {
break;
} else {
return Err(anyhow!("recv failed: {}", e));
}
}
}
}
Ok(body.freeze())
}
}
3 changes: 0 additions & 3 deletions leaf/src/proxy/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,6 @@ pub mod ws;
#[cfg(feature = "outbound-fragment")]
pub mod fragment;

#[cfg(feature = "inbound-dns")]
pub mod dns;

#[cfg(feature = "outbound-h2")]
pub mod h2;

Expand Down

0 comments on commit 1df3ff2

Please sign in to comment.