Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make TryInsert functions within the packages module use INSERT ... ON CONFLICT #21063

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
7f4e851
Make TryInsert functions within the packages module try to insert first
zeripath Sep 4, 2022
b29ab42
partial still broken
zeripath Sep 6, 2022
470064c
more
zeripath Sep 6, 2022
9c83ab8
Merge remote-tracking branch 'origin/main' into fix-19586-insert-firs…
zeripath Sep 7, 2022
bf94b55
Merge remote-tracking branch 'origin/main' into fix-19586-insert-firs…
zeripath Sep 7, 2022
8d3864a
ensure that timestamps are also set
zeripath Sep 7, 2022
71522ea
oops lets try to get the mysql syntax right
zeripath Sep 7, 2022
7843fe9
Merge remote-tracking branch 'origin/main' into fix-19586-insert-firs…
zeripath Feb 2, 2023
abcf334
attempt to fix mysql/mssql
zeripath Feb 2, 2023
430f964
Merge remote-tracking branch 'origin/main' into fix-19586-insert-firs…
zeripath Feb 6, 2023
5dff21a
fix insert on conflict for sqlite
zeripath Feb 7, 2023
f934a98
fix unit-test
zeripath Feb 7, 2023
b3db6da
attempt to fix postgres
zeripath Feb 7, 2023
5bc4924
placate lint
zeripath Feb 7, 2023
8f7987c
fix mssql?
zeripath Feb 7, 2023
fc5d9aa
hopefully fix psql
zeripath Feb 7, 2023
7ec881f
get more info mssql
zeripath Feb 7, 2023
d8aa794
perhaps inserted should be INSERTED?
zeripath Feb 7, 2023
3f39045
fix mssql bug
zeripath Feb 7, 2023
15855df
add comments and slight restructure
zeripath Feb 7, 2023
38d540b
slight adjustments
zeripath Feb 7, 2023
a941cba
placate the linter
zeripath Feb 7, 2023
5ef7902
as per wxiaoguang
zeripath Feb 7, 2023
04efbf9
more comments
zeripath Feb 7, 2023
2283b23
missed setting sql
zeripath Feb 7, 2023
a282e66
add testcases
zeripath Feb 7, 2023
62b1e20
fix fmt
zeripath Feb 7, 2023
f1222e8
slight bug in mysql
zeripath Feb 7, 2023
25abc72
as per wxiaoguang
zeripath Feb 8, 2023
028b5a6
split functions out and just use ignore
zeripath Feb 8, 2023
1c17006
Merge remote-tracking branch 'origin/main' into fix-19586-insert-firs…
zeripath Feb 8, 2023
ac6862a
remove unnecessary mutex
zeripath Feb 9, 2023
dc4638c
Merge remote-tracking branch 'origin/main' into fix-19586-insert-firs…
zeripath Feb 9, 2023
ecd3eea
Merge remote-tracking branch 'origin/main' into fix-19586-insert-firs…
zeripath Mar 12, 2023
bf18cf8
add a few comments to test;
zeripath Mar 12, 2023
1a6f5df
fix broken merge
zeripath Mar 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
365 changes: 365 additions & 0 deletions models/db/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,20 @@
package db

import (
"context"
"fmt"
"reflect"
"strings"
"time"

"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"

"xorm.io/builder"
"xorm.io/xorm/convert"
"xorm.io/xorm/dialects"
"xorm.io/xorm/schemas"
)

// BuildCaseInsensitiveLike returns a condition to check if the given value is like the given key case-insensitively.
Expand All @@ -20,3 +28,360 @@ func BuildCaseInsensitiveLike(key, value string) builder.Cond {
}
return builder.Like{"UPPER(" + key + ")", strings.ToUpper(value)}
}

