diff --git a/paginator/paginator.go b/paginator/paginator.go index c55ec37..5c32fab 100644 --- a/paginator/paginator.go +++ b/paginator/paginator.go @@ -184,7 +184,12 @@ func (p *Paginator) decodeCursor(dest interface{}) (result []interface{}, err er } // replace null values for i := range result { - if isNil(result[i]) { + value := result[i] + // for custom types, evaluate isNil on the underlying value + if ct, ok := result[i].(cursor.CustomType); ok && p.rules[i].CustomType != nil { + value, err = ct.GetCustomTypeValue(p.rules[i].CustomType.Meta) + } + if isNil(value) { result[i] = p.rules[i].NULLReplacement } } diff --git a/paginator/paginator_paginate_test.go b/paginator/paginator_paginate_test.go index 543393f..a3216d5 100644 --- a/paginator/paginator_paginate_test.go +++ b/paginator/paginator_paginate_test.go @@ -669,6 +669,87 @@ func (s *paginatorSuite) TestPaginateCustomTypeString() { s.assertIDs(p1, 9, 8, 7) } +func (s *paginatorSuite) TestPaginateCustomTypeNullable() { + s.givenOrders([]order{ + { + ID: 1, + Remark: ptrStr("1"), + NullableCustomData: NullString{ + String: "A", + Valid: true, + }, + }, + { + ID: 2, + Remark: ptrStr("2"), + NullableCustomData: NullString{ + String: "", + Valid: false, + }, + }, + { + ID: 3, + Remark: ptrStr("2"), + NullableCustomData: NullString{ + String: "B", + Valid: true, + }, + }, + }) + + text := "text" + cfg := &Config{ + Limit: 1, + Rules: []Rule{ + { + Key: "Remark", + Order: ASC, + }, + { + Key: "NullableCustomData", + Order: ASC, + SQLType: &text, + CustomType: &CustomType{ + Type: reflect.TypeOf(NullString{}), + }, + NULLReplacement: "", + }, + }, + } + + var p1 []order + _, c, _ := New(cfg).Paginate(s.db, &p1) + s.Len(p1, 1) + s.assertForwardOnly(c) + s.assertIDs(p1, 1) + + var p2 []order + _, c, _ = New(cfg, WithAfter(*c.After)).Paginate(s.db, &p2) + s.Len(p2, 1) + s.assertBothDirections(c) + s.assertIDs(p2, 2) + + var p3 []order + _, c, _ = New(cfg, WithAfter(*c.After)).Paginate(s.db, &p3) + s.Len(p3, 1) + s.assertBackwardOnly(c) + s.assertIDs(p3, 3) + + // back to page 2 + var p4 []order + _, c, _ = New(cfg, WithBefore(*c.Before)).Paginate(s.db, &p4) + s.Len(p4, 1) + s.assertBothDirections(c) + s.assertIDs(p4, 2) + + // back to page 1 + var p5 []order + _, c, _ = New(cfg, WithBefore(*c.Before)).Paginate(s.db, &p5) + s.Len(p5, 1) + s.assertForwardOnly(c) + s.assertIDs(p5, 1) +} + /* compatibility */ func (s *paginatorSuite) TestPaginateConsistencyBetweenBuilderAndKeyOptions() { diff --git a/paginator/paginator_test.go b/paginator/paginator_test.go index 1548b58..3945753 100644 --- a/paginator/paginator_test.go +++ b/paginator/paginator_test.go @@ -1,6 +1,8 @@ package paginator import ( + "bytes" + "database/sql" "database/sql/driver" "encoding/json" "errors" @@ -24,10 +26,11 @@ func TestPaginator(t *testing.T) { /* models */ type order struct { - ID int `gorm:"primaryKey"` - Remark *string `gorm:"type:varchar(30)"` - CreatedAt time.Time `gorm:"type:timestamp;not null"` - Data JSON `gorm:"type:jsonb"` + ID int `gorm:"primaryKey"` + Remark *string `gorm:"type:varchar(30)"` + CreatedAt time.Time `gorm:"type:timestamp;not null"` + Data JSON `gorm:"type:jsonb"` + NullableCustomData NullString `gorm:"type:varchar(30)"` } type item struct { @@ -80,6 +83,55 @@ func (j JSON) GetCustomTypeValue(meta interface{}) (interface{}, error) { return i, nil } +/* NullString type used for testing custom types with nullable values */ + +type NullString sql.NullString + +func (ns NullString) MarshalJSON() ([]byte, error) { + if !ns.Valid { + return []byte("null"), nil + } + return json.Marshal(ns.String) +} + +func (ns *NullString) UnmarshalJSON(b []byte) error { + isNull := bytes.Equal(b, []byte("null")) + ns.Valid = !isNull + + if isNull { + ns.String = "null" + return nil + } + return json.Unmarshal(b, &ns.String) +} + +func (ns NullString) Value() (driver.Value, error) { + if !ns.Valid { + return nil, nil + } + return ns.String, nil +} + +func (ns *NullString) Scan(value interface{}) error { + if value == nil { + ns.String = "" + ns.Valid = false + } else if strValue, ok := value.(string); ok { + ns.Valid = true + ns.String = strValue + } else { + return errors.New("unsupported type") + } + return nil +} + +func (ns NullString) GetCustomTypeValue(meta interface{}) (interface{}, error) { + if ns.Valid { + return ns.String, nil + } + return nil, nil +} + /* paginator suite */ type paginatorSuite struct {