diff --git a/fallthrough_test.go b/fallthrough_test.go new file mode 100644 index 0000000..4868189 --- /dev/null +++ b/fallthrough_test.go @@ -0,0 +1,93 @@ +package httptreemux + +import ( + "net/http" + "net/http/httptest" + "reflect" + "testing" +) + +// When we find a node with a matching path but no handler for a method, +// we should fall through and continue searching the tree for a less specific +// match, i.e. a wildcard or catchall, that does have a handler for that method. +func TestMethodNotAllowedFallthrough(t *testing.T) { + var matchedMethod string + var matchedPath string + var matchedParams map[string]string + + router := New() + + addRoute := func(method, path string) { + router.Handle(method, path, func(w http.ResponseWriter, r *http.Request, params map[string]string) { + matchedMethod = method + matchedPath = path + matchedParams = params + }) + } + + checkRoute := func(method, path, expectedMethod, expectedPath string, + expectedCode int, expectedParams map[string]string) { + + matchedMethod = "" + matchedPath = "" + matchedParams = nil + + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, path, nil) + router.ServeHTTP(w, r) + if expectedCode != w.Code { + t.Errorf("%s %s expected code %d, saw %d", method, path, expectedCode, w.Code) + } + + if w.Code == 200 { + if matchedMethod != method || matchedPath != expectedPath { + t.Errorf("%s %s expected %s %s, saw %s %s", method, path, + expectedMethod, expectedPath, matchedMethod, matchedPath) + } + + if !reflect.DeepEqual(matchedParams, expectedParams) { + t.Errorf("%s %s expected params %+v, saw %+v", method, path, expectedParams, matchedParams) + } + } + } + + addRoute("GET", "/apple/banana/cat") + addRoute("GET", "/apple/potato") + addRoute("POST", "/apple/banana/:abc") + addRoute("POST", "/apple/ban/def") + addRoute("DELETE", "/apple/:seed") + addRoute("DELETE", "/apple/*path") + addRoute("OPTIONS", "/apple/*path") + + checkRoute("GET", "/apple/banana/cat", "GET", "/apple/banana/cat", 200, nil) + checkRoute("POST", "/apple/banana/cat", "POST", "/apple/banana/:abc", 200, + map[string]string{"abc": "cat"}) + checkRoute("POST", "/apple/banana/dog", "POST", "/apple/banana/:abc", 200, + map[string]string{"abc": "dog"}) + + // Wildcards should be checked before catchalls + checkRoute("DELETE", "/apple/banana", "DELETE", "/apple/:seed", 200, + map[string]string{"seed": "banana"}) + checkRoute("DELETE", "/apple/banana/cat", "DELETE", "/apple/*path", 200, + map[string]string{"path": "banana/cat"}) + + checkRoute("POST", "/apple/ban/def", "POST", "/apple/ban/def", 200, nil) + checkRoute("OPTIONS", "/apple/ban/def", "OPTIONS", "/apple/*path", 200, + map[string]string{"path": "ban/def"}) + checkRoute("GET", "/apple/ban/def", "", "", 405, nil) + + // Always fallback to the matching handler no matter how many other + // nodes without proper handlers are found on the way. + checkRoute("OPTIONS", "/apple/banana/cat", "OPTIONS", "/apple/*path", 200, + map[string]string{"path": "banana/cat"}) + checkRoute("OPTIONS", "/apple/bbbb", "OPTIONS", "/apple/*path", 200, + map[string]string{"path": "bbbb"}) + + // Nothing matches on patch + checkRoute("PATCH", "/apple/banana/cat", "", "", 405, nil) + checkRoute("PATCH", "/apple/potato", "", "", 405, nil) + + // And some 404 tests for good measure + checkRoute("GET", "/abc", "", "", 404, nil) + checkRoute("OPTIONS", "/apple", "", "", 404, nil) +} diff --git a/group.go b/group.go index 6f8c076..aea7715 100644 --- a/group.go +++ b/group.go @@ -102,7 +102,11 @@ func (g *Group) Handle(method string, path string, handler HandlerFunc) { if addSlash { node.addSlash = true } - node.setHandler(method, handler) + node.setHandler(method, handler, false) + + if g.mux.HeadCanUseGet && method == "GET" && node.leafHandler["HEAD"] == nil { + node.setHandler("HEAD", handler, true) + } } // Syntactic sugar for Handle("GET", path, handler) diff --git a/group_test.go b/group_test.go index 0736201..9065d5c 100644 --- a/group_test.go +++ b/group_test.go @@ -56,7 +56,8 @@ func TestSubGroupEmptyMapping(t *testing.T) { func TestGroupMethods(t *testing.T) { for _, scenario := range scenarios { t.Log(scenario.description) - testGroupMethods(t, scenario.RequestCreator) + testGroupMethods(t, scenario.RequestCreator, false) + testGroupMethods(t, scenario.RequestCreator, true) } } @@ -88,7 +89,7 @@ func TestInvalidPath(t *testing.T) { } //Liberally borrowed from router_test -func testGroupMethods(t *testing.T, reqGen RequestCreator) { +func testGroupMethods(t *testing.T, reqGen RequestCreator, headCanUseGet bool) { var result string makeHandler := func(method string) HandlerFunc { return func(w http.ResponseWriter, r *http.Request, params map[string]string) { @@ -96,6 +97,7 @@ func testGroupMethods(t *testing.T, reqGen RequestCreator) { } } router := New() + router.HeadCanUseGet = headCanUseGet // Testing with a sub-group of a group as that will test everything at once g := router.NewGroup("/base").NewGroup("/user") g.GET("/:param", makeHandler("GET")) @@ -110,7 +112,7 @@ func testGroupMethods(t *testing.T, reqGen RequestCreator) { r, _ := reqGen(method, "/base/user/"+method, nil) router.ServeHTTP(w, r) if expect == "" && w.Code != http.StatusMethodNotAllowed { - t.Errorf("Method %s not expected to match but saw code %d", w.Code) + t.Errorf("Method %s not expected to match but saw code %d", method, w.Code) } if result != expect { @@ -123,18 +125,43 @@ func testGroupMethods(t *testing.T, reqGen RequestCreator) { testMethod("PATCH", "PATCH") testMethod("PUT", "PUT") testMethod("DELETE", "DELETE") - t.Log("Test HeadCanUseGet = true") - testMethod("HEAD", "GET") - - router.HeadCanUseGet = false - t.Log("Test HeadCanUseGet = false") - testMethod("HEAD", "") + if headCanUseGet { + t.Log("Test implicit HEAD with HeadCanUseGet = true") + testMethod("HEAD", "GET") + } else { + t.Log("Test implicit HEAD with HeadCanUseGet = false") + testMethod("HEAD", "") + } router.HEAD("/base/user/:param", makeHandler("HEAD")) - - t.Log("Test HeadCanUseGet = false with explicit HEAD handler") testMethod("HEAD", "HEAD") +} + +// Ensure that setting a GET handler doesn't overwrite an explciit HEAD handler. +func TestSetGetAfterHead(t *testing.T) { + var result string + makeHandler := func(method string) HandlerFunc { + return func(w http.ResponseWriter, r *http.Request, params map[string]string) { + result = method + } + } + + router := New() router.HeadCanUseGet = true - t.Log("Test HeadCanUseGet = true with explicit HEAD handler") + router.HEAD("/abc", makeHandler("HEAD")) + router.GET("/abc", makeHandler("GET")) + + testMethod := func(method, expect string) { + result = "" + w := httptest.NewRecorder() + r, _ := http.NewRequest(method, "/abc", nil) + router.ServeHTTP(w, r) + + if result != expect { + t.Errorf("Method %s got result %s", method, result) + } + } + testMethod("HEAD", "HEAD") + testMethod("GET", "GET") } diff --git a/router.go b/router.go index 90cd76e..2337ec6 100644 --- a/router.go +++ b/router.go @@ -178,13 +178,13 @@ func (t *TreeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { if trailingSlash && t.RedirectTrailingSlash { path = path[:pathLen-1] } - n, params := t.root.search(path[1:]) + n, handler, params := t.root.search(r.Method, path[1:]) if n == nil { if t.RedirectCleanPath { // Path was not found. Try cleaning it up and search again. // TODO Test this cleanPath := httppath.Clean(path) - n, params = t.root.search(cleanPath[1:]) + n, handler, params = t.root.search(r.Method, cleanPath[1:]) if n == nil { // Still nothing found. t.NotFoundHandler(w, r) @@ -202,16 +202,12 @@ func (t *TreeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } - handler, ok := n.leafHandler[r.Method] - if !ok { - if r.Method == "HEAD" && t.HeadCanUseGet { - handler, ok = n.leafHandler["GET"] - } else if r.Method == "OPTIONS" && t.OptionsHandler != nil { + if handler == nil { + if r.Method == "OPTIONS" && t.OptionsHandler != nil { handler = t.OptionsHandler - ok = true } - if !ok { + if handler == nil { t.MethodNotAllowedHandler(w, r, n.leafHandler) return } diff --git a/router_test.go b/router_test.go index 3393b3b..8c1fde0 100644 --- a/router_test.go +++ b/router_test.go @@ -67,11 +67,12 @@ func benchRequest(b *testing.B, router http.Handler, r *http.Request) { func TestMethods(t *testing.T) { for _, scenario := range scenarios { t.Log(scenario.description) - testMethods(t, scenario.RequestCreator) + testMethods(t, scenario.RequestCreator, true) + testMethods(t, scenario.RequestCreator, false) } } -func testMethods(t *testing.T, newRequest RequestCreator) { +func testMethods(t *testing.T, newRequest RequestCreator, headCanUseGet bool) { var result string makeHandler := func(method string) HandlerFunc { @@ -81,6 +82,7 @@ func testMethods(t *testing.T, newRequest RequestCreator) { } router := New() + router.HeadCanUseGet = headCanUseGet router.GET("/user/:param", makeHandler("GET")) router.POST("/user/:param", makeHandler("POST")) router.PATCH("/user/:param", makeHandler("PATCH")) @@ -93,7 +95,7 @@ func testMethods(t *testing.T, newRequest RequestCreator) { r, _ := newRequest(method, "/user/"+method, nil) router.ServeHTTP(w, r) if expect == "" && w.Code != http.StatusMethodNotAllowed { - t.Errorf("Method %s not expected to match but saw code %d", w.Code) + t.Errorf("Method %s not expected to match but saw code %d", method, w.Code) } if result != expect { @@ -106,19 +108,15 @@ func testMethods(t *testing.T, newRequest RequestCreator) { testMethod("PATCH", "PATCH") testMethod("PUT", "PUT") testMethod("DELETE", "DELETE") - t.Log("Test HeadCanUseGet = true") - testMethod("HEAD", "GET") - - router.HeadCanUseGet = false - t.Log("Test HeadCanUseGet = false") - testMethod("HEAD", "") + if headCanUseGet { + t.Log("Test implicit HEAD with HeadCanUseGet = true") + testMethod("HEAD", "GET") + } else { + t.Log("Test implicit HEAD with HeadCanUseGet = false") + testMethod("HEAD", "") + } router.HEAD("/user/:param", makeHandler("HEAD")) - - t.Log("Test HeadCanUseGet = false with explicit HEAD handler") - testMethod("HEAD", "HEAD") - router.HeadCanUseGet = true - t.Log("Test HeadCanUseGet = true with explicit HEAD handler") testMethod("HEAD", "HEAD") } @@ -157,7 +155,7 @@ func TestMethodNotAllowedHandler(t *testing.T) { calledNotAllowed = true - expected := []string{"GET", "PUT", "DELETE"} + expected := []string{"GET", "PUT", "DELETE", "HEAD"} allowed := make([]string, 0) for m := range methods { allowed = append(allowed, m) @@ -188,7 +186,7 @@ func TestMethodNotAllowedHandler(t *testing.T) { allowed := w.Header()["Allow"] sort.Strings(allowed) - expected := []string{"DELETE", "GET", "PUT"} + expected := []string{"DELETE", "GET", "PUT", "HEAD"} sort.Strings(expected) if !reflect.DeepEqual(allowed, expected) { @@ -271,7 +269,7 @@ func TestOptionsHandler(t *testing.T) { allowed := w.Header()["Allow"] sort.Strings(allowed) - expected := []string{"DELETE", "GET", "PUT"} + expected := []string{"DELETE", "GET", "PUT", "HEAD"} sort.Strings(expected) if !reflect.DeepEqual(allowed, expected) { diff --git a/tree.go b/tree.go index 64e22ed..5b06158 100644 --- a/tree.go +++ b/tree.go @@ -25,6 +25,8 @@ type node struct { addSlash bool isCatchAll bool + // If true, the head handler was set implicitly, so let it also be set explicitly. + implicitHead bool // If this node is the end of the URL, then call the handler, if applicable. leafHandler map[string]HandlerFunc @@ -40,15 +42,19 @@ func (n *node) sortStaticChild(i int) { } } -func (n *node) setHandler(verb string, handler HandlerFunc) { +func (n *node) setHandler(verb string, handler HandlerFunc, implicitHead bool) { if n.leafHandler == nil { n.leafHandler = make(map[string]HandlerFunc) } _, ok := n.leafHandler[verb] - if ok { + if ok && (verb != "HEAD" || !n.implicitHead) { panic(fmt.Sprintf("%s already handles %s", n.path, verb)) } n.leafHandler[verb] = handler + + if verb == "HEAD" { + n.implicitHead = implicitHead + } } func (n *node) addPath(path string, wildcards []string) *node { @@ -103,7 +109,7 @@ func (n *node) addPath(path string, wildcards []string) *node { } if path[1:] != n.catchAllChild.path { - panic(fmt.Sprintf("Catch-all name in %s doesn't match %s", + panic(fmt.Sprintf("Catch-all name in %s doesn't match %s. You probably tried to define overlapping catchalls", path, n.catchAllChild.path)) } @@ -206,16 +212,16 @@ func (n *node) splitCommonPrefix(existingNodeIndex int, path string) (*node, int return newNode, i } -func (n *node) search(path string) (found *node, params []string) { +func (n *node) search(method, path string) (found *node, handler HandlerFunc, params []string) { // if test != nil { // test.Logf("Searching for %s in %s", path, n.dumpTree("", "")) // } pathLen := len(path) if pathLen == 0 { if len(n.leafHandler) == 0 { - return nil, nil + return nil, nil, nil } else { - return n, nil + return n, n.leafHandler[method], nil } } @@ -227,13 +233,15 @@ func (n *node) search(path string) (found *node, params []string) { childPathLen := len(child.path) if pathLen >= childPathLen && child.path == path[:childPathLen] { nextPath := path[childPathLen:] - found, params = child.search(nextPath) + found, handler, params = child.search(method, nextPath) } break } } - if found != nil { + // If we found a node and it had a valid handler, then return here. Otherwise + // let's remember that we found this one, but look for a better match. + if handler != nil { return } @@ -248,36 +256,52 @@ func (n *node) search(path string) (found *node, params []string) { nextToken := path[nextSlash:] if len(thisToken) > 0 { // Don't match on empty tokens. - found, params = n.wildcardChild.search(nextToken) - if found != nil { + wcNode, wcHandler, wcParams := n.wildcardChild.search(method, nextToken) + if wcHandler != nil || (found == nil && wcNode != nil) { unescaped, err := url.QueryUnescape(thisToken) if err != nil { unescaped = thisToken } - if params == nil { - params = []string{unescaped} + if wcParams == nil { + wcParams = []string{unescaped} } else { - params = append(params, unescaped) + wcParams = append(wcParams, unescaped) + } + + if wcHandler != nil { + return wcNode, wcHandler, wcParams } - return + // Didn't actually find a handler here, so remember that we + // found a node but also see if we can fall through to the + // catchall. + found = wcNode + handler = wcHandler + params = wcParams } } } catchAllChild := n.catchAllChild if catchAllChild != nil { - // Hit the catchall, so just assign the whole remaining path. - unescaped, err := url.QueryUnescape(path) - if err != nil { - unescaped = path + // Hit the catchall, so just assign the whole remaining path if it + // has a matching handler. + handler = catchAllChild.leafHandler[method] + // Found a handler, or we found a catchall node without a handler. + // Either way, return it since there's nothing left to check after this. + if handler != nil || found == nil { + unescaped, err := url.QueryUnescape(path) + if err != nil { + unescaped = path + } + + return catchAllChild, handler, []string{unescaped} } - return catchAllChild, []string{unescaped} } - return nil, nil + return found, handler, params } func (n *node) dumpTree(prefix, nodeType string) string { diff --git a/tree_test.go b/tree_test.go index f068b0b..7935ebb 100644 --- a/tree_test.go +++ b/tree_test.go @@ -16,7 +16,7 @@ func addPath(t *testing.T, tree *node, path string) { handler := func(w http.ResponseWriter, r *http.Request, urlParams map[string]string) { urlParams["path"] = path } - n.setHandler("GET", handler) + n.setHandler("GET", handler, false) } var test *testing.T @@ -30,7 +30,7 @@ func testPath(t *testing.T, tree *node, path string, expectPath string, expected expectCatchAll := strings.Contains(expectPath, "/*") t.Log("Testing", path) - n, paramList := tree.search(path[1:]) + n, foundHandler, paramList := tree.search("GET", path[1:]) if expectPath != "" && n == nil { t.Errorf("No match for %s, expected %s", path, expectPath) return @@ -55,6 +55,12 @@ func testPath(t *testing.T, tree *node, path string, expectPath string, expected return } + if foundHandler == nil { + t.Errorf("Path %s returned valid node but foundHandler was false", path) + t.Error("Node and subtree was\n" + n.dumpTree("", " ")) + return + } + pathMap := make(map[string]string) handler(nil, nil, pathMap) matchedPath := pathMap["path"] @@ -66,7 +72,7 @@ func testPath(t *testing.T, tree *node, path string, expectPath string, expected if expectedParams == nil { if len(paramList) != 0 { - t.Errorf("Path %p expected no parameters, saw %v", path, paramList) + t.Errorf("Path %s expected no parameters, saw %v", path, paramList) } } else { if len(paramList) != len(n.leafWildcardNames) { @@ -264,8 +270,8 @@ func TestPanics(t *testing.T) { sawPanic = false defer panicHandler() tree := &node{path: "/"} - tree.setHandler("GET", dummyHandler) - tree.setHandler("GET", dummyHandler) + tree.setHandler("GET", dummyHandler, false) + tree.setHandler("GET", dummyHandler, false) }() if !sawPanic { t.Error("Expected panic when adding a duplicate handler for a pattern") @@ -306,32 +312,47 @@ func TestPanics(t *testing.T) { func BenchmarkTreeNullRequest(b *testing.B) { b.ReportAllocs() - tree := &node{path: "/"} + tree := &node{ + path: "/", + leafHandler: map[string]HandlerFunc{ + "GET": dummyHandler, + }, + } b.ResetTimer() for i := 0; i < b.N; i++ { - tree.search("") + tree.search("GET", "") } } func BenchmarkTreeOneStatic(b *testing.B) { b.ReportAllocs() - tree := &node{path: "/"} + tree := &node{ + path: "/", + leafHandler: map[string]HandlerFunc{ + "GET": dummyHandler, + }, + } tree.addPath("abc", nil) b.ResetTimer() for i := 0; i < b.N; i++ { - tree.search("abc") + tree.search("GET", "abc") } } func BenchmarkTreeOneParam(b *testing.B) { + tree := &node{ + path: "/", + leafHandler: map[string]HandlerFunc{ + "GET": dummyHandler, + }, + } b.ReportAllocs() - tree := &node{path: "/"} tree.addPath(":abc", nil) b.ResetTimer() for i := 0; i < b.N; i++ { - tree.search("abc") + tree.search("GET", "abc") } }