diff --git a/constants.go b/constants.go index 23288d3..0b1a33c 100644 --- a/constants.go +++ b/constants.go @@ -9,7 +9,7 @@ const ( annotationRelation = "relation" annotationOmitEmpty = "omitempty" annotationISO8601 = "iso8601" - annotationSeperator = "," + annotationSeparator = "," iso8601TimeFormat = "2006-01-02T15:04:05Z" diff --git a/models_test.go b/models_test.go index 2d4aae4..c86fcb3 100644 --- a/models_test.go +++ b/models_test.go @@ -1,6 +1,7 @@ package jsonapi import ( + "database/sql" "fmt" "time" ) @@ -26,9 +27,42 @@ type WithPointer struct { } type Timestamp struct { - ID int `jsonapi:"primary,timestamps"` - Time time.Time `jsonapi:"attr,timestamp,iso8601"` - Next *time.Time `jsonapi:"attr,next,iso8601"` + ID int `jsonapi:"primary,timestamps"` + Time time.Time `jsonapi:"attr,timestamp,iso8601"` + Next *time.Time `jsonapi:"attr,next,iso8601"` + Null sql.NullTime `jsonapi:"attr,null,iso8601"` +} + +type NullStringID struct { + ID sql.NullString `jsonapi:"primary,null-string-id"` + Periodic sql.NullBool `jsonapi:"attr,periodic,omitempty"` + Name sql.NullString `jsonapi:"attr,name,omitempty"` + Value sql.NullFloat64 `jsonapi:"attr,value,omitempty"` + Decimal sql.NullInt32 `jsonapi:"attr,decimal,omitempty"` + Fractional sql.NullInt64 `jsonapi:"attr,fractional,omitempty"` + ComputedAt sql.NullTime `jsonapi:"attr,computed_at,omitempty,iso8601"` +} + +type NullInt32ID struct { + ID sql.NullInt32 `jsonapi:"primary,null-int32-id"` +} + +type NullInt64ID struct { + ID sql.NullInt64 `jsonapi:"primary,null-int64-id"` +} + +type NullFloat64ID struct { + ID sql.NullFloat64 `jsonapi:"primary,null-float64-id"` +} + +type Float struct { + ID sql.NullString `jsonapi:"primary,float"` + Periodic sql.NullBool `jsonapi:"attr,periodic"` + Name sql.NullString `jsonapi:"attr,name"` + Value sql.NullFloat64 `jsonapi:"attr,value"` + Decimal sql.NullInt32 `jsonapi:"attr,decimal"` + Fractional sql.NullInt64 `jsonapi:"attr,fractional"` + ComputedAt sql.NullTime `jsonapi:"attr,computed_at"` } type Car struct { diff --git a/request.go b/request.go index b2fa477..9e15c8a 100644 --- a/request.go +++ b/request.go @@ -2,6 +2,7 @@ package jsonapi import ( "bytes" + "database/sql" "encoding/json" "errors" "fmt" @@ -149,90 +150,53 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) modelValue := model.Elem() modelType := modelValue.Type() - var er error - for i := 0; i < modelValue.NumField(); i++ { fieldType := modelType.Field(i) - tag := fieldType.Tag.Get("jsonapi") + tag := fieldType.Tag.Get(annotationJSONAPI) if tag == "" { continue } fieldValue := modelValue.Field(i) - args := strings.Split(tag, ",") + args := strings.Split(tag, annotationSeparator) + if len(args) < 1 { - er = ErrBadJSONAPIStructTag - break + return ErrBadJSONAPIStructTag } annotation := args[0] if (annotation == annotationClientID && len(args) != 1) || (annotation != annotationClientID && len(args) < 2) { - er = ErrBadJSONAPIStructTag - break + return ErrBadJSONAPIStructTag } - if annotation == annotationPrimary { + switch annotation { + case annotationPrimary: // Check the JSON API Type if data.Type != args[1] { - er = fmt.Errorf( + return fmt.Errorf( "Trying to Unmarshal an object of type %#v, but %#v does not match", data.Type, args[1], ) - break - } - - if data.ID == "" { - continue - } - - // ID will have to be transmitted as astring per the JSON API spec - v := reflect.ValueOf(data.ID) - - // Deal with PTRS - var kind reflect.Kind - if fieldValue.Kind() == reflect.Ptr { - kind = fieldType.Type.Elem().Kind() - } else { - kind = fieldType.Type.Kind() } - // Handle String case - if kind == reflect.String { - assign(fieldValue, v) - continue - } - - // Value was not a string... only other supported type was a numeric, - // which would have been sent as a float value. - floatValue, err := strconv.ParseFloat(data.ID, 64) - if err != nil { - // Could not convert the value in the "id" attr to a float - er = ErrBadJSONAPIID - break - } + data, err = unmarshallID(data, fieldValue, fieldType) - // Convert the numeric float to one of the supported ID numeric types - // (int[8,16,32,64] or uint[8,16,32,64]) - idValue, err := handleNumeric(floatValue, fieldType.Type, fieldValue) if err != nil { - // We had a JSON float (numeric), but our field was not one of the - // allowed numeric types - er = ErrBadJSONAPIID - break + return } - assign(fieldValue, idValue) - } else if annotation == annotationClientID { + case annotationClientID: if data.ClientID == "" { continue } fieldValue.Set(reflect.ValueOf(data.ClientID)) - } else if annotation == annotationAttribute { + + case annotationAttribute: attributes := data.Attributes if attributes == nil || len(data.Attributes) == 0 { @@ -249,87 +213,148 @@ func unmarshalNode(data *Node, model reflect.Value, included *map[string]*Node) structField := fieldType value, err := unmarshalAttribute(attribute, args, structField, fieldValue) if err != nil { - er = err - break + return err } assign(fieldValue, value) - } else if annotation == annotationRelation { - isSlice := fieldValue.Type().Kind() == reflect.Slice - if data.Relationships == nil || data.Relationships[args[1]] == nil { - continue + case annotationRelation: + data, err = unmarshallRelation(data, fieldValue, included, args) + + if err != nil { + return } - if isSlice { - // to-many relationship - relationship := new(RelationshipManyNode) + default: + return fmt.Errorf(unsupportedStructTagMsg, annotation) + } + } - buf := bytes.NewBuffer(nil) + return nil +} - json.NewEncoder(buf).Encode(data.Relationships[args[1]]) - json.NewDecoder(buf).Decode(relationship) +func unmarshallID(node *Node, fieldValue reflect.Value, structField reflect.StructField) (*Node, error) { + if node.ID == "" { + return node, nil + } - data := relationship.Data - models := reflect.New(fieldValue.Type()).Elem() + // ID will have to be transmitted as a string per the JSON API spec + v := reflect.ValueOf(node.ID) - for _, n := range data { - m := reflect.New(fieldValue.Type().Elem().Elem()) + // Deal with PTRS + var kind reflect.Kind + if fieldValue.Kind() == reflect.Ptr { + kind = structField.Type.Elem().Kind() + } else { + kind = structField.Type.Kind() + } - if err := unmarshalNode( - fullNode(n, included), - m, - included, - ); err != nil { - er = err - break - } + // Handle String case + if kind == reflect.String { + assign(fieldValue, v) - models = reflect.Append(models, m) - } + return node, nil + } - fieldValue.Set(models) - } else { - // to-one relationships - relationship := new(RelationshipOneNode) + // Handle sql.NullString case + if structField.Type == reflect.TypeOf(sql.NullString{}) { + if str, ok := v.Interface().(string); ok { + assign(fieldValue, reflect.ValueOf(sql.NullString{String: str, Valid: true})) - buf := bytes.NewBuffer(nil) + return node, nil + } + } - json.NewEncoder(buf).Encode( - data.Relationships[args[1]], - ) - json.NewDecoder(buf).Decode(relationship) - - /* - http://jsonapi.org/format/#document-resource-object-relationships - http://jsonapi.org/format/#document-resource-object-linkage - relationship can have a data node set to null (e.g. to disassociate the relationship) - so unmarshal and set fieldValue only if data obj is not null - */ - if relationship.Data == nil { - continue - } - - m := reflect.New(fieldValue.Type().Elem()) - if err := unmarshalNode( - fullNode(relationship.Data, included), - m, - included, - ); err != nil { - er = err - break - } - - fieldValue.Set(m) + // Value was not a string... only other supported type was a numeric, + // which would have been sent as a float value. + floatValue, err := strconv.ParseFloat(node.ID, 64) + if err != nil { + // Could not convert the value in the "id" attr to a float + return nil, ErrBadJSONAPIID + } + + // Convert the numeric float to one of the supported ID numeric types + // (int[8,16,32,64], uint[8,16,32,64] or sql.Null[Int32, Int64, Float64]) + idValue, err := handleNumeric(floatValue, structField.Type, fieldValue) + if err != nil { + // We had a JSON float (numeric), but our field was not one of the + // allowed numeric types + return nil, ErrBadJSONAPIID + } + + assign(fieldValue, idValue) + + return node, nil +} +func unmarshallRelation(node *Node, fieldValue reflect.Value, included *map[string]*Node, args []string) (*Node, error) { + isSlice := fieldValue.Type().Kind() == reflect.Slice + + if node.Relationships == nil || node.Relationships[args[1]] == nil { + return node, nil + } + + if isSlice { + // to-many relationship + relationship := new(RelationshipManyNode) + + buf := bytes.NewBuffer(nil) + + json.NewEncoder(buf).Encode(node.Relationships[args[1]]) + json.NewDecoder(buf).Decode(relationship) + + data := relationship.Data + models := reflect.New(fieldValue.Type()).Elem() + + for _, n := range data { + m := reflect.New(fieldValue.Type().Elem().Elem()) + + if err := unmarshalNode( + fullNode(n, included), + m, + included, + ); err != nil { + return nil, err } - } else { - er = fmt.Errorf(unsupportedStructTagMsg, annotation) + models = reflect.Append(models, m) + } + + fieldValue.Set(models) + } else { + // to-one relationships + relationship := new(RelationshipOneNode) + + buf := bytes.NewBuffer(nil) + + json.NewEncoder(buf).Encode( + node.Relationships[args[1]], + ) + json.NewDecoder(buf).Decode(relationship) + + /* + http://jsonapi.org/format/#document-resource-object-relationships + http://jsonapi.org/format/#document-resource-object-linkage + relationship can have a data node set to null (e.g. to disassociate the relationship) + so unmarshal and set fieldValue only if data obj is not null + */ + if relationship.Data == nil { + return node, nil } + + m := reflect.New(fieldValue.Type().Elem()) + if err := unmarshalNode( + fullNode(relationship.Data, included), + m, + included, + ); err != nil { + return nil, err + } + + fieldValue.Set(m) } - return er + return node, nil } func fullNode(n *Node, included *map[string]*Node) *Node { @@ -393,6 +418,12 @@ func unmarshalAttribute( return } + // Handle field of sql.Null* type + if isSQLNullType(fieldType) { + value, err = handleSQLNullType(attribute, args, fieldType, fieldValue) + return + } + // Handle field of type time.Time if fieldValue.Type() == reflect.TypeOf(time.Time{}) || fieldValue.Type() == reflect.TypeOf(new(time.Time)) { @@ -457,18 +488,19 @@ func handleTime(attribute interface{}, args []string, fieldValue reflect.Value) } if isIso8601 { - var tm string - if v.Kind() == reflect.String { - tm = v.Interface().(string) - } else { + if v.Kind() != reflect.String { return reflect.ValueOf(time.Now()), ErrInvalidISO8601 } - t, err := time.Parse(iso8601TimeFormat, tm) + t, err := time.Parse(iso8601TimeFormat, v.Interface().(string)) if err != nil { return reflect.ValueOf(time.Now()), ErrInvalidISO8601 } + if _, ok := fieldValue.Interface().(sql.NullTime); ok { + return reflect.ValueOf(sql.NullTime{Time: t, Valid: true}), nil + } + if fieldValue.Kind() == reflect.Ptr { return reflect.ValueOf(&t), nil } @@ -476,17 +508,19 @@ func handleTime(attribute interface{}, args []string, fieldValue reflect.Value) return reflect.ValueOf(t), nil } - var at int64 + var t time.Time if v.Kind() == reflect.Float64 { - at = int64(v.Interface().(float64)) + t = time.Unix(int64(v.Float()), 0) } else if v.Kind() == reflect.Int { - at = v.Int() + t = time.Unix(v.Int(), 0) } else { return reflect.ValueOf(time.Now()), ErrInvalidTime } - t := time.Unix(at, 0) + if _, ok := fieldValue.Interface().(sql.NullTime); ok { + return reflect.ValueOf(sql.NullTime{Time: t, Valid: true}), nil + } return reflect.ValueOf(t), nil } @@ -544,6 +578,23 @@ func handleNumeric( case reflect.Float64: n := floatValue numericValue = reflect.ValueOf(&n) + case reflect.Struct: + if _, ok := fieldValue.Interface().(sql.NullInt32); ok { + numericValue = reflect.ValueOf(sql.NullInt32{Int32: int32(floatValue), Valid: true}) + break + } + + if _, ok := fieldValue.Interface().(sql.NullInt64); ok { + numericValue = reflect.ValueOf(sql.NullInt64{Int64: int64(floatValue), Valid: true}) + break + } + + if _, ok := fieldValue.Interface().(sql.NullFloat64); ok { + numericValue = reflect.ValueOf(sql.NullFloat64{Float64: floatValue, Valid: true}) + break + } + + fallthrough default: return reflect.Value{}, ErrUnknownFieldNumberType } @@ -588,6 +639,32 @@ func handlePointer( return concreteVal, nil } +func isSQLNullType(fieldType reflect.Type) bool { + switch fieldType { + case reflect.TypeOf(sql.NullString{}), reflect.TypeOf(sql.NullBool{}), reflect.TypeOf(sql.NullInt32{}), + reflect.TypeOf(sql.NullInt64{}), reflect.TypeOf(sql.NullFloat64{}), reflect.TypeOf(sql.NullTime{}): + return true + } + + return false +} + +func handleSQLNullType(attribute interface{}, args []string, fieldType reflect.Type, + fieldValue reflect.Value) (reflect.Value, error) { + switch fieldType { + case reflect.TypeOf(sql.NullString{}): + return reflect.ValueOf(sql.NullString{String: attribute.(string), Valid: true}), nil + case reflect.TypeOf(sql.NullBool{}): + return reflect.ValueOf(sql.NullBool{Bool: attribute.(bool), Valid: true}), nil + case reflect.TypeOf(sql.NullInt32{}), reflect.TypeOf(sql.NullInt64{}), reflect.TypeOf(sql.NullFloat64{}): + return handleNumeric(attribute, fieldType, fieldValue) + case reflect.TypeOf(sql.NullTime{}): + return handleTime(attribute, args, fieldValue) + } + + return reflect.Value{}, fmt.Errorf("expected sql.Null* type, got: %v", fieldType) +} + func handleStruct( attribute interface{}, fieldValue reflect.Value) (reflect.Value, error) { diff --git a/request_test.go b/request_test.go index daa2159..2f9092c 100644 --- a/request_test.go +++ b/request_test.go @@ -212,6 +212,120 @@ func TestUnmarshalToStructWithPointerAttr_BadType_IntSlice(t *testing.T) { } } +func TestUnmarshalToStructNullStringID(t *testing.T) { + data := map[string]interface{}{ + "data": map[string]interface{}{ + "type": "null-string-id", + "id": "314", + "attributes": map[string]interface{}{ + "periodic": false, + "name": "Pi", + "value": 3.1415926535897932, + "decimal": 3, + "fractional": 1415926535897932, + "computed_at": "2021-03-14T15:00:00Z", + }, + }, + } + payload, err := json.Marshal(data) + if err != nil { + t.Fatal(err) + } + + pi := new(NullStringID) + if err = UnmarshalPayload(bytes.NewReader(payload), pi); err != nil { + t.Fatal(err) + } + + if pi.ID.String != "314" { + t.Fatalf("Error unmarshalling to sql.NullString") + } + if pi.Name.String != "Pi" { + t.Fatalf("Error unmarshalling to sql.NullString") + } + if pi.Periodic.Bool { + t.Fatalf("Error unmarshalling to sql.NullBool") + } + if pi.Value.Float64 != 3.1415926535897932 { + t.Fatalf("Error unmarshalling to sql.NullFloat64") + } + if pi.Decimal.Int32 != 3 { + t.Fatalf("Error unmarshalling to sql.NullInt32") + } + if pi.Fractional.Int64 != 1415926535897932 { + t.Fatalf("Error unmarshalling to sql.NullInt64") + } + if !pi.ComputedAt.Time.Equal(time.Unix(1615734000, 0)) { + t.Fatalf("Error unmarshalling to sql.NullTime") + } +} + +func TestUnmarshalToStructNullInt32ID(t *testing.T) { + data := map[string]interface{}{ + "data": map[string]interface{}{ + "type": "null-int32-id", + "id": "123", + }, + } + payload, err := json.Marshal(data) + if err != nil { + t.Fatal(err) + } + + i32 := new(NullInt32ID) + if err = UnmarshalPayload(bytes.NewReader(payload), i32); err != nil { + t.Fatal(err) + } + + if i32.ID.Int32 != 123 { + t.Fatalf("Error unmarshalling to sql.NullInt32") + } +} + +func TestUnmarshalToStructNullInt64ID(t *testing.T) { + data := map[string]interface{}{ + "data": map[string]interface{}{ + "type": "null-int64-id", + "id": "456", + }, + } + payload, err := json.Marshal(data) + if err != nil { + t.Fatal(err) + } + + i64 := new(NullInt64ID) + if err = UnmarshalPayload(bytes.NewReader(payload), i64); err != nil { + t.Fatal(err) + } + + if i64.ID.Int64 != 456 { + t.Fatalf("Error unmarshalling to sql.NullInt64") + } +} + +func TestUnmarshalToStructNullFloat64ID(t *testing.T) { + data := map[string]interface{}{ + "data": map[string]interface{}{ + "type": "null-float64-id", + "id": "789", + }, + } + payload, err := json.Marshal(data) + if err != nil { + t.Fatal(err) + } + + f64 := new(NullFloat64ID) + if err = UnmarshalPayload(bytes.NewReader(payload), f64); err != nil { + t.Fatal(err) + } + + if f64.ID.Float64 != 789 { + t.Fatalf("Error unmarshalling to sql.NullFloat64") + } +} + func TestStringPointerField(t *testing.T) { // Build Book payload description := "Hello World!" @@ -393,6 +507,32 @@ func TestUnmarshalParsesISO8601TimePointer(t *testing.T) { } } +func TestUnmarshalParsesISO8601NullTime(t *testing.T) { + payload := &OnePayload{ + Data: &Node{ + Type: "timestamps", + Attributes: map[string]interface{}{ + "null": "2016-08-17T08:27:12Z", + }, + }, + } + + in := bytes.NewBuffer(nil) + json.NewEncoder(in).Encode(payload) + + out := new(Timestamp) + + if err := UnmarshalPayload(in, out); err != nil { + t.Fatal(err) + } + + expected := time.Date(2016, 8, 17, 8, 27, 12, 0, time.UTC) + + if !out.Null.Time.Equal(expected) { + t.Fatal("Parsing the ISO8601 timestamp failed") + } +} + func TestUnmarshalInvalidISO8601(t *testing.T) { payload := &OnePayload{ Data: &Node{ diff --git a/response.go b/response.go index 3f8ab73..e9c9164 100644 --- a/response.go +++ b/response.go @@ -1,6 +1,7 @@ package jsonapi import ( + "database/sql" "encoding/json" "errors" "fmt" @@ -18,7 +19,7 @@ var ( // ErrBadJSONAPIID is returned when the Struct JSON API annotated "id" field // was not a valid numeric type. ErrBadJSONAPIID = errors.New( - "id should be either string, int(8,16,32,64) or uint(8,16,32,64)") + "id should be either string, int(8,16,32,64), uint(8,16,32,64) or sql.Null(Int32, Int64, Float64)") // ErrExpectedSlice is returned when a variable or argument was expected to // be a slice of *Structs; MarshalMany will return this error when its // interface{} argument is invalid. @@ -173,7 +174,7 @@ func marshalMany(models []interface{}) (*ManyPayload, error) { // related records. This method will serialize a single struct // pointer into an embedded json response. In other words, there // will be no, "included", array in the json all relationships will -// be serailized inline in the data. +// be serialized inline in the data. // // However, in tests, you may want to construct payloads to post // to create methods that are embedded to most closely resemble @@ -196,7 +197,6 @@ func visitModelNode(model interface{}, included *map[string]*Node, sideload bool) (*Node, error) { node := new(Node) - var er error value := reflect.ValueOf(model) if value.IsNil() { return nil, nil @@ -215,253 +215,363 @@ func visitModelNode(model interface{}, included *map[string]*Node, fieldValue := modelValue.Field(i) fieldType := modelType.Field(i) - args := strings.Split(tag, annotationSeperator) + args := strings.Split(tag, annotationSeparator) if len(args) < 1 { - er = ErrBadJSONAPIStructTag - break + return nil, ErrBadJSONAPIStructTag } annotation := args[0] if (annotation == annotationClientID && len(args) != 1) || (annotation != annotationClientID && len(args) < 2) { - er = ErrBadJSONAPIStructTag - break + return nil, ErrBadJSONAPIStructTag } - if annotation == annotationPrimary { - v := fieldValue - - // Deal with PTRS - var kind reflect.Kind - if fieldValue.Kind() == reflect.Ptr { - kind = fieldType.Type.Elem().Kind() - v = reflect.Indirect(fieldValue) - } else { - kind = fieldType.Type.Kind() - } + var err error - // Handle allowed types - switch kind { - case reflect.String: - node.ID = v.Interface().(string) - case reflect.Int: - node.ID = strconv.FormatInt(int64(v.Interface().(int)), 10) - case reflect.Int8: - node.ID = strconv.FormatInt(int64(v.Interface().(int8)), 10) - case reflect.Int16: - node.ID = strconv.FormatInt(int64(v.Interface().(int16)), 10) - case reflect.Int32: - node.ID = strconv.FormatInt(int64(v.Interface().(int32)), 10) - case reflect.Int64: - node.ID = strconv.FormatInt(v.Interface().(int64), 10) - case reflect.Uint: - node.ID = strconv.FormatUint(uint64(v.Interface().(uint)), 10) - case reflect.Uint8: - node.ID = strconv.FormatUint(uint64(v.Interface().(uint8)), 10) - case reflect.Uint16: - node.ID = strconv.FormatUint(uint64(v.Interface().(uint16)), 10) - case reflect.Uint32: - node.ID = strconv.FormatUint(uint64(v.Interface().(uint32)), 10) - case reflect.Uint64: - node.ID = strconv.FormatUint(v.Interface().(uint64), 10) - default: - // We had a JSON float (numeric), but our field was not one of the - // allowed numeric types - er = ErrBadJSONAPIID - } + switch annotation { + case annotationPrimary: + node, err = resolveNodeID(node, fieldValue, fieldType) - if er != nil { - break + if err != nil { + return nil, err } node.Type = args[1] - } else if annotation == annotationClientID { + case annotationClientID: clientID := fieldValue.String() if clientID != "" { node.ClientID = clientID } - } else if annotation == annotationAttribute { - var omitEmpty, iso8601 bool - - if len(args) > 2 { - for _, arg := range args[2:] { - switch arg { - case annotationOmitEmpty: - omitEmpty = true - case annotationISO8601: - iso8601 = true - } - } + case annotationAttribute: + node = resolveNodeAttribute(node, fieldValue, args) + case annotationRelation: + node, err = resolveNodeRelation(node, fieldValue, args, model, included, sideload) + + if err != nil { + return nil, err } + default: + return nil, ErrBadJSONAPIStructTag + } + } + + if linkableModel, isLinkable := model.(Linkable); isLinkable { + jl := linkableModel.JSONAPILinks() + if er := jl.validate(); er != nil { + return nil, er + } + node.Links = linkableModel.JSONAPILinks() + } + + if metableModel, ok := model.(Metable); ok { + node.Meta = metableModel.JSONAPIMeta() + } + + return node, nil +} - if node.Attributes == nil { - node.Attributes = make(map[string]interface{}) +func resolveNodeID(node *Node, fieldValue reflect.Value, structField reflect.StructField) (*Node, error) { + v := fieldValue + + // Deal with PTRS + var kind reflect.Kind + if fieldValue.Kind() == reflect.Ptr { + kind = structField.Type.Elem().Kind() + v = reflect.Indirect(fieldValue) + } else { + kind = structField.Type.Kind() + } + + // Handle allowed types + switch kind { + case reflect.String: + node.ID = v.Interface().(string) + case reflect.Int: + node.ID = strconv.FormatInt(int64(v.Interface().(int)), 10) + case reflect.Int8: + node.ID = strconv.FormatInt(int64(v.Interface().(int8)), 10) + case reflect.Int16: + node.ID = strconv.FormatInt(int64(v.Interface().(int16)), 10) + case reflect.Int32: + node.ID = strconv.FormatInt(int64(v.Interface().(int32)), 10) + case reflect.Int64: + node.ID = strconv.FormatInt(v.Interface().(int64), 10) + case reflect.Uint: + node.ID = strconv.FormatUint(uint64(v.Interface().(uint)), 10) + case reflect.Uint8: + node.ID = strconv.FormatUint(uint64(v.Interface().(uint8)), 10) + case reflect.Uint16: + node.ID = strconv.FormatUint(uint64(v.Interface().(uint16)), 10) + case reflect.Uint32: + node.ID = strconv.FormatUint(uint64(v.Interface().(uint32)), 10) + case reflect.Uint64: + node.ID = strconv.FormatUint(v.Interface().(uint64), 10) + case reflect.Struct: + if str, ok := v.Interface().(sql.NullString); ok { + node.ID = str.String + break + } + + if i32, ok := v.Interface().(sql.NullInt32); ok { + node.ID = strconv.FormatInt(int64(i32.Int32), 10) + break + } + + if i64, ok := v.Interface().(sql.NullInt64); ok { + node.ID = strconv.FormatInt(i64.Int64, 10) + break + } + + if f64, ok := v.Interface().(sql.NullFloat64); ok { + node.ID = strconv.FormatFloat(f64.Float64, 'f', -1, 64) + break + } + + fallthrough + default: + // We had a JSON float (numeric), but our field was not one of the + // allowed numeric types + return nil, ErrBadJSONAPIID + } + + return node, nil +} + +func resolveNodeAttribute(node *Node, fieldValue reflect.Value, args []string) *Node { + var omitEmpty, iso8601 bool + + if len(args) > 2 { + for _, arg := range args[2:] { + switch arg { + case annotationOmitEmpty: + omitEmpty = true + case annotationISO8601: + iso8601 = true } + } + } - if fieldValue.Type() == reflect.TypeOf(time.Time{}) { - t := fieldValue.Interface().(time.Time) - - if t.IsZero() { - continue - } - - if iso8601 { - node.Attributes[args[1]] = t.UTC().Format(iso8601TimeFormat) - } else { - node.Attributes[args[1]] = t.Unix() - } - } else if fieldValue.Type() == reflect.TypeOf(new(time.Time)) { - // A time pointer may be nil - if fieldValue.IsNil() { - if omitEmpty { - continue - } - - node.Attributes[args[1]] = nil - } else { - tm := fieldValue.Interface().(*time.Time) - - if tm.IsZero() && omitEmpty { - continue - } - - if iso8601 { - node.Attributes[args[1]] = tm.UTC().Format(iso8601TimeFormat) - } else { - node.Attributes[args[1]] = tm.Unix() - } - } + if node.Attributes == nil { + node.Attributes = make(map[string]interface{}) + } + + switch fieldValue.Type() { + case reflect.TypeOf(time.Time{}): + t := fieldValue.Interface().(time.Time) + + if t.IsZero() { + return node + } + + if iso8601 { + node.Attributes[args[1]] = t.UTC().Format(iso8601TimeFormat) + } else { + node.Attributes[args[1]] = t.Unix() + } + case reflect.TypeOf(new(time.Time)): + // A time pointer may be nil + if fieldValue.IsNil() { + if omitEmpty { + return node + } + + node.Attributes[args[1]] = nil + } else { + t := fieldValue.Interface().(*time.Time) + + if t.IsZero() && omitEmpty { + return node + } + + if iso8601 { + node.Attributes[args[1]] = t.UTC().Format(iso8601TimeFormat) } else { - // Dealing with a fieldValue that is not a time - emptyValue := reflect.Zero(fieldValue.Type()) - - // See if we need to omit this field - if omitEmpty && reflect.DeepEqual(fieldValue.Interface(), emptyValue.Interface()) { - continue - } - - strAttr, ok := fieldValue.Interface().(string) - if ok { - node.Attributes[args[1]] = strAttr - } else { - node.Attributes[args[1]] = fieldValue.Interface() - } + node.Attributes[args[1]] = t.Unix() + } + } + case reflect.TypeOf(sql.NullTime{}): + nt := fieldValue.Interface().(sql.NullTime) + + // Time is NULL + if !nt.Valid { + if omitEmpty { + return node + } + + node.Attributes[args[1]] = nil + } else { + if nt.Time.IsZero() { + return node } - } else if annotation == annotationRelation { - var omitEmpty bool - //add support for 'omitempty' struct tag for marshaling as absent - if len(args) > 2 { - omitEmpty = args[2] == annotationOmitEmpty + if iso8601 { + node.Attributes[args[1]] = nt.Time.UTC().Format(iso8601TimeFormat) + } else { + node.Attributes[args[1]] = nt.Time.Unix() } + } + default: + // Dealing with a fieldValue that is not a time + emptyValue := reflect.Zero(fieldValue.Type()) + + // See if we need to omit this field + if omitEmpty && reflect.DeepEqual(fieldValue.Interface(), emptyValue.Interface()) { + break + } - isSlice := fieldValue.Type().Kind() == reflect.Slice - if omitEmpty && - (isSlice && fieldValue.Len() < 1 || - (!isSlice && fieldValue.IsNil())) { - continue + // Handle remaining sql.Null* types + if boo, ok := fieldValue.Interface().(sql.NullBool); ok { + if boo.Valid { + node.Attributes[args[1]] = boo.Bool + } else { + node.Attributes[args[1]] = nil } + break + } - if node.Relationships == nil { - node.Relationships = make(map[string]interface{}) + if str, ok := fieldValue.Interface().(sql.NullString); ok { + if str.Valid { + node.Attributes[args[1]] = str.String + } else { + node.Attributes[args[1]] = nil } + break + } - var relLinks *Links - if linkableModel, ok := model.(RelationshipLinkable); ok { - relLinks = linkableModel.JSONAPIRelationshipLinks(args[1]) + if f64, ok := fieldValue.Interface().(sql.NullFloat64); ok { + if f64.Valid { + node.Attributes[args[1]] = f64.Float64 + } else { + node.Attributes[args[1]] = nil } + break + } - var relMeta *Meta - if metableModel, ok := model.(RelationshipMetable); ok { - relMeta = metableModel.JSONAPIRelationshipMeta(args[1]) + if i32, ok := fieldValue.Interface().(sql.NullInt32); ok { + if i32.Valid { + node.Attributes[args[1]] = i32.Int32 + } else { + node.Attributes[args[1]] = nil } + break + } - if isSlice { - // to-many relationship - relationship, err := visitModelNodeRelationships( - fieldValue, - included, - sideload, - ) - if err != nil { - er = err - break - } - relationship.Links = relLinks - relationship.Meta = relMeta - - if sideload { - shallowNodes := []*Node{} - for _, n := range relationship.Data { - appendIncluded(included, n) - shallowNodes = append(shallowNodes, toShallowNode(n)) - } - - node.Relationships[args[1]] = &RelationshipManyNode{ - Data: shallowNodes, - Links: relationship.Links, - Meta: relationship.Meta, - } - } else { - node.Relationships[args[1]] = relationship - } + if i64, ok := fieldValue.Interface().(sql.NullInt64); ok { + if i64.Valid { + node.Attributes[args[1]] = i64.Int64 } else { - // to-one relationships - - // Handle null relationship case - if fieldValue.IsNil() { - node.Relationships[args[1]] = &RelationshipOneNode{Data: nil} - continue - } - - relationship, err := visitModelNode( - fieldValue.Interface(), - included, - sideload, - ) - if err != nil { - er = err - break - } - - if sideload { - appendIncluded(included, relationship) - node.Relationships[args[1]] = &RelationshipOneNode{ - Data: toShallowNode(relationship), - Links: relLinks, - Meta: relMeta, - } - } else { - node.Relationships[args[1]] = &RelationshipOneNode{ - Data: relationship, - Links: relLinks, - Meta: relMeta, - } - } + node.Attributes[args[1]] = nil } + break + } + // Handle string and remaining types + if str, ok := fieldValue.Interface().(string); ok { + node.Attributes[args[1]] = str } else { - er = ErrBadJSONAPIStructTag - break + node.Attributes[args[1]] = fieldValue.Interface() } } - if er != nil { - return nil, er + return node +} + +func resolveNodeRelation(node *Node, fieldValue reflect.Value, args []string, + model interface{}, included *map[string]*Node, sideload bool) (*Node, error) { + var omitEmpty bool + + // add support for 'omitempty' struct tag for marshaling as absent + if len(args) > 2 { + omitEmpty = args[2] == annotationOmitEmpty } - if linkableModel, isLinkable := model.(Linkable); isLinkable { - jl := linkableModel.JSONAPILinks() - if er := jl.validate(); er != nil { - return nil, er + isSlice := fieldValue.Type().Kind() == reflect.Slice + if omitEmpty && + (isSlice && fieldValue.Len() < 1 || + (!isSlice && fieldValue.IsNil())) { + return node, nil + } + + if node.Relationships == nil { + node.Relationships = make(map[string]interface{}) + } + + var relLinks *Links + if linkableModel, ok := model.(RelationshipLinkable); ok { + relLinks = linkableModel.JSONAPIRelationshipLinks(args[1]) + } + + var relMeta *Meta + if metableModel, ok := model.(RelationshipMetable); ok { + relMeta = metableModel.JSONAPIRelationshipMeta(args[1]) + } + + if isSlice { + // to-many relationship + relationship, err := visitModelNodeRelationships( + fieldValue, + included, + sideload, + ) + if err != nil { + return nil, err } - node.Links = linkableModel.JSONAPILinks() + + relationship.Links = relLinks + relationship.Meta = relMeta + + if sideload { + shallowNodes := []*Node{} + for _, n := range relationship.Data { + appendIncluded(included, n) + shallowNodes = append(shallowNodes, toShallowNode(n)) + } + + node.Relationships[args[1]] = &RelationshipManyNode{ + Data: shallowNodes, + Links: relationship.Links, + Meta: relationship.Meta, + } + } else { + node.Relationships[args[1]] = relationship + } + + return node, nil } - if metableModel, ok := model.(Metable); ok { - node.Meta = metableModel.JSONAPIMeta() + // to-one relationships + + // Handle null relationship case + if fieldValue.IsNil() { + node.Relationships[args[1]] = &RelationshipOneNode{Data: nil} + + return node, nil + } + + relationship, err := visitModelNode( + fieldValue.Interface(), + included, + sideload, + ) + if err != nil { + return nil, err + } + + if sideload { + appendIncluded(included, relationship) + node.Relationships[args[1]] = &RelationshipOneNode{ + Data: toShallowNode(relationship), + Links: relLinks, + Meta: relMeta, + } + } else { + node.Relationships[args[1]] = &RelationshipOneNode{ + Data: relationship, + Links: relLinks, + Meta: relMeta, + } } return node, nil diff --git a/response_test.go b/response_test.go index 5b42595..f6ee4b9 100644 --- a/response_test.go +++ b/response_test.go @@ -2,6 +2,7 @@ package jsonapi import ( "bytes" + "database/sql" "encoding/json" "reflect" "sort" @@ -116,7 +117,7 @@ func TestWithoutOmitsEmptyAnnotationOnRelation(t *testing.T) { } relationships := jsonData["data"].(map[string]interface{})["relationships"].(map[string]interface{}) - // Verifiy the "posts" relation was an empty array + // Verify the "posts" relation was an empty array posts, ok := relationships["posts"] if !ok { t.Fatal("Was expecting the data.relationships.posts key/value to have been present") @@ -137,7 +138,7 @@ func TestWithoutOmitsEmptyAnnotationOnRelation(t *testing.T) { t.Fatal("Was expecting the data.relationships.posts.data value to have been an empty array []") } - // Verifiy the "current_post" was a null + // Verify the "current_post" was a null currentPost, postExists := relationships["current_post"] if !postExists { t.Fatal("Was expecting the data.relationships.current_post key/value to have NOT been omitted") @@ -525,6 +526,456 @@ func TestMarshalISO8601TimePointer(t *testing.T) { } } +func TestMarshalISO8601NullTime(t *testing.T) { + testModel := &Timestamp{ + ID: 5, + Null: sql.NullTime{ + Time: time.Date(2016, 8, 17, 8, 27, 12, 23849, time.UTC), + Valid: true, + }, + } + + out := bytes.NewBuffer(nil) + if err := MarshalPayload(out, testModel); err != nil { + t.Fatal(err) + } + + resp := new(OnePayload) + if err := json.NewDecoder(out).Decode(resp); err != nil { + t.Fatal(err) + } + + data := resp.Data + + if data.Attributes == nil { + t.Fatalf("Expected attributes") + } + + if data.Attributes["null"] != "2016-08-17T08:27:12Z" { + t.Fatal("Null was not serialised into ISO8601 correctly") + } +} + +func TestMarshalISO8601NullTime_Zero(t *testing.T) { + testModel := &Timestamp{ + ID: 5, + Null: sql.NullTime{Valid: true}, + } + + out := bytes.NewBuffer(nil) + if err := MarshalPayload(out, testModel); err != nil { + t.Fatal(err) + } + + resp := new(OnePayload) + if err := json.NewDecoder(out).Decode(resp); err != nil { + t.Fatal(err) + } + + data := resp.Data + + if data.Attributes == nil { + t.Fatal("Expected attributes") + } + + if data.Attributes["null"] != nil { + t.Fatalf("Null should not have been serialised") + } +} + +func TestMarshalStructNullStringID_Zero_Invalid(t *testing.T) { + pi := new(NullStringID) + + out := bytes.NewBuffer(nil) + if err := MarshalPayload(out, pi); err != nil { + t.Fatal(err) + } + + var jsonData map[string]interface{} + + if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { + t.Fatal(err) + } + data := jsonData["data"].(map[string]interface{}) + + if data["type"] != "null-string-id" { + t.Fatalf("Error marshalling type") + } + + if _, ok := data["attributes"]; ok { + t.Fatal("Was expecting data.attributes to be omitted") + } +} + +func TestMarshalStructNullStringID_Zero_Valid(t *testing.T) { + pi := &NullStringID{ + ID: sql.NullString{Valid: true}, + Periodic: sql.NullBool{Valid: true}, + Name: sql.NullString{Valid: true}, + Value: sql.NullFloat64{Valid: true}, + Decimal: sql.NullInt32{Valid: true}, + Fractional: sql.NullInt64{Valid: true}, + ComputedAt: sql.NullTime{Valid: true}, + } + + out := bytes.NewBuffer(nil) + if err := MarshalPayload(out, pi); err != nil { + t.Fatal(err) + } + + var jsonData map[string]interface{} + + if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { + t.Fatal(err) + } + data := jsonData["data"].(map[string]interface{}) + + if _, ok := data["id"]; ok { + t.Fatal("Was expecting data.id to be omitted") + } + + if data["type"] != "null-string-id" { + t.Fatalf("Error marshalling type") + } + + attributes := data["attributes"].(map[string]interface{}) + + if attributes["periodic"] != false { + t.Fatalf("Error marshalling to sql.NullBool: %v", attributes["periodic"]) + } + + if attributes["name"] != "" { + t.Fatal("Error marshalling to sql.NullString") + } + + if attributes["value"] != 0.0 { + t.Fatal("Error marshalling to sql.NullFloat64") + } + + if attributes["decimal"] != 0.0 { + t.Fatalf("Error marshalling to sql.NullInt32") + } + + if attributes["fractional"] != 0.0 { + t.Fatalf("Error marshalling to sql.NullInt64") + } + + if _, ok := attributes["computed_at"]; ok { + t.Fatal("Was expecting data.attributes.computed_at to be omitted") + } +} + +func TestMarshalStructNullStringID(t *testing.T) { + pi := &NullStringID{ + ID: sql.NullString{ + String: "314", + Valid: true, + }, + Periodic: sql.NullBool{ + Bool: false, + Valid: true, + }, + Name: sql.NullString{ + String: "Pi", + Valid: true, + }, + Value: sql.NullFloat64{ + Float64: 3.1415926535897932, + Valid: true, + }, + Decimal: sql.NullInt32{ + Int32: 3, + Valid: true, + }, + Fractional: sql.NullInt64{ + Int64: 1415926535897932, + Valid: true, + }, + ComputedAt: sql.NullTime{ + Time: time.Date(2021, 3, 14, 15, 0, 0, 0, time.UTC), + Valid: true, + }, + } + + out := bytes.NewBuffer(nil) + if err := MarshalPayload(out, pi); err != nil { + t.Fatal(err) + } + + resp := new(OnePayload) + if err := json.Unmarshal(out.Bytes(), resp); err != nil { + t.Fatal(err) + } + + data := resp.Data + + if data.ID != "314" { + t.Fatal("Error marshalling id") + } + + if data.Type != "null-string-id" { + t.Fatal("Error marshalling type") + } + + if data.Attributes["periodic"] != false { + t.Fatal("Error marshalling to sql.NullBool") + } + + if data.Attributes["name"] != "Pi" { + t.Fatal("Error marshalling to sql.NullString") + } + + if data.Attributes["value"] != 3.1415926535897932 { + t.Fatal("Error marshalling to sql.NullFloat64") + } + + if data.Attributes["decimal"] != 3.0 { + t.Fatalf("Error marshalling to sql.NullInt32") + } + + if data.Attributes["computed_at"] != "2021-03-14T15:00:00Z" { + t.Fatalf("Error marshalling to sql.NullTime") + } +} + +func TestMarshalStructNullInt32ID(t *testing.T) { + i32 := &NullInt32ID{ + ID: sql.NullInt32{ + Int32: 123, + Valid: true, + }, + } + + out := bytes.NewBuffer(nil) + if err := MarshalPayload(out, i32); err != nil { + t.Fatal(err) + } + + resp := new(OnePayload) + if err := json.Unmarshal(out.Bytes(), resp); err != nil { + t.Fatal(err) + } + + data := resp.Data + + if data.ID != "123" { + t.Fatalf("Error marshalling id") + } + + if data.Type != "null-int32-id" { + t.Fatal("Error marshalling type") + } +} + +func TestMarshalStructNullInt64ID(t *testing.T) { + i32 := &NullInt64ID{ + ID: sql.NullInt64{ + Int64: 456, + Valid: true, + }, + } + + out := bytes.NewBuffer(nil) + if err := MarshalPayload(out, i32); err != nil { + t.Fatal(err) + } + + resp := new(OnePayload) + if err := json.Unmarshal(out.Bytes(), resp); err != nil { + t.Fatal(err) + } + + data := resp.Data + + if data.ID != "456" { + t.Fatalf("Error marshalling id") + } + + if data.Type != "null-int64-id" { + t.Fatal("Error marshalling type") + } +} + +func TestMarshalStructNullFloat64ID(t *testing.T) { + i32 := &NullFloat64ID{ + ID: sql.NullFloat64{ + Float64: 12345678.12345678, + Valid: true, + }, + } + + out := bytes.NewBuffer(nil) + if err := MarshalPayload(out, i32); err != nil { + t.Fatal(err) + } + + resp := new(OnePayload) + if err := json.Unmarshal(out.Bytes(), resp); err != nil { + t.Fatal(err) + } + + data := resp.Data + + if data.ID != "12345678.12345678" { + t.Fatal("Error marshalling id") + } + + if data.Type != "null-float64-id" { + t.Fatal("Error marshalling type") + } +} + +func TestMarshalStructPi_Zero_Invalid(t *testing.T) { + pi := new(Float) + + out := bytes.NewBuffer(nil) + if err := MarshalPayload(out, pi); err != nil { + t.Fatal(err) + } + + var jsonData map[string]interface{} + + if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { + t.Fatal(err) + } + data := jsonData["data"].(map[string]interface{}) + + if data["type"] != "float" { + t.Fatalf("Error marshalling type") + } + + if _, ok := data["attributes"]; !ok { + t.Fatal("Was expecting data.attributes to NOT be omitted") + } +} + +func TestMarshalStructPi_Zero_Valid(t *testing.T) { + pi := &Float{ + ID: sql.NullString{Valid: true}, + Periodic: sql.NullBool{Valid: true}, + Name: sql.NullString{Valid: true}, + Value: sql.NullFloat64{Valid: true}, + Decimal: sql.NullInt32{Valid: true}, + Fractional: sql.NullInt64{Valid: true}, + ComputedAt: sql.NullTime{Valid: true}, + } + + out := bytes.NewBuffer(nil) + if err := MarshalPayload(out, pi); err != nil { + t.Fatal(err) + } + + var jsonData map[string]interface{} + + if err := json.Unmarshal(out.Bytes(), &jsonData); err != nil { + t.Fatal(err) + } + data := jsonData["data"].(map[string]interface{}) + + if _, ok := data["id"]; ok { + t.Fatal("Was expecting data.id to be omitted") + } + + if data["type"] != "float" { + t.Fatalf("Error marshalling type") + } + + attributes := data["attributes"].(map[string]interface{}) + + if attributes["periodic"] != false { + t.Fatalf("Error marshalling to sql.NullBool: %v", attributes["periodic"]) + } + + if attributes["name"] != "" { + t.Fatal("Error marshalling to sql.NullString") + } + + if attributes["value"] != 0.0 { + t.Fatal("Error marshalling to sql.NullFloat64") + } + + if attributes["decimal"] != 0.0 { + t.Fatalf("Error marshalling to sql.NullInt32") + } + + if attributes["fractional"] != 0.0 { + t.Fatalf("Error marshalling to sql.NullInt64") + } + + if _, ok := attributes["computed_at"]; ok { + t.Fatal("Was expecting data.attributes.computed_at to be omitted") + } +} + +func TestMarshalStructPi(t *testing.T) { + pi := &Float{ + ID: sql.NullString{ + String: "314", + Valid: true, + }, + Periodic: sql.NullBool{ + Bool: false, + Valid: true, + }, + Name: sql.NullString{ + String: "Float", + Valid: true, + }, + Value: sql.NullFloat64{ + Float64: 3.1415926535897932, + Valid: true, + }, + Decimal: sql.NullInt32{ + Int32: 3, + Valid: true, + }, + Fractional: sql.NullInt64{ + Int64: 1415926535897932, + Valid: true, + }, + ComputedAt: sql.NullTime{ + Time: time.Date(2021, 3, 14, 15, 0, 0, 0, time.UTC), + Valid: true, + }, + } + + out := bytes.NewBuffer(nil) + if err := MarshalPayload(out, pi); err != nil { + t.Fatal(err) + } + + resp := new(OnePayload) + if err := json.Unmarshal(out.Bytes(), resp); err != nil { + t.Fatal(err) + } + + data := resp.Data + + if data.ID != "314" { + t.Fatal("Error marshalling id") + } + + if data.Type != "float" { + t.Fatal("Error marshalling type") + } + + if data.Attributes["periodic"] != false { + t.Fatal("Error marshalling to sql.NullBool") + } + + if data.Attributes["name"] != "Float" { + t.Fatal("Error marshalling to sql.NullString") + } + + if data.Attributes["value"] != 3.1415926535897932 { + t.Fatal("Error marshalling to sql.NullFloat64") + } + + if data.Attributes["decimal"] != 3.0 { + t.Fatalf("Error marshalling to sql.NullInt32") + } +} + func TestSupportsLinkable(t *testing.T) { testModel := &Blog{ ID: 5,