diff --git a/actix-http/CHANGES.md b/actix-http/CHANGES.md index 982add26a23..3d924a241cb 100644 --- a/actix-http/CHANGES.md +++ b/actix-http/CHANGES.md @@ -5,6 +5,7 @@ ### Added - Add `header::CLEAR_SITE_DATA` constant. +- Add DEFLATE compression support for WebSocket. ### Changed diff --git a/actix-http/Cargo.toml b/actix-http/Cargo.toml index 3f81ea9f000..704f06c1e06 100644 --- a/actix-http/Cargo.toml +++ b/actix-http/Cargo.toml @@ -32,6 +32,7 @@ features = [ "compress-brotli", "compress-gzip", "compress-zstd", + "compress-ws-deflate", ] [package.metadata.cargo_check_external_types] @@ -91,6 +92,7 @@ rustls-0_23 = ["__tls", "actix-tls/accept", "actix-tls/rustls-0_23"] compress-brotli = ["__compress", "dep:brotli"] compress-gzip = ["__compress", "dep:flate2"] compress-zstd = ["__compress", "dep:zstd"] +compress-ws-deflate = ["dep:flate2", "flate2/zlib-default"] # Internal (PRIVATE!) features used to aid testing and checking feature status. # Don't rely on these whatsoever. They are semver-exempt and may disappear at anytime. diff --git a/actix-http/src/lib.rs b/actix-http/src/lib.rs index 734e6e1e159..a1b218f26aa 100644 --- a/actix-http/src/lib.rs +++ b/actix-http/src/lib.rs @@ -2,18 +2,19 @@ //! //! ## Crate Features //! -//! | Feature | Functionality | -//! | ------------------- | ------------------------------------------- | -//! | `http2` | HTTP/2 support via [h2]. | -//! | `openssl` | TLS support via [OpenSSL]. | -//! | `rustls-0_20` | TLS support via rustls 0.20. | -//! | `rustls-0_21` | TLS support via rustls 0.21. | -//! | `rustls-0_22` | TLS support via rustls 0.22. | -//! | `rustls-0_23` | TLS support via [rustls] 0.23. | -//! | `compress-brotli` | Payload compression support: Brotli. | -//! | `compress-gzip` | Payload compression support: Deflate, Gzip. | -//! | `compress-zstd` | Payload compression support: Zstd. | -//! | `trust-dns` | Use [trust-dns] as the client DNS resolver. | +//! | Feature | Functionality | +//! | --------------------- | ------------------------------------------- | +//! | `http2` | HTTP/2 support via [h2]. | +//! | `openssl` | TLS support via [OpenSSL]. | +//! | `rustls-0_20` | TLS support via rustls 0.20. | +//! | `rustls-0_21` | TLS support via rustls 0.21. | +//! | `rustls-0_22` | TLS support via rustls 0.22. | +//! | `rustls-0_23` | TLS support via [rustls] 0.23. | +//! | `compress-brotli` | Payload compression support: Brotli. | +//! | `compress-gzip` | Payload compression support: Deflate, Gzip. | +//! | `compress-zstd` | Payload compression support: Zstd. | +//! | `compress-ws-deflate` | WebSocket DEFLATE compression support. | +//! | `trust-dns` | Use [trust-dns] as the client DNS resolver. | //! //! [h2]: https://crates.io/crates/h2 //! [OpenSSL]: https://crates.io/crates/openssl diff --git a/actix-http/src/ws/codec.rs b/actix-http/src/ws/codec.rs index ad487e400fb..526ce23bdc6 100644 --- a/actix-http/src/ws/codec.rs +++ b/actix-http/src/ws/codec.rs @@ -1,12 +1,16 @@ use bitflags::bitflags; use bytes::{Bytes, BytesMut}; use bytestring::ByteString; -use tokio_util::codec::{Decoder, Encoder}; +use tokio_util::codec; use tracing::error; +#[cfg(feature = "compress-ws-deflate")] +use super::deflate::{ + DeflateCompressionContext, DeflateDecompressionContext, RSV_BIT_DEFLATE_FLAG, +}; use super::{ frame::Parser, - proto::{CloseReason, OpCode}, + proto::{CloseReason, OpCode, RsvBits}, ProtocolError, }; @@ -66,13 +70,6 @@ pub enum Item { Last(Bytes), } -/// WebSocket protocol codec. -#[derive(Debug, Clone)] -pub struct Codec { - flags: Flags, - max_size: usize, -} - bitflags! { #[derive(Debug, Clone, Copy)] struct Flags: u8 { @@ -82,63 +79,122 @@ bitflags! { } } -impl Codec { - /// Create new WebSocket frames decoder. - pub const fn new() -> Codec { - Codec { - max_size: 65_536, +/// WebSocket message encoder. +#[derive(Debug)] +pub struct Encoder { + flags: Flags, + + #[cfg(feature = "compress-ws-deflate")] + deflate_compress: Option, +} + +impl Encoder { + /// Create new WebSocket frames encoder. + pub const fn new() -> Encoder { + Encoder { flags: Flags::SERVER, + + #[cfg(feature = "compress-ws-deflate")] + deflate_compress: None, } } - /// Set max frame size. - /// - /// By default max size is set to 64KiB. - #[must_use = "This returns the a new Codec, without modifying the original."] - pub fn max_size(mut self, size: usize) -> Self { - self.max_size = size; - self + /// Create new WebSocket frames encoder with `permessage-deflate` extension support. + /// Compression context can be made from + /// [`DeflateSessionParameters::create_context`](super::DeflateSessionParameters::create_context). + #[cfg(feature = "compress-ws-deflate")] + pub fn new_deflate(compress: DeflateCompressionContext) -> Encoder { + Encoder { + flags: Flags::SERVER, + + deflate_compress: Some(compress), + } } - /// Set decoder to client mode. + /// Set encoder to client mode. /// - /// By default decoder works in server mode. - #[must_use = "This returns the a new Codec, without modifying the original."] + /// By default encoder works in server mode. + #[must_use = "This returns the a new Encoder, without modifying the original."] pub fn client_mode(mut self) -> Self { - self.flags.remove(Flags::SERVER); + self.flags = Flags::empty(); + self + } + + #[cfg(feature = "compress-ws-deflate")] + fn set_client_mode_deflate( + mut self, + remote_no_context_takeover: bool, + remote_max_window_bits: u8, + ) -> Self { + self.deflate_compress = self + .deflate_compress + .map(|c| c.reset_with(remote_no_context_takeover, remote_max_window_bits)); self } + + #[cfg(feature = "compress-ws-deflate")] + fn process_payload( + &mut self, + fin: bool, + bytes: Bytes, + ) -> Result<(Bytes, RsvBits), ProtocolError> { + if let Some(compress) = &mut self.deflate_compress { + Ok((compress.compress(fin, bytes)?, RSV_BIT_DEFLATE_FLAG)) + } else { + Ok((bytes, RsvBits::empty())) + } + } + + #[cfg(not(feature = "compress-ws-deflate"))] + fn process_payload( + &mut self, + _fin: bool, + bytes: Bytes, + ) -> Result<(Bytes, RsvBits), ProtocolError> { + Ok((bytes, RsvBits::empty())) + } } -impl Default for Codec { +impl Default for Encoder { fn default() -> Self { Self::new() } } -impl Encoder for Codec { +impl codec::Encoder for Encoder { type Error = ProtocolError; fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { match item { - Message::Text(txt) => Parser::write_message( - dst, - txt, - OpCode::Text, - true, - !self.flags.contains(Flags::SERVER), - ), - Message::Binary(bin) => Parser::write_message( - dst, - bin, - OpCode::Binary, - true, - !self.flags.contains(Flags::SERVER), - ), + Message::Text(txt) => { + let (bytes, rsv_bits) = self.process_payload(true, txt.into_bytes())?; + + Parser::write_message( + dst, + bytes, + OpCode::Text, + rsv_bits, + true, + !self.flags.contains(Flags::SERVER), + ) + } + Message::Binary(bin) => { + let (bin, rsv_bits) = self.process_payload(true, bin)?; + + Parser::write_message( + dst, + bin, + OpCode::Binary, + rsv_bits, + true, + !self.flags.contains(Flags::SERVER), + ) + } Message::Ping(txt) => Parser::write_message( dst, txt, OpCode::Ping, + RsvBits::empty(), true, !self.flags.contains(Flags::SERVER), ), @@ -146,22 +202,29 @@ impl Encoder for Codec { dst, txt, OpCode::Pong, + RsvBits::empty(), true, !self.flags.contains(Flags::SERVER), ), - Message::Close(reason) => { - Parser::write_close(dst, reason, !self.flags.contains(Flags::SERVER)) - } + Message::Close(reason) => Parser::write_close( + dst, + reason, + RsvBits::empty(), + !self.flags.contains(Flags::SERVER), + ), Message::Continuation(cont) => match cont { Item::FirstText(data) => { if self.flags.contains(Flags::W_CONTINUATION) { return Err(ProtocolError::ContinuationStarted); } else { + let (data, rsv_bits) = self.process_payload(false, data)?; + self.flags.insert(Flags::W_CONTINUATION); Parser::write_message( dst, - &data[..], + data, OpCode::Text, + rsv_bits, false, !self.flags.contains(Flags::SERVER), ) @@ -171,11 +234,14 @@ impl Encoder for Codec { if self.flags.contains(Flags::W_CONTINUATION) { return Err(ProtocolError::ContinuationStarted); } else { + let (data, rsv_bits) = self.process_payload(false, data)?; + self.flags.insert(Flags::W_CONTINUATION); Parser::write_message( dst, - &data[..], + data, OpCode::Binary, + rsv_bits, false, !self.flags.contains(Flags::SERVER), ) @@ -183,10 +249,13 @@ impl Encoder for Codec { } Item::Continue(data) => { if self.flags.contains(Flags::W_CONTINUATION) { + let (data, rsv_bits) = self.process_payload(false, data)?; + Parser::write_message( dst, - &data[..], + data, OpCode::Continue, + rsv_bits, false, !self.flags.contains(Flags::SERVER), ) @@ -197,10 +266,14 @@ impl Encoder for Codec { Item::Last(data) => { if self.flags.contains(Flags::W_CONTINUATION) { self.flags.remove(Flags::W_CONTINUATION); + + let (data, rsv_bits) = self.process_payload(true, data)?; + Parser::write_message( dst, - &data[..], + data, OpCode::Continue, + rsv_bits, true, !self.flags.contains(Flags::SERVER), ) @@ -215,20 +288,130 @@ impl Encoder for Codec { } } -impl Decoder for Codec { +/// WebSocket message decoder. +#[derive(Debug)] +pub struct Decoder { + flags: Flags, + max_size: usize, + + #[cfg(feature = "compress-ws-deflate")] + deflate_decompress: Option, +} + +impl Decoder { + /// Create new WebSocket frames decoder. + pub const fn new() -> Decoder { + Decoder { + flags: Flags::SERVER, + max_size: 65_536, + + #[cfg(feature = "compress-ws-deflate")] + deflate_decompress: None, + } + } + + /// Create new WebSocket frames decoder with `permessage-deflate` extension support. + /// Decompression context can be made from + /// [`DeflateSessionParameters::create_context`](super::DeflateSessionParameters::create_context). + #[cfg(feature = "compress-ws-deflate")] + pub fn new_deflate(decompress: DeflateDecompressionContext) -> Decoder { + Decoder { + flags: Flags::SERVER, + max_size: 65_536, + + deflate_decompress: Some(decompress), + } + } + + /// Set max frame size. + /// + /// By default max size is set to 64KiB. + #[must_use = "This returns the a new Decoder, without modifying the original."] + pub fn max_size(mut self, size: usize) -> Self { + self.max_size = size; + self + } + + /// Set decoder to client mode. + /// + /// By default decoder works in server mode. + #[must_use = "This returns the a new Decoder, without modifying the original."] + pub fn client_mode(mut self) -> Self { + self.flags = Flags::empty(); + self + } + + #[cfg(feature = "compress-ws-deflate")] + fn set_client_mode_deflate( + mut self, + local_no_context_takeover: bool, + local_max_window_bits: u8, + ) -> Self { + if let Some(decompress) = &mut self.deflate_decompress { + decompress.reset_with(local_no_context_takeover, local_max_window_bits); + } + + self + } + + #[cfg(feature = "compress-ws-deflate")] + fn process_payload( + &mut self, + fin: bool, + opcode: OpCode, + rsv_bits: RsvBits, + bytes: Option, + ) -> Result, ProtocolError> { + if let Some(bytes) = bytes { + if let Some(decompress) = &mut self.deflate_decompress { + Ok(Some(decompress.decompress(fin, opcode, rsv_bits, bytes)?)) + } else { + Ok(Some(bytes)) + } + } else { + Ok(None) + } + } + + #[cfg(not(feature = "compress-ws-deflate"))] + fn process_payload( + &mut self, + _fin: bool, + _opcode: OpCode, + _rsv_bits: RsvBits, + bytes: Option, + ) -> Result, ProtocolError> { + Ok(bytes) + } +} + +impl Default for Decoder { + fn default() -> Self { + Self::new() + } +} + +impl codec::Decoder for Decoder { type Item = Frame; type Error = ProtocolError; fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { match Parser::parse(src, self.flags.contains(Flags::SERVER), self.max_size) { - Ok(Some((finished, opcode, payload))) => { + Ok(Some((finished, opcode, rsv_bits, payload))) => { + let payload = self.process_payload( + finished, + opcode, + rsv_bits, + payload.map(BytesMut::freeze), + )?; + // continuation is not supported if !finished { return match opcode { OpCode::Continue => { if self.flags.contains(Flags::CONTINUATION) { Ok(Some(Frame::Continuation(Item::Continue( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + payload.unwrap_or_else(Bytes::new), )))) } else { Err(ProtocolError::ContinuationNotStarted) @@ -238,7 +421,7 @@ impl Decoder for Codec { if !self.flags.contains(Flags::CONTINUATION) { self.flags.insert(Flags::CONTINUATION); Ok(Some(Frame::Continuation(Item::FirstBinary( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + payload.unwrap_or_else(Bytes::new), )))) } else { Err(ProtocolError::ContinuationStarted) @@ -248,7 +431,7 @@ impl Decoder for Codec { if !self.flags.contains(Flags::CONTINUATION) { self.flags.insert(Flags::CONTINUATION); Ok(Some(Frame::Continuation(Item::FirstText( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + payload.unwrap_or_else(Bytes::new), )))) } else { Err(ProtocolError::ContinuationStarted) @@ -266,7 +449,7 @@ impl Decoder for Codec { if self.flags.contains(Flags::CONTINUATION) { self.flags.remove(Flags::CONTINUATION); Ok(Some(Frame::Continuation(Item::Last( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), + payload.unwrap_or_else(Bytes::new), )))) } else { Err(ProtocolError::ContinuationNotStarted) @@ -281,18 +464,10 @@ impl Decoder for Codec { Ok(Some(Frame::Close(None))) } } - OpCode::Ping => Ok(Some(Frame::Ping( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), - ))), - OpCode::Pong => Ok(Some(Frame::Pong( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), - ))), - OpCode::Binary => Ok(Some(Frame::Binary( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), - ))), - OpCode::Text => Ok(Some(Frame::Text( - payload.map(|pl| pl.freeze()).unwrap_or_else(Bytes::new), - ))), + OpCode::Ping => Ok(Some(Frame::Ping(payload.unwrap_or_else(Bytes::new)))), + OpCode::Pong => Ok(Some(Frame::Pong(payload.unwrap_or_else(Bytes::new)))), + OpCode::Binary => Ok(Some(Frame::Binary(payload.unwrap_or_else(Bytes::new)))), + OpCode::Text => Ok(Some(Frame::Text(payload.unwrap_or_else(Bytes::new)))), } } Ok(None) => Ok(None), @@ -300,3 +475,130 @@ impl Decoder for Codec { } } } + +/// WebSocket protocol codec. +/// This is essentially a combination of [`Encoder`] and [`Decoder`] and +/// actual conversion behaviors are defined in both structs respectively. +/// +/// # Note +/// Cloning [`Codec`] creates a new codec with existing configurations +/// and will not preserve the context information. +#[derive(Debug, Default)] +pub struct Codec { + encoder: Encoder, + decoder: Decoder, +} + +impl Clone for Codec { + fn clone(&self) -> Self { + Self { + encoder: Encoder { + flags: self.encoder.flags & Flags::SERVER, + #[cfg(feature = "compress-ws-deflate")] + deflate_compress: self.encoder.deflate_compress.as_ref().map(|c| { + DeflateCompressionContext::new( + Some(c.compression_level), + c.remote_no_context_takeover, + c.remote_max_window_bits, + ) + }), + }, + decoder: Decoder { + flags: self.decoder.flags & Flags::SERVER, + max_size: self.decoder.max_size, + #[cfg(feature = "compress-ws-deflate")] + deflate_decompress: self.decoder.deflate_decompress.as_ref().map(|d| { + DeflateDecompressionContext::new( + d.local_no_context_takeover, + d.local_max_window_bits, + ) + }), + }, + } + } +} + +impl Codec { + /// Create new WebSocket frames codec. + pub fn new() -> Codec { + Codec { + encoder: Encoder::new(), + decoder: Decoder::new(), + } + } + + /// Create new WebSocket frames codec with DEFLATE compression. + /// Both compression and decompression contexts can be made from + /// [`DeflateSessionParameters::create_context`](super::DeflateSessionParameters::create_context). + #[cfg(feature = "compress-ws-deflate")] + pub fn new_deflate( + compress: DeflateCompressionContext, + decompress: DeflateDecompressionContext, + ) -> Codec { + Codec { + encoder: Encoder::new_deflate(compress), + decoder: Decoder::new_deflate(decompress), + } + } + + /// Set max frame size. + /// + /// By default max size is set to 64KiB. + #[must_use = "This returns the a new Codec, without modifying the original."] + pub fn max_size(self, size: usize) -> Self { + let Self { encoder, decoder } = self; + + Codec { + encoder, + decoder: decoder.max_size(size), + } + } + + /// Set codec to client mode. + /// + /// By default codec works in server mode. + #[must_use = "This returns the a new Codec, without modifying the original."] + pub fn client_mode(self) -> Self { + let Self { + mut encoder, + mut decoder, + } = self; + + encoder = encoder.client_mode(); + decoder = decoder.client_mode(); + #[cfg(feature = "compress-ws-deflate")] + { + if let Some(decoder) = &decoder.deflate_decompress { + encoder = encoder.set_client_mode_deflate( + decoder.local_no_context_takeover, + decoder.local_max_window_bits, + ); + } + if let Some(encoder) = &encoder.deflate_compress { + decoder = decoder.set_client_mode_deflate( + encoder.remote_no_context_takeover, + encoder.remote_max_window_bits, + ); + } + } + + Self { encoder, decoder } + } +} + +impl codec::Decoder for Codec { + type Item = Frame; + type Error = ProtocolError; + + fn decode(&mut self, src: &mut BytesMut) -> Result, Self::Error> { + self.decoder.decode(src) + } +} + +impl codec::Encoder for Codec { + type Error = ProtocolError; + + fn encode(&mut self, item: Message, dst: &mut BytesMut) -> Result<(), Self::Error> { + self.encoder.encode(item, dst) + } +} diff --git a/actix-http/src/ws/deflate.rs b/actix-http/src/ws/deflate.rs new file mode 100644 index 00000000000..76bfebb350b --- /dev/null +++ b/actix-http/src/ws/deflate.rs @@ -0,0 +1,846 @@ +//! WebSocket permessage-deflate compression implementation. + +use std::convert::Infallible; + +use bytes::Bytes; +pub use flate2::Compression as DeflateCompressionLevel; + +use super::{OpCode, ProtocolError, RsvBits}; +use crate::header::{HeaderName, HeaderValue, TryIntoHeaderPair, SEC_WEBSOCKET_EXTENSIONS}; + +// NOTE: according to [RFC 7692 §7.1.2.1] window bit size should be within 8..=15 +// but we have to limit the range to 9..=15 because [flate2] only supports window bit within 9..=15. +// +// [RFC 6792 §7.1.2.1]: https://datatracker.ietf.org/doc/html/rfc7692#section-7.1.2.1 +// [flate2]: https://docs.rs/flate2/latest/flate2/struct.Compress.html#method.new_with_window_bits +const MAX_WINDOW_BITS_RANGE: std::ops::RangeInclusive = 9..=15; +const DEFAULT_WINDOW_BITS: u8 = 15; + +const BUF_SIZE: usize = 2048; + +pub(super) const RSV_BIT_DEFLATE_FLAG: RsvBits = RsvBits::RSV1; + +/// DEFLATE compression related handshake errors. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum DeflateHandshakeError { + /// Unknown extension parameter given. + UnknownWebSocketParameters, + + /// Duplicate parameter found in single extension statement. + DuplicateParameter(&'static str), + + /// Max window bits size out of range. Should be in 9..=15 + MaxWindowBitsOutOfRange, + + /// Multiple `permessage-deflate` statements found but failed to negotiate any. + NoSuitableConfigurationFound, +} + +impl std::fmt::Display for DeflateHandshakeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::UnknownWebSocketParameters => { + write!(f, "Unknown WebSocket `permessage-deflate` parameters.") + } + Self::DuplicateParameter(p) => { + write!(f, "Duplicate WebSocket `permessage-deflate` parameter: {p}") + } + Self::MaxWindowBitsOutOfRange => write!( + f, + "Max window bits out of range. ({} to {} expected)", + MAX_WINDOW_BITS_RANGE.start(), + MAX_WINDOW_BITS_RANGE.end() + ), + Self::NoSuitableConfigurationFound => write!( + f, + "No suitable WebSocket `permedia-deflate` parameter configurations found." + ), + } + } +} + +impl std::error::Error for DeflateHandshakeError {} + +/// Maximum size of client's DEFLATE sliding window. +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum ClientMaxWindowBits { + /// Unspecified. Indicates client will follow server configuration. + NotSpecified, + /// Specified size of client's DEFLATE sliding window size in bits, between 9 and 15. + Specified(u8), +} + +/// Per-session DEFLATE configuration parameter. +/// +/// It can be used both client and server side. +/// At client side, it can be used to pass desired configuration to server. +/// At server side, negotiated parameter will be sent to client with this. +/// This can be represented in HTTP header form as it implements [`TryIntoHeaderPair`] trait. +#[derive(Debug, Clone, Default, Eq, PartialEq)] +pub struct DeflateSessionParameters { + /// Disallow server from take over context. + pub server_no_context_takeover: bool, + /// Disallow client from take over context. + pub client_no_context_takeover: bool, + /// Maximum size of server's DEFLATE sliding window in bits, between 9 and 15. + pub server_max_window_bits: Option, + /// Maximum size of client's DEFLATE sliding window. + pub client_max_window_bits: Option, +} + +impl TryIntoHeaderPair for DeflateSessionParameters { + type Error = Infallible; + + fn try_into_pair(self) -> Result<(HeaderName, HeaderValue), Self::Error> { + let mut response_extension = vec!["permessage-deflate".to_owned()]; + + if self.server_no_context_takeover { + response_extension.push("server_no_context_takeover".to_owned()); + } + if self.client_no_context_takeover { + response_extension.push("client_no_context_takeover".to_owned()); + } + if let Some(server_max_window_bits) = self.server_max_window_bits { + response_extension.push(format!("server_max_window_bits={server_max_window_bits}")); + } + if let Some(client_max_window_bits) = self.client_max_window_bits { + match client_max_window_bits { + ClientMaxWindowBits::NotSpecified => { + response_extension.push("client_max_window_bits".to_string()); + } + ClientMaxWindowBits::Specified(bits) => { + response_extension.push(format!("client_max_window_bits={bits}")); + } + } + } + + Ok(( + SEC_WEBSOCKET_EXTENSIONS, + HeaderValue::from_str(&response_extension.join("; ")).unwrap(), + )) + } +} + +impl DeflateSessionParameters { + fn parse<'a>( + extension_frags: impl Iterator, + ) -> Result { + let mut client_max_window_bits = None; + let mut server_max_window_bits = None; + let mut client_no_context_takeover = None; + let mut server_no_context_takeover = None; + + let mut unknown_parameters = vec![]; + + for fragment in extension_frags { + if fragment.is_empty() { + continue; + } else if fragment == "client_max_window_bits" { + if client_max_window_bits.is_some() { + return Err(DeflateHandshakeError::DuplicateParameter( + "client_max_window_bits", + )); + } + client_max_window_bits = Some(ClientMaxWindowBits::NotSpecified); + } else if let Some(value) = fragment.strip_prefix("client_max_window_bits=") { + if client_max_window_bits.is_some() { + return Err(DeflateHandshakeError::DuplicateParameter( + "client_max_window_bits", + )); + } + let bits = value + .parse::() + .map_err(|_| DeflateHandshakeError::MaxWindowBitsOutOfRange)?; + if !MAX_WINDOW_BITS_RANGE.contains(&bits) { + return Err(DeflateHandshakeError::MaxWindowBitsOutOfRange); + } + client_max_window_bits = Some(ClientMaxWindowBits::Specified(bits)); + } else if let Some(value) = fragment.strip_prefix("server_max_window_bits=") { + if server_max_window_bits.is_some() { + return Err(DeflateHandshakeError::DuplicateParameter( + "server_max_window_bits", + )); + } + let bits = value + .parse::() + .map_err(|_| DeflateHandshakeError::MaxWindowBitsOutOfRange)?; + if !MAX_WINDOW_BITS_RANGE.contains(&bits) { + return Err(DeflateHandshakeError::MaxWindowBitsOutOfRange); + } + server_max_window_bits = Some(bits); + } else if fragment == "server_no_context_takeover" { + if server_no_context_takeover.is_some() { + return Err(DeflateHandshakeError::DuplicateParameter( + "server_no_context_takeover", + )); + } + server_no_context_takeover = Some(true); + } else if fragment == "client_no_context_takeover" { + if client_no_context_takeover.is_some() { + return Err(DeflateHandshakeError::DuplicateParameter( + "client_no_context_takeover", + )); + } + client_no_context_takeover = Some(true); + } else { + unknown_parameters.push(fragment.to_owned()); + } + } + + if !unknown_parameters.is_empty() { + Err(DeflateHandshakeError::UnknownWebSocketParameters) + } else { + Ok(DeflateSessionParameters { + server_no_context_takeover: server_no_context_takeover.unwrap_or(false), + client_no_context_takeover: client_no_context_takeover.unwrap_or(false), + server_max_window_bits, + client_max_window_bits, + }) + } + } + + /// Parse desired parameters from `Sec-WebSocket-Extensions` header. + /// The result may contain multiple values as it's possible to pass multiple parameters + /// separated with comma. + pub fn from_extension_header(header_value: &str) -> Vec> { + let mut results = vec![]; + for extension in header_value.split(',').map(str::trim) { + let mut fragments = extension.split(';').map(str::trim); + if fragments.next() == Some("permessage-deflate") { + results.push(Self::parse(fragments)); + } + } + + results + } + + /// Create compression and decompression context based on the parameter. + pub fn create_context( + &self, + compression_level: Option, + is_client_mode: bool, + ) -> (DeflateCompressionContext, DeflateDecompressionContext) { + let client_max_window_bits = + if let Some(ClientMaxWindowBits::Specified(value)) = self.client_max_window_bits { + value + } else { + DEFAULT_WINDOW_BITS + }; + let server_max_window_bits = self.server_max_window_bits.unwrap_or(DEFAULT_WINDOW_BITS); + + let (remote_no_context_takeover, remote_max_window_bits) = if is_client_mode { + (self.server_no_context_takeover, server_max_window_bits) + } else { + (self.client_no_context_takeover, client_max_window_bits) + }; + + let (local_no_context_takeover, local_max_window_bits) = if is_client_mode { + (self.client_no_context_takeover, client_max_window_bits) + } else { + (self.server_no_context_takeover, server_max_window_bits) + }; + + ( + DeflateCompressionContext::new( + compression_level, + remote_no_context_takeover, + remote_max_window_bits, + ), + DeflateDecompressionContext::new(local_no_context_takeover, local_max_window_bits), + ) + } +} + +/// Server-side DEFLATE configuration. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +pub struct DeflateServerConfig { + /// DEFLATE compression level. See [`flate2::Compression`] for details. + pub compression_level: Option, + /// Disallow server from take over context. Default is false. + pub server_no_context_takeover: bool, + /// Disallow client from take over context. Default is false. + pub client_no_context_takeover: bool, + /// Maximum size of server's DEFLATE sliding window in bits, between 9 and 15. Default is 15. + pub server_max_window_bits: Option, + /// Maximum size of client's DEFLATE sliding window in bits, between 9 and 15. Default is 15. + pub client_max_window_bits: Option, +} + +impl DeflateServerConfig { + /// Negotiate context parameters. + /// Since parameters from the client may be incompatible with the server configuration, + /// actual parameters could be adjusted here. Conversion rules are as follows: + /// + /// ## server_no_context_takeover + /// + /// | Config | Request | Response | + /// | ------ | ------- | --------- | + /// | false | false | false | + /// | false | true | true | + /// | true | false | true | + /// | true | true | true | + /// + /// ## client_no_context_takeover + /// + /// | Config | Request | Response | + /// | ------ | ------- | --------- | + /// | false | false | false | + /// | false | true | true | + /// | true | false | true | + /// | true | true | true | + /// + /// ## server_max_window_bits + /// + /// | Config | Request | Response | + /// | ------------ | ------------ | -------- | + /// | None | None | None | + /// | None | 9 <= R <= 15 | R | + /// | 9 <= C <= 15 | None | C | + /// | 9 <= C <= 15 | 9 <= R <= C | R | + /// | 9 <= C <= 15 | C <= R <= 15 | C | + /// + /// ## client_max_window_bits + /// + /// | Config | Request | Response | + /// | ------------ | ------------ | -------- | + /// | None | None | None | + /// | None | Unspecified | None | + /// | None | 9 <= R <= 15 | R | + /// | 9 <= C <= 15 | None | None | + /// | 9 <= C <= 15 | Unspecified | C | + /// | 9 <= C <= 15 | 9 <= R <= C | R | + /// | 9 <= C <= 15 | C <= R <= 15 | C | + pub fn negotiate(&self, params: DeflateSessionParameters) -> DeflateSessionParameters { + let server_no_context_takeover = + if self.server_no_context_takeover && !params.server_no_context_takeover { + true + } else { + params.server_no_context_takeover + }; + + let client_no_context_takeover = + if self.client_no_context_takeover && !params.client_no_context_takeover { + true + } else { + params.client_no_context_takeover + }; + + let server_max_window_bits = + match (self.server_max_window_bits, params.server_max_window_bits) { + (None, value) => value, + (Some(config_value), None) => Some(config_value), + (Some(config_value), Some(value)) => { + if value > config_value { + Some(config_value) + } else { + Some(value) + } + } + }; + + let client_max_window_bits = + match (self.client_max_window_bits, params.client_max_window_bits) { + (None, None | Some(ClientMaxWindowBits::NotSpecified)) => None, + (None, Some(ClientMaxWindowBits::Specified(value))) => Some(value), + (Some(_), None) => None, + (Some(config_value), Some(ClientMaxWindowBits::NotSpecified)) => Some(config_value), + (Some(config_value), Some(ClientMaxWindowBits::Specified(value))) => { + if value > config_value { + Some(config_value) + } else { + Some(value) + } + } + }; + + DeflateSessionParameters { + server_no_context_takeover, + client_no_context_takeover, + server_max_window_bits, + client_max_window_bits: client_max_window_bits.map(ClientMaxWindowBits::Specified), + } + } +} + +/// DEFLATE decompression context. +#[derive(Debug)] +pub struct DeflateDecompressionContext { + pub(super) local_no_context_takeover: bool, + pub(super) local_max_window_bits: u8, + + decompress: flate2::Decompress, + + decode_continuation: bool, + total_bytes_written: u64, + total_bytes_read: u64, +} + +impl DeflateDecompressionContext { + pub(super) fn new(local_no_context_takeover: bool, local_max_window_bits: u8) -> Self { + Self { + local_no_context_takeover, + local_max_window_bits, + + decompress: flate2::Decompress::new_with_window_bits(false, local_max_window_bits), + + decode_continuation: false, + total_bytes_written: 0, + total_bytes_read: 0, + } + } + + pub(super) fn reset_with( + &mut self, + local_no_context_takeover: bool, + local_max_window_bits: u8, + ) { + *self = Self::new(local_no_context_takeover, local_max_window_bits); + } + + pub(super) fn decompress( + &mut self, + fin: bool, + opcode: OpCode, + rsv: RsvBits, + payload: Bytes, + ) -> Result { + if !matches!(opcode, OpCode::Text | OpCode::Binary | OpCode::Continue) + || !rsv.contains(RSV_BIT_DEFLATE_FLAG) + { + return Ok(payload); + } + + if opcode == OpCode::Continue { + if !self.decode_continuation { + return Ok(payload); + } + } else { + self.decode_continuation = true; + } + + let mut output: Vec = vec![]; + let mut buf = [0u8; BUF_SIZE]; + + let mut offset: usize = 0; + loop { + let res = if offset >= payload.len() { + self.decompress + .decompress( + &[0x00, 0x00, 0xff, 0xff], + &mut buf, + flate2::FlushDecompress::Finish, + ) + .map_err(|err| { + self.reset(); + ProtocolError::Io(err.into()) + })? + } else { + self.decompress + .decompress(&payload[offset..], &mut buf, flate2::FlushDecompress::None) + .map_err(|err| { + self.reset(); + ProtocolError::Io(err.into()) + })? + }; + + let read = self.decompress.total_in() - self.total_bytes_read; + let written = self.decompress.total_out() - self.total_bytes_written; + + offset += read as usize; + self.total_bytes_read += read; + if written > 0 { + output.extend(buf.iter().take(written as usize)); + self.total_bytes_written += written; + } + + if res != flate2::Status::Ok { + break; + } + } + + if fin { + self.decode_continuation = false; + if self.local_no_context_takeover { + self.reset(); + } + } + + Ok(output.into()) + } + + fn reset(&mut self) { + self.decompress.reset(false); + self.total_bytes_read = 0; + self.total_bytes_written = 0; + } +} + +/// DEFLATE compression context. +#[derive(Debug)] +pub struct DeflateCompressionContext { + pub(super) compression_level: flate2::Compression, + pub(super) remote_no_context_takeover: bool, + pub(super) remote_max_window_bits: u8, + + compress: flate2::Compress, + total_bytes_written: u64, + total_bytes_read: u64, +} + +impl DeflateCompressionContext { + pub(super) fn new( + compression_level: Option, + remote_no_context_takeover: bool, + remote_max_window_bits: u8, + ) -> Self { + let compression_level = compression_level.unwrap_or_default(); + + Self { + compression_level, + remote_no_context_takeover, + remote_max_window_bits, + + compress: flate2::Compress::new_with_window_bits( + compression_level, + false, + remote_max_window_bits, + ), + + total_bytes_written: 0, + total_bytes_read: 0, + } + } + + pub(super) fn reset_with( + mut self, + remote_no_context_takeover: bool, + remote_max_window_bits: u8, + ) -> Self { + self = Self::new( + Some(self.compression_level), + remote_no_context_takeover, + remote_max_window_bits, + ); + + self + } + + pub(super) fn compress(&mut self, fin: bool, payload: Bytes) -> Result { + let mut output = vec![]; + let mut buf = [0u8; BUF_SIZE]; + + loop { + let total_in = self.compress.total_in() - self.total_bytes_read; + let res = if total_in >= payload.len() as u64 { + self.compress + .compress(&[], &mut buf, flate2::FlushCompress::Sync) + .map_err(|err| { + self.reset(); + ProtocolError::Io(err.into()) + })? + } else { + self.compress + .compress(&payload, &mut buf, flate2::FlushCompress::None) + .map_err(|err| { + self.reset(); + ProtocolError::Io(err.into()) + })? + }; + + let written = self.compress.total_out() - self.total_bytes_written; + if written > 0 { + output.extend(buf.iter().take(written as usize)); + self.total_bytes_written += written; + } + + if res != flate2::Status::Ok { + break; + } + } + self.total_bytes_read = self.compress.total_in(); + + if output.iter().rev().take(4).eq(&[0xff, 0xff, 0x00, 0x00]) { + output.drain(output.len() - 4..); + } + + if fin && self.remote_no_context_takeover { + self.reset(); + } + + Ok(output.into()) + } + + fn reset(&mut self) { + self.compress.reset(); + self.total_bytes_read = 0; + self.total_bytes_written = 0; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::body::MessageBody; + + #[test] + fn test_session_parameters() { + let extension = "abc, def, permessage-deflate"; + assert_eq!( + DeflateSessionParameters::from_extension_header(extension), + vec![Ok(DeflateSessionParameters::default())] + ); + + let extension = "permessage-deflate; unknown_parameter"; + assert_eq!( + DeflateSessionParameters::from_extension_header(extension), + vec![Err(DeflateHandshakeError::UnknownWebSocketParameters)] + ); + + let extension = "permessage-deflate; client_max_window_bits=9; client_max_window_bits=10"; + assert_eq!( + DeflateSessionParameters::from_extension_header(extension), + vec![Err(DeflateHandshakeError::DuplicateParameter( + "client_max_window_bits" + ))] + ); + + let extension = "permessage-deflate; server_max_window_bits=8"; + assert_eq!( + DeflateSessionParameters::from_extension_header(extension), + vec![Err(DeflateHandshakeError::MaxWindowBitsOutOfRange)] + ); + + let extension = "permessage-deflate; server_max_window_bits=16"; + assert_eq!( + DeflateSessionParameters::from_extension_header(extension), + vec![Err(DeflateHandshakeError::MaxWindowBitsOutOfRange)] + ); + + let extension = "permessage-deflate; client_max_window_bits; server_max_window_bits=15; \ + client_no_context_takeover; server_no_context_takeover, \ + permessage-deflate; client_max_window_bits=10"; + assert_eq!( + DeflateSessionParameters::from_extension_header(extension), + vec![ + Ok(DeflateSessionParameters { + server_no_context_takeover: true, + client_no_context_takeover: true, + server_max_window_bits: Some(15), + client_max_window_bits: Some(ClientMaxWindowBits::NotSpecified) + }), + Ok(DeflateSessionParameters { + server_no_context_takeover: false, + client_no_context_takeover: false, + server_max_window_bits: None, + client_max_window_bits: Some(ClientMaxWindowBits::Specified(10)) + }) + ] + ); + } + + #[test] + fn test_compress() { + // With context takeover + + let mut compress = DeflateCompressionContext::new(None, false, 15); + assert_eq!( + compress + .compress(true, "Hello World".try_into_bytes().unwrap()) + .unwrap(), + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ); + assert_eq!( + compress + .compress(true, "Hello World".try_into_bytes().unwrap()) + .unwrap(), + Bytes::from_static(b"\xf2@0\x01\0") + ); + + // Without context takeover + + let mut compress = DeflateCompressionContext::new(None, true, 15); + assert_eq!( + compress + .compress(true, "Hello World".try_into_bytes().unwrap()) + .unwrap(), + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ); + assert_eq!( + compress + .compress(true, "Hello World".try_into_bytes().unwrap()) + .unwrap(), + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ); + + // With continuation + assert_eq!( + compress + .compress(false, "Hello World".try_into_bytes().unwrap()) + .unwrap(), + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ); + // Continuation keeps context. + assert_eq!( + compress + .compress(true, "Hello World".try_into_bytes().unwrap()) + .unwrap(), + Bytes::from_static(b"\xf2@0\x01\0") + ); + // after continuation, context resets + assert_eq!( + compress + .compress(true, "Hello World".try_into_bytes().unwrap()) + .unwrap(), + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ); + } + + #[test] + fn test_decompress() { + // With context takeover + + let mut decompress = DeflateDecompressionContext::new(false, 15); + + // Without RSV1 bit, decompression does not happen. + assert_eq!( + decompress + .decompress( + true, + OpCode::Text, + RsvBits::empty(), + Bytes::from_static(b"Hello World") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + + // Control frames (such as ping/pong) are not decompressed + assert_eq!( + decompress + .decompress( + true, + OpCode::Ping, + RsvBits::RSV1, + Bytes::from_static(b"Hello World") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + + // Successful decompression + assert_eq!( + decompress + .decompress( + true, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + + // Success subsequent decompression + assert_eq!( + decompress + .decompress( + true, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"\xf2@0\x01\0") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + + // Invalid compression payload + assert!(decompress + .decompress( + true, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"Hello World") + ) + .is_err()); + + // When there was error, context is reset. + assert_eq!( + decompress + .decompress( + true, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + + // Without context takeover + + let mut decompress = DeflateDecompressionContext::new(true, 15); + + // Successful decompression + assert_eq!( + decompress + .decompress( + true, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + + // Context has been reset. + assert_eq!( + decompress + .decompress( + true, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + + // With continuation + assert_eq!( + decompress + .decompress( + false, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + // Continuation keeps context. + assert_eq!( + decompress + .decompress( + true, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"\xf2@0\x01\0") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + // When continuation has finished, context is reset. + assert_eq!( + decompress + .decompress( + false, + OpCode::Text, + RsvBits::RSV1, + Bytes::from_static(b"\xf2H\xcd\xc9\xc9W\x08\xcf/\xcaI\x01\0") + ) + .unwrap(), + Bytes::from_static(b"Hello World") + ); + } +} diff --git a/actix-http/src/ws/frame.rs b/actix-http/src/ws/frame.rs index 35b3f8e668e..0bd64a46522 100644 --- a/actix-http/src/ws/frame.rs +++ b/actix-http/src/ws/frame.rs @@ -5,7 +5,7 @@ use tracing::debug; use super::{ mask::apply_mask, - proto::{CloseCode, CloseReason, OpCode}, + proto::{CloseCode, CloseReason, OpCode, RsvBits}, ProtocolError, }; @@ -17,7 +17,7 @@ impl Parser { fn parse_metadata( src: &[u8], server: bool, - ) -> Result)>, ProtocolError> { + ) -> Result)>, ProtocolError> { let chunk_len = src.len(); let mut idx = 2; @@ -37,6 +37,9 @@ impl Parser { return Err(ProtocolError::MaskedFrame); } + // RSV bits + let rsv_bits = RsvBits::from_bits((first & 0x70) >> 4).unwrap_or(RsvBits::empty()); + // Op code let opcode = OpCode::from(first & 0x0F); @@ -79,7 +82,7 @@ impl Parser { None }; - Ok(Some((idx, finished, opcode, length, mask))) + Ok(Some((idx, finished, opcode, rsv_bits, length, mask))) } /// Parse the input stream into a frame. @@ -87,12 +90,13 @@ impl Parser { src: &mut BytesMut, server: bool, max_size: usize, - ) -> Result)>, ProtocolError> { + ) -> Result)>, ProtocolError> { // try to parse ws frame metadata - let (idx, finished, opcode, length, mask) = match Parser::parse_metadata(src, server)? { - None => return Ok(None), - Some(res) => res, - }; + let (idx, finished, opcode, rsv_bits, length, mask) = + match Parser::parse_metadata(src, server)? { + None => return Ok(None), + Some(res) => res, + }; // not enough data if src.len() < idx + length { @@ -115,7 +119,7 @@ impl Parser { // no need for body if length == 0 { - return Ok(Some((finished, opcode, None))); + return Ok(Some((finished, opcode, rsv_bits, None))); } let mut data = src.split_to(length); @@ -127,7 +131,7 @@ impl Parser { } OpCode::Close if length > 125 => { debug!("Received close frame with payload length exceeding 125. Morphing to protocol close frame."); - return Ok(Some((true, OpCode::Close, None))); + return Ok(Some((true, OpCode::Close, rsv_bits, None))); } _ => {} } @@ -137,7 +141,7 @@ impl Parser { apply_mask(&mut data, mask); } - Ok(Some((finished, opcode, Some(data)))) + Ok(Some((finished, opcode, rsv_bits, Some(data)))) } /// Parse the payload of a close frame. @@ -161,15 +165,15 @@ impl Parser { dst: &mut BytesMut, pl: B, op: OpCode, + rsv_bits: RsvBits, fin: bool, mask: bool, ) { let payload = pl.as_ref(); - let one: u8 = if fin { - 0x80 | Into::::into(op) - } else { - op.into() - }; + let fin_bits = if fin { 0x80 } else { 0x00 }; + let rsv_bits = rsv_bits.bits() << 4; + + let one: u8 = fin_bits | rsv_bits | Into::::into(op); let payload_len = payload.len(); let (two, p_len) = if mask { (0x80, payload_len + 4) @@ -203,7 +207,12 @@ impl Parser { /// Create a new Close control frame. #[inline] - pub fn write_close(dst: &mut BytesMut, reason: Option, mask: bool) { + pub fn write_close( + dst: &mut BytesMut, + reason: Option, + rsv_bits: RsvBits, + mask: bool, + ) { let payload = match reason { None => Vec::new(), Some(reason) => { @@ -215,7 +224,7 @@ impl Parser { } }; - Parser::write_message(dst, payload, OpCode::Close, true, mask) + Parser::write_message(dst, payload, OpCode::Close, rsv_bits, true, mask) } } @@ -228,18 +237,22 @@ mod tests { struct F { finished: bool, opcode: OpCode, + rsv_bits: RsvBits, payload: Bytes, } - fn is_none(frm: &Result)>, ProtocolError>) -> bool { + fn is_none( + frm: &Result)>, ProtocolError>, + ) -> bool { matches!(*frm, Ok(None)) } - fn extract(frm: Result)>, ProtocolError>) -> F { + fn extract(frm: Result)>, ProtocolError>) -> F { match frm { - Ok(Some((finished, opcode, payload))) => F { + Ok(Some((finished, opcode, rsv_bits, payload))) => F { finished, opcode, + rsv_bits, payload: payload .map(|b| b.freeze()) .unwrap_or_else(|| Bytes::from("")), @@ -260,6 +273,17 @@ mod tests { assert!(!frame.finished); assert_eq!(frame.opcode, OpCode::Text); assert_eq!(frame.payload.as_ref(), &b"1"[..]); + + let mut buf = BytesMut::from(&[0b1111_0001u8, 0b0000_0001u8][..]); + buf.extend(b"2"); + + let frame = extract(Parser::parse(&mut buf, false, 1024)); + assert!(frame.finished); + assert_eq!(frame.opcode, OpCode::Text); + assert_eq!(frame.payload.as_ref(), &b"2"[..]); + assert!(frame.rsv_bits.contains(RsvBits::RSV1)); + assert!(frame.rsv_bits.contains(RsvBits::RSV2)); + assert!(frame.rsv_bits.contains(RsvBits::RSV3)); } #[test] @@ -368,7 +392,14 @@ mod tests { #[test] fn test_ping_frame() { let mut buf = BytesMut::new(); - Parser::write_message(&mut buf, Vec::from("data"), OpCode::Ping, true, false); + Parser::write_message( + &mut buf, + Vec::from("data"), + OpCode::Ping, + RsvBits::empty(), + true, + false, + ); let mut v = vec![137u8, 4u8]; v.extend(b"data"); @@ -378,7 +409,14 @@ mod tests { #[test] fn test_pong_frame() { let mut buf = BytesMut::new(); - Parser::write_message(&mut buf, Vec::from("data"), OpCode::Pong, true, false); + Parser::write_message( + &mut buf, + Vec::from("data"), + OpCode::Pong, + RsvBits::empty(), + true, + false, + ); let mut v = vec![138u8, 4u8]; v.extend(b"data"); @@ -389,7 +427,7 @@ mod tests { fn test_close_frame() { let mut buf = BytesMut::new(); let reason = (CloseCode::Normal, "data"); - Parser::write_close(&mut buf, Some(reason.into()), false); + Parser::write_close(&mut buf, Some(reason.into()), RsvBits::empty(), false); let mut v = vec![136u8, 6u8, 3u8, 232u8]; v.extend(b"data"); @@ -399,7 +437,7 @@ mod tests { #[test] fn test_empty_close_frame() { let mut buf = BytesMut::new(); - Parser::write_close(&mut buf, None, false); + Parser::write_close(&mut buf, None, RsvBits::empty(), false); assert_eq!(&buf[..], &vec![0x88, 0x00][..]); } } diff --git a/actix-http/src/ws/mod.rs b/actix-http/src/ws/mod.rs index 88053b254d5..f08012573b6 100644 --- a/actix-http/src/ws/mod.rs +++ b/actix-http/src/ws/mod.rs @@ -11,16 +11,20 @@ use http::{header, Method, StatusCode}; use crate::{body::BoxBody, header::HeaderValue, RequestHead, Response, ResponseBuilder}; mod codec; +#[cfg(feature = "compress-ws-deflate")] +mod deflate; mod dispatcher; mod frame; mod mask; mod proto; +#[cfg(feature = "compress-ws-deflate")] +pub use self::deflate::{DeflateCompressionLevel, DeflateServerConfig, DeflateSessionParameters}; pub use self::{ - codec::{Codec, Frame, Item, Message}, + codec::{Codec, Decoder, Encoder, Frame, Item, Message}, dispatcher::Dispatcher, frame::Parser, - proto::{hash_key, CloseCode, CloseReason, OpCode}, + proto::{hash_key, CloseCode, CloseReason, OpCode, RsvBits}, }; /// WebSocket protocol errors. @@ -93,6 +97,11 @@ pub enum HandshakeError { /// WebSocket key is not set or wrong. #[display("unknown WebSocket key")] BadWebsocketKey, + + /// Invalid `permessage-deflate` request. + #[cfg(feature = "compress-ws-deflate")] + #[display("invalid WebSocket `permessage-deflate` extension request")] + BadDeflateRequest(deflate::DeflateHandshakeError), } impl From for Response { @@ -135,6 +144,13 @@ impl From for Response { res.head_mut().reason = Some("Handshake error"); res } + + #[cfg(feature = "compress-ws-deflate")] + HandshakeError::BadDeflateRequest(_) => { + let mut res = Response::bad_request(); + res.head_mut().reason = Some("Invalid permessage-deflate request"); + res + } } } } @@ -151,6 +167,69 @@ pub fn handshake(req: &RequestHead) -> Result { Ok(handshake_response(req)) } +/// Verify WebSocket handshake request with DEFLATE compression configurations. +#[cfg(feature = "compress-ws-deflate")] +pub fn handshake_deflate( + config: &deflate::DeflateServerConfig, + req: &RequestHead, +) -> Result< + ( + ResponseBuilder, + Option<( + deflate::DeflateCompressionContext, + deflate::DeflateDecompressionContext, + )>, + ), + HandshakeError, +> { + verify_handshake(req)?; + + let mut available_configurations = vec![]; + for header in req.headers().get_all(header::SEC_WEBSOCKET_EXTENSIONS) { + let Ok(header_str) = header.to_str() else { + continue; + }; + + available_configurations.extend(deflate::DeflateSessionParameters::from_extension_header( + header_str, + )); + } + + let mut selected_config = None; + let mut selected_error = None; + for config in available_configurations { + match config { + Ok(config) => { + selected_config = Some(config); + break; + } + Err(err) => { + if selected_error.is_none() { + selected_error = Some(err); + } else { + selected_error = + Some(deflate::DeflateHandshakeError::NoSuitableConfigurationFound); + } + } + } + } + + if let Some(selected_error) = selected_error { + Err(HandshakeError::BadDeflateRequest(selected_error)) + } else { + let mut response = handshake_response(req); + + if let Some(selected_config) = selected_config { + let param = config.negotiate(selected_config); + let contexts = param.create_context(config.compression_level, false); + response.insert_header(param); + Ok((response, Some(contexts))) + } else { + Ok((response, None)) + } + } +} + /// Verify WebSocket handshake request. pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> { // WebSocket accepts only GET @@ -196,6 +275,7 @@ pub fn verify_handshake(req: &RequestHead) -> Result<(), HandshakeError> { if !req.headers().contains_key(header::SEC_WEBSOCKET_KEY) { return Err(HandshakeError::BadWebsocketKey); } + Ok(()) } diff --git a/actix-http/src/ws/proto.rs b/actix-http/src/ws/proto.rs index 27815eaf248..1bdfcf8f7f6 100644 --- a/actix-http/src/ws/proto.rs +++ b/actix-http/src/ws/proto.rs @@ -222,6 +222,25 @@ impl> From<(CloseCode, T)> for CloseReason { } } +bitflags::bitflags! { + /// RSV bits defined in [RFC 6455 §5.2]. + /// Reserved for extensions and should be set to zero if no extensions are applicable. + /// + /// [RFC 6455 §5.2]: https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 + #[derive(Debug, Eq, PartialEq, Clone, Copy)] + pub struct RsvBits: u8 { + const RSV1 = 0b0000_0100; + const RSV2 = 0b0000_0010; + const RSV3 = 0b0000_0001; + } +} + +impl Default for RsvBits { + fn default() -> Self { + Self::empty() + } +} + /// The WebSocket GUID as stated in the spec. /// See . static WS_GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; diff --git a/awc/Cargo.toml b/awc/Cargo.toml index c09f32ac862..4ef3c52397d 100644 --- a/awc/Cargo.toml +++ b/awc/Cargo.toml @@ -79,6 +79,8 @@ compress-brotli = ["actix-http/compress-brotli", "__compress"] compress-gzip = ["actix-http/compress-gzip", "__compress"] # Zstd algorithm content-encoding support compress-zstd = ["actix-http/compress-zstd", "__compress"] +# Deflate compression for WebSocket +compress-ws-deflate = ["actix-http/compress-ws-deflate"] # Cookie parsing and cookie jar cookies = ["dep:cookie"] @@ -112,7 +114,7 @@ futures-util = { version = "0.3.17", default-features = false, features = ["allo h2 = "0.3.26" http = "0.2.7" itoa = "1" -log =" 0.4" +log = "0.4" mime = "0.3" percent-encoding = "2.1" pin-project-lite = "0.2" diff --git a/awc/src/ws.rs b/awc/src/ws.rs index 760331e9d6a..507bcd3c484 100644 --- a/awc/src/ws.rs +++ b/awc/src/ws.rs @@ -30,6 +30,8 @@ use std::{fmt, net::SocketAddr, str}; use actix_codec::Framed; pub use actix_http::ws::{CloseCode, CloseReason, Codec, Frame, Message}; +#[cfg(feature = "compress-ws-deflate")] +pub use actix_http::ws::{DeflateCompressionLevel, DeflateSessionParameters}; use actix_http::{ws, Payload, RequestHead}; use actix_rt::time::timeout; use actix_service::Service as _; @@ -59,6 +61,9 @@ pub struct WebsocketsRequest { server_mode: bool, config: ClientConfig, + #[cfg(feature = "compress-ws-deflate")] + deflate_compression_level: Option, + #[cfg(feature = "cookies")] cookies: Option, } @@ -94,6 +99,8 @@ impl WebsocketsRequest { protocols: None, max_size: 65_536, server_mode: false, + #[cfg(feature = "compress-ws-deflate")] + deflate_compression_level: None, #[cfg(feature = "cookies")] cookies: None, } @@ -249,6 +256,22 @@ impl WebsocketsRequest { self.header(AUTHORIZATION, format!("Bearer {}", token)) } + /// Enable DEFLATE compression + #[cfg(feature = "compress-ws-deflate")] + pub fn deflate( + mut self, + compression_level: Option, + params: DeflateSessionParameters, + ) -> Self { + use actix_http::header::TryIntoHeaderPair; + // Assume session parameters are always valid. + let (key, value) = params.try_into_pair().unwrap(); + + self.deflate_compression_level = compression_level; + + self.header(key, value) + } + /// Complete request construction and connect to a WebSocket server. pub async fn connect( mut self, @@ -409,17 +432,52 @@ impl WebsocketsRequest { return Err(WsClientError::MissingWebSocketAcceptHeader); }; - // response and ws framed - Ok(( - ClientResponse::new(head, Payload::None), - framed.into_map_codec(|_| { + #[cfg(feature = "compress-ws-deflate")] + let framed = { + let selected_parameter = head + .headers + .get_all(header::SEC_WEBSOCKET_EXTENSIONS) + .filter_map(|header| { + if let Ok(header_str) = header.to_str() { + Some(DeflateSessionParameters::from_extension_header(header_str)) + } else { + None + } + }) + .flatten() + .filter_map(Result::ok) + .next(); + + framed.into_map_codec(move |_| { + let codec = if let Some(parameter) = selected_parameter.clone() { + let (compress, decompress) = + parameter.create_context(self.deflate_compression_level, false); + Codec::new_deflate(compress, decompress) + } else { + Codec::new() + } + .max_size(max_size); + if server_mode { - ws::Codec::new().max_size(max_size) + codec } else { - ws::Codec::new().max_size(max_size).client_mode() + codec.client_mode() } - }), - )) + }) + }; + #[cfg(not(feature = "compress-ws-deflate"))] + let framed = framed.into_map_codec(move |_| { + let codec = Codec::new().max_size(max_size); + + if server_mode { + codec + } else { + codec.client_mode() + } + }); + + // response and ws framed + Ok((ClientResponse::new(head, Payload::None), framed)) } }