diff --git a/znet/inject.go b/znet/inject.go index 9ca58b2..f6bf4a6 100644 --- a/znet/inject.go +++ b/znet/inject.go @@ -40,6 +40,8 @@ func invokeHandler(c *Context, v []reflect.Value) (err error) { c.render = &renderString{Format: vv} case error: err = vv + case Renderer: + c.render = vv case []byte: c.render = &renderByte{Data: vv} case ApiData: diff --git a/znet/inject_test.go b/znet/inject_test.go index b432396..0e496e4 100644 --- a/znet/inject_test.go +++ b/znet/inject_test.go @@ -139,6 +139,37 @@ func TestInjectMiddleware(t *testing.T) { tt.Equal("middleware", w.Body.String()) } +type customRenderer struct { + Text string + Err error +} + +func (c *customRenderer) Content(ctx *Context) (content []byte) { + if c.Err != nil { + ctx.SetStatus(500) + return []byte(c.Err.Error()) + } + ctx.SetStatus(200) + return []byte(c.Text) +} + +func TestCustomRenderer(t *testing.T) { + tt := zlsgo.NewTest(t) + r := newServer() + + w := newRequest(r, "GET", "/TestCustomRenderer", "/TestCustomRenderer", func(c *Context) *customRenderer { + return &customRenderer{Text: "test custom renderer"} + }) + tt.Equal(200, w.Code) + tt.Equal("test custom renderer", w.Body.String()) + + w = newRequest(r, "GET", "/TestCustomRendererError", "/TestCustomRendererError", func(c *Context) *customRenderer { + return &customRenderer{Err: errors.New("test custom renderer error")} + }) + tt.Equal(500, w.Code) + tt.Equal("test custom renderer error", w.Body.String()) +} + func BenchmarkInjectNo(b *testing.B) { r := newServer() path := "/BenchmarkInjectNo" diff --git a/znet/render.go b/znet/render.go index 9e7fb66..ada4c33 100644 --- a/znet/render.go +++ b/znet/render.go @@ -18,7 +18,7 @@ import ( ) type ( - render interface { + Renderer interface { Content(c *Context) (content []byte) } renderByte struct { @@ -71,7 +71,7 @@ var ( ContentTypeJSON = "application/json; charset=utf-8" ) -func (c *Context) renderProcessing(code int32, r render) { +func (c *Context) renderProcessing(code int32, r Renderer) { // if c.stopHandle.Load() && c.prevData.Code.Load() != 0 { // return // } diff --git a/znet/web.go b/znet/web.go index 1b2fa4b..6e624b7 100644 --- a/znet/web.go +++ b/znet/web.go @@ -28,7 +28,7 @@ type ( // Context context Context struct { startTime time.Time - render render + render Renderer Writer http.ResponseWriter injector zdi.Injector stopHandle *zutil.Bool