// InsertOnConflictDoNothing will attempt to insert the provided bean but if there is a conflict it will not error out
// This function will update the ID of the provided bean if there is an insertion
// This does not do all of the conversions that xorm would do automatically but it does quite a number of them
// once xorm has a working InsertOnConflictDoNothing this function could be removed.
func InsertOnConflictDoNothing(ctx context.Context, bean interface{}) (bool, error) {
e := GetEngine(ctx)

tableName := x.TableName(bean, true)
table, err := x.TableInfo(bean)
if err != nil {
return false, err
}

autoIncrCol := table.AutoIncrColumn()

columns := table.Columns()

colNames, values, zeroedColNames, zeroedValues, err := getColNamesAndValuesFromBean(bean, columns)
if err != nil {
return false, err
}

if len(colNames) == 0 {
return false, fmt.Errorf("provided bean to insert has all empty values")
}

// MSSQL needs to separately pass in the columns with the unique constraint and we need to
// include empty columns which are in the constraint in the insert for other dbs
uniqueCols, uniqueValues, colNames, values := addInUniqueCols(colNames, values, zeroedColNames, zeroedValues, table)
if len(uniqueCols) == 0 {
return false, fmt.Errorf("provided bean has no unique constraints")
}

var insertArgs []any

switch {
case setting.Database.UseSQLite3 || setting.Database.UsePostgreSQL || setting.Database.UseMySQL:
insertArgs = generateInsertNoConflictSQLAndArgs(tableName, colNames, values, autoIncrCol)
case setting.Database.UseMSSQL:
insertArgs = generateInsertNoConflictSQLAndArgsForMSSQL(tableName, colNames, values, uniqueCols, uniqueValues, autoIncrCol)
default:
return false, fmt.Errorf("database type not supported")
}

if autoIncrCol != nil && (setting.Database.UsePostgreSQL || setting.Database.UseMSSQL) {
// Postgres and MSSQL do not use the LastInsertID mechanism
// Therefore use query rather than exec and read the last provided ID back in

res, err := e.Query(insertArgs...)
if err != nil {
return false, fmt.Errorf("error in query: %s, %w", insertArgs[0], err)
}
if len(res) == 0 {
// this implies there was a conflict
return false, nil
}

aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
log.Error("unable to get value for autoincrcol of %#v %v", bean, err)
}

if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
return true, nil
}

id := res[0][autoIncrCol.Name]
err = convert.AssignValue(*aiValue, id)
if err != nil {
return true, fmt.Errorf("error in assignvalue %v %v %w", id, res, err)
}
return true, nil
}

res, err := e.Exec(values...)
zeripath marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return false, err
}

n, err := res.RowsAffected()
if err != nil {
return n != 0, err
}

if n != 0 && autoIncrCol != nil {
id, err := res.LastInsertId()
if err != nil {
return true, err
}
reflect.ValueOf(bean).Elem().FieldByName(autoIncrCol.FieldName).SetInt(id)
}

return n != 0, err
}

