Skip to content

Commit

Permalink
add request body to into key in post request
Browse files Browse the repository at this point in the history
  • Loading branch information
Victor Springer committed Jul 21, 2019
1 parent 9c541ee commit f0c859b
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 5 deletions.
18 changes: 18 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"errors"
"fmt"
"hash/fnv"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
Expand Down Expand Up @@ -89,6 +90,15 @@ func (c *Client) Middleware(next http.Handler) http.Handler {
if c.cacheableMethod(r.Method) {
sortURLParams(r.URL)
key := generateKey(r.URL.String())
if r.Method == http.MethodPost && r.Body != nil {
body, err := ioutil.ReadAll(r.Body)
defer r.Body.Close()
if err != nil {
next.ServeHTTP(w, r)
return
}
key = generateKeyWithBody(r.URL.String(), body)
}

params := r.URL.Query()
if _, ok := params[c.refreshKey]; ok {
Expand Down Expand Up @@ -197,6 +207,14 @@ func generateKey(URL string) uint64 {
return hash.Sum64()
}

func generateKeyWithBody(URL string, body []byte) uint64 {
hash := fnv.New64a()
body = append([]byte(URL), body...)
hash.Write(body)

return hash.Sum64()
}

// NewClient initializes the cache HTTP middleware client with the given
// options.
func NewClient(opts ...ClientOption) (*Client, error) {
Expand Down
110 changes: 105 additions & 5 deletions cache_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package cache

import (
"bytes"
"errors"
"fmt"
"net/http"
"net/http/httptest"
Expand All @@ -16,6 +18,8 @@ type adapterMock struct {
store map[uint64][]byte
}

type errReader int

func (a *adapterMock) Get(key uint64) ([]byte, bool) {
a.Lock()
defer a.Unlock()
Expand All @@ -37,6 +41,10 @@ func (a *adapterMock) Release(key uint64) {
delete(a.store, key)
}

func (errReader) Read(p []byte) (n int, err error) {
return 0, errors.New("readAll error")
}

func TestMiddleware(t *testing.T) {
counter := 0
httpTestHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -57,13 +65,18 @@ func TestMiddleware(t *testing.T) {
Value: []byte("value 3"),
Expiration: time.Now().Add(-1 * time.Minute),
}.Bytes(),
10956846073361780255: Response{
Value: []byte("value 4"),
Expiration: time.Now().Add(-1 * time.Minute),
}.Bytes(),
},
}

client, _ := NewClient(
ClientWithAdapter(adapter),
ClientWithTTL(1*time.Minute),
ClientWithRefreshKey("rk"),
ClientWithMethods([]string{http.MethodGet, http.MethodPost}),
)

handler := client.Middleware(httpTestHandler)
Expand All @@ -72,74 +85,126 @@ func TestMiddleware(t *testing.T) {
name string
url string
method string
body []byte
wantBody string
wantCode int
}{
{
"returns cached response",
"http://foo.bar/test-1",
"GET",
nil,
"value 1",
200,
},
{
"returns new response",
"http://foo.bar/test-2",
"POST",
"PUT",
nil,
"new value 2",
200,
},
{
"returns cached response",
"http://foo.bar/test-2",
"GET",
nil,
"value 2",
200,
},
{
"returns new response",
"http://foo.bar/test-3?zaz=baz&baz=zaz",
"GET",
nil,
"new value 4",
200,
},
{
"returns cached response",
"http://foo.bar/test-3?baz=zaz&zaz=baz",
"GET",
nil,
"new value 4",
200,
},
{
"cache expired",
"http://foo.bar/test-3",
"GET",
nil,
"new value 6",
200,
},
{
"releases cached response and returns new response",
"http://foo.bar/test-2?rk=true",
"GET",
nil,
"new value 7",
200,
},
{
"returns new cached response",
"http://foo.bar/test-2",
"GET",
nil,
"new value 7",
200,
},
{
"returns new cached response",
"http://foo.bar/test-2",
"POST",
[]byte(`{"foo": "bar"}`),
"new value 9",
200,
},
{
"returns new cached response",
"http://foo.bar/test-2",
"POST",
[]byte(`{"foo": "bar"}`),
"new value 9",
200,
},
{
"ignores request body",
"http://foo.bar/test-2",
"GET",
[]byte(`{"foo": "bar"}`),
"new value 7",
200,
},
{
"returns new response",
"http://foo.bar/test-2",
"POST",
[]byte(`{"foo": "bar"}`),
"new value 12",
200,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
counter++
var r *http.Request
var err error

r, err := http.NewRequest(tt.method, tt.url, nil)
if err != nil {
t.Error(err)
return
if counter != 12 {
reader := bytes.NewReader(tt.body)
r, err = http.NewRequest(tt.method, tt.url, reader)
if err != nil {
t.Error(err)
return
}
} else {
r, err = http.NewRequest(tt.method, tt.url, errReader(0))
if err != nil {
t.Error(err)
return
}
}

w := httptest.NewRecorder()
Expand Down Expand Up @@ -289,6 +354,41 @@ func TestGenerateKey(t *testing.T) {
}
}

func TestGenerateKeyWithBody(t *testing.T) {
tests := []struct {
name string
URL string
body []byte
want uint64
}{
{
"get POST checksum",
"http://foo.bar/test-1",
[]byte(`{"foo": "bar"}`),
16224051135567554746,
},
{
"get POST 2 checksum",
"http://foo.bar/test-1",
[]byte(`{"bar": "foo"}`),
3604153880186288164,
},
{
"get POST 3 checksum",
"http://foo.bar/test-2",
[]byte(`{"foo": "bar"}`),
10956846073361780255,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := generateKeyWithBody(tt.URL, tt.body); got != tt.want {
t.Errorf("generateKeyWithBody() = %v, want %v", got, tt.want)
}
})
}
}

func TestNewClient(t *testing.T) {
adapter := &adapterMock{}

Expand Down

0 comments on commit f0c859b

Please sign in to comment.