diff --git a/src/extension/deflate.rs b/src/extension/deflate.rs index 259c9083..96c56363 100644 --- a/src/extension/deflate.rs +++ b/src/extension/deflate.rs @@ -31,6 +31,10 @@ const SERVER_MAX_WINDOW_BITS: &str = "server_max_window_bits"; const CLIENT_NO_CONTEXT_TAKEOVER: &str = "client_no_context_takeover"; const CLIENT_MAX_WINDOW_BITS: &str = "client_max_window_bits"; +const DEFAULT_GROWTH: usize = 4096; +const DEFAULT_DECOMPRESS_SIZE: usize = 256 * 1024 * 1024; +const TRAILER: [u8; 4] = [0, 0, 0xff, 0xff]; + /// The deflate extension type. /// #[derive(Debug)] @@ -39,9 +43,12 @@ pub struct Deflate { enabled: bool, buffer: Vec, params: Vec>, + zlib_compression_level: Compression, our_max_window_bits: u8, their_max_window_bits: u8, await_last_fragment: bool, + max_buffer_size: usize, + grow_buffer_size: usize, decoder_no_context_takeover: bool, decoder: Decompress, encoder_no_context_takeover: bool, @@ -64,9 +71,12 @@ impl Deflate { enabled: false, buffer: Vec::new(), params, + zlib_compression_level: Compression::fast(), our_max_window_bits: 15, their_max_window_bits: 15, await_last_fragment: false, + max_buffer_size: DEFAULT_DECOMPRESS_SIZE, + grow_buffer_size: DEFAULT_GROWTH, decoder_no_context_takeover: false, decoder: Decompress::new(false), encoder_no_context_takeover: false, @@ -149,6 +159,33 @@ impl Deflate { } } + /// Set the maximum size of the internal buffer used for decompression. + /// + /// Messages that decompress to a size larger than this will fail to decode. + /// + /// The default size is 256 MiB. + pub fn set_max_buffer_size(&mut self, size: usize) { + self.max_buffer_size = size; + } + + /// Set the size by which the internal buffer grows when it runs out of space. + /// The underlying Rust implementation may grow by more than this. + /// + /// The default grow size is 4096. + pub fn set_grow_buffer_size(&mut self, size: usize) { + self.grow_buffer_size = size; + } + + /// Set the zlib compression level to use. The range is from 0 (no compression) to 9 (best compression). + /// + /// The default is 1 (fastest compression). + pub fn set_compression_level(&mut self, level: u32) { + self.zlib_compression_level = match level { + 0..=9 => Compression::new(level), + _ => panic!("invalid compression level: {}", level), + }; + } + fn set_their_max_window_bits(&mut self, p: &Param, expected: Option) -> Result<(), ()> { if let Some(Ok(v)) = p.value().map(|s| s.parse::()) { if v < 8 || v > 15 { @@ -323,7 +360,7 @@ impl Extension for Deflate { // Restore LEN and NLEN: data.extend_from_slice(&[0, 0, 0xff, 0xff]); // cf. RFC 7692, 7.2.2 - let buffer_block_size: usize = (2 << (self.their_max_window_bits - 3)) + data.len(); + let buffer_block_size: usize = self.grow_buffer_size + data.len(); self.buffer.clear(); // Guess at an initial buffer size needed. self.buffer.reserve(buffer_block_size); @@ -348,13 +385,18 @@ impl Extension for Deflate { } } + if self.buffer.len() >= self.max_buffer_size { + return Err( + io::Error::new(io::ErrorKind::Other, "decompressed message too large").into() + ); + } if decoder.total_in() == t_in && decoder.total_out() == t_out { return Err(io::Error::new(io::ErrorKind::Other, "decompression stalled").into()); } if decoder.total_in() > t_in { input = &input[(decoder.total_in() - t_in) as usize..]; } - self.buffer.reserve(buffer_block_size); + self.buffer.reserve(self.grow_buffer_size); } log::trace!( "decompressed {}->{} bytes, total {}/{} ratio {:.2}", @@ -393,7 +435,11 @@ impl Extension for Deflate { if self.encoder.is_none() { self.encoder = Some( - Compress::new_with_window_bits(Compression::fast(), false, self.our_max_window_bits) + Compress::new_with_window_bits( + self.zlib_compression_level, + false, + self.our_max_window_bits + ) ); } let encoder = self.encoder.as_mut().unwrap(); @@ -424,11 +470,11 @@ impl Extension for Deflate { if encoder.total_in() > t_in { input = &input[(encoder.total_in() - t_in) as usize..]; } - self.buffer.reserve(4096); + self.buffer.reserve(self.grow_buffer_size); } // We need to append an empty deflate block if not there yet (RFC 7692, 7.2.1). - while !self.buffer.ends_with(&[0, 0, 0xff, 0xff]) { + while !self.buffer.ends_with(&TRAILER) { self.buffer.reserve(5); // Make sure there is room for the trailing end bytes. match encoder.compress_vec(&[], &mut self.buffer, FlushCompress::Sync)? { Status::Ok => { @@ -444,8 +490,8 @@ impl Extension for Deflate { } // If we still have not seen the empty deflate block appended, something is wrong. - if !self.buffer.ends_with(&[0, 0, 0xff, 0xff]) { - log::error!("missing 00 00 FF FF"); + if !self.buffer.ends_with(&TRAILER) { + log::error!("missing trailer {:?}", TRAILER); return Err(io::Error::new(io::ErrorKind::Other, "missing 00 00 FF FF").into()); }