diff --git a/README.md b/README.md index 96ec37b..eb644c8 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,8 @@ If TreeMux.HeadCanUseGet is set to true, the router will call the GET handler fo Go's http.ServeContent and related functions already handle the HEAD method correctly by sending only the header, so in most cases your handlers will not need any special cases for it. +By default TreeMux.OptionsHandler is a null handler that doesn't affect your routing. If you set the handler, it will be called on OPTIONS requests to a path already registered by another method. If you set a path specific handler by using `router.OPTIONS`, it will override the global Options Handler for that path. + ### Trailing Slashes The router has special handling for paths with trailing slashes. If a pattern is added to the router with a trailing slash, any matches on that pattern without a trailing slash will be redirected to the version with the slash. If a pattern does not have a trailing slash, matches on that pattern with a trailing slash will be redirected to the version without. @@ -83,7 +85,7 @@ POST /posts will redirect to /posts/, because the GET method used a trailing sla ### Custom Redirects -RedirectBehavior sets the behavior when the router redirects the request to the canonical version of the requested URL using RedirectTrailingSlash or RedirectClean. The default behavior is to return a 301 status, redirecting the browser to the version of the URL that matches the given pattern. +RedirectBehavior sets the behavior when the router redirects the request to the canonical version of the requested URL using RedirectTrailingSlash or RedirectClean. The default behavior is to return a 301 status, redirecting the browser to the version of the URL that matches the given pattern. These are the values accepted for RedirectBehavior. You may also add these values to the RedirectMethodBehavior map to define custom per-method redirect behavior. diff --git a/router.go b/router.go index 042e282..b3b4d58 100644 --- a/router.go +++ b/router.go @@ -6,9 +6,10 @@ package httptreemux import ( "fmt" - "github.com/dimfeld/httppath" "net/http" "net/url" + + "github.com/dimfeld/httppath" ) // The params argument contains the parameters parsed from wildcards and catch-alls in the URL. @@ -51,8 +52,14 @@ type TreeMux struct { // The default PanicHandler just returns a 500 code. PanicHandler PanicHandler + // The default NotFoundHandler is http.NotFound. NotFoundHandler func(w http.ResponseWriter, r *http.Request) + + // Any OPTIONS request that matches a path without its own OPTIONS handler will use this handler, + // if set, instead of calling MethodNotAllowedHandler. + OptionsHandler HandlerFunc + // MethodNotAllowedHandler is called when a pattern matches, but that // pattern does not have a handler for the requested method. The default // handler just writes the status code http.StatusMethodNotAllowed and adds @@ -61,6 +68,7 @@ type TreeMux struct { // handler function. MethodNotAllowedHandler func(w http.ResponseWriter, r *http.Request, methods map[string]HandlerFunc) + // HeadCanUseGet allows the router to use the GET handler to respond to // HEAD requests if no explicit HEAD handler has been added for the // matching pattern. This is true by default. @@ -311,6 +319,9 @@ func (t *TreeMux) ServeHTTP(w http.ResponseWriter, r *http.Request) { if !ok { if r.Method == "HEAD" && t.HeadCanUseGet { handler, ok = n.leafHandler["GET"] + } else if r.Method == "OPTIONS" && t.OptionsHandler != nil { + handler = t.OptionsHandler + ok = true } if !ok { diff --git a/router_test.go b/router_test.go index 12a373b..3393b3b 100644 --- a/router_test.go +++ b/router_test.go @@ -205,6 +205,81 @@ func TestMethodNotAllowedHandler(t *testing.T) { } } +func TestOptionsHandler(t *testing.T) { + optionsHandler := func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.WriteHeader(http.StatusNoContent) + } + + customOptionsHandler := func(w http.ResponseWriter, r *http.Request, pathParams map[string]string) { + w.Header().Set("Access-Control-Allow-Origin", "httptreemux.com") + w.WriteHeader(http.StatusUnauthorized) + } + + router := New() + router.GET("/user/abc", simpleHandler) + router.PUT("/user/abc", simpleHandler) + router.DELETE("/user/abc", simpleHandler) + router.OPTIONS("/user/abc/options", customOptionsHandler) + + // test without an OPTIONS handler + w := httptest.NewRecorder() + r, _ := newRequest("OPTIONS", "/user/abc", nil) + router.ServeHTTP(w, r) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("Expected error %d from built-in not found handler but saw %d", + http.StatusMethodNotAllowed, w.Code) + } + + // Now try with a global options handler. + router.OptionsHandler = optionsHandler + + w = httptest.NewRecorder() + router.ServeHTTP(w, r) + if !(w.Code == http.StatusNoContent && w.Header()["Access-Control-Allow-Origin"][0] == "*") { + t.Error("global options handler was not called") + } + + // Now see if a custom handler overwrites the global options handler. + w = httptest.NewRecorder() + r, _ = newRequest("OPTIONS", "/user/abc/options", nil) + router.ServeHTTP(w, r) + if !(w.Code == http.StatusUnauthorized && w.Header()["Access-Control-Allow-Origin"][0] == "httptreemux.com") { + t.Error("custom options handler did not overwrite global handler") + } + + // Now see if a custom handler works with the global options handler set to nil. + router.OptionsHandler = nil + w = httptest.NewRecorder() + r, _ = newRequest("OPTIONS", "/user/abc/options", nil) + router.ServeHTTP(w, r) + if !(w.Code == http.StatusUnauthorized && w.Header()["Access-Control-Allow-Origin"][0] == "httptreemux.com") { + t.Error("custom options handler did not overwrite global handler") + } + + // Make sure that the MethodNotAllowedHandler works when OptionsHandler is set + router.OptionsHandler = optionsHandler + w = httptest.NewRecorder() + r, _ = newRequest("POST", "/user/abc", nil) + router.ServeHTTP(w, r) + + if w.Code != http.StatusMethodNotAllowed { + t.Errorf("Expected error %d from built-in not found handler but saw %d", + http.StatusMethodNotAllowed, w.Code) + } + + allowed := w.Header()["Allow"] + sort.Strings(allowed) + expected := []string{"DELETE", "GET", "PUT"} + sort.Strings(expected) + + if !reflect.DeepEqual(allowed, expected) { + t.Errorf("Expected Allow header %v, saw %v", + expected, allowed) + } +} + func TestPanic(t *testing.T) { router := New()