Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(config): add custom json marshaler function #105

Merged
merged 1 commit into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions atreugo.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ func New(cfg Config) *Atreugo {
cfg.Logger = defaultLogger
}

if cfg.JSONMarshalFunc == nil {
cfg.JSONMarshalFunc = defaultJSONMarshalFunc
}

if cfg.ErrorView == nil {
cfg.ErrorView = defaultErrorView
}
Expand Down
41 changes: 27 additions & 14 deletions atreugo_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"io"
"math/rand"
"net"
"os"
Expand Down Expand Up @@ -32,19 +33,24 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo
network string
gracefulShutdown bool
gracefulShutdownSignals []os.Signal
jsonMarshalFunc JSONMarshalFunc
notFoundView View
methodNotAllowedView View
panicView PanicView
}

type want struct {
gracefulShutdownSignals []os.Signal
jsonMarshalFunc JSONMarshalFunc
notFoundView bool
methodNotAllowedView bool
panicView bool
err bool
}

jsonMarshalFunc := func(w io.Writer, body interface{}) error {
return nil
}
notFoundView := func(ctx *RequestCtx) error {
return nil
}
Expand All @@ -67,8 +73,7 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo
args: args{},
want: want{
gracefulShutdownSignals: nil,
notFoundView: false,
methodNotAllowedView: false,
jsonMarshalFunc: defaultJSONMarshalFunc,
panicView: false,
},
},
Expand All @@ -79,6 +84,7 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo
},
want: want{
gracefulShutdownSignals: defaultGracefulShutdownSignals,
jsonMarshalFunc: defaultJSONMarshalFunc,
notFoundView: false,
methodNotAllowedView: false,
panicView: false,
Expand All @@ -90,12 +96,14 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo
network: "unix",
gracefulShutdown: true,
gracefulShutdownSignals: []os.Signal{syscall.SIGKILL},
jsonMarshalFunc: jsonMarshalFunc,
notFoundView: notFoundView,
methodNotAllowedView: methodNotAllowedView,
panicView: panicView,
},
want: want{
gracefulShutdownSignals: []os.Signal{syscall.SIGKILL},
jsonMarshalFunc: jsonMarshalFunc,
notFoundView: true,
methodNotAllowedView: true,
panicView: true,
Expand Down Expand Up @@ -133,6 +141,7 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo
Network: tt.args.network,
GracefulShutdown: tt.args.gracefulShutdown,
GracefulShutdownSignals: tt.args.gracefulShutdownSignals,
JSONMarshalFunc: tt.args.jsonMarshalFunc,
NotFoundView: tt.args.notFoundView,
MethodNotAllowedView: tt.args.methodNotAllowedView,
PanicView: tt.args.panicView,
Expand All @@ -151,25 +160,17 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo
t.Errorf("Logger == %p, want %p", s.cfg.Logger, defaultLogger)
}

if !isEqual(s.cfg.ErrorView, defaultErrorView) {
t.Errorf("Error view == %p, want %p", s.cfg.ErrorView, defaultErrorView)
}

if s.router == nil {
t.Fatal("Atreugo router instance is nil")
}

if s.router.GlobalOPTIONS != nil {
t.Error("GlobalOPTIONS handler is not nil")
}

if !reflect.DeepEqual(tt.want.gracefulShutdownSignals, s.cfg.GracefulShutdownSignals) {
t.Errorf(
"GracefulShutdownSignals = %v, want %v",
s.cfg.GracefulShutdownSignals, tt.want.gracefulShutdownSignals,
)
}

if !isEqual(s.cfg.JSONMarshalFunc, tt.want.jsonMarshalFunc) {
t.Errorf("JSONMarshalFunc == %p, want %p", s.cfg.JSONMarshalFunc, tt.want.jsonMarshalFunc)
}

if tt.want.notFoundView != (s.router.NotFound != nil) {
t.Error("NotFound handler is not setted")
}
Expand All @@ -178,6 +179,10 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo
t.Error("MethodNotAllowed handler is not setted")
}

if !isEqual(s.cfg.ErrorView, defaultErrorView) {
t.Errorf("Error view == %p, want %p", s.cfg.ErrorView, defaultErrorView)
}

if tt.want.panicView != (s.router.PanicHandler != nil) {
t.Error("PanicHandler handler is not setted")
}
Expand All @@ -190,6 +195,14 @@ func Test_New(t *testing.T) { //nolint:funlen,gocognit,gocyclo
t.Errorf("Panic handler response == %s, want %s", ctx.Response.Body(), panicErr.Error())
}
}

if s.router == nil {
t.Fatal("Atreugo router instance is nil")
}

if s.router.GlobalOPTIONS != nil {
t.Error("GlobalOPTIONS handler is not nil")
}
})
}
}
Expand Down
5 changes: 4 additions & 1 deletion context.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ var (

requestCtxPool = sync.Pool{
New: func() interface{} {
return new(RequestCtx)
ctx := new(RequestCtx)
ctx.jsonMarshalFunc = defaultJSONMarshalFunc

return ctx
},
}
)
Expand Down
16 changes: 10 additions & 6 deletions context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@ func Test_AcquireRequestCtx(t *testing.T) {
ctx := new(fasthttp.RequestCtx)
actx := AcquireRequestCtx(ctx) // nolint:ifshort

if !isEqual(actx.jsonMarshalFunc, defaultJSONMarshalFunc) {
t.Errorf("jsonMarshalFunc = %p, want %p", actx.jsonMarshalFunc, defaultJSONMarshalFunc)
}

if actx.RequestCtx != ctx {
t.Errorf("AcquireRequestCtx() = %p, want %p", actx.RequestCtx, ctx)
t.Errorf("RequestCtx = %p, want %p", actx.RequestCtx, ctx)
}
}

