From 57d05b259f28f4b358b87eef2ff3cb9ac0308b11 Mon Sep 17 00:00:00 2001 From: Yonas Habteab Date: Wed, 18 Oct 2023 11:29:22 +0200 Subject: [PATCH] Allow to dynamically define type constraint name --- database/contracts.go | 7 +++++++ database/db.go | 18 ++++++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) diff --git a/database/contracts.go b/database/contracts.go index 7ab8cd19..223166ea 100644 --- a/database/contracts.go +++ b/database/contracts.go @@ -47,3 +47,10 @@ type TableNamer interface { type Scoper interface { Scope() any } + +// PgsqlOnConflictConstrainter implements the PgsqlOnConflictConstraint method, +// which returns the primary or unique key constraint name of the PostgreSQL table. +type PgsqlOnConflictConstrainter interface { + // PgsqlOnConflictConstraint returns the primary or unique key constraint name of the PostgreSQL table. + PgsqlOnConflictConstraint() string +} diff --git a/database/db.go b/database/db.go index 296da23d..d973136f 100644 --- a/database/db.go +++ b/database/db.go @@ -232,7 +232,14 @@ func (db *DB) BuildInsertIgnoreStmt(into interface{}) (string, int) { // MySQL treats UPDATE id = id as a no-op. clause = fmt.Sprintf(`ON DUPLICATE KEY UPDATE "%s" = "%s"`, columns[0], columns[0]) case driver.PostgreSQL: - clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT pk_%s DO NOTHING", table) + var constraint string + if constrainter, ok := into.(PgsqlOnConflictConstrainter); ok { + constraint = constrainter.PgsqlOnConflictConstraint() + } else { + constraint = "pk_" + table + } + + clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO NOTHING", constraint) } return fmt.Sprintf( @@ -295,7 +302,14 @@ func (db *DB) BuildUpsertStmt(subject interface{}) (stmt string, placeholders in clause = "ON DUPLICATE KEY UPDATE" setFormat = `"%[1]s" = VALUES("%[1]s")` case driver.PostgreSQL: - clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT pk_%s DO UPDATE SET", table) + var constraint string + if constrainter, ok := subject.(PgsqlOnConflictConstrainter); ok { + constraint = constrainter.PgsqlOnConflictConstraint() + } else { + constraint = "pk_" + table + } + + clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", constraint) setFormat = `"%[1]s" = EXCLUDED."%[1]s"` }