diff --git a/common/json/context_ext.go b/common/json/context_ext.go new file mode 100644 index 000000000..aec149a22 --- /dev/null +++ b/common/json/context_ext.go @@ -0,0 +1,23 @@ +package json + +import ( + "context" + + "github.com/sagernet/sing/common/json/internal/contextjson" +) + +var ( + MarshalContext = json.MarshalContext + UnmarshalContext = json.UnmarshalContext + NewEncoderContext = json.NewEncoderContext + NewDecoderContext = json.NewDecoderContext + UnmarshalContextDisallowUnknownFields = json.UnmarshalContextDisallowUnknownFields +) + +type ContextMarshaler interface { + MarshalJSONContext(ctx context.Context) ([]byte, error) +} + +type ContextUnmarshaler interface { + UnmarshalJSONContext(ctx context.Context, content []byte) error +} diff --git a/common/json/internal/contextjson/context.go b/common/json/internal/contextjson/context.go new file mode 100644 index 000000000..ded69d7da --- /dev/null +++ b/common/json/internal/contextjson/context.go @@ -0,0 +1,11 @@ +package json + +import "context" + +type ContextMarshaler interface { + MarshalJSONContext(ctx context.Context) ([]byte, error) +} + +type ContextUnmarshaler interface { + UnmarshalJSONContext(ctx context.Context, content []byte) error +} diff --git a/common/json/internal/contextjson/context_test.go b/common/json/internal/contextjson/context_test.go new file mode 100644 index 000000000..cffecbb09 --- /dev/null +++ b/common/json/internal/contextjson/context_test.go @@ -0,0 +1,43 @@ +package json_test + +import ( + "context" + "testing" + + "github.com/sagernet/sing/common/json/internal/contextjson" + + "github.com/stretchr/testify/require" +) + +type myStruct struct { + value string +} + +func (m *myStruct) MarshalJSONContext(ctx context.Context) ([]byte, error) { + return json.Marshal(ctx.Value("key").(string)) +} + +func (m *myStruct) UnmarshalJSONContext(ctx context.Context, content []byte) error { + m.value = ctx.Value("key").(string) + return nil +} + +//nolint:staticcheck +func TestMarshalContext(t *testing.T) { + t.Parallel() + ctx := context.WithValue(context.Background(), "key", "value") + var s myStruct + b, err := json.MarshalContext(ctx, &s) + require.NoError(t, err) + require.Equal(t, []byte(`"value"`), b) +} + +//nolint:staticcheck +func TestUnmarshalContext(t *testing.T) { + t.Parallel() + ctx := context.WithValue(context.Background(), "key", "value") + var s myStruct + err := json.UnmarshalContext(ctx, []byte(`{}`), &s) + require.NoError(t, err) + require.Equal(t, "value", s.value) +} diff --git a/common/json/internal/contextjson/decode.go b/common/json/internal/contextjson/decode.go index 8457171e0..20c7ac680 100644 --- a/common/json/internal/contextjson/decode.go +++ b/common/json/internal/contextjson/decode.go @@ -8,6 +8,7 @@ package json import ( + "context" "encoding" "encoding/base64" "fmt" @@ -95,10 +96,15 @@ import ( // Instead, they are replaced by the Unicode replacement // character U+FFFD. func Unmarshal(data []byte, v any) error { + return UnmarshalContext(context.Background(), data, v) +} + +func UnmarshalContext(ctx context.Context, data []byte, v any) error { // Check for well-formedness. // Avoids filling out half a data structure // before discovering a JSON syntax error. var d decodeState + d.ctx = ctx err := checkValid(data, &d.scan) if err != nil { return err @@ -209,6 +215,7 @@ type errorContext struct { // decodeState represents the state while decoding a JSON value. type decodeState struct { + ctx context.Context data []byte off int // next read offset in data opcode int // last read result @@ -428,7 +435,7 @@ func (d *decodeState) valueQuoted() any { // If it encounters an Unmarshaler, indirect stops and returns that. // If decodingNull is true, indirect stops at the first settable pointer so it // can be set to nil. -func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnmarshaler, reflect.Value) { +func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, ContextUnmarshaler, encoding.TextUnmarshaler, reflect.Value) { // Issue #24153 indicates that it is generally not a guaranteed property // that you may round-trip a reflect.Value by calling Value.Addr().Elem() // and expect the value to still be settable for values derived from @@ -482,11 +489,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm } if v.Type().NumMethod() > 0 && v.CanInterface() { if u, ok := v.Interface().(Unmarshaler); ok { - return u, nil, reflect.Value{} + return u, nil, nil, reflect.Value{} + } + if cu, ok := v.Interface().(ContextUnmarshaler); ok { + return nil, cu, nil, reflect.Value{} } if !decodingNull { if u, ok := v.Interface().(encoding.TextUnmarshaler); ok { - return nil, u, reflect.Value{} + return nil, nil, u, reflect.Value{} } } } @@ -498,14 +508,14 @@ func indirect(v reflect.Value, decodingNull bool) (Unmarshaler, encoding.TextUnm v = v.Elem() } } - return nil, nil, v + return nil, nil, nil, v } // array consumes an array from d.data[d.off-1:], decoding into v. // The first byte of the array ('[') has been read already. func (d *decodeState) array(v reflect.Value) error { // Check for unmarshaler. - u, ut, pv := indirect(v, false) + u, cu, ut, pv := indirect(v, false) if u != nil { start := d.readIndex() d.skip() @@ -515,6 +525,15 @@ func (d *decodeState) array(v reflect.Value) error { } return nil } + if cu != nil { + start := d.readIndex() + d.skip() + err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off]) + if err != nil { + d.saveError(err) + } + return nil + } if ut != nil { d.saveError(&UnmarshalTypeError{Value: "array", Type: v.Type(), Offset: int64(d.off)}) d.skip() @@ -612,7 +631,7 @@ var ( // The first byte ('{') of the object has been read already. func (d *decodeState) object(v reflect.Value) error { // Check for unmarshaler. - u, ut, pv := indirect(v, false) + u, cu, ut, pv := indirect(v, false) if u != nil { start := d.readIndex() d.skip() @@ -622,6 +641,15 @@ func (d *decodeState) object(v reflect.Value) error { } return nil } + if cu != nil { + start := d.readIndex() + d.skip() + err := cu.UnmarshalJSONContext(d.ctx, d.data[start:d.off]) + if err != nil { + d.saveError(err) + } + return nil + } if ut != nil { d.saveError(&UnmarshalTypeError{Value: "object", Type: v.Type(), Offset: int64(d.off)}) d.skip() @@ -870,7 +898,7 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool return nil } isNull := item[0] == 'n' // null - u, ut, pv := indirect(v, isNull) + u, cu, ut, pv := indirect(v, isNull) if u != nil { err := u.UnmarshalJSON(item) if err != nil { @@ -878,6 +906,13 @@ func (d *decodeState) literalStore(item []byte, v reflect.Value, fromQuoted bool } return nil } + if cu != nil { + err := cu.UnmarshalJSONContext(d.ctx, item) + if err != nil { + d.saveError(err) + } + return nil + } if ut != nil { if item[0] != '"' { if fromQuoted { diff --git a/common/json/internal/contextjson/encode.go b/common/json/internal/contextjson/encode.go index 296177a53..27f901be0 100644 --- a/common/json/internal/contextjson/encode.go +++ b/common/json/internal/contextjson/encode.go @@ -12,6 +12,7 @@ package json import ( "bytes" + "context" "encoding" "encoding/base64" "fmt" @@ -156,7 +157,11 @@ import ( // handle them. Passing cyclic structures to Marshal will result in // an error. func Marshal(v any) ([]byte, error) { - e := newEncodeState() + return MarshalContext(context.Background(), v) +} + +func MarshalContext(ctx context.Context, v any) ([]byte, error) { + e := newEncodeState(ctx) defer encodeStatePool.Put(e) err := e.marshal(v, encOpts{escapeHTML: true}) @@ -251,6 +256,7 @@ var hex = "0123456789abcdef" type encodeState struct { bytes.Buffer // accumulated output + ctx context.Context // Keep track of what pointers we've seen in the current recursive call // path, to avoid cycles that could lead to a stack overflow. Only do // the relatively expensive map operations if ptrLevel is larger than @@ -264,7 +270,7 @@ const startDetectingCyclesAfter = 1000 var encodeStatePool sync.Pool -func newEncodeState() *encodeState { +func newEncodeState(ctx context.Context) *encodeState { if v := encodeStatePool.Get(); v != nil { e := v.(*encodeState) e.Reset() @@ -274,7 +280,7 @@ func newEncodeState() *encodeState { e.ptrLevel = 0 return e } - return &encodeState{ptrSeen: make(map[any]struct{})} + return &encodeState{ctx: ctx, ptrSeen: make(map[any]struct{})} } // jsonError is an error wrapper type for internal use only. @@ -371,8 +377,9 @@ func typeEncoder(t reflect.Type) encoderFunc { } var ( - marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() - textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() + marshalerType = reflect.TypeOf((*Marshaler)(nil)).Elem() + contextMarshalerType = reflect.TypeOf((*ContextMarshaler)(nil)).Elem() + textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem() ) // newTypeEncoder constructs an encoderFunc for a type. @@ -385,9 +392,15 @@ func newTypeEncoder(t reflect.Type, allowAddr bool) encoderFunc { if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(marshalerType) { return newCondAddrEncoder(addrMarshalerEncoder, newTypeEncoder(t, false)) } + if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(contextMarshalerType) { + return newCondAddrEncoder(addrContextMarshalerEncoder, newTypeEncoder(t, false)) + } if t.Implements(marshalerType) { return marshalerEncoder } + if t.Implements(contextMarshalerType) { + return contextMarshalerEncoder + } if t.Kind() != reflect.Pointer && allowAddr && reflect.PointerTo(t).Implements(textMarshalerType) { return newCondAddrEncoder(addrTextMarshalerEncoder, newTypeEncoder(t, false)) } @@ -470,6 +483,47 @@ func addrMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { } } +func contextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { + if v.Kind() == reflect.Pointer && v.IsNil() { + e.WriteString("null") + return + } + m, ok := v.Interface().(ContextMarshaler) + if !ok { + e.WriteString("null") + return + } + b, err := m.MarshalJSONContext(e.ctx) + if err == nil { + e.Grow(len(b)) + out := availableBuffer(&e.Buffer) + out, err = appendCompact(out, b, opts.escapeHTML) + e.Buffer.Write(out) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err, "MarshalJSON"}) + } +} + +func addrContextMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { + va := v.Addr() + if va.IsNil() { + e.WriteString("null") + return + } + m := va.Interface().(ContextMarshaler) + b, err := m.MarshalJSONContext(e.ctx) + if err == nil { + e.Grow(len(b)) + out := availableBuffer(&e.Buffer) + out, err = appendCompact(out, b, opts.escapeHTML) + e.Buffer.Write(out) + } + if err != nil { + e.error(&MarshalerError{v.Type(), err, "MarshalJSON"}) + } +} + func textMarshalerEncoder(e *encodeState, v reflect.Value, opts encOpts) { if v.Kind() == reflect.Pointer && v.IsNil() { e.WriteString("null") @@ -827,7 +881,7 @@ func newSliceEncoder(t reflect.Type) encoderFunc { // Byte slices get special treatment; arrays don't. if t.Elem().Kind() == reflect.Uint8 { p := reflect.PointerTo(t.Elem()) - if !p.Implements(marshalerType) && !p.Implements(textMarshalerType) { + if !p.Implements(marshalerType) && !p.Implements(contextMarshalerType) && !p.Implements(textMarshalerType) { return encodeByteSlice } } diff --git a/common/json/internal/contextjson/stream.go b/common/json/internal/contextjson/stream.go index a670ab14f..2849dbf97 100644 --- a/common/json/internal/contextjson/stream.go +++ b/common/json/internal/contextjson/stream.go @@ -6,6 +6,7 @@ package json import ( "bytes" + "context" "errors" "io" ) @@ -29,7 +30,11 @@ type Decoder struct { // The decoder introduces its own buffering and may // read data from r beyond the JSON values requested. func NewDecoder(r io.Reader) *Decoder { - return &Decoder{r: r} + return NewDecoderContext(context.Background(), r) +} + +func NewDecoderContext(ctx context.Context, r io.Reader) *Decoder { + return &Decoder{r: r, d: decodeState{ctx: ctx}} } // UseNumber causes the Decoder to unmarshal a number into an interface{} as a @@ -183,6 +188,7 @@ func nonSpace(b []byte) bool { // An Encoder writes JSON values to an output stream. type Encoder struct { + ctx context.Context w io.Writer err error escapeHTML bool @@ -194,7 +200,11 @@ type Encoder struct { // NewEncoder returns a new encoder that writes to w. func NewEncoder(w io.Writer) *Encoder { - return &Encoder{w: w, escapeHTML: true} + return NewEncoderContext(context.Background(), w) +} + +func NewEncoderContext(ctx context.Context, w io.Writer) *Encoder { + return &Encoder{ctx: ctx, w: w, escapeHTML: true} } // Encode writes the JSON encoding of v to the stream, @@ -207,7 +217,7 @@ func (enc *Encoder) Encode(v any) error { return enc.err } - e := newEncodeState() + e := newEncodeState(enc.ctx) defer encodeStatePool.Put(e) err := e.marshal(v, encOpts{escapeHTML: enc.escapeHTML}) diff --git a/common/json/internal/contextjson/unmarshal.go b/common/json/internal/contextjson/unmarshal.go index 29405395d..04c13cbe9 100644 --- a/common/json/internal/contextjson/unmarshal.go +++ b/common/json/internal/contextjson/unmarshal.go @@ -1,5 +1,7 @@ package json +import "context" + func UnmarshalDisallowUnknownFields(data []byte, v any) error { var d decodeState d.disallowUnknownFields = true @@ -10,3 +12,15 @@ func UnmarshalDisallowUnknownFields(data []byte, v any) error { d.init(data) return d.unmarshal(v) } + +func UnmarshalContextDisallowUnknownFields(ctx context.Context, data []byte, v any) error { + var d decodeState + d.ctx = ctx + d.disallowUnknownFields = true + err := checkValid(data, &d.scan) + if err != nil { + return err + } + d.init(data) + return d.unmarshal(v) +}