diff --git a/README.md b/README.md index 7f1fff2..7277f79 100644 --- a/README.md +++ b/README.md @@ -41,7 +41,6 @@ For config detail info, see [Godoc](https://pkg.go.dev/github.com/longbridge/gor ```go middleware := sharding.Register(map[string]sharding.Resolver{ "orders": { - EnableFullTable: true, ShardingColumn: "user_id", ShardingAlgorithm: func(value interface{}) (suffix string, err error) { switch user_id := value.(type) { @@ -51,9 +50,6 @@ middleware := sharding.Register(map[string]sharding.Resolver{ return "", errors.New("invalid user_id") } }, - ShardingAlgorithmByPrimaryKey: func(id int64) (suffix string) { - return fmt.Sprintf("_%02d", keygen.TableIdx(id)) - }, PrimaryKeyGenerate: func(tableIdx int64) int64 { keygen.Snowflake() keygen.UUID(), diff --git a/README.zh-CN.md b/README.zh-CN.md index fbcd372..43c6321 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -33,7 +33,6 @@ go get -u github.com/longbridgeapp/gorm-sharding ```go middleware := sharding.Register(map[string]sharding.Resolver{ "orders": { - EnableFullTable: true, ShardingColumn: "user_id", ShardingAlgorithm: func(value interface{}) (suffix string, err error) { switch user_id := value.(type) { @@ -43,9 +42,6 @@ middleware := sharding.Register(map[string]sharding.Resolver{ return "", errors.New("invalid user_id") } }, - ShardingAlgorithmByPrimaryKey: func(id int64) (suffix string) { - return fmt.Sprintf("_%02d", keygen.TableIdx(id)) - }, PrimaryKeyGenerate: func(tableIdx int64) int64 { return keygen.Next(tableIdx) } diff --git a/sharding.go b/sharding.go index dcdaa8c..cdb8607 100644 --- a/sharding.go +++ b/sharding.go @@ -2,6 +2,7 @@ package sharding import ( "errors" + "fmt" "strconv" "strings" "sync" @@ -45,9 +46,7 @@ type Resolver struct { ShardingAlgorithm func(columnValue interface{}) (suffix string, err error) // ShardingAlgorithmByPrimaryKey specifies a function to generate the sharding - // table's suffix by the primary key. - // Note, when the record contains an id field, - // ShardingAlgorithmByPrimaryKey will preferred than ShardingAlgorithm. + // table's suffix by the primary key. Used when no sharding key specified. // For example, this function use the Keygen library to generate the suffix. // // func(id int64) (suffix string) { @@ -148,14 +147,15 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, } var value interface{} - var isID bool + var id int64 + var keyFind bool if isInsert { - value, isID, err = s.insertValue(r.ShardingColumn, insertNames, insertValues, args...) + value, id, keyFind, err = s.insertValue(r.ShardingColumn, insertNames, insertValues, args...) if err != nil { return } } else { - value, isID, err = s.nonInsertValue(r.ShardingColumn, condition, args...) + value, id, keyFind, err = s.nonInsertValue(r.ShardingColumn, condition, args...) if err != nil { return } @@ -163,25 +163,17 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, var suffix string - if isID { - if id, ok := value.(int64); ok { - suffix = r.ShardingAlgorithmByPrimaryKey(id) - } else if idUint, ok := value.(uint64); ok { - suffix = r.ShardingAlgorithmByPrimaryKey(int64(idUint)) - } else if idStr, ok := value.(string); ok { - id, err := strconv.ParseInt(idStr, 10, 64) - if err != nil { - return ftQuery, stQuery, tableName, ErrInvalidID - } - suffix = r.ShardingAlgorithmByPrimaryKey(id) - } else { - return ftQuery, stQuery, tableName, ErrInvalidID - } - } else { + if keyFind { suffix, err = r.ShardingAlgorithm(value) if err != nil { return } + } else { + if r.ShardingAlgorithmByPrimaryKey == nil { + err = fmt.Errorf("there is not sharding key and ShardingAlgorithmByPrimaryKey is not configured") + return + } + suffix = r.ShardingAlgorithmByPrimaryKey(id) } newTable := &sqlparser.TableName{Name: &sqlparser.Ident{Name: tableName + suffix}} @@ -232,57 +224,50 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, return } -func (s *Sharding) insertValue(key string, names []*sqlparser.Ident, exprs []sqlparser.Expr, args ...interface{}) (value interface{}, isID bool, err error) { - bind := false - find := false - +func (s *Sharding) insertValue(key string, names []*sqlparser.Ident, exprs []sqlparser.Expr, args ...interface{}) (value interface{}, id int64, keyFind bool, err error) { if len(names) != len(exprs) { - return nil, false, errors.New("column names and expressions mismatch") + return nil, 0, keyFind, errors.New("column names and expressions mismatch") } for i, name := range names { if name.Name == key { switch expr := exprs[i].(type) { case *sqlparser.BindExpr: - bind = true - value = expr.Name + value, err = getBindValue(expr.Name, args) + if err != nil { + return nil, 0, keyFind, err + } case *sqlparser.StringLit: value = expr.Value case *sqlparser.NumberLit: value = expr.Value default: - return nil, false, sqlparser.ErrNotImplemented + return nil, 0, keyFind, sqlparser.ErrNotImplemented } - find = true + keyFind = true break } } - if !find { - return nil, false, ErrMissingShardingKey - } - - if bind { - value, err = getBindValue(value, args) + if !keyFind { + return nil, 0, keyFind, ErrMissingShardingKey } return } -func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ...interface{}) (value interface{}, isID bool, err error) { - bind := false - find := false +func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ...interface{}) (value interface{}, id int64, keyFind bool, err error) { err = sqlparser.Walk(sqlparser.VisitFunc(func(node sqlparser.Node) error { if n, ok := node.(*sqlparser.BinaryExpr); ok { if x, ok := n.X.(*sqlparser.Ident); ok { if x.Name == key && n.Op == sqlparser.EQ { - find = true - isID = false - bind = false + keyFind = true switch expr := n.Y.(type) { case *sqlparser.BindExpr: - bind = true - value = expr.Name + value, err = getBindValue(expr.Name, args) + if err != nil { + return err + } case *sqlparser.StringLit: value = expr.Value case *sqlparser.NumberLit: @@ -292,15 +277,21 @@ func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ... } return nil } else if x.Name == "id" && n.Op == sqlparser.EQ { - find = true - isID = true - bind = false switch expr := n.Y.(type) { case *sqlparser.BindExpr: - bind = true - value = expr.Name + v, err := getBindValue(expr.Name, args) + if err != nil { + return err + } + var ok bool + if id, ok = v.(int64); !ok { + return fmt.Errorf("ID should be int64 type") + } case *sqlparser.NumberLit: - value = expr.Value + id, err = strconv.ParseInt(expr.Value, 10, 64) + if err != nil { + return err + } default: return ErrInvalidID } @@ -314,12 +305,8 @@ func (s *Sharding) nonInsertValue(key string, condition sqlparser.Expr, args ... return } - if !find { - return nil, false, ErrMissingShardingKey - } - - if bind { - value, err = getBindValue(value, args) + if !keyFind && id == 0 { + return nil, 0, keyFind, ErrMissingShardingKey } return