From 40aef96012b0ed5761738dfdd0ebc73b2edd85b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janosch=20Gr=C3=A4f?= Date: Tue, 9 Jul 2024 06:26:37 +0200 Subject: [PATCH] make WebSocket Send + Sync --- .gitignore | 1 + Cargo.toml | 8 +- examples/wasm/src/app.rs | 2 + src/lib.rs | 40 +++- src/wasm.rs | 413 +++++++++++++++++++++++---------------- 5 files changed, 283 insertions(+), 181 deletions(-) diff --git a/.gitignore b/.gitignore index 60d3e8a..1ed5d60 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ Thumbs.db # Rust /Cargo.lock /target +/.cargo # wasm example /examples/wasm/dist diff --git a/Cargo.toml b/Cargo.toml index d3c652c..2fc5656 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,8 @@ json = ["dep:serde", "dep:serde_json"] [dependencies] # pin version, see https://github.com/jgraef/reqwest-websocket/pull/33 -futures-util = { version = ">=0.3.31", default-features = false, features = ["sink"] } +futures-util = { version = ">=0.3.31", default-features = false, features = ["sink", "async-await-macro"] } +futures-channel = { version = "0.3", default-features = false, features = ["sink", "std"] } reqwest = { version = "0.12", default-features = false } thiserror = "2" tracing = "0.1" @@ -38,14 +39,13 @@ tungstenite = { version = "0.24", default-features = false, features = ["handsha [target.'cfg(target_arch = "wasm32")'.dependencies] web-sys = { version = "0.3", features = ["WebSocket", "CloseEvent", "ErrorEvent", "Event", "MessageEvent", "BinaryType"] } -tokio = { version = "1", default-features = false, features = ["sync", "macros"] } +wasm-bindgen-futures = "0.4" [dev-dependencies] tokio = { version = "1", features = ["macros", "rt"] } reqwest = { version = "0.12", features = ["default-tls"] } serde = { version = "1.0", features = ["derive"] } -futures-util = { version = "0.3", default-features = false, features = ["sink", "alloc"] } +futures-util = "0.3" [target.'cfg(target_arch = "wasm32")'.dev-dependencies] wasm-bindgen-test = "0.3" -wasm-bindgen-futures = "0.4" diff --git a/examples/wasm/src/app.rs b/examples/wasm/src/app.rs index 1165b71..0984a66 100644 --- a/examples/wasm/src/app.rs +++ b/examples/wasm/src/app.rs @@ -17,6 +17,8 @@ pub fn App() -> impl IntoView { spawn_local(async move { let websocket = reqwest_websocket::websocket("https://echo.websocket.org/").await.unwrap(); + tracing::info!("WebSocket connected"); + let (mut sender, mut receiver) = websocket.split(); futures::join!( diff --git a/src/lib.rs b/src/lib.rs index a80b77d..cb16526 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -85,7 +85,7 @@ pub enum Error { #[cfg(target_arch = "wasm32")] #[cfg_attr(docsrs, doc(cfg(target_arch = "wasm32")))] #[error("web_sys error")] - WebSys(#[from] wasm::WebSysError), + WebSys(#[from] wasm::Error), /// Error during serialization/deserialization. #[error("serde_json error")] @@ -178,7 +178,7 @@ impl UpgradedRequestBuilder { let inner = native::send_request(self.inner, &self.protocols).await?; #[cfg(target_arch = "wasm32")] - let inner = wasm::WebSysWebSocketStream::new(self.inner.build()?, &self.protocols).await?; + let inner = wasm::WebSocket::new(self.inner.build()?, &self.protocols).await?; Ok(UpgradeResponse { inner, @@ -198,7 +198,7 @@ pub struct UpgradeResponse { inner: native::WebSocketResponse, #[cfg(target_arch = "wasm32")] - inner: wasm::WebSysWebSocketStream, + inner: wasm::WebSocket, #[allow(dead_code)] protocols: Vec, @@ -229,7 +229,7 @@ impl UpgradeResponse { #[cfg(target_arch = "wasm32")] let (inner, protocol) = { - let protocol = self.inner.protocol(); + let protocol = self.inner.protocol().to_owned(); (self.inner, Some(protocol)) }; @@ -252,7 +252,7 @@ pub struct WebSocket { inner: native::WebSocketStream, #[cfg(target_arch = "wasm32")] - inner: wasm::WebSysWebSocketStream, + inner: wasm::WebSocket, protocol: Option, } @@ -283,7 +283,15 @@ impl WebSocket { } #[cfg(target_arch = "wasm32")] - self.inner.close(code.into(), reason.unwrap_or_default())?; + { + let mut inner = self.inner; + inner + .send(Message::Close { + code, + reason: reason.unwrap_or_default().to_owned(), + }) + .await?; + } Ok(()) } @@ -344,8 +352,22 @@ pub mod tests { #[cfg(target_arch = "wasm32")] wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser); + use crate::{UpgradeResponse, UpgradedRequestBuilder}; + use super::{websocket, CloseCode, Message, RequestBuilderExt, WebSocket}; + macro_rules! assert_send { + ($ty:ty) => { + const _: () = { + struct Assert(std::marker::PhantomData); + Assert::<$ty>(std::marker::PhantomData); + }; + }; + } + + // unfortunately hyper IO is not sync + assert_send!(WebSocket); + async fn test_websocket(mut websocket: WebSocket) { let text = "Hello, World!"; websocket @@ -467,4 +489,10 @@ pub mod tests { assert_eq!(byte, 1001u16); assert_eq!(u16::from(text), 1001u16); } + + // assert that our types are Send + Sync + trait AssertSendSync: Send + Sync {} + impl AssertSendSync for UpgradedRequestBuilder {} + impl AssertSendSync for UpgradeResponse {} + impl AssertSendSync for WebSocket {} } diff --git a/src/wasm.rs b/src/wasm.rs index adabb49..1138b4a 100644 --- a/src/wasm.rs +++ b/src/wasm.rs @@ -3,251 +3,322 @@ use std::{ task::{Context, Poll}, }; -use futures_util::{Sink, Stream}; +use futures_channel::{mpsc, oneshot}; +use futures_util::{ready, select_biased, FutureExt, Sink, SinkExt, Stream, StreamExt}; use reqwest::{Request, Url}; -use tokio::sync::{mpsc, oneshot}; +use tracing::Instrument; use web_sys::{ js_sys::{Array, ArrayBuffer, JsString, Uint8Array}, wasm_bindgen::{closure::Closure, JsCast, JsValue}, CloseEvent, ErrorEvent, Event, MessageEvent, }; -use crate::protocol::{CloseCode, Message}; +use crate::protocol::Message; #[derive(Debug, thiserror::Error)] -pub enum WebSysError { - #[error("invalid url: {0}")] +#[non_exhaustive] +pub enum Error { + #[error("Invalid URL: {0}")] InvalidUrl(Url), - #[error("connection failed")] + #[error("Connection failed")] ConnectionFailed, #[error("{0}")] ErrorEvent(String), - #[error("unknown error")] + #[error("Unknown error")] Unknown, + + #[error("Send error")] + SendError, } -impl From for WebSysError { +impl From for Error { fn from(event: ErrorEvent) -> Self { Self::ErrorEvent(event.message()) } } -impl From for WebSysError { +impl From for Error { fn from(_value: JsValue) -> Self { Self::Unknown } } -#[derive(Debug)] -pub struct WebSysWebSocketStream { - inner: web_sys::WebSocket, - - rx: mpsc::UnboundedReceiver>>, - - #[allow(dead_code)] - on_message_callback: Closure, - - #[allow(dead_code)] - on_error_callback: Closure, +struct Outgoing { + message: Message, + ack_tx: oneshot::Sender>, +} - #[allow(dead_code)] - on_close_callback: Closure, +#[derive(Debug)] +pub struct WebSocket { + outgoing_tx: mpsc::Sender, + incoming_rx: mpsc::UnboundedReceiver>, + ack_rx: Option>>, + protocol: String, } -impl WebSysWebSocketStream { - pub async fn new(request: Request, protocols: &Vec) -> Result { +impl WebSocket { + pub async fn new(request: Request, protocols: &Vec) -> Result { + // get websocket URL from request. + // this contains query parameters. everything else is ignored, as web_sys only accepts an URL. let mut url = request.url().clone(); let scheme = match url.scheme() { "http" | "ws" => "ws", "https" | "wss" => "wss", - _ => return Err(WebSysError::InvalidUrl(url)), + _ => return Err(Error::InvalidUrl(url)), }; if let Err(_) = url.set_scheme(scheme) { - return Err(WebSysError::InvalidUrl(url)); + return Err(Error::InvalidUrl(url)); } - // the channel for messages and errors - let (tx, rx) = mpsc::unbounded_channel(); - - // channel to signal when the websocket has been opened - let (open_success_tx, open_success_rx) = oneshot::channel(); - let mut open_success_tx = Some(open_success_tx); - - // channel to signal an error while opening the channel - let (open_error_tx, open_error_rx) = oneshot::channel(); - let mut open_error_tx = Some(open_error_tx); - - // create websocket - let inner = web_sys::WebSocket::new_with_str_sequence( + // create the websocket + let websocket = web_sys::WebSocket::new_with_str_sequence( &url.to_string(), &protocols .into_iter() .map(|s| JsString::from(s.to_owned())) .collect::(), ) - .map_err(|_| WebSysError::ConnectionFailed)?; - - inner.set_binary_type(web_sys::BinaryType::Arraybuffer); - - // register message handler - let on_message_callback = { - let tx = tx.clone(); - Closure::::new(move |event: MessageEvent| { - tracing::debug!(event = ?event.data(), "message event"); - - if let Ok(abuf) = event.data().dyn_into::() { - let array = Uint8Array::new(&abuf); - let data = array.to_vec(); - let _ = tx.send(Some(Ok(Message::Binary(data)))); - } else if let Ok(text) = event.data().dyn_into::() { - let _ = tx.send(Some(Ok(Message::Text(text.into())))); - } else { - tracing::debug!(event = ?event.data(), "received unknown message event"); - } - }) - }; - inner.set_onmessage(Some(on_message_callback.as_ref().unchecked_ref())); - - // register error handler - // this will try to put the first error into a oneshot channel for errors that - // happen during opening. once that has been used, or the oneshot - // channel is dropped, this uses the regular message channel - let on_error_callback = { - let tx = tx.clone(); - Closure::::new(move |event: Event| { - let error = match event.dyn_into::() { - Ok(error) => WebSysError::from(error), - Err(_event) => WebSysError::Unknown, - }; - tracing::debug!("received error event: {error}"); - - let error = if let Some(open_error_tx) = open_error_tx.take() { - match open_error_tx.send(error) { - Ok(()) => return, - Err(error) => error, - } - } else { - error - }; + .map_err(|_| Error::ConnectionFailed)?; + websocket.set_binary_type(web_sys::BinaryType::Arraybuffer); - let _ = tx.send(Some(Err(error))); - }) - }; - inner.set_onerror(Some(on_error_callback.as_ref().unchecked_ref())); + // outgoing channel. only needs a capacity of 1, as we wait for acks anyway + let (outgoing_tx, outgoing_rx) = mpsc::channel(1); - // register close handler - let on_close_callback = { - let tx = tx.clone(); - Closure::::new(move |event: CloseEvent| { - tracing::debug!("received close event"); + // note: this needs to be unbounded, because we can't block in the event handlers + let (incoming_tx, incoming_rx) = mpsc::unbounded(); - let _ = tx.send(Some(Ok(Message::Close { - code: event.code().into(), - reason: event.reason(), - }))); - let _ = tx.send(None); - }) - }; - inner.set_onclose(Some(on_close_callback.as_ref().unchecked_ref())); + // channel for connect acks. message type: `Result`, where `String` is the protocol reported by the websocket + let (connect_ack_tx, connect_ack_rx) = oneshot::channel(); - // register open handler - let on_open_callback = Closure::::new(move |_event: Event| { - tracing::debug!("received open event"); - if let Some(tx) = open_success_tx.take() { - let _ = tx.send(()); - } - }); - inner.set_onopen(Some(on_open_callback.as_ref().unchecked_ref())); - - // wait for either the open event or an error - tokio::select! { - Ok(()) = open_success_rx => {}, - Ok(error) = open_error_rx => { - // cleanup - let _result = inner.close(); - inner.set_onopen(None); - inner.set_onmessage(None); - inner.set_onclose(None); - inner.set_onerror(None); - return Err(error); - }, - else => { - tracing::warn!("open sender dropped"); - } - }; + // spawn a task for the websocket locally. this way our `WebSocket` struct is `Send + Sync`, while the code that has the + // `web_sys::Websocket` (which is not `Send + Sync`) stays on the same thread. + tracing::debug!("spawning websocket task"); + let task_span = tracing::info_span!("websocket"); + wasm_bindgen_futures::spawn_local( + run_websocket(websocket, connect_ack_tx, outgoing_rx, incoming_tx).instrument(task_span), + ); - // remove open handler - inner.set_onopen(None); + // wait for connection ack, or error + tracing::debug!("waiting for ack"); + let protocol = connect_ack_rx + .await + .expect("websocket handler dropped ack sender")?; + tracing::debug!("ack received"); Ok(Self { - inner, - on_message_callback, - on_error_callback, - on_close_callback, - rx, + outgoing_tx, + incoming_rx, + ack_rx: None, + protocol, }) } - pub fn protocol(&self) -> String { - self.inner.protocol() + fn poll_ack(&mut self, cx: &mut Context) -> Poll> { + if let Some(ack_rx) = &mut self.ack_rx { + let result = ready!(ack_rx.poll_unpin(cx)).unwrap_or(Ok(())); + self.ack_rx = None; + Poll::Ready(result) + } else { + Poll::Ready(Ok(())) + } } - pub fn close(self, code: CloseCode, reason: &str) -> Result<(), WebSysError> { - self.inner.close_with_code_and_reason(code.into(), reason)?; - Ok(()) + pub fn protocol(&self) -> &str { + &self.protocol } } -impl Drop for WebSysWebSocketStream { - fn drop(&mut self) { - tracing::debug!("websocket stream dropped"); - let _result = self.inner.close(); - self.inner.set_onmessage(None); - self.inner.set_onclose(None); - self.inner.set_onerror(None); +impl Stream for WebSocket { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.incoming_rx + .poll_next_unpin(cx) + .map_ok(|message| { + tracing::debug!("message: {message:?}"); + message + }) + .map_err(|e| { + tracing::error!("receive error: {e}"); + e + }) } } -impl Stream for WebSysWebSocketStream { - type Item = Result; +impl Sink for WebSocket { + type Error = Error; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.rx - .poll_recv(cx) - .map(|ready_value| ready_value.flatten()) + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.outgoing_tx.poll_ready(cx)).map_err(|_| Error::SendError)?; + self.poll_ack(cx) } -} -impl Sink for WebSysWebSocketStream { - type Error = WebSysError; + fn start_send(mut self: Pin<&mut Self>, message: Message) -> Result<(), Self::Error> { + let (ack_tx, ack_rx) = oneshot::channel(); + self.ack_rx = Some(ack_rx); + self.outgoing_tx + .start_send(Outgoing { message, ack_tx }) + .map_err(|_| Error::SendError) + } - fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.outgoing_tx.poll_flush_unpin(cx)).map_err(|_| Error::SendError)?; + self.poll_ack(cx) } - fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { - match item { - Message::Text(text) => self.inner.send_with_str(&text)?, - Message::Binary(data) => self.inner.send_with_u8_array(&data)?, - Message::Close { code, reason } => self - .inner - .close_with_code_and_reason(code.into(), &reason)?, - #[allow(deprecated)] - Message::Ping(_) | Message::Pong(_) => { - // ignored! + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ready!(self.outgoing_tx.poll_close_unpin(cx)).map_err(|_| Error::SendError)?; + self.poll_ack(cx) + } +} + +async fn run_websocket( + websocket: web_sys::WebSocket, + connect_ack_tx: oneshot::Sender>, + mut outgoing_rx: mpsc::Receiver, + incoming_tx: mpsc::UnboundedSender>, +) { + let (mut error_tx, mut error_rx) = mpsc::unbounded(); + let (close_tx, mut close_rx) = oneshot::channel(); + let (open_tx, mut open_rx) = oneshot::channel(); + + // register error handler + // this will try to put the first error into a oneshot channel for errors that + // happen during opening. once that has been used, or the oneshot + // channel is dropped, this uses the regular message channel + let on_error_callback = { + tracing::debug!("error event"); + Closure::::new(move |event: Event| { + let error = match event.dyn_into::() { + Ok(error) => Error::from(error), + Err(_event) => Error::Unknown, + }; + let _ = error_tx.send(error); + }) + }; + websocket.set_onerror(Some(on_error_callback.as_ref().unchecked_ref())); + + // register close handler + let on_close_callback = { + let mut close_tx = Some(close_tx); + let incoming_tx = incoming_tx.clone(); + + Closure::::new(move |event: CloseEvent| { + tracing::debug!("close event"); + if let Some(close_tx) = close_tx.take() { + let _ = incoming_tx.unbounded_send(Ok(Message::Close { + code: event.code().into(), + reason: event.reason(), + })); + let _ = close_tx.send(()); + } + }) + }; + websocket.set_onclose(Some(on_close_callback.as_ref().unchecked_ref())); + + // register open handler + let on_open_callback = { + let mut open_tx = Some(open_tx); + + Closure::::new(move |_event: Event| { + tracing::debug!("open event"); + if let Some(open_tx) = open_tx.take() { + let _ = open_tx.send(()); + } + }) + }; + websocket.set_onopen(Some(on_open_callback.as_ref().unchecked_ref())); + + // register message handler + let on_message_callback = { + let incoming_tx = incoming_tx.clone(); + + Closure::::new(move |event: MessageEvent| { + tracing::debug!("message event"); + if let Ok(abuf) = event.data().dyn_into::() { + let array = Uint8Array::new(&abuf); + let data = array.to_vec(); + let _ = incoming_tx.unbounded_send(Ok(Message::Binary(data))); + } else if let Ok(text) = event.data().dyn_into::() { + let _ = incoming_tx.unbounded_send(Ok(Message::Text(text.into()))); + } else { + tracing::debug!(event = ?event.data(), "received unknown message event"); + } + }) + }; + websocket.set_onmessage(Some(on_message_callback.as_ref().unchecked_ref())); + + // first wait for open/close/error and send connect ack + let mut run_socket = false; + select_biased! { + _ = open_rx => { + let _ = connect_ack_tx.send(Ok(websocket.protocol())); + run_socket = true; + } + _ = &mut close_rx => { + let _ = connect_ack_tx.send(Err(Error::ConnectionFailed)); + } + error_opt = error_rx.next() => { + if let Some(error) = error_opt { + let _ = connect_ack_tx.send(Err(error)); } } - Ok(()) } - fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } + // we can remove the open handler + websocket.set_onopen(None); - fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(self.inner.close().map_err(Into::into)) + // connection established. listen for close/error events and outgoing messages + while run_socket { + select_biased! { + _ = &mut close_rx => { + // close event received + // the event handler takes care of sending the close frame into incoming_tx + run_socket = false; + } + error_opt = error_rx.next() => { + // error event received + if let Some(error) = error_opt { + if incoming_tx.unbounded_send(Err(error)).is_err() { + // receiver half dropped + run_socket = false; + } + } + } + message_opt = outgoing_rx.next() => { + if let Some(Outgoing { message, ack_tx }) = message_opt { + let result = send_message(&websocket, message); + let _ = ack_tx.send(result); + } + else { + // sender half dropped + run_socket = false; + } + } + } } + + // cleanup + let _ = websocket.close(); + websocket.set_onmessage(None); + websocket.set_onclose(None); + websocket.set_onerror(None); } + +fn send_message(websocket: &web_sys::WebSocket, message: Message) -> Result<(), Error> { + match message { + Message::Text(text) => websocket.send_with_str(&text)?, + Message::Binary(data) => websocket.send_with_u8_array(&data)?, + Message::Close { code, reason } => { + websocket.close_with_code_and_reason(code.into(), &reason)? + } + #[allow(deprecated)] + Message::Ping(_) | Message::Pong(_) => { + // ignored! + } + } + Ok::<(), Error>(()) +} \ No newline at end of file