Expand All @@ -31,19 +35,19 @@ func Test_ReleaseRequestCtx(t *testing.T) {
ReleaseRequestCtx(actx)

if actx.next {
t.Errorf("reset() next is not 'false'")
t.Errorf("next is not 'false'")
}

if actx.skipView {
t.Errorf("reset() skipView is not 'false'")
t.Errorf("skipView is not 'false'")
}

if actx.RequestCtx != nil {
t.Errorf("reset() *fasthttp.RequestCtx = %p, want %v", actx.RequestCtx, nil)
if actx.jsonMarshalFunc == nil {
t.Errorf("jsonMarshalFunc is nil")
}

if actx.RequestCtx != nil {
t.Errorf("ReleaseRequestCtx() *fasthttp.RequestCtx = %p, want %v", actx.RequestCtx, nil)
t.Errorf("*fasthttp.RequestCtx = %p, want %v", actx.RequestCtx, nil)
}
}

Expand Down
9 changes: 7 additions & 2 deletions response.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@ package atreugo

import (
"encoding/json"
"io"

"github.com/valyala/bytebufferpool"
"github.com/valyala/fasthttp"
)

func defaultJSONMarshalFunc(w io.Writer, body interface{}) error {
return json.NewEncoder(w).Encode(body) // nolint:wrapcheck
}

// JSONResponse return response with body in json format.
func (ctx *RequestCtx) JSONResponse(body interface{}, statusCode ...int) (err error) {
func (ctx *RequestCtx) JSONResponse(body interface{}, statusCode ...int) error {
ctx.Response.Header.SetContentType("application/json")

if len(statusCode) > 0 {
Expand All @@ -17,7 +22,7 @@ func (ctx *RequestCtx) JSONResponse(body interface{}, statusCode ...int) (err er

w := ctx.Response.BodyWriter()

return json.NewEncoder(w).Encode(body) // nolint:wrapcheck
return ctx.jsonMarshalFunc(w, body) // nolint:wrapcheck
}

// HTTPResponse return response with body in html format.
Expand Down
38 changes: 31 additions & 7 deletions response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"errors"
"fmt"
"io"
"os"
"path"
"testing"
Expand All @@ -23,8 +24,9 @@ func (cj customJSON) MarshalJSON() ([]byte, error) {

func TestJSONResponse(t *testing.T) { //nolint:funlen
type args struct {
body interface{}
statusCode int
body interface{}
statusCode int
jsonMarshalFunc JSONMarshalFunc
}

type want struct {
Expand All @@ -42,8 +44,9 @@ func TestJSONResponse(t *testing.T) { //nolint:funlen
{
name: "ValidBody",
args: args{
body: JSON{"test": true},
statusCode: 200,
body: JSON{"test": true},
statusCode: 200,
jsonMarshalFunc: defaultJSONMarshalFunc,
},
want: want{
body: "{\"test\":true}",
Expand All @@ -55,8 +58,9 @@ func TestJSONResponse(t *testing.T) { //nolint:funlen
{
name: "BodyAsJsonMarshaler",
args: args{
body: customJSON{Value: "test"},
statusCode: 200,
body: customJSON{Value: "test"},
statusCode: 200,
jsonMarshalFunc: defaultJSONMarshalFunc,
},
want: want{
body: "{\"Value\":\"test\"}",
Expand All @@ -66,10 +70,29 @@ func TestJSONResponse(t *testing.T) { //nolint:funlen
},
},
{
name: "InvalidBody",
name: "CustomJSONMarshalFunc",
args: args{
body: make(chan int),
statusCode: 200,
jsonMarshalFunc: func(w io.Writer, value interface{}) error {
_, err := w.Write([]byte("my custom response"))

return err // nolint:wrapcheck
},
},
want: want{
body: "my custom response",
statusCode: 200,
contentType: "application/json",
err: false,
},
},
{
name: "InvalidBody",
args: args{
body: make(chan int),
statusCode: 200,
jsonMarshalFunc: defaultJSONMarshalFunc,
},
want: want{
body: "",
Expand All @@ -88,6 +111,7 @@ func TestJSONResponse(t *testing.T) { //nolint:funlen

ctx := new(fasthttp.RequestCtx)
actx := AcquireRequestCtx(ctx)
actx.jsonMarshalFunc = tt.args.jsonMarshalFunc

err := actx.JSONResponse(tt.args.body, tt.args.statusCode)
if tt.want.err && (err == nil) {
Expand Down
11 changes: 9 additions & 2 deletions types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package atreugo

import (
"crypto/tls"
"io"
"net"
"os"
"time"
Expand All @@ -21,6 +22,8 @@ type preforkServer interface {
ListenAndServe(addr string) error
}

type JSONMarshalFunc func(w io.Writer, value interface{}) error

// Atreugo implements high performance HTTP server
//
// It is prohibited copying Atreugo values. Create new values instead.
Expand Down Expand Up @@ -111,6 +114,9 @@ type Config struct { // nolint:maligned
// in 'Accept-Encoding' header.
Compress bool

// JSONMarshalFunc is used to marshal the response body to writer as JSON.
JSONMarshalFunc JSONMarshalFunc

// Configurable view which is called when no matching route is
// found. If it is not set, http.NotFound is used.
NotFoundView View
Expand Down Expand Up @@ -484,8 +490,9 @@ type StaticFS struct {
type RequestCtx struct {
noCopy nocopy.NoCopy // nolint:structcheck,unused

next bool
skipView bool
next bool
skipView bool
jsonMarshalFunc JSONMarshalFunc

// Flag to avoid stack overflow when this context has been embedded in the attached context
searchingOnAttachedCtx int32
Expand Down