diff --git a/README.md b/README.md index aebd42f..b5f3737 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ A paginator doing cursor-based pagination based on [GORM](https://github.com/go- - GORM `column` tag supported. - Error handling enhancement. - Exporting `cursor` module for advanced usage. +- Implement custom codec for cursor encoding/decoding. ## Installation @@ -100,6 +101,34 @@ We first need to create a `paginator.Paginator` for `User`, here are some useful return p } ``` + +If using the directional cursor functionality provided by `ParseDirectionAndCursor` & `SerialiseDirectionAndCursor`, then one can pass the directional cursor to the `SetCursor` function instead: + +```go +func CreateUserPaginator( + directionalCursor string, + order *paginator.Order, + limit *int, +) *paginator.Paginator { + p := paginator.New( + &paginator.Config{ + Keys: []string{"ID", "JoinedAt"}, + Limit: 10, + Order: paginator.ASC, + }, + ) + if order != nil { + p.SetOrder(*order) + } + if limit != nil { + p.SetLimit(*limit) + } + if directionalCursor != "" { + p.SetCursor(directionalCursor) + } + return p +} +``` 3. Configure by `paginator.Rule` for fine grained setting for each key: @@ -131,6 +160,91 @@ We first need to create a `paginator.Paginator` for `User`, here are some useful } ``` +4. By default the library encodes cursors with `base64`. If a custom encoding/decoding implementation is required, this can be implemented and passed as part of the configuration: + + +First implement your custom codec such that it conforms to the `CursorCodec` interface: + + +```go +type CursorCodec interface { + // Encode encodes model fields into cursor + Encode( + fields []pc.EncoderField, + model interface{}, + ) (string, error) + + // Decode decodes cursor into model fields + Decode( + fields []pc.DecoderField, + cursor string, + model interface{}, + ) ([]interface{}, error) + + // ParseDirectionAndCursor parses the direction and plain cursor. The resulting cursor can be fed to Decode() + ParseDirectionAndCursor( + cursor string, + ) (direction, plainCursor string, err error) + + // SerialiseDirectionAndCursor takes a direction and the plain cursor result from Encode() + // and serialises it into a directional cursor + SerialiseDirectionAndCursor( + direction string, + plainCursor string, + ) (cursor string, err error) +} + +type customCodec struct {} + +func (*CustomCodec) Encode(fields []pc.EncoderField, model interface{}) (string, error) { + ... +} + +func (*CustomCodec) Decode(fields []pc.DecoderField, cursor string, model interface{}) ([]interface{}, error) { + ... +} + +func (*CustomCodec) ParseDirectionAndCursor(fields []pc.EncoderField, model interface{}) (string, error) { + ... +} + +func (*CustomCodec) SerialiseDirectionAndCursor(fields []pc.DecoderField, cursor string, model interface{}) ([]interface{}, error) { + ... +} +``` + +Then pass an instance of your codec during initialisation: + +```go +func CreateUserPaginator(/* ... */) { + codec := &customCodec{} + + p := paginator.New( + &paginator.Config{ + Rules: []paginator.Rule{ + { + Key: "ID", + }, + { + Key: "JoinedAt", + Order: paginator.DESC, + SQLRepr: "users.created_at", + NULLReplacement: "1970-01-01", + }, + }, + Limit: 10, + // supply a custom implementation for the encoder/decoder + CursorCodec: codec, + // Order here will apply to keys without order specified. + // In this example paginator will order by "ID" ASC, "JoinedAt" DESC. + Order: paginator.ASC, + }, + ) + // ... + return p +} +``` + After knowing how to setup the paginator, we can start paginating `User` with GORM: ```go diff --git a/cursor/cursor.go b/cursor/cursor.go index 8072a85..923ea3c 100644 --- a/cursor/cursor.go +++ b/cursor/cursor.go @@ -5,3 +5,6 @@ type Cursor struct { After *string `json:"after" query:"after"` Before *string `json:"before" query:"before"` } + +var BeforePrefix = []byte("<") +var AfterPrefix = []byte(">") diff --git a/cursor/decoder.go b/cursor/decoder.go index 32bf88c..db5de4e 100644 --- a/cursor/decoder.go +++ b/cursor/decoder.go @@ -62,6 +62,28 @@ func (d *Decoder) Decode(cursor string, model interface{}) (fields []interface{} return } +// ParseDirectionAndCursor parses the direction and plain cursor string. The cursor string will then be fed to Decode() +func (d *Decoder) ParseDirectionAndCursor(cursor string) (direction, plainCursor string, err error) { + b, err := base64.StdEncoding.DecodeString(cursor) + if err != nil { + return "", "", ErrInvalidCursor + } + + var cursorBytes []byte + + if bytes.HasPrefix(b, BeforePrefix) { + direction = "before" + cursorBytes = bytes.TrimPrefix(b, BeforePrefix) + } else if bytes.HasPrefix(b, AfterPrefix) { + direction = "after" + cursorBytes = bytes.TrimPrefix(b, AfterPrefix) + } + + plainCursor = base64.StdEncoding.EncodeToString(cursorBytes) + + return +} + // DecodeStruct decodes cursor into model, model must be a pointer to struct or it will panic. func (d *Decoder) DecodeStruct(cursor string, model interface{}) (err error) { fields, err := d.Decode(cursor, model) diff --git a/cursor/decoder_test.go b/cursor/decoder_test.go index 4a1368b..dada959 100644 --- a/cursor/decoder_test.go +++ b/cursor/decoder_test.go @@ -65,3 +65,19 @@ func (s *decoderSuite) TestDecodeStructInvalidCursor() { err := NewDecoder([]DecoderField{{Key: "Value"}}).DecodeStruct("123", struct{ Value string }{}) s.Equal(ErrInvalidCursor, err) } + +func (s *decoderSuite) TestParseDirectionAndCursor() { + e := NewEncoder([]EncoderField{{Key: "Slice"}}) + c, err := e.Encode(struct{ Slice []string }{Slice: []string{"value"}}) + s.Nil(err) + + c, err = e.SerialiseDirectionAndCursor("after", c) + s.Nil(err) + + dec := NewDecoder([]DecoderField{{Key: "Slice"}}) + direction, plainCursor, err := dec.ParseDirectionAndCursor(c) + + s.Nil(err) + s.Equal(direction, "after") + s.NotEmpty(plainCursor) +} diff --git a/cursor/encoder.go b/cursor/encoder.go index 03cac72..94421ed 100644 --- a/cursor/encoder.go +++ b/cursor/encoder.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "encoding/json" "reflect" + "strings" "github.com/pilagod/gorm-cursor-paginator/v2/internal/util" ) @@ -34,6 +35,30 @@ func (e *Encoder) Encode(model interface{}) (string, error) { return base64.StdEncoding.EncodeToString(b), nil } +// SerialiseDirectionAndCursor serialises the direction and plain cursor string. +// This should be called with the result of Encode() +func (e *Encoder) SerialiseDirectionAndCursor(direction, plainCursor string) (string, error) { + b, err := base64.StdEncoding.DecodeString(plainCursor) + if err != nil { + return "", ErrInvalidCursor + } + + var directionPrefix []byte + + switch strings.ToLower(direction) { + case "after": + directionPrefix = AfterPrefix + case "before": + directionPrefix = BeforePrefix + default: + return "", ErrInvalidDirection + } + + cursorBytes := append(directionPrefix, b...) + + return base64.StdEncoding.EncodeToString(cursorBytes), nil +} + func (e *Encoder) marshalJSON(model interface{}) ([]byte, error) { rv := util.ReflectValue(model) fields := make([]interface{}, len(e.fields)) diff --git a/cursor/encoder_test.go b/cursor/encoder_test.go index 60b7696..1509267 100644 --- a/cursor/encoder_test.go +++ b/cursor/encoder_test.go @@ -57,3 +57,12 @@ func (s *encoderSuite) TestSliceValue() { _, err := e.Encode(struct{ Slice []string }{Slice: []string{"value"}}) s.Nil(err) } + +func (s *encoderSuite) TestSerialiseDirectionAndCursor() { + e := NewEncoder([]EncoderField{{Key: "Slice"}}) + c, err := e.Encode(struct{ Slice []string }{Slice: []string{"value"}}) + s.Nil(err) + + _, err = e.SerialiseDirectionAndCursor("after", c) + s.Nil(err) +} diff --git a/cursor/error.go b/cursor/error.go index 7474c45..d042dc0 100644 --- a/cursor/error.go +++ b/cursor/error.go @@ -4,6 +4,7 @@ import "errors" // Errors for encoder var ( - ErrInvalidCursor = errors.New("invalid cursor") - ErrInvalidModel = errors.New("invalid model") + ErrInvalidCursor = errors.New("invalid cursor") + ErrInvalidModel = errors.New("invalid model") + ErrInvalidDirection = errors.New("invalid direction") ) diff --git a/paginator/cursor.go b/paginator/cursor.go index ebb99f3..e261362 100644 --- a/paginator/cursor.go +++ b/paginator/cursor.go @@ -1,6 +1,64 @@ package paginator -import "github.com/pilagod/gorm-cursor-paginator/v2/cursor" +import ( + pc "github.com/pilagod/gorm-cursor-paginator/v2/cursor" +) // Cursor re-exports cursor.Cursor -type Cursor = cursor.Cursor +type Cursor = pc.Cursor + +// CursorCodec encodes/decodes cursor +type CursorCodec interface { + // Encode encodes model fields into cursor + Encode( + fields []pc.EncoderField, + model interface{}, + ) (string, error) + + // Decode decodes cursor into model fields + Decode( + fields []pc.DecoderField, + cursor string, + model interface{}, + ) ([]interface{}, error) + + // ParseDirectionAndCursor parses the direction and plain cursor. The resulting cursor can be fed to Decode() + ParseDirectionAndCursor( + cursor string, + ) (direction, plainCursor string, err error) + + // SerialiseDirectionAndCursor takes a direction and the plain cursor result from Encode() + // and serialises it into a directional cursor + SerialiseDirectionAndCursor( + direction string, + plainCursor string, + ) (cursor string, err error) +} + +// JSONCursorCodec encodes/decodes cursor in JSON format +type JSONCursorCodec struct{} + +// Encode encodes model fields into JSON format cursor +func (*JSONCursorCodec) Encode( + fields []pc.EncoderField, + model interface{}, +) (string, error) { + return pc.NewEncoder(fields).Encode(model) +} + +// Decode decodes JSON format cursor into model fields +func (*JSONCursorCodec) Decode( + fields []pc.DecoderField, + cursor string, + model interface{}, +) ([]interface{}, error) { + return pc.NewDecoder(fields).Decode(cursor, model) +} + +func (*JSONCursorCodec) ParseDirectionAndCursor(cursor string) (direction, plainCursor string, err error) { + return pc.NewDecoder(nil).ParseDirectionAndCursor(cursor) +} + +func (*JSONCursorCodec) SerialiseDirectionAndCursor(direction, plainCursor string) (string, error) { + return pc.NewEncoder(nil).SerialiseDirectionAndCursor(direction, plainCursor) +} diff --git a/paginator/option.go b/paginator/option.go index 47c3443..bba2a88 100644 --- a/paginator/option.go +++ b/paginator/option.go @@ -12,6 +12,7 @@ var defaultConfig = Config{ Limit: 10, Order: DESC, AllowTupleCmp: FALSE, + CursorCodec: &JSONCursorCodec{}, } // Option for paginator @@ -28,6 +29,7 @@ type Config struct { After string Before string AllowTupleCmp Flag + CursorCodec CursorCodec } // Apply applies config to paginator @@ -54,6 +56,9 @@ func (c *Config) Apply(p *Paginator) { if c.AllowTupleCmp != "" { p.SetAllowTupleCmp(c.AllowTupleCmp == TRUE) } + if c.CursorCodec != nil { + p.SetCursorCodec(c.CursorCodec) + } } // WithRules configures rules for paginator @@ -104,3 +109,10 @@ func WithAllowTupleCmp(flag Flag) Option { AllowTupleCmp: flag, } } + +// WithCursorCodec configures custom cursor codec +func WithCursorCodec(codec CursorCodec) Option { + return &Config{ + CursorCodec: codec, + } +} diff --git a/paginator/paginator.go b/paginator/paginator.go index 2eef3c0..3ee5153 100644 --- a/paginator/paginator.go +++ b/paginator/paginator.go @@ -27,6 +27,7 @@ type Paginator struct { limit int order Order allowTupleCmp bool + cursorCodec CursorCodec } // SetRules sets paging rules @@ -66,11 +67,31 @@ func (p *Paginator) SetBeforeCursor(beforeCursor string) { p.cursor.Before = &beforeCursor } +func (p *Paginator) SetCursor(cursor string) error { + direction, cursorString, err := p.cursorCodec.ParseDirectionAndCursor(cursor) + if err != nil { + return err + } + + if direction == "after" { + p.SetAfterCursor(cursorString) + } else if direction == "before" { + p.SetBeforeCursor(cursorString) + } + + return nil +} + // SetAllowTupleCmp enables or disables tuple comparison optimization func (p *Paginator) SetAllowTupleCmp(allow bool) { p.allowTupleCmp = allow } +// SetCursorCodec sets custom cursor codec +func (p *Paginator) SetCursorCodec(codec CursorCodec) { + p.cursorCodec = codec +} + // Paginate paginates data func (p *Paginator) Paginate(db *gorm.DB, dest interface{}) (result *gorm.DB, c Cursor, err error) { if err = p.validate(db, dest); err != nil { @@ -104,16 +125,6 @@ func (p *Paginator) Paginate(db *gorm.DB, dest interface{}) (result *gorm.DB, c return } -// GetCursorEncoder returns cursor encoder based on paginator rules -func (p *Paginator) GetCursorEncoder() *cursor.Encoder { - return cursor.NewEncoder(p.getEncoderFields()) -} - -// GetCursorDecoder returns cursor decoder based on paginator rules -func (p *Paginator) GetCursorDecoder() *cursor.Decoder { - return cursor.NewDecoder(p.getDecoderFields()) -} - /* private */ func (p *Paginator) validate(db *gorm.DB, dest interface{}) (err error) { @@ -189,14 +200,12 @@ func isNil(i interface{}) bool { } func (p *Paginator) decodeCursor(dest interface{}) (result []interface{}, err error) { - decoder := p.GetCursorDecoder() - if p.isForward() { - if result, err = decoder.Decode(*p.cursor.After, dest); err != nil { + if result, err = p.cursorCodec.Decode(p.getDecoderFields(), *p.cursor.After, dest); err != nil { err = ErrInvalidCursor } } else if p.isBackward() { - if result, err = decoder.Decode(*p.cursor.Before, dest); err != nil { + if result, err = p.cursorCodec.Decode(p.getDecoderFields(), *p.cursor.Before, dest); err != nil { err = ErrInvalidCursor } } @@ -312,21 +321,32 @@ func (p *Paginator) buildCursorSQLQueryArgs(fields []interface{}) (args []interf } func (p *Paginator) encodeCursor(elems reflect.Value, hasMore bool) (result Cursor, err error) { - encoder := p.GetCursorEncoder() // encode after cursor if p.isBackward() || hasMore { - c, err := encoder.Encode(elems.Index(elems.Len() - 1)) + c, err := p.cursorCodec.Encode(p.getEncoderFields(), elems.Index(elems.Len()-1)) + if err != nil { + return Cursor{}, err + } + + c, err = p.cursorCodec.SerialiseDirectionAndCursor("after", c) if err != nil { return Cursor{}, err } + result.After = &c } // encode before cursor if p.isForward() || (hasMore && p.isBackward()) { - c, err := encoder.Encode(elems.Index(0)) + c, err := p.cursorCodec.Encode(p.getEncoderFields(), elems.Index(0)) + if err != nil { + return Cursor{}, err + } + + c, err = p.cursorCodec.SerialiseDirectionAndCursor("before", c) if err != nil { return Cursor{}, err } + result.Before = &c } return diff --git a/paginator/paginator_paginate_test.go b/paginator/paginator_paginate_test.go index 156a312..bdec538 100644 --- a/paginator/paginator_paginate_test.go +++ b/paginator/paginator_paginate_test.go @@ -1,11 +1,17 @@ package paginator import ( + "fmt" "reflect" + "strconv" + "strings" "time" - "github.com/pilagod/pointer" "gorm.io/gorm" + + pc "github.com/pilagod/gorm-cursor-paginator/v2/cursor" + "github.com/pilagod/gorm-cursor-paginator/v2/internal/util" + "github.com/pilagod/pointer" ) func (s *paginatorSuite) TestPaginateDefaultOptions() { @@ -644,6 +650,8 @@ func (s *paginatorSuite) TestPaginateReplaceNULL() { s.assertForwardOnly(c) } +/* Custom Type */ + func (s *paginatorSuite) TestPaginateCustomTypeInt() { s.givenOrders(9) @@ -827,6 +835,102 @@ func (s *paginatorSuite) TestPaginateCustomTypeNullable() { s.assertIDs(p5, 1) } +/* Custom Cursor Codec */ + +type idCursorCodec struct{} + +func (c *idCursorCodec) Encode(fields []pc.EncoderField, model interface{}) (string, error) { + if len(fields) != 1 || fields[0].Key != "ID" { + return "", fmt.Errorf("ID field is required") + } + id := util.ReflectValue(model).FieldByName("ID").Interface() + return fmt.Sprintf("%d", id), nil +} + +func (c *idCursorCodec) Decode(fields []pc.DecoderField, cursor string, model interface{}) ([]interface{}, error) { + if len(fields) != 1 || fields[0].Key != "ID" { + return nil, fmt.Errorf("ID field is required") + } + if _, ok := util.ReflectType(model).FieldByName("ID"); !ok { + return nil, fmt.Errorf("ID field is required on model") + } + id, err := strconv.Atoi(cursor) + if err != nil { + return nil, err + } + return []interface{}{id}, nil +} + +func (*idCursorCodec) ParseDirectionAndCursor(cursor string) (direction, plainCursor string, err error) { + if strings.HasPrefix(cursor, ">") { + direction = "after" + plainCursor = cursor[1:] + } else if strings.HasPrefix(cursor, "<") { + direction = "before" + plainCursor = cursor[1:] + } else { + err = ErrInvalidCursor + } + return +} + +func (*idCursorCodec) SerialiseDirectionAndCursor(direction, plainCursor string) (string, error) { + if direction == "after" { + return fmt.Sprintf("%s%s", ">", plainCursor), nil + } else if direction == "before" { + return fmt.Sprintf("%s%s", "<", plainCursor), nil + } else { + return "", ErrInvalidCursor + } +} + +func (s *paginatorSuite) TestPaginateCustomCodec() { + s.givenOrders([]order{ + { + ID: 1, + }, + { + ID: 2, + }, + { + ID: 3, + }, + }) + + cfg := Config{ + Limit: 2, + } + codec := &idCursorCodec{} + + var p1 []order + _, c, _ := New( + &cfg, + WithCursorCodec(codec), + ).Paginate(s.db, &p1) + s.Len(p1, 2) + s.assertForwardOnly(c) + s.assertIDs(p1, 3, 2) + + var p2 []order + _, c, _ = New( + &cfg, + WithCursorCodec(codec), + WithAfter(*c.After), + ).Paginate(s.db, &p2) + s.Len(p2, 1) + s.assertBackwardOnly(c) + s.assertIDs(p2, 1) + + var p3 []order + _, c, _ = New( + &cfg, + WithCursorCodec(codec), + WithBefore(*c.Before), + ).Paginate(s.db, &p3) + s.Len(p3, 2) + s.assertIDs(p3, 3, 2) +} + /* compatibility */ func (s *paginatorSuite) TestPaginateConsistencyBetweenBuilderAndKeyOptions() {