// generateInsertNoConflictSQLAndArgs will create the correct insert code for most of the DBs except MSSQL
func generateInsertNoConflictSQLAndArgs(tableName string, colNames []string, args []any, autoIncrCol *schemas.Column) (insertArgs []any) {
sb := &strings.Builder{}

quote := x.Dialect().Quoter().Quote
write := func(args ...string) {
for _, arg := range args {
_, _ = sb.WriteString(arg)
}
}
write("INSERT ")
if setting.Database.UseMySQL && autoIncrCol == nil {
write("IGNORE ")
}
write("INTO ", quote(tableName), " (")
_ = x.Dialect().Quoter().JoinWrite(sb, colNames, ",")
write(") VALUES (?")
for range colNames[1:] {
write(",?")
}
switch {
case setting.Database.UsePostgreSQL:
write(") ON CONFLICT DO NOTHING")
if autoIncrCol != nil {
write(" RETURNING ", quote(autoIncrCol.Name))
}
case setting.Database.UseSQLite3:
write(") ON CONFLICT DO NOTHING")
case setting.Database.UseMySQL:
if autoIncrCol != nil {
write(") ON DUPLICATE KEY UPDATE ", quote(autoIncrCol.Name), " = ", quote(autoIncrCol.Name))
}
Copy link
Contributor

@wxiaoguang wxiaoguang Feb 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to bother, but why INSERT IGNORE doesn't work for autoIncrCol != nil?

If you meant to do it for LastInsertId, I think it (ON DUPLICATE KEY UPDATE) doesn't work, either.

Otherwise I really can not understand its purpose.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't you try it. When I wrote this code it required the ON DUPLICATE KEY UPDATE and not the IGNORE

Copy link
Contributor

@wxiaoguang wxiaoguang Feb 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I have tried (update: out-dated)

image

Copy link
Contributor

@wxiaoguang wxiaoguang Feb 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(merged into next comment below)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it really works from XORM or Golang MySQL Driver, there should be some complete test cases for it.

Copy link
Contributor

@wxiaoguang wxiaoguang Feb 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What you really need is this (update: out-dated)

update id=last_insert_id(id)

Table:

CREATE TABLE `t` (
  `id` int NOT NULL AUTO_INCREMENT,
  `k` varchar(100),
  `v` varchar(100) DEFAULT '',
  `i` int DEFAULT NULL,
  PRIMARY KEY (`id`),
  UNIQUE KEY `k` (`k`)
) ENGINE=InnoDB AUTO_INCREMENT=9
mysql> insert into t (k) values ('1') on duplicate key update id=last_insert_id(id);
Query OK, 0 rows affected (0.00 sec)

mysql> select last_insert_id();
+------------------+
| last_insert_id() |
+------------------+
|                1 |
+------------------+
1 row in set (0.00 sec)

Copy link
Contributor Author

@zeripath zeripath Feb 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a) I would have thought that last_insert_id() requires that there are no other id insertions in the meantime id = last_insert_id() would cause updates on conflicts which is completely the wrong thing to do .
b) id = id avoids that.
c) IGNORE is only used when there is no autoincrement ID present in the table where there is no id to update.
d) I've written some tests now.

Copy link
Contributor

@wxiaoguang wxiaoguang Feb 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a) last_insert_id is the last successfully inserted ID in current session, it's always safe across transactions. And it won't be reset if a new insertion fails. you might have seen some incorrect last_insert_id/unrelated during your test.
c) you can also do on duplicate update set col[0] = col[0], then no separate ignore need to be added IMO.
b) & d) the tests are incorrect at the moment. see the comment below. #21063 (comment)


Update: out-dated.

Copy link
Contributor

@wxiaoguang wxiaoguang Feb 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MySQL demo (update: out-dated, the demo is for filling the ID)

mysql> insert into t (k) values ('a');
Query OK, 1 row affected (0.00 sec)

mysql> select last_insert_id();
+------------------+
| last_insert_id() |
+------------------+
|                1 |
+------------------+
1 row in set (0.00 sec)

mysql> insert into t (k) values ('b');
Query OK, 1 row affected (0.00 sec)

mysql> select last_insert_id();
+------------------+
| last_insert_id() |
+------------------+
|                2 |
+------------------+
1 row in set (0.00 sec)

mysql> insert into t (k) values ('a') on duplicate key update id=id;
Query OK, 0 rows affected (0.00 sec)

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
!!!!!!! HERE You are inserting duplicated `a` and want to see `last_insert_id=1` for 'a', BUT:
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

mysql> select last_insert_id();
+------------------+
| last_insert_id() |
+------------------+
|                2 |
+------------------+
1 row in set (0.00 sec)

mysql> select * from t where id=2;
+----+------+------+------+
| id | k    | v    | i    |
+----+------+------+------+
|  2 | b    |      | NULL |
+----+------+------+------+
1 row in set (0.00 sec)

}
args[0] = sb.String()
return args
}

