diff --git a/internal/trace/transport.go b/internal/trace/transport.go index 2ab827bfc..13a088b7c 100644 --- a/internal/trace/transport.go +++ b/internal/trace/transport.go @@ -38,7 +38,7 @@ var ( ) // payloadSizeLimit limits the maximum size of the response body to be printed. -const payloadSizeLimit int64 = 4 * 1024 * 1024 // 4 MiB +const payloadSizeLimit int64 = 4 * 1024 // 4 KiB // Transport is an http.RoundTripper that keeps track of the in-flight // request and add hooks to report HTTP tracing events. @@ -113,8 +113,13 @@ func logResponseBody(resp *http.Response) string { if err != nil { return fmt.Sprintf(" Error reading response body: %v", err) } + // restore the body by concatenating the read body with the remaining body - resp.Body = io.NopCloser(io.MultiReader(bytes.NewReader(readBody), resp.Body)) + closeFunc := resp.Body.Close + resp.Body = &readCloser{ + Reader: io.MultiReader(bytes.NewReader(readBody), resp.Body), + closeFunc: closeFunc, + } if len(readBody) == 0 { return " Response body is empty" @@ -142,3 +147,15 @@ func isPrintableContentType(contentType string) bool { } return false } + +// readCloser returns an io.ReadCloser that wraps an io.Reader and a +// close function. +type readCloser struct { + io.Reader + closeFunc func() error +} + +// Close closes the readCloser. +func (rc *readCloser) Close() error { + return rc.closeFunc() +} diff --git a/internal/trace/transport_test.go b/internal/trace/transport_test.go index 5f267d9ba..3fc348ec4 100644 --- a/internal/trace/transport_test.go +++ b/internal/trace/transport_test.go @@ -17,12 +17,17 @@ package trace import ( "bytes" - "fmt" + "errors" "io" "net/http" "testing" ) +var ( + mockReadErr = errors.New("mock read error") + mockCloseErr = errors.New("mock close error") +) + func Test_isPrintableContentType(t *testing.T) { tests := []struct { name string @@ -117,9 +122,10 @@ func Test_isPrintableContentType(t *testing.T) { func Test_logResponseBody(t *testing.T) { tests := []struct { - name string - resp *http.Response - want string + name string + resp *http.Response + want string + wantData []byte }{ { name: "Nil body", @@ -130,7 +136,8 @@ func Test_logResponseBody(t *testing.T) { want: " No response body to print", }, { - name: "No body", + name: "No body", + wantData: nil, resp: &http.Response{ Body: http.NoBody, ContentLength: 100, // in case of HEAD response, the content length is set but the body is empty @@ -139,7 +146,8 @@ func Test_logResponseBody(t *testing.T) { want: " No response body to print", }, { - name: "Empty body", + name: "Empty body", + wantData: []byte(""), resp: &http.Response{ Body: io.NopCloser(bytes.NewReader([]byte(""))), ContentLength: 0, @@ -148,7 +156,8 @@ func Test_logResponseBody(t *testing.T) { want: " Response body is empty", }, { - name: "Unknown content length", + name: "Unknown content length", + wantData: []byte("whatever"), resp: &http.Response{ Body: io.NopCloser(bytes.NewReader([]byte("whatever"))), ContentLength: -1, @@ -157,7 +166,8 @@ func Test_logResponseBody(t *testing.T) { want: "whatever", }, { - name: "Non-printable content type", + name: "Non-printable content type", + wantData: []byte("binary data"), resp: &http.Response{ Body: io.NopCloser(bytes.NewReader([]byte("binary data"))), ContentLength: 11, @@ -166,7 +176,8 @@ func Test_logResponseBody(t *testing.T) { want: " Response body of content type \"application/octet-stream\" is not printed", }, { - name: "Body at the limit", + name: "Body at the limit", + wantData: bytes.Repeat([]byte("a"), int(payloadSizeLimit)), resp: &http.Response{ Body: io.NopCloser(bytes.NewReader(bytes.Repeat([]byte("a"), int(payloadSizeLimit)))), ContentLength: payloadSizeLimit, @@ -175,7 +186,8 @@ func Test_logResponseBody(t *testing.T) { want: string(bytes.Repeat([]byte("a"), int(payloadSizeLimit))), }, { - name: "Body larger than limit", + name: "Body larger than limit", + wantData: bytes.Repeat([]byte("a"), int(payloadSizeLimit+1)), resp: &http.Response{ Body: io.NopCloser(bytes.NewReader(bytes.Repeat([]byte("a"), int(payloadSizeLimit+1)))), // 1 byte larger than limit ContentLength: payloadSizeLimit + 1, @@ -184,7 +196,8 @@ func Test_logResponseBody(t *testing.T) { want: string(bytes.Repeat([]byte("a"), int(payloadSizeLimit))) + "\n...(truncated)", }, { - name: "Printable content type within limit", + name: "Printable content type within limit", + wantData: []byte("data"), resp: &http.Response{ Body: io.NopCloser(bytes.NewReader([]byte("data"))), ContentLength: 4, @@ -193,7 +206,8 @@ func Test_logResponseBody(t *testing.T) { want: "data", }, { - name: "Actual body size is larger than content length", + name: "Actual body size is larger than content length", + wantData: []byte("data"), resp: &http.Response{ Body: io.NopCloser(bytes.NewReader([]byte("data"))), ContentLength: 3, // mismatched content length @@ -202,7 +216,8 @@ func Test_logResponseBody(t *testing.T) { want: "data", }, { - name: "Actual body size is larger than content length and exceeds limit", + name: "Actual body size is larger than content length and exceeds limit", + wantData: bytes.Repeat([]byte("a"), int(payloadSizeLimit+1)), resp: &http.Response{ Body: io.NopCloser(bytes.NewReader(bytes.Repeat([]byte("a"), int(payloadSizeLimit+1)))), // 1 byte larger than limit ContentLength: 1, // mismatched content length @@ -211,7 +226,8 @@ func Test_logResponseBody(t *testing.T) { want: string(bytes.Repeat([]byte("a"), int(payloadSizeLimit))) + "\n...(truncated)", }, { - name: "Actual body size is smaller than content length", + name: "Actual body size is smaller than content length", + wantData: []byte("data"), resp: &http.Response{ Body: io.NopCloser(bytes.NewReader([]byte("data"))), ContentLength: 5, // mismatched content length @@ -219,6 +235,36 @@ func Test_logResponseBody(t *testing.T) { }, want: "data", }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := logResponseBody(tt.resp); got != tt.want { + t.Errorf("logResponseBody() = %v, want %v", got, tt.want) + } + // validate the response body + if tt.resp.Body != nil { + readBytes, err := io.ReadAll(tt.resp.Body) + if err != nil { + t.Errorf("failed to read body after logResponseBody(), err= %v", err) + } + if !bytes.Equal(readBytes, tt.wantData) { + t.Errorf("resp.Body after logResponseBody() = %v, want %v", readBytes, tt.wantData) + } + if closeErr := tt.resp.Body.Close(); closeErr != nil { + t.Errorf("failed to close body after logResponseBody(), err= %v", closeErr) + } + } + }) + } +} + +func Test_logResponseBody_error(t *testing.T) { + tests := []struct { + name string + resp *http.Response + want string + }{ { name: "Error reading body", resp: &http.Response{ @@ -235,6 +281,71 @@ func Test_logResponseBody(t *testing.T) { if got := logResponseBody(tt.resp); got != tt.want { t.Errorf("logResponseBody() = %v, want %v", got, tt.want) } + if closeErr := tt.resp.Body.Close(); closeErr != nil { + t.Errorf("failed to close body after logResponseBody(), err= %v", closeErr) + } + }) + } +} + +func Test_readCloser_Close(t *testing.T) { + + tests := []struct { + name string + reader io.Reader + closeFunc func() error + wantData []byte + wantReadErr error + wantCloseErr error + }{ + { + name: "successfully read and close", + wantData: []byte("data"), + reader: bytes.NewReader([]byte("data")), + closeFunc: func() error { + return nil + }, + wantReadErr: nil, + wantCloseErr: nil, + }, + { + name: "error reading", + wantData: nil, + reader: &errorReader{}, + closeFunc: func() error { + return nil + }, + wantReadErr: mockReadErr, + wantCloseErr: nil, + }, + { + name: "error closing", + wantData: []byte("data"), + reader: bytes.NewReader([]byte("data")), + closeFunc: func() error { + return mockCloseErr + }, + wantReadErr: nil, + wantCloseErr: mockCloseErr, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rc := &readCloser{ + Reader: tt.reader, + closeFunc: tt.closeFunc, + } + got, err := io.ReadAll(rc) + if err != tt.wantReadErr { + t.Errorf("readCloser.ReadAll() error = %v, wantErr %v", err, tt.wantReadErr) + } + if !bytes.Equal(got, tt.wantData) { + t.Errorf("readCloser.ReadAll() = %v, want %v", got, tt.wantData) + } + if err := rc.Close(); err != tt.wantCloseErr { + t.Errorf("readCloser.Close() error = %v, wantErr %v", err, tt.wantCloseErr) + } }) } } @@ -242,5 +353,5 @@ func Test_logResponseBody(t *testing.T) { type errorReader struct{} func (e *errorReader) Read(p []byte) (n int, err error) { - return 0, fmt.Errorf("mock error") + return 0, mockReadErr }