Skip to content

Commit

Permalink
contextjson: Add context marshaler/unmarshaler
Browse files Browse the repository at this point in the history
  • Loading branch information
nekohasekai committed Oct 31, 2024
1 parent 94f0582 commit e07d30f
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 16 deletions.
23 changes: 23 additions & 0 deletions common/json/context_ext.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package json

import (
"context"

"github.com/sagernet/sing/common/json/internal/contextjson"
)

var (
MarshalContext = json.MarshalContext
UnmarshalContext = json.UnmarshalContext
NewEncoderContext = json.NewEncoderContext
NewDecoderContext = json.NewDecoderContext
UnmarshalContextDisallowUnknownFields = json.UnmarshalContextDisallowUnknownFields
)

type ContextMarshaler interface {
MarshalJSONContext(ctx context.Context) ([]byte, error)
}

type ContextUnmarshaler interface {
UnmarshalJSONContext(ctx context.Context, content []byte) error
}
11 changes: 11 additions & 0 deletions common/json/internal/contextjson/context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package json

import "context"

type ContextMarshaler interface {
MarshalJSONContext(ctx context.Context) ([]byte, error)
}

type ContextUnmarshaler interface {
UnmarshalJSONContext(ctx context.Context, content []byte) error
}
43 changes: 43 additions & 0 deletions common/json/internal/contextjson/context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package json_test

import (
"context"
"testing"

"github.com/sagernet/sing/common/json/internal/contextjson"

"github.com/stretchr/testify/require"
)

type myStruct struct {
value string
}

func (m *myStruct) MarshalJSONContext(ctx context.Context) ([]byte, error) {
return json.Marshal(ctx.Value("key").(string))
}

func (m *myStruct) UnmarshalJSONContext(ctx context.Context, content []byte) error {
m.value = ctx.Value("key").(string)
return nil
}

//nolint:staticcheck
func TestMarshalContext(t *testing.T) {
t.Parallel()
ctx := context.WithValue(context.Background(), "key", "value")
var s myStruct
b, err := json.MarshalContext(ctx, &s)
require.NoError(t, err)
require.Equal(t, []byte(`"value"`), b)
}

//nolint:staticcheck
func TestUnmarshalContext(t *testing.T) {
t.Parallel()
ctx := context.WithValue(context.Background(), "key", "value")
var s myStruct
err := json.UnmarshalContext(ctx, []byte(`{}`), &s)
require.NoError(t, err)
require.Equal(t, "value", s.value)
}
49 changes: 42 additions & 7 deletions common/json/internal/contextjson/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package json

