diff --git a/engine.go b/engine.go index 08e2ace5c..846889c36 100644 --- a/engine.go +++ b/engine.go @@ -194,7 +194,7 @@ func (engine *Engine) Quote(value string) string { } // QuoteTo quotes string and writes into the buffer -func (engine *Engine) QuoteTo(buf *bytes.Buffer, value string) { +func (engine *Engine) QuoteTo(buf *builder.StringBuilder, value string) { if buf == nil { return } diff --git a/session_insert.go b/session_insert.go index c1182fe64..2ea58fdaf 100644 --- a/session_insert.go +++ b/session_insert.go @@ -204,30 +204,28 @@ func (session *Session) innerInsertMulti(rowsSlicePtr interface{}) (int64, error } cleanupProcessorsClosures(&session.beforeClosures) - var sql = "INSERT INTO %s (%v%v%v) VALUES (%v)" - var statement string + var sql string if session.engine.dialect.DBType() == core.ORACLE { - sql = "INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL" temp := fmt.Sprintf(") INTO %s (%v%v%v) VALUES (", session.engine.Quote(tableName), session.engine.QuoteStr(), strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), session.engine.QuoteStr()) - statement = fmt.Sprintf(sql, + sql = fmt.Sprintf("INSERT ALL INTO %s (%v%v%v) VALUES (%v) SELECT 1 FROM DUAL", session.engine.Quote(tableName), session.engine.QuoteStr(), strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), session.engine.QuoteStr(), strings.Join(colMultiPlaces, temp)) } else { - statement = fmt.Sprintf(sql, + sql = fmt.Sprintf("INSERT INTO %s (%v%v%v) VALUES (%v)", session.engine.Quote(tableName), session.engine.QuoteStr(), strings.Join(colNames, session.engine.QuoteStr()+", "+session.engine.QuoteStr()), session.engine.QuoteStr(), strings.Join(colMultiPlaces, "),(")) } - res, err := session.exec(statement, args...) + res, err := session.exec(sql, args...) if err != nil { return 0, err } diff --git a/statement.go b/statement.go index 54c6006b7..7856936f5 100644 --- a/statement.go +++ b/statement.go @@ -5,7 +5,6 @@ package xorm import ( - "bytes" "database/sql/driver" "encoding/json" "errors" @@ -706,10 +705,9 @@ func (statement *Statement) OrderBy(order string) *Statement { // Desc generate `ORDER BY xx DESC` func (statement *Statement) Desc(colNames ...string) *Statement { - var buf bytes.Buffer - fmt.Fprintf(&buf, statement.OrderStr) + var buf builder.StringBuilder if len(statement.OrderStr) > 0 { - fmt.Fprint(&buf, ", ") + fmt.Fprint(&buf, statement.OrderStr, ", ") } newColNames := statement.col2NewColsWithQuote(colNames...) fmt.Fprintf(&buf, "%v DESC", strings.Join(newColNames, " DESC, ")) @@ -719,10 +717,9 @@ func (statement *Statement) Desc(colNames ...string) *Statement { // Asc provide asc order by query condition, the input parameters are columns. func (statement *Statement) Asc(colNames ...string) *Statement { - var buf bytes.Buffer - fmt.Fprintf(&buf, statement.OrderStr) + var buf builder.StringBuilder if len(statement.OrderStr) > 0 { - fmt.Fprint(&buf, ", ") + fmt.Fprint(&buf, statement.OrderStr, ", ") } newColNames := statement.col2NewColsWithQuote(colNames...) fmt.Fprintf(&buf, "%v ASC", strings.Join(newColNames, " ASC, ")) @@ -749,7 +746,7 @@ func (statement *Statement) Table(tableNameOrBean interface{}) *Statement { // Join The joinOP should be one of INNER, LEFT OUTER, CROSS etc - this will be prepended to JOIN func (statement *Statement) Join(joinOP string, tablename interface{}, condition string, args ...interface{}) *Statement { - var buf bytes.Buffer + var buf builder.StringBuilder if len(statement.JoinStr) > 0 { fmt.Fprintf(&buf, "%v %v JOIN ", statement.JoinStr, joinOP) } else { @@ -783,11 +780,11 @@ func (statement *Statement) Unscoped() *Statement { } func (statement *Statement) genColumnStr() string { - var buf bytes.Buffer if statement.RefTable == nil { return "" } + var buf builder.StringBuilder columns := statement.RefTable.Columns() for _, col := range columns { @@ -1031,23 +1028,20 @@ func (statement *Statement) genSumSQL(bean interface{}, columns ...string) (stri return sqlStr, append(statement.joinArgs, condArgs...), nil } -func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (a string, err error) { - var distinct string +func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, needOrderBy bool) (string, error) { + var ( + distinct string + dialect = statement.Engine.Dialect() + quote = statement.Engine.Quote + fromStr = " FROM " + top, mssqlCondi, whereStr string + ) if statement.IsDistinct && !strings.HasPrefix(columnStr, "count") { distinct = "DISTINCT " } - - var dialect = statement.Engine.Dialect() - var quote = statement.Engine.Quote - var top string - var mssqlCondi string - - var buf bytes.Buffer if len(condSQL) > 0 { - fmt.Fprintf(&buf, " WHERE %v", condSQL) + whereStr = " WHERE " + condSQL } - var whereStr = buf.String() - var fromStr = " FROM " if dialect.DBType() == core.MSSQL && strings.Contains(statement.TableName(), "..") { fromStr += statement.TableName() @@ -1107,43 +1101,46 @@ func (statement *Statement) genSelectSQL(columnStr, condSQL string, needLimit, n } } - // !nashtsai! REVIEW Sprintf is considered slowest mean of string concatnation, better to work with builder pattern - a = fmt.Sprintf("SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) + var buf builder.StringBuilder + fmt.Fprintf(&buf, "SELECT %v%v%v%v%v", distinct, top, columnStr, fromStr, whereStr) if len(mssqlCondi) > 0 { if len(whereStr) > 0 { - a += " AND " + mssqlCondi + fmt.Fprint(&buf, " AND ", mssqlCondi) } else { - a += " WHERE " + mssqlCondi + fmt.Fprint(&buf, " WHERE ", mssqlCondi) } } if statement.GroupByStr != "" { - a = fmt.Sprintf("%v GROUP BY %v", a, statement.GroupByStr) + fmt.Fprint(&buf, " GROUP BY ", statement.GroupByStr) } if statement.HavingStr != "" { - a = fmt.Sprintf("%v %v", a, statement.HavingStr) + fmt.Fprint(&buf, " ", statement.HavingStr) } if needOrderBy && statement.OrderStr != "" { - a = fmt.Sprintf("%v ORDER BY %v", a, statement.OrderStr) + fmt.Fprint(&buf, " ORDER BY ", statement.OrderStr) } if needLimit { if dialect.DBType() != core.MSSQL && dialect.DBType() != core.ORACLE { if statement.Start > 0 { - a = fmt.Sprintf("%v LIMIT %v OFFSET %v", a, statement.LimitN, statement.Start) + fmt.Fprintf(&buf, " LIMIT %v OFFSET %v", statement.LimitN, statement.Start) } else if statement.LimitN > 0 { - a = fmt.Sprintf("%v LIMIT %v", a, statement.LimitN) + fmt.Fprint(&buf, " LIMIT ", statement.LimitN) } } else if dialect.DBType() == core.ORACLE { if statement.Start != 0 || statement.LimitN != 0 { - a = fmt.Sprintf("SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", columnStr, columnStr, a, statement.Start+statement.LimitN, statement.Start) + oldString := buf.String() + buf.Reset() + fmt.Fprintf(&buf, "SELECT %v FROM (SELECT %v,ROWNUM RN FROM (%v) at WHERE ROWNUM <= %d) aat WHERE RN > %d", + columnStr, columnStr, oldString, statement.Start+statement.LimitN, statement.Start) } } } if statement.IsForUpdate { - a = dialect.ForUpdateSql(a) + return dialect.ForUpdateSql(buf.String()), nil } - return + return buf.String(), nil } func (statement *Statement) processIDParam() error {