diff --git a/models/packages/package.go b/models/packages/package.go index 65a25741509e4..e30c6f6867d0a 100644 --- a/models/packages/package.go +++ b/models/packages/package.go @@ -6,9 +6,11 @@ package packages import ( "context" "fmt" + "strconv" "strings" "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/util" "xorm.io/builder" @@ -189,25 +191,65 @@ type Package struct { // TryInsertPackage inserts a package. If a package exists already, ErrDuplicatePackage is returned func TryInsertPackage(ctx context.Context, p *Package) (*Package, error) { - e := db.GetEngine(ctx) + switch { + case setting.Database.Type.IsMySQL(): + if _, err := db.GetEngine(ctx).Exec("INSERT INTO package (owner_id,`type`,lower_name,name,semver_compatible) VALUES (?,?,?,?,?) ON DUPLICATE KEY UPDATE 1=1", + p.OwnerID, p.Type, p.LowerName, p.Name, p.SemverCompatible); err != nil { + return nil, err + } + case setting.Database.Type.IsPostgreSQL(), setting.Database.Type.IsSQLite3(): + if _, err := db.GetEngine(ctx).Exec("INSERT INTO package (owner_id,`type`,lower_name,name,semver_compatible) VALUES (?,?,?,?,?) ON CONFLICT (owner_id,`type`,lower_name) DO UPDATE SET lower_name=lower_name", + p.OwnerID, p.Type, p.LowerName, p.Name, p.SemverCompatible); err != nil { + return nil, err + } + case setting.Database.Type.IsMSSQL(): + r := func(s string) string { + return strings.ReplaceAll(s, "'", "''") + } + sql := fmt.Sprintf(` + MERGE INTO package WITH (HOLDLOCK) AS target USING ( + SELECT + %d AS owner_id, + '%s' AS [type], + '%s' AS lower_name, + '%s' AS name, + %s AS semver_compatible + ) AS source ( + owner_id, [type], lower_name, name, semver_compatible + ) ON ( + target.owner_id = source.owner_id + AND target.[type] = source.[type] + AND target.lower_name = source.lower_name + ) WHEN MATCHED + THEN UPDATE SET 1 = 1 + WHEN NOT MATCHED + THEN INSERT ( + owner_id, [type], lower_name, name, semver_compatible + ) VALUES ( + %d, '%s', '%s', '%s', %s + )`, + p.OwnerID, r(string(p.Type)), r(p.LowerName), r(p.Name), strconv.FormatBool(p.SemverCompatible), + p.OwnerID, r(string(p.Type)), r(p.LowerName), r(p.Name), strconv.FormatBool(p.SemverCompatible), + ) - existing := &Package{} + if _, err := db.GetEngine(ctx).Exec(sql); err != nil { + return nil, err + } + } - has, err := e.Where(builder.Eq{ + var existing Package + has, err := db.GetEngine(ctx).Where(builder.Eq{ "owner_id": p.OwnerID, "type": p.Type, "lower_name": p.LowerName, - }).Get(existing) + }).Get(&existing) if err != nil { return nil, err } - if has { - return existing, ErrDuplicatePackage - } - if _, err = e.Insert(p); err != nil { - return nil, err + if !has { + return nil, util.ErrNotExist } - return p, nil + return &existing, nil } // DeletePackageByID deletes a package by id diff --git a/models/packages/package_file.go b/models/packages/package_file.go index 1bb6b57a34e8e..61581b20450ad 100644 --- a/models/packages/package_file.go +++ b/models/packages/package_file.go @@ -5,11 +5,13 @@ package packages import ( "context" + "fmt" "strconv" "strings" "time" "code.gitea.io/gitea/models/db" + "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/timeutil" "code.gitea.io/gitea/modules/util" @@ -44,11 +46,55 @@ type PackageFile struct { // TryInsertFile inserts a file. If the file exists already ErrDuplicatePackageFile is returned func TryInsertFile(ctx context.Context, pf *PackageFile) (*PackageFile, error) { - e := db.GetEngine(ctx) + switch { + case setting.Database.Type.IsMySQL(): + if _, err := db.GetEngine(ctx).Exec("INSERT INTO package_file (version_id,blob_id,name,lower_name,composite_key,is_lead) VALUES (?,?,?,?,?,?) ON DUPLICATE KEY UPDATE 1=1", + pf.VersionID, pf.BlobID, pf.Name, pf.LowerName, pf.CompositeKey, pf.IsLead); err != nil { + return nil, err + } + case setting.Database.Type.IsPostgreSQL(), setting.Database.Type.IsSQLite3(): + if _, err := db.GetEngine(ctx).Exec("INSERT INTO package_file (version_id,blob_id,name,lower_name,composite_key,is_lead) VALUES (?,?,?,?,?,?) ON CONFLICT (version_id,lower_name,composite_key) DO UPDATE SET lower_name=lower_name", + pf.VersionID, pf.BlobID, pf.Name, pf.LowerName, pf.CompositeKey, pf.IsLead); err != nil { + return nil, err + } + case setting.Database.Type.IsMSSQL(): + r := func(s string) string { + return strings.ReplaceAll(s, "'", "''") + } + sql := fmt.Sprintf(` + MERGE INTO package_file WITH (HOLDLOCK) AS target USING ( + SELECT + %d AS version_id, + %d AS blob_id, + '%s' AS name, + '%s' AS lower_name, + '%s' AS composite_key, + %s AS is_lead + ) AS source ( + version_id, blob_id, name, lower_name, composite_key, is_lead + ) ON ( + target.version_id = source.version_id + AND target.lower_name = source.lower_name + AND target.composite_key = source.composite_key + ) WHEN MATCHED + THEN UPDATE SET 1 = 1 + WHEN NOT MATCHED + THEN INSERT ( + version_id, blob_id, name, lower_name, composite_key, is_lead + ) VALUES ( + %d, %d, '%s', '%s', '%s', %s + )`, + pf.VersionID, pf.BlobID, r(pf.Name), r(pf.LowerName), r(pf.CompositeKey), strconv.FormatBool(pf.IsLead), + pf.VersionID, pf.BlobID, r(pf.Name), r(pf.LowerName), r(pf.CompositeKey), strconv.FormatBool(pf.IsLead), + ) + + if _, err := db.GetEngine(ctx).Exec(sql); err != nil { + return nil, err + } + } existing := &PackageFile{} - - has, err := e.Where(builder.Eq{ + has, err := db.GetEngine(ctx).Where(builder.Eq{ "version_id": pf.VersionID, "lower_name": pf.LowerName, "composite_key": pf.CompositeKey, @@ -56,13 +102,10 @@ func TryInsertFile(ctx context.Context, pf *PackageFile) (*PackageFile, error) { if err != nil { return nil, err } - if has { - return existing, ErrDuplicatePackageFile - } - if _, err = e.Insert(pf); err != nil { - return nil, err + if !has { + return nil, util.ErrNotExist } - return pf, nil + return existing, nil } // GetFilesByVersionID gets all files of a version diff --git a/tests/integration/api_packages_maven_test.go b/tests/integration/api_packages_maven_test.go index c7ed554a9d7f7..cb45b49dedc71 100644 --- a/tests/integration/api_packages_maven_test.go +++ b/tests/integration/api_packages_maven_test.go @@ -8,6 +8,7 @@ import ( "net/http" "strconv" "strings" + "sync" "testing" "code.gitea.io/gitea/models/db" @@ -242,3 +243,35 @@ func TestPackageMaven(t *testing.T) { putFile(t, fmt.Sprintf("/%s/maven-metadata.xml", snapshotVersion), "test-overwrite", http.StatusCreated) }) } + +func TestPackageMavenConcurrent(t *testing.T) { + defer tests.PrepareTestEnv(t)() + + user := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 2}) + + groupID := "com.gitea" + artifactID := "test-project" + packageVersion := "1.0.1" + + root := fmt.Sprintf("/api/packages/%s/maven/%s/%s", user.Name, strings.ReplaceAll(groupID, ".", "/"), artifactID) + + putFile := func(t *testing.T, path, content string, expectedStatus int) { + req := NewRequestWithBody(t, "PUT", root+path, strings.NewReader(content)). + AddBasicAuth(user.Name) + MakeRequest(t, req, expectedStatus) + } + + t.Run("Concurrent Upload", func(t *testing.T) { + defer tests.PrintCurrentTest(t)() + + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func() { + putFile(t, fmt.Sprintf("/%s/%s.jar", packageVersion, strconv.Itoa(i)), "test", http.StatusCreated) + wg.Done() + }() + } + wg.Wait() + }) +}