Skip to content

Commit

Permalink
feat: support http queries (#18)
Browse files Browse the repository at this point in the history
  • Loading branch information
nrwiersma authored Jul 15, 2022
1 parent f9cd6cd commit c89bad0
Show file tree
Hide file tree
Showing 2 changed files with 194 additions and 71 deletions.
113 changes: 98 additions & 15 deletions http/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ package http
import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/ryanuber/go-glob"
Expand All @@ -39,19 +41,22 @@ const (
type Expectation struct {
method string
path string
qry *url.Values

fn http.HandlerFunc

headers []string
body []byte
status int

times int
times int
called int
}

// Times sets the number of times the request can be made.
func (e *Expectation) Times(times int) *Expectation {
e.times = times
e.called = times

return e
}
Expand Down Expand Up @@ -111,15 +116,9 @@ func (s *Server) URL() string {
return s.srv.URL
}

func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
method := r.Method
path := r.URL.Path
func (s *Server) handler(w http.ResponseWriter, req *http.Request) {
for i, exp := range s.expect {
if exp.method != method && exp.method != Anything {
continue
}

if exp.path != Anything && !glob.Glob(exp.path, path) {
if !requestMatches(req, exp) {
continue
}

Expand All @@ -128,30 +127,68 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
}

if exp.fn != nil {
exp.fn(w, r)
exp.fn(w, req)
} else {
w.WriteHeader(exp.status)
if len(exp.body) > 0 {
_, _ = w.Write(exp.body)
}
}

exp.times--
if exp.times == 0 {
exp.called--
if exp.called == 0 {
s.expect = append(s.expect[:i], s.expect[i+1:]...)
}
return
}

s.t.Errorf("Unexpected call to %s %s", method, path)
s.t.Errorf("Unexpected call to %s %s", req.Method, req.URL.String())
}

func requestMatches(req *http.Request, exp *Expectation) bool {
if exp.method != req.Method && exp.method != Anything {
return false
}

if exp.path != Anything && !glob.Glob(exp.path, req.URL.Path) {
return false
}

qry := req.URL.Query()
if exp.qry != nil {
found := false
for k, v := range *exp.qry {
if !qry.Has(k) {
break
}
if elementsMatch(v, qry[k]) {
found = true
}
}
if !found {
return false
}
}

return true
}

// On creates an expectation of a request on the server.
func (s *Server) On(method, path string) *Expectation {
var qry *url.Values
if parts := strings.SplitN(path, "?", 2); len(parts) == 2 {
path = parts[0]
if val, err := url.ParseQuery(parts[1]); err == nil {
qry = &val
}
}

exp := &Expectation{
method: method,
path: path,
qry: qry,
times: -1,
called: -1,
status: 200,
}
s.expect = append(s.expect, exp)
Expand All @@ -162,8 +199,29 @@ func (s *Server) On(method, path string) *Expectation {
// AssertExpectations asserts all expectations have been met.
func (s *Server) AssertExpectations() {
for _, exp := range s.expect {
if exp.times > 0 || exp.times == -1 {
s.t.Errorf("mock: server: Expected a call to %s %s but got none", exp.method, exp.path)
var call string
if exp.method != Anything {
call = exp.method
}
if exp.path != Anything {
if call != "" {
call += " "
}
call += exp.path
}
if exp.qry != nil {
if call != "" || exp.path == Anything {
call += " "
}
call += exp.qry.Encode()
}

switch exp.called {
case -1:
s.t.Errorf("Expected a call to %s but got none", call)
case 0:
default:
s.t.Errorf("Expected a call to %s %d times but got called %d times", call, exp.times, exp.times-exp.called)
}
}
}
Expand All @@ -172,3 +230,28 @@ func (s *Server) AssertExpectations() {
func (s *Server) Close() {
s.srv.Close()
}

func elementsMatch(a, b []string) bool {
aLen := len(a)
bLen := len(b)

visited := make([]bool, bLen)
for i := 0; i < aLen; i++ {
found := false
element := a[i]
for j := 0; j < bLen; j++ {
if visited[j] {
continue
}
if element == b[j] {
visited[j] = true
found = true
break
}
}
if !found {
return false
}
}
return true
}
Loading

0 comments on commit c89bad0

Please sign in to comment.