Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add zstd support for compression middleware #258

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions compress.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ import (
"compress/gzip"
"io"
"net/http"
"slices"
"strings"

"github.com/felixge/httpsnoop"
"github.com/klauspost/compress/zstd"
)

const acceptEncoding string = "Accept-Encoding"
Expand Down Expand Up @@ -55,8 +57,9 @@ func (cw *compressResponseWriter) Flush() {
}
}

// CompressHandler gzip compresses HTTP responses for clients that support it
// via the 'Accept-Encoding' header.
// CompressHandler zstd compresses HTTP responses for clients that support it
// via the 'Accept-Encoding' header. If zstd is not supported, it will fall back
// to gzip or flate compression.
//
// Compressing TLS traffic may leak the page contents to an attacker if the
// page contains user input: http://security.stackexchange.com/a/102015/12208
Expand All @@ -78,17 +81,24 @@ func CompressHandlerLevel(h http.Handler, level int) http.Handler {
const (
gzipEncoding = "gzip"
flateEncoding = "deflate"
zstdEncoding = "zstd"
)

return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// detect what encoding to use
var encoding string
for _, curEnc := range strings.Split(r.Header.Get(acceptEncoding), ",") {
curEnc = strings.TrimSpace(curEnc)
if curEnc == gzipEncoding || curEnc == flateEncoding {
encoding = curEnc
break
}
encodings := strings.Split(r.Header.Get(acceptEncoding), ",")
for i := range encodings {
encodings[i] = strings.TrimSpace(encodings[i])
}

// zstd > gzip > flate
if slices.Contains(encodings, zstdEncoding) {
encoding = zstdEncoding
} else if slices.Contains(encodings, gzipEncoding) {
encoding = gzipEncoding
} else if slices.Contains(encodings, flateEncoding) {
encoding = flateEncoding
}

// always add Accept-Encoding to Vary to prevent intermediate caches corruption
Expand All @@ -108,7 +118,17 @@ func CompressHandlerLevel(h http.Handler, level int) http.Handler {

// wrap the ResponseWriter with the writer for the chosen encoding
var encWriter io.WriteCloser
if encoding == gzipEncoding {
if encoding == zstdEncoding {
// Map gzip compression levels to zstd compression levels
zstdSpeed := zstd.SpeedDefault
if level == gzip.BestSpeed {
zstdSpeed = zstd.SpeedFastest
} else if level == gzip.BestCompression {
zstdSpeed = zstd.SpeedBestCompression
}

encWriter, _ = zstd.NewWriter(w, zstd.WithEncoderLevel(zstd.EncoderLevel(zstdSpeed)))
} else if encoding == gzipEncoding {
encWriter, _ = gzip.NewWriterLevel(w, level)
} else if encoding == flateEncoding {
encWriter, _ = flate.NewWriter(w, level)
Expand Down
30 changes: 30 additions & 0 deletions compress_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,24 @@ func TestAcceptEncodingIsDropped(t *testing.T) {
}
}

func TestCompressHandlerZstd(t *testing.T) {
w := httptest.NewRecorder()
compressedRequest(w, "zstd")
resp := w.Result()
if resp.Header.Get("Content-Encoding") != "zstd" {
t.Fatalf("wrong content encoding, got %q want %q", resp.Header.Get("Content-Encoding"), "zstd")
}
if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" {
t.Fatalf("wrong content type, got %s want %s", resp.Header.Get("Content-Type"), "text/plain; charset=utf-8")
}
if w.Body.Len() != 32 {
t.Fatalf("wrong len, got %d want %d", w.Body.Len(), 32)
}
if l := resp.Header.Get("Content-Length"); l != "" {
t.Fatalf("wrong content-length. got %q expected %q", l, "")
}
}

func TestCompressHandlerGzip(t *testing.T) {
w := httptest.NewRecorder()
compressedRequest(w, "gzip")
Expand Down Expand Up @@ -171,6 +189,18 @@ func TestCompressHandlerGzipDeflate(t *testing.T) {
}
}

func TestCompressHandlerGzipDeflateBrZstd(t *testing.T) {
w := httptest.NewRecorder()
compressedRequest(w, "gzip, deflate, br, zstd")
resp := w.Result()
if resp.Header.Get("Content-Encoding") != "zstd" {
t.Fatalf("wrong content encoding, got %q want %q", resp.Header.Get("Content-Encoding"), "zstd")
}
if resp.Header.Get("Content-Type") != "text/plain; charset=utf-8" {
t.Fatalf("wrong content type, got %s want %s", resp.Header.Get("Content-Type"), "text/plain; charset=utf-8")
}
}

// Make sure we can compress and serve an *os.File properly. We need
// to use a real http server to trigger the net/http sendfile special
// case.
Expand Down
4 changes: 3 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
module github.com/gorilla/handlers

go 1.20
go 1.21

require github.com/felixge/httpsnoop v1.0.3

require github.com/klauspost/compress v1.17.11
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk=
github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U=
github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc=
github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0=
2 changes: 1 addition & 1 deletion logging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func TestLogUser(t *testing.T) {
func BenchmarkWriteLog(b *testing.B) {
loc, err := time.LoadLocation("Europe/Warsaw")
if err != nil {
b.Fatalf(err.Error())
b.Fatal(err.Error())
}
ts := time.Date(1983, 0o5, 26, 3, 30, 45, 0, loc)

Expand Down
Loading