Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/ws relay #43

Merged
merged 8 commits into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 26 additions & 18 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,34 +1,42 @@
[workspace]
members = ["utils", "utils-aio", "spansy", "serio", "uid-mux", "utils/fuzz"]
members = [
"serio",
"spansy",
"uid-mux",
"utils",
"utils-aio",
"utils/fuzz",
"websocket-relay"
]

[workspace.dependencies]
serio = { path = "serio" }
spansy = { path = "spansy" }
tlsn-utils = { path = "utils" }
tlsn-utils-aio = { path = "utils-aio" }
spansy = { path = "spansy" }
serio = { path = "serio" }
uid-mux = { path = "uid-mux" }

rand = "0.8"
thiserror = "1"
async-std = "1"
async-trait = "0.1"
prost = "0.9"
async-tungstenite = "0.16"
bincode = "1.3"
bytes = "1"
cfg-if = "1"
futures = "0.3"
futures-sink = "0.3"
futures-channel = "0.3"
futures-core = "0.3"
futures-io = "0.3"
futures-channel = "0.3"
futures-sink = "0.3"
futures-util = "0.3"
tokio-util = "0.7"
tokio-serde = "0.8"
tokio = "1.23"
async-tungstenite = "0.16"
pin-project-lite = "0.2"
prost = "0.9"
prost-build = "0.9"
bytes = "1"
async-std = "1"
rand = "0.8"
rayon = "1"
serde = "1"
cfg-if = "1"
bincode = "1.3"
pin-project-lite = "0.2"
thiserror = "1"
tokio = "1.23"
tokio-serde = "0.8"
tokio-util = "0.7"
tracing = "0.1"
tracing-subscriber = "0.3"
tracing-subscriber = "0.3"
1 change: 1 addition & 0 deletions utils-aio/src/expect_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ mod tests {
use futures_util::StreamExt;

#[derive(Debug)]
#[allow(dead_code)]
enum Msg {
Foo(u8),
Bar(u8),
Expand Down
18 changes: 18 additions & 0 deletions websocket-relay/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
[package]
name = "websocket-relay"
version = "0.1.0"
edition = "2021"
authors = ["TLSNotary Contributors"]
license = "MIT OR Apache-2.0"
repository = "https://github.com/tlsnotary/tlsn-utils"
description = """A relay for websocket clients."""

[dependencies]
anyhow = "1"
form_urlencoded = "1.2"
futures = { workspace = true }
once_cell = "1.19"
tokio = { workspace = true, features = ["full"] }
tokio-tungstenite = { version = "0.23", features = ["url"] }
tracing = { workspace = true }
tracing-subscriber = { workspace = true, features = ["env-filter"] }
203 changes: 203 additions & 0 deletions websocket-relay/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
use std::{
collections::HashMap,
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
},
};

use anyhow::{anyhow, Result};
use futures::{SinkExt, StreamExt as _};
use once_cell::sync::Lazy;
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
};
use tokio_tungstenite::{
accept_hdr_async,
tungstenite::{http::Request, Message},
WebSocketStream,
};
use tracing::{debug, info, instrument};

