Skip to content

Commit

Permalink
reading code
Browse files Browse the repository at this point in the history
  • Loading branch information
ssrlive committed Jul 9, 2023
1 parent 5fc8bc4 commit bd9c228
Show file tree
Hide file tree
Showing 9 changed files with 116 additions and 110 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
examples/
nginx_signing.key
overtls-daemon.sh
project.xcworkspace/
Expand Down
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,15 @@ rustls = "0.21"
rustls-pemfile = "1.0"
serde = { version = "1.0", features = [ "derive" ] }
serde_json = "1.0"
socks5-impl = "0.2"
socks5-impl = "0.3"
thiserror = "1.0"
tokio = { version = "1.28", features = [ "full" ] }
tokio-rustls = "0.24"
tokio-tungstenite = { version = "0.19", features = [ "rustls-tls-webpki-roots" ] }
tungstenite = { version = "0.19", features = [ "rustls-tls-webpki-roots" ] }
url = "2.3"
webpki = { package = "rustls-webpki", version = "0.100", features = ["alloc", "std"] }
webpki-roots = "0.23"
webpki-roots = "0.24"

[target.'cfg(target_family="unix")'.dependencies]
daemonize = "0.5"
Expand Down
16 changes: 12 additions & 4 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async fn handle_incoming(
) -> Result<()> {
let peer_addr = conn.peer_addr()?;
match conn.handshake().await? {
Connection::Associate(asso, _) => {
Connection::UdpAssociate(asso, _) => {
if let Some(udp_tx) = udp_tx {
if let Err(e) = udprelay::handle_s5_upd_associate(asso, udp_tx, incomings).await {
log::debug!("{peer_addr} handle_s5_upd_associate \"{e}\"");
Expand Down Expand Up @@ -116,7 +116,7 @@ async fn handle_socks5_cmd_connection(connect: Connect<NeedReply>, target_addr:

let client = config.client.as_ref().ok_or("client not exist")?;
let (ip_addr, port) = (client.server_host.as_str(), client.server_port);
let addr = &SocketAddr::new(ip_addr.parse()?, port);
let addr = SocketAddr::new(ip_addr.parse()?, port);

if !config.disable_tls() {
let ws_stream = create_tls_ws_stream(addr, Some(target_addr.clone()), &config, None).await?;
Expand All @@ -134,6 +134,7 @@ async fn client_traffic_loop<T: AsyncRead + AsyncWrite + Unpin, S: AsyncRead + A
peer_addr: SocketAddr,
target_addr: Address,
) -> Result<()> {
let mut timer = tokio::time::interval(std::time::Duration::from_secs(30));
loop {
let mut buf = BytesMut::with_capacity(crate::STREAM_BUFFER_SIZE);
tokio::select! {
Expand Down Expand Up @@ -171,9 +172,16 @@ async fn client_traffic_loop<T: AsyncRead + AsyncWrite + Unpin, S: AsyncRead + A
log::trace!("{} <- {} ws closed", peer_addr, target_addr);
break;
}
Message::Pong(_) => {
log::trace!("{} <- {} Websocket pong from remote", peer_addr, target_addr);
},
_ => {}
}
}
_ = timer.tick() => {
ws_stream.send(Message::Ping(vec![])).await?;
log::trace!("{} -> {} Websocket ping from local", peer_addr, target_addr);
}
}
}
Ok(())
Expand All @@ -182,7 +190,7 @@ async fn client_traffic_loop<T: AsyncRead + AsyncWrite + Unpin, S: AsyncRead + A
type WsTlsStream = WebSocketStream<TlsStream<TcpStream>>;

pub(crate) async fn create_tls_ws_stream(
svr_addr: &SocketAddr,
svr_addr: SocketAddr,
dst_addr: Option<Address>,
config: &Config,
udp_tunnel: Option<bool>,
Expand All @@ -199,7 +207,7 @@ pub(crate) async fn create_tls_ws_stream(
}

pub(crate) async fn create_plaintext_ws_stream(
server_addr: &SocketAddr,
server_addr: SocketAddr,
dst_addr: Option<Address>,
config: &Config,
udp_tunnel: Option<bool>,
Expand Down
10 changes: 2 additions & 8 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,19 +56,13 @@ pub enum Error {
#[error("std::str::Utf8Error {0}")]
Utf8(#[from] std::str::Utf8Error),

#[error("&str error: {0}")]
Str(String),

#[error("String error: {0}")]
String(String),

#[error("&String error: {0}")]
RefString(String),
}

impl From<&str> for Error {
fn from(s: &str) -> Self {
Error::Str(s.to_string())
Error::String(s.to_string())
}
}

Expand All @@ -80,7 +74,7 @@ impl From<String> for Error {

impl From<&String> for Error {
fn from(s: &String) -> Self {
Error::RefString(s.to_string())
Error::String(s.to_string())
}
}

Expand Down
149 changes: 69 additions & 80 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::{
};
use bytes::{BufMut, BytesMut};
use futures_util::{SinkExt, StreamExt};
use socks5_impl::protocol::Address;
use socks5_impl::protocol::{Address, AddressType};
use std::{
collections::HashMap,
net::{SocketAddr, ToSocketAddrs},
Expand Down Expand Up @@ -43,31 +43,33 @@ pub async fn run_server(config: &Config, exiting_flag: Option<Arc<AtomicBool>>)
let p = server.listen_port;
let addr: SocketAddr = (h, p).to_socket_addrs()?.next().ok_or("Invalid server address")?;

let certs = if let Some(ref cert) = server.certfile {
let certs = server.certfile.as_ref().and_then(|cert| {
if !config.disable_tls() {
server_load_certs(cert).ok()
} else {
None
}
} else {
None
};
});

let keys = if let Some(ref key) = server.keyfile {
let keys = server.keyfile.as_ref().and_then(|key| {
if !config.disable_tls() {
server_load_keys(key).ok()
let keys = server_load_keys(key).ok();
if keys.as_ref().map(|keys| keys.len()).unwrap_or(0) > 0 {
keys
} else {
None
}
} else {
None
}
} else {
None
};
});

let svr_cfg = if let (Some(certs), Some(mut keys)) = (certs, keys) {
let svr_cfg = if let (Some(certs), Some(keys)) = (certs, keys) {
let key = keys.get(0).ok_or("no keys")?.clone();
rustls::ServerConfig::builder()
.with_safe_defaults()
.with_no_client_auth()
.with_single_cert(certs, keys.remove(0))
.with_single_cert(certs, key)
.ok()
} else {
None
Expand Down Expand Up @@ -99,20 +101,16 @@ pub async fn run_server(config: &Config, exiting_flag: Option<Arc<AtomicBool>>)
let incoming_task = async move {
if let Some(acceptor) = acceptor {
let stream = acceptor.accept(stream).await?;
if let Err(e) = handle_incoming(stream, peer_addr, config, traffic_audit).await {
log::debug!("{}: {}", peer_addr, e);
}
} else if let Err(e) = handle_incoming(stream, peer_addr, config, traffic_audit).await {
log::debug!("{}: {}", peer_addr, e);
handle_incoming(stream, peer_addr, config, traffic_audit).await?;
} else {
log::debug!("some unknown error with {}", peer_addr);
handle_incoming(stream, peer_addr, config, traffic_audit).await?;
}
Ok::<_, Error>(())
};

tokio::spawn(async move {
if let Err(err) = incoming_task.await {
log::debug!("{:?}", err);
if let Err(e) = incoming_task.await {
log::debug!("{peer_addr}: {e}");
}
});
}
Expand Down Expand Up @@ -189,7 +187,7 @@ where
let tls_enable = scheme == "https";
let host = url.host_str().ok_or("url host not exist")?;
let port = url.port_or_known_default().ok_or("port not exist")?;
let forward_addr = &SocketAddr::new(host.parse()?, port);
let forward_addr = SocketAddr::new(host.parse()?, port);

if tls_enable {
let cert_store = retrieve_root_cert_store_for_client(&None)?;
Expand Down Expand Up @@ -291,7 +289,7 @@ async fn websocket_traffic_handler<S: AsyncRead + AsyncWrite + Unpin>(
let addr_str = b64str_to_address(&target_address, false)?.to_string();
let dst_addr = addr_str.to_socket_addrs()?.next().ok_or("addr string parse failed")?;
log::trace!("{} -> {} {client_id:?} uri path: \"{}\"", peer, dst_addr, uri_path);
result = normal_tunnel(ws_stream, config, traffic_audit, &client_id, &dst_addr).await;
result = normal_tunnel(ws_stream, peer, config, traffic_audit, &client_id, dst_addr).await;
if let Err(ref e) = result {
log::debug!("{} <> {} connection closed error: {}", peer, dst_addr, e);
} else {
Expand All @@ -303,79 +301,56 @@ async fn websocket_traffic_handler<S: AsyncRead + AsyncWrite + Unpin>(

async fn normal_tunnel<S: AsyncRead + AsyncWrite + Unpin>(
mut ws_stream: WebSocketStream<S>,
peer: SocketAddr,
_config: Config,
traffic_audit: TrafficAuditPtr,
client_id: &Option<String>,
dst_addr: &SocketAddr,
dst_addr: SocketAddr,
) -> Result<()> {
let mut outgoing = crate::tcp_stream::create(dst_addr).await?;

let (ws_stream_tx, mut ws_stream_rx) = tokio::sync::mpsc::channel(1024);
let (outgoing_tx, mut outgoing_rx) = tokio::sync::mpsc::channel(1024);

let ws_stream_to_outgoing = async move {
loop {
tokio::select! {
Some(msg) = ws_stream.next() => {
let msg = msg?;
if let Some(client_id) = &client_id {
let len = (msg.len() + WS_MSG_HEADER_LEN) as u64;
traffic_audit.lock().await.add_upstream_traffic_of(client_id, len);
}
if msg.is_close() {
let mut buffer = [0; crate::STREAM_BUFFER_SIZE];
loop {
tokio::select! {
msg = ws_stream.next() => {
let msg = msg.ok_or(format!("{peer} -> {dst_addr} no Websocket message"))??;
if let Some(client_id) = &client_id {
let len = (msg.len() + WS_MSG_HEADER_LEN) as u64;
traffic_audit.lock().await.add_upstream_traffic_of(client_id, len);
}
match msg {
Message::Close(_) => {
log::trace!("{peer} <> {dst_addr} incoming connection closed normally");
break;
}
if msg.is_text() || msg.is_binary() {
outgoing_tx.send(msg.into_data()).await?;
}
}
Some(data) = ws_stream_rx.recv() => {
let msg = Message::binary(data);
if let Some(client_id) = &client_id {
let len = (msg.len() + WS_MSG_HEADER_LEN) as u64;
traffic_audit.lock().await.add_downstream_traffic_of(client_id, len);
Message::Text(_) | Message::Binary(_) => {
outgoing.write_all(&msg.into_data()).await?;
}
ws_stream.send(msg).await?;
}
else => {
break;
_ => {}
}
}
}
Ok::<_, Error>(())
};

let outgoing_to_ws_stream = async move {
loop {
tokio::select! {
Ok(data) = async {
let mut b2 = [0; crate::STREAM_BUFFER_SIZE];
let n = outgoing.read(&mut b2).await?;
Ok::<_, Error>(Some(b2[..n].to_vec()))
} => {
if let Some(data) = data {
if data.is_empty() {
break;
len = outgoing.read(&mut buffer) => {
match len {
Ok(0) => {
ws_stream.send(Message::Close(None)).await?;
log::trace!("{} <> {} outgoing connection reached EOF", peer, dst_addr);
break;
}
Ok(n) => {
let msg = Message::Binary(buffer[..n].to_vec());
if let Some(client_id) = &client_id {
let len = (msg.len() + WS_MSG_HEADER_LEN) as u64;
traffic_audit.lock().await.add_downstream_traffic_of(client_id, len);
}
ws_stream_tx.send(data).await?;
} else {
ws_stream.send(msg).await?;
}
Err(e) => {
ws_stream.send(Message::Close(None)).await?;
log::debug!("{} <> {} outgoing connection closed \"{}\"", peer, dst_addr, e);
break;
}
}
Some(msg) = outgoing_rx.recv() => {
outgoing.write_all(&msg).await?;
}
else => {
break;
}
}
}
Ok::<_, Error>(())
};

tokio::select! {
r = ws_stream_to_outgoing => { if let Err(e) = r { log::debug!("{} ws_stream_to_outgoing \"{}\"", dst_addr, e); } }
r = outgoing_to_ws_stream => { if let Err(e) = r { log::debug!("{} outgoing_to_ws_stream \"{}\"", dst_addr, e); } }
}
Ok(())
}
Expand Down Expand Up @@ -417,7 +392,21 @@ async fn create_udp_tunnel<S: AsyncRead + AsyncWrite + Unpin>(

dst_src_pairs.lock().await.insert(dst_addr.clone(), src_addr);

let dst_addr = SocketAddr::try_from(dst_addr)?;
let dst_addr = if dst_addr.port() == 53 {
match dst_addr.get_type() {
AddressType::IPv4 => {
"8.8.8.8:53".parse::<SocketAddr>()?
}
AddressType::IPv6 => {
"[2001:4860:4860::8888]:53".parse::<SocketAddr>()?
}
_ => {
return Err(Error::from("invalid address type"));
}
}
} else {
dst_addr.to_socket_addrs()?.next().ok_or("invalid address")?
};

if dst_addr.is_ipv4() {
udp_socket.send_to(&pkt, &dst_addr).await?;
Expand Down
4 changes: 2 additions & 2 deletions src/tcp_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::error::Result;
use std::net::SocketAddr;
use tokio::net::TcpStream;

pub(crate) async fn create(addr: &SocketAddr) -> Result<TcpStream> {
pub(crate) async fn create(addr: SocketAddr) -> Result<TcpStream> {
#[cfg(target_os = "android")]
{
let socket = if addr.is_ipv4() {
Expand All @@ -16,7 +16,7 @@ pub(crate) async fn create(addr: &SocketAddr) -> Result<TcpStream> {
use std::os::unix::io::AsRawFd;
crate::android::tun_callbacks::on_socket_created(socket.as_raw_fd());

Ok(socket.connect(*addr).await?)
Ok(socket.connect(addr).await?)
}

#[cfg(not(target_os = "android"))]
Expand Down
2 changes: 1 addition & 1 deletion src/tls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ pub(crate) fn retrieve_root_cert_store_for_client(cafile: &Option<PathBuf>) -> R

pub(crate) async fn create_tls_client_stream(
root_cert_store: RootCertStore,
addr: &SocketAddr,
addr: SocketAddr,
domain: &str,
) -> Result<TlsStream<TcpStream>> {
let config = rustls::ClientConfig::builder()
Expand Down
Loading

0 comments on commit bd9c228

Please sign in to comment.