Skip to content

Commit

Permalink
full table and pk algorithm is optional
Browse files Browse the repository at this point in the history
  • Loading branch information
hyperphoton committed Jan 18, 2022
1 parent 00accc1 commit 1512490
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 64 deletions.
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ For config detail info, see [Godoc](https://pkg.go.dev/github.com/longbridge/gor
```go
middleware := sharding.Register(map[string]sharding.Resolver{
"orders": {
EnableFullTable: true,
ShardingColumn: "user_id",
ShardingAlgorithm: func(value interface{}) (suffix string, err error) {
switch user_id := value.(type) {
Expand All @@ -51,9 +50,6 @@ middleware := sharding.Register(map[string]sharding.Resolver{
return "", errors.New("invalid user_id")
}
},
ShardingAlgorithmByPrimaryKey: func(id int64) (suffix string) {
return fmt.Sprintf("_%02d", keygen.TableIdx(id))
},
PrimaryKeyGenerate: func(tableIdx int64) int64 {
keygen.Snowflake()
keygen.UUID(),
Expand Down
4 changes: 0 additions & 4 deletions README.zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ go get -u github.com/longbridgeapp/gorm-sharding
```go
middleware := sharding.Register(map[string]sharding.Resolver{
"orders": {
EnableFullTable: true,
ShardingColumn: "user_id",
ShardingAlgorithm: func(value interface{}) (suffix string, err error) {
switch user_id := value.(type) {
Expand All @@ -43,9 +42,6 @@ middleware := sharding.Register(map[string]sharding.Resolver{
return "", errors.New("invalid user_id")
}
},
ShardingAlgorithmByPrimaryKey: func(id int64) (suffix string) {
return fmt.Sprintf("_%02d", keygen.TableIdx(id))
},
PrimaryKeyGenerate: func(tableIdx int64) int64 {
return keygen.Next(tableIdx)
}
Expand Down
99 changes: 43 additions & 56 deletions sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package sharding

import (
"errors"
"fmt"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -45,9 +46,7 @@ type Resolver struct {
ShardingAlgorithm func(columnValue interface{}) (suffix string, err error)

// ShardingAlgorithmByPrimaryKey specifies a function to generate the sharding
// table's suffix by the primary key.
// Note, when the record contains an id field,
// ShardingAlgorithmByPrimaryKey will preferred than ShardingAlgorithm.
// table's suffix by the primary key. Used when no sharding key specified.
// For example, this function use the Keygen library to generate the suffix.
//
// func(id int64) (suffix string) {
Expand Down Expand Up @@ -148,40 +147,33 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery,
}

var value interface{}
var isID bool
var id int64
var keyFind bool
if isInsert {
value, isID, err = s.insertValue(r.ShardingColumn, insertNames, insertValues, args...)
value, id, keyFind, err = s.insertValue(r.ShardingColumn, insertNames, insertValues, args...)
if err != nil {
return
}
} else {
value, isID, err = s.nonInsertValue(r.ShardingColumn, condition, args...)
value, id, keyFind, err = s.nonInsertValue(r.ShardingColumn, condition, args...)
if err != nil {
return
}
}

var suffix string

if isID {
if id, ok := value.(int64); ok {
suffix = r.ShardingAlgorithmByPrimaryKey(id)
} else if idUint, ok := value.(uint64); ok {
suffix = r.ShardingAlgorithmByPrimaryKey(int64(idUint))
} else if idStr, ok := value.(string); ok {
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
return ftQuery, stQuery, tableName, ErrInvalidID
}
suffix = r.ShardingAlgorithmByPrimaryKey(id)
} else {
return ftQuery, stQuery, tableName, ErrInvalidID
}
} else {
if keyFind {
suffix, err = r.ShardingAlgorithm(value)
if err != nil {
return
}
} else {
if r.ShardingAlgorithmByPrimaryKey == nil {
err = fmt.Errorf("there is not sharding key and ShardingAlgorithmByPrimaryKey is not configured")
return
}
suffix = r.ShardingAlgorithmByPrimaryKey(id)
}

newTable := &sqlparser.TableName{Name: &sqlparser.Ident{Name: tableName + suffix}}
Expand Down Expand Up @@ -232,57 +224,50 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery,
return
}

func (s *Sharding) insertValue(key string, names []*sqlparser.Ident, exprs []sqlparser.Expr, args ...interface{}) (value interface{}, isID bool, err error) {
bind := false
find := false

func (s *Sharding) insertValue(key string, names []*sqlparser.Ident, exprs []sqlparser.Expr, args ...interface{}) (value interface{}, id int64, keyFind bool, err error) {
if len(names) != len(exprs) {
return nil, false, errors.New("column names and expressions mismatch")
return nil, 0, keyFind, errors.New("column names and expressions mismatch")
}

for i, name := range names {
if name.Name == key {
switch expr := exprs[i].(type) {
case *sqlparser.BindExpr:
bind = true
value = expr.Name
value, err = getBindValue(expr.Name, args)
if err != nil {
return nil, 0, keyFind, err
}
case *sqlparser.StringLit:
value = expr.Value
case *sqlparser.NumberLit:
value = expr.Value
default:
return nil, false, sqlparser.ErrNotImplemented
return nil, 0, keyFind, sqlparser.ErrNotImplemented
}
find = true
keyFind = true
break
}
}
if !find {
return nil, false, ErrMissingShardingKey
}

if bind {
value, err = getBindValue(value, args)
if !keyFind {
return nil, 0, keyFind, ErrMissingShardingKey
}

return
}

func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ...interface{}) (value interface{}, isID bool, err error) {
bind := false
find := false
func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ...interface{}) (value interface{}, id int64, keyFind bool, err error) {

err = sqlparser.Walk(sqlparser.VisitFunc(func(node sqlparser.Node) error {
if n, ok := node.(*sqlparser.BinaryExpr); ok {
if x, ok := n.X.(*sqlparser.Ident); ok {
if x.Name == key && n.Op == sqlparser.EQ {
find = true
isID = false
bind = false
keyFind = true
switch expr := n.Y.(type) {
case *sqlparser.BindExpr:
bind = true
value = expr.Name
value, err = getBindValue(expr.Name, args)
if err != nil {
return err
}
case *sqlparser.StringLit:
value = expr.Value
case *sqlparser.NumberLit:
Expand All @@ -292,15 +277,21 @@ func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ...
}
return nil
} else if x.Name == "id" && n.Op == sqlparser.EQ {
find = true
isID = true
bind = false
switch expr := n.Y.(type) {
case *sqlparser.BindExpr:
bind = true
value = expr.Name
v, err := getBindValue(expr.Name, args)
if err != nil {
return err
}
var ok bool
if id, ok = v.(int64); !ok {
return fmt.Errorf("ID should be int64 type")
}
case *sqlparser.NumberLit:
value = expr.Value
id, err = strconv.ParseInt(expr.Value, 10, 64)
if err != nil {
return err
}
default:
return ErrInvalidID
}
Expand All @@ -314,12 +305,8 @@ func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ...
return
}

if !find {
return nil, false, ErrMissingShardingKey
}

if bind {
value, err = getBindValue(value, args)
if !keyFind && id == 0 {
return nil, 0, keyFind, ErrMissingShardingKey
}

return
Expand Down

0 comments on commit 1512490

Please sign in to comment.