#[derive(Debug, Default)]
struct State {
waiting: HashMap<ConnectionId, WebSocketStream<TcpStream>>,
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct ConnectionId(String);

static STATE: Lazy<Arc<Mutex<State>>> = Lazy::new(Default::default);

enum Mode {
/// Acts a proxy between two websocket clients.
Ws {
id: ConnectionId,
ws: WebSocketStream<TcpStream>,
},
/// Acts as a proxy between a websocket client and a TCP server.
Tcp {
addr: String,
ws: WebSocketStream<TcpStream>,
},
}

/// Runs the websocket relay server with the given TCP listener.
#[instrument]
pub async fn run(listener: TcpListener) -> Result<()> {
loop {
let (socket, addr) = listener.accept().await?;
info!("accepted connection from: {}", addr);

tokio::spawn(handle_connection(addr, socket));
}
}

#[instrument(skip(io), err)]
async fn handle_connection(addr: SocketAddr, io: TcpStream) -> Result<()> {
match accept_ws(io).await? {
Mode::Ws { id, ws } => {
tokio::spawn(handle_ws(id, ws));
}
Mode::Tcp { addr, ws } => {
tokio::spawn(handle_tcp(addr, ws));
}
}

Ok(())
}

#[instrument(level = "debug", skip_all, err)]
async fn accept_ws(io: TcpStream) -> Result<Mode> {
let mut uri = None;

let mut ws = accept_hdr_async(io, |req: &Request<()>, res| {
uri = Some(req.uri().clone());

Ok(res)
})
.await?;

let uri = uri.expect("uri should be set");
let query = uri
.query()
.ok_or_else(|| anyhow!("query string not provided"))?;
let mut params = form_urlencoded::parse(query.as_bytes())
.map(|(k, v)| (k.into_owned(), v.into_owned()))
.collect::<HashMap<String, String>>();

match uri.path() {
"/tcp" => {
let addr = params
.remove("addr")
.ok_or_else(|| anyhow!("addr query parameter not provided"))?;

return Ok(Mode::Tcp { addr, ws });
}
"/ws" => {
let id = params
.remove("id")
.ok_or_else(|| anyhow!("id query parameter not provided"))?;

return Ok(Mode::Ws {
id: ConnectionId(id),
ws,
});
}
_ => {
ws.close(None).await?;

return Err(anyhow!("invalid path: {:?}", uri.path()));
}
}
}

/// Relays messages between two websocket clients.
#[instrument(level = "debug", skip(ws), err)]
async fn handle_ws(id: ConnectionId, ws: WebSocketStream<TcpStream>) -> Result<()> {
let peer = {
let mut state = STATE.lock().unwrap();
if let Some(peer) = state.waiting.remove(&id) {
peer
} else {
state.waiting.insert(id.clone(), ws);

debug!("connection waiting");

return Ok(());
}
};

debug!("started");

let (left_sink, left_stream) = ws.split();
let (right_sink, right_stream) = peer.split();

tokio::try_join!(
left_stream.forward(right_sink),
right_stream.forward(left_sink),
)?;

debug!("connection closed cleanly");

Ok(())
}

/// Relays data between a websocket client and a TCP server.
#[instrument(level = "debug", skip(ws), err)]
async fn handle_tcp(addr: String, ws: WebSocketStream<TcpStream>) -> Result<()> {
let mut tcp = TcpStream::connect(addr).await?;

let (mut sink, mut stream) = ws.split();
let (mut rx, mut tx) = tcp.split();

let is_client_closed = AtomicBool::new(false);

let fut_tx = async {
while let Some(msg) = stream.next().await.transpose()? {
let data = match msg {
Message::Binary(data) => data,
Message::Close(_) => {
break;
}
_ => {
return Err(anyhow!("websocket client sent non-binary message"));
}
};

tx.write_all(&data).await?;
}

debug!("websocket client closed");
is_client_closed.store(true, Ordering::Relaxed);

tx.shutdown().await?;

Ok(())
};

let fut_rx = async {
// 16KB buffer
let mut buf = [0; 16 * 1024];
loop {
let n = rx.read(&mut buf).await?;

if n == 0 {
debug!("tcp server closed");
sink.close().await?;
return Ok(());
}

// Only send to client if it hasn't closed.
if !is_client_closed.load(Ordering::Relaxed) {
sink.send(Message::Binary(buf[..n].to_vec())).await?;
}
}
};

tokio::try_join!(fut_tx, fut_rx)?;

Ok(())
}
29 changes: 29 additions & 0 deletions websocket-relay/src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use std::{env, net::IpAddr};

use anyhow::{Context, Result};
use tokio::net::TcpListener;
use tracing::info;

#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.init();

let port: u16 = env::var("PROXY_PORT")
.map(|port| port.parse().expect("port should be valid integer"))
.unwrap_or(8080);
let addr: IpAddr = env::var("PROXY_IP")
.map(|addr| addr.parse().expect("should be valid IP address"))
.unwrap_or(IpAddr::V4("127.0.0.1".parse().unwrap()));

let listener = TcpListener::bind((addr, port))
.await
.context("failed to bind to address")?;

info!("listening on: {}", listener.local_addr()?);

websocket_relay::run(listener).await?;

Ok(())
}
Loading