Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support scrubbing gzip encoded bundles #36

Merged
merged 7 commits into from
Feb 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions catcher/catcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
Expand Down Expand Up @@ -82,7 +81,7 @@ func (service *Service) LastRequestBody() ([]byte, error) {
}

defer request.Body.Close()
body, err := ioutil.ReadAll(request.Body)
body, err := io.ReadAll(request.Body)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note ioutil is deprecated and internally just calls the io package, so I've updated all usages of ioutil.

if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions relay/main/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@ package main

import (
"flag"
"io/ioutil"
"io"
"log"
"os"
"time"

"github.com/fullstorydev/relay-core/relay"
"github.com/fullstorydev/relay-core/relay/config"
"github.com/fullstorydev/relay-core/relay/environment"
"github.com/fullstorydev/relay-core/relay/traffic/plugin-loader"
plugin_loader "github.com/fullstorydev/relay-core/relay/traffic/plugin-loader"
)

var logger = log.New(os.Stdout, "[relay] ", 0)

func readConfigFile(path string) (rawConfigFileBytes []byte, err error) {
if path == "-" {
rawConfigFileBytes, err = ioutil.ReadAll(os.Stdin)
rawConfigFileBytes, err = io.ReadAll(os.Stdin)
return
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ package content_blocker_plugin
import (
"bytes"
"fmt"
"io/ioutil"
"io"
"log"
"net/http"
"os"
Expand Down Expand Up @@ -78,7 +78,7 @@ func (f contentBlockerPluginFactory) New(configSection *config.Section) (traffic
}

if regexp, err := regexp.Compile(pattern); err != nil {
return fmt.Errorf(`Could not compile regular expression "%v": %v`, pattern, err)
return fmt.Errorf(`could not compile regular expression "%v": %v`, pattern, err)
} else {
logger.Printf("Added rule: %s %s content matching \"%s\"", mode, contentKind, regexp)
blockers = append(blockers, &contentBlocker{
Expand All @@ -94,7 +94,7 @@ func (f contentBlockerPluginFactory) New(configSection *config.Section) (traffic
case "header":
plugin.headerBlockers = append(plugin.headerBlockers, blockers...)
default:
return fmt.Errorf(`Unexpected content kind %s`, contentKind)
return fmt.Errorf(`unexpected content kind %s`, contentKind)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note capitalization of error messages is a style lint warning for Golang so fixed these as well generally.

https://google.github.io/styleguide/go/decisions.html#error-strings

}

return nil
Expand Down Expand Up @@ -222,28 +222,26 @@ func (plug contentBlockerPlugin) blockBodyContent(response http.ResponseWriter,
return false
}

processedBody, err := ioutil.ReadAll(request.Body)
processedBody, err := io.ReadAll(request.Body)
if err != nil {
http.Error(response, fmt.Sprintf("Error reading request body: %s", err), 500)
request.Body = http.NoBody
return true
}
initialLength := len(processedBody)

for _, blocker := range plug.bodyBlockers {
processedBody = blocker.Block(processedBody)
}

// If the length of the body has changed, we should update the
// Content-Length header too.
finalLength := len(processedBody)
if finalLength != initialLength {
contentLength := int64(finalLength)
contentLength := int64(len(processedBody))
if contentLength != request.ContentLength {
Comment on lines +238 to +239
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small change to directly verify the content length is valid instead of indirectly through the existing body content.

request.ContentLength = contentLength
request.Header.Set("Content-Length", strconv.FormatInt(contentLength, 10))
}

request.Body = ioutil.NopCloser(bytes.NewBuffer(processedBody))
request.Body = io.NopCloser(bytes.NewBuffer(processedBody))
return false
}

Expand Down Expand Up @@ -283,7 +281,7 @@ func (b *contentBlocker) Block(content []byte) []byte {
case excludeMode:
return b.regexp.ReplaceAllLiteral(content, []byte{})
default:
panic(fmt.Errorf("Invalid content blocking mode: %v", b.mode))
panic(fmt.Errorf("invalid content blocking mode: %v", b.mode))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package content_blocker_plugin_test

import (
"bytes"
"fmt"
"net/http"
"strconv"
"testing"

"github.com/fullstorydev/relay-core/catcher"
"github.com/fullstorydev/relay-core/relay"
"github.com/fullstorydev/relay-core/relay/plugins/traffic/content-blocker-plugin"
content_blocker_plugin "github.com/fullstorydev/relay-core/relay/plugins/traffic/content-blocker-plugin"
"github.com/fullstorydev/relay-core/relay/test"
"github.com/fullstorydev/relay-core/relay/traffic"
"github.com/fullstorydev/relay-core/relay/version"
Expand Down Expand Up @@ -133,7 +134,8 @@ func TestContentBlocking(t *testing.T) {
}

for _, testCase := range testCases {
runContentBlockerTest(t, testCase)
runContentBlockerTest(t, testCase, traffic.Identity)
runContentBlockerTest(t, testCase, traffic.Gzip)
}
}

Expand Down Expand Up @@ -185,7 +187,18 @@ type contentBlockerTestCase struct {
expectedHeaders map[string]string
}

func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase) {
func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase, encoding traffic.Encoding) {
var encodingStr string
switch encoding {
case traffic.Gzip:
encodingStr = "gzip"
case traffic.Identity:
encodingStr = ""
}

// Add encoding to the test description
desc := fmt.Sprintf("%s (encoding: %v)", testCase.desc, encodingStr)

plugins := []traffic.PluginFactory{
content_blocker_plugin.Factory,
}
Expand All @@ -203,36 +216,46 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase) {
expectedHeaders[content_blocker_plugin.PluginVersionHeaderName] = version.RelayRelease

test.WithCatcherAndRelay(t, testCase.config, plugins, func(catcherService *catcher.Service, relayService *relay.Service) {
b, err := traffic.EncodeData([]byte(testCase.originalBody), encoding)
if err != nil {
t.Errorf("Test '%v': Error encoding data: %v", desc, err)
return
}

request, err := http.NewRequest(
"POST",
relayService.HttpUrl(),
bytes.NewBufferString(testCase.originalBody),
bytes.NewBuffer(b),
)
if err != nil {
t.Errorf("Test '%v': Error creating request: %v", testCase.desc, err)
t.Errorf("Test '%v': Error creating request: %v", desc, err)
return
}

if encoding == traffic.Gzip {
request.Header.Set("Content-Encoding", "gzip")
}

request.Header.Set("Content-Type", "application/json")
for header, headerValue := range originalHeaders {
request.Header.Set(header, headerValue)
}

response, err := http.DefaultClient.Do(request)
if err != nil {
t.Errorf("Test '%v': Error POSTing: %v", testCase.desc, err)
t.Errorf("Test '%v': Error POSTing: %v", desc, err)
return
}
defer response.Body.Close()

if response.StatusCode != 200 {
t.Errorf("Test '%v': Expected 200 response: %v", testCase.desc, response)
t.Errorf("Test '%v': Expected 200 response: %v", desc, response)
return
}

lastRequest, err := catcherService.LastRequest()
if err != nil {
t.Errorf("Test '%v': Error reading last request from catcher: %v", testCase.desc, err)
t.Errorf("Test '%v': Error reading last request from catcher: %v", desc, err)
return
}

Expand All @@ -241,43 +264,58 @@ func runContentBlockerTest(t *testing.T, testCase contentBlockerTestCase) {
if expectedHeaderValue != actualHeaderValue {
t.Errorf(
"Test '%v': Expected header '%v' with value '%v' but got: %v",
testCase.desc,
desc,
expectedHeader,
expectedHeaderValue,
actualHeaderValue,
)
}
}

if lastRequest.Header.Get("Content-Encoding") != encodingStr {
t.Errorf(
"Test '%v': Expected Content-Encoding '%v' but got: %v",
desc,
encodingStr,
lastRequest.Header.Get("Content-Encoding"),
)
}

lastRequestBody, err := catcherService.LastRequestBody()
if err != nil {
t.Errorf("Test '%v': Error reading last request body from catcher: %v", testCase.desc, err)
t.Errorf("Test '%v': Error reading last request body from catcher: %v", desc, err)
return
}

lastRequestBodyStr := string(lastRequestBody)
if testCase.expectedBody != lastRequestBodyStr {
t.Errorf(
"Test '%v': Expected body '%v' but got: %v",
testCase.desc,
testCase.expectedBody,
lastRequestBodyStr,
)
}

contentLength, err := strconv.Atoi(lastRequest.Header.Get("Content-Length"))
if err != nil {
t.Errorf("Test '%v': Error parsing Content-Length: %v", testCase.desc, err)
t.Errorf("Test '%v': Error parsing Content-Length: %v", desc, err)
return
}

if contentLength != len(lastRequestBody) {
t.Errorf(
"Test '%v': Content-Length is %v but actual body length is %v",
testCase.desc,
desc,
contentLength,
len(lastRequestBody),
)
}

decodedRequestBody, err := traffic.DecodeData(lastRequestBody, encoding)
if err != nil {
t.Errorf("Test '%v': Error decoding data: %v", desc, err)
return
}

lastRequestBodyStr := string(decodedRequestBody)
if testCase.expectedBody != lastRequestBodyStr {
t.Errorf(
"Test '%v': Expected body '%v' but got: %v",
desc,
testCase.expectedBody,
lastRequestBodyStr,
)
}
})
}
2 changes: 1 addition & 1 deletion relay/plugins/traffic/paths-plugin/paths-plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ func runPathsPluginTest(t *testing.T, testCase pathsPluginTestCase) {
lastRequest, err = altCatcherService.LastRequest()
}
if err != nil {
t.Errorf("Error reading last request from catcher: %v", err)
t.Errorf("Text '%v': Error reading last request from catcher: %v", testCase.desc, err)
return
}

Expand Down
Loading
Loading