From 7bab199c2bf6c9625f44f8770d5c2166a1194c0b Mon Sep 17 00:00:00 2001 From: Bryan Gillespie Date: Sat, 26 Oct 2024 21:31:53 -0600 Subject: [PATCH 1/2] Prefer zstd for compressions over gzip when possible - Add zstd support to compress middleware with github.com/klauspost/compress/zstd - Prefer zstd over gzip or flate when possible to leverage its superior compression ratios and speed - Add unit tests for zstd --- compress.go | 38 +++++++++++++++++++++++++++++--------- compress_test.go | 30 ++++++++++++++++++++++++++++++ go.mod | 4 +++- go.sum | 2 ++ 4 files changed, 64 insertions(+), 10 deletions(-) diff --git a/compress.go b/compress.go index d6f5895..1e986b7 100644 --- a/compress.go +++ b/compress.go @@ -9,9 +9,11 @@ import ( "compress/gzip" "io" "net/http" + "slices" "strings" "github.com/felixge/httpsnoop" + "github.com/klauspost/compress/zstd" ) const acceptEncoding string = "Accept-Encoding" @@ -55,8 +57,9 @@ func (cw *compressResponseWriter) Flush() { } } -// CompressHandler gzip compresses HTTP responses for clients that support it -// via the 'Accept-Encoding' header. +// CompressHandler zstd compresses HTTP responses for clients that support it +// via the 'Accept-Encoding' header. If zstd is not supported, it will fall back +// to gzip or flate compression. // // Compressing TLS traffic may leak the page contents to an attacker if the // page contains user input: http://security.stackexchange.com/a/102015/12208 @@ -78,17 +81,24 @@ func CompressHandlerLevel(h http.Handler, level int) http.Handler { const ( gzipEncoding = "gzip" flateEncoding = "deflate" + zstdEncoding = "zstd" ) return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // detect what encoding to use var encoding string - for _, curEnc := range strings.Split(r.Header.Get(acceptEncoding), ",") { - curEnc = strings.TrimSpace(curEnc) - if curEnc == gzipEncoding || curEnc == flateEncoding { - encoding = curEnc - break - } + encodings := strings.Split(r.Header.Get(acceptEncoding), ",") + for i := range encodings { + encodings[i] = strings.TrimSpace(encodings[i]) + } + + // zstd > gzip > flate + if slices.Contains(encodings, zstdEncoding) { + encoding = zstdEncoding + } else if slices.Contains(encodings, gzipEncoding) { + encoding = gzipEncoding + } else if slices.Contains(encodings, flateEncoding) { + encoding = flateEncoding } // always add Accept-Encoding to Vary to prevent intermediate caches corruption @@ -108,7 +118,17 @@ func CompressHandlerLevel(h http.Handler, level int) http.Handler { // wrap the ResponseWriter with the writer for the chosen encoding var encWriter io.WriteCloser - if encoding == gzipEncoding { + if encoding == zstdEncoding { + // Map gzip compression levels to zstd compression levels + zstdSpeed := zstd.SpeedDefault + if level == gzip.BestSpeed { + zstdSpeed = zstd.SpeedFastest + } else if level == gzip.BestCompression { + zstdSpeed = zstd.SpeedBestCompression + } + + encWriter, _ = zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.EncoderLevel(zstdSpeed))) + } else if encoding == gzipEncoding { encWriter, _ = gzip.NewWriterLevel(w, level) } else if encoding == flateEncoding { encWriter, _ = flate.NewWriter(w, level) diff --git a/compress_test.go b/compress_test.go index 25fdbe9..3a97d1a 100644 --- a/compress_test.go +++ b/compress_test.go @@ -126,6 +126,24 @@ func TestAcceptEncodingIsDropped(t *testing.T) { } } +func TestCompressHandlerZstd(t *testing.T) { + w := httptest.NewRecorder() + compressedRequest(w, "zstd") + resp := w.Result() + if resp.Header.Get("Content-Encoding") != "zstd" { + t.Fatalf("wrong content encoding, got %q want %q", resp.Header.Get("Content-Encoding"), "zstd") + } + if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" { + t.Fatalf("wrong content type, got %s want %s", resp.Header.Get("Content-Type"), "text/plain; charset=utf-8") + } + if w.Body.Len() != 32 { + t.Fatalf("wrong len, got %d want %d", w.Body.Len(), 32) + } + if l := resp.Header.Get("Content-Length"); l != "" { + t.Fatalf("wrong content-length. got %q expected %q", l, "") + } +} + func TestCompressHandlerGzip(t *testing.T) { w := httptest.NewRecorder() compressedRequest(w, "gzip") @@ -171,6 +189,18 @@ func TestCompressHandlerGzipDeflate(t *testing.T) { } } +func TestCompressHandlerGzipDeflateBrZstd(t *testing.T) { + w := httptest.NewRecorder() + compressedRequest(w, "gzip, deflate, br, zstd") + resp := w.Result() + if resp.Header.Get("Content-Encoding") != "zstd" { + t.Fatalf("wrong content encoding, got %q want %q", resp.Header.Get("Content-Encoding"), "zstd") + } + if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" { + t.Fatalf("wrong content type, got %s want %s", resp.Header.Get("Content-Type"), "text/plain; charset=utf-8") + } +} + // Make sure we can compress and serve an *os.File properly. We need // to use a real http server to trigger the net/http sendfile special // case. diff --git a/go.mod b/go.mod index c9558d5..0406a66 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,7 @@ module github.com/gorilla/handlers -go 1.20 +go 1.21 require github.com/felixge/httpsnoop v1.0.3 + +require github.com/klauspost/compress v1.17.11 diff --git a/go.sum b/go.sum index e05d02e..eb3dd75 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,4 @@ github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk= github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= From f9045ba40ea1f3fe79653fa27249324f64327de7 Mon Sep 17 00:00:00 2001 From: Bryan Gillespie Date: Sat, 26 Oct 2024 21:42:41 -0600 Subject: [PATCH 2/2] Fix govet failure flagged by `make verify` --- logging_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/logging_test.go b/logging_test.go index d33847a..0e9fe91 100644 --- a/logging_test.go +++ b/logging_test.go @@ -148,7 +148,7 @@ func TestLogUser(t *testing.T) { func BenchmarkWriteLog(b *testing.B) { loc, err := time.LoadLocation("Europe/Warsaw") if err != nil { - b.Fatalf(err.Error()) + b.Fatal(err.Error()) } ts := time.Date(1983, 0o5, 26, 3, 30, 45, 0, loc)