diff --git a/builder/builder.go b/builder/builder.go index 6fdb667..de63aa5 100644 --- a/builder/builder.go +++ b/builder/builder.go @@ -21,9 +21,18 @@ var ( errHavingUnsupportedOperator = errors.New(`[builder] "_having" contains unsupported operator`) errLockModeValueType = errors.New(`[builder] the value of "_lockMode" must be of string type`) errNotAllowedLockMode = errors.New(`[builder] the value of "_lockMode" is not allowed`) + errUpdateLimitType = errors.New(`[builder] the value of "_limit" in update query must be one of int,uint,int64,uint64`) errWhereInterfaceSliceType = `[builder] the value of "xxx %s" must be of []interface{} type` errEmptySliceCondition = `[builder] the value of "%s" must contain at least one element` + + defaultIgnoreKeys = map[string]struct{}{ + "_orderby": struct{}{}, + "_groupby": struct{}{}, + "_having": struct{}{}, + "_limit": struct{}{}, + "_lockMode": struct{}{}, + } ) type whereMapSet struct { @@ -59,26 +68,23 @@ func BuildSelect(table string, where map[string]interface{}, selectField []strin var groupBy string var having map[string]interface{} var lockMode string - copiedWhere := copyWhere(where) - if val, ok := copiedWhere["_orderby"]; ok { + if val, ok := where["_orderby"]; ok { s, ok := val.(string) if !ok { err = errOrderByValueType return } orderBy = strings.TrimSpace(s) - delete(copiedWhere, "_orderby") } - if val, ok := copiedWhere["_groupby"]; ok { + if val, ok := where["_groupby"]; ok { s, ok := val.(string) if !ok { err = errGroupByValueType return } groupBy = strings.TrimSpace(s) - delete(copiedWhere, "_groupby") if "" != groupBy { - if h, ok := copiedWhere["_having"]; ok { + if h, ok := where["_having"]; ok { having, err = resolveHaving(h) if nil != err { return @@ -86,10 +92,7 @@ func BuildSelect(table string, where map[string]interface{}, selectField []strin } } } - if _, ok := copiedWhere["_having"]; ok { - delete(copiedWhere, "_having") - } - if val, ok := copiedWhere["_limit"]; ok { + if val, ok := where["_limit"]; ok { arr, ok := val.([]uint) if !ok { err = errLimitValueType @@ -108,9 +111,8 @@ func BuildSelect(table string, where map[string]interface{}, selectField []strin begin: begin, step: step, } - delete(copiedWhere, "_limit") } - if val, ok := copiedWhere["_lockMode"]; ok { + if val, ok := where["_lockMode"]; ok { s, ok := val.(string) if !ok { err = errLockModeValueType @@ -121,14 +123,13 @@ func BuildSelect(table string, where map[string]interface{}, selectField []strin err = errNotAllowedLockMode return } - delete(copiedWhere, "_lockMode") } - conditions, err := getWhereConditions(copiedWhere) + conditions, err := getWhereConditions(where, defaultIgnoreKeys) if nil != err { return } if having != nil { - havingCondition, err1 := getWhereConditions(having) + havingCondition, err1 := getWhereConditions(having, defaultIgnoreKeys) if nil != err1 { err = err1 return @@ -169,16 +170,31 @@ func resolveHaving(having interface{}) (map[string]interface{}, error) { // BuildUpdate work as its name says func BuildUpdate(table string, where map[string]interface{}, update map[string]interface{}) (string, []interface{}, error) { - conditions, err := getWhereConditions(where) + var limit uint + if v, ok := where["_limit"]; ok { + switch val := v.(type) { + case int: + limit = uint(val) + case uint: + limit = val + case int64: + limit = uint(val) + case uint64: + limit = uint(val) + default: + return "", nil, errUpdateLimitType + } + } + conditions, err := getWhereConditions(where, defaultIgnoreKeys) if nil != err { return "", nil, err } - return buildUpdate(table, update, conditions...) + return buildUpdate(table, update, limit, conditions...) } // BuildDelete work as its name says func BuildDelete(table string, where map[string]interface{}) (string, []interface{}, error) { - conditions, err := getWhereConditions(where) + conditions, err := getWhereConditions(where, defaultIgnoreKeys) if nil != err { return "", nil, err } @@ -209,7 +225,7 @@ func isStringInSlice(str string, arr []string) bool { return false } -func getWhereConditions(where map[string]interface{}) ([]Comparable, error) { +func getWhereConditions(where map[string]interface{}, ignoreKeys map[string]struct{}) ([]Comparable, error) { if len(where) == 0 { return nil, nil } @@ -218,6 +234,9 @@ func getWhereConditions(where map[string]interface{}) ([]Comparable, error) { var field, operator string var err error for key, val := range where { + if _, ok := ignoreKeys[key]; ok { + continue + } if key == "_or" { var ( orWheres []map[string]interface{} @@ -231,7 +250,7 @@ func getWhereConditions(where map[string]interface{}) ([]Comparable, error) { if orWhere == nil { continue } - orNestWhere, err := getWhereConditions(orWhere) + orNestWhere, err := getWhereConditions(orWhere, ignoreKeys) if nil != err { return nil, err } diff --git a/builder/builder_test.go b/builder/builder_test.go index e14c038..0ee3235 100644 --- a/builder/builder_test.go +++ b/builder/builder_test.go @@ -447,6 +447,46 @@ func Test_BuildUpdate(t *testing.T) { err: nil, }, }, + { + in: inStruct{ + table: "tb", + where: map[string]interface{}{ + "foo": "bar", + "age >=": 23, + "sex in": []interface{}{"male", "female"}, + "_limit": 10, + }, + setData: map[string]interface{}{ + "score": 50, + "district": "010", + }, + }, + out: outStruct{ + cond: "UPDATE tb SET district=?,score=? WHERE (foo=? AND sex IN (?,?) AND age>=?) LIMIT ?", + vals: []interface{}{"010", 50, "bar", "male", "female", 23, 10}, + err: nil, + }, + }, + { + in: inStruct{ + table: "tb", + where: map[string]interface{}{ + "foo": "bar", + "age >=": 23, + "sex in": []interface{}{"male", "female"}, + "_limit": 5.5, + }, + setData: map[string]interface{}{ + "score": 50, + "district": "010", + }, + }, + out: outStruct{ + cond: "", + vals: nil, + err: errUpdateLimitType, + }, + }, } ass := assert.New(t) for _, tc := range data { @@ -1245,13 +1285,13 @@ func TestNotLike_1(t *testing.T) { func TestFixBug_insert_quote_field(t *testing.T) { cond, vals, err := BuildInsert("tb", []map[string]interface{}{ { - "id": 1, + "id": 1, "`order`": 2, - "`id`": 3, // I know this is forbidden, but just for test + "`id`": 3, // I know this is forbidden, but just for test }, }) ass := assert.New(t) ass.NoError(err) ass.Equal("INSERT INTO tb (`id`,`order`,id) VALUES (?,?,?)", cond) - ass.Equal([]interface{}{3,2,1}, vals) -} \ No newline at end of file + ass.Equal([]interface{}{3, 2, 1}, vals) +} diff --git a/builder/dao.go b/builder/dao.go index 81fe7ec..9c04d29 100644 --- a/builder/dao.go +++ b/builder/dao.go @@ -397,7 +397,7 @@ func buildInsert(table string, setMap []map[string]interface{}, insertType inser return fmt.Sprintf(format, insertType, quoteField(table), strings.Join(fields, ","), strings.Join(sets, ",")), vals, nil } -func buildUpdate(table string, update map[string]interface{}, conditions ...Comparable) (string, []interface{}, error) { +func buildUpdate(table string, update map[string]interface{}, limit uint, conditions ...Comparable) (string, []interface{}, error) { format := "UPDATE %s SET %s" keys, vals := resolveKV(update) var sets string @@ -411,6 +411,10 @@ func buildUpdate(table string, update map[string]interface{}, conditions ...Comp cond = fmt.Sprintf("%s WHERE %s", cond, whereString) vals = append(vals, whereVals...) } + if limit > 0 { + cond += " LIMIT ?" + vals = append(vals, int(limit)) + } return cond, vals, nil } diff --git a/builder/dao_test.go b/builder/dao_test.go index eb3a080..c647589 100644 --- a/builder/dao_test.go +++ b/builder/dao_test.go @@ -285,7 +285,7 @@ func TestBuildUpdate(t *testing.T) { } ass := assert.New(t) for _, tc := range data { - cond, vals, err := buildUpdate(tc.table, tc.data, tc.conditions...) + cond, vals, err := buildUpdate(tc.table, tc.data, 0, tc.conditions...) ass.Equal(tc.outErr, err) ass.Equal(tc.outStr, cond) ass.Equal(tc.outVals, vals)