diff --git a/atreugo.go b/atreugo.go index fdbd0d1..adfd20e 100644 --- a/atreugo.go +++ b/atreugo.go @@ -41,6 +41,10 @@ func New(cfg Config) *Atreugo { cfg.Logger = defaultLogger } + if cfg.JSONMarshalFunc == nil { + cfg.JSONMarshalFunc = defaultJSONMarshalFunc + } + if cfg.ErrorView == nil { cfg.ErrorView = defaultErrorView } diff --git a/atreugo_test.go b/atreugo_test.go index bb5869f..0c4dec6 100644 --- a/atreugo_test.go +++ b/atreugo_test.go @@ -5,6 +5,7 @@ import ( "crypto/tls" "errors" "fmt" + "io" "math/rand" "net" "os" @@ -32,6 +33,7 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo network string gracefulShutdown bool gracefulShutdownSignals []os.Signal + jsonMarshalFunc JSONMarshalFunc notFoundView View methodNotAllowedView View panicView PanicView @@ -39,12 +41,16 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo type want struct { gracefulShutdownSignals []os.Signal + jsonMarshalFunc JSONMarshalFunc notFoundView bool methodNotAllowedView bool panicView bool err bool } + jsonMarshalFunc := func(w io.Writer, body interface{}) error { + return nil + } notFoundView := func(ctx *RequestCtx) error { return nil } @@ -67,8 +73,7 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo args: args{}, want: want{ gracefulShutdownSignals: nil, - notFoundView: false, - methodNotAllowedView: false, + jsonMarshalFunc: defaultJSONMarshalFunc, panicView: false, }, }, @@ -79,6 +84,7 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo }, want: want{ gracefulShutdownSignals: defaultGracefulShutdownSignals, + jsonMarshalFunc: defaultJSONMarshalFunc, notFoundView: false, methodNotAllowedView: false, panicView: false, @@ -90,12 +96,14 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo network: "unix", gracefulShutdown: true, gracefulShutdownSignals: []os.Signal{syscall.SIGKILL}, + jsonMarshalFunc: jsonMarshalFunc, notFoundView: notFoundView, methodNotAllowedView: methodNotAllowedView, panicView: panicView, }, want: want{ gracefulShutdownSignals: []os.Signal{syscall.SIGKILL}, + jsonMarshalFunc: jsonMarshalFunc, notFoundView: true, methodNotAllowedView: true, panicView: true, @@ -133,6 +141,7 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo Network: tt.args.network, GracefulShutdown: tt.args.gracefulShutdown, GracefulShutdownSignals: tt.args.gracefulShutdownSignals, + JSONMarshalFunc: tt.args.jsonMarshalFunc, NotFoundView: tt.args.notFoundView, MethodNotAllowedView: tt.args.methodNotAllowedView, PanicView: tt.args.panicView, @@ -151,18 +160,6 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo t.Errorf("Logger == %p, want %p", s.cfg.Logger, defaultLogger) } - if !isEqual(s.cfg.ErrorView, defaultErrorView) { - t.Errorf("Error view == %p, want %p", s.cfg.ErrorView, defaultErrorView) - } - - if s.router == nil { - t.Fatal("Atreugo router instance is nil") - } - - if s.router.GlobalOPTIONS != nil { - t.Error("GlobalOPTIONS handler is not nil") - } - if !reflect.DeepEqual(tt.want.gracefulShutdownSignals, s.cfg.GracefulShutdownSignals) { t.Errorf( "GracefulShutdownSignals = %v, want %v", @@ -170,6 +167,10 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo ) } + if !isEqual(s.cfg.JSONMarshalFunc, tt.want.jsonMarshalFunc) { + t.Errorf("JSONMarshalFunc == %p, want %p", s.cfg.JSONMarshalFunc, tt.want.jsonMarshalFunc) + } + if tt.want.notFoundView != (s.router.NotFound != nil) { t.Error("NotFound handler is not setted") } @@ -178,6 +179,10 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo t.Error("MethodNotAllowed handler is not setted") } + if !isEqual(s.cfg.ErrorView, defaultErrorView) { + t.Errorf("Error view == %p, want %p", s.cfg.ErrorView, defaultErrorView) + } + if tt.want.panicView != (s.router.PanicHandler != nil) { t.Error("PanicHandler handler is not setted") } @@ -190,6 +195,14 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo t.Errorf("Panic handler response == %s, want %s", ctx.Response.Body(), panicErr.Error()) } } + + if s.router == nil { + t.Fatal("Atreugo router instance is nil") + } + + if s.router.GlobalOPTIONS != nil { + t.Error("GlobalOPTIONS handler is not nil") + } }) } } diff --git a/context.go b/context.go index feff465..c6cac76 100644 --- a/context.go +++ b/context.go @@ -17,7 +17,10 @@ var ( requestCtxPool = sync.Pool{ New: func() interface{} { - return new(RequestCtx) + ctx := new(RequestCtx) + ctx.jsonMarshalFunc = defaultJSONMarshalFunc + + return ctx }, } ) diff --git a/context_test.go b/context_test.go index c650d48..8b6f371 100644 --- a/context_test.go +++ b/context_test.go @@ -13,8 +13,12 @@ func Test_AcquireRequestCtx(t *testing.T) { ctx := new(fasthttp.RequestCtx) actx := AcquireRequestCtx(ctx) // nolint:ifshort + if !isEqual(actx.jsonMarshalFunc, defaultJSONMarshalFunc) { + t.Errorf("jsonMarshalFunc = %p, want %p", actx.jsonMarshalFunc, defaultJSONMarshalFunc) + } + if actx.RequestCtx != ctx { - t.Errorf("AcquireRequestCtx() = %p, want %p", actx.RequestCtx, ctx) + t.Errorf("RequestCtx = %p, want %p", actx.RequestCtx, ctx) } } @@ -31,19 +35,19 @@ func Test_ReleaseRequestCtx(t *testing.T) { ReleaseRequestCtx(actx) if actx.next { - t.Errorf("reset() next is not 'false'") + t.Errorf("next is not 'false'") } if actx.skipView { - t.Errorf("reset() skipView is not 'false'") + t.Errorf("skipView is not 'false'") } - if actx.RequestCtx != nil { - t.Errorf("reset() *fasthttp.RequestCtx = %p, want %v", actx.RequestCtx, nil) + if actx.jsonMarshalFunc == nil { + t.Errorf("jsonMarshalFunc is nil") } if actx.RequestCtx != nil { - t.Errorf("ReleaseRequestCtx() *fasthttp.RequestCtx = %p, want %v", actx.RequestCtx, nil) + t.Errorf("*fasthttp.RequestCtx = %p, want %v", actx.RequestCtx, nil) } } diff --git a/response.go b/response.go index 2d9e1ba..e8fd7ff 100644 --- a/response.go +++ b/response.go @@ -2,13 +2,18 @@ package atreugo import ( "encoding/json" + "io" "github.com/valyala/bytebufferpool" "github.com/valyala/fasthttp" ) +func defaultJSONMarshalFunc(w io.Writer, body interface{}) error { + return json.NewEncoder(w).Encode(body) // nolint:wrapcheck +} + // JSONResponse return response with body in json format. -func (ctx *RequestCtx) JSONResponse(body interface{}, statusCode ...int) (err error) { +func (ctx *RequestCtx) JSONResponse(body interface{}, statusCode ...int) error { ctx.Response.Header.SetContentType("application/json") if len(statusCode) > 0 { @@ -17,7 +22,7 @@ func (ctx *RequestCtx) JSONResponse(body interface{}, statusCode ...int) (err er w := ctx.Response.BodyWriter() - return json.NewEncoder(w).Encode(body) // nolint:wrapcheck + return ctx.jsonMarshalFunc(w, body) // nolint:wrapcheck } // HTTPResponse return response with body in html format. diff --git a/response_test.go b/response_test.go index 4672c82..600a770 100644 --- a/response_test.go +++ b/response_test.go @@ -4,6 +4,7 @@ import ( "bytes" "errors" "fmt" + "io" "os" "path" "testing" @@ -23,8 +24,9 @@ func (cj customJSON) MarshalJSON() ([]byte, error) { func TestJSONResponse(t *testing.T) { //nolint:funlen type args struct { - body interface{} - statusCode int + body interface{} + statusCode int + jsonMarshalFunc JSONMarshalFunc } type want struct { @@ -42,8 +44,9 @@ func TestJSONResponse(t *testing.T) { //nolint:funlen { name: "ValidBody", args: args{ - body: JSON{"test": true}, - statusCode: 200, + body: JSON{"test": true}, + statusCode: 200, + jsonMarshalFunc: defaultJSONMarshalFunc, }, want: want{ body: "{\"test\":true}", @@ -55,8 +58,9 @@ func TestJSONResponse(t *testing.T) { //nolint:funlen { name: "BodyAsJsonMarshaler", args: args{ - body: customJSON{Value: "test"}, - statusCode: 200, + body: customJSON{Value: "test"}, + statusCode: 200, + jsonMarshalFunc: defaultJSONMarshalFunc, }, want: want{ body: "{\"Value\":\"test\"}", @@ -66,10 +70,29 @@ func TestJSONResponse(t *testing.T) { //nolint:funlen }, }, { - name: "InvalidBody", + name: "CustomJSONMarshalFunc", args: args{ body: make(chan int), statusCode: 200, + jsonMarshalFunc: func(w io.Writer, value interface{}) error { + _, err := w.Write([]byte("my custom response")) + + return err // nolint:wrapcheck + }, + }, + want: want{ + body: "my custom response", + statusCode: 200, + contentType: "application/json", + err: false, + }, + }, + { + name: "InvalidBody", + args: args{ + body: make(chan int), + statusCode: 200, + jsonMarshalFunc: defaultJSONMarshalFunc, }, want: want{ body: "", @@ -88,6 +111,7 @@ func TestJSONResponse(t *testing.T) { //nolint:funlen ctx := new(fasthttp.RequestCtx) actx := AcquireRequestCtx(ctx) + actx.jsonMarshalFunc = tt.args.jsonMarshalFunc err := actx.JSONResponse(tt.args.body, tt.args.statusCode) if tt.want.err && (err == nil) { diff --git a/types.go b/types.go index 3d6a98a..6588677 100644 --- a/types.go +++ b/types.go @@ -2,6 +2,7 @@ package atreugo import ( "crypto/tls" + "io" "net" "os" "time" @@ -21,6 +22,8 @@ type preforkServer interface { ListenAndServe(addr string) error } +type JSONMarshalFunc func(w io.Writer, value interface{}) error + // Atreugo implements high performance HTTP server // // It is prohibited copying Atreugo values. Create new values instead. @@ -111,6 +114,9 @@ type Config struct { // nolint:maligned // in 'Accept-Encoding' header. Compress bool + // JSONMarshalFunc is used to marshal the response body to writer as JSON. + JSONMarshalFunc JSONMarshalFunc + // Configurable view which is called when no matching route is // found. If it is not set, http.NotFound is used. NotFoundView View @@ -484,8 +490,9 @@ type StaticFS struct { type RequestCtx struct { noCopy nocopy.NoCopy // nolint:structcheck,unused - next bool - skipView bool + next bool + skipView bool + jsonMarshalFunc JSONMarshalFunc // Flag to avoid stack overflow when this context has been embedded in the attached context searchingOnAttachedCtx int32