From f699972518c5a64a637ff381859576b997338c41 Mon Sep 17 00:00:00 2001 From: kazk Date: Sun, 12 Sep 2021 21:54:45 -0700 Subject: [PATCH] Add WebSocket `permessage-deflate` extension support --- Cargo.toml | 7 +++--- src/filters/ws.rs | 63 ++++++++++++++++++++++++++++++++++++++++++++--- 2 files changed, 63 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 62ee5666b..7c452ace7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ async-compression = { version = "0.3.7", features = ["tokio"], optional = true } bytes = "1.0" futures-util = { version = "0.3", default-features = false, features = ["sink"] } futures-channel = { version = "0.3.17", features = ["sink"]} -headers = "0.3" +headers = { version = "0.3", git = "https://github.com/kazk/headers", branch = "sec-websocket-extensions" } http = "0.2" hyper = { version = "0.14", features = ["stream", "server", "http1", "http2", "tcp", "client"] } log = "0.4" @@ -37,7 +37,7 @@ tokio-stream = "0.1.1" tokio-util = { version = "0.7", features = ["io"] } tracing = { version = "0.1.21", default-features = false, features = ["log", "std"] } tower-service = "0.3" -tokio-tungstenite = { version = "0.17", optional = true } +tokio-tungstenite = { git = "https://github.com/kazk/tokio-tungstenite", branch = "feature/permessage-deflate", optional = true } percent-encoding = "2.1" pin-project = "1.0" tokio-rustls = { version = "0.23", optional = true } @@ -55,7 +55,8 @@ listenfd = "0.3" [features] default = ["multipart", "websocket"] -websocket = ["tokio-tungstenite"] +# TODO Separate feature for permessage-deflate? +websocket = ["tokio-tungstenite/deflate"] tls = ["tokio-rustls"] # Enable compression-related filters diff --git a/src/filters/ws.rs b/src/filters/ws.rs index 1e953c7a3..e2dde6ef1 100644 --- a/src/filters/ws.rs +++ b/src/filters/ws.rs @@ -11,11 +11,16 @@ use crate::filter::{filter_fn_one, Filter, One}; use crate::reject::Rejection; use crate::reply::{Reply, Response}; use futures_util::{future, ready, FutureExt, Sink, Stream, TryFutureExt}; -use headers::{Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketKey, Upgrade}; +use headers::{ + Connection, HeaderMapExt, SecWebsocketAccept, SecWebsocketExtensions, SecWebsocketKey, Upgrade, +}; use http; use hyper::upgrade::OnUpgrade; use tokio_tungstenite::{ - tungstenite::protocol::{self, WebSocketConfig}, + tungstenite::{ + self, + protocol::{self, WebSocketConfig}, + }, WebSocketStream, }; @@ -58,11 +63,13 @@ pub fn ws() -> impl Filter, Error = Rejection> + Copy { //.and(header::exact2(Upgrade::websocket())) //.and(header::exact2(SecWebsocketVersion::V13)) .and(header::header2::()) + .and(header::optional2::()) .and(on_upgrade()) .map( - move |key: SecWebsocketKey, on_upgrade: Option| Ws { + move |key: SecWebsocketKey, extensions, on_upgrade: Option| Ws { config: None, key, + extensions, on_upgrade, }, ) @@ -72,6 +79,7 @@ pub fn ws() -> impl Filter, Error = Rejection> + Copy { pub struct Ws { config: Option, key: SecWebsocketKey, + extensions: Option, on_upgrade: Option, } @@ -115,6 +123,14 @@ impl Ws { .max_frame_size = Some(max); self } + + /// Enable `permessage-deflate` support. + pub fn with_compression(mut self) -> Self { + self.config + .get_or_insert_with(WebSocketConfig::default) + .compression = Some(tungstenite::extensions::DeflateConfig::default()); + self + } } impl fmt::Debug for Ws { @@ -129,19 +145,44 @@ struct WsReply { on_upgrade: F, } +impl WsReply { + // Accept extension negotiation offers + fn accept_offers( + &self, + ) -> Option<(SecWebsocketExtensions, tungstenite::extensions::Extensions)> { + if let Some(extensions) = &self.ws.extensions { + self.ws.config.and_then(|c| c.accept_offers(extensions)) + } else { + None + } + } +} + impl Reply for WsReply where F: FnOnce(WebSocket) -> U + Send + 'static, U: Future + Send + 'static, { fn into_response(self) -> Response { + let (agreed_params, extensions) = if let Some((agreed, extensions)) = self.accept_offers() { + (Some(agreed), Some(extensions)) + } else { + (None, None) + }; + if let Some(on_upgrade) = self.ws.on_upgrade { let on_upgrade_cb = self.on_upgrade; let config = self.ws.config; let fut = on_upgrade .and_then(move |upgraded| { tracing::trace!("websocket upgrade complete"); - WebSocket::from_raw_socket(upgraded, protocol::Role::Server, config).map(Ok) + WebSocket::from_raw_socket_with_extensions( + upgraded, + protocol::Role::Server, + config, + extensions, + ) + .map(Ok) }) .and_then(move |socket| on_upgrade_cb(socket).map(Ok)) .map(|result| { @@ -162,6 +203,9 @@ where res.headers_mut().typed_insert(Upgrade::websocket()); res.headers_mut() .typed_insert(SecWebsocketAccept::from(self.ws.key)); + if let Some(agreed) = agreed_params { + res.headers_mut().typed_insert(agreed); + } res } @@ -196,6 +240,17 @@ impl WebSocket { .await } + pub(crate) async fn from_raw_socket_with_extensions( + upgraded: hyper::upgrade::Upgraded, + role: protocol::Role, + config: Option, + extensions: Option, + ) -> Self { + WebSocketStream::from_raw_socket_with_extensions(upgraded, role, config, extensions) + .map(|inner| WebSocket { inner }) + .await + } + /// Gracefully close this websocket. pub async fn close(mut self) -> Result<(), crate::Error> { future::poll_fn(|cx| Pin::new(&mut self).poll_close(cx)).await