diff --git a/languages/go/go-server/encode_json_test.go b/languages/go/go-server/encode_json_test.go index fc664b24..0a08b5d4 100644 --- a/languages/go/go-server/encode_json_test.go +++ b/languages/go/go-server/encode_json_test.go @@ -3,9 +3,11 @@ package arri_test import ( arri "arri/languages/go/go-server" "encoding/json" + "fmt" "os" "testing" "time" + "unsafe" ) var testDate = time.Date(2001, time.January, 01, 16, 0, 0, 0, time.UTC) @@ -59,12 +61,24 @@ var userV2Input userV2 = userV2{ } func BenchmarkV2Encoding(b *testing.B) { + e, err := arri.CompileJSONEncoder(userV2Input, arri.KeyCasingCamelCase) + if err != nil { + b.Fatalf(err.Error()) + } + if e == nil { + b.Fatalf("Encoder is nil") + } + ptr := unsafe.Pointer(&userV2Input) + example, _ := e(ptr, arri.NewEncodingContext(arri.KeyCasingCamelCase)) + fmt.Println("RESULT", string(example)) for i := 0; i < b.N; i++ { - arri.EncodeJSONCompiled(userV2Input, arri.KeyCasingCamelCase) + e(ptr, arri.NewEncodingContext(arri.KeyCasingCamelCase)) } } func BenchmarkStdEncodingAgainstV2(b *testing.B) { + result, _ := json.Marshal(userV2Input) + fmt.Println("RESULT", string(result)) for i := 0; i < b.N; i++ { _, err := json.Marshal(userV2Input) if err != nil { diff --git a/languages/go/go-server/encode_json_v2.go b/languages/go/go-server/encode_json_v2.go index de5366a6..36220233 100644 --- a/languages/go/go-server/encode_json_v2.go +++ b/languages/go/go-server/encode_json_v2.go @@ -7,15 +7,21 @@ import ( "unsafe" "github.com/iancoleman/strcase" - "github.com/viant/xunsafe" ) +type encoder = func(input unsafe.Pointer, context *Ctx) ([]byte, error) + +var encoderPool = sync.Pool{} + type Ctx struct { target []byte currentDepth uint32 maxDepth uint32 enumValues []string keyCasing KeyCasing + instancePath string + schemaPath string + hasKeys bool } func (c *Ctx) Reset() { @@ -24,69 +30,131 @@ func (c *Ctx) Reset() { c.maxDepth = 0 c.enumValues = []string{} c.keyCasing = KeyCasingCamelCase + c.hasKeys = false } -var encodingJsonV2SyncPool = sync.Pool{} +func (c *Ctx) SetEnumValues(vals []string) { + c.enumValues = vals +} -func newEncodingContext(keyCasing KeyCasing) *Ctx { - if v := encodingJsonV2SyncPool.Get(); v != nil { - e := v.(*Ctx) - e.Reset() - e.keyCasing = keyCasing - return e - } +func (c *Ctx) SetDepth(depth uint32) { + c.currentDepth = depth + +} + +func (c *Ctx) SetInstancePath(path string) { + c.instancePath = path +} + +func (c *Ctx) SetSchemaPath(path string) { + c.schemaPath = path +} + +func NewEncodingContext(keyCasing KeyCasing) *Ctx { return &Ctx{ target: []byte{}, currentDepth: 0, - maxDepth: 0, + maxDepth: 10000, enumValues: []string{}, keyCasing: keyCasing, } } -func EncodeJSONCompiled(input interface{}, keyCasing KeyCasing) ([]byte, error) { - target := []byte{} - ctx := newEncodingContext(keyCasing) - defer encodingJsonV2SyncPool.Put(ctx) - typ := reflect.TypeOf(input) - ptr := xunsafe.EnsurePointer(input) - structToJSON(ptr, typ, ctx) - return target, nil +func CompileJSONEncoder(input interface{}, keyCasing KeyCasing) (encoder, error) { + ctx := NewEncodingContext(keyCasing) + t := reflect.TypeOf(input) + return typeToJSONEncoder(t, ctx) } -func fieldToJSON(input unsafe.Pointer, field *xunsafe.Field, ctx *Ctx) error { - switch field.Type.Kind() { - case reflect.Int: - return intToJSON(input, field, ctx) +func typeToJSONEncoder(t reflect.Type, ctx *Ctx) (encoder, error) { + if ctx.currentDepth > ctx.maxDepth { + return nil, fmt.Errorf("max depth exceeded: %v", ctx.instancePath) + } + switch t.Kind() { case reflect.Int8: + return int8ToJSONEncoder(t, ctx) case reflect.Int16: case reflect.Int32: case reflect.Int64: - case reflect.Uint: + case reflect.Int: + return intToJSONEncoder(t, ctx) case reflect.Uint8: case reflect.Uint16: case reflect.Uint32: case reflect.Uint64: + case reflect.Uint: case reflect.Struct: + return structToJSONEncoder(t, ctx) } - return nil + return nil, nil +} + +func intToJSONEncoder(_ reflect.Type, _ *Ctx) (encoder, error) { + return func(input unsafe.Pointer, context *Ctx) ([]byte, error) { + context.target = append(context.target, fmt.Sprintf("\"%v\"", *(*int)(input))...) + return context.target, nil + }, nil } -func intToJSON(input unsafe.Pointer, field *xunsafe.Field, ctx *Ctx) error { - val := field.Value(input).(int) - ctx.target = append(ctx.target, fmt.Sprint(val)...) - return nil +func int8ToJSONEncoder(_ reflect.Type, _ *Ctx) (encoder, error) { + return func(input unsafe.Pointer, context *Ctx) ([]byte, error) { + context.target = append(context.target, fmt.Sprint(*(*int8)(input))...) + return context.target, nil + }, nil } -func structToJSON(input unsafe.Pointer, t reflect.Type, ctx *Ctx) error { - ctx.target = append(ctx.target, '{') +func structToJSONEncoder(t reflect.Type, ctx *Ctx) (encoder, error) { + encoders := []encoder{} + instancePath := ctx.instancePath + schemaPath := ctx.schemaPath + currentDepth := ctx.currentDepth + ctx.hasKeys = false + ctx.currentDepth++ for i := 0; i < t.NumField(); i++ { - field := xunsafe.FieldByIndex(t, i) - fieldPtr := field.ValuePointer(input) + field := t.Field(i) fieldName := strcase.ToLowerCamel(field.Name) - ctx.target = append(ctx.target, "\""+fieldName+"\":"...) - fieldToJSON(fieldPtr, field, ctx) + fmt.Println("FIELD", fieldName, "TYPE", field.Type.Kind()) + switch ctx.keyCasing { + case KeyCasingCamelCase: + case KeyCasingPascalCase: + case KeyCasingSnakeCase: + default: + msg := fmt.Sprintf("Unsupported key casing at %v expected one of [%v, %v, %v]", ctx.instancePath, KeyCasingCamelCase, KeyCasingPascalCase, KeyCasingSnakeCase) + panic(msg) + } + ctx.target = append(ctx.target, `"`+fieldName+`":`...) + ctx.instancePath = instancePath + "/" + fieldName + ctx.schemaPath = schemaPath + "/properties/" + fieldName + e, err := typeToJSONEncoder(field.Type, ctx) + if e == nil { + continue + } + if err != nil { + panic(err) + } + encoders = append(encoders, func(input unsafe.Pointer, context *Ctx) ([]byte, error) { + if context.hasKeys { + context.target = append(context.target, ',') + } + context.target = appendString(context.target, "\""+fieldName+"\":", false) + result, err := e(unsafe.Pointer(uintptr(input)+field.Offset), context) + if err == nil { + context.hasKeys = true + return nil, err + } + return result, err + }) } - ctx.target = append(ctx.target, '}') - return nil + ctx.currentDepth = currentDepth + ctx.schemaPath = schemaPath + ctx.instancePath = instancePath + ctx.hasKeys = false + return func(input unsafe.Pointer, context *Ctx) ([]byte, error) { + context.target = append(context.target, '{') + for _, e := range encoders { + e(input, context) + } + context.target = append(context.target, '}') + return context.target, nil + }, nil }