diff --git a/languages/go/go-server/decode_json.go b/languages/go/go-server/decode_json.go index a76ed74f..00dd6520 100644 --- a/languages/go/go-server/decode_json.go +++ b/languages/go/go-server/decode_json.go @@ -156,12 +156,19 @@ func typeFromJSON(data *gjson.Result, target reflect.Value, context *ValidationC case reflect.Bool: return boolFromJSON(data, target, context) case reflect.Struct: - if target.Type().Name() == "Time" { + t := target.Type() + if t.Name() == "Time" { return timestampFromJSON(data, target, context) } - if target.Type().Implements(reflect.TypeFor[JsonDecoder]()) { + if t.Implements(reflect.TypeFor[JsonDecoder]()) { return target.Interface().(JsonDecoder).DecodeJSON(data, target, context) } + if isOptionalType(t) { + return optionFromJson(data, target, context) + } + if isNullableType(t) { + return nullableFromJson(data, target, context) + } return structFromJSON(data, target, context) case reflect.Slice, reflect.Array: return arrayFromJSON(data, target, context) @@ -751,8 +758,8 @@ func optionFromJson(data *gjson.Result, target reflect.Value, context *Validatio } func nullableFromJson(data *gjson.Result, target reflect.Value, context *ValidationContext) bool { - if data.Type == gjson.Null { - return false + if data.Type == gjson.Null || !data.Exists() { + return true } val := target.FieldByName("Value") isSet := target.FieldByName("IsSet") diff --git a/languages/go/go-server/decode_json_test.go b/languages/go/go-server/decode_json_test.go index c23318db..9caeeaca 100644 --- a/languages/go/go-server/decode_json_test.go +++ b/languages/go/go-server/decode_json_test.go @@ -124,6 +124,70 @@ func TestDecodeObjectWithOptionalFields(t *testing.T) { } } +func TestDecodeObjectWithNullableFieldsAllNull(t *testing.T) { + input, inputErr := os.ReadFile("../../../tests/test-files/ObjectWithNullableFields_AllNull.json") + if inputErr != nil { + t.Errorf(inputErr.Error()) + return + } + result := objectWithNullableFields{} + expectedResult := objectWithNullableFields{} + err := arri.DecodeJSON(input, &result, arri.KeyCasingCamelCase) + if err != nil { + t.Errorf(err.Error()) + return + } + if !reflect.DeepEqual(result, expectedResult) { + t.Errorf(deepEqualErrString(result, expectedResult)) + return + } +} + +func TestDecodeObjectWithNullableFieldsNoNull(t *testing.T) { + input, inputErr := os.ReadFile("../../../tests/test-files/ObjectWithNullableFields_NoNull.json") + if inputErr != nil { + t.Errorf(inputErr.Error()) + return + } + result := objectWithNullableFields{} + expectedResult := objectWithNullableFields{ + String: arri.NotNull(""), + Boolean: arri.NotNull(true), + Timestamp: arri.NotNull(testDate), + Float32: arri.NotNull[float32](1.5), + Float64: arri.NotNull(1.5), + Int8: arri.NotNull[int8](1), + Uint8: arri.NotNull[uint8](1), + Int16: arri.NotNull[int16](10), + Uint16: arri.NotNull[uint16](10), + Int32: arri.NotNull[int32](100), + Uint32: arri.NotNull[uint32](100), + Int64: arri.NotNull[int64](1000), + Uint64: arri.NotNull[uint64](1000), + Enum: arri.NotNull("BAZ"), + Object: arri.NotNull(nestedObject{Id: "", Content: ""}), + Array: arri.NotNull([]bool{true, false, false}), + Record: arri.NotNull(arri.OrderedMapWithData(arri.Pair("A", true), arri.Pair("B", false))), + Discriminator: arri.NotNull(discriminator{C: &discriminatorC{Id: "", Name: "", Date: testDate}}), + Any: arri.NotNull[any](map[string]any{ + "message": "hello world", + }), + } + err := arri.DecodeJSON(input, &result, arri.KeyCasingCamelCase) + if err != nil { + t.Errorf(err.Error()) + return + } + if !reflect.DeepEqual(result, expectedResult) { + fmt.Println("RESULT_ANY", reflect.TypeOf(result.Any.Value)) + fmt.Printf("RESULT:\n%+v\n\n", result) + fmt.Printf("EXPECTED:\n%+v\n\n", expectedResult) + t.Errorf(deepEqualErrString(result, expectedResult)) + return + } + +} + type userWithPrivateFields struct { Id string Name string diff --git a/languages/go/go-server/procedures.go b/languages/go/go-server/procedures.go index 902faa40..13533664 100644 --- a/languages/go/go-server/procedures.go +++ b/languages/go/go-server/procedures.go @@ -39,17 +39,17 @@ func rpc[TParams, TResponse any, TContext Context](app *App[TContext], serviceNa rpcSchema.Http.Path = app.Options.RpcRoutePrefix + options.Path } if len(options.Description) > 0 { - rpcSchema.Http.Description = Some(options.Description) + rpcSchema.Http.Description.Set(options.Description) } if options.IsDeprecated { - rpcSchema.Http.IsDeprecated = Some(options.IsDeprecated) + rpcSchema.Http.IsDeprecated.Set(options.IsDeprecated) } params := handlerType.In(0) if params.Kind() != reflect.Struct { panic("rpc params must be a struct. pointers and other types are not allowed.") } paramsName := getModelName(rpcName, params.Name(), "Params") - hasParams := paramsName != "EmptyMessage" + hasParams := !isEmptyMessage(params) if hasParams { paramsDefContext := _NewTypeDefContext(app.Options.KeyCasing) paramsSchema, paramsSchemaErr := typeToTypeDef(params, paramsDefContext) @@ -59,15 +59,17 @@ func rpc[TParams, TResponse any, TContext Context](app *App[TContext], serviceNa if paramsSchema.Metadata.IsNone() { panic("Procedures cannot accept anonymous structs") } - rpcSchema.Http.Params = Some(paramsName) + rpcSchema.Http.Params.Set(paramsName) app.Definitions.Set(paramsName, *paramsSchema) + } else { + rpcSchema.Http.Params.Unset() } response := handlerType.Out(0) if response.Kind() == reflect.Ptr { response = response.Elem() } responseName := getModelName(rpcName, response.Name(), "Response") - hasResponse := !(responseName == "EmptyMessage" && response.PkgPath() == "arrirpc.com/arri") + hasResponse := !isEmptyMessage(response) if hasResponse { responseDefContext := _NewTypeDefContext(app.Options.KeyCasing) responseSchema, responseSchemaErr := typeToTypeDef(response, responseDefContext) @@ -77,8 +79,10 @@ func rpc[TParams, TResponse any, TContext Context](app *App[TContext], serviceNa if responseSchema.Metadata.IsNone() { panic("Procedures cannot return anonymous structs") } - rpcSchema.Http.Response = Some(responseName) + rpcSchema.Http.Response.Set(responseName) app.Definitions.Set(responseName, *responseSchema) + } else { + rpcSchema.Http.Response.Unset() } app.Procedures.Set(rpcName, *rpcSchema) onRequest := app.Options.OnRequest diff --git a/languages/go/go-server/procedures_sse.go b/languages/go/go-server/procedures_sse.go index 543a56cd..3bcbe42e 100644 --- a/languages/go/go-server/procedures_sse.go +++ b/languages/go/go-server/procedures_sse.go @@ -123,7 +123,7 @@ func eventStreamRpc[TParams, TResponse any, TContext Context](app *App[TContext] panic("rpc params must be a struct. pointers and other types are not allowed.") } paramName := getModelName(rpcName, params.Name(), "Params") - hasParams := !(paramName == "EmptyMessage" && params.PkgPath() == "arrirpc.com/arri") + hasParams := !(paramName == "EmptyMessage" && params.PkgPath() == "github.com/modiimedia/arri") if hasParams { paramsDefContext := _NewTypeDefContext(app.Options.KeyCasing) paramsSchema, paramsSchemaErr := typeToTypeDef(params, paramsDefContext) @@ -133,15 +133,17 @@ func eventStreamRpc[TParams, TResponse any, TContext Context](app *App[TContext] if paramsSchema.Metadata.IsNone() { panic("Procedures cannot accept anonymous structs") } - rpcSchema.Http.Params = Some(paramName) + rpcSchema.Http.Params.Set(paramName) app.Definitions.Set(paramName, *paramsSchema) + } else { + rpcSchema.Http.Params.Unset() } response := reflect.TypeFor[TResponse]() if response.Kind() == reflect.Ptr { response = response.Elem() } responseName := getModelName(rpcName, response.Name(), "Response") - hasResponse := !(responseName == "EmptyMessage" && response.PkgPath() == "arrirpc.com/arri") + hasResponse := !(responseName == "EmptyMessage" && response.PkgPath() == "github.com/modiimedia/arri") if hasResponse { responseDefContext := _NewTypeDefContext(app.Options.KeyCasing) responseSchema, responseSchemaErr := typeToTypeDef(response, responseDefContext) @@ -151,8 +153,10 @@ func eventStreamRpc[TParams, TResponse any, TContext Context](app *App[TContext] if responseSchema.Metadata.IsNone() { panic("Procedures cannot return anonymous structs") } - rpcSchema.Http.Response = Some(responseName) + rpcSchema.Http.Response.Set(responseName) app.Definitions.Set(responseName, *responseSchema) + } else { + rpcSchema.Http.Response.Unset() } app.Procedures.Set(rpcName, *rpcSchema) onRequest, _, onAfterResponse, onError := getHooks(app) diff --git a/languages/go/go-server/reflect_helpers.go b/languages/go/go-server/reflect_helpers.go index 3276fd14..47afe518 100644 --- a/languages/go/go-server/reflect_helpers.go +++ b/languages/go/go-server/reflect_helpers.go @@ -64,3 +64,7 @@ func getSerialKey(field *reflect.StructField, keyCasing KeyCasing) string { } return strcase.ToLowerCamel(field.Name) } + +func isEmptyMessage(t reflect.Type) bool { + return t.Name() == "EmptyMessage" && t.PkgPath() == "github.com/modiimedia/arri" +} diff --git a/languages/go/go-server/type_helpers.go b/languages/go/go-server/type_helpers.go index 76f98af6..90ef5bb7 100644 --- a/languages/go/go-server/type_helpers.go +++ b/languages/go/go-server/type_helpers.go @@ -4,6 +4,8 @@ import ( "encoding/json" "fmt" "reflect" + + "github.com/tidwall/gjson" ) type DiscriminatorKey struct{} @@ -279,6 +281,56 @@ func (m OrderedMap[T]) EncodeJSON(keyCasing KeyCasing) ([]byte, error) { return result, nil } +func (m OrderedMap[T]) String() string { + result := "OrderedMap[" + for i := 0; i < len(m.keys); i++ { + if i > 0 { + result += " " + } + key := m.keys[i] + val := m.values[i] + result += key + result += ":" + result += fmt.Sprintf("%+v", val) + } + result += "]" + return result +} + +func (m OrderedMap[T]) DecodeJSON(data *gjson.Result, target reflect.Value, context *ValidationContext) bool { + switch data.Type { + case gjson.Null, gjson.False, gjson.String, gjson.Number: + *context.Errors = append(*context.Errors, newValidationErrorItem("expected object", context.InstancePath, context.SchemaPath)) + return false + } + valuesResult := []T{} + keysResult := []string{} + gjsonMap := data.Map() + instancePath := context.InstancePath + schemaPath := context.SchemaPath + context.SchemaPath = context.SchemaPath + "/values" + context.CurrentDepth++ + for key, value := range gjsonMap { + valueTarget := reflect.New(reflect.TypeFor[T]()) + context.InstancePath = instancePath + "/" + key + valueResult := typeFromJSON(&value, valueTarget, context) + if !valueResult { + return false + } + valuesResult = append(valuesResult, *valueTarget.Interface().(*T)) + keysResult = append(keysResult, key) + } + context.InstancePath = instancePath + context.SchemaPath = schemaPath + context.CurrentDepth-- + result := OrderedMap[T]{ + keys: keysResult, + values: valuesResult, + } + target.Set(reflect.ValueOf(result)) + return true +} + func (m OrderedMap[T]) ToTypeDef(keyCasing KeyCasing) (*TypeDef, error) { subDef, subDefErr := TypeToTypeDef(reflect.TypeFor[T](), keyCasing) if subDefErr != nil { diff --git a/tests/servers/go/main.go b/tests/servers/go/main.go index 9560db01..3e8b3439 100644 --- a/tests/servers/go/main.go +++ b/tests/servers/go/main.go @@ -504,7 +504,9 @@ func StreamTenEventsThenEnd(_ arri.EmptyMessage, controller arri.SseController[C select { case <-t.C: msgCount++ - controller.Push(ChatMessage{}) + controller.Push(ChatMessage{ + ChatMessageText: &ChatMessageText{}, + }) if msgCount > 10 { panic("Message count exceeded 10. This means the ticker was not properly cleaned up.") }