From fc56bc8979825675f7aafe591360dcf549e301cc Mon Sep 17 00:00:00 2001 From: Huan Du Date: Sun, 26 Jan 2025 15:35:02 +0800 Subject: [PATCH 1/2] ignore empty values and expressions to prevent syntax error --- cond.go | 14 +++++++++----- cond_test.go | 5 +++++ delete.go | 1 - select.go | 8 ++++---- stringbuilder.go | 26 ++++++++++++++++++++++++++ update.go | 4 ++-- whereclause.go | 8 +++++++- whereclause_test.go | 8 ++++---- 8 files changed, 57 insertions(+), 17 deletions(-) diff --git a/cond.go b/cond.go index 01754a0..5991c3d 100644 --- a/cond.go +++ b/cond.go @@ -186,7 +186,7 @@ func (c *Cond) LTE(field string, value interface{}) string { // In is used to construct the expression "field IN (value...)". func (c *Cond) In(field string, values ...interface{}) string { - if len(field) == 0 { + if len(field) == 0 || len(values) == 0 { return "" } @@ -202,7 +202,7 @@ func (c *Cond) In(field string, values ...interface{}) string { // NotIn is used to construct the expression "field NOT IN (value...)". func (c *Cond) NotIn(field string, values ...interface{}) string { - if len(field) == 0 { + if len(field) == 0 || len(values) == 0 { return "" } @@ -369,6 +369,8 @@ func (c *Cond) NotBetween(field string, lower, upper interface{}) string { // Or is used to construct the expression OR logic like "expr1 OR expr2 OR expr3". func (c *Cond) Or(orExpr ...string) string { + orExpr = filterEmptyStrings(orExpr) + if len(orExpr) == 0 { return "" } @@ -392,6 +394,8 @@ func (c *Cond) Or(orExpr ...string) string { // And is used to construct the expression AND logic like "expr1 AND expr2 AND expr3". func (c *Cond) And(andExpr ...string) string { + andExpr = filterEmptyStrings(andExpr) + if len(andExpr) == 0 { return "" } @@ -453,7 +457,7 @@ func (c *Cond) NotExists(subquery interface{}) string { // Any is used to construct the expression "field op ANY (value...)". func (c *Cond) Any(field, op string, values ...interface{}) string { - if len(field) == 0 || len(op) == 0 { + if len(field) == 0 || len(op) == 0 || len(values) == 0 { return "" } @@ -471,7 +475,7 @@ func (c *Cond) Any(field, op string, values ...interface{}) string { // All is used to construct the expression "field op ALL (value...)". func (c *Cond) All(field, op string, values ...interface{}) string { - if len(field) == 0 || len(op) == 0 { + if len(field) == 0 || len(op) == 0 || len(values) == 0 { return "" } @@ -489,7 +493,7 @@ func (c *Cond) All(field, op string, values ...interface{}) string { // Some is used to construct the expression "field op SOME (value...)". func (c *Cond) Some(field, op string, values ...interface{}) string { - if len(field) == 0 || len(op) == 0 { + if len(field) == 0 || len(op) == 0 || len(values) == 0 { return "" } diff --git a/cond_test.go b/cond_test.go index ec2b4e2..885286e 100644 --- a/cond_test.go +++ b/cond_test.go @@ -123,7 +123,9 @@ func TestEmptyCond(t *testing.T) { func(cond *Cond) string { return cond.LessThan("", 123) }, func(cond *Cond) string { return cond.LessEqualThan("", 123) }, func(cond *Cond) string { return cond.In("", 1, 2, 3) }, + func(cond *Cond) string { return cond.In("a") }, func(cond *Cond) string { return cond.NotIn("", 1, 2, 3) }, + func(cond *Cond) string { return cond.NotIn("a") }, func(cond *Cond) string { return cond.Like("", "%Huan%") }, func(cond *Cond) string { return cond.ILike("", "%Huan%") }, func(cond *Cond) string { return cond.NotLike("", "%Huan%") }, @@ -137,14 +139,17 @@ func TestEmptyCond(t *testing.T) { func(cond *Cond) string { return cond.Any("", "", 1, 2) }, func(cond *Cond) string { return cond.Any("", ">", 1, 2) }, func(cond *Cond) string { return cond.Any("$a", "", 1, 2) }, + func(cond *Cond) string { return cond.Any("$a", ">") }, func(cond *Cond) string { return cond.All("", "", 1) }, func(cond *Cond) string { return cond.All("", ">", 1) }, func(cond *Cond) string { return cond.All("$a", "", 1) }, + func(cond *Cond) string { return cond.All("$a", ">") }, func(cond *Cond) string { return cond.Some("", "", 1, 2, 3) }, func(cond *Cond) string { return cond.Some("", ">", 1, 2, 3) }, func(cond *Cond) string { return cond.Some("$a", "", 1, 2, 3) }, + func(cond *Cond) string { return cond.Some("$a", ">") }, func(cond *Cond) string { return cond.IsDistinctFrom("", 1) }, func(cond *Cond) string { return cond.IsNotDistinctFrom("", 1) }, diff --git a/delete.go b/delete.go index 67b4c2f..c85c2b9 100644 --- a/delete.go +++ b/delete.go @@ -172,7 +172,6 @@ func (db *DeleteBuilder) Build() (sql string, args []interface{}) { // BuildWithFlavor returns compiled DELETE string and args with flavor and initial args. // They can be used in `DB#Query` of package `database/sql` directly. func (db *DeleteBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) { - buf := newStringBuilder() db.injection.WriteTo(buf, deleteMarkerInit) diff --git a/select.go b/select.go index 5047324..aac9dba 100644 --- a/select.go +++ b/select.go @@ -390,9 +390,9 @@ func (sb *SelectBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{ buf.WriteLeadingString("JOIN ") buf.WriteString(sb.joinTables[i]) - if exprs := sb.joinExprs[i]; len(exprs) > 0 { + if exprs := filterEmptyStrings(sb.joinExprs[i]); len(exprs) > 0 { buf.WriteString(" ON ") - buf.WriteStrings(sb.joinExprs[i], " AND ") + buf.WriteStrings(exprs, " AND ") } } @@ -414,9 +414,9 @@ func (sb *SelectBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{ buf.WriteLeadingString("GROUP BY ") buf.WriteStrings(sb.groupByCols, ", ") - if len(sb.havingExprs) > 0 { + if havingExprs := filterEmptyStrings(sb.havingExprs); len(havingExprs) > 0 { buf.WriteString(" HAVING ") - buf.WriteStrings(sb.havingExprs, " AND ") + buf.WriteStrings(havingExprs, " AND ") } sb.injection.WriteTo(buf, selectMarkerAfterGroupBy) diff --git a/stringbuilder.go b/stringbuilder.go index 4c2c7a2..6fd37df 100644 --- a/stringbuilder.go +++ b/stringbuilder.go @@ -75,3 +75,29 @@ func (sb *stringBuilder) Reset() { func (sb *stringBuilder) Grow(n int) { sb.builder.Grow(n) } + +// filterEmptyStrings removes empty strings from ss. +// As ss rarely contains empty strings, filterEmptyStrings tries to avoid allocation if possible. +func filterEmptyStrings(ss []string) []string { + emptyStrings := 0 + + for _, s := range ss { + if len(s) == 0 { + emptyStrings++ + } + } + + if emptyStrings == 0 { + return ss + } + + filtered := make([]string, 0, len(ss)-emptyStrings) + + for _, s := range ss { + if len(s) != 0 { + filtered = append(filtered, s) + } + } + + return filtered +} diff --git a/update.go b/update.go index 973c0c4..7fc66d8 100644 --- a/update.go +++ b/update.go @@ -260,9 +260,9 @@ func (ub *UpdateBuilder) BuildWithFlavor(flavor Flavor, initialArg ...interface{ ub.injection.WriteTo(buf, updateMarkerAfterUpdate) - if len(ub.assignments) > 0 { + if assignments := filterEmptyStrings(ub.assignments); len(assignments) > 0 { buf.WriteLeadingString("SET ") - buf.WriteStrings(ub.assignments, ", ") + buf.WriteStrings(assignments, ", ") } ub.injection.WriteTo(buf, updateMarkerAfterSet) diff --git a/whereclause.go b/whereclause.go index f06ff90..57931da 100644 --- a/whereclause.go +++ b/whereclause.go @@ -38,8 +38,14 @@ type clause struct { } func (c *clause) Build(flavor Flavor, initialArg ...interface{}) (sql string, args []interface{}) { + exprs := filterEmptyStrings(c.andExprs) + + if len(exprs) == 0 { + return + } + buf := newStringBuilder() - buf.WriteStrings(c.andExprs, " AND ") + buf.WriteStrings(exprs, " AND ") sql, args = c.args.CompileWithFlavor(buf.String(), flavor, initialArg...) return } diff --git a/whereclause_test.go b/whereclause_test.go index cc2478c..5368a0d 100644 --- a/whereclause_test.go +++ b/whereclause_test.go @@ -245,10 +245,10 @@ func TestWhereClauseSharedInstances(t *testing.T) { func TestEmptyWhereExpr(t *testing.T) { a := assert.New(t) - var emptyExpr []string - sb := Select("*").From("t").Where(emptyExpr...) - ub := Update("t").Set("foo = 1").Where(emptyExpr...) - db := DeleteFrom("t").Where(emptyExpr...) + blankExprs := []string{"", ""} + sb := Select("*").From("t").Where(blankExprs...) + ub := Update("t").Set("foo = 1").Where(blankExprs...) + db := DeleteFrom("t").Where(blankExprs...) a.Equal(sb.String(), "SELECT * FROM t") a.Equal(ub.String(), "UPDATE t SET foo = 1") From 06070c6df0efb31e5076b0b884d520699f712427 Mon Sep 17 00:00:00 2001 From: Huan Du Date: Tue, 11 Feb 2025 15:09:46 +0800 Subject: [PATCH 2/2] generate FALSE if values is empty in IN/ALL/SOME/ANY --- cond.go | 28 +++++- cond_test.go | 246 ++++++++++++++++++++++++++------------------------- 2 files changed, 149 insertions(+), 125 deletions(-) diff --git a/cond.go b/cond.go index 5991c3d..794a9d3 100644 --- a/cond.go +++ b/cond.go @@ -186,10 +186,15 @@ func (c *Cond) LTE(field string, value interface{}) string { // In is used to construct the expression "field IN (value...)". func (c *Cond) In(field string, values ...interface{}) string { - if len(field) == 0 || len(values) == 0 { + if len(field) == 0 { return "" } + // Empty values means "false". + if len(values) == 0 { + return "0 = 1" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -457,10 +462,15 @@ func (c *Cond) NotExists(subquery interface{}) string { // Any is used to construct the expression "field op ANY (value...)". func (c *Cond) Any(field, op string, values ...interface{}) string { - if len(field) == 0 || len(op) == 0 || len(values) == 0 { + if len(field) == 0 || len(op) == 0 { return "" } + // Empty values means "false". + if len(values) == 0 { + return "0 = 1" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -475,10 +485,15 @@ func (c *Cond) Any(field, op string, values ...interface{}) string { // All is used to construct the expression "field op ALL (value...)". func (c *Cond) All(field, op string, values ...interface{}) string { - if len(field) == 0 || len(op) == 0 || len(values) == 0 { + if len(field) == 0 || len(op) == 0 { return "" } + // Empty values means "false". + if len(values) == 0 { + return "0 = 1" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) @@ -493,10 +508,15 @@ func (c *Cond) All(field, op string, values ...interface{}) string { // Some is used to construct the expression "field op SOME (value...)". func (c *Cond) Some(field, op string, values ...interface{}) string { - if len(field) == 0 || len(op) == 0 || len(values) == 0 { + if len(field) == 0 || len(op) == 0 { return "" } + // Empty values means "false". + if len(values) == 0 { + return "0 = 1" + } + return c.Var(condBuilder{ Builder: func(ctx *argsCompileContext) { ctx.WriteString(field) diff --git a/cond_test.go b/cond_test.go index 885286e..c484602 100644 --- a/cond_test.go +++ b/cond_test.go @@ -12,165 +12,162 @@ import ( type TestPair struct { Expected string - Actual func(cond *Cond) string + Actual string +} + +func newTestPair(expected string, fn func(c *Cond) string) *TestPair { + cond := newCond() + format := fn(cond) + sql, _ := cond.Args.CompileWithFlavor(format, PostgreSQL) + return &TestPair{ + Expected: expected, + Actual: sql, + } } func TestCond(t *testing.T) { a := assert.New(t) - cases := map[string]func(cond *Cond) string{ - "$a = $1": func(cond *Cond) string { return cond.Equal("$a", 123) }, - "$b = $1": func(cond *Cond) string { return cond.E("$b", 123) }, - "$c = $1": func(cond *Cond) string { return cond.EQ("$c", 123) }, - "$a <> $1": func(cond *Cond) string { return cond.NotEqual("$a", 123) }, - "$b <> $1": func(cond *Cond) string { return cond.NE("$b", 123) }, - "$c <> $1": func(cond *Cond) string { return cond.NEQ("$c", 123) }, - "$a > $1": func(cond *Cond) string { return cond.GreaterThan("$a", 123) }, - "$b > $1": func(cond *Cond) string { return cond.G("$b", 123) }, - "$c > $1": func(cond *Cond) string { return cond.GT("$c", 123) }, - "$a >= $1": func(cond *Cond) string { return cond.GreaterEqualThan("$a", 123) }, - "$b >= $1": func(cond *Cond) string { return cond.GE("$b", 123) }, - "$c >= $1": func(cond *Cond) string { return cond.GTE("$c", 123) }, - "$a < $1": func(cond *Cond) string { return cond.LessThan("$a", 123) }, - "$b < $1": func(cond *Cond) string { return cond.L("$b", 123) }, - "$c < $1": func(cond *Cond) string { return cond.LT("$c", 123) }, - "$a <= $1": func(cond *Cond) string { return cond.LessEqualThan("$a", 123) }, - "$b <= $1": func(cond *Cond) string { return cond.LE("$b", 123) }, - "$c <= $1": func(cond *Cond) string { return cond.LTE("$c", 123) }, - "$a IN ($1, $2, $3)": func(cond *Cond) string { return cond.In("$a", 1, 2, 3) }, - "$a NOT IN ($1, $2, $3)": func(cond *Cond) string { return cond.NotIn("$a", 1, 2, 3) }, - "$a LIKE $1": func(cond *Cond) string { return cond.Like("$a", "%Huan%") }, - "$a ILIKE $1": func(cond *Cond) string { return cond.ILike("$a", "%Huan%") }, - "$a NOT LIKE $1": func(cond *Cond) string { return cond.NotLike("$a", "%Huan%") }, - "$a NOT ILIKE $1": func(cond *Cond) string { return cond.NotILike("$a", "%Huan%") }, - "$a IS NULL": func(cond *Cond) string { return cond.IsNull("$a") }, - "$a IS NOT NULL": func(cond *Cond) string { return cond.IsNotNull("$a") }, - "$a BETWEEN $1 AND $2": func(cond *Cond) string { return cond.Between("$a", 123, 456) }, - "$a NOT BETWEEN $1 AND $2": func(cond *Cond) string { return cond.NotBetween("$a", 123, 456) }, - "NOT 1 = 1": func(cond *Cond) string { return cond.Not("1 = 1") }, - "EXISTS ($1)": func(cond *Cond) string { return cond.Exists(1) }, - "NOT EXISTS ($1)": func(cond *Cond) string { return cond.NotExists(1) }, - "$a > ANY ($1, $2)": func(cond *Cond) string { return cond.Any("$a", ">", 1, 2) }, - "$a < ALL ($1)": func(cond *Cond) string { return cond.All("$a", "<", 1) }, - "$a > SOME ($1, $2, $3)": func(cond *Cond) string { return cond.Some("$a", ">", 1, 2, 3) }, - "$a IS DISTINCT FROM $1": func(cond *Cond) string { return cond.IsDistinctFrom("$a", 1) }, - "$a IS NOT DISTINCT FROM $1": func(cond *Cond) string { return cond.IsNotDistinctFrom("$a", 1) }, - "$1": func(cond *Cond) string { return cond.Var(123) }, - } - - for expected, f := range cases { - actual := callCond(f) - a.Equal(actual, expected) + cases := []*TestPair{ + newTestPair("$a = $1", func(c *Cond) string { return c.Equal("$a", 123) }), + newTestPair("$b = $1", func(c *Cond) string { return c.E("$b", 123) }), + newTestPair("$c = $1", func(c *Cond) string { return c.EQ("$c", 123) }), + newTestPair("$a <> $1", func(c *Cond) string { return c.NotEqual("$a", 123) }), + newTestPair("$b <> $1", func(c *Cond) string { return c.NE("$b", 123) }), + newTestPair("$c <> $1", func(c *Cond) string { return c.NEQ("$c", 123) }), + newTestPair("$a > $1", func(c *Cond) string { return c.GreaterThan("$a", 123) }), + newTestPair("$b > $1", func(c *Cond) string { return c.G("$b", 123) }), + newTestPair("$c > $1", func(c *Cond) string { return c.GT("$c", 123) }), + newTestPair("$a >= $1", func(c *Cond) string { return c.GreaterEqualThan("$a", 123) }), + newTestPair("$b >= $1", func(c *Cond) string { return c.GE("$b", 123) }), + newTestPair("$c >= $1", func(c *Cond) string { return c.GTE("$c", 123) }), + newTestPair("$a < $1", func(c *Cond) string { return c.LessThan("$a", 123) }), + newTestPair("$b < $1", func(c *Cond) string { return c.L("$b", 123) }), + newTestPair("$c < $1", func(c *Cond) string { return c.LT("$c", 123) }), + newTestPair("$a <= $1", func(c *Cond) string { return c.LessEqualThan("$a", 123) }), + newTestPair("$b <= $1", func(c *Cond) string { return c.LE("$b", 123) }), + newTestPair("$c <= $1", func(c *Cond) string { return c.LTE("$c", 123) }), + newTestPair("$a IN ($1, $2, $3)", func(c *Cond) string { return c.In("$a", 1, 2, 3) }), + newTestPair("0 = 1", func(c *Cond) string { return c.In("$a") }), + newTestPair("$a NOT IN ($1, $2, $3)", func(c *Cond) string { return c.NotIn("$a", 1, 2, 3) }), + newTestPair("$a LIKE $1", func(c *Cond) string { return c.Like("$a", "%Huan%") }), + newTestPair("$a ILIKE $1", func(c *Cond) string { return c.ILike("$a", "%Huan%") }), + newTestPair("$a NOT LIKE $1", func(c *Cond) string { return c.NotLike("$a", "%Huan%") }), + newTestPair("$a NOT ILIKE $1", func(c *Cond) string { return c.NotILike("$a", "%Huan%") }), + newTestPair("$a IS NULL", func(c *Cond) string { return c.IsNull("$a") }), + newTestPair("$a IS NOT NULL", func(c *Cond) string { return c.IsNotNull("$a") }), + newTestPair("$a BETWEEN $1 AND $2", func(c *Cond) string { return c.Between("$a", 123, 456) }), + newTestPair("$a NOT BETWEEN $1 AND $2", func(c *Cond) string { return c.NotBetween("$a", 123, 456) }), + newTestPair("NOT 1 = 1", func(c *Cond) string { return c.Not("1 = 1") }), + newTestPair("EXISTS ($1)", func(c *Cond) string { return c.Exists(1) }), + newTestPair("NOT EXISTS ($1)", func(c *Cond) string { return c.NotExists(1) }), + newTestPair("$a > ANY ($1, $2)", func(c *Cond) string { return c.Any("$a", ">", 1, 2) }), + newTestPair("0 = 1", func(c *Cond) string { return c.Any("$a", ">") }), + newTestPair("$a < ALL ($1)", func(c *Cond) string { return c.All("$a", "<", 1) }), + newTestPair("0 = 1", func(c *Cond) string { return c.All("$a", "<") }), + newTestPair("$a > SOME ($1, $2, $3)", func(c *Cond) string { return c.Some("$a", ">", 1, 2, 3) }), + newTestPair("0 = 1", func(c *Cond) string { return c.Some("$a", ">") }), + newTestPair("$a IS DISTINCT FROM $1", func(c *Cond) string { return c.IsDistinctFrom("$a", 1) }), + newTestPair("$a IS NOT DISTINCT FROM $1", func(c *Cond) string { return c.IsNotDistinctFrom("$a", 1) }), + newTestPair("$1", func(c *Cond) string { return c.Var(123) }), + } + + for _, f := range cases { + a.Equal(f.Actual, f.Expected) } } func TestOrCond(t *testing.T) { a := assert.New(t) - cases := []TestPair{ - {Expected: "(1 = 1 OR 2 = 2 OR 3 = 3)", Actual: func(cond *Cond) string { return cond.Or("1 = 1", "2 = 2", "3 = 3") }}, + cases := []*TestPair{ + newTestPair("(1 = 1 OR 2 = 2 OR 3 = 3)", func(c *Cond) string { return c.Or("1 = 1", "2 = 2", "3 = 3") }), - {Expected: "(1 = 1 OR 2 = 2)", Actual: func(cond *Cond) string { return cond.Or("", "1 = 1", "2 = 2") }}, - {Expected: "(1 = 1 OR 2 = 2)", Actual: func(cond *Cond) string { return cond.Or("1 = 1", "2 = 2", "") }}, - {Expected: "(1 = 1 OR 2 = 2)", Actual: func(cond *Cond) string { return cond.Or("1 = 1", "", "2 = 2") }}, + newTestPair("(1 = 1 OR 2 = 2)", func(c *Cond) string { return c.Or("", "1 = 1", "2 = 2") }), + newTestPair("(1 = 1 OR 2 = 2)", func(c *Cond) string { return c.Or("1 = 1", "2 = 2", "") }), + newTestPair("(1 = 1 OR 2 = 2)", func(c *Cond) string { return c.Or("1 = 1", "", "2 = 2") }), - {Expected: "(1 = 1)", Actual: func(cond *Cond) string { return cond.Or("1 = 1", "", "") }}, - {Expected: "(1 = 1)", Actual: func(cond *Cond) string { return cond.Or("", "1 = 1", "") }}, - {Expected: "(1 = 1)", Actual: func(cond *Cond) string { return cond.Or("", "", "1 = 1") }}, - {Expected: "(1 = 1)", Actual: func(cond *Cond) string { return cond.Or("1 = 1") }}, + newTestPair("(1 = 1)", func(c *Cond) string { return c.Or("1 = 1", "", "") }), + newTestPair("(1 = 1)", func(c *Cond) string { return c.Or("", "1 = 1", "") }), + newTestPair("(1 = 1)", func(c *Cond) string { return c.Or("", "", "1 = 1") }), + newTestPair("(1 = 1)", func(c *Cond) string { return c.Or("1 = 1") }), - {Expected: "", Actual: func(cond *Cond) string { return cond.Or("") }}, - {Expected: "", Actual: func(cond *Cond) string { return cond.Or() }}, - {Expected: "", Actual: func(cond *Cond) string { return cond.Or("", "", "") }}, + {Expected: "", Actual: newCond().Or("")}, + {Expected: "", Actual: newCond().Or()}, + {Expected: "", Actual: newCond().Or("", "", "")}, } for _, f := range cases { - actual := callCond(f.Actual) - a.Equal(actual, f.Expected) + a.Equal(f.Actual, f.Expected) } } func TestAndCond(t *testing.T) { a := assert.New(t) - cases := []TestPair{ - {Expected: "(1 = 1 AND 2 = 2 AND 3 = 3)", Actual: func(cond *Cond) string { return cond.And("1 = 1", "2 = 2", "3 = 3") }}, + cases := []*TestPair{ + newTestPair("(1 = 1 AND 2 = 2 AND 3 = 3)", func(c *Cond) string { return c.And("1 = 1", "2 = 2", "3 = 3") }), - {Expected: "(1 = 1 AND 2 = 2)", Actual: func(cond *Cond) string { return cond.And("", "1 = 1", "2 = 2") }}, - {Expected: "(1 = 1 AND 2 = 2)", Actual: func(cond *Cond) string { return cond.And("1 = 1", "2 = 2", "") }}, - {Expected: "(1 = 1 AND 2 = 2)", Actual: func(cond *Cond) string { return cond.And("1 = 1", "", "2 = 2") }}, + newTestPair("(1 = 1 AND 2 = 2)", func(c *Cond) string { return c.And("", "1 = 1", "2 = 2") }), + newTestPair("(1 = 1 AND 2 = 2)", func(c *Cond) string { return c.And("1 = 1", "2 = 2", "") }), + newTestPair("(1 = 1 AND 2 = 2)", func(c *Cond) string { return c.And("1 = 1", "", "2 = 2") }), - {Expected: "(1 = 1)", Actual: func(cond *Cond) string { return cond.And("1 = 1", "", "") }}, - {Expected: "(1 = 1)", Actual: func(cond *Cond) string { return cond.And("", "1 = 1", "") }}, - {Expected: "(1 = 1)", Actual: func(cond *Cond) string { return cond.And("", "", "1 = 1") }}, - {Expected: "(1 = 1)", Actual: func(cond *Cond) string { return cond.And("1 = 1") }}, + newTestPair("(1 = 1)", func(c *Cond) string { return c.And("1 = 1", "", "") }), + newTestPair("(1 = 1)", func(c *Cond) string { return c.And("", "1 = 1", "") }), + newTestPair("(1 = 1)", func(c *Cond) string { return c.And("", "", "1 = 1") }), + newTestPair("(1 = 1)", func(c *Cond) string { return c.And("1 = 1") }), - {Expected: "", Actual: func(cond *Cond) string { return cond.And("") }}, - {Expected: "", Actual: func(cond *Cond) string { return cond.And() }}, - {Expected: "", Actual: func(cond *Cond) string { return cond.And("", "", "") }}, + {Expected: "", Actual: newCond().And("")}, + {Expected: "", Actual: newCond().And()}, + {Expected: "", Actual: newCond().And("", "", "")}, } for _, f := range cases { - actual := callCond(f.Actual) - a.Equal(actual, f.Expected) + a.Equal(f.Actual, f.Expected) } } func TestEmptyCond(t *testing.T) { a := assert.New(t) - cases := []func(cond *Cond) string{ - func(cond *Cond) string { return cond.Equal("", 123) }, - func(cond *Cond) string { return cond.NotEqual("", 123) }, - func(cond *Cond) string { return cond.GreaterThan("", 123) }, - func(cond *Cond) string { return cond.GreaterEqualThan("", 123) }, - func(cond *Cond) string { return cond.LessThan("", 123) }, - func(cond *Cond) string { return cond.LessEqualThan("", 123) }, - func(cond *Cond) string { return cond.In("", 1, 2, 3) }, - func(cond *Cond) string { return cond.In("a") }, - func(cond *Cond) string { return cond.NotIn("", 1, 2, 3) }, - func(cond *Cond) string { return cond.NotIn("a") }, - func(cond *Cond) string { return cond.Like("", "%Huan%") }, - func(cond *Cond) string { return cond.ILike("", "%Huan%") }, - func(cond *Cond) string { return cond.NotLike("", "%Huan%") }, - func(cond *Cond) string { return cond.NotILike("", "%Huan%") }, - func(cond *Cond) string { return cond.IsNull("") }, - func(cond *Cond) string { return cond.IsNotNull("") }, - func(cond *Cond) string { return cond.Between("", 123, 456) }, - func(cond *Cond) string { return cond.NotBetween("", 123, 456) }, - func(cond *Cond) string { return cond.Not("") }, - - func(cond *Cond) string { return cond.Any("", "", 1, 2) }, - func(cond *Cond) string { return cond.Any("", ">", 1, 2) }, - func(cond *Cond) string { return cond.Any("$a", "", 1, 2) }, - func(cond *Cond) string { return cond.Any("$a", ">") }, - - func(cond *Cond) string { return cond.All("", "", 1) }, - func(cond *Cond) string { return cond.All("", ">", 1) }, - func(cond *Cond) string { return cond.All("$a", "", 1) }, - func(cond *Cond) string { return cond.All("$a", ">") }, - - func(cond *Cond) string { return cond.Some("", "", 1, 2, 3) }, - func(cond *Cond) string { return cond.Some("", ">", 1, 2, 3) }, - func(cond *Cond) string { return cond.Some("$a", "", 1, 2, 3) }, - func(cond *Cond) string { return cond.Some("$a", ">") }, - - func(cond *Cond) string { return cond.IsDistinctFrom("", 1) }, - func(cond *Cond) string { return cond.IsNotDistinctFrom("", 1) }, + cases := []string{ + newCond().Equal("", 123), + newCond().NotEqual("", 123), + newCond().GreaterThan("", 123), + newCond().GreaterEqualThan("", 123), + newCond().LessThan("", 123), + newCond().LessEqualThan("", 123), + newCond().In("", 1, 2, 3), + newCond().NotIn("", 1, 2, 3), + newCond().NotIn("a"), + newCond().Like("", "%Huan%"), + newCond().ILike("", "%Huan%"), + newCond().NotLike("", "%Huan%"), + newCond().NotILike("", "%Huan%"), + newCond().IsNull(""), + newCond().IsNotNull(""), + newCond().Between("", 123, 456), + newCond().NotBetween("", 123, 456), + newCond().Not(""), + + newCond().Any("", "", 1, 2), + newCond().Any("", ">", 1, 2), + newCond().Any("$a", "", 1, 2), + + newCond().All("", "", 1), + newCond().All("", ">", 1), + newCond().All("$a", "", 1), + + newCond().Some("", "", 1, 2, 3), + newCond().Some("", ">", 1, 2, 3), + newCond().Some("$a", "", 1, 2, 3), + + newCond().IsDistinctFrom("", 1), + newCond().IsNotDistinctFrom("", 1), } expected := "" - for _, f := range cases { - actual := callCond(f) + for _, actual := range cases { a.Equal(actual, expected) } } -func callCond(fn func(cond *Cond) string) (actual string) { - cond := &Cond{ - Args: &Args{}, - } - format := fn(cond) - actual, _ = cond.Args.CompileWithFlavor(format, PostgreSQL) - return -} - func TestCondWithFlavor(t *testing.T) { a := assert.New(t) cond := &Cond{ @@ -248,3 +245,10 @@ func TestCondMisuse(t *testing.T) { a.Equal(sql, "SELECT * FROM t1 WHERE /* INVALID ARG $256 */") a.Equal(args, nil) } + +func newCond() *Cond { + args := &Args{} + return &Cond{ + Args: args, + } +}