Skip to content

Commit

Permalink
Add middleware support (#71)
Browse files Browse the repository at this point in the history
* Add middleware support

* Add UseHandlerFunc

* Add UseHandler

* Prevent stack sharing between groups

* Remove UseHandlerFunc for simplicity
  • Loading branch information
vmihailenco authored Apr 11, 2020
1 parent 2a407fa commit 5614850
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 3 deletions.
10 changes: 10 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,16 @@ type ContextGroup struct {
group *Group
}

// Use appends a middleware handler to the Group middleware stack.
func (cg *ContextGroup) Use(fn MiddlewareFunc) {
cg.group.Use(fn)
}

// UseHandler is like Use but accepts http.Handler middleware.
func (cg *ContextGroup) UseHandler(middleware func(http.Handler) http.Handler) {
cg.group.UseHandler(middleware)
}

// UsingContext wraps the receiver to return a new instance of a ContextGroup.
// The returned ContextGroup is a sibling to its wrapped Group, within the parent TreeMux.
// The choice of using a *Group as the receiver, as opposed to a function parameter, allows chaining
Expand Down
52 changes: 49 additions & 3 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,24 @@ package httptreemux

import (
"fmt"
"net/http"
"net/url"
"strings"
)

type MiddlewareFunc func(next HandlerFunc) HandlerFunc

func handlerWithMiddlewares(handler HandlerFunc, stack []MiddlewareFunc) HandlerFunc {
for i := len(stack) - 1; i >= 0; i-- {
handler = stack[i](handler)
}
return handler
}

type Group struct {
path string
mux *TreeMux
path string
mux *TreeMux
stack []MiddlewareFunc
}

// Add a sub-group to this group
Expand All @@ -23,7 +34,38 @@ func (g *Group) NewGroup(path string) *Group {
if path[len(path)-1] == '/' {
path = path[:len(path)-1]
}
return &Group{path, g.mux}
return &Group{
path: path,
mux: g.mux,
stack: g.stack[:len(g.stack):len(g.stack)],
}
}

// Use appends a middleware handler to the Group middleware stack.
func (g *Group) Use(fn MiddlewareFunc) {
g.stack = append(g.stack, fn)
}

type handlerWithParams struct {
handler HandlerFunc
params map[string]string
}

func (h handlerWithParams) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.handler(w, r, h.params)
}

// UseHandler is like Use but accepts http.Handler middleware.
func (g *Group) UseHandler(middleware func(http.Handler) http.Handler) {
g.stack = append(g.stack, func(next HandlerFunc) HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, params map[string]string) {
nextHandler := handlerWithParams{
handler: next,
params: params,
}
middleware(nextHandler).ServeHTTP(w, r)
}
})
}

// Path elements starting with : indicate a wildcard in the path. A wildcard will only match on a
Expand Down Expand Up @@ -92,6 +134,10 @@ func (g *Group) Handle(method string, path string, handler HandlerFunc) {
g.mux.mutex.Lock()
defer g.mux.mutex.Unlock()

if len(g.stack) > 0 {
handler = handlerWithMiddlewares(handler, g.stack)
}

addSlash := false
addOne := func(thePath string) {
node := g.mux.root.addPath(thePath[1:], nil, false)
Expand Down
124 changes: 124 additions & 0 deletions router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1073,6 +1073,130 @@ func TestRedirectEscapedPath(t *testing.T) {
}
}

func TestMiddleware(t *testing.T) {
var execLog []string

record := func(s string) {
execLog = append(execLog, s)
}

assertExecLog := func(wanted []string) {
if !reflect.DeepEqual(execLog, wanted) {
t.Fatalf("got %v, wanted %v", execLog, wanted)
}
}

newHandler := func(name string) HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, params map[string]string) {
record(name)
}
}

newMiddleware := func(name string) MiddlewareFunc {
return func(next HandlerFunc) HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, params map[string]string) {
record(name)
next(w, r, params)
}
}
}

router := New()
w := httptest.NewRecorder()

// No middlewares.
{
router.GET("/h1", newHandler("h1"))

req, _ := newRequest("GET", "/h1", nil)
router.ServeHTTP(w, req)

assertExecLog([]string{"h1"})
}

// Test route with and without middleware.
{
execLog = nil
router.Use(newMiddleware("m1"))
router.GET("/h2", newHandler("h2"))

req, _ := newRequest("GET", "/h1", nil)
router.ServeHTTP(w, req)

req, _ = newRequest("GET", "/h2", nil)
router.ServeHTTP(w, req)

assertExecLog([]string{"h1", "m1", "h2"})
}

// NewGroup inherits middlewares but has its own stack.
{
execLog = nil
g := router.NewGroup("/g1")
g.Use(newMiddleware("m2"))
g.GET("/h3", newHandler("h3"))

req, _ := newRequest("GET", "/h2", nil)
router.ServeHTTP(w, req)

req, _ = newRequest("GET", "/g1/h3", nil)
router.ServeHTTP(w, req)

assertExecLog([]string{"m1", "h2", "m1", "m2", "h3"})
}

// Middleware can modify params.
{
execLog = nil
g := router.NewGroup("/g2")
g.Use(func(next HandlerFunc) HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, params map[string]string) {
record("m4")
if params == nil {
params = make(map[string]string)
}
params["foo"] = "bar"
next(w, r, params)
}
})
g.GET("/h6", func(w http.ResponseWriter, r *http.Request, params map[string]string) {
record("h6")
if params["foo"] != "bar" {
t.Fatalf("got %q, wanted %q", params["foo"], "bar")
}
})

req, _ := newRequest("GET", "/g2/h6", nil)
router.ServeHTTP(w, req)

assertExecLog([]string{"m1", "m4", "h6"})
}

// Middleware can serve request without calling next.
{
execLog = nil
router.Use(func(_ HandlerFunc) HandlerFunc {
return func(w http.ResponseWriter, r *http.Request, params map[string]string) {
record("m3")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("pong"))
}
})
router.GET("/h5", newHandler("h5"))

req, _ := newRequest("GET", "/h5", nil)
router.ServeHTTP(w, req)

assertExecLog([]string{"m1", "m3"})
if w.Code != http.StatusBadRequest {
t.Fatalf("got %d, wanted %d", w.Code, http.StatusBadRequest)
}
if w.Body.String() != "pong" {
t.Fatalf("got %s, wanted %s", w.Body.String(), "pong")
}
}
}

func BenchmarkRouterSimple(b *testing.B) {
router := New()

Expand Down

0 comments on commit 5614850

Please sign in to comment.