Skip to content

Commit

Permalink
remove QuoteStr() usage
Browse files Browse the repository at this point in the history
  • Loading branch information
xormplus committed Jul 24, 2019
1 parent 4a0e425 commit 7d5610b
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 60 deletions.
40 changes: 24 additions & 16 deletions engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ func (engine *Engine) SupportInsertMany() bool {

// QuoteStr Engine's database use which character as quote.
// mysql, sqlite use ` and postgres use "
// Deprecated, use Quote() instead
func (engine *Engine) QuoteStr() string {
return engine.dialect.QuoteStr()
}
Expand All @@ -199,13 +200,10 @@ func (engine *Engine) Quote(value string) string {
return value
}

if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' {
return value
}

value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1)
buf := builder.StringBuilder{}
engine.QuoteTo(&buf, value)

return engine.dialect.QuoteStr() + value + engine.dialect.QuoteStr()
return buf.String()
}

// QuoteTo quotes string and writes into the buffer
Expand All @@ -219,20 +217,30 @@ func (engine *Engine) QuoteTo(buf *builder.StringBuilder, value string) {
return
}

if string(value[0]) == engine.dialect.QuoteStr() || value[0] == '`' {
buf.WriteString(value)
quotePair := engine.dialect.Quote("")

if value[0] == '`' || len(quotePair) < 2 || value[0] == quotePair[0] { // no quote
_, _ = buf.WriteString(value)
return
} else {
prefix, suffix := quotePair[0], quotePair[1]

_ = buf.WriteByte(prefix)
for i := 0; i < len(value); i++ {
if value[i] == '.' {
_ = buf.WriteByte(suffix)
_ = buf.WriteByte('.')
_ = buf.WriteByte(prefix)
} else {
_ = buf.WriteByte(value[i])
}
}
_ = buf.WriteByte(suffix)
}

value = strings.Replace(value, ".", engine.dialect.QuoteStr()+"."+engine.dialect.QuoteStr(), -1)

buf.WriteString(engine.dialect.QuoteStr())
buf.WriteString(value)
buf.WriteString(engine.dialect.QuoteStr())
}

func (engine *Engine) quote(sql string) string {
return engine.dialect.QuoteStr() + sql + engine.dialect.QuoteStr()
return engine.dialect.Quote(sql)
}

