diff --git a/database/db.go b/database/db.go index 5434839..aafde42 100644 --- a/database/db.go +++ b/database/db.go @@ -868,7 +868,7 @@ func (db *DB) Log(ctx context.Context, query string, counter *com.Counter) perio })) } -func BuildUpsertStatement(db *DB, stmt UpsertStatement) (string, error) { +func BuildUpsertStatement(db *DB, stmt UpsertStatement) (string, int, error) { return NewQueryBuilder(db.DriverName()).UpsertStatement(stmt) } diff --git a/database/query_builder.go b/database/query_builder.go index a209148..d4888c4 100644 --- a/database/query_builder.go +++ b/database/query_builder.go @@ -10,7 +10,7 @@ import ( ) type QueryBuilder interface { - UpsertStatement(stmt UpsertStatement) (string, error) + UpsertStatement(stmt UpsertStatement) (string, int, error) InsertStatement(stmt InsertStatement) string @@ -43,7 +43,7 @@ type queryBuilder struct { columnMap ColumnMap } -func (qb *queryBuilder) UpsertStatement(stmt UpsertStatement) (string, error) { +func (qb *queryBuilder) UpsertStatement(stmt UpsertStatement) (string, int, error) { columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) into := stmt.Table() if into == "" { @@ -61,7 +61,7 @@ func (qb *queryBuilder) UpsertStatement(stmt UpsertStatement) (string, error) { ) setFormat = `"%[1]s" = EXCLUDED."%[1]s"` default: - return "", errors.New(fmt.Sprintf("unsupported driver: %s", qb.dbDriver)) + return "", 0, errors.New(fmt.Sprintf("unsupported driver: %s", qb.dbDriver)) } set := make([]string, 0, len(columns)) @@ -76,7 +76,7 @@ func (qb *queryBuilder) UpsertStatement(stmt UpsertStatement) (string, error) { fmt.Sprintf(":%s", strings.Join(columns, ", :")), clause, strings.Join(set, ", "), - ), nil + ), len(columns), nil } func (qb *queryBuilder) InsertStatement(stmt InsertStatement) string {