From 96800046d49ef30bf2b597038f683aea7910f8ad Mon Sep 17 00:00:00 2001 From: Alan Colon Date: Tue, 28 Jul 2020 01:33:40 +0000 Subject: [PATCH] Further flush out ability to auto-bind to functions and types, including recursion. Fix code coverage, support additional functions. Add bind examples Avoid errors where value IsZero Add type tag, nested resolvers on Bind Extend type Don't call .Type() on zero values --- bind.go | 251 ++++++++++++++++++++++++++++++++++ bind_test.go | 241 ++++++++++++++++++++++++++++++++ examples/bind-complex/main.go | 85 ++++++++++++ examples/bind-simple/main.go | 43 ++++++ executor.go | 2 +- util.go | 129 +++++++++++++---- 6 files changed, 724 insertions(+), 27 deletions(-) create mode 100644 bind.go create mode 100644 bind_test.go create mode 100644 examples/bind-complex/main.go create mode 100644 examples/bind-simple/main.go diff --git a/bind.go b/bind.go new file mode 100644 index 00000000..d8c8dfcd --- /dev/null +++ b/bind.go @@ -0,0 +1,251 @@ +package graphql + +import ( + "context" + "encoding/json" + "fmt" + "reflect" +) + +var ctxType = reflect.TypeOf((*context.Context)(nil)).Elem() +var errType = reflect.TypeOf((*error)(nil)).Elem() + +/* + Bind will create a Field around a function formatted a certain way, or any value. + + The input parameters can be, in any order, + - context.Context, or *context.Context (optional) + - An input struct, or pointer (optional) + + The output parameters can be, in any order, + - A primitive, an output struct, or pointer (required for use in schema) + - error (optional) + + Input or output types provided will be automatically bound using BindType. +*/ +func Bind(bindTo interface{}, additionalFields ...Fields) *Field { + combinedAdditionalFields := MergeFields(additionalFields...) + val := reflect.ValueOf(bindTo) + tipe := reflect.TypeOf(bindTo) + if tipe.Kind() == reflect.Func { + in := tipe.NumIn() + out := tipe.NumOut() + + var ctxIn *int + var inputIn *int + + var errOut *int + var outputOut *int + + queryArgs := FieldConfigArgument{} + + if in > 2 { + panic(fmt.Sprintf("Mismatch on number of inputs. Expected 0, 1, or 2. got %d.", tipe.NumIn())) + } + + if out > 2 { + panic(fmt.Sprintf("Mismatch on number of outputs. Expected 0, 1, or 2, got %d.", tipe.NumOut())) + } + + // inTypes := make([]reflect.Type, in) + // outTypes := make([]reflect.Type, out) + + for i := 0; i < in; i++ { + t := tipe.In(i) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + switch t { + case ctxType: + if ctxIn != nil { + panic(fmt.Sprintf("Unexpected multiple *context.Context inputs.")) + } + ctxIn = intP(i) + default: + if inputIn != nil { + panic(fmt.Sprintf("Unexpected multiple inputs.")) + } + inputType := tipe.In(i) + if inputType.Kind() == reflect.Ptr { + inputType = inputType.Elem() + } + inputFields := BindFields(reflect.New(inputType).Interface()) + for key, inputField := range inputFields { + queryArgs[key] = &ArgumentConfig{ + Type: inputField.Type, + } + } + + inputIn = intP(i) + } + } + + for i := 0; i < out; i++ { + t := tipe.Out(i) + switch t.String() { + case errType.String(): + if errOut != nil { + panic(fmt.Sprintf("Unexpected multiple error outputs")) + } + errOut = intP(i) + default: + if outputOut != nil { + panic(fmt.Sprintf("Unexpected multiple outputs")) + } + outputOut = intP(i) + } + } + + resolve := func(p ResolveParams) (output interface{}, err error) { + inputs := make([]reflect.Value, in) + if ctxIn != nil { + isPtr := tipe.In(*ctxIn).Kind() == reflect.Ptr + if isPtr { + if p.Context == nil { + inputs[*ctxIn] = reflect.New(ctxType) + } else { + inputs[*ctxIn] = reflect.ValueOf(&p.Context) + } + } else { + if p.Context == nil { + inputs[*ctxIn] = reflect.New(ctxType).Elem() + } else { + inputs[*ctxIn] = reflect.ValueOf(p.Context).Convert(ctxType).Elem() + } + } + } + if inputIn != nil { + var inputType, inputBaseType, sourceType, sourceBaseType reflect.Type + sourceVal := reflect.ValueOf(p.Source) + sourceExists := !sourceVal.IsZero() + if sourceExists { + sourceType = sourceVal.Type() + if sourceType.Kind() == reflect.Ptr { + sourceBaseType = sourceType.Elem() + } else { + sourceBaseType = sourceType + } + } + inputType = tipe.In(*inputIn) + isPtr := tipe.In(*inputIn).Kind() == reflect.Ptr + if isPtr { + inputBaseType = inputType.Elem() + } else { + inputBaseType = inputType + } + var input interface{} + if sourceExists && sourceBaseType.AssignableTo(inputBaseType) { + input = sourceVal.Interface() + } else { + input = reflect.New(inputBaseType).Interface() + j, err := json.Marshal(p.Args) + if err == nil { + err = json.Unmarshal(j, &input) + } + if err != nil { + return nil, err + } + } + + inputs[*inputIn], err = convertValue(reflect.ValueOf(input), inputType) + if err != nil { + return nil, err + } + } + results := val.Call(inputs) + if errOut != nil { + val := results[*errOut].Interface() + if val != nil { + err = val.(error) + } + if err != nil { + return output, err + } + } + if outputOut != nil { + var val reflect.Value + val, err = convertValue(results[*outputOut], tipe.Out(*outputOut)) + if err != nil { + return nil, err + } + if !val.IsZero() { + output = val.Interface() + } + } + return output, err + } + + var outputType Output + if outputOut != nil { + outputType = BindType(tipe.Out(*outputOut)) + extendType(outputType, combinedAdditionalFields) + } + + field := &Field{ + Type: outputType, + Resolve: resolve, + Args: queryArgs, + } + + return field + } else if tipe.Kind() == reflect.Struct { + fieldType := BindType(reflect.TypeOf(bindTo)) + extendType(fieldType, combinedAdditionalFields) + field := &Field{ + Type: fieldType, + Resolve: func(p ResolveParams) (data interface{}, err error) { + return bindTo, nil + }, + } + return field + } else { + if len(additionalFields) > 0 { + panic("Cannot add field resolvers to a scalar type.") + } + return &Field{ + Type: getGraphType(tipe), + Resolve: func(p ResolveParams) (data interface{}, err error) { + return bindTo, nil + }, + } + } +} + +func extendType(t Type, fields Fields) { + switch t.(type) { + case *Object: + object := t.(*Object) + for fieldName, fieldConfig := range fields { + object.AddFieldConfig(fieldName, fieldConfig) + } + return + case *List: + list := t.(*List) + extendType(list.OfType, fields) + return + } +} + +func convertValue(value reflect.Value, targetType reflect.Type) (ret reflect.Value, err error) { + if !value.IsValid() || value.IsZero() { + return reflect.Zero(targetType), nil + } + if value.Type().Kind() == reflect.Ptr { + if targetType.Kind() == reflect.Ptr { + return value, nil + } else { + return value.Elem(), nil + } + } else { + if targetType.Kind() == reflect.Ptr { + // Will throw an informative error + return value.Convert(targetType), nil + } else { + return value, nil + } + } +} + +func intP(i int) *int { + return &i +} diff --git a/bind_test.go b/bind_test.go new file mode 100644 index 00000000..f2fa89cb --- /dev/null +++ b/bind_test.go @@ -0,0 +1,241 @@ +package graphql_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "strings" + "testing" + "time" + + "github.com/graphql-go/graphql" +) + +type HelloOutput struct { + Message string `json:"message"` +} + +func Hello(ctx *context.Context) (output *HelloOutput, err error) { + output = &HelloOutput{ + Message: "Hello World", + } + return output, nil +} + +func Hellos() []HelloOutput { + return []HelloOutput{ + { + Message: "Hello One", + }, + { + Message: "Hello Two", + }, + } +} + +func Upper(ctx *context.Context, source HelloOutput) string { + return strings.ToUpper(source.Message) +} + +type GreetingInput struct { + Name string `json:"name"` +} + +type GreetingOutput struct { + Message string `json:"message"` + Timestamp time.Time `json:"timestamp"` +} + +func GreetingPtr(ctx *context.Context, input *GreetingInput) (output *GreetingOutput, err error) { + return &GreetingOutput{ + Message: fmt.Sprintf("Hello %s.", input.Name), + Timestamp: time.Now(), + }, nil +} + +func Greeting(ctx context.Context, input GreetingInput) (output GreetingOutput, err error) { + return GreetingOutput{ + Message: fmt.Sprintf("Hello %s.", input.Name), + Timestamp: time.Now(), + }, nil +} + +type FriendRecur struct { + Name string `json:"name"` + Friends []FriendRecur `json:"friends"` +} + +func friends(ctx *context.Context) (output *FriendRecur) { + recursiveFriendRecur := FriendRecur{ + Name: "Recursion", + } + recursiveFriendRecur.Friends = make([]FriendRecur, 2) + recursiveFriendRecur.Friends[0] = recursiveFriendRecur + recursiveFriendRecur.Friends[1] = recursiveFriendRecur + + return &FriendRecur{ + Name: "Alan", + Friends: []FriendRecur{ + recursiveFriendRecur, + { + Name: "Samantha", + Friends: []FriendRecur{ + { + Name: "Olivia", + }, + { + Name: "Eric", + }, + }, + }, + { + Name: "Brian", + Friends: []FriendRecur{ + { + Name: "Windy", + }, + { + Name: "Kevin", + }, + }, + }, + { + Name: "Kevin", + Friends: []FriendRecur{ + { + Name: "Sergei", + }, + { + Name: "Michael", + }, + }, + }, + }, + } +} + +func TestBindHappyPath(t *testing.T) { + // Schema + fields := graphql.Fields{ + "hello": graphql.Bind(Hello), + "hellos": graphql.Bind(Hellos, graphql.Fields{ + "upper": graphql.Bind(Upper), + }), + "greeting": graphql.Bind(Greeting), + "greetingPtr": graphql.Bind(GreetingPtr), + "friends": graphql.Bind(friends), + "string": graphql.Bind("Hello World"), + "number": graphql.Bind(12345), + "float": graphql.Bind(123.45), + "anonymous": graphql.Bind(struct { + SomeField string `json:"someField"` + }{ + SomeField: "Some Value", + }), + "simpleFunc": graphql.Bind(func() string { + return "Hello World" + }), + } + rootQuery := graphql.ObjectConfig{Name: "RootQuery", Fields: fields} + schemaConfig := graphql.SchemaConfig{Query: graphql.NewObject(rootQuery)} + schema, err := graphql.NewSchema(schemaConfig) + if err != nil { + log.Fatalf("failed to create new schema, error: %v", err) + } + + // Query + query := ` + { + hello { + message + upper + } + hellos { + message + upper + } + greeting(name:"Alan") { + message + timestamp + } + greetingPtr(name:"Alan") { + message + timestamp + } + friends { + name + friends { + name + friends { + name + friends { + name + friends { + name + } + } + } + } + } + string + number + float + anonymous { + someField + } + simpleFunc + } + ` + params := graphql.Params{Schema: schema, RequestString: query} + r := graphql.Do(params) + if len(r.Errors) > 0 { + t.Errorf("failed to execute graphql operation, errors: %+v", r.Errors) + } + json, err := json.MarshalIndent(r.Data, "", " ") + fmt.Println(string(json)) +} + +func TestBindPanicImproperInput(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected Bind to panic due to improper function signature") + } + }() + graphql.Bind(func(a, b, c string) {}) +} + +func TestBindPanicImproperOutput(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected Bind to panic due to improper function signature") + } + }() + graphql.Bind(func() (string, string) { return "Hello", "World" }) +} + +func TestBindWithRuntimeError(t *testing.T) { + rootQuery := graphql.ObjectConfig{Name: "RootQuery", Fields: graphql.Fields{ + "throwError": graphql.Bind(func() (string, error) { + return "", errors.New("Some Error") + }), + }} + schemaConfig := graphql.SchemaConfig{Query: graphql.NewObject(rootQuery)} + schema, err := graphql.NewSchema(schemaConfig) + if err != nil { + log.Fatalf("failed to create new schema, error: %v", err) + } + + // Query + query := ` + { + throwError + } + ` + params := graphql.Params{Schema: schema, RequestString: query} + r := graphql.Do(params) + if len(r.Errors) == 0 { + t.Error("Expected error") + } +} diff --git a/examples/bind-complex/main.go b/examples/bind-complex/main.go new file mode 100644 index 00000000..0bed337b --- /dev/null +++ b/examples/bind-complex/main.go @@ -0,0 +1,85 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + + "github.com/graphql-go/graphql" +) + +var people = []Person{ + { + Name: "Alan", + Friends: []Person{ + { + Name: "Nadeem", + Friends: []Person{ + { + Name: "Heidi", + }, + }, + }, + }, + }, +} + +type Person struct { + Name string `json:"name"` + Friends []Person `json:"friends"` +} + +type GetPersonInput struct { + Name string `json:"name"` +} + +type GetPersonOutput struct { + Person +} + +func GetPerson(ctx context.Context, input GetPersonInput) (*GetPersonOutput, error) { + for _, person := range people { + if person.Name == input.Name { + return &GetPersonOutput{ + Person: person, + }, nil + } + } + return nil, errors.New("Could not find person.") +} + +func main() { + rootQuery := graphql.ObjectConfig{Name: "RootQuery", Fields: graphql.Fields{ + "person": graphql.Bind(GetPerson), + }} + + schemaConfig := graphql.SchemaConfig{Query: graphql.NewObject(rootQuery)} + schema, err := graphql.NewSchema(schemaConfig) + if err != nil { + log.Fatalf("failed to create new schema, error: %v", err) + } + + // Query + query := ` + { + person(name: "Alan") { + name + friends { + name + friends { + name + } + } + } + } + ` + params := graphql.Params{Schema: schema, RequestString: query} + r := graphql.Do(params) + if len(r.Errors) > 0 { + log.Fatalf("failed to execute graphql operation, errors: %+v", r.Errors) + } + rJSON, _ := json.Marshal(r) + fmt.Printf("%s \n", rJSON) +} diff --git a/examples/bind-simple/main.go b/examples/bind-simple/main.go new file mode 100644 index 00000000..ada9aee5 --- /dev/null +++ b/examples/bind-simple/main.go @@ -0,0 +1,43 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + + "github.com/graphql-go/graphql" +) + +type GreetingInput struct { + Name string `json:"name"` +} + +func Greeting(input GreetingInput) string { + return fmt.Sprintf("Hello %s", input.Name) +} + +func main() { + rootQuery := graphql.ObjectConfig{Name: "RootQuery", Fields: graphql.Fields{ + "greeting": graphql.Bind(Greeting), + }} + + schemaConfig := graphql.SchemaConfig{Query: graphql.NewObject(rootQuery)} + schema, err := graphql.NewSchema(schemaConfig) + if err != nil { + log.Fatalf("failed to create new schema, error: %v", err) + } + + // Query + query := ` + { + greeting(name: "Alan") + } + ` + params := graphql.Params{Schema: schema, RequestString: query} + r := graphql.Do(params) + if len(r.Errors) > 0 { + log.Fatalf("failed to execute graphql operation, errors: %+v", r.Errors) + } + rJSON, _ := json.Marshal(r) + fmt.Printf("%s \n", rJSON) +} diff --git a/executor.go b/executor.go index 7440ae21..096f6fbc 100644 --- a/executor.go +++ b/executor.go @@ -943,7 +943,7 @@ func DefaultResolveFn(p ResolveParams) (interface{}, error) { } // try to resolve p.Source as a struct - if sourceVal.IsValid() && sourceVal.Type().Kind() == reflect.Ptr { + if sourceVal.IsValid() && !sourceVal.IsZero() && sourceVal.Type().Kind() == reflect.Ptr { sourceVal = sourceVal.Elem() } if !sourceVal.IsValid() { diff --git a/util.go b/util.go index ae374c33..a48ebca2 100644 --- a/util.go +++ b/util.go @@ -8,13 +8,88 @@ import ( ) const TAG = "json" +const TYPETAG = "graphql" + +var boundTypes = map[string]*Object{} +var anonTypes = 0 + +func MergeFields(fieldses ...Fields) (ret Fields) { + ret = Fields{} + for _, fields := range fieldses { + for key, field := range fields { + if _, ok := ret[key]; ok { + panic(fmt.Sprintf("Dupliate field: %s", key)) + } + ret[key] = field + } + } + return ret +} + +func BindType(tipe reflect.Type) Type { + if tipe.Kind() == reflect.Ptr { + tipe = tipe.Elem() + } + + kind := tipe.Kind() + switch kind { + case reflect.String: + return String + case reflect.Int, reflect.Int8, reflect.Int32, reflect.Int64: + return Int + case reflect.Float32, reflect.Float64: + return Float + case reflect.Bool: + return Boolean + case reflect.Slice: + return getGraphList(tipe) + } + + typeName := safeName(tipe) + object, ok := boundTypes[typeName] + if !ok { + // Allows for recursion + object = &Object{} + boundTypes[typeName] = object + *object = *NewObject(ObjectConfig{ + Name: typeName, + Fields: BindFields(reflect.New(tipe).Interface()), + }) + } + + return object +} + +func safeName(tipe reflect.Type) string { + name := fmt.Sprint(tipe) + if strings.HasPrefix(name, "struct ") { + anonTypes++ + name = fmt.Sprintf("Anon%d", anonTypes) + } else { + name = strings.Replace(fmt.Sprint(tipe), ".", "_", -1) + } + return name +} + +func getType(typeTag string) Output { + switch strings.ToLower(typeTag) { + case "int": + return Int + case "float": + return Float + case "string": + return String + case "boolean": + return Boolean + case "id": + return ID + case "datetime": + return DateTime + default: + panic(fmt.Sprintf("Unsupported graphql type: %s", typeTag)) + } +} -// can't take recursive slice type -// e.g -// type Person struct{ -// Friends []Person -// } -// it will throw panic stack-overflow func BindFields(obj interface{}) Fields { t := reflect.TypeOf(obj) v := reflect.ValueOf(obj) @@ -33,14 +108,17 @@ func BindFields(obj interface{}) Fields { continue } + typeTag := field.Tag.Get(TYPETAG) + fieldType := field.Type if fieldType.Kind() == reflect.Ptr { fieldType = fieldType.Elem() } - var graphType Output - if fieldType.Kind() == reflect.Struct { + if typeTag != "" { + graphType = getType(typeTag) + } else if fieldType.Kind() == reflect.Struct { itf := v.Field(i).Interface() if _, ok := itf.(encoding.TextMarshaler); ok { fieldType = reflect.TypeOf("") @@ -53,10 +131,7 @@ func BindFields(obj interface{}) Fields { fields = appendFields(fields, structFields) continue } else { - graphType = NewObject(ObjectConfig{ - Name: tag, - Fields: structFields, - }) + graphType = BindType(fieldType) } } @@ -110,11 +185,7 @@ func getGraphList(tipe reflect.Type) *List { } // finally bind object t := reflect.New(tipe.Elem()) - name := strings.Replace(fmt.Sprint(tipe.Elem()), ".", "_", -1) - obj := NewObject(ObjectConfig{ - Name: name, - Fields: BindFields(t.Elem().Interface()), - }) + obj := BindType(t.Elem().Type()) return NewList(obj) } @@ -132,21 +203,27 @@ func extractValue(originTag string, obj interface{}) interface{} { field := val.Type().Field(j) found := originTag == extractTag(field.Tag) if field.Type.Kind() == reflect.Struct { - itf := val.Field(j).Interface() + fieldVal := val.Field(j) + if !fieldVal.IsZero() { + itf := fieldVal.Interface() - if str, ok := itf.(encoding.TextMarshaler); ok && found { - byt, _ := str.MarshalText() - return string(byt) - } + if str, ok := itf.(encoding.TextMarshaler); ok && found { + byt, _ := str.MarshalText() + return string(byt) + } - res := extractValue(originTag, itf) - if res != nil { - return res + res := extractValue(originTag, itf) + if res != nil { + return res + } } } if found { - return reflect.Indirect(val.Field(j)).Interface() + fieldVal := val.Field(j) + if !fieldVal.IsZero() { + return reflect.Indirect(fieldVal).Interface() + } } } return nil