// generateInsertNoConflictSQLAndArgsForMSSQL writes the INSERT ... ON CONFLICT sql variant for MSSQL
// MSSQL uses MERGE <tablename> WITH <lock> ... but needs to pre-select the unique cols first
// then WHEN NOT MATCHED INSERT - this is kind of the opposite way round from INSERT ... ON CONFLICT
func generateInsertNoConflictSQLAndArgsForMSSQL(tableName string, colNames []string, args []any, uniqueCols []string, uniqueArgs []any, autoIncrCol *schemas.Column) (insertArgs []any) {
sb := &strings.Builder{}

quote := x.Dialect().Quoter().Quote
write := func(args ...string) {
for _, arg := range args {
_, _ = sb.WriteString(arg)
}
}

write("MERGE ", quote(tableName), " WITH (HOLDLOCK) AS target USING (SELECT ? AS ")
_ = x.Dialect().Quoter().JoinWrite(sb, uniqueCols, ", ? AS ")
write(") AS src ON src.", quote(uniqueCols[0]), "= target.", quote(uniqueCols[0]))
for _, uniqueCol := range uniqueCols[1:] {
write(" AND src.", quote(uniqueCol), "= target.", quote(uniqueCol))
}
write(" WHEN NOT MATCHED THEN INSERT (")
_ = x.Dialect().Quoter().JoinWrite(sb, colNames, ",")
write(") VALUES (?")
for range colNames[1:] {
write(", ?")
}
write(")")
if autoIncrCol != nil {
write(" OUTPUT INSERTED.", quote(autoIncrCol.Name))
}
write(";")
uniqueArgs[0] = sb.String()
return append(uniqueArgs, args[1:]...)
}

// addInUniqueCols determines the columns that refer to unique constraints and creates slices for these
// as they're needed by MSSQL. In addition, any columns which are zero-valued but are part of a constraint
// are added back in to the colNames and args
func addInUniqueCols(colNames []string, args []any, zeroedColNames []string, emptyArgs []any, table *schemas.Table) (uniqueCols []string, uniqueArgs []any, insertCols []string, insertArgs []any) {
uniqueCols = make([]string, 0, len(table.Columns()))
uniqueArgs = make([]interface{}, 1, len(uniqueCols)+1) // leave uniqueArgs[0] empty to put the SQL in

// Iterate across the indexes in the provided table
for _, index := range table.Indexes {
if index.Type != schemas.UniqueType {
continue
}

// index is a Unique constraint
indexCol:
for _, iCol := range index.Cols {
for _, uCol := range uniqueCols {
if uCol == iCol {
// column is already included in uniqueCols so we don't need to add it again
continue indexCol
}
}

// Now iterate across colNames and add to the uniqueCols
for i, col := range colNames {
if col == iCol {
uniqueCols = append(uniqueCols, col)
uniqueArgs = append(uniqueArgs, args[i+1])
continue indexCol
}
}

// If we still haven't found the column we need to look in the emptyColumns and add
// it back into colNames and args as well as uniqueCols/uniqueArgs
for i, col := range zeroedColNames {
if col == iCol {
// Always include empty unique columns in the insert statement as otherwise the insert no conflict will pass
colNames = append(colNames, col)
args = append(args, emptyArgs[i])
uniqueCols = append(uniqueCols, col)
uniqueArgs = append(uniqueArgs, emptyArgs[i])
continue indexCol
}
}
}
}
return uniqueCols, uniqueArgs, colNames, args
}

