Skip to content

Commit

Permalink
Merge pull request #82 from dimfeld/route-info-at-start-of-handler-chain
Browse files Browse the repository at this point in the history
Add ContextData before middleware runs
  • Loading branch information
dimfeld authored Mar 30, 2021
2 parents ebfe087 + 0b81076 commit b61bfc4
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 20 deletions.
41 changes: 28 additions & 13 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,32 +55,47 @@ func (cg *ContextGroup) NewGroup(path string) *ContextGroup {
return cg.NewContextGroup(path)
}

// Handle allows handling HTTP requests via an http.HandlerFunc, as opposed to an httptreemux.HandlerFunc.
// Any parameters from the request URL are stored in a map[string]string in the request's context.
func (cg *ContextGroup) Handle(method, path string, handler http.HandlerFunc) {
func (cg *ContextGroup) wrapHandler(path string, handler HandlerFunc) HandlerFunc {
if len(cg.group.stack) > 0 {
handler = handlerWithMiddlewares(handler, cg.group.stack)
}

//add the context data after adding all middleware
fullPath := cg.group.path + path
cg.group.Handle(method, path, func(w http.ResponseWriter, r *http.Request, params map[string]string) {
return func(writer http.ResponseWriter, request *http.Request, m map[string]string) {
routeData := &contextData{
route: fullPath,
params: params,
params: m,
}
r = r.WithContext(AddRouteDataToContext(r.Context(), routeData))
request = request.WithContext(AddRouteDataToContext(request.Context(), routeData))
handler(writer, request, m)
}
}

// Handle allows handling HTTP requests via an http.HandlerFunc, as opposed to an httptreemux.HandlerFunc.
// Any parameters from the request URL are stored in a map[string]string in the request's context.
func (cg *ContextGroup) Handle(method, path string, handler http.HandlerFunc) {
cg.group.mux.mutex.Lock()
defer cg.group.mux.mutex.Unlock()

wrapped := cg.wrapHandler(path, func(w http.ResponseWriter, r *http.Request, params map[string]string) {
handler(w, r)
})

cg.group.addFullStackHandler(method, path, wrapped)
}

// Handler allows handling HTTP requests via an http.Handler interface, as opposed to an httptreemux.HandlerFunc.
// Any parameters from the request URL are stored in a map[string]string in the request's context.
func (cg *ContextGroup) Handler(method, path string, handler http.Handler) {
fullPath := cg.group.path + path
cg.group.Handle(method, path, func(w http.ResponseWriter, r *http.Request, params map[string]string) {
routeData := &contextData{
route: fullPath,
params: params,
}
r = r.WithContext(AddRouteDataToContext(r.Context(), routeData))
cg.group.mux.mutex.Lock()
defer cg.group.mux.mutex.Unlock()

wrapped := cg.wrapHandler(path, func(w http.ResponseWriter, r *http.Request, params map[string]string) {
handler.ServeHTTP(w, r)
})

cg.group.addFullStackHandler(method, path, wrapped)
}

// GET is convenience method for handling GET requests on a context group.
Expand Down
64 changes: 57 additions & 7 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,24 @@ func TestContextParams(t *testing.T) {
}

func TestContextRoute(t *testing.T) {
tests := []struct{
tests := []struct {
name,
expectedRoute string
} {
}{
{
name: "basic",
name: "basic",
expectedRoute: "/base/path",
},
{
name: "params",
name: "params",
expectedRoute: "/base/path/:id/items/:itemid",
},
{
name: "catch-all",
name: "catch-all",
expectedRoute: "/base/*path",
},
{
name: "empty",
name: "empty",
expectedRoute: "",
},
}
Expand Down Expand Up @@ -140,6 +140,10 @@ func testContextGroupMethods(t *testing.T, reqGen RequestCreator, headCanUseGet
}

ctxData := ContextData(r.Context())
if ctxData == nil {
t.Fatal("context did not contain ContextData")
}

v, ok = ctxData.Params()["param"]
if hasParam && !ok {
t.Error("missing key 'param' in context from ContextData")
Expand Down Expand Up @@ -371,7 +375,7 @@ func TestAddDataToContext(t *testing.T) {
}

ctx := AddRouteDataToContext(context.Background(), &contextData{
route: expectedRoute,
route: expectedRoute,
params: expectedParams,
})

Expand Down Expand Up @@ -416,3 +420,49 @@ func TestAddRouteToContext(t *testing.T) {
t.Error("failed to retrieve context data")
}
}

func TestContextDataWithMiddleware(t *testing.T) {
wantRoute := "/foo/:id/bar"
wantParams := map[string]string{
"id": "15",
}

validateRequestAndParams := func(request *http.Request, params map[string]string, location string) {
data := ContextData(request.Context())
if data == nil {
t.Fatalf("ContextData returned nil in %s", location)
}
if data.Route() != wantRoute {
t.Errorf("Unexpected route in %s. Got %s", location, data.Route())
}
if !reflect.DeepEqual(data.Params(), wantParams) {
t.Errorf("Unexpected context params in %s. Got %+v", location, data.Params())
}
if !reflect.DeepEqual(params, wantParams) {
t.Errorf("Unexpected handler params in %s. Got %+v", location, params)
}
}

router := NewContextMux()
router.Use(func(next HandlerFunc) HandlerFunc {
return func(writer http.ResponseWriter, request *http.Request, m map[string]string) {
t.Log("Testing Middleware")
validateRequestAndParams(request, m, "middleware")
next(writer, request, m)
}
})

router.GET(wantRoute, func(writer http.ResponseWriter, request *http.Request) {
t.Log("Testing handler")
validateRequestAndParams(request, ContextParams(request.Context()), "handler")
writer.WriteHeader(http.StatusOK)
})

w := httptest.NewRecorder()
r, _ := http.NewRequest(http.MethodGet, "/foo/15/bar", nil)
router.ServeHTTP(w, r)

if w.Code != http.StatusOK {
t.Fatalf("unexpected status code. got %d", w.Code)
}
}
5 changes: 5 additions & 0 deletions group.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ func (g *Group) Handle(method string, path string, handler HandlerFunc) {
handler = handlerWithMiddlewares(handler, g.stack)
}

g.addFullStackHandler(method, path, handler)
}

func (g *Group) addFullStackHandler(method string, path string, handler HandlerFunc) {
addSlash := false
addOne := func(thePath string) {
node := g.mux.root.addPath(thePath[1:], nil, false)
Expand Down Expand Up @@ -175,6 +179,7 @@ func (g *Group) Handle(method string, path string, handler HandlerFunc) {
}

addOne(path)

}

// Syntactic sugar for Handle("GET", path, handler)
Expand Down

0 comments on commit b61bfc4

Please sign in to comment.