diff --git a/http/server.go b/http/server.go index 139bb63..a959d2b 100755 --- a/http/server.go +++ b/http/server.go @@ -25,6 +25,8 @@ package http import ( "net/http" "net/http/httptest" + "net/url" + "strings" "testing" "github.com/ryanuber/go-glob" @@ -39,6 +41,7 @@ const ( type Expectation struct { method string path string + qry *url.Values fn http.HandlerFunc @@ -46,12 +49,14 @@ type Expectation struct { body []byte status int - times int + times int + called int } // Times sets the number of times the request can be made. func (e *Expectation) Times(times int) *Expectation { e.times = times + e.called = times return e } @@ -111,15 +116,9 @@ func (s *Server) URL() string { return s.srv.URL } -func (s *Server) handler(w http.ResponseWriter, r *http.Request) { - method := r.Method - path := r.URL.Path +func (s *Server) handler(w http.ResponseWriter, req *http.Request) { for i, exp := range s.expect { - if exp.method != method && exp.method != Anything { - continue - } - - if exp.path != Anything && !glob.Glob(exp.path, path) { + if !requestMatches(req, exp) { continue } @@ -128,7 +127,7 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) { } if exp.fn != nil { - exp.fn(w, r) + exp.fn(w, req) } else { w.WriteHeader(exp.status) if len(exp.body) > 0 { @@ -136,22 +135,60 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) { } } - exp.times-- - if exp.times == 0 { + exp.called-- + if exp.called == 0 { s.expect = append(s.expect[:i], s.expect[i+1:]...) } return } - s.t.Errorf("Unexpected call to %s %s", method, path) + s.t.Errorf("Unexpected call to %s %s", req.Method, req.URL.String()) +} + +func requestMatches(req *http.Request, exp *Expectation) bool { + if exp.method != req.Method && exp.method != Anything { + return false + } + + if exp.path != Anything && !glob.Glob(exp.path, req.URL.Path) { + return false + } + + qry := req.URL.Query() + if exp.qry != nil { + found := false + for k, v := range *exp.qry { + if !qry.Has(k) { + break + } + if elementsMatch(v, qry[k]) { + found = true + } + } + if !found { + return false + } + } + + return true } // On creates an expectation of a request on the server. func (s *Server) On(method, path string) *Expectation { + var qry *url.Values + if parts := strings.SplitN(path, "?", 2); len(parts) == 2 { + path = parts[0] + if val, err := url.ParseQuery(parts[1]); err == nil { + qry = &val + } + } + exp := &Expectation{ method: method, path: path, + qry: qry, times: -1, + called: -1, status: 200, } s.expect = append(s.expect, exp) @@ -162,8 +199,29 @@ func (s *Server) On(method, path string) *Expectation { // AssertExpectations asserts all expectations have been met. func (s *Server) AssertExpectations() { for _, exp := range s.expect { - if exp.times > 0 || exp.times == -1 { - s.t.Errorf("mock: server: Expected a call to %s %s but got none", exp.method, exp.path) + var call string + if exp.method != Anything { + call = exp.method + } + if exp.path != Anything { + if call != "" { + call += " " + } + call += exp.path + } + if exp.qry != nil { + if call != "" || exp.path == Anything { + call += " " + } + call += exp.qry.Encode() + } + + switch exp.called { + case -1: + s.t.Errorf("Expected a call to %s but got none", call) + case 0: + default: + s.t.Errorf("Expected a call to %s %d times but got called %d times", call, exp.times, exp.times-exp.called) } } } @@ -172,3 +230,28 @@ func (s *Server) AssertExpectations() { func (s *Server) Close() { s.srv.Close() } + +func elementsMatch(a, b []string) bool { + aLen := len(a) + bLen := len(b) + + visited := make([]bool, bLen) + for i := 0; i < aLen; i++ { + found := false + element := a[i] + for j := 0; j < bLen; j++ { + if visited[j] { + continue + } + if element == b[j] { + visited[j] = true + found = true + break + } + } + if !found { + return false + } + } + return true +} diff --git a/http/server_test.go b/http/server_test.go index e532800..5167417 100755 --- a/http/server_test.go +++ b/http/server_test.go @@ -8,63 +8,74 @@ import ( httptest "github.com/hamba/testutils/http" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestServer_HandlesExpectation(t *testing.T) { s := httptest.NewServer(t) - defer s.Close() + t.Cleanup(s.Close) - s.On("GET", "/test/path") + s.On(http.MethodGet, "/test/path") res, err := http.Get(s.URL() + "/test/path") - assert.NoError(t, err) + require.NoError(t, err) + assert.Equal(t, 200, res.StatusCode) +} + +func TestServer_HandlesExpectationWithQuery(t *testing.T) { + s := httptest.NewServer(t) + t.Cleanup(s.Close) + + s.On(http.MethodGet, "/test/path?p=some%2Fpath") + + res, err := http.Get(s.URL() + "/test/path?p=some%2Fpath") + require.NoError(t, err) assert.Equal(t, 200, res.StatusCode) } func TestServer_HandlesAnythingMethodExpectation(t *testing.T) { s := httptest.NewServer(t) - defer s.Close() + t.Cleanup(s.Close) s.On(httptest.Anything, "/test/path") res, err := http.Post(s.URL()+"/test/path", "text/plain", bytes.NewReader([]byte{})) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 200, res.StatusCode) } func TestServer_HandlesAnythingPathExpectation(t *testing.T) { s := httptest.NewServer(t) - defer s.Close() + t.Cleanup(s.Close) - s.On("GET", httptest.Anything) + s.On(http.MethodGet, httptest.Anything) res, err := http.Get(s.URL() + "/test/path") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 200, res.StatusCode) } func TestServer_HandlesWildcardPathExpectation(t *testing.T) { s := httptest.NewServer(t) - defer s.Close() + t.Cleanup(s.Close) - s.On("GET", "/test/*") + s.On(http.MethodGet, "/test/*") res, err := http.Get(s.URL() + "/test/path") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 200, res.StatusCode) } func TestServer_HandlesUnexpectedMethodRequest(t *testing.T) { mockT := new(testing.T) - defer func() { + t.Cleanup(func() { if !mockT.Failed() { t.Error("Expected error when no expectation on request") } - - }() + }) s := httptest.NewServer(mockT) - defer s.Close() + t.Cleanup(s.Close) s.On("POST", "/") @@ -73,34 +84,48 @@ func TestServer_HandlesUnexpectedMethodRequest(t *testing.T) { func TestServer_HandlesUnexpectedPathRequest(t *testing.T) { mockT := new(testing.T) - defer func() { + t.Cleanup(func() { if !mockT.Failed() { t.Error("Expected error when no expectation on request") } - - }() + }) s := httptest.NewServer(mockT) - defer s.Close() - s.On("GET", "/foobar") + t.Cleanup(s.Close) + s.On(http.MethodGet, "/foobar") - s.On("GET", "/") + s.On(http.MethodGet, "/") _, _ = http.Get(s.URL() + "/test/path") } +func TestServer_HandlesUnexpectedPathQueryRequest(t *testing.T) { + mockT := new(testing.T) + t.Cleanup(func() { + if !mockT.Failed() { + t.Error("Expected error when no expectation on request") + } + }) + + s := httptest.NewServer(mockT) + t.Cleanup(s.Close) + s.On(http.MethodGet, "/test/path?a=other") + s.On(http.MethodGet, "/test/path?p=something") + + _, _ = http.Get(s.URL() + "/test/path?p=somethingelse") +} + func TestServer_HandlesExpectationNTimes(t *testing.T) { mockT := new(testing.T) - defer func() { + t.Cleanup(func() { if !mockT.Failed() { t.Error("Expected error when expectation times used") } - - }() + }) s := httptest.NewServer(mockT) - defer s.Close() - s.On("GET", "/test/path").Times(2) + t.Cleanup(s.Close) + s.On(http.MethodGet, "/test/path").Times(2) _, _ = http.Get(s.URL() + "/test/path") _, _ = http.Get(s.URL() + "/test/path") @@ -109,16 +134,15 @@ func TestServer_HandlesExpectationNTimes(t *testing.T) { func TestServer_HandlesExpectationUnlimitedTimes(t *testing.T) { mockT := new(testing.T) - defer func() { + t.Cleanup(func() { if mockT.Failed() { t.Error("Unexpected error on request") } - - }() + }) s := httptest.NewServer(mockT) - defer s.Close() - s.On("GET", "/test/path") + t.Cleanup(s.Close) + s.On(http.MethodGet, "/test/path") _, _ = http.Get(s.URL() + "/test/path") _, _ = http.Get(s.URL() + "/test/path") @@ -126,12 +150,12 @@ func TestServer_HandlesExpectationUnlimitedTimes(t *testing.T) { func TestServer_ExpectationReturnsBodyBytes(t *testing.T) { s := httptest.NewServer(t) - defer s.Close() + t.Cleanup(s.Close) - s.On("GET", "/test/path").Returns(400, []byte("test")) + s.On(http.MethodGet, "/test/path").Returns(400, []byte("test")) res, err := http.Get(s.URL() + "/test/path") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 400, res.StatusCode) b, _ := ioutil.ReadAll(res.Body) assert.Equal(t, []byte("test"), b) @@ -141,12 +165,12 @@ func TestServer_ExpectationReturnsBodyBytes(t *testing.T) { func TestServer_ExpectationReturnsBodyString(t *testing.T) { s := httptest.NewServer(t) - defer s.Close() + t.Cleanup(s.Close) - s.On("GET", "/test/path").ReturnsString(400, "test") + s.On(http.MethodGet, "/test/path").ReturnsString(400, "test") res, err := http.Get(s.URL() + "/test/path") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 400, res.StatusCode) b, _ := ioutil.ReadAll(res.Body) assert.Equal(t, []byte("test"), b) @@ -156,12 +180,12 @@ func TestServer_ExpectationReturnsBodyString(t *testing.T) { func TestServer_ExpectationReturnsStatusCode(t *testing.T) { s := httptest.NewServer(t) - defer s.Close() + t.Cleanup(s.Close) - s.On("GET", "/test/path").ReturnsStatus(400) + s.On(http.MethodGet, "/test/path").ReturnsStatus(400) res, err := http.Get(s.URL() + "/test/path") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 400, res.StatusCode) b, _ := ioutil.ReadAll(res.Body) assert.Len(t, b, 0) @@ -171,12 +195,12 @@ func TestServer_ExpectationReturnsStatusCode(t *testing.T) { func TestServer_ExpectationReturnsHeaders(t *testing.T) { s := httptest.NewServer(t) - defer s.Close() + t.Cleanup(s.Close) - s.On("GET", "/test/path").Header("foo", "bar").ReturnsStatus(200) + s.On(http.MethodGet, "/test/path").Header("foo", "bar").ReturnsStatus(200) res, err := http.Get(s.URL() + "/test/path") - assert.NoError(t, err) + require.NoError(t, err) v := res.Header.Get("foo") assert.Equal(t, "bar", v) @@ -185,45 +209,61 @@ func TestServer_ExpectationReturnsHeaders(t *testing.T) { func TestServer_ExpectationUsesHandleFunc(t *testing.T) { s := httptest.NewServer(t) - defer s.Close() + t.Cleanup(s.Close) - s.On("GET", "/test/path").Handle(func(w http.ResponseWriter, r *http.Request) { + s.On(http.MethodGet, "/test/path").Handle(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(400) }) res, err := http.Get(s.URL() + "/test/path") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, 400, res.StatusCode) } +func TestServer_AssertExpectations(t *testing.T) { + mockT := new(testing.T) + t.Cleanup(func() { + if mockT.Failed() { + t.Error("Expected no error when asserting expectations") + } + }) + + s := httptest.NewServer(mockT) + t.Cleanup(s.Close) + s.On(http.MethodGet, "/").Times(1) + + _, err := http.Get(s.URL() + "/") + assert.NoError(t, err) + + s.AssertExpectations() +} + func TestServer_AssertExpectationsOnUnlimited(t *testing.T) { mockT := new(testing.T) - defer func() { + t.Cleanup(func() { if !mockT.Failed() { t.Error("Expected error when asserting expectations") } - - }() + }) s := httptest.NewServer(mockT) - defer s.Close() - s.On("POST", "/") + t.Cleanup(s.Close) + s.On(http.MethodPost, "/") s.AssertExpectations() } func TestServer_AssertExpectationsOnNTimes(t *testing.T) { mockT := new(testing.T) - defer func() { + t.Cleanup(func() { if !mockT.Failed() { t.Error("Expected error when asserting expectations") } - - }() + }) s := httptest.NewServer(mockT) - defer s.Close() - s.On("POST", "/").Times(1) + t.Cleanup(s.Close) + s.On(http.MethodPost, "/").Times(1) s.AssertExpectations() }