import (
"context"
"encoding"
"encoding/base64"
"fmt"
Expand Down Expand Up @@ -95,10 +96,15 @@ import (
// Instead, they are replaced by the Unicode replacement
// character U+FFFD.
func Unmarshal(data []byte, v any) error {
return UnmarshalContext(context.Background(), data, v)
}

func UnmarshalContext(ctx context.Context, data []byte, v any) error {
// Check for well-formedness.
// Avoids filling out half a data structure
// before discovering a JSON syntax error.
var d decodeState
d.ctx = ctx
err := checkValid(data, &d.scan)
if err != nil {
return err
Expand Down Expand Up @@ -209,6 +215,7 @@ type errorContext struct {

// decodeState represents the state while decoding a JSON value.
type decodeState struct {
ctx context.Context
data []byte
off int // next read offset in data
opcode int // last read result
Expand Down Expand Up @@ -428,7 +435,7 @@ func (d *decodeState) valueQuoted() any {
// If it encounters an Unmarshaler, indirect stops and returns that.
// If decodingNull is true, indirect stops at the first settable pointer so it
// can be set to nil.
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) {
func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshaler, encoding.TextUnmarshaler, reflect.Value) {
// Issue #24153 indicates that it is generally not a guaranteed property
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
// and expect the value to still be settable for values derived from
Expand Down Expand Up @@ -482,11 +489,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm
}
if v.Type().NumMethod() > 0 && v.CanInterface() {
if u, ok := v.Interface().(Unmarshaler); ok {
return u, nil, reflect.Value{}
return u, nil, nil, reflect.Value{}
}
if cu, ok := v.Interface().(ContextUnmarshaler); ok {
return nil, cu, nil, reflect.Value{}
}
if !decodingNull {
if u, ok := v.Interface().(encoding.TextUnmarshaler); ok {
return nil, u, reflect.Value{}
return nil, nil, u, reflect.Value{}
}
}
}
Expand All @@ -498,14 +508,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm
v = v.Elem()
}
}
return nil, nil, v
return nil, nil, nil, v
}

// array consumes an array from d.data[d.off-1:], decoding into v.
// The first byte of the array ('[') has been read already.
func (d *decodeState) array(v reflect.Value) error {
// Check for unmarshaler.
u, ut, pv := indirect(v, false)
u, cu, ut, pv := indirect(v, false)
if u != nil {
start := d.readIndex()
d.skip()
Expand All @@ -515,6 +525,15 @@ func (d *decodeState) array(v reflect.Value) error {
}
return nil
}
if cu != nil {
start := d.readIndex()
d.skip()
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
if err != nil {
d.saveError(err)
}
return nil
}
if ut != nil {
d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)})
d.skip()
Expand Down Expand Up @@ -612,7 +631,7 @@ var (
// The first byte ('{') of the object has been read already.
func (d *decodeState) object(v reflect.Value) error {
// Check for unmarshaler.
u, ut, pv := indirect(v, false)
u, cu, ut, pv := indirect(v, false)
if u != nil {
start := d.readIndex()
d.skip()
Expand All @@ -622,6 +641,15 @@ func (d *decodeState) object(v reflect.Value) error {
}
return nil
}
if cu != nil {
start := d.readIndex()
d.skip()
err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off])
if err != nil {
d.saveError(err)
}
return nil
}
if ut != nil {
d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)})
d.skip()
Expand Down Expand Up @@ -870,14 +898,21 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool
return nil
}
isNull := item[0] == 'n' // null
u, ut, pv := indirect(v, isNull)
u, cu, ut, pv := indirect(v, isNull)
if u != nil {
err := u.UnmarshalJSON(item)
if err != nil {
d.saveError(err)
}
return nil
}
if cu != nil {
err := cu.UnmarshalJSONContext(d.ctx, item)
if err != nil {
d.saveError(err)
}
return nil
}
if ut != nil {
if item[0] != '"' {
if fromQuoted {
Expand Down
66 changes: 60 additions & 6 deletions common/json/internal/contextjson/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ package json

import (
"bytes"
"context"
"encoding"
"encoding/base64"
"fmt"
Expand Down Expand Up @@ -156,7 +157,11 @@ import (
// handle them. Passing cyclic structures to Marshal will result in
// an error.
func Marshal(v any) ([]byte, error) {
e := newEncodeState()
return MarshalContext(context.Background(), v)
}

func MarshalContext(ctx context.Context, v any) ([]byte, error) {
e := newEncodeState(ctx)
defer encodeStatePool.Put(e)

err := e.marshal(v, encOpts{escapeHTML: true})
Expand Down Expand Up @@ -251,6 +256,7 @@ var hex = "0123456789abcdef"
type encodeState struct {
bytes.Buffer // accumulated output

ctx context.Context
// Keep track of what pointers we've seen in the current recursive call
// path, to avoid cycles that could lead to a stack overflow. Only do
// the relatively expensive map operations if ptrLevel is larger than
Expand All @@ -264,7 +270,7 @@ const startDetectingCyclesAfter = 1000

var encodeStatePool sync.Pool

func newEncodeState() *encodeState {
func newEncodeState(ctx context.Context) *encodeState {
if v := encodeStatePool.Get(); v != nil {
e := v.(*encodeState)
e.Reset()
Expand All @@ -274,7 +280,7 @@ func newEncodeState() *encodeState {
e.ptrLevel = 0
return e
}
return &encodeState{ptrSeen: make(map[any]struct{})}
return &encodeState{ctx: ctx, ptrSeen: make(map[any]struct{})}
}

// jsonError is an error wrapper type for internal use only.
Expand Down Expand Up @@ -371,8 +377,9 @@ func typeEncoder(t reflect.Type) encoderFunc {
}

var (
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem()
contextMarshalerType = reflect.TypeOf((*ContextMarshaler)(nil)).Elem()
textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
)

// newTypeEncoder constructs an encoderFunc for a type.
Expand All @@ -385,9 +392,15 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc {
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) {
return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false))
}
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(contextMarshalerType) {
return newCondAddrEncoder(addrContextMarshalerEncoder, newTypeEncoder(t, false))
}
if t.Implements(marshalerType) {
return marshalerEncoder
}
if t.Implements(contextMarshalerType) {
return contextMarshalerEncoder
}
if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) {
return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false))
}
Expand Down Expand Up @@ -470,6 +483,47 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
}
}

func contextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.Kind() == reflect.Pointer && v.IsNil() {
e.WriteString("null")
return
}
m, ok := v.Interface().(ContextMarshaler)
if !ok {
e.WriteString("null")
return
}
b, err := m.MarshalJSONContext(e.ctx)
if err == nil {
e.Grow(len(b))
out := availableBuffer(&e.Buffer)
out, err = appendCompact(out, b, opts.escapeHTML)
e.Buffer.Write(out)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
}
}

func addrContextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
va := v.Addr()
if va.IsNil() {
e.WriteString("null")
return
}
m := va.Interface().(ContextMarshaler)
b, err := m.MarshalJSONContext(e.ctx)
if err == nil {
e.Grow(len(b))
out := availableBuffer(&e.Buffer)
out, err = appendCompact(out, b, opts.escapeHTML)
e.Buffer.Write(out)
}
if err != nil {
e.error(&MarshalerError{v.Type(), err, "MarshalJSON"})
}
}

func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) {
if v.Kind() == reflect.Pointer && v.IsNil() {
e.WriteString("null")
Expand Down Expand Up @@ -827,7 +881,7 @@ func newSliceEncoder(t reflect.Type) encoderFunc {
// Byte slices get special treatment; arrays don't.
if t.Elem().Kind() == reflect.Uint8 {
p := reflect.PointerTo(t.Elem())
if !p.Implements(marshalerType) && !p.Implements(textMarshalerType) {
if !p.Implements(marshalerType) && !p.Implements(contextMarshalerType) && !p.Implements(textMarshalerType) {
return encodeByteSlice
}
}
Expand Down
Loading

0 comments on commit e07d30f

Please sign in to comment.