Skip to content

Commit

Permalink
improve tripperware request id (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
instabledesign authored Sep 27, 2019
1 parent 642b4df commit 3c74fbd
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
3 changes: 2 additions & 1 deletion tripperware/request_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ func RequestId(config *request_id.Config) httpware.Tripperware {
}
if id == "" {
id = config.IdGenerator(req)
// add requestId header to current request
req.Header.Add(config.HeaderName, id)
}
r := req.WithContext(context.WithValue(req.Context(), config.HeaderName, id))
r.Header.Add(config.HeaderName, id)
return next.RoundTrip(r)
})
}
Expand Down
14 changes: 13 additions & 1 deletion tripperware/request_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package tripperware_test

import (
"fmt"
"math/rand"
"net/http"
"net/http/httptest"
"os"
"testing"

"github.com/stretchr/testify/assert"
Expand All @@ -15,6 +17,14 @@ import (
"github.com/gol4ng/httpware/tripperware"
)

func TestMain(m *testing.M){
request_id.DefaultIdGenerator = request_id.NewRandomIdGenerator(
rand.New(request_id.NewLockedSource(rand.NewSource(1))),
10,
)
os.Exit(m.Run())
}

func TestRequestId(t *testing.T) {
roundTripperMock := &mocks.RoundTripper{}
req := httptest.NewRequest(http.MethodGet, "http://fake-addr", nil)
Expand All @@ -27,11 +37,13 @@ func TestRequestId(t *testing.T) {
roundTripperMock.On("RoundTrip", mock.AnythingOfType("*http.Request")).Return(resp, nil).Run(func(args mock.Arguments) {
innerReq := args.Get(0).(*http.Request)
assert.True(t, len(innerReq.Header.Get(request_id.HeaderName)) == 10)
assert.Equal(t, req.Header.Get(request_id.HeaderName), innerReq.Header.Get(request_id.HeaderName))
})

resp2, err := tripperware.RequestId(request_id.NewConfig())(roundTripperMock).RoundTrip(req)
assert.Nil(t, err)
assert.Equal(t, resp, resp2)
assert.Equal(t, "p1LGIehp1s", req.Header.Get(request_id.HeaderName))
}

func TestRequestIdCustom(t *testing.T) {
Expand Down Expand Up @@ -95,7 +107,7 @@ func ExampleRequestId() {
}
}()

_, _ = client.Get("http://localhost"+port+"/")
_, _ = client.Get("http://localhost" + port + "/")

// Output: server receive request with request id: my-generated-id
}

0 comments on commit 3c74fbd

Please sign in to comment.