Skip to content

Commit

Permalink
refactor and actually separate http from router
Browse files Browse the repository at this point in the history
  • Loading branch information
TheKhanj committed Sep 14, 2023
1 parent eea8172 commit cfe6427
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 200 deletions.
140 changes: 66 additions & 74 deletions http_router.go → http/http_router.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
package drouter
package http

import (
"context"
"net/http"
"strings"
"sync"

"github.com/thekhanj/drouter"
)

// Handle is a function that can be registered to a route to handle HTTP
// requests. Like http.HandlerFunc, but has a third parameter for the values of
// wildcards (path variables).
type HttpHandle func(http.ResponseWriter, *http.Request, Params)
type HttpHandle func(http.ResponseWriter, *http.Request, drouter.Params)

// Router is a http.Handler which can be used to dispatch requests to different
// handler functions via configurable routes
type HttpRouter struct {
Router
routers map[string]*drouter.Router

methods []string
paramsPool sync.Pool
maxParams uint16

// If enabled, adds the matched route path onto the http.Request context
// before invoking the handle.
Expand Down Expand Up @@ -88,71 +92,71 @@ type httpHandle struct {
handle HttpHandle
}

func (h *httpHandle) Handle(params Params) {
func (h *httpHandle) Handle(params drouter.Params) {
h.handle(h.w, h.req, params)
}

func getHttpRoutingPath(method string, path string) string {
return method + " " + path
}

func httpRoutingPathToPath(path string) string {
for i := range path {
if path[i] == ' ' {
return path[i+1:]
}
}
panic("added a route which is not http")
}

// NewHttpRouter returns a new initialized Router.
// New returns a new initialized Router.
// Path auto-correction, including trailing slashes, is enabled by default.
func NewHttpRouter() *HttpRouter {
func New() *HttpRouter {
return &HttpRouter{
methods: []string{},

RedirectTrailingSlash: true,
RedirectFixedPath: true,
HandleMethodNotAllowed: true,
HandleOPTIONS: true,
}
}

func (r *HttpRouter) getParams() *drouter.Params {
ps, _ := r.paramsPool.Get().(*drouter.Params)
*ps = (*ps)[0:0] // reset slice
return ps
}

func (r *HttpRouter) putParams(ps *drouter.Params) {
if ps != nil {
r.paramsPool.Put(ps)
}
}

func (r *HttpRouter) lazyInitParamsPool() {
if !(r.paramsPool.New == nil) {
return
}

r.paramsPool.New = func() interface{} {
ps := make(drouter.Params, 0, r.maxParams)
return &ps
}
}

func (r *HttpRouter) updateMaxParams(path string, varsCount uint16) {
if paramsCount := drouter.CountParams(path); paramsCount+varsCount > r.maxParams {
r.maxParams = paramsCount + varsCount
}
}

func (r *HttpRouter) saveMatchedRoutePath(path string, handle HttpHandle) HttpHandle {
return func(w http.ResponseWriter, req *http.Request, ps Params) {
return func(w http.ResponseWriter, req *http.Request, ps drouter.Params) {
if ps == nil {
psp := r.getParams()
ps = (*psp)[0:1]
ps[0] = Param{
Key: MatchedRoutePathParam,
ps[0] = drouter.Param{
Key: drouter.MatchedRoutePathParam,
Value: path,
}
handle(w, req, ps)
r.putParams(psp)
} else {
ps = append(ps, Param{
Key: MatchedRoutePathParam,
ps = append(ps, drouter.Param{
Key: drouter.MatchedRoutePathParam,
Value: path,
})
handle(w, req, ps)
}
}
}

func (r *HttpRouter) methodExists(method string) bool {
for _, match := range r.methods {
if match == method {
return true
}
}

return false
}

func (r *HttpRouter) addMethod(method string) {
r.methods = append(r.methods, method)
}

// GET is a shortcut for router.Handle(http.MethodGet, path, handle)
func (r *HttpRouter) GET(path string, handle HttpHandle) {
r.Handle(http.MethodGet, path, handle)
Expand Down Expand Up @@ -209,49 +213,38 @@ func (r *HttpRouter) Handle(method, path string, handle HttpHandle) {
panic("handle must not be nil")
}

httpRoutingPath := getHttpRoutingPath(method, path)

if r.SaveMatchedRoutePath {
varsCount++
handle = r.saveMatchedRoutePath(path, handle)
}

if r.root == nil {
r.root = new(node)
if r.routers == nil {
r.routers = make(map[string]*drouter.Router)
}

root := r.root
if !r.methodExists(method) {
r.addMethod(method)
router := r.routers[method]
if router == nil {
router = drouter.New()
r.routers[method] = router

r.globalAllowed = r.allowed("*", "")
}

root.addRoute(httpRoutingPath, handle)

// Update maxParams
if paramsCount := countParams(path); paramsCount+varsCount > r.maxParams {
r.maxParams = paramsCount + varsCount
}
router.AddRoute(path, handle)

// Lazy-init paramsPool alloc func
if r.paramsPool.New == nil && r.maxParams > 0 {
r.paramsPool.New = func() interface{} {
ps := make(Params, 0, r.maxParams)
return &ps
}
}
r.updateMaxParams(path, varsCount)
r.lazyInitParamsPool()
}

// Handler is an adapter which allows the usage of an http.Handler as a
// request handle.
// The Params are available in the request context under ParamsKey.
func (r *HttpRouter) Handler(method, path string, handler http.Handler) {
r.Handle(method, path,
func(w http.ResponseWriter, req *http.Request, p Params) {
func(w http.ResponseWriter, req *http.Request, p drouter.Params) {
if len(p) > 0 {
ctx := req.Context()
ctx = context.WithValue(ctx, ParamsKey, p)
ctx = context.WithValue(ctx, drouter.ParamsKey, p)
req = req.WithContext(ctx)
}
handler.ServeHTTP(w, req)
Expand Down Expand Up @@ -282,7 +275,7 @@ func (r *HttpRouter) ServeFiles(path string, root http.FileSystem) {

fileServer := http.FileServer(root)

r.GET(path, func(w http.ResponseWriter, req *http.Request, ps Params) {
r.GET(path, func(w http.ResponseWriter, req *http.Request, ps drouter.Params) {
req.URL.Path = ps.ByName("filepath")
fileServer.ServeHTTP(w, req)
})
Expand All @@ -300,7 +293,7 @@ func (r *HttpRouter) allowed(path, reqMethod string) (allow string) {
if path == "*" { // server-wide
// empty method is used for internal calls to refresh the cache
if reqMethod == "" {
for _, method := range r.methods {
for method := range r.routers {
if method == http.MethodOptions {
continue
}
Expand All @@ -311,13 +304,13 @@ func (r *HttpRouter) allowed(path, reqMethod string) (allow string) {
return r.globalAllowed
}
} else { // specific path
for _, method := range r.methods {
for method := range r.routers {
// Skip the requested method - we already tried this one
if method == reqMethod || method == http.MethodOptions {
continue
}

handler, _, _ := r.root.getValue(getHttpRoutingPath(method, path), nil)
handler, _ := r.routers[method].Lookup(path, nil)
if handler != nil {
// Add request method to list of allowed methods
allowed = append(allowed, method)
Expand Down Expand Up @@ -353,10 +346,9 @@ func (r *HttpRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {

path := req.URL.Path

if root := r.root; root != nil {
if handle, ps, tsr := root.getValue(
getHttpRoutingPath(req.Method, path), r.getParams,
); handle != nil {
if router := r.routers[req.Method]; router != nil {
ps := r.getParams()
if handle, tsr := router.Lookup(path, ps); handle != nil {
if ps != nil {
handle.(HttpHandle)(w, req, *ps)
r.putParams(ps)
Expand Down Expand Up @@ -384,12 +376,12 @@ func (r *HttpRouter) ServeHTTP(w http.ResponseWriter, req *http.Request) {

// Try to fix the request path
if r.RedirectFixedPath {
fixedPath, found := root.findCaseInsensitivePath(
getHttpRoutingPath(req.Method, CleanPath(path)),
fixedPath, found := router.FindCaseInsensitivePath(
drouter.CleanPath(path),
r.RedirectTrailingSlash,
)
if found {
req.URL.Path = httpRoutingPathToPath(fixedPath)
req.URL.Path = fixedPath
http.Redirect(w, req, req.URL.String(), code)
return
}
Expand Down
Loading

0 comments on commit cfe6427

Please sign in to comment.