Skip to content

Commit

Permalink
Match less specific routes when appropriate.
Browse files Browse the repository at this point in the history
When a node is found but it does not handle the requested HTTP method,
look for a less specific match in the rest of the tree that does handle
that method.

This also changes in the behavior of HeadCanUseGet, which is now
checked at route add time instead of at request time. This means that
if you are setting HeadCanUseGet to false, you must do so before
adding your routes.

Fixes #27
  • Loading branch information
dimfeld committed Mar 19, 2016
1 parent f061259 commit 837a149
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 70 deletions.
93 changes: 93 additions & 0 deletions fallthrough_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package httptreemux

import (
"net/http"
"net/http/httptest"
"reflect"
"testing"
)

// When we find a node with a matching path but no handler for a method,
// we should fall through and continue searching the tree for a less specific
// match, i.e. a wildcard or catchall, that does have a handler for that method.
func TestMethodNotAllowedFallthrough(t *testing.T) {
var matchedMethod string
var matchedPath string
var matchedParams map[string]string

router := New()

addRoute := func(method, path string) {
router.Handle(method, path, func(w http.ResponseWriter, r *http.Request, params map[string]string) {
matchedMethod = method
matchedPath = path
matchedParams = params
})
}

checkRoute := func(method, path, expectedMethod, expectedPath string,
expectedCode int, expectedParams map[string]string) {

matchedMethod = ""
matchedPath = ""
matchedParams = nil

w := httptest.NewRecorder()
r, _ := http.NewRequest(method, path, nil)
router.ServeHTTP(w, r)
if expectedCode != w.Code {
t.Errorf("%s %s expected code %d, saw %d", method, path, expectedCode, w.Code)
}

if w.Code == 200 {
if matchedMethod != method || matchedPath != expectedPath {
t.Errorf("%s %s expected %s %s, saw %s %s", method, path,
expectedMethod, expectedPath, matchedMethod, matchedPath)
}

if !reflect.DeepEqual(matchedParams, expectedParams) {
t.Errorf("%s %s expected params %+v, saw %+v", method, path, expectedParams, matchedParams)
}
}
}

addRoute("GET", "/apple/banana/cat")
addRoute("GET", "/apple/potato")
addRoute("POST", "/apple/banana/:abc")
addRoute("POST", "/apple/ban/def")
addRoute("DELETE", "/apple/:seed")
addRoute("DELETE", "/apple/*path")
addRoute("OPTIONS", "/apple/*path")

checkRoute("GET", "/apple/banana/cat", "GET", "/apple/banana/cat", 200, nil)
checkRoute("POST", "/apple/banana/cat", "POST", "/apple/banana/:abc", 200,
map[string]string{"abc": "cat"})
checkRoute("POST", "/apple/banana/dog", "POST", "/apple/banana/:abc", 200,
map[string]string{"abc": "dog"})

// Wildcards should be checked before catchalls
checkRoute("DELETE", "/apple/banana", "DELETE", "/apple/:seed", 200,
map[string]string{"seed": "banana"})
checkRoute("DELETE", "/apple/banana/cat", "DELETE", "/apple/*path", 200,
map[string]string{"path": "banana/cat"})

checkRoute("POST", "/apple/ban/def", "POST", "/apple/ban/def", 200, nil)
checkRoute("OPTIONS", "/apple/ban/def", "OPTIONS", "/apple/*path", 200,
map[string]string{"path": "ban/def"})
checkRoute("GET", "/apple/ban/def", "", "", 405, nil)

// Always fallback to the matching handler no matter how many other
// nodes without proper handlers are found on the way.
checkRoute("OPTIONS", "/apple/banana/cat", "OPTIONS", "/apple/*path", 200,
map[string]string{"path": "banana/cat"})
checkRoute("OPTIONS", "/apple/bbbb", "OPTIONS", "/apple/*path", 200,
map[string]string{"path": "bbbb"})

// Nothing matches on patch
checkRoute("PATCH", "/apple/banana/cat", "", "", 405, nil)
checkRoute("PATCH", "/apple/potato", "", "", 405, nil)

// And some 404 tests for good measure
checkRoute("GET", "/abc", "", "", 404, nil)
checkRoute("OPTIONS", "/apple", "", "", 404, nil)
}
6 changes: 5 additions & 1 deletion group.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ func (g *Group) Handle(method string, path string, handler HandlerFunc) {
if addSlash {
node.addSlash = true
}
node.setHandler(method, handler)
node.setHandler(method, handler, false)

if g.mux.HeadCanUseGet && method == "GET" && node.leafHandler["HEAD"] == nil {
node.setHandler("HEAD", handler, true)
}
}

