diff --git a/pkg/adapters/beego/doc.go b/pkg/adapters/beego/doc.go new file mode 100644 index 00000000..e1e86c1a --- /dev/null +++ b/pkg/adapters/beego/doc.go @@ -0,0 +1,21 @@ +/* +This package provides Sentinel middleware for Beego. + +Users may register SentinelMiddleware to the Beego server, like. + + import ( + sentinelPlugin "github.com/sentinel-group/sentinel-go-adapters/beego" + "github.com/beego/beego/v2/server/web" + ) + + web.RunWithMiddleWares(":0", middleware) + +The plugin extracts "HttpMethod:FullPath" as the resource name by default (e.g. GET:/foo/:id). +Users may provide customized resource name extractor when creating new +SentinelMiddleware (via options). + +Fallback logic: the plugin will return "429 Too Many Requests" status code +if current request is blocked by Sentinel rules. Users may also +provide customized fallback logic via WithBlockFallback(handler) options. +*/ +package beego diff --git a/pkg/adapters/beego/filter_chain.go b/pkg/adapters/beego/filter_chain.go new file mode 100644 index 00000000..9633ef84 --- /dev/null +++ b/pkg/adapters/beego/filter_chain.go @@ -0,0 +1,43 @@ +package beego + +import ( + sentinel "github.com/alibaba/sentinel-golang/api" + "github.com/alibaba/sentinel-golang/core/base" + "github.com/beego/beego/v2/server/web" + beegoCtx "github.com/beego/beego/v2/server/web/context" + "net/http" +) + +// SentinelFilterChain returns new web.FilterChain. +// Default resource name pattern is {httpMethod}:{apiPath}, such as "GET:/api/:id". +// Default block fallback is to return 429 (Too Many Requests) response. +// +// You may customize your own resource extractor and block handler by setting options. +func SentinelFilterChain(opts ...Option) web.FilterChain { + options := evaluateOptions(opts) + return func(next web.FilterFunc) web.FilterFunc { + return func(ctx *beegoCtx.Context) { + resourceName := ctx.Input.Method() + ":" + ctx.Input.URL() + if options.resourceExtract != nil { + resourceName = options.resourceExtract(ctx.Request) + } + entry, blockErr := sentinel.Entry( + resourceName, + sentinel.WithResourceType(base.ResTypeWeb), + sentinel.WithTrafficType(base.Inbound), + ) + if blockErr != nil { + if options.blockFallback != nil { + status, msg := options.blockFallback(ctx.Request) + http.Error(ctx.ResponseWriter, msg, status) + } else { + // default error response + http.Error(ctx.ResponseWriter, "Blocked by Sentinel", http.StatusTooManyRequests) + } + return + } + defer entry.Exit() + next(ctx) + } + } +} diff --git a/pkg/adapters/beego/filter_chain_test.go b/pkg/adapters/beego/filter_chain_test.go new file mode 100644 index 00000000..f1903894 --- /dev/null +++ b/pkg/adapters/beego/filter_chain_test.go @@ -0,0 +1,106 @@ +package beego + +import ( + "github.com/alibaba/sentinel-golang/core/flow" + "github.com/beego/beego/v2/server/web" + beegoCtx "github.com/beego/beego/v2/server/web/context" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func TestSentinelFilterChain(t *testing.T) { + type args struct { + opts []Option + method string + path string + reqPath string + handlerFunc web.HandleFunc + body io.Reader + } + type want struct { + code int + } + tests := []struct { + name string + args args + want want + }{ + { + name: "default get", + args: args{ + opts: []Option{}, + method: http.MethodGet, + path: "/test", + reqPath: "/test", + handlerFunc: func(ctx *beegoCtx.Context) { + _ = ctx.Resp([]byte("hello")) + }, + body: nil, + }, + want: want{ + code: http.StatusOK, + }, + }, + { + name: "customize resource extract", + args: args{ + opts: []Option{ + WithResourceExtractor(func(r *http.Request) string { + return "customize_block_fallback" + }), + }, + method: http.MethodPost, + path: "/api/users/:id", + reqPath: "/api/users/123", + handlerFunc: func(ctx *beegoCtx.Context) { + _ = ctx.Resp([]byte("pong")) + }, + body: nil, + }, + want: want{ + code: http.StatusTooManyRequests, + }, + }, + { + name: "customize block fallback", + args: args{ + opts: []Option{ + WithBlockFallback(func(r *http.Request) (int, string) { + return http.StatusInternalServerError, "customize block fallback" + }), + }, + method: http.MethodGet, + path: "/block", + reqPath: "/block", + handlerFunc: func(ctx *beegoCtx.Context) { + _ = ctx.Resp([]byte("pong")) + }, + body: nil, + }, + want: want{ + code: http.StatusInternalServerError, + }, + }, + } + initSentinel(t) + defer func() { + _ = flow.ClearRules() + }() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cr := web.NewControllerRegister() + cr.Get(tt.args.reqPath, tt.args.handlerFunc) + + cr.InsertFilterChain("/*", SentinelFilterChain(tt.args.opts...)) + cr.Init() + + r := httptest.NewRequest(tt.args.method, tt.args.reqPath, tt.args.body) + w := httptest.NewRecorder() + + cr.ServeHTTP(w, r) + }) + } +} diff --git a/pkg/adapters/beego/go.mod b/pkg/adapters/beego/go.mod new file mode 100644 index 00000000..4d887158 --- /dev/null +++ b/pkg/adapters/beego/go.mod @@ -0,0 +1,8 @@ +module github.com/alibaba/sentinel-golang/pkg/adapters/beego + +go 1.13 + +require ( + github.com/alibaba/sentinel-golang v1.0.4 + github.com/beego/beego/v2 v2.1.1 +) diff --git a/pkg/adapters/beego/middleware.go b/pkg/adapters/beego/middleware.go new file mode 100644 index 00000000..c3a9829a --- /dev/null +++ b/pkg/adapters/beego/middleware.go @@ -0,0 +1,42 @@ +package beego + +import ( + sentinel "github.com/alibaba/sentinel-golang/api" + "github.com/alibaba/sentinel-golang/core/base" + "github.com/beego/beego/v2/server/web" + "net/http" +) + +// SentinelMiddleware returns new web.MiddleWare. +// Default resource name pattern is {httpMethod}:{apiPath}, such as "GET:/api/:id". +// Default block fallback is to return 429 (Too Many Requests) response. +// +// You may customize your own resource extractor and block handler by setting options. +func SentinelMiddleware(opts ...Option) web.MiddleWare { + options := evaluateOptions(opts) + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resourceName := r.Method + ":" + r.URL.Path + if options.resourceExtract != nil { + resourceName = options.resourceExtract(r) + } + entry, blockErr := sentinel.Entry( + resourceName, + sentinel.WithResourceType(base.ResTypeWeb), + sentinel.WithTrafficType(base.Inbound), + ) + if blockErr != nil { + if options.blockFallback != nil { + status, msg := options.blockFallback(r) + http.Error(w, msg, status) + } else { + // default error response + http.Error(w, "Blocked by Sentinel", http.StatusTooManyRequests) + } + return + } + defer entry.Exit() + next.ServeHTTP(w, r) + }) + } +} diff --git a/pkg/adapters/beego/middleware_example_test.go b/pkg/adapters/beego/middleware_example_test.go new file mode 100644 index 00000000..3e3224ab --- /dev/null +++ b/pkg/adapters/beego/middleware_example_test.go @@ -0,0 +1,31 @@ +package beego + +import ( + "github.com/beego/beego/v2/server/web" + beegoCtx "github.com/beego/beego/v2/server/web/context" + "net/http" +) + +func Example() { + opts := []Option{ + // customize resource extractor if required + // method_path by default + WithResourceExtractor(func(r *http.Request) string { + return r.Header.Get("X-Real-IP") + }), + // customize block fallback if required + // abort with status 429 by default + WithBlockFallback(func(r *http.Request) (int, string) { + return 400, "too many request; the quota used up" + }), + } + + web.Get("/test", func(ctx *beegoCtx.Context) { + }) + + // Routing filter chain + web.InsertFilterChain("/*", SentinelFilterChain(opts...)) + + // Global middleware + web.RunWithMiddleWares(":0", SentinelMiddleware(opts...)) +} diff --git a/pkg/adapters/beego/middleware_test.go b/pkg/adapters/beego/middleware_test.go new file mode 100644 index 00000000..fa4c2ddf --- /dev/null +++ b/pkg/adapters/beego/middleware_test.go @@ -0,0 +1,161 @@ +package beego + +import ( + "context" + sentinel "github.com/alibaba/sentinel-golang/api" + "github.com/alibaba/sentinel-golang/core/flow" + "github.com/beego/beego/v2/server/web" + beegoCtx "github.com/beego/beego/v2/server/web/context" + "github.com/stretchr/testify/assert" + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func initSentinel(t *testing.T) { + err := sentinel.InitDefault() + if err != nil { + t.Fatalf("Unexpected error: %+v", err) + } + + _, err = flow.LoadRules([]*flow.Rule{ + { + Resource: "GET:/ping", + Threshold: 1.0, + TokenCalculateStrategy: flow.Direct, + ControlBehavior: flow.Reject, + StatIntervalInMs: 1000, + }, + { + Resource: "GET:/test", + Threshold: 1.0, + TokenCalculateStrategy: flow.Direct, + ControlBehavior: flow.Reject, + StatIntervalInMs: 1000, + }, + { + Resource: "GET:/block", + Threshold: 0.0, + TokenCalculateStrategy: flow.Direct, + ControlBehavior: flow.Reject, + StatIntervalInMs: 1000, + }, + { + Resource: "customize_block_fallback", + Threshold: 0.0, + TokenCalculateStrategy: flow.Direct, + ControlBehavior: flow.Reject, + StatIntervalInMs: 1000, + }, + }) + if err != nil { + t.Fatalf("Unexpected error: %+v", err) + return + } +} + +func TestSentinelMiddleware(t *testing.T) { + type args struct { + opts []Option + method string + path string + reqPath string + //handler http.Handler + handlerFunc web.HandleFunc + body io.Reader + } + type want struct { + code int + } + tests := []struct { + name string + args args + want want + }{ + { + name: "default get", + args: args{ + opts: []Option{}, + method: http.MethodGet, + path: "/ping", + reqPath: "/ping", + handlerFunc: func(ctx *beegoCtx.Context) { + _ = ctx.Resp([]byte("pong")) + }, + body: nil, + }, + want: want{ + code: http.StatusOK, + }, + }, + { + name: "customize resource extract", + args: args{ + opts: []Option{ + WithResourceExtractor(func(r *http.Request) string { + return "customize_block_fallback" + }), + }, + method: http.MethodPost, + path: "/ping", + reqPath: "/ping", + handlerFunc: func(ctx *beegoCtx.Context) { + _ = ctx.Resp([]byte("pong")) + }, + body: nil, + }, + want: want{ + code: http.StatusTooManyRequests, + }, + }, + { + name: "customize block fallback", + args: args{ + opts: []Option{ + WithBlockFallback(func(r *http.Request) (int, string) { + return http.StatusInternalServerError, "customize block fallback" + }), + }, + method: http.MethodGet, + path: "/block", + reqPath: "/block", + handlerFunc: func(ctx *beegoCtx.Context) { + _ = ctx.Resp([]byte("pong")) + }, + body: nil, + }, + want: want{ + code: http.StatusInternalServerError, + }, + }, + } + + initSentinel(t) + defer func() { + _ = flow.ClearRules() + }() + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + middleware := SentinelMiddleware(tt.args.opts...) + + server := web.NewHttpSever() + defer func() { + _ = server.Server.Shutdown(context.Background()) + }() + + server.Get(tt.args.reqPath, tt.args.handlerFunc) + server.Handlers.Init() + server.Server.Handler = middleware(server.Handlers) + + r := httptest.NewRequest(tt.args.method, tt.args.reqPath, tt.args.body) + w := httptest.NewRecorder() + + server.Server.Handler.ServeHTTP(w, r) + + assert.Equal(t, tt.want.code, w.Code) + }) + } +} diff --git a/pkg/adapters/beego/option.go b/pkg/adapters/beego/option.go new file mode 100644 index 00000000..ed81b325 --- /dev/null +++ b/pkg/adapters/beego/option.go @@ -0,0 +1,34 @@ +package beego + +import "net/http" + +type ( + Option func(*options) + options struct { + resourceExtract func(r *http.Request) string + blockFallback func(r *http.Request) (int, string) + } +) + +func evaluateOptions(opts []Option) *options { + optCopy := &options{} + for _, opt := range opts { + opt(optCopy) + } + + return optCopy +} + +// WithResourceExtractor sets the resource extractor of the web requests. +func WithResourceExtractor(fn func(r *http.Request) string) Option { + return func(opts *options) { + opts.resourceExtract = fn + } +} + +// WithBlockFallback sets the fallback handler when requests are blocked. +func WithBlockFallback(fn func(r *http.Request) (int, string)) Option { + return func(opts *options) { + opts.blockFallback = fn + } +}