// getColNamesAndValuesFromBean reads the provided bean, providing two pairs of linked slices:
//
// - colNames and values
// - zeroedColNames and zeroedValues
//
// colNames contains the names of the columns that have non-zero values in the provided bean
// values contains the values - with one exception - values is 1-based so that values[0] is deliberately left zero
//
// emptyyColNames and zeroedValues accounts for the other columns - with zeroedValues containing the zero values
func getColNamesAndValuesFromBean(bean interface{}, cols []*schemas.Column) (colNames []string, values []any, zeroedColNames []string, zeroedValues []any, err error) {
colNames = make([]string, len(cols))
values = make([]any, len(cols)+1) // Leave args[0] to put the SQL in
maxNonEmpty := 0
minEmpty := len(cols)

val := reflect.ValueOf(bean)
elem := val.Elem()
for _, col := range cols {
if fieldIdx := col.FieldIndex; fieldIdx != nil {
fieldVal := elem.FieldByIndex(fieldIdx)
if col.IsCreated || col.IsUpdated {
result, err := setCurrentTime(fieldVal, col)
if err != nil {
return nil, nil, nil, nil, err
}

colNames[maxNonEmpty] = col.Name
maxNonEmpty++
values[maxNonEmpty] = result
continue
}

val, err := getValueFromField(fieldVal, col)
if err != nil {
return nil, nil, nil, nil, err
}
if fieldVal.IsZero() {
values[minEmpty] = val // remember args is 1-based not 0-based
minEmpty--
colNames[minEmpty] = col.Name
continue
}
colNames[maxNonEmpty] = col.Name
maxNonEmpty++
values[maxNonEmpty] = val
}
}

return colNames[:maxNonEmpty], values[:maxNonEmpty+1], colNames[maxNonEmpty:], values[maxNonEmpty+1:], nil
}

func setCurrentTime(fieldVal reflect.Value, col *schemas.Column) (interface{}, error) {
t := time.Now()
result, err := dialects.FormatColumnTime(x.Dialect(), x.DatabaseTZ, col, t)
if err != nil {
return result, err
}

switch fieldVal.Type().Kind() {
case reflect.Struct:
fieldVal.Set(reflect.ValueOf(t).Convert(fieldVal.Type()))
case reflect.Int, reflect.Int64, reflect.Int32:
fieldVal.SetInt(t.Unix())
case reflect.Uint, reflect.Uint64, reflect.Uint32:
fieldVal.SetUint(uint64(t.Unix()))
}
return result, nil
}

// getValueFromField extracts the reflected value from the provided fieldVal
// this keeps the type and makes such that zero values work in the SQL Insert above
func getValueFromField(fieldVal reflect.Value, col *schemas.Column) (any, error) {
// Handle pointers to convert.Conversion
if fieldVal.CanAddr() {
if fieldConvert, ok := fieldVal.Addr().Interface().(convert.Conversion); ok {
data, err := fieldConvert.ToDB()
if err != nil {
return nil, err
}
if data == nil {
if col.Nullable {
return nil, nil
}
data = []byte{}
}
if col.SQLType.IsBlob() {
return data, nil
}
return string(data), nil
}
}

// Handle nil pointer to convert.Conversion
isNil := fieldVal.Kind() == reflect.Ptr && fieldVal.IsNil()
if !isNil {
if fieldConvert, ok := fieldVal.Interface().(convert.Conversion); ok {
data, err := fieldConvert.ToDB()
if err != nil {
return nil, err
}
if data == nil {
if col.Nullable {
return nil, nil
}
data = []byte{}
}
if col.SQLType.IsBlob() {
return data, nil
}
return string(data), nil
}
}

// Handle common primitive types
switch fieldVal.Type().Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return fieldVal.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return fieldVal.Uint(), nil
case reflect.Float32, reflect.Float64:
return fieldVal.Float(), nil
case reflect.Complex64, reflect.Complex128:
return fieldVal.Complex(), nil
case reflect.String:
return fieldVal.String(), nil
case reflect.Bool:
valBool := fieldVal.Bool()

if setting.Database.UseMSSQL {
if valBool {
return 1, nil
}
return 0, nil
}
return valBool, nil
default:
}

// just return the interface
return fieldVal.Interface(), nil
}
Loading