// SqlType will be deprecated, please use SQLType instead
Expand Down Expand Up @@ -1605,7 +1613,7 @@ func (engine *Engine) formatColTime(col *core.Column, t time.Time) (v interface{
func (engine *Engine) formatTime(sqlTypeName string, t time.Time) (v interface{}) {
switch sqlTypeName {
case core.Time:
s := t.Format("2006-01-02 15:04:05") //time.RFC3339
s := t.Format("2006-01-02 15:04:05") // time.RFC3339
v = s[11:19]
case core.Date:
v = t.Format("2006-01-02")
Expand Down
23 changes: 22 additions & 1 deletion helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ func rValue(bean interface{}) reflect.Value {

func rType(bean interface{}) reflect.Type {
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
//return reflect.TypeOf(sliceValue.Interface())
// return reflect.TypeOf(sliceValue.Interface())
return sliceValue.Type()
}

Expand Down Expand Up @@ -309,3 +309,24 @@ func sliceEq(left, right []string) bool {
func indexName(tableName, idxName string) string {
return fmt.Sprintf("IDX_%v_%v", tableName, idxName)
}

func eraseAny(value string, strToErase ...string) string {
if len(strToErase) == 0 {
return value
}
var replaceSeq []string
for _, s := range strToErase {
replaceSeq = append(replaceSeq, s, "")
}

replacer := strings.NewReplacer(replaceSeq...)

return replacer.Replace(value)
}

func quoteColumns(cols []string, quoteFunc func(string) string, sep string) string {
for i := range cols {
cols[i] = quoteFunc(cols[i])
}
return strings.Join(cols, sep+" ")
}
22 changes: 21 additions & 1 deletion helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

package xorm

import "testing"
import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestSplitTag(t *testing.T) {
var cases = []struct {
Expand All @@ -24,3 +28,19 @@ func TestSplitTag(t *testing.T) {
}
}
}

func TestEraseAny(t *testing.T) {
raw := "SELECT * FROM `table`.[table_name]"
assert.EqualValues(t, raw, eraseAny(raw))
assert.EqualValues(t, "SELECT * FROM table.[table_name]", eraseAny(raw, "`"))
assert.EqualValues(t, "SELECT * FROM table.table_name", eraseAny(raw, "`", "[", "]"))
}

func TestQuoteColumns(t *testing.T) {
cols := []string{"f1", "f2", "f3"}
quoteFunc := func(value string) string {
return "[" + value + "]"
}

assert.EqualValues(t, "[f1], [f2], [f3]", quoteColumns(cols, quoteFunc, ","))
}
24 changes: 8 additions & 16 deletions session_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,23 +242,17 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error

var sql string
if session.engine.dialect.DBType() == core.ORACLE {
temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (",
temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr())
sql = fmt.Sprintf("INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL",
quoteColumns(colNames, session.engine.Quote, ","))
sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(),
quoteColumns(colNames, session.engine.Quote, ","),
strings.Join(colMultiPlaces, temp))
} else {
sql = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)",
sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()),
session.engine.QuoteStr(),
quoteColumns(colNames, session.engine.Quote, ","),
strings.Join(colMultiPlaces, "),("))
}
res, err := session.exec(sql, args...)
Expand Down Expand Up @@ -379,11 +373,9 @@ func (session *Session) innerInsert(bean interface{}) (int64, error) {
output = fmt.Sprintf(" OUTPUT Inserted.%s", table.AutoIncrement)
}
if len(colPlaces) > 0 {
sqlStr = fmt.Sprintf("INSERT INTO %s (%v%v%v)%s VALUES (%v)",
sqlStr = fmt.Sprintf("INSERT INTO %s (%v)%s VALUES (%v)",
session.engine.Quote(tableName),
session.engine.QuoteStr(),
strings.Join(colNames, session.engine.Quote(", ")),
session.engine.QuoteStr(),
quoteColumns(colNames, session.engine.Quote, ","),
output,
colPlaces)
} else {
Expand Down
15 changes: 8 additions & 7 deletions session_update.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,15 @@ func (session *Session) cacheUpdate(table *core.Table, tableName, sqlStr string,
return ErrCacheFailed
}
kvs := strings.Split(strings.TrimSpace(sqls[1]), ",")

for idx, kv := range kvs {
sps := strings.SplitN(kv, "=", 2)
sps2 := strings.Split(sps[0], ".")
colName := sps2[len(sps2)-1]
if strings.Contains(colName, "`") {
colName = strings.TrimSpace(strings.Replace(colName, "`", "", -1))
} else if strings.Contains(colName, session.engine.QuoteStr()) {
colName = strings.TrimSpace(strings.Replace(colName, session.engine.QuoteStr(), "", -1))
// treat quote prefix, suffix and '`' as quotes
quotes := append(strings.Split(session.engine.Quote(""), ""), "`")
if strings.ContainsAny(colName, strings.Join(quotes, "")) {
colName = strings.TrimSpace(eraseAny(colName, quotes...))
} else {
session.engine.logger.Debug("[cacheUpdate] cannot find column", tableName, colName)
return ErrCacheFailed
Expand Down Expand Up @@ -221,19 +222,19 @@ func (session *Session) Update(bean interface{}, condiBean ...interface{}) (int6
}
}

//for update action to like "column = column + ?"
// for update action to like "column = column + ?"
incColumns := session.statement.getInc()
for _, v := range incColumns {
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" + ?")
args = append(args, v.arg)
}
//for update action to like "column = column - ?"
// for update action to like "column = column - ?"
decColumns := session.statement.getDec()
for _, v := range decColumns {
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+session.engine.Quote(v.colName)+" - ?")
args = append(args, v.arg)
}
//for update action to like "column = expression"
// for update action to like "column = expression"
exprColumns := session.statement.getExpr()
for _, v := range exprColumns {
colNames = append(colNames, session.engine.Quote(v.colName)+" = "+v.expr)
Expand Down
29 changes: 10 additions & 19 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ package xorm

import (
"database/sql/driver"
"errors"
"fmt"
"reflect"
"strings"
Expand Down Expand Up @@ -426,7 +425,7 @@ func (statement *Statement) buildUpdates(bean interface{},
continue
}
} else {
//TODO: how to handler?
// TODO: how to handler?
panic("not supported")
}
} else {
Expand Down Expand Up @@ -607,21 +606,9 @@ func (statement *Statement) getExpr() map[string]exprParam {

func (statement *Statement) col2NewColsWithQuote(columns ...string) []string {
newColumns := make([]string, 0)
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")
for _, col := range columns {
col = strings.Replace(col, "`", "", -1)
col = strings.Replace(col, statement.Engine.QuoteStr(), "", -1)
ccols := strings.Split(col, ",")
for _, c := range ccols {
fields := strings.Split(strings.TrimSpace(c), ".")
if len(fields) == 1 {
newColumns = append(newColumns, statement.Engine.quote(fields[0]))
} else if len(fields) == 2 {
newColumns = append(newColumns, statement.Engine.quote(fields[0])+"."+
statement.Engine.quote(fields[1]))
} else {
panic(errors.New("unwanted colnames"))
}
}
newColumns = append(newColumns, statement.Engine.Quote(eraseAny(col, quotes...)))
}
return newColumns
}
Expand Down Expand Up @@ -792,7 +779,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
return statement
}
tbs := strings.Split(tp.TableName(), ".")
var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr())
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")

var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
case *builder.Builder:
Expand All @@ -802,7 +791,9 @@ func (statement *Statement) Join(joinOP string, tablename interface{}, condition
return statement
}
tbs := strings.Split(tp.TableName(), ".")
var aliasName = strings.Trim(tbs[len(tbs)-1], statement.Engine.QuoteStr())
quotes := append(strings.Split(statement.Engine.Quote(""), ""), "`")

var aliasName = strings.Trim(tbs[len(tbs)-1], strings.Join(quotes, ""))
fmt.Fprintf(&buf, "(%s) %s ON %v", subSQL, aliasName, condition)
statement.joinArgs = append(statement.joinArgs, subQueryArgs...)
default:
Expand Down Expand Up @@ -1272,7 +1263,7 @@ func (statement *Statement) convertUpdateSQL(sqlStr string) (string, string) {

var whereStr = sqls[1]

//TODO: for postgres only, if any other database?
// TODO: for postgres only, if any other database?
var paraStr string
if statement.Engine.dialect.DBType() == core.POSTGRES {
paraStr = "$"
Expand Down
9 changes: 9 additions & 0 deletions statement_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,3 +237,12 @@ func TestUpdateIgnoreOnlyFromDBFields(t *testing.T) {
testEngine.Update(record)
assertGetRecord()
}

func TestCol2NewColsWithQuote(t *testing.T) {
cols := []string{"f1", "f2", "t3.f3"}

statement := createTestStatement()

quotedCols := statement.col2NewColsWithQuote(cols...)
assert.EqualValues(t, []string{statement.Engine.Quote("f1"), statement.Engine.Quote("f2"), statement.Engine.Quote("t3.f3")}, quotedCols)
}

0 comments on commit 7d5610b

Please sign in to comment.