diff --git a/models/db/common.go b/models/db/common.go index 1a59a8b5c697f..a31c066ef11db 100644 --- a/models/db/common.go +++ b/models/db/common.go @@ -5,12 +5,16 @@ package db import ( + "context" + "fmt" + "reflect" "strings" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/util" "xorm.io/builder" + "xorm.io/xorm/schemas" ) // BuildCaseInsensitiveLike returns a condition to check if the given value is like the given key case-insensitively. @@ -21,3 +25,149 @@ func BuildCaseInsensitiveLike(key, value string) builder.Cond { } return builder.Like{"UPPER(" + key + ")", strings.ToUpper(value)} } + +func InsertOnConflictDoNothing(ctx context.Context, bean interface{}) (int64, error) { + e := GetEngine(ctx) + + tableName := x.TableName(bean, true) + table, err := x.TableInfo(bean) + if err != nil { + return 0, err + } + + autoIncrCol := table.AutoIncrColumn() + if autoIncrCol == nil { + return 0, fmt.Errorf("this function requires an autoincrement column") + } + + cols := table.Columns() + colNames := make([]string, 0, len(cols)) + args := make([]interface{}, 1, len(cols)) + + val := reflect.ValueOf(bean) + elem := val.Elem() + for _, col := range cols { + if fieldIdx := col.FieldIndex; fieldIdx != nil { + fieldVal := elem.FieldByIndex(fieldIdx) + if fieldVal.IsZero() { + continue + } + colNames = append(colNames, col.Name) + args = append(args, fieldVal.Interface()) + } + } + + if len(colNames) == 0 { + return 0, fmt.Errorf("empty bean") + } + + uniqueCols := make([]string, 0, len(cols)) + uniqueArgs := make([]interface{}, 1, len(cols)) + for _, index := range table.Indexes { + if index.Type != schemas.UniqueType { + continue + } + indexCol: + for _, iCol := range index.Cols { + for _, uCol := range uniqueCols { + if uCol == iCol { + continue indexCol + } + } + for i, col := range colNames { + if col == iCol { + uniqueCols = append(uniqueCols, col) + uniqueArgs = append(uniqueArgs, args[i+1]) + continue indexCol + } + } + } + } + + if len(uniqueCols) == 0 { + return 0, fmt.Errorf("empty bean") + } + + sb := &strings.Builder{} + switch { + case setting.Database.UseSQLite3 || setting.Database.UsePostgreSQL || setting.Database.UseMySQL: + _, _ = sb.WriteString("INSERT INTO ") + _, _ = sb.WriteString(x.Dialect().Quoter().Quote(tableName)) + _, _ = sb.WriteString(" (") + _, _ = sb.WriteString(colNames[0]) + for _, colName := range colNames[1:] { + _, _ = sb.WriteString(",") + _, _ = sb.WriteString(colName) + } + _, _ = sb.WriteString(") VALUES (") + _, _ = sb.WriteString("?") + for range colNames[1:] { + _, _ = sb.WriteString(",?") + } + switch { + case setting.Database.UseSQLite3 || setting.Database.UsePostgreSQL: + _, _ = sb.WriteString(") ON CONFLICT DO NOTHING") + case setting.Database.UseMySQL: + _, _ = sb.WriteString(") ON CONFLICT DO DUPLICATE KEY ") + _, _ = sb.WriteString(autoIncrCol.Name) + _, _ = sb.WriteString(" = ") + _, _ = sb.WriteString(autoIncrCol.Name) + } + case setting.Database.UseMSSQL: + _, _ = sb.WriteString("MERGE ") + _, _ = sb.WriteString(x.Dialect().Quoter().Quote(tableName)) + _, _ = sb.WriteString(" WITH (HOLDLOCK) AS target USING (SELECT ") + + _, _ = sb.WriteString("? AS ") + _, _ = sb.WriteString(uniqueCols[0]) + for _, uniqueCol := range uniqueCols[1:] { + _, _ = sb.WriteString(", ? AS ") + _, _ = sb.WriteString(uniqueCol) + } + _, _ = sb.WriteString(") AS src ON src.") + _, _ = sb.WriteString(uniqueCols[0]) + _, _ = sb.WriteString("= target.") + _, _ = sb.WriteString(uniqueCols[0]) + for _, uniqueCol := range uniqueCols[1:] { + _, _ = sb.WriteString(" AND src.") + _, _ = sb.WriteString(uniqueCol) + _, _ = sb.WriteString("= target.") + _, _ = sb.WriteString(uniqueCols[0]) + } + _, _ = sb.WriteString(" WHEN NOT MATCHED THEN INSERT (") + _, _ = sb.WriteString(colNames[0]) + for _, colName := range colNames[1:] { + _, _ = sb.WriteString(",") + _, _ = sb.WriteString(colName) + } + _, _ = sb.WriteString(") VALUES (") + _, _ = sb.WriteString("?") + for range colNames[1:] { + _, _ = sb.WriteString(",?") + } + _, _ = sb.WriteString(")") + args = append(uniqueArgs, args[1:]...) + default: + return 0, fmt.Errorf("database type not supported") + } + args[0] = sb.String() + res, err := e.Exec(args...) + if err != nil { + return 0, err + } + + n, err := res.RowsAffected() + if err != nil { + return n, err + } + + if n != 0 { + id, err := res.LastInsertId() + if err != nil { + return n, err + } + elem.FieldByName(autoIncrCol.FieldName).SetInt(id) + } + + return res.RowsAffected() +} diff --git a/models/packages/package.go b/models/packages/package.go index f9bd6c7657653..8c8b6621c5d94 100644 --- a/models/packages/package.go +++ b/models/packages/package.go @@ -120,7 +120,13 @@ type Package struct { // TryInsertPackage inserts a package. If a package exists already, ErrDuplicatePackage is returned func TryInsertPackage(ctx context.Context, p *Package) (*Package, error) { - e := db.GetEngine(ctx) + n, err := db.InsertOnConflictDoNothing(ctx, p) + if err != nil { + return nil, err + } + if n != 0 { + return p, nil + } key := &Package{ OwnerID: p.OwnerID, @@ -128,14 +134,16 @@ func TryInsertPackage(ctx context.Context, p *Package) (*Package, error) { LowerName: p.LowerName, } - if _, err := e.Insert(p); err != nil { - // Try to get the key again - if has, _ := e.Get(key); has { - return key, ErrDuplicatePackage + has, err := db.GetEngine(ctx).Get(key) + if has { + if n == 0 { + err = ErrDuplicatePackage } - return nil, err + return key, err + } else if err == nil { + return TryInsertPackage(ctx, p) } - return p, nil + return nil, err } // DeletePackageByID deletes a package by id diff --git a/models/packages/package_file.go b/models/packages/package_file.go index a2e007c7dc160..162df3d2396c1 100644 --- a/models/packages/package_file.go +++ b/models/packages/package_file.go @@ -45,21 +45,30 @@ type PackageFile struct { // TryInsertFile inserts a file. If the file exists already ErrDuplicatePackageFile is returned func TryInsertFile(ctx context.Context, pf *PackageFile) (*PackageFile, error) { - e := db.GetEngine(ctx) + n, err := db.InsertOnConflictDoNothing(ctx, pf) + if err != nil { + return nil, err + } + if n != 0 { + return pf, nil + } key := &PackageFile{ VersionID: pf.VersionID, LowerName: pf.LowerName, CompositeKey: pf.CompositeKey, } - - if _, err := e.Insert(pf); err != nil { - if has, _ := e.Get(key); has { - return pf, ErrDuplicatePackageFile + has, err := db.GetEngine(ctx).Get(key) + if has { + if n == 0 { + err = ErrDuplicatePackageFile } - return nil, err + return key, err + } else if err == nil { + return TryInsertFile(ctx, pf) } - return pf, nil + + return key, err } // GetFilesByVersionID gets all files of a version diff --git a/models/packages/package_version.go b/models/packages/package_version.go index 8bb72f6e1f685..a0ef1da185d18 100644 --- a/models/packages/package_version.go +++ b/models/packages/package_version.go @@ -39,20 +39,29 @@ type PackageVersion struct { // GetOrInsertVersion inserts a version. If the same version exist already ErrDuplicatePackageVersion is returned func GetOrInsertVersion(ctx context.Context, pv *PackageVersion) (*PackageVersion, error) { - e := db.GetEngine(ctx) - key := &PackageVersion{ PackageID: pv.PackageID, LowerVersion: pv.LowerVersion, } - if _, err := e.Insert(pv); err != nil { - if has, _ := e.Get(key); has { - return key, ErrDuplicatePackageVersion - } + n, err := db.InsertOnConflictDoNothing(ctx, pv) + if err != nil { return nil, err } - return pv, nil + if n != 0 { + return pv, nil + } + + has, err := db.GetEngine(ctx).Get(key) + if has { + if n == 0 { + err = ErrDuplicatePackageVersion + } + return key, err + } else if err == nil { + return GetOrInsertVersion(ctx, pv) + } + return nil, err } // UpdateVersion updates a version