Skip to content

Commit

Permalink
get all tests passing
Browse files Browse the repository at this point in the history
  • Loading branch information
joshmossas committed Nov 11, 2024
1 parent a067d61 commit 89badea
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 15 deletions.
15 changes: 11 additions & 4 deletions languages/go/go-server/decode_json.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,19 @@ func typeFromJSON(data *gjson.Result, target reflect.Value, context *ValidationC
case reflect.Bool:
return boolFromJSON(data, target, context)
case reflect.Struct:
if target.Type().Name() == "Time" {
t := target.Type()
if t.Name() == "Time" {
return timestampFromJSON(data, target, context)
}
if target.Type().Implements(reflect.TypeFor[JsonDecoder]()) {
if t.Implements(reflect.TypeFor[JsonDecoder]()) {
return target.Interface().(JsonDecoder).DecodeJSON(data, target, context)
}
if isOptionalType(t) {
return optionFromJson(data, target, context)
}
if isNullableType(t) {
return nullableFromJson(data, target, context)
}
return structFromJSON(data, target, context)
case reflect.Slice, reflect.Array:
return arrayFromJSON(data, target, context)
Expand Down Expand Up @@ -751,8 +758,8 @@ func optionFromJson(data *gjson.Result, target reflect.Value, context *Validatio
}

func nullableFromJson(data *gjson.Result, target reflect.Value, context *ValidationContext) bool {
if data.Type == gjson.Null {
return false
if data.Type == gjson.Null || !data.Exists() {
return true
}
val := target.FieldByName("Value")
isSet := target.FieldByName("IsSet")
Expand Down
64 changes: 64 additions & 0 deletions languages/go/go-server/decode_json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,70 @@ func TestDecodeObjectWithOptionalFields(t *testing.T) {
}
}

func TestDecodeObjectWithNullableFieldsAllNull(t *testing.T) {
input, inputErr := os.ReadFile("../../../tests/test-files/ObjectWithNullableFields_AllNull.json")
if inputErr != nil {
t.Errorf(inputErr.Error())
return
}
result := objectWithNullableFields{}
expectedResult := objectWithNullableFields{}
err := arri.DecodeJSON(input, &result, arri.KeyCasingCamelCase)
if err != nil {
t.Errorf(err.Error())
return
}
if !reflect.DeepEqual(result, expectedResult) {
t.Errorf(deepEqualErrString(result, expectedResult))
return
}
}

func TestDecodeObjectWithNullableFieldsNoNull(t *testing.T) {
input, inputErr := os.ReadFile("../../../tests/test-files/ObjectWithNullableFields_NoNull.json")
if inputErr != nil {
t.Errorf(inputErr.Error())
return
}
result := objectWithNullableFields{}
expectedResult := objectWithNullableFields{
String: arri.NotNull(""),
Boolean: arri.NotNull(true),
Timestamp: arri.NotNull(testDate),
Float32: arri.NotNull[float32](1.5),
Float64: arri.NotNull(1.5),
Int8: arri.NotNull[int8](1),
Uint8: arri.NotNull[uint8](1),
Int16: arri.NotNull[int16](10),
Uint16: arri.NotNull[uint16](10),
Int32: arri.NotNull[int32](100),
Uint32: arri.NotNull[uint32](100),
Int64: arri.NotNull[int64](1000),
Uint64: arri.NotNull[uint64](1000),
Enum: arri.NotNull("BAZ"),
Object: arri.NotNull(nestedObject{Id: "", Content: ""}),
Array: arri.NotNull([]bool{true, false, false}),
Record: arri.NotNull(arri.OrderedMapWithData(arri.Pair("A", true), arri.Pair("B", false))),
Discriminator: arri.NotNull(discriminator{C: &discriminatorC{Id: "", Name: "", Date: testDate}}),
Any: arri.NotNull[any](map[string]any{
"message": "hello world",
}),
}
err := arri.DecodeJSON(input, &result, arri.KeyCasingCamelCase)
if err != nil {
t.Errorf(err.Error())
return
}
if !reflect.DeepEqual(result, expectedResult) {
fmt.Println("RESULT_ANY", reflect.TypeOf(result.Any.Value))
fmt.Printf("RESULT:\n%+v\n\n", result)
fmt.Printf("EXPECTED:\n%+v\n\n", expectedResult)
t.Errorf(deepEqualErrString(result, expectedResult))
return
}

}

type userWithPrivateFields struct {
Id string
Name string
Expand Down
16 changes: 10 additions & 6 deletions languages/go/go-server/procedures.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,17 @@ func rpc[TParams, TResponse any, TContext Context](app *App[TContext], serviceNa
rpcSchema.Http.Path = app.Options.RpcRoutePrefix + options.Path
}
if len(options.Description) > 0 {
rpcSchema.Http.Description = Some(options.Description)
rpcSchema.Http.Description.Set(options.Description)
}
if options.IsDeprecated {
rpcSchema.Http.IsDeprecated = Some(options.IsDeprecated)
rpcSchema.Http.IsDeprecated.Set(options.IsDeprecated)
}
params := handlerType.In(0)
if params.Kind() != reflect.Struct {
panic("rpc params must be a struct. pointers and other types are not allowed.")
}
paramsName := getModelName(rpcName, params.Name(), "Params")
hasParams := paramsName != "EmptyMessage"
hasParams := !isEmptyMessage(params)
if hasParams {
paramsDefContext := _NewTypeDefContext(app.Options.KeyCasing)
paramsSchema, paramsSchemaErr := typeToTypeDef(params, paramsDefContext)
Expand All @@ -59,15 +59,17 @@ func rpc[TParams, TResponse any, TContext Context](app *App[TContext], serviceNa
if paramsSchema.Metadata.IsNone() {
panic("Procedures cannot accept anonymous structs")
}
rpcSchema.Http.Params = Some(paramsName)
rpcSchema.Http.Params.Set(paramsName)
app.Definitions.Set(paramsName, *paramsSchema)
} else {
rpcSchema.Http.Params.Unset()
}
response := handlerType.Out(0)
if response.Kind() == reflect.Ptr {
response = response.Elem()
}
responseName := getModelName(rpcName, response.Name(), "Response")
hasResponse := !(responseName == "EmptyMessage" && response.PkgPath() == "arrirpc.com/arri")
hasResponse := !isEmptyMessage(response)
if hasResponse {
responseDefContext := _NewTypeDefContext(app.Options.KeyCasing)
responseSchema, responseSchemaErr := typeToTypeDef(response, responseDefContext)
Expand All @@ -77,8 +79,10 @@ func rpc[TParams, TResponse any, TContext Context](app *App[TContext], serviceNa
if responseSchema.Metadata.IsNone() {
panic("Procedures cannot return anonymous structs")
}
rpcSchema.Http.Response = Some(responseName)
rpcSchema.Http.Response.Set(responseName)
app.Definitions.Set(responseName, *responseSchema)
} else {
rpcSchema.Http.Response.Unset()
}
app.Procedures.Set(rpcName, *rpcSchema)
onRequest := app.Options.OnRequest
Expand Down
12 changes: 8 additions & 4 deletions languages/go/go-server/procedures_sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ func eventStreamRpc[TParams, TResponse any, TContext Context](app *App[TContext]
panic("rpc params must be a struct. pointers and other types are not allowed.")
}
paramName := getModelName(rpcName, params.Name(), "Params")
hasParams := !(paramName == "EmptyMessage" && params.PkgPath() == "arrirpc.com/arri")
hasParams := !(paramName == "EmptyMessage" && params.PkgPath() == "github.com/modiimedia/arri")
if hasParams {
paramsDefContext := _NewTypeDefContext(app.Options.KeyCasing)
paramsSchema, paramsSchemaErr := typeToTypeDef(params, paramsDefContext)
Expand All @@ -133,15 +133,17 @@ func eventStreamRpc[TParams, TResponse any, TContext Context](app *App[TContext]
if paramsSchema.Metadata.IsNone() {
panic("Procedures cannot accept anonymous structs")
}
rpcSchema.Http.Params = Some(paramName)
rpcSchema.Http.Params.Set(paramName)
app.Definitions.Set(paramName, *paramsSchema)
} else {
rpcSchema.Http.Params.Unset()
}
response := reflect.TypeFor[TResponse]()
if response.Kind() == reflect.Ptr {
response = response.Elem()
}
responseName := getModelName(rpcName, response.Name(), "Response")
hasResponse := !(responseName == "EmptyMessage" && response.PkgPath() == "arrirpc.com/arri")
hasResponse := !(responseName == "EmptyMessage" && response.PkgPath() == "github.com/modiimedia/arri")
if hasResponse {
responseDefContext := _NewTypeDefContext(app.Options.KeyCasing)
responseSchema, responseSchemaErr := typeToTypeDef(response, responseDefContext)
Expand All @@ -151,8 +153,10 @@ func eventStreamRpc[TParams, TResponse any, TContext Context](app *App[TContext]
if responseSchema.Metadata.IsNone() {
panic("Procedures cannot return anonymous structs")
}
rpcSchema.Http.Response = Some(responseName)
rpcSchema.Http.Response.Set(responseName)
app.Definitions.Set(responseName, *responseSchema)
} else {
rpcSchema.Http.Response.Unset()
}
app.Procedures.Set(rpcName, *rpcSchema)
onRequest, _, onAfterResponse, onError := getHooks(app)
Expand Down
4 changes: 4 additions & 0 deletions languages/go/go-server/reflect_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,7 @@ func getSerialKey(field *reflect.StructField, keyCasing KeyCasing) string {
}
return strcase.ToLowerCamel(field.Name)
}

func isEmptyMessage(t reflect.Type) bool {
return t.Name() == "EmptyMessage" && t.PkgPath() == "github.com/modiimedia/arri"
}
52 changes: 52 additions & 0 deletions languages/go/go-server/type_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import (
"encoding/json"
"fmt"
"reflect"

"github.com/tidwall/gjson"
)

type DiscriminatorKey struct{}
Expand Down Expand Up @@ -279,6 +281,56 @@ func (m OrderedMap[T]) EncodeJSON(keyCasing KeyCasing) ([]byte, error) {
return result, nil
}

func (m OrderedMap[T]) String() string {
result := "OrderedMap["
for i := 0; i < len(m.keys); i++ {
if i > 0 {
result += " "
}
key := m.keys[i]
val := m.values[i]
result += key
result += ":"
result += fmt.Sprintf("%+v", val)
}
result += "]"
return result
}

func (m OrderedMap[T]) DecodeJSON(data *gjson.Result, target reflect.Value, context *ValidationContext) bool {
switch data.Type {
case gjson.Null, gjson.False, gjson.String, gjson.Number:
*context.Errors = append(*context.Errors, newValidationErrorItem("expected object", context.InstancePath, context.SchemaPath))
return false
}
valuesResult := []T{}
keysResult := []string{}
gjsonMap := data.Map()
instancePath := context.InstancePath
schemaPath := context.SchemaPath
context.SchemaPath = context.SchemaPath + "/values"
context.CurrentDepth++
for key, value := range gjsonMap {
valueTarget := reflect.New(reflect.TypeFor[T]())
context.InstancePath = instancePath + "/" + key
valueResult := typeFromJSON(&value, valueTarget, context)
if !valueResult {
return false
}
valuesResult = append(valuesResult, *valueTarget.Interface().(*T))
keysResult = append(keysResult, key)
}
context.InstancePath = instancePath
context.SchemaPath = schemaPath
context.CurrentDepth--
result := OrderedMap[T]{
keys: keysResult,
values: valuesResult,
}
target.Set(reflect.ValueOf(result))
return true
}

func (m OrderedMap[T]) ToTypeDef(keyCasing KeyCasing) (*TypeDef, error) {
subDef, subDefErr := TypeToTypeDef(reflect.TypeFor[T](), keyCasing)
if subDefErr != nil {
Expand Down
4 changes: 3 additions & 1 deletion tests/servers/go/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,9 @@ func StreamTenEventsThenEnd(_ arri.EmptyMessage, controller arri.SseController[C
select {
case <-t.C:
msgCount++
controller.Push(ChatMessage{})
controller.Push(ChatMessage{
ChatMessageText: &ChatMessageText{},
})
if msgCount > 10 {
panic("Message count exceeded 10. This means the ticker was not properly cleaned up.")
}
Expand Down

0 comments on commit 89badea

Please sign in to comment.