Skip to content

Commit

Permalink
Support HTTP host sniffing
Browse files Browse the repository at this point in the history
Credits #288

Closes #288
  • Loading branch information
eycorsican committed Oct 5, 2024
1 parent 415dab3 commit 1ab748d
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 120 deletions.
11 changes: 6 additions & 5 deletions leaf/src/app/dispatcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,19 @@ impl Dispatcher {
T: 'static + AsyncRead + AsyncWrite + Unpin + Send + Sync,
{
debug!("dispatching {}:{}", &sess.network, &sess.destination);
let mut lhs: Box<dyn ProxyStream> = if *option::DOMAIN_SNIFFING
&& !sess.destination.is_domain()
&& sess.destination.port() == 443
{
let mut lhs: Box<dyn ProxyStream> = if sniff::should_sniff(&sess) {
let mut lhs = sniff::SniffingStream::new(lhs);
match lhs.sniff().await {
match lhs.sniff(&sess).await {
Ok(res) => {
if let Some(domain) = res {
debug!(
"sniffed domain {} for tcp link {} <-> {}",
&domain, &sess.source, &sess.destination,
);
// TODO Add an option to use the sniffed domain for routing only
//
// TODO Add DNS sniff, sniff domain name from DNS response, keep
// an IP -> domain mapping, use this info for routing only.
sess.destination =
match SocksAddr::try_from((&domain, sess.destination.port())) {
Ok(a) => a,
Expand Down
319 changes: 206 additions & 113 deletions leaf/src/common/sniff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,46 @@ use bytes::BytesMut;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, ReadBuf};
use tokio::time::timeout;

use crate::option;
use crate::session::Session;

fn should_sniff_tls(sess: &Session) -> bool {
if *option::TLS_DOMAIN_SNIFFING {
if !*option::TLS_DOMAIN_SNIFFING_ALL && sess.destination.port() != 443 {
return false;
}
true
} else {
false
}
}

fn should_sniff_http(sess: &Session) -> bool {
if *option::HTTP_DOMAIN_SNIFFING {
if !*option::HTTP_DOMAIN_SNIFFING_ALL && sess.destination.port() != 80 {
return false;
}
true
} else {
false
}
}

pub fn should_sniff(sess: &Session) -> bool {
!sess.destination.is_domain() && (should_sniff_tls(sess) || should_sniff_http(sess))
}

pub struct SniffingStream<T> {
inner: T,
buf: BytesMut,
}

enum SniffResult {
NotMatch,
NotEnoughData,
Domain(String),
}

impl<T> SniffingStream<T>
where
T: AsyncRead + AsyncWrite + Unpin,
Expand All @@ -24,128 +59,186 @@ where
}
}

pub async fn sniff(&mut self) -> io::Result<Option<String>> {
fn sniff_http_host(&self, buf: &[u8]) -> SniffResult {
// Credits https://github.com/eycorsican/leaf/pull/288

let bytes_str = String::from_utf8_lossy(buf);
let parts: Vec<&str> = bytes_str.split("\r\n").collect();

if parts.len() == 0 {
return SniffResult::NotMatch;
}

let http_methods = [
"get", "post", "head", "put", "delete", "options", "connect", "patch", "trace",
];
let method_str = parts[0];

let matched_method = http_methods
.into_iter()
.filter(|item| method_str.to_lowercase().contains(item))
.count();

if matched_method == 0 {
return SniffResult::NotMatch;
}

for (idx, &el) in parts.iter().enumerate() {
if idx == 0 || el == "" {
continue;
}
let inner_parts: Vec<&str> = el.split(":").collect();
if inner_parts.len() != 2 {
continue;
}
if inner_parts[0].to_lowercase() == "host" {
return SniffResult::Domain(inner_parts[1].trim().to_string());
}
}

SniffResult::NotMatch
}

fn sniff_tls_sni(&self, buf: &[u8]) -> SniffResult {
// https://tls.ulfheim.net/

let sbuf = &buf[..];
if sbuf.len() < 5 {
return SniffResult::NotEnoughData;
}
// handshake record type
if sbuf[0] != 0x16 {
return SniffResult::NotMatch;
}
// protocol version
if sbuf[1] != 0x3 {
return SniffResult::NotMatch;
}
let header_len = u16::from_be_bytes(sbuf[3..5].try_into().unwrap()) as usize;
if sbuf.len() < 5 + header_len {
return SniffResult::NotEnoughData;
}
let sbuf = &sbuf[5..5 + header_len];
// ?
if sbuf.len() < 42 {
return SniffResult::NotEnoughData;
}
let session_id_len = sbuf[38] as usize;
if session_id_len > 32 || sbuf.len() < 39 + session_id_len {
return SniffResult::NotEnoughData;
}
let sbuf = &sbuf[39 + session_id_len..];
if sbuf.len() < 2 {
return SniffResult::NotEnoughData;
}
let cipher_suite_bytes = u16::from_be_bytes(sbuf[..2].try_into().unwrap()) as usize;
if sbuf.len() < 2 + cipher_suite_bytes {
return SniffResult::NotEnoughData;
}
let sbuf = &sbuf[2 + cipher_suite_bytes..];
if sbuf.is_empty() {
return SniffResult::NotEnoughData;
}
let compression_method_bytes = sbuf[0] as usize;
if sbuf.len() < 1 + compression_method_bytes {
return SniffResult::NotEnoughData;
}
let sbuf = &sbuf[1 + compression_method_bytes..];
if sbuf.len() < 2 {
return SniffResult::NotEnoughData;
}
let extensions_bytes = u16::from_be_bytes(sbuf[..2].try_into().unwrap()) as usize;
if sbuf.len() < 2 + extensions_bytes {
return SniffResult::NotEnoughData;
}
let mut sbuf = &sbuf[2..2 + extensions_bytes];
while !sbuf.is_empty() {
// extension + extension-specific-len
if sbuf.len() < 4 {
return SniffResult::NotEnoughData;
}
let extension = u16::from_be_bytes(sbuf[..2].try_into().unwrap());
let extension_len = u16::from_be_bytes(sbuf[2..4].try_into().unwrap()) as usize;
sbuf = &sbuf[4..];
if sbuf.len() < extension_len {
return SniffResult::NotEnoughData;
}
// extension "server name"
if extension == 0x0 {
let mut ebuf = &sbuf[..extension_len];
if ebuf.len() < 2 {
return SniffResult::NotEnoughData;
}
let entry_len = u16::from_be_bytes(ebuf[..2].try_into().unwrap()) as usize;
ebuf = &ebuf[2..];
if ebuf.len() < entry_len {
return SniffResult::NotEnoughData;
}
// just make sure no oob
if ebuf.is_empty() {
return SniffResult::NotEnoughData;
}
let entry_type = ebuf[0];
// type "DNS hostname"
if entry_type == 0x0 {
ebuf = &ebuf[1..];
// just make sure no oob
if ebuf.len() < 2 {
return SniffResult::NotEnoughData;
}
let hostname_len = u16::from_be_bytes(ebuf[..2].try_into().unwrap()) as usize;
ebuf = &ebuf[2..];
if ebuf.len() < hostname_len {
return SniffResult::NotEnoughData;
}
return SniffResult::Domain(
String::from_utf8_lossy(&ebuf[..hostname_len]).into(),
);
} else {
// TODO
// I assume there's only "DNS hostname" type
// in the the "server name" extension, should
// check if this is true later.
//
// I also assume there's only one entry in the
// "server name" extension list.
return SniffResult::NotMatch;
}
} else {
sbuf = &sbuf[extension_len..];
}
}
SniffResult::NotEnoughData
}

pub async fn sniff(&mut self, sess: &Session) -> io::Result<Option<String>> {
let mut buf = vec![0u8; 2 * 1024];
'outer: for _ in 0..2 {
for _ in 0..2 {
match timeout(Duration::from_millis(100), self.inner.read(&mut buf)).await {
Ok(res) => match res {
Ok(n) => {
self.buf.extend_from_slice(&buf[..n]);

// https://tls.ulfheim.net/

let sbuf = &self.buf[..];
if sbuf.len() < 5 {
continue;
let mut tls_not_match = true;
let mut http_not_match = true;
if should_sniff_tls(sess) {
tls_not_match = false;
match self.sniff_tls_sni(&buf[..n]) {
SniffResult::NotEnoughData => (),
SniffResult::NotMatch => tls_not_match = true,
SniffResult::Domain(domain) => return Ok(Some(domain)),
}
}
// handshake record type
if sbuf[0] != 0x16 {
return Ok(None);
if should_sniff_http(sess) {
http_not_match = false;
match self.sniff_http_host(&buf[..n]) {
SniffResult::NotEnoughData => (),
SniffResult::NotMatch => http_not_match = true,
SniffResult::Domain(domain) => return Ok(Some(domain)),
}
}
// protocol version
if sbuf[1] != 0x3 {
if tls_not_match && http_not_match {
return Ok(None);
}
let header_len =
u16::from_be_bytes(sbuf[3..5].try_into().unwrap()) as usize;
if sbuf.len() < 5 + header_len {
continue;
}
let sbuf = &sbuf[5..5 + header_len];
// ?
if sbuf.len() < 42 {
continue;
}
let session_id_len = sbuf[38] as usize;
if session_id_len > 32 || sbuf.len() < 39 + session_id_len {
continue;
}
let sbuf = &sbuf[39 + session_id_len..];
if sbuf.len() < 2 {
continue;
}
let cipher_suite_bytes =
u16::from_be_bytes(sbuf[..2].try_into().unwrap()) as usize;
if sbuf.len() < 2 + cipher_suite_bytes {
continue;
}
let sbuf = &sbuf[2 + cipher_suite_bytes..];
if sbuf.is_empty() {
continue;
}
let compression_method_bytes = sbuf[0] as usize;
if sbuf.len() < 1 + compression_method_bytes {
continue;
}
let sbuf = &sbuf[1 + compression_method_bytes..];
if sbuf.len() < 2 {
continue;
}
let extensions_bytes =
u16::from_be_bytes(sbuf[..2].try_into().unwrap()) as usize;
if sbuf.len() < 2 + extensions_bytes {
continue;
}
let mut sbuf = &sbuf[2..2 + extensions_bytes];
while !sbuf.is_empty() {
// extension + extension-specific-len
if sbuf.len() < 4 {
continue 'outer;
}
let extension = u16::from_be_bytes(sbuf[..2].try_into().unwrap());
let extension_len =
u16::from_be_bytes(sbuf[2..4].try_into().unwrap()) as usize;
sbuf = &sbuf[4..];
if sbuf.len() < extension_len {
continue 'outer;
}
// extension "server name"
if extension == 0x0 {
let mut ebuf = &sbuf[..extension_len];
if ebuf.len() < 2 {
continue 'outer;
}
let entry_len =
u16::from_be_bytes(ebuf[..2].try_into().unwrap()) as usize;
ebuf = &ebuf[2..];
if ebuf.len() < entry_len {
continue 'outer;
}
// just make sure no oob
if ebuf.is_empty() {
continue 'outer;
}
let entry_type = ebuf[0];
// type "DNS hostname"
if entry_type == 0x0 {
ebuf = &ebuf[1..];
// just make sure no oob
if ebuf.len() < 2 {
continue 'outer;
}
let hostname_len =
u16::from_be_bytes(ebuf[..2].try_into().unwrap()) as usize;
ebuf = &ebuf[2..];
if ebuf.len() < hostname_len {
continue 'outer;
}
return Ok(Some(
String::from_utf8_lossy(&ebuf[..hostname_len]).into(),
));
} else {
// TODO
// I assume there's only "DNS hostname" type
// in the the "server name" extension, should
// check if this is true later.
//
// I also assume there's only one entry in the
// "server name" extension list.
return Ok(None);
}
} else {
sbuf = &sbuf[extension_len..];
}
}
}
Err(e) => {
return Err(e);
Expand Down
Loading

0 comments on commit 1ab748d

Please sign in to comment.