diff --git a/engine.go b/engine.go index 7a5210073..d959d01f2 100644 --- a/engine.go +++ b/engine.go @@ -177,6 +177,14 @@ func (engine *Engine) QuoteStr() string { return engine.dialect.QuoteStr() } +func (engine *Engine) quoteColumns(columnStr string) string { + columns := strings.Split(columnStr, ",") + for i := 0; i < len(columns); i++ { + columns[i] = engine.Quote(strings.TrimSpace(columns[i])) + } + return strings.Join(columns, ",") +} + // Quote Use QuoteStr quote the string sql func (engine *Engine) Quote(value string) string { value = strings.TrimSpace(value) diff --git a/session_find.go b/session_find.go index 46bbf26c9..b75f83479 100644 --- a/session_find.go +++ b/session_find.go @@ -135,7 +135,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) if session.statement.JoinStr == "" { if columnStr == "" { if session.statement.GroupByStr != "" { - columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1)) + columnStr = session.engine.quoteColumns(session.statement.GroupByStr) } else { columnStr = session.statement.genColumnStr() } @@ -143,7 +143,7 @@ func (session *Session) find(rowsSlicePtr interface{}, condiBean ...interface{}) } else { if columnStr == "" { if session.statement.GroupByStr != "" { - columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1)) + columnStr = session.engine.quoteColumns(session.statement.GroupByStr) } else { columnStr = "*" } diff --git a/session_find_test.go b/session_find_test.go index d0ec339cf..f9ebdc913 100644 --- a/session_find_test.go +++ b/session_find_test.go @@ -268,6 +268,15 @@ func TestOrder(t *testing.T) { fmt.Println(users2) } +func TestGroupBy(t *testing.T) { + assert.NoError(t, prepareEngine()) + assertSync(t, new(Userinfo)) + + users := make([]Userinfo, 0) + err := testEngine.GroupBy("id, username").Find(&users) + assert.NoError(t, err) +} + func TestHaving(t *testing.T) { assert.NoError(t, prepareEngine()) assertSync(t, new(Userinfo)) diff --git a/session_query.go b/session_query.go index 7ddd08f1e..1d0b156bc 100644 --- a/session_query.go +++ b/session_query.go @@ -35,7 +35,7 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa if session.statement.JoinStr == "" { if columnStr == "" { if session.statement.GroupByStr != "" { - columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1)) + columnStr = session.engine.quoteColumns(session.statement.GroupByStr) } else { columnStr = session.statement.genColumnStr() } @@ -43,7 +43,7 @@ func (session *Session) genQuerySQL(sqlorArgs ...interface{}) (string, []interfa } else { if columnStr == "" { if session.statement.GroupByStr != "" { - columnStr = session.statement.Engine.Quote(strings.Replace(session.statement.GroupByStr, ",", session.engine.Quote(","), -1)) + columnStr = session.engine.quoteColumns(session.statement.GroupByStr) } else { columnStr = "*" } diff --git a/statement.go b/statement.go index 7856936f5..56644036a 100644 --- a/statement.go +++ b/statement.go @@ -933,7 +933,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, if len(statement.JoinStr) == 0 { if len(columnStr) == 0 { if len(statement.GroupByStr) > 0 { - columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) + columnStr = statement.Engine.quoteColumns(statement.GroupByStr) } else { columnStr = statement.genColumnStr() } @@ -941,7 +941,7 @@ func (statement *Statement) genGetSQL(bean interface{}) (string, []interface{}, } else { if len(columnStr) == 0 { if len(statement.GroupByStr) > 0 { - columnStr = statement.Engine.Quote(strings.Replace(statement.GroupByStr, ",", statement.Engine.Quote(","), -1)) + columnStr = statement.Engine.quoteColumns(statement.GroupByStr) } } }