From 4a143d16abea881408c7ec074a86d0765c3eef37 Mon Sep 17 00:00:00 2001 From: sinu <65924192+sinui0@users.noreply.github.com> Date: Tue, 9 Jul 2024 08:14:13 -0700 Subject: [PATCH 1/7] feat: websocket relay --- Cargo.toml | 2 +- websocket-relay/Cargo.toml | 14 +++++ websocket-relay/src/main.rs | 121 ++++++++++++++++++++++++++++++++++++ 3 files changed, 136 insertions(+), 1 deletion(-) create mode 100644 websocket-relay/Cargo.toml create mode 100644 websocket-relay/src/main.rs diff --git a/Cargo.toml b/Cargo.toml index 40f1827..c79f02a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["utils", "utils-aio", "spansy", "serio", "uid-mux"] +members = ["utils", "utils-aio", "spansy", "serio", "uid-mux", "websocket-relay"] [workspace.dependencies] tlsn-utils = { path = "utils" } diff --git a/websocket-relay/Cargo.toml b/websocket-relay/Cargo.toml new file mode 100644 index 0000000..8bb7605 --- /dev/null +++ b/websocket-relay/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "websocket-relay" +version = "0.1.0" +edition = "2021" + +[dependencies] +tokio = { version = "1", features = ["full"] } +tokio-tungstenite = { version = "0.23", features = ["url"] } +anyhow = "1" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +form_urlencoded = "1.2" +once_cell = "1.19" +futures = "0.3" diff --git a/websocket-relay/src/main.rs b/websocket-relay/src/main.rs new file mode 100644 index 0000000..f3d9795 --- /dev/null +++ b/websocket-relay/src/main.rs @@ -0,0 +1,121 @@ +use std::{ + collections::HashMap, + env, + net::{IpAddr, SocketAddr}, + sync::{Arc, Mutex}, +}; + +use anyhow::{anyhow, Context, Result}; +use futures::StreamExt as _; +use once_cell::sync::Lazy; +use tokio::net::{TcpListener, TcpStream}; +use tokio_tungstenite::{accept_hdr_async, tungstenite::http::Request, WebSocketStream}; +use tracing::{debug, info, instrument}; +use tracing_subscriber::EnvFilter; + +fn init_global_subscriber() { + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .init(); +} + +#[derive(Debug, Default)] +struct State { + connections: HashMap>, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct ConnectionId(String); + +static STATE: Lazy>> = Lazy::new(|| Default::default()); + +#[tokio::main] +#[instrument] +async fn main() -> Result<()> { + init_global_subscriber(); + + let port: u16 = env::var("PORT") + .map(|port| port.parse().expect("port should be valid integer")) + .unwrap_or(8080); + let addr: IpAddr = env::var("ADDR") + .map(|addr| addr.parse().expect("addr should be valid IP address")) + .unwrap_or(IpAddr::V4("127.0.0.1".parse().unwrap())); + + let listener = TcpListener::bind((addr, port)) + .await + .context("failed to bind to address")?; + + info!("listening on: {}", listener.local_addr()?); + + loop { + let (socket, addr) = listener.accept().await?; + info!("accepted connection from: {}", addr); + + tokio::spawn(handle_connection(addr, socket)); + } +} + +#[instrument(skip(io), err)] +async fn handle_connection(addr: SocketAddr, io: TcpStream) -> Result<()> { + let (id, ws) = accept_ws(io).await?; + + let mut state = STATE.lock().unwrap(); + if let Some(peer_connection) = state.connections.remove(&id) { + tokio::spawn(relay(id, ws, peer_connection)); + } else { + state.connections.insert(id, ws); + } + + Ok(()) +} + +#[instrument(level = "debug", skip_all, err)] +async fn accept_ws(io: TcpStream) -> Result<(ConnectionId, WebSocketStream)> { + let mut query = None; + + let mut ws = accept_hdr_async(io, |req: &Request<()>, res| { + query = Some( + req.uri() + .query() + .map(ToString::to_string) + .unwrap_or_default(), + ); + + Ok(res) + }) + .await?; + + let query = query.ok_or_else(|| anyhow!("no query parameters provided in websocket url"))?; + let params = form_urlencoded::parse(query.as_bytes()); + + for (key, value) in params { + if key == "id" { + return Ok((ConnectionId(value.into_owned()), ws)); + } + } + + ws.close(None).await?; + + Err(anyhow!("id query parameter not provided in websocket url")) +} + +#[instrument(level = "debug", skip(left, right), err)] +async fn relay( + id: ConnectionId, + left: WebSocketStream, + right: WebSocketStream, +) -> Result<()> { + debug!("starting relay"); + + let (left_sink, left_stream) = left.split(); + let (right_sink, right_stream) = right.split(); + + tokio::try_join!( + left_stream.forward(right_sink), + right_stream.forward(left_sink), + )?; + + debug!("connection closed cleanly: {:?}", id); + + Ok(()) +} From e92544fd550082e2293e205956b6a8be7822686e Mon Sep 17 00:00:00 2001 From: sinui0 <65924192+sinui0@users.noreply.github.com> Date: Thu, 18 Jul 2024 15:44:18 +0900 Subject: [PATCH 2/7] support both TCP proxy and ws clients --- websocket-relay/src/lib.rs | 194 ++++++++++++++++++++++++++++++++++++ websocket-relay/src/main.rs | 114 ++------------------- 2 files changed, 205 insertions(+), 103 deletions(-) create mode 100644 websocket-relay/src/lib.rs diff --git a/websocket-relay/src/lib.rs b/websocket-relay/src/lib.rs new file mode 100644 index 0000000..8a5016c --- /dev/null +++ b/websocket-relay/src/lib.rs @@ -0,0 +1,194 @@ +use std::{ + collections::HashMap, + net::SocketAddr, + sync::{Arc, Mutex}, +}; + +use anyhow::{anyhow, Result}; +use futures::{SinkExt, StreamExt as _}; +use once_cell::sync::Lazy; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + net::{TcpListener, TcpStream}, +}; +use tokio_tungstenite::{ + accept_hdr_async, + tungstenite::{http::Request, Message}, + WebSocketStream, +}; +use tracing::{debug, info, instrument}; + +#[derive(Debug, Default)] +struct State { + waiting: HashMap>, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +struct ConnectionId(String); + +static STATE: Lazy>> = Lazy::new(|| Default::default()); + +enum Mode { + /// Acts a proxy between two websocket clients. + Ws { + id: ConnectionId, + ws: WebSocketStream, + }, + /// Acts as a proxy between a websocket client and a TCP server. + Tcp { + addr: String, + ws: WebSocketStream, + }, +} + +/// Runs the websocket relay server with the given TCP listener. +#[instrument] +pub async fn run(listener: TcpListener) -> Result<()> { + loop { + let (socket, addr) = listener.accept().await?; + info!("accepted connection from: {}", addr); + + tokio::spawn(handle_connection(addr, socket)); + } +} + +#[instrument(skip(io), err)] +async fn handle_connection(addr: SocketAddr, io: TcpStream) -> Result<()> { + match accept_ws(io).await? { + Mode::Ws { id, ws } => { + tokio::spawn(handle_ws(id, ws)); + } + Mode::Tcp { addr, ws } => { + tokio::spawn(handle_tcp(addr, ws)); + } + } + + Ok(()) +} + +#[instrument(level = "debug", skip_all, err)] +async fn accept_ws(io: TcpStream) -> Result { + let mut uri = None; + + let mut ws = accept_hdr_async(io, |req: &Request<()>, res| { + uri = Some(req.uri().clone()); + + Ok(res) + }) + .await?; + + let uri = uri.expect("uri should be set"); + let query = uri + .query() + .ok_or_else(|| anyhow!("query string not provided"))?; + let mut params = form_urlencoded::parse(query.as_bytes()) + .map(|(k, v)| (k.into_owned(), v.into_owned())) + .collect::>(); + + match uri.path() { + "/tcp" => { + let addr = params + .remove("addr") + .ok_or_else(|| anyhow!("addr query parameter not provided"))?; + + return Ok(Mode::Tcp { addr, ws }); + } + "/ws" => { + let id = params + .remove("id") + .ok_or_else(|| anyhow!("id query parameter not provided"))?; + + return Ok(Mode::Ws { + id: ConnectionId(id), + ws, + }); + } + _ => { + ws.close(None).await?; + + return Err(anyhow!("invalid path: {:?}", uri.path())); + } + } +} + +/// Relays messages between two websocket clients. +#[instrument(level = "debug", skip(ws), err)] +async fn handle_ws(id: ConnectionId, ws: WebSocketStream) -> Result<()> { + let peer = { + let mut state = STATE.lock().unwrap(); + if let Some(peer) = state.waiting.remove(&id) { + peer + } else { + state.waiting.insert(id.clone(), ws); + + debug!("connection waiting"); + + return Ok(()); + } + }; + + debug!("started"); + + let (left_sink, left_stream) = ws.split(); + let (right_sink, right_stream) = peer.split(); + + tokio::try_join!( + left_stream.forward(right_sink), + right_stream.forward(left_sink), + )?; + + debug!("connection closed cleanly"); + + Ok(()) +} + +/// Relays data between a websocket client and a TCP server. +#[instrument(level = "debug", skip(ws), err)] +async fn handle_tcp(addr: String, ws: WebSocketStream) -> Result<()> { + let mut tcp = TcpStream::connect(addr).await?; + + let (mut sink, mut stream) = ws.split(); + let (mut rx, mut tx) = tcp.split(); + + let fut_tx = async { + while let Some(msg) = stream.next().await.transpose()? { + let data = match msg { + Message::Binary(data) => data, + Message::Close(_) => { + break; + } + _ => { + return Err(anyhow!("websocket client sent non-binary message")); + } + }; + + tx.write_all(&data).await?; + } + + debug!("websocket client closed"); + + tx.shutdown().await?; + + Ok(()) + }; + + let fut_rx = async { + // 16KB buffer + let mut buf = [0; 16 * 1024]; + loop { + let n = rx.read(&mut buf).await?; + + if n == 0 { + debug!("tcp server closed"); + sink.close().await?; + return Ok(()); + } + + sink.send(Message::Binary(buf[..n].to_vec())).await?; + } + }; + + tokio::try_join!(fut_tx, fut_rx)?; + + Ok(()) +} diff --git a/websocket-relay/src/main.rs b/websocket-relay/src/main.rs index f3d9795..bd6a58f 100644 --- a/websocket-relay/src/main.rs +++ b/websocket-relay/src/main.rs @@ -1,44 +1,20 @@ -use std::{ - collections::HashMap, - env, - net::{IpAddr, SocketAddr}, - sync::{Arc, Mutex}, -}; +use std::{env, net::IpAddr}; -use anyhow::{anyhow, Context, Result}; -use futures::StreamExt as _; -use once_cell::sync::Lazy; -use tokio::net::{TcpListener, TcpStream}; -use tokio_tungstenite::{accept_hdr_async, tungstenite::http::Request, WebSocketStream}; -use tracing::{debug, info, instrument}; -use tracing_subscriber::EnvFilter; - -fn init_global_subscriber() { - tracing_subscriber::fmt() - .with_env_filter(EnvFilter::from_default_env()) - .init(); -} - -#[derive(Debug, Default)] -struct State { - connections: HashMap>, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -struct ConnectionId(String); - -static STATE: Lazy>> = Lazy::new(|| Default::default()); +use anyhow::{Context, Result}; +use tokio::net::TcpListener; +use tracing::info; #[tokio::main] -#[instrument] async fn main() -> Result<()> { - init_global_subscriber(); + tracing_subscriber::fmt() + .with_env_filter(tracing_subscriber::EnvFilter::from_default_env()) + .init(); - let port: u16 = env::var("PORT") + let port: u16 = env::var("PROXY_PORT") .map(|port| port.parse().expect("port should be valid integer")) .unwrap_or(8080); - let addr: IpAddr = env::var("ADDR") - .map(|addr| addr.parse().expect("addr should be valid IP address")) + let addr: IpAddr = env::var("PROXY_IP") + .map(|addr| addr.parse().expect("should be valid IP address")) .unwrap_or(IpAddr::V4("127.0.0.1".parse().unwrap())); let listener = TcpListener::bind((addr, port)) @@ -47,75 +23,7 @@ async fn main() -> Result<()> { info!("listening on: {}", listener.local_addr()?); - loop { - let (socket, addr) = listener.accept().await?; - info!("accepted connection from: {}", addr); - - tokio::spawn(handle_connection(addr, socket)); - } -} - -#[instrument(skip(io), err)] -async fn handle_connection(addr: SocketAddr, io: TcpStream) -> Result<()> { - let (id, ws) = accept_ws(io).await?; - - let mut state = STATE.lock().unwrap(); - if let Some(peer_connection) = state.connections.remove(&id) { - tokio::spawn(relay(id, ws, peer_connection)); - } else { - state.connections.insert(id, ws); - } - - Ok(()) -} - -#[instrument(level = "debug", skip_all, err)] -async fn accept_ws(io: TcpStream) -> Result<(ConnectionId, WebSocketStream)> { - let mut query = None; - - let mut ws = accept_hdr_async(io, |req: &Request<()>, res| { - query = Some( - req.uri() - .query() - .map(ToString::to_string) - .unwrap_or_default(), - ); - - Ok(res) - }) - .await?; - - let query = query.ok_or_else(|| anyhow!("no query parameters provided in websocket url"))?; - let params = form_urlencoded::parse(query.as_bytes()); - - for (key, value) in params { - if key == "id" { - return Ok((ConnectionId(value.into_owned()), ws)); - } - } - - ws.close(None).await?; - - Err(anyhow!("id query parameter not provided in websocket url")) -} - -#[instrument(level = "debug", skip(left, right), err)] -async fn relay( - id: ConnectionId, - left: WebSocketStream, - right: WebSocketStream, -) -> Result<()> { - debug!("starting relay"); - - let (left_sink, left_stream) = left.split(); - let (right_sink, right_stream) = right.split(); - - tokio::try_join!( - left_stream.forward(right_sink), - right_stream.forward(left_sink), - )?; - - debug!("connection closed cleanly: {:?}", id); + websocket_relay::run(listener).await?; Ok(()) } From 73a6be13081b5bc1a4224089a766e6a097d3869a Mon Sep 17 00:00:00 2001 From: sinui0 <65924192+sinui0@users.noreply.github.com> Date: Thu, 18 Jul 2024 16:52:05 +0900 Subject: [PATCH 3/7] clippy --- websocket-relay/src/lib.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/websocket-relay/src/lib.rs b/websocket-relay/src/lib.rs index 8a5016c..882d14d 100644 --- a/websocket-relay/src/lib.rs +++ b/websocket-relay/src/lib.rs @@ -26,7 +26,7 @@ struct State { #[derive(Debug, Clone, PartialEq, Eq, Hash)] struct ConnectionId(String); -static STATE: Lazy>> = Lazy::new(|| Default::default()); +static STATE: Lazy>> = Lazy::new(Default::default); enum Mode { /// Acts a proxy between two websocket clients. From 76c9de0587f8063138095f12802101e58d8bb7ed Mon Sep 17 00:00:00 2001 From: dan Date: Tue, 10 Sep 2024 06:41:37 +0000 Subject: [PATCH 4/7] fix: prevent sending after ws client closed (#38) --- websocket-relay/src/lib.rs | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/websocket-relay/src/lib.rs b/websocket-relay/src/lib.rs index 882d14d..fc4de61 100644 --- a/websocket-relay/src/lib.rs +++ b/websocket-relay/src/lib.rs @@ -1,7 +1,10 @@ use std::{ collections::HashMap, net::SocketAddr, - sync::{Arc, Mutex}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Mutex, + }, }; use anyhow::{anyhow, Result}; @@ -150,6 +153,8 @@ async fn handle_tcp(addr: String, ws: WebSocketStream) -> Result<()> let (mut sink, mut stream) = ws.split(); let (mut rx, mut tx) = tcp.split(); + let is_client_closed = AtomicBool::new(false); + let fut_tx = async { while let Some(msg) = stream.next().await.transpose()? { let data = match msg { @@ -166,6 +171,7 @@ async fn handle_tcp(addr: String, ws: WebSocketStream) -> Result<()> } debug!("websocket client closed"); + is_client_closed.store(true, Ordering::Relaxed); tx.shutdown().await?; @@ -184,7 +190,10 @@ async fn handle_tcp(addr: String, ws: WebSocketStream) -> Result<()> return Ok(()); } - sink.send(Message::Binary(buf[..n].to_vec())).await?; + // Only send to client if it hasn't closed. + if !is_client_closed.load(Ordering::Relaxed) { + sink.send(Message::Binary(buf[..n].to_vec())).await?; + } } }; From 2a2c75e5cafaab696e98ae8cbaf222dea2e81943 Mon Sep 17 00:00:00 2001 From: dan Date: Fri, 18 Oct 2024 11:42:38 +0200 Subject: [PATCH 5/7] clean up cargo.toml --- Cargo.toml | 44 ++++++++++++++++++++++---------------- websocket-relay/Cargo.toml | 14 +++++++----- 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c79f02a..463d239 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,34 +1,42 @@ [workspace] -members = ["utils", "utils-aio", "spansy", "serio", "uid-mux", "websocket-relay"] +members = [ + "serio", + "spansy", + "uid-mux", + "utils", + "utils-aio", + "utils/fuzz", + "websocket-relay" +] [workspace.dependencies] +serio = { path = "serio" } +spansy = { path = "spansy" } tlsn-utils = { path = "utils" } tlsn-utils-aio = { path = "utils-aio" } -spansy = { path = "spansy" } -serio = { path = "serio" } uid-mux = { path = "uid-mux" } -rand = "0.8" -thiserror = "1" +async-std = "1" async-trait = "0.1" -prost = "0.9" +async-tungstenite = "0.16" +bincode = "1.3" +bytes = "1" +cfg-if = "1" futures = "0.3" -futures-sink = "0.3" +futures-channel = "0.3" futures-core = "0.3" futures-io = "0.3" -futures-channel = "0.3" +futures-sink = "0.3" futures-util = "0.3" -tokio-util = "0.7" -tokio-serde = "0.8" -tokio = "1.23" -async-tungstenite = "0.16" +pin-project-lite = "0.2" +prost = "0.9" prost-build = "0.9" -bytes = "1" -async-std = "1" +rand = "0.8" rayon = "1" serde = "1" -cfg-if = "1" -bincode = "1.3" -pin-project-lite = "0.2" +thiserror = "1" +tokio = "1.23" +tokio-serde = "0.8" +tokio-util = "0.7" tracing = "0.1" -tracing-subscriber = "0.3" +tracing-subscriber = "0.3" \ No newline at end of file diff --git a/websocket-relay/Cargo.toml b/websocket-relay/Cargo.toml index 8bb7605..db080ba 100644 --- a/websocket-relay/Cargo.toml +++ b/websocket-relay/Cargo.toml @@ -2,13 +2,17 @@ name = "websocket-relay" version = "0.1.0" edition = "2021" +authors = ["TLSNotary Contributors"] +license = "MIT OR Apache-2.0" +repository = "https://github.com/tlsnotary/tlsn-utils" +description = """A relay for websocket clients.""" [dependencies] -tokio = { version = "1", features = ["full"] } -tokio-tungstenite = { version = "0.23", features = ["url"] } anyhow = "1" -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } form_urlencoded = "1.2" +futures = { workspace = true } once_cell = "1.19" -futures = "0.3" +tokio = { workspace = true, features = ["full"] } +tokio-tungstenite = { version = "0.23", features = ["url"] } +tracing = { workspace = true } +tracing-subscriber = { workspace = true, features = ["env-filter"] } From 4805164d7d438e585219c6c88ddc198c5e8e748f Mon Sep 17 00:00:00 2001 From: dan Date: Fri, 18 Oct 2024 12:01:25 +0200 Subject: [PATCH 6/7] fix multiple workspace error --- utils/fuzz/Cargo.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/utils/fuzz/Cargo.toml b/utils/fuzz/Cargo.toml index e58b4ef..3dec072 100644 --- a/utils/fuzz/Cargo.toml +++ b/utils/fuzz/Cargo.toml @@ -13,10 +13,6 @@ libfuzzer-sys = { version = "0.4", features = ["arbitrary-derive"] } [dependencies.tlsn-utils] path = ".." -# Prevent this from interfering with workspaces -[workspace] -members = ["."] - [profile.release] debug = 1 From 522b735352ca8ae26d22e9712c138965812fe198 Mon Sep 17 00:00:00 2001 From: dan Date: Fri, 18 Oct 2024 12:04:17 +0200 Subject: [PATCH 7/7] fix lints --- utils-aio/src/expect_msg.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/utils-aio/src/expect_msg.rs b/utils-aio/src/expect_msg.rs index 208c7f1..63fc4f9 100644 --- a/utils-aio/src/expect_msg.rs +++ b/utils-aio/src/expect_msg.rs @@ -27,6 +27,7 @@ mod tests { use futures_util::StreamExt; #[derive(Debug)] + #[allow(dead_code)] enum Msg { Foo(u8), Bar(u8),