From f68c9b6b4eb797594e51629db07c32c927672826 Mon Sep 17 00:00:00 2001 From: Clint Ayres Date: Tue, 6 Feb 2024 12:37:32 +0000 Subject: [PATCH 1/7] update relay to support compressed data --- catcher/catcher.go | 3 +- relay/main/main.go | 6 +- .../content-blocker-plugin.go | 18 ++-- .../traffic/paths-plugin/paths-plugin_test.go | 2 +- relay/traffic/encoding.go | 92 +++++++++++++++++++ relay/traffic/handler.go | 61 +++++++++++- relay/traffic/traffic_test.go | 6 +- 7 files changed, 167 insertions(+), 21 deletions(-) create mode 100644 relay/traffic/encoding.go diff --git a/catcher/catcher.go b/catcher/catcher.go index 0de061a..708c5d8 100644 --- a/catcher/catcher.go +++ b/catcher/catcher.go @@ -6,7 +6,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "net" "net/http" @@ -82,7 +81,7 @@ func (service *Service) LastRequestBody() ([]byte, error) { } defer request.Body.Close() - body, err := ioutil.ReadAll(request.Body) + body, err := io.ReadAll(request.Body) if err != nil { return nil, err } diff --git a/relay/main/main.go b/relay/main/main.go index 95c74aa..ea545e9 100644 --- a/relay/main/main.go +++ b/relay/main/main.go @@ -2,7 +2,7 @@ package main import ( "flag" - "io/ioutil" + "io" "log" "os" "time" @@ -10,14 +10,14 @@ import ( "github.com/fullstorydev/relay-core/relay" "github.com/fullstorydev/relay-core/relay/config" "github.com/fullstorydev/relay-core/relay/environment" - "github.com/fullstorydev/relay-core/relay/traffic/plugin-loader" + plugin_loader "github.com/fullstorydev/relay-core/relay/traffic/plugin-loader" ) var logger = log.New(os.Stdout, "[relay] ", 0) func readConfigFile(path string) (rawConfigFileBytes []byte, err error) { if path == "-" { - rawConfigFileBytes, err = ioutil.ReadAll(os.Stdin) + rawConfigFileBytes, err = io.ReadAll(os.Stdin) return } diff --git a/relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin.go b/relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin.go index bdcda4f..a0b037a 100644 --- a/relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin.go +++ b/relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin.go @@ -25,7 +25,7 @@ package content_blocker_plugin import ( "bytes" "fmt" - "io/ioutil" + "io" "log" "net/http" "os" @@ -78,7 +78,7 @@ func (f contentBlockerPluginFactory) New(configSection *config.Section) (traffic } if regexp, err := regexp.Compile(pattern); err != nil { - return fmt.Errorf(`Could not compile regular expression "%v": %v`, pattern, err) + return fmt.Errorf(`could not compile regular expression "%v": %v`, pattern, err) } else { logger.Printf("Added rule: %s %s content matching \"%s\"", mode, contentKind, regexp) blockers = append(blockers, &contentBlocker{ @@ -94,7 +94,7 @@ func (f contentBlockerPluginFactory) New(configSection *config.Section) (traffic case "header": plugin.headerBlockers = append(plugin.headerBlockers, blockers...) default: - return fmt.Errorf(`Unexpected content kind %s`, contentKind) + return fmt.Errorf(`unexpected content kind %s`, contentKind) } return nil @@ -222,13 +222,12 @@ func (plug contentBlockerPlugin) blockBodyContent(response http.ResponseWriter, return false } - processedBody, err := ioutil.ReadAll(request.Body) + processedBody, err := io.ReadAll(request.Body) if err != nil { http.Error(response, fmt.Sprintf("Error reading request body: %s", err), 500) request.Body = http.NoBody return true } - initialLength := len(processedBody) for _, blocker := range plug.bodyBlockers { processedBody = blocker.Block(processedBody) @@ -236,14 +235,13 @@ func (plug contentBlockerPlugin) blockBodyContent(response http.ResponseWriter, // If the length of the body has changed, we should update the // Content-Length header too. - finalLength := len(processedBody) - if finalLength != initialLength { - contentLength := int64(finalLength) + contentLength := int64(len(processedBody)) + if contentLength != request.ContentLength { request.ContentLength = contentLength request.Header.Set("Content-Length", strconv.FormatInt(contentLength, 10)) } - request.Body = ioutil.NopCloser(bytes.NewBuffer(processedBody)) + request.Body = io.NopCloser(bytes.NewBuffer(processedBody)) return false } @@ -283,7 +281,7 @@ func (b *contentBlocker) Block(content []byte) []byte { case excludeMode: return b.regexp.ReplaceAllLiteral(content, []byte{}) default: - panic(fmt.Errorf("Invalid content blocking mode: %v", b.mode)) + panic(fmt.Errorf("invalid content blocking mode: %v", b.mode)) } } diff --git a/relay/plugins/traffic/paths-plugin/paths-plugin_test.go b/relay/plugins/traffic/paths-plugin/paths-plugin_test.go index ffd7e0c..4138516 100644 --- a/relay/plugins/traffic/paths-plugin/paths-plugin_test.go +++ b/relay/plugins/traffic/paths-plugin/paths-plugin_test.go @@ -279,7 +279,7 @@ func runPathsPluginTest(t *testing.T, testCase pathsPluginTestCase) { lastRequest, err = altCatcherService.LastRequest() } if err != nil { - t.Errorf("Error reading last request from catcher: %v", err) + t.Errorf("Text '%v': Error reading last request from catcher: %v", testCase.desc, err) return } diff --git a/relay/traffic/encoding.go b/relay/traffic/encoding.go new file mode 100644 index 0000000..ed98a57 --- /dev/null +++ b/relay/traffic/encoding.go @@ -0,0 +1,92 @@ +package traffic + +import ( + "bytes" + "compress/gzip" + "io" + "net/http" + "net/url" + "strings" +) + +func GetContentEncoding(request *http.Request) (string, error) { + // NOTE: This is a workaround for a bug in post-Go 1.17. See golang.org/issue/25192. + // Our algorithm differs from the logic of AllowQuerySemicolons by replacing semicolons with encoded semicolons instead + // of with ampersands. This is because we want to preserve the original query string as much as possible. + if strings.Contains(request.URL.RawQuery, ";") { + request.URL.RawQuery = strings.ReplaceAll(request.URL.RawQuery, ";", "%3B") // Replace semicolons with encoded semicolons. + } + + queryParams, err := url.ParseQuery(request.URL.RawQuery) + if err != nil { + return "", err + } + + // request query parameter takes precedence over request header + encoding := queryParams.Get("ContentEncoding") + if encoding == "" { + encoding = request.Header.Get("Content-Encoding") + } + return encoding, nil +} + +// WrapReader checks if the request Content-Encoding or request query parameter indicates gzip compression. +// If so, it returns a gzip.Reader that decompresses the content. +func WrapReader(request *http.Request, encoding string) (io.ReadCloser, error) { + if request.Body == nil { + return nil, nil + } + + switch encoding { + case "gzip": + // Create a new gzip.Reader to decompress the request body + return gzip.NewReader(request.Body) + default: + // If the content is not gzip-compressed, return the original request body + return request.Body, nil + } +} + +func EncodeData(data []byte, encoding string) ([]byte, error) { + switch encoding { + case "gzip": + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + + _, err := gz.Write(data) + if err != nil { + return nil, err + } + + err = gz.Close() + if err != nil { + return nil, err + } + + compressedData := buf.Bytes() + return compressedData, nil + default: + // identity encoding + return data, nil + } +} + +func DecodeData(data []byte, encoding string) ([]byte, error) { + switch encoding { + case "gzip": + reader, err := gzip.NewReader(bytes.NewReader(data)) + if err != nil { + return nil, err + } + + decodedData, err := io.ReadAll(reader) + if err != nil { + return nil, err + } + + return decodedData, nil + default: + // identity encoding + return data, nil + } +} diff --git a/relay/traffic/handler.go b/relay/traffic/handler.go index 0a45466..cdf6b66 100644 --- a/relay/traffic/handler.go +++ b/relay/traffic/handler.go @@ -10,6 +10,7 @@ import ( "net" "net/http" "os" + "strconv" "strings" "time" @@ -57,6 +58,19 @@ func (handler *Handler) ServeHTTP(response http.ResponseWriter, request *http.Re request.URL.Host = handler.config.TargetHost request.Host = handler.config.TargetHost + encoding, err := GetContentEncoding(request) + if err != nil { + logger.Printf("URL %v error getting request content encoding: %v", request.URL, err) + request.Body = http.NoBody + return + } + + if err := handler.prepareRequestBody(request, encoding); err != nil { + http.Error(response, fmt.Sprintf("Error setting up clientRequest body reader: %s", err), 500) + request.Body = http.NoBody + return + } + serviced := false for _, trafficPlugin := range handler.plugins { if trafficPlugin.HandleRequest(response, request, RequestInfo{ @@ -68,7 +82,7 @@ func (handler *Handler) ServeHTTP(response http.ResponseWriter, request *http.Re } } - if handler.HandleRequest(response, request, serviced) { + if handler.HandleRequest(response, request, serviced, encoding) { serviced = true } @@ -80,7 +94,17 @@ func (handler *Handler) ServeHTTP(response http.ResponseWriter, request *http.Re } } -func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, clientRequest *http.Request, serviced bool) bool { +// prepareRequestBody wraps the request Body with a reader that will decode the content if necessary. +func (handler *Handler) prepareRequestBody(clientRequest *http.Request, encoding string) error { + if reader, err := WrapReader(clientRequest, encoding); err != nil { + return err + } else if reader != nil && reader != http.NoBody { + clientRequest.Body = reader + } + return nil +} + +func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, clientRequest *http.Request, serviced bool, encoding string) bool { if serviced { return false } @@ -90,6 +114,7 @@ func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, client return true } + handler.ensureBodyContentEncoding(clientRequest, encoding) handler.addRelayHeaders(clientRequest) if clientRequest.Header.Get("Upgrade") == "websocket" { @@ -99,6 +124,38 @@ func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, client } } +func (handler *Handler) ensureBodyContentEncoding(clientRequest *http.Request, encoding string) { + if encoding == "" || encoding == "identity" { + return + } + + servicedBody, err := io.ReadAll(clientRequest.Body) + if err != nil { + logger.Printf("Error reading request body: %s", err) + clientRequest.Body = http.NoBody + return + } + + if encodedData, err := EncodeData(servicedBody, encoding); err != nil { + logger.Printf("Error encoding request body: %s", err) + clientRequest.Body = http.NoBody + return + } else { + servicedBody = encodedData + } + + // If the length of the body has changed, we should update the + // Content-Length header too. + contentLength := int64(len(servicedBody)) + if contentLength != clientRequest.ContentLength { + clientRequest.ContentLength = contentLength + clientRequest.Header.Set("Content-Length", strconv.FormatInt(contentLength, 10)) + } + + clientRequest.Body = io.NopCloser(bytes.NewBuffer(servicedBody)) + +} + func (handler *Handler) addRelayHeaders(clientRequest *http.Request) { // Add X-Forwarded-* headers remoteAddrTokens := strings.Split(clientRequest.RemoteAddr, ":") diff --git a/relay/traffic/traffic_test.go b/relay/traffic/traffic_test.go index 7773503..3d56295 100644 --- a/relay/traffic/traffic_test.go +++ b/relay/traffic/traffic_test.go @@ -4,7 +4,7 @@ import ( "bytes" "errors" "fmt" - "io/ioutil" + "io" "net/http" "reflect" "strings" @@ -12,7 +12,7 @@ import ( "github.com/fullstorydev/relay-core/catcher" "github.com/fullstorydev/relay-core/relay" - "github.com/fullstorydev/relay-core/relay/plugins/traffic/test-interceptor-plugin" + test_interceptor_plugin "github.com/fullstorydev/relay-core/relay/plugins/traffic/test-interceptor-plugin" "github.com/fullstorydev/relay-core/relay/test" "github.com/fullstorydev/relay-core/relay/traffic" "github.com/fullstorydev/relay-core/relay/version" @@ -214,7 +214,7 @@ func getBody(url string, t *testing.T) []byte { t.Errorf("Non-200 GET: %v", response) return nil } - body, err := ioutil.ReadAll(response.Body) + body, err := io.ReadAll(response.Body) if err != nil { t.Errorf("Error GETing body: %v", err) return nil From dc153b547843493d4648d2cedf3f0a6829919194 Mon Sep 17 00:00:00 2001 From: Clint Ayres Date: Wed, 7 Feb 2024 13:05:15 +0000 Subject: [PATCH 2/7] add e2e test for gzip support --- relay/traffic/traffic_test.go | 102 ++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/relay/traffic/traffic_test.go b/relay/traffic/traffic_test.go index 3d56295..a0788ef 100644 --- a/relay/traffic/traffic_test.go +++ b/relay/traffic/traffic_test.go @@ -151,6 +151,108 @@ func TestMaxBodySize(t *testing.T) { }) } +type Encoding int + +const ( + Identity Encoding = iota + Gzip +) + +func TestRelaySupportsContentEncoding(t *testing.T) { + testCases := map[string]struct { + encoding Encoding + bodyContentStr string + headers map[string]string + customUrl func(relayServiceURL string) string + }{ + "identity": { + encoding: Identity, + bodyContentStr: "Hello, world!", + }, + "gzip - with header": { + encoding: Gzip, + bodyContentStr: "Hello, world!", + headers: map[string]string{ + "Content-Encoding": "gzip", + }, + }, + "gzip - with query param": { + encoding: Gzip, + bodyContentStr: "Hello, world!", + customUrl: func(relayServiceURL string) string { + return fmt.Sprintf("%v?ContentEncoding=gzip", relayServiceURL) + }, + }, + } + + for desc, testCase := range testCases { + test.WithCatcherAndRelay(t, "", nil, func(catcherService *catcher.Service, relayService *relay.Service) { + // convert the body content to a reader with the proper content encoding applied + var body io.Reader + switch testCase.encoding { + case Gzip: + b, err := traffic.EncodeData([]byte(testCase.bodyContentStr), "gzip") + if err != nil { + t.Errorf("Test %s - Error encoding data: %v", desc, err) + return + } + body = bytes.NewReader(b) + case Identity: + body = strings.NewReader(testCase.bodyContentStr) + } + + requestURL := relayService.HttpUrl() + if testCase.customUrl != nil { + requestURL = testCase.customUrl(requestURL) + } + request, err := http.NewRequest("POST", requestURL, body) + if err != nil { + t.Errorf("Test %s - Error GETing: %v", desc, err) + return + } + + for header, headerValue := range testCase.headers { + request.Header.Set(header, headerValue) + } + + response, err := http.DefaultClient.Do(request) + if err != nil { + t.Errorf("Test %s - Error POSTing: %v", desc, err) + return + } + + defer response.Body.Close() + + if response.StatusCode != 200 { + t.Errorf("Test %s - Expected 200 response: %v", desc, response) + return + } + + lastRequest, err := catcherService.LastRequestBody() + if err != nil { + t.Errorf("Test %s - Error reading last request body from catcher: %v", desc, err) + return + } + + switch testCase.encoding { + case Gzip: + decodedData, err := traffic.DecodeData(lastRequest, "gzip") + if err != nil { + t.Errorf("Test %s - Error decoding data: %v", desc, err) + return + } + if string(decodedData) != testCase.bodyContentStr { + t.Errorf("Test %s - Expected body '%v' but got: %v", desc, testCase.bodyContentStr, string(decodedData)) + } + case Identity: + if string(lastRequest) != testCase.bodyContentStr { + t.Errorf("Test %s - Expected body '%v' but got: %v", desc, testCase.bodyContentStr, string(lastRequest)) + } + } + }) + } +} + func TestRelayNotFound(t *testing.T) { test.WithCatcherAndRelay(t, "", nil, func(catcherService *catcher.Service, relayService *relay.Service) { faviconURL := fmt.Sprintf("%v/favicon.ico", relayService.HttpUrl()) From 94b54c20027c00c4c7bd8eb61de4d48be0c0a3fa Mon Sep 17 00:00:00 2001 From: Clint Ayres Date: Wed, 7 Feb 2024 14:00:27 +0000 Subject: [PATCH 3/7] update blocking tests to asset they work with gzipped data --- .../content-blocker-plugin_test.go | 89 ++++++++++++++----- 1 file changed, 67 insertions(+), 22 deletions(-) diff --git a/relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin_test.go b/relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin_test.go index ab0df43..fe4c344 100644 --- a/relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin_test.go +++ b/relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin_test.go @@ -2,18 +2,26 @@ package content_blocker_plugin_test import ( "bytes" + "fmt" "net/http" "strconv" "testing" "github.com/fullstorydev/relay-core/catcher" "github.com/fullstorydev/relay-core/relay" - "github.com/fullstorydev/relay-core/relay/plugins/traffic/content-blocker-plugin" + content_blocker_plugin "github.com/fullstorydev/relay-core/relay/plugins/traffic/content-blocker-plugin" "github.com/fullstorydev/relay-core/relay/test" "github.com/fullstorydev/relay-core/relay/traffic" "github.com/fullstorydev/relay-core/relay/version" ) +type Encoding int + +const ( + Identity Encoding = iota + Gzip +) + func TestContentBlocking(t *testing.T) { testCases := []contentBlockerTestCase{ { @@ -133,7 +141,8 @@ func TestContentBlocking(t *testing.T) { } for _, testCase := range testCases { - runContentBlockerTest(t, testCase) + runContentBlockerTest(t, testCase, Identity) + runContentBlockerTest(t, testCase, Gzip) } } @@ -185,7 +194,18 @@ type contentBlockerTestCase struct { expectedHeaders map[string]string } -func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase) { +func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encoding Encoding) { + var encodingStr string + switch encoding { + case Gzip: + encodingStr = "gzip" + case Identity: + encodingStr = "" + } + + // Add encoding to the test description + desc := fmt.Sprintf("%s (encoding: %v)", testCase.desc, encodingStr) + plugins := []traffic.PluginFactory{ content_blocker_plugin.Factory, } @@ -203,16 +223,26 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase) { expectedHeaders[content_blocker_plugin.PluginVersionHeaderName] = version.RelayRelease test.WithCatcherAndRelay(t, testCase.config, plugins, func(catcherService *catcher.Service, relayService *relay.Service) { + b, err := traffic.EncodeData([]byte(testCase.originalBody), encodingStr) + if err != nil { + t.Errorf("Test '%v': Error encoding data: %v", desc, err) + return + } + request, err := http.NewRequest( "POST", relayService.HttpUrl(), - bytes.NewBufferString(testCase.originalBody), + bytes.NewBuffer(b), ) if err != nil { - t.Errorf("Test '%v': Error creating request: %v", testCase.desc, err) + t.Errorf("Test '%v': Error creating request: %v", desc, err) return } + if encoding == Gzip { + request.Header.Set("Content-Encoding", "gzip") + } + request.Header.Set("Content-Type", "application/json") for header, headerValue := range originalHeaders { request.Header.Set(header, headerValue) @@ -220,19 +250,19 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase) { response, err := http.DefaultClient.Do(request) if err != nil { - t.Errorf("Test '%v': Error POSTing: %v", testCase.desc, err) + t.Errorf("Test '%v': Error POSTing: %v", desc, err) return } defer response.Body.Close() if response.StatusCode != 200 { - t.Errorf("Test '%v': Expected 200 response: %v", testCase.desc, response) + t.Errorf("Test '%v': Expected 200 response: %v", desc, response) return } lastRequest, err := catcherService.LastRequest() if err != nil { - t.Errorf("Test '%v': Error reading last request from catcher: %v", testCase.desc, err) + t.Errorf("Test '%v': Error reading last request from catcher: %v", desc, err) return } @@ -241,7 +271,7 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase) { if expectedHeaderValue != actualHeaderValue { t.Errorf( "Test '%v': Expected header '%v' with value '%v' but got: %v", - testCase.desc, + desc, expectedHeader, expectedHeaderValue, actualHeaderValue, @@ -249,35 +279,50 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase) { } } + if lastRequest.Header.Get("Content-Encoding") != encodingStr { + t.Errorf( + "Test '%v': Expected Content-Encoding '%v' but got: %v", + desc, + encodingStr, + lastRequest.Header.Get("Content-Encoding"), + ) + } + lastRequestBody, err := catcherService.LastRequestBody() if err != nil { - t.Errorf("Test '%v': Error reading last request body from catcher: %v", testCase.desc, err) + t.Errorf("Test '%v': Error reading last request body from catcher: %v", desc, err) return } - lastRequestBodyStr := string(lastRequestBody) - if testCase.expectedBody != lastRequestBodyStr { - t.Errorf( - "Test '%v': Expected body '%v' but got: %v", - testCase.desc, - testCase.expectedBody, - lastRequestBodyStr, - ) - } - contentLength, err := strconv.Atoi(lastRequest.Header.Get("Content-Length")) if err != nil { - t.Errorf("Test '%v': Error parsing Content-Length: %v", testCase.desc, err) + t.Errorf("Test '%v': Error parsing Content-Length: %v", desc, err) return } if contentLength != len(lastRequestBody) { t.Errorf( "Test '%v': Content-Length is %v but actual body length is %v", - testCase.desc, + desc, contentLength, len(lastRequestBody), ) } + + decodedRequestBody, err := traffic.DecodeData(lastRequestBody, encodingStr) + if err != nil { + t.Errorf("Test '%v': Error decoding data: %v", desc, err) + return + } + + lastRequestBodyStr := string(decodedRequestBody) + if testCase.expectedBody != lastRequestBodyStr { + t.Errorf( + "Test '%v': Expected body '%v' but got: %v", + desc, + testCase.expectedBody, + lastRequestBodyStr, + ) + } }) } From a18311e817c401686635a3211abbce1bb1d24015 Mon Sep 17 00:00:00 2001 From: Clint Ayres Date: Fri, 9 Feb 2024 11:13:18 +0000 Subject: [PATCH 4/7] pr feedback (sethf) --- .../content-blocker-plugin_test.go | 23 +++----- relay/traffic/encoding.go | 49 +++++++++++----- relay/traffic/handler.go | 58 ++++++++++--------- relay/traffic/traffic_test.go | 27 ++++----- 4 files changed, 84 insertions(+), 73 deletions(-) diff --git a/relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin_test.go b/relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin_test.go index fe4c344..501a026 100644 --- a/relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin_test.go +++ b/relay/plugins/traffic/content-blocker-plugin/content-blocker-plugin_test.go @@ -15,13 +15,6 @@ import ( "github.com/fullstorydev/relay-core/relay/version" ) -type Encoding int - -const ( - Identity Encoding = iota - Gzip -) - func TestContentBlocking(t *testing.T) { testCases := []contentBlockerTestCase{ { @@ -141,8 +134,8 @@ func TestContentBlocking(t *testing.T) { } for _, testCase := range testCases { - runContentBlockerTest(t, testCase, Identity) - runContentBlockerTest(t, testCase, Gzip) + runContentBlockerTest(t, testCase, traffic.Identity) + runContentBlockerTest(t, testCase, traffic.Gzip) } } @@ -194,12 +187,12 @@ type contentBlockerTestCase struct { expectedHeaders map[string]string } -func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encoding Encoding) { +func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encoding traffic.Encoding) { var encodingStr string switch encoding { - case Gzip: + case traffic.Gzip: encodingStr = "gzip" - case Identity: + case traffic.Identity: encodingStr = "" } @@ -223,7 +216,7 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encodi expectedHeaders[content_blocker_plugin.PluginVersionHeaderName] = version.RelayRelease test.WithCatcherAndRelay(t, testCase.config, plugins, func(catcherService *catcher.Service, relayService *relay.Service) { - b, err := traffic.EncodeData([]byte(testCase.originalBody), encodingStr) + b, err := traffic.EncodeData([]byte(testCase.originalBody), encoding) if err != nil { t.Errorf("Test '%v': Error encoding data: %v", desc, err) return @@ -239,7 +232,7 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encodi return } - if encoding == Gzip { + if encoding == traffic.Gzip { request.Header.Set("Content-Encoding", "gzip") } @@ -309,7 +302,7 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encodi ) } - decodedRequestBody, err := traffic.DecodeData(lastRequestBody, encodingStr) + decodedRequestBody, err := traffic.DecodeData(lastRequestBody, encoding) if err != nil { t.Errorf("Test '%v': Error decoding data: %v", desc, err) return diff --git a/relay/traffic/encoding.go b/relay/traffic/encoding.go index ed98a57..b8e0e06 100644 --- a/relay/traffic/encoding.go +++ b/relay/traffic/encoding.go @@ -3,13 +3,22 @@ package traffic import ( "bytes" "compress/gzip" + "fmt" "io" "net/http" "net/url" "strings" ) -func GetContentEncoding(request *http.Request) (string, error) { +type Encoding int + +const ( + Unsupported Encoding = iota + Identity + Gzip +) + +func GetContentEncoding(request *http.Request) (Encoding, error) { // NOTE: This is a workaround for a bug in post-Go 1.17. See golang.org/issue/25192. // Our algorithm differs from the logic of AllowQuerySemicolons by replacing semicolons with encoded semicolons instead // of with ampersands. This is because we want to preserve the original query string as much as possible. @@ -19,7 +28,7 @@ func GetContentEncoding(request *http.Request) (string, error) { queryParams, err := url.ParseQuery(request.URL.RawQuery) if err != nil { - return "", err + return Unsupported, err } // request query parameter takes precedence over request header @@ -27,29 +36,39 @@ func GetContentEncoding(request *http.Request) (string, error) { if encoding == "" { encoding = request.Header.Get("Content-Encoding") } - return encoding, nil + + switch encoding { + case "gzip": + return Gzip, nil + case "": + return Identity, nil + default: + return Unsupported, fmt.Errorf("unsupported encoding: %v", encoding) + } } // WrapReader checks if the request Content-Encoding or request query parameter indicates gzip compression. // If so, it returns a gzip.Reader that decompresses the content. -func WrapReader(request *http.Request, encoding string) (io.ReadCloser, error) { +func WrapReader(request *http.Request, encoding Encoding) (io.ReadCloser, error) { if request.Body == nil { return nil, nil } switch encoding { - case "gzip": + case Gzip: // Create a new gzip.Reader to decompress the request body return gzip.NewReader(request.Body) - default: + case Identity: // If the content is not gzip-compressed, return the original request body return request.Body, nil + default: + return nil, fmt.Errorf("unsupported encoding: %v", encoding) } } -func EncodeData(data []byte, encoding string) ([]byte, error) { +func EncodeData(data []byte, encoding Encoding) ([]byte, error) { switch encoding { - case "gzip": + case Gzip: var buf bytes.Buffer gz := gzip.NewWriter(&buf) @@ -65,15 +84,16 @@ func EncodeData(data []byte, encoding string) ([]byte, error) { compressedData := buf.Bytes() return compressedData, nil - default: - // identity encoding + case Identity: return data, nil + default: + return nil, fmt.Errorf("unsupported encoding: %v", encoding) } } -func DecodeData(data []byte, encoding string) ([]byte, error) { +func DecodeData(data []byte, encoding Encoding) ([]byte, error) { switch encoding { - case "gzip": + case Gzip: reader, err := gzip.NewReader(bytes.NewReader(data)) if err != nil { return nil, err @@ -85,8 +105,9 @@ func DecodeData(data []byte, encoding string) ([]byte, error) { } return decodedData, nil - default: - // identity encoding + case Identity: return data, nil + default: + return nil, fmt.Errorf("unsupported encoding: %v", encoding) } } diff --git a/relay/traffic/handler.go b/relay/traffic/handler.go index cdf6b66..ce15a8a 100644 --- a/relay/traffic/handler.go +++ b/relay/traffic/handler.go @@ -60,7 +60,7 @@ func (handler *Handler) ServeHTTP(response http.ResponseWriter, request *http.Re encoding, err := GetContentEncoding(request) if err != nil { - logger.Printf("URL %v error getting request content encoding: %v", request.URL, err) + logger.Printf("URL %v error in request content encoding: %v", request.URL, err) request.Body = http.NoBody return } @@ -95,7 +95,7 @@ func (handler *Handler) ServeHTTP(response http.ResponseWriter, request *http.Re } // prepareRequestBody wraps the request Body with a reader that will decode the content if necessary. -func (handler *Handler) prepareRequestBody(clientRequest *http.Request, encoding string) error { +func (handler *Handler) prepareRequestBody(clientRequest *http.Request, encoding Encoding) error { if reader, err := WrapReader(clientRequest, encoding); err != nil { return err } else if reader != nil && reader != http.NoBody { @@ -104,7 +104,7 @@ func (handler *Handler) prepareRequestBody(clientRequest *http.Request, encoding return nil } -func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, clientRequest *http.Request, serviced bool, encoding string) bool { +func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, clientRequest *http.Request, serviced bool, encoding Encoding) bool { if serviced { return false } @@ -124,35 +124,39 @@ func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, client } } -func (handler *Handler) ensureBodyContentEncoding(clientRequest *http.Request, encoding string) { - if encoding == "" || encoding == "identity" { +func (handler *Handler) ensureBodyContentEncoding(clientRequest *http.Request, encoding Encoding) { + switch encoding { + case Unsupported: + logger.Println("Error unsupported content-encoding") return - } - - servicedBody, err := io.ReadAll(clientRequest.Body) - if err != nil { - logger.Printf("Error reading request body: %s", err) - clientRequest.Body = http.NoBody + case Identity: return - } + case Gzip: + servicedBody, err := io.ReadAll(clientRequest.Body) + if err != nil { + logger.Printf("Error reading request body: %s", err) + clientRequest.Body = http.NoBody + return + } - if encodedData, err := EncodeData(servicedBody, encoding); err != nil { - logger.Printf("Error encoding request body: %s", err) - clientRequest.Body = http.NoBody - return - } else { - servicedBody = encodedData - } + if encodedData, err := EncodeData(servicedBody, encoding); err != nil { + logger.Printf("Error encoding request body: %s", err) + clientRequest.Body = http.NoBody + return + } else { + servicedBody = encodedData + } - // If the length of the body has changed, we should update the - // Content-Length header too. - contentLength := int64(len(servicedBody)) - if contentLength != clientRequest.ContentLength { - clientRequest.ContentLength = contentLength - clientRequest.Header.Set("Content-Length", strconv.FormatInt(contentLength, 10)) - } + // If the length of the body has changed, we should update the + // Content-Length header too. + contentLength := int64(len(servicedBody)) + if contentLength != clientRequest.ContentLength { + clientRequest.ContentLength = contentLength + clientRequest.Header.Set("Content-Length", strconv.FormatInt(contentLength, 10)) + } - clientRequest.Body = io.NopCloser(bytes.NewBuffer(servicedBody)) + clientRequest.Body = io.NopCloser(bytes.NewBuffer(servicedBody)) + } } diff --git a/relay/traffic/traffic_test.go b/relay/traffic/traffic_test.go index a0788ef..a1597ec 100644 --- a/relay/traffic/traffic_test.go +++ b/relay/traffic/traffic_test.go @@ -151,33 +151,26 @@ func TestMaxBodySize(t *testing.T) { }) } -type Encoding int - -const ( - Identity Encoding = iota - Gzip -) - func TestRelaySupportsContentEncoding(t *testing.T) { testCases := map[string]struct { - encoding Encoding + encoding traffic.Encoding bodyContentStr string headers map[string]string customUrl func(relayServiceURL string) string }{ "identity": { - encoding: Identity, + encoding: traffic.Identity, bodyContentStr: "Hello, world!", }, "gzip - with header": { - encoding: Gzip, + encoding: traffic.Gzip, bodyContentStr: "Hello, world!", headers: map[string]string{ "Content-Encoding": "gzip", }, }, "gzip - with query param": { - encoding: Gzip, + encoding: traffic.Gzip, bodyContentStr: "Hello, world!", customUrl: func(relayServiceURL string) string { return fmt.Sprintf("%v?ContentEncoding=gzip", relayServiceURL) @@ -190,14 +183,14 @@ func TestRelaySupportsContentEncoding(t *testing.T) { // convert the body content to a reader with the proper content encoding applied var body io.Reader switch testCase.encoding { - case Gzip: - b, err := traffic.EncodeData([]byte(testCase.bodyContentStr), "gzip") + case traffic.Gzip: + b, err := traffic.EncodeData([]byte(testCase.bodyContentStr), traffic.Gzip) if err != nil { t.Errorf("Test %s - Error encoding data: %v", desc, err) return } body = bytes.NewReader(b) - case Identity: + case traffic.Identity: body = strings.NewReader(testCase.bodyContentStr) } @@ -235,8 +228,8 @@ func TestRelaySupportsContentEncoding(t *testing.T) { } switch testCase.encoding { - case Gzip: - decodedData, err := traffic.DecodeData(lastRequest, "gzip") + case traffic.Gzip: + decodedData, err := traffic.DecodeData(lastRequest, traffic.Gzip) if err != nil { t.Errorf("Test %s - Error decoding data: %v", desc, err) return @@ -244,7 +237,7 @@ func TestRelaySupportsContentEncoding(t *testing.T) { if string(decodedData) != testCase.bodyContentStr { t.Errorf("Test %s - Expected body '%v' but got: %v", desc, testCase.bodyContentStr, string(decodedData)) } - case Identity: + case traffic.Identity: if string(lastRequest) != testCase.bodyContentStr { t.Errorf("Test %s - Expected body '%v' but got: %v", desc, testCase.bodyContentStr, string(lastRequest)) } From 0d112969e3c8c743458da24e0356b2eaf35f6689 Mon Sep 17 00:00:00 2001 From: Clint Ayres Date: Fri, 9 Feb 2024 17:23:59 +0000 Subject: [PATCH 5/7] fail request if back content encoding --- relay/traffic/handler.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/traffic/handler.go b/relay/traffic/handler.go index ce15a8a..6bb3948 100644 --- a/relay/traffic/handler.go +++ b/relay/traffic/handler.go @@ -60,7 +60,7 @@ func (handler *Handler) ServeHTTP(response http.ResponseWriter, request *http.Re encoding, err := GetContentEncoding(request) if err != nil { - logger.Printf("URL %v error in request content encoding: %v", request.URL, err) + http.Error(response, fmt.Sprintf("URL %v error in request content encoding: %v", request.URL, err), 500) request.Body = http.NoBody return } From 627b047fe294ec57416176e2e23643379376614a Mon Sep 17 00:00:00 2001 From: Clint Ayres Date: Mon, 12 Feb 2024 09:05:09 +0000 Subject: [PATCH 6/7] pr feedback (eugene) --- relay/traffic/encoding.go | 3 +-- relay/traffic/handler.go | 2 ++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/relay/traffic/encoding.go b/relay/traffic/encoding.go index b8e0e06..e15016e 100644 --- a/relay/traffic/encoding.go +++ b/relay/traffic/encoding.go @@ -47,8 +47,7 @@ func GetContentEncoding(request *http.Request) (Encoding, error) { } } -// WrapReader checks if the request Content-Encoding or request query parameter indicates gzip compression. -// If so, it returns a gzip.Reader that decompresses the content. +// WrapReader returns a wrapped request.Body for the encoding provided. func WrapReader(request *http.Request, encoding Encoding) (io.ReadCloser, error) { if request.Body == nil { return nil, nil diff --git a/relay/traffic/handler.go b/relay/traffic/handler.go index 6bb3948..89acdf7 100644 --- a/relay/traffic/handler.go +++ b/relay/traffic/handler.go @@ -124,6 +124,8 @@ func (handler *Handler) HandleRequest(clientResponse http.ResponseWriter, client } } +// ensureBodyContentEncoding operates on the assumption that the downstream proxy target will be using the same +// encoding as what the relay received and ensures we proxy the content encoded correctly. func (handler *Handler) ensureBodyContentEncoding(clientRequest *http.Request, encoding Encoding) { switch encoding { case Unsupported: From a52dacd80a3b7a25b971c5b76f5ea2b9a71619f8 Mon Sep 17 00:00:00 2001 From: Clint Ayres Date: Mon, 12 Feb 2024 09:28:24 +0000 Subject: [PATCH 7/7] bump version --- relay/version/version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/relay/version/version.go b/relay/version/version.go index 97bff11..c774670 100644 --- a/relay/version/version.go +++ b/relay/version/version.go @@ -1,3 +1,3 @@ package version -const RelayRelease = "v0.3.2" // TODO set this from tags automatically during git commit +const RelayRelease = "v0.3.3" // TODO set this from tags automatically during git commit