Skip to content

Commit

Permalink
support _limit in update
Browse files Browse the repository at this point in the history
  • Loading branch information
caibirdme committed Dec 12, 2020
1 parent b74c7dc commit 1965d0a
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 26 deletions.
59 changes: 39 additions & 20 deletions builder/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,18 @@ var (
errHavingUnsupportedOperator = errors.New(`[builder] "_having" contains unsupported operator`)
errLockModeValueType = errors.New(`[builder] the value of "_lockMode" must be of string type`)
errNotAllowedLockMode = errors.New(`[builder] the value of "_lockMode" is not allowed`)
errUpdateLimitType = errors.New(`[builder] the value of "_limit" in update query must be one of int,uint,int64,uint64`)

errWhereInterfaceSliceType = `[builder] the value of "xxx %s" must be of []interface{} type`
errEmptySliceCondition = `[builder] the value of "%s" must contain at least one element`

defaultIgnoreKeys = map[string]struct{}{
"_orderby": struct{}{},
"_groupby": struct{}{},
"_having": struct{}{},
"_limit": struct{}{},
"_lockMode": struct{}{},
}
)

type whereMapSet struct {
Expand Down Expand Up @@ -59,37 +68,31 @@ func BuildSelect(table string, where map[string]interface{}, selectField []strin
var groupBy string
var having map[string]interface{}
var lockMode string
copiedWhere := copyWhere(where)
if val, ok := copiedWhere["_orderby"]; ok {
if val, ok := where["_orderby"]; ok {
s, ok := val.(string)
if !ok {
err = errOrderByValueType
return
}
orderBy = strings.TrimSpace(s)
delete(copiedWhere, "_orderby")
}
if val, ok := copiedWhere["_groupby"]; ok {
if val, ok := where["_groupby"]; ok {
s, ok := val.(string)
if !ok {
err = errGroupByValueType
return
}
groupBy = strings.TrimSpace(s)
delete(copiedWhere, "_groupby")
if "" != groupBy {
if h, ok := copiedWhere["_having"]; ok {
if h, ok := where["_having"]; ok {
having, err = resolveHaving(h)
if nil != err {
return
}
}
}
}
if _, ok := copiedWhere["_having"]; ok {
delete(copiedWhere, "_having")
}
if val, ok := copiedWhere["_limit"]; ok {
if val, ok := where["_limit"]; ok {
arr, ok := val.([]uint)
if !ok {
err = errLimitValueType
Expand All @@ -108,9 +111,8 @@ func BuildSelect(table string, where map[string]interface{}, selectField []strin
begin: begin,
step: step,
}
delete(copiedWhere, "_limit")
}
if val, ok := copiedWhere["_lockMode"]; ok {
if val, ok := where["_lockMode"]; ok {
s, ok := val.(string)
if !ok {
err = errLockModeValueType
Expand All @@ -121,14 +123,13 @@ func BuildSelect(table string, where map[string]interface{}, selectField []strin
err = errNotAllowedLockMode
return
}
delete(copiedWhere, "_lockMode")
}
conditions, err := getWhereConditions(copiedWhere)
conditions, err := getWhereConditions(where, defaultIgnoreKeys)
if nil != err {
return
}
if having != nil {
havingCondition, err1 := getWhereConditions(having)
havingCondition, err1 := getWhereConditions(having, defaultIgnoreKeys)
if nil != err1 {
err = err1
return
Expand Down Expand Up @@ -169,16 +170,31 @@ func resolveHaving(having interface{}) (map[string]interface{}, error) {

// BuildUpdate work as its name says
func BuildUpdate(table string, where map[string]interface{}, update map[string]interface{}) (string, []interface{}, error) {
conditions, err := getWhereConditions(where)
var limit uint
if v, ok := where["_limit"]; ok {
switch val := v.(type) {
case int:
limit = uint(val)
case uint:
limit = val
case int64:
limit = uint(val)
case uint64:
limit = uint(val)
default:
return "", nil, errUpdateLimitType
}
}
conditions, err := getWhereConditions(where, defaultIgnoreKeys)
if nil != err {
return "", nil, err
}
return buildUpdate(table, update, conditions...)
return buildUpdate(table, update, limit, conditions...)
}

// BuildDelete work as its name says
func BuildDelete(table string, where map[string]interface{}) (string, []interface{}, error) {
conditions, err := getWhereConditions(where)
conditions, err := getWhereConditions(where, defaultIgnoreKeys)
if nil != err {
return "", nil, err
}
Expand Down Expand Up @@ -209,7 +225,7 @@ func isStringInSlice(str string, arr []string) bool {
return false
}

func getWhereConditions(where map[string]interface{}) ([]Comparable, error) {
func getWhereConditions(where map[string]interface{}, ignoreKeys map[string]struct{}) ([]Comparable, error) {
if len(where) == 0 {
return nil, nil
}
Expand All @@ -218,6 +234,9 @@ func getWhereConditions(where map[string]interface{}) ([]Comparable, error) {
var field, operator string
var err error
for key, val := range where {
if _, ok := ignoreKeys[key]; ok {
continue
}
if key == "_or" {
var (
orWheres []map[string]interface{}
Expand All @@ -231,7 +250,7 @@ func getWhereConditions(where map[string]interface{}) ([]Comparable, error) {
if orWhere == nil {
continue
}
orNestWhere, err := getWhereConditions(orWhere)
orNestWhere, err := getWhereConditions(orWhere, ignoreKeys)
if nil != err {
return nil, err
}
Expand Down
48 changes: 44 additions & 4 deletions builder/builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,46 @@ func Test_BuildUpdate(t *testing.T) {
err: nil,
},
},
{
in: inStruct{
table: "tb",
where: map[string]interface{}{
"foo": "bar",
"age >=": 23,
"sex in": []interface{}{"male", "female"},
"_limit": 10,
},
setData: map[string]interface{}{
"score": 50,
"district": "010",
},
},
out: outStruct{
cond: "UPDATE tb SET district=?,score=? WHERE (foo=? AND sex IN (?,?) AND age>=?) LIMIT ?",
vals: []interface{}{"010", 50, "bar", "male", "female", 23, 10},
err: nil,
},
},
{
in: inStruct{
table: "tb",
where: map[string]interface{}{
"foo": "bar",
"age >=": 23,
"sex in": []interface{}{"male", "female"},
"_limit": 5.5,
},
setData: map[string]interface{}{
"score": 50,
"district": "010",
},
},
out: outStruct{
cond: "",
vals: nil,
err: errUpdateLimitType,
},
},
}
ass := assert.New(t)
for _, tc := range data {
Expand Down Expand Up @@ -1245,13 +1285,13 @@ func TestNotLike_1(t *testing.T) {
func TestFixBug_insert_quote_field(t *testing.T) {
cond, vals, err := BuildInsert("tb", []map[string]interface{}{
{
"id": 1,
"id": 1,
"`order`": 2,
"`id`": 3, // I know this is forbidden, but just for test
"`id`": 3, // I know this is forbidden, but just for test
},
})
ass := assert.New(t)
ass.NoError(err)
ass.Equal("INSERT INTO tb (`id`,`order`,id) VALUES (?,?,?)", cond)
ass.Equal([]interface{}{3,2,1}, vals)
}
ass.Equal([]interface{}{3, 2, 1}, vals)
}
6 changes: 5 additions & 1 deletion builder/dao.go
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ func buildInsert(table string, setMap []map[string]interface{}, insertType inser
return fmt.Sprintf(format, insertType, quoteField(table), strings.Join(fields, ","), strings.Join(sets, ",")), vals, nil
}

func buildUpdate(table string, update map[string]interface{}, conditions ...Comparable) (string, []interface{}, error) {
func buildUpdate(table string, update map[string]interface{}, limit uint, conditions ...Comparable) (string, []interface{}, error) {
format := "UPDATE %s SET %s"
keys, vals := resolveKV(update)
var sets string
Expand All @@ -411,6 +411,10 @@ func buildUpdate(table string, update map[string]interface{}, conditions ...Comp
cond = fmt.Sprintf("%s WHERE %s", cond, whereString)
vals = append(vals, whereVals...)
}
if limit > 0 {
cond += " LIMIT ?"
vals = append(vals, int(limit))
}
return cond, vals, nil
}

Expand Down
2 changes: 1 addition & 1 deletion builder/dao_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ func TestBuildUpdate(t *testing.T) {
}
ass := assert.New(t)
for _, tc := range data {
cond, vals, err := buildUpdate(tc.table, tc.data, tc.conditions...)
cond, vals, err := buildUpdate(tc.table, tc.data, 0, tc.conditions...)
ass.Equal(tc.outErr, err)
ass.Equal(tc.outStr, cond)
ass.Equal(tc.outVals, vals)
Expand Down

0 comments on commit 1965d0a

Please sign in to comment.