diff --git a/README.md b/README.md index 39d4bd8..9e4eb09 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,7 @@ func main() { cache.ClientWithAdapter(memcached), cache.ClientWithTTL(10 * time.Minute), cache.ClientWithRefreshKey("opn"), + cache.ClientWithNonCachedHeaders([]string{"geo-country"}), ) if err != nil { fmt.Println(err) @@ -137,4 +138,4 @@ http-cache memory adapter takes way less GC pause time, that means smaller GC ov - [Redis adapter](https://godoc.org/github.com/victorspringer/http-cache/adapter/redis) ## License -http-cache is released under the [MIT License](https://github.com/victorspringer/http-cache/blob/master/LICENSE). \ No newline at end of file +http-cache is released under the [MIT License](https://github.com/victorspringer/http-cache/blob/master/LICENSE). diff --git a/cache.go b/cache.go index 6cdb8ad..fe11c9f 100644 --- a/cache.go +++ b/cache.go @@ -62,10 +62,11 @@ type Response struct { // Client data structure for HTTP cache middleware. type Client struct { - adapter Adapter - ttl time.Duration - refreshKey string - methods []string + adapter Adapter + ttl time.Duration + refreshKey string + methods []string + nonCacheableHeaders []string } // ClientOption is used to set Client settings. @@ -89,7 +90,9 @@ func (c *Client) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if c.cacheableMethod(r.Method) { sortURLParams(r.URL) - key := generateKey(r.URL.String()) + headerValues := extractHeaders(c.nonCacheableHeaders, r.Header) + + key := generateKey(r.URL.String(), headerValues) if r.Method == http.MethodPost && r.Body != nil { body, err := ioutil.ReadAll(r.Body) defer r.Body.Close() @@ -98,7 +101,7 @@ func (c *Client) Middleware(next http.Handler) http.Handler { return } reader := ioutil.NopCloser(bytes.NewBuffer(body)) - key = generateKeyWithBody(r.URL.String(), body) + key = generateKeyWithBody(r.URL.String(), headerValues, body) r.Body = reader } @@ -107,7 +110,7 @@ func (c *Client) Middleware(next http.Handler) http.Handler { delete(params, c.refreshKey) r.URL.RawQuery = params.Encode() - key = generateKey(r.URL.String()) + key = generateKey(r.URL.String(), headerValues) c.adapter.Release(key) } else { @@ -202,21 +205,49 @@ func KeyAsString(key uint64) string { return strconv.FormatUint(key, 36) } -func generateKey(URL string) uint64 { +func generateKey(URL string, headerValues []string) uint64 { + buffer := bytes.Buffer{} + buffer.WriteString(URL) + + for _, value := range headerValues { + buffer.WriteString(value) + } + hash := fnv.New64a() - hash.Write([]byte(URL)) + hash.Write(buffer.Bytes()) return hash.Sum64() } -func generateKeyWithBody(URL string, body []byte) uint64 { +func generateKeyWithBody(URL string, headerValues []string, body []byte) uint64 { + buffer := bytes.Buffer{} + buffer.WriteString(URL) + + for _, value := range headerValues { + buffer.WriteString(value) + } + + buffer.Write(body) + hash := fnv.New64a() - body = append([]byte(URL), body...) - hash.Write(body) + hash.Write([]byte(buffer.String())) return hash.Sum64() } +func extractHeaders(nonCachedHeaders []string, headers http.Header) []string { + var headerValues []string + + for _, nonCachedHeader := range nonCachedHeaders { + headerValue, ok := headers[nonCachedHeader] + if ok { + headerValues = append(headerValues, headerValue...) + } + } + + return headerValues +} + // NewClient initializes the cache HTTP middleware client with the given // options. func NewClient(opts ...ClientOption) (*Client, error) { @@ -285,3 +316,12 @@ func ClientWithMethods(methods []string) ClientOption { return nil } } + +// ClientWithNonCacheableHeaders sets the un-cacheable headers. +// If you provide []string{"geo-country"} cache will be missed with different "geo-country" header but same URL. +func ClientWithNonCacheableHeaders(headers []string) ClientOption { + return func(c *Client) error { + c.nonCacheableHeaders = headers + return nil + } +} diff --git a/cache_test.go b/cache_test.go index 8ebca53..81f0b35 100644 --- a/cache_test.go +++ b/cache_test.go @@ -53,22 +53,26 @@ func TestMiddleware(t *testing.T) { adapter := &adapterMock{ store: map[uint64][]byte{ - 14974843192121052621: Response{ + generateKey("http://foo.bar/test-1", nil): Response{ Value: []byte("value 1"), Expiration: time.Now().Add(1 * time.Minute), }.Bytes(), - 14974839893586167988: Response{ + generateKey("http://foo.bar/test-2", nil): Response{ Value: []byte("value 2"), Expiration: time.Now().Add(1 * time.Minute), }.Bytes(), - 14974840993097796199: Response{ + generateKey("http://foo.bar/test-3", nil): Response{ Value: []byte("value 3"), Expiration: time.Now().Add(-1 * time.Minute), }.Bytes(), - 10956846073361780255: Response{ + generateKey("http://foo.bar/test-4", nil): Response{ Value: []byte("value 4"), Expiration: time.Now().Add(-1 * time.Minute), }.Bytes(), + generateKey("http://foo.bar/test-5", []string{"test5"}): Response{ + Value: []byte("value 5"), + Expiration: time.Now().Add(1 * time.Minute), + }.Bytes(), }, } @@ -77,6 +81,7 @@ func TestMiddleware(t *testing.T) { ClientWithTTL(1*time.Minute), ClientWithRefreshKey("rk"), ClientWithMethods([]string{http.MethodGet, http.MethodPost}), + ClientWithNonCacheableHeaders([]string{"country"}), ) handler := client.Middleware(httpTestHandler) @@ -86,6 +91,7 @@ func TestMiddleware(t *testing.T) { url string method string body []byte + headers http.Header wantBody string wantCode int }{ @@ -94,6 +100,7 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-1", "GET", nil, + http.Header{}, "value 1", 200, }, @@ -102,6 +109,7 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-2", "PUT", nil, + http.Header{}, "new value 2", 200, }, @@ -110,6 +118,7 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-2", "GET", nil, + http.Header{}, "value 2", 200, }, @@ -118,6 +127,7 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-3?zaz=baz&baz=zaz", "GET", nil, + http.Header{}, "new value 4", 200, }, @@ -126,6 +136,7 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-3?baz=zaz&zaz=baz", "GET", nil, + http.Header{}, "new value 4", 200, }, @@ -134,15 +145,26 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-3", "GET", nil, + http.Header{}, "new value 6", 200, }, + { + "returns cached response", + "http://foo.bar/test-5", + "GET", + []byte(``), + http.Header{"country": {"test5"}}, + "value 5", + 200, + }, { "releases cached response and returns new response", "http://foo.bar/test-2?rk=true", "GET", nil, - "new value 7", + http.Header{}, + "new value 8", 200, }, { @@ -150,7 +172,8 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-2", "GET", nil, - "new value 7", + http.Header{}, + "new value 8", 200, }, { @@ -158,7 +181,8 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-2", "POST", []byte(`{"foo": "bar"}`), - "new value 9", + http.Header{}, + "new value 10", 200, }, { @@ -166,7 +190,8 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-2", "POST", []byte(`{"foo": "bar"}`), - "new value 9", + http.Header{}, + "new value 10", 200, }, { @@ -174,7 +199,8 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-2", "GET", []byte(`{"foo": "bar"}`), - "new value 7", + http.Header{}, + "new value 8", 200, }, { @@ -182,7 +208,8 @@ func TestMiddleware(t *testing.T) { "http://foo.bar/test-2", "POST", []byte(`{"foo": "bar"}`), - "new value 12", + http.Header{}, + "new value 13", 200, }, } @@ -192,19 +219,22 @@ func TestMiddleware(t *testing.T) { var r *http.Request var err error - if counter != 12 { + if counter != 13 { reader := bytes.NewReader(tt.body) r, err = http.NewRequest(tt.method, tt.url, reader) if err != nil { t.Error(err) return } + r.Header = tt.headers } else { r, err = http.NewRequest(tt.method, tt.url, errReader(0)) if err != nil { t.Error(err) return } + + r.Header = tt.headers } w := httptest.NewRecorder() @@ -309,11 +339,21 @@ func TestGenerateKeyString(t *testing.T) { "http://localhost:8080/category", "http://localhost:8080/category/morisco", "http://localhost:8080/category/mourisquinho", + "http://localhost:8080/category/mourisquinho", + "http://localhost:8080/category/mourisquinho", + } + + headers := [][]string{ + {}, + {}, + {}, + {"test1"}, + {"test1", "test2"}, } keys := make(map[string]string, len(urls)) - for _, u := range urls { - rawKey := generateKey(u) + for i, u := range urls { + rawKey := generateKey(u, headers[i]) key := KeyAsString(rawKey) if otherURL, found := keys[key]; found { @@ -325,29 +365,39 @@ func TestGenerateKeyString(t *testing.T) { func TestGenerateKey(t *testing.T) { tests := []struct { - name string - URL string - want uint64 + name string + URL string + nonCachedHeaderValues []string + want uint64 }{ { "get url checksum", "http://foo.bar/test-1", + []string{}, 14974843192121052621, }, { "get url 2 checksum", "http://foo.bar/test-2", + []string{}, 14974839893586167988, }, { "get url 3 checksum", "http://foo.bar/test-3", + []string{}, 14974840993097796199, }, + { + "get url checksum with non-cached headers", + "http://foo.bar/test-3", + []string{"value1", "value2"}, + 6093834678676844634, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := generateKey(tt.URL); got != tt.want { + if got := generateKey(tt.URL, tt.nonCachedHeaderValues); got != tt.want { t.Errorf("generateKey() = %v, want %v", got, tt.want) } }) @@ -356,33 +406,44 @@ func TestGenerateKey(t *testing.T) { func TestGenerateKeyWithBody(t *testing.T) { tests := []struct { - name string - URL string - body []byte - want uint64 + name string + URL string + nonCachedHeaderValues []string + body []byte + want uint64 }{ { "get POST checksum", "http://foo.bar/test-1", + []string{}, []byte(`{"foo": "bar"}`), 16224051135567554746, }, { "get POST 2 checksum", "http://foo.bar/test-1", + []string{}, []byte(`{"bar": "foo"}`), 3604153880186288164, }, { "get POST 3 checksum", "http://foo.bar/test-2", + []string{}, []byte(`{"foo": "bar"}`), 10956846073361780255, }, + { + "get POST 3 checksum with cached headers", + "http://foo.bar/test-2", + []string{"value1", "value2"}, + []byte(`{"foo": "bar"}`), + 16634781976963392442, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := generateKeyWithBody(tt.URL, tt.body); got != tt.want { + if got := generateKeyWithBody(tt.URL, tt.nonCachedHeaderValues, tt.body); got != tt.want { t.Errorf("generateKeyWithBody() = %v, want %v", got, tt.want) } }) @@ -479,3 +540,35 @@ func TestNewClient(t *testing.T) { }) } } + +func Test_extractHeaders(t *testing.T) { + type args struct { + nonCachedHeaders []string + headers http.Header + } + tests := []struct { + name string + args args + want []string + }{ + { + "general", + args{ + []string{"test1", "test2"}, + http.Header{ + "test1": []string{"test1Value1", "test1Value2"}, + "test2": []string{"test2Value1"}, + "test3": []string{"test3Value1"}, + }, + }, + []string{"test1Value1", "test1Value2", "test2Value1"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := extractHeaders(tt.args.nonCachedHeaders, tt.args.headers); !reflect.DeepEqual(got, tt.want) { + t.Errorf("extractHeaders() = %v, want %v", got, tt.want) + } + }) + } +}