// Syntactic sugar for Handle("GET", path, handler)
Expand Down
51 changes: 39 additions & 12 deletions group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ func TestSubGroupEmptyMapping(t *testing.T) {
func TestGroupMethods(t *testing.T) {
for _, scenario := range scenarios {
t.Log(scenario.description)
testGroupMethods(t, scenario.RequestCreator)
testGroupMethods(t, scenario.RequestCreator, false)
testGroupMethods(t, scenario.RequestCreator, true)
}
}

Expand Down Expand Up @@ -88,14 +89,15 @@ func TestInvalidPath(t *testing.T) {
}

//Liberally borrowed from router_test
func testGroupMethods(t *testing.T, reqGen RequestCreator) {
func testGroupMethods(t *testing.T, reqGen RequestCreator, headCanUseGet bool) {
var result string
makeHandler := func(method string) HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, params map[string]string) {
result = method
}
}
router := New()
router.HeadCanUseGet = headCanUseGet
// Testing with a sub-group of a group as that will test everything at once
g := router.NewGroup("/base").NewGroup("/user")
g.GET("/:param", makeHandler("GET"))
Expand All @@ -110,7 +112,7 @@ func testGroupMethods(t *testing.T, reqGen RequestCreator) {
r, _ := reqGen(method, "/base/user/"+method, nil)
router.ServeHTTP(w, r)
if expect == "" && w.Code != http.StatusMethodNotAllowed {
t.Errorf("Method %s not expected to match but saw code %d", w.Code)
t.Errorf("Method %s not expected to match but saw code %d", method, w.Code)
}

if result != expect {
Expand All @@ -123,18 +125,43 @@ func testGroupMethods(t *testing.T, reqGen RequestCreator) {
testMethod("PATCH", "PATCH")
testMethod("PUT", "PUT")
testMethod("DELETE", "DELETE")
t.Log("Test HeadCanUseGet = true")
testMethod("HEAD", "GET")

router.HeadCanUseGet = false
t.Log("Test HeadCanUseGet = false")
testMethod("HEAD", "")
if headCanUseGet {
t.Log("Test implicit HEAD with HeadCanUseGet = true")
testMethod("HEAD", "GET")
} else {
t.Log("Test implicit HEAD with HeadCanUseGet = false")
testMethod("HEAD", "")
}

router.HEAD("/base/user/:param", makeHandler("HEAD"))

t.Log("Test HeadCanUseGet = false with explicit HEAD handler")
testMethod("HEAD", "HEAD")
}

// Ensure that setting a GET handler doesn't overwrite an explciit HEAD handler.
func TestSetGetAfterHead(t *testing.T) {
var result string
makeHandler := func(method string) HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, params map[string]string) {
result = method
}
}

router := New()
router.HeadCanUseGet = true
t.Log("Test HeadCanUseGet = true with explicit HEAD handler")
router.HEAD("/abc", makeHandler("HEAD"))
router.GET("/abc", makeHandler("GET"))

testMethod := func(method, expect string) {
result = ""
w := httptest.NewRecorder()
r, _ := http.NewRequest(method, "/abc", nil)
router.ServeHTTP(w, r)

if result != expect {
t.Errorf("Method %s got result %s", method, result)
}
}

testMethod("HEAD", "HEAD")
testMethod("GET", "GET")
}
14 changes: 5 additions & 9 deletions router.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,13 @@ func (t *TreeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if trailingSlash && t.RedirectTrailingSlash {
path = path[:pathLen-1]
}
n, params := t.root.search(path[1:])
n, handler, params := t.root.search(r.Method, path[1:])
if n == nil {
if t.RedirectCleanPath {
// Path was not found. Try cleaning it up and search again.
// TODO Test this
cleanPath := httppath.Clean(path)
n, params = t.root.search(cleanPath[1:])
n, handler, params = t.root.search(r.Method, cleanPath[1:])
if n == nil {
// Still nothing found.
t.NotFoundHandler(w, r)
Expand All @@ -202,16 +202,12 @@ func (t *TreeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}

handler, ok := n.leafHandler[r.Method]
if !ok {
if r.Method == "HEAD" && t.HeadCanUseGet {
handler, ok = n.leafHandler["GET"]
} else if r.Method == "OPTIONS" && t.OptionsHandler != nil {
if handler == nil {
if r.Method == "OPTIONS" && t.OptionsHandler != nil {
handler = t.OptionsHandler
ok = true
}

if !ok {
if handler == nil {
t.MethodNotAllowedHandler(w, r, n.leafHandler)
return
}
Expand Down
32 changes: 15 additions & 17 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,12 @@ func benchRequest(b *testing.B, router http.Handler, r *http.Request) {
func TestMethods(t *testing.T) {
for _, scenario := range scenarios {
t.Log(scenario.description)
testMethods(t, scenario.RequestCreator)
testMethods(t, scenario.RequestCreator, true)
testMethods(t, scenario.RequestCreator, false)
}
}

func testMethods(t *testing.T, newRequest RequestCreator) {
func testMethods(t *testing.T, newRequest RequestCreator, headCanUseGet bool) {
var result string

makeHandler := func(method string) HandlerFunc {
Expand All @@ -81,6 +82,7 @@ func testMethods(t *testing.T, newRequest RequestCreator) {
}

router := New()
router.HeadCanUseGet = headCanUseGet
router.GET("/user/:param", makeHandler("GET"))
router.POST("/user/:param", makeHandler("POST"))
router.PATCH("/user/:param", makeHandler("PATCH"))
Expand All @@ -93,7 +95,7 @@ func testMethods(t *testing.T, newRequest RequestCreator) {
r, _ := newRequest(method, "/user/"+method, nil)
router.ServeHTTP(w, r)
if expect == "" && w.Code != http.StatusMethodNotAllowed {
t.Errorf("Method %s not expected to match but saw code %d", w.Code)
t.Errorf("Method %s not expected to match but saw code %d", method, w.Code)
}

if result != expect {
Expand All @@ -106,19 +108,15 @@ func testMethods(t *testing.T, newRequest RequestCreator) {
testMethod("PATCH", "PATCH")
testMethod("PUT", "PUT")
testMethod("DELETE", "DELETE")
t.Log("Test HeadCanUseGet = true")
testMethod("HEAD", "GET")

router.HeadCanUseGet = false
t.Log("Test HeadCanUseGet = false")
testMethod("HEAD", "")
if headCanUseGet {
t.Log("Test implicit HEAD with HeadCanUseGet = true")
testMethod("HEAD", "GET")
} else {
t.Log("Test implicit HEAD with HeadCanUseGet = false")
testMethod("HEAD", "")
}

router.HEAD("/user/:param", makeHandler("HEAD"))

t.Log("Test HeadCanUseGet = false with explicit HEAD handler")
testMethod("HEAD", "HEAD")
router.HeadCanUseGet = true
t.Log("Test HeadCanUseGet = true with explicit HEAD handler")
testMethod("HEAD", "HEAD")
}

Expand Down Expand Up @@ -157,7 +155,7 @@ func TestMethodNotAllowedHandler(t *testing.T) {

calledNotAllowed = true

expected := []string{"GET", "PUT", "DELETE"}
expected := []string{"GET", "PUT", "DELETE", "HEAD"}
allowed := make([]string, 0)
for m := range methods {
allowed = append(allowed, m)
Expand Down Expand Up @@ -188,7 +186,7 @@ func TestMethodNotAllowedHandler(t *testing.T) {

allowed := w.Header()["Allow"]
sort.Strings(allowed)
expected := []string{"DELETE", "GET", "PUT"}
expected := []string{"DELETE", "GET", "PUT", "HEAD"}
sort.Strings(expected)

if !reflect.DeepEqual(allowed, expected) {
Expand Down Expand Up @@ -271,7 +269,7 @@ func TestOptionsHandler(t *testing.T) {

allowed := w.Header()["Allow"]
sort.Strings(allowed)
expected := []string{"DELETE", "GET", "PUT"}
expected := []string{"DELETE", "GET", "PUT", "HEAD"}
sort.Strings(expected)

if !reflect.DeepEqual(allowed, expected) {
Expand Down
Loading

0 comments on commit 837a149

Please sign in to comment.