diff --git a/database/query_builder.go b/database/query_builder.go new file mode 100644 index 00000000..261f667d --- /dev/null +++ b/database/query_builder.go @@ -0,0 +1,254 @@ +package database + +import ( + "fmt" + "golang.org/x/exp/slices" + "reflect" + "sort" + "strings" +) + +// QueryBuilder is an addon for the [DB] type that takes care of all the database statement building shenanigans. +// The recommended use of QueryBuilder is to only use it to generate a single query at a time and not two different +// ones. If for instance you want to generate `INSERT` and `SELECT` queries, it is best to use two different +// QueryBuilder instances. You can use the DB#QueryBuilder() method to get fully initialised instances each time. +type QueryBuilder struct { + db *DB + + subject any + columns []string + excludedColumns []string + + // Indicates whether the generated columns should be sorted in ascending order before generating the + // actual statements. This is intended for unit tests only and shouldn't be necessary for production code. + sort bool +} + +// SetColumns sets the DB columns to be used when building the statements. +// When you do not want the columns to be extracted dynamically, you can use this method to specify them manually. +// Returns the current *[QueryBuilder] receiver and allows you to chain some method calls. +func (qb *QueryBuilder) SetColumns(columns ...string) *QueryBuilder { + qb.columns = columns + return qb +} + +// SetExcludedColumns excludes the given columns from all the database statements. +// Returns the current *[QueryBuilder] receiver and allows you to chain some method calls. +func (qb *QueryBuilder) SetExcludedColumns(columns ...string) *QueryBuilder { + qb.excludedColumns = columns + return qb +} + +// Delete returns a DELETE statement for the query builders subject filtered by ID. +func (qb *QueryBuilder) Delete() string { + return qb.DeleteBy("id") +} + +// DeleteBy returns a DELETE statement for the query builders subject filtered by the given column. +func (qb *QueryBuilder) DeleteBy(column string) string { + return fmt.Sprintf(`DELETE FROM "%s" WHERE "%s" IN (?)`, TableName(qb.subject), column) +} + +// Insert returns an INSERT INTO statement for the query builders subject. +func (qb *QueryBuilder) Insert() (string, int) { + columns := qb.BuildColumns() + + return fmt.Sprintf( + `INSERT INTO "%s" ("%s") VALUES (%s)`, + TableName(qb.subject), + strings.Join(columns, `", "`), + fmt.Sprintf(":%s", strings.Join(columns, ", :")), + ), len(columns) +} + +// InsertIgnore returns an INSERT statement for the query builders subject for +// which the database ignores rows that have already been inserted. +func (qb *QueryBuilder) InsertIgnore() (string, int) { + columns := qb.BuildColumns() + + var clause string + switch qb.db.DriverName() { + case MySQL: + // MySQL treats UPDATE id = id as a no-op. + clause = fmt.Sprintf(`ON DUPLICATE KEY UPDATE "%[1]s" = "%[1]s"`, columns[0]) + case PostgreSQL: + clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO NOTHING", qb.getPgsqlOnConflictConstraint()) + default: + panic("Driver unsupported: " + qb.db.DriverName()) + } + + return fmt.Sprintf( + `INSERT INTO "%s" ("%s") VALUES (%s) %s`, + TableName(qb.subject), + strings.Join(columns, `", "`), + fmt.Sprintf(":%s", strings.Join(columns, ", :")), + clause, + ), len(columns) +} + +// Select returns a SELECT statement from the query builders subject and the already set columns. +// If no columns are set, they will be extracted from the query builders subject. +// When the query builders subject is of type Scoper, a WHERE clause is appended to the statement. +func (qb *QueryBuilder) Select() string { + var scoper Scoper + if sc, ok := qb.subject.(Scoper); ok { + scoper = sc + } + + return qb.SelectScoped(scoper) +} + +// SelectScoped returns a SELECT statement from the query builders subject and the already set columns filtered +// by the given scoper/column. When no columns are set, they will be extracted from the query builders subject. +// The argument scoper must either be of type Scoper, string or nil to get SELECT statements without a WHERE clause. +func (qb *QueryBuilder) SelectScoped(scoper any) string { + query := fmt.Sprintf(`SELECT "%s" FROM "%s"`, strings.Join(qb.BuildColumns(), `", "`), TableName(qb.subject)) + where, placeholders := qb.Where(scoper) + if placeholders > 0 { + query += ` WHERE ` + where + } + + return query +} + +// Update returns an UPDATE statement for the query builders subject filter by ID column. +func (qb *QueryBuilder) Update() (string, int) { + return qb.UpdateScoped("id") +} + +// UpdateScoped returns an UPDATE statement for the query builders subject filtered by the given column/scoper. +// The argument scoper must either be of type Scoper, string or nil to get UPDATE statements without a WHERE clause. +func (qb *QueryBuilder) UpdateScoped(scoper any) (string, int) { + columns := qb.BuildColumns() + set := make([]string, 0, len(columns)) + + for _, col := range columns { + set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col)) + } + + placeholders := len(columns) + query := `UPDATE "%s" SET %s` + if where, count := qb.Where(scoper); count > 0 { + placeholders += count + query += ` WHERE ` + where + } + + return fmt.Sprintf(query, TableName(qb.subject), strings.Join(set, ", ")), placeholders +} + +// Upsert returns an upsert statement for the query builders subject. +func (qb *QueryBuilder) Upsert() (string, int) { + var updateColumns []string + if upserter, ok := qb.subject.(Upserter); ok { + updateColumns = qb.db.columnMap.Columns(upserter.Upsert()) + } else { + updateColumns = qb.BuildColumns() + } + + return qb.UpsertColumns(updateColumns...) +} + +// UpsertColumns returns an upsert statement for the query builders subject and the specified update columns. +func (qb *QueryBuilder) UpsertColumns(updateColumns ...string) (string, int) { + var clause, setFormat string + switch qb.db.DriverName() { + case MySQL: + clause = "ON DUPLICATE KEY UPDATE" + setFormat = `"%[1]s" = VALUES("%[1]s")` + case PostgreSQL: + clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", qb.getPgsqlOnConflictConstraint()) + setFormat = `"%[1]s" = EXCLUDED."%[1]s"` + default: + panic("Driver unsupported: " + qb.db.DriverName()) + } + + set := make([]string, 0, len(updateColumns)) + for _, col := range updateColumns { + set = append(set, fmt.Sprintf(setFormat, col)) + } + + insertColumns := qb.BuildColumns() + + return fmt.Sprintf( + `INSERT INTO "%s" ("%s") VALUES (%s) %s %s`, + TableName(qb.subject), + strings.Join(insertColumns, `", "`), + fmt.Sprintf(":%s", strings.Join(insertColumns, ", :")), + clause, + strings.Join(set, ", "), + ), len(insertColumns) +} + +// Where returns a WHERE clause with named placeholder conditions built from the +// specified scoper/column combined with the AND operator. +func (qb *QueryBuilder) Where(subject any) (string, int) { + t := reflect.TypeOf(subject) + if t == nil { // Subject is a nil interface value. + return "", 0 + } + + var columns []string + if t.Kind() == reflect.String { + columns = []string{subject.(string)} + } else if t.Kind() == reflect.Struct || t.Kind() == reflect.Pointer { + if scoper, ok := subject.(Scoper); ok { + return qb.Where(scoper.Scope()) + } + + columns = qb.db.columnMap.Columns(subject) + } else { // This should never happen unless someone wants to do some silly things. + panic(fmt.Sprintf("qb.Where: unknown subject type provided: %q", t.Kind().String())) + } + + where := make([]string, 0, len(columns)) + for _, col := range columns { + where = append(where, fmt.Sprintf(`"%[1]s" = :%[1]s`, col)) + } + + return strings.Join(where, ` AND `), len(columns) +} + +// BuildColumns returns all the Query Builder columns (if specified), otherwise they are +// determined dynamically using its subject. Additionally, it checks whether columns need +// to be excluded and proceeds accordingly. +func (qb *QueryBuilder) BuildColumns() []string { + var columns []string + if len(qb.columns) > 0 { + columns = qb.columns + } else { + columns = qb.db.columnMap.Columns(qb.subject) + } + + if len(qb.excludedColumns) > 0 { + columns = slices.DeleteFunc(append([]string(nil), columns...), func(column string) bool { + for _, exclude := range qb.excludedColumns { + if exclude == column { + return true + } + } + + return false + }) + } + + if qb.sort { + // The order in which the columns appear is not guaranteed as we extract the columns dynamically + // from the struct. So, we've to sort them here to be able to test the generated statements. + sort.SliceStable(columns, func(a, b int) bool { + return columns[a] < columns[b] + }) + } + + return columns +} + +// getPgsqlOnConflictConstraint returns the constraint name of the current [QueryBuilder]'s subject. +// If the subject does not implement the PgsqlOnConflictConstrainter interface, it will simply return +// the table name prefixed with `pk_`. +func (qb *QueryBuilder) getPgsqlOnConflictConstraint() string { + if constrainter, ok := qb.subject.(PgsqlOnConflictConstrainter); ok { + return constrainter.PgsqlOnConflictConstraint() + } + + return "pk_" + TableName(qb.subject) +} diff --git a/database/query_builder_test.go b/database/query_builder_test.go new file mode 100644 index 00000000..4c637621 --- /dev/null +++ b/database/query_builder_test.go @@ -0,0 +1,231 @@ +package database + +import ( + "github.com/creasty/defaults" + "github.com/icinga/icinga-go-library/logging" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.uber.org/zap/zaptest" + "testing" + "time" +) + +func TestQueryBuilder(t *testing.T) { + t.Parallel() + + c := &Config{} + require.NoError(t, defaults.Set(c), "applying config default should not fail") + + db, err := NewDbFromConfig(c, logging.NewLogger(zaptest.NewLogger(t).Sugar(), time.Hour), RetryConnectorCallbacks{}) + require.NoError(t, err) + require.Equal(t, MySQL, db.DriverName()) + runTests(t, db) // Tests with MySQL driver + + c.Type = "pgsql" + db, err = NewDbFromConfig(c, logging.NewLogger(zaptest.NewLogger(t).Sugar(), time.Hour), RetryConnectorCallbacks{}) + require.NoError(t, err) + require.Equal(t, PostgreSQL, db.DriverName()) + runTests(t, db) // Tests with PostgreSQL driver + + t.Run("PgsqlOnConflictConstrainter", func(t *testing.T) { + qb := newQB(db, &pgsqlConstraintName{}) + qb.SetColumns("a") + + stmt, columns := qb.Upsert() + assert.Equal(t, 1, columns) + assert.Equal(t, `INSERT INTO "test" ("a") VALUES (:a) ON CONFLICT ON CONSTRAINT idx_custom_constraint_name DO UPDATE SET "a" = EXCLUDED."a"`, stmt) + + stmt, columns = qb.InsertIgnore() + assert.Equal(t, 1, columns) + assert.Equal(t, `INSERT INTO "test" ("a") VALUES (:a) ON CONFLICT ON CONSTRAINT idx_custom_constraint_name DO NOTHING`, stmt) + }) +} + +func runTests(t *testing.T, db *DB) { + t.Run("SetColumns", func(t *testing.T) { + qb := &QueryBuilder{subject: "test"} + qb.SetColumns("a", "b") + assert.Equal(t, []string{"a", "b"}, qb.columns) + }) + + t.Run("ExcludeColumns", func(t *testing.T) { + qb := &QueryBuilder{subject: &test{}} + qb.SetExcludedColumns("a", "b") + assert.Equal(t, []string{"a", "b"}, qb.excludedColumns) + }) + + t.Run("DeleteStatements", func(t *testing.T) { + qb := &QueryBuilder{subject: &test{}} + assert.Equal(t, `DELETE FROM "test" WHERE "id" IN (?)`, qb.Delete()) + assert.Equal(t, `DELETE FROM "test" WHERE "foo" IN (?)`, qb.DeleteBy("foo")) + }) + + t.Run("WhereClause", func(t *testing.T) { + qb := newQB(db, 1) + // Is invalid column (1) + assert.PanicsWithValue(t, "qb.Where: unknown subject type provided: \"int\"", func() { _, _ = qb.Where(1) }) + + var nilPtr Scoper // Interface nil value + qb = &QueryBuilder{subject: nilPtr} + clause, placeholder := qb.Where(nilPtr) + assert.Equal(t, 0, placeholder) + assert.Empty(t, clause) + + clause, placeholder = qb.Where("id") + assert.Equal(t, 1, placeholder) + assert.Equal(t, "\"id\" = :id", clause) + + assertScoperID := func(clause string, placeholder int) { + assert.Equal(t, 1, placeholder) + assert.Equal(t, "\"scoper_id\" = :scoper_id", clause) + } + + var reference test + qb = newQB(db, &reference) + clause, placeholder = qb.Where(&reference) + assertScoperID(clause, placeholder) + + nonNilPtr := new(test) + qb = newQB(db, nonNilPtr) + clause, placeholder = qb.Where(nonNilPtr) + assertScoperID(clause, placeholder) + }) + + t.Run("InsertStatements", func(t *testing.T) { + t.Parallel() + + qb := newQB(db, &test{}) + qb.sort = true + qb.SetExcludedColumns("random") + + stmt, columns := qb.Insert() + assert.Equal(t, 2, columns) + assert.Equal(t, `INSERT INTO "test" ("name", "value") VALUES (:name, :value)`, stmt) + + qb.SetExcludedColumns("a", "b") + qb.SetColumns("a", "b", "c", "d") + + stmt, columns = qb.Insert() + assert.Equal(t, 2, columns) + assert.Equal(t, `INSERT INTO "test" ("c", "d") VALUES (:c, :d)`, stmt) + + stmt, columns = qb.InsertIgnore() + assert.Equal(t, 2, columns) + if db.DriverName() == MySQL { + assert.Equal(t, `INSERT INTO "test" ("c", "d") VALUES (:c, :d) ON DUPLICATE KEY UPDATE "c" = "c"`, stmt) + } else { + assert.Equal(t, `INSERT INTO "test" ("c", "d") VALUES (:c, :d) ON CONFLICT ON CONSTRAINT pk_test DO NOTHING`, stmt) + } + }) + + t.Run("SelectStatements", func(t *testing.T) { + t.Parallel() + + qb := newQB(db, &test{}) + qb.sort = true + + stmt := qb.Select() + expected := `SELECT "name", "random", "value" FROM "test" WHERE "scoper_id" = :scoper_id` + assert.Equal(t, expected, stmt) + + qb.SetColumns("name", "random", "value") + + stmt = qb.SelectScoped("name") + assert.Equal(t, `SELECT "name", "random", "value" FROM "test" WHERE "name" = :name`, stmt) + }) + + t.Run("UpdateStatements", func(t *testing.T) { + t.Parallel() + + qb := newQB(db, &test{}) + qb.sort = true + qb.SetExcludedColumns("random") + + stmt, placeholders := qb.Update() + assert.Equal(t, 3, placeholders) + + expected := `UPDATE "test" SET "name" = :name, "value" = :value WHERE "id" = :id` + assert.Equal(t, expected, stmt) + + stmt, placeholders = qb.UpdateScoped((&test{}).Scope()) + assert.Equal(t, 3, placeholders) + assert.Equal(t, `UPDATE "test" SET "name" = :name, "value" = :value WHERE "scoper_id" = :scoper_id`, stmt) + + qb.SetExcludedColumns("a", "b") + qb.SetColumns("a", "b", "c", "d") + + stmt, placeholders = qb.UpdateScoped("c") + assert.Equal(t, 3, placeholders) + assert.Equal(t, 3, placeholders) + assert.Equal(t, `UPDATE "test" SET "c" = :c, "d" = :d WHERE "c" = :c`, stmt) + }) + + t.Run("UpsertStatements", func(t *testing.T) { + t.Parallel() + + qb := newQB(db, &test{}) + qb.sort = true + qb.SetExcludedColumns("random") + + stmt, columns := qb.Upsert() + assert.Equal(t, 2, columns) + + expected := `INSERT INTO "test" ("name", "value") VALUES (:name, :value)` + if db.DriverName() == MySQL { + assert.Equal(t, expected+` ON DUPLICATE KEY UPDATE "name" = VALUES("name"), "value" = VALUES("value")`, stmt) + } else { + assert.Equal(t, expected+` ON CONFLICT ON CONSTRAINT pk_test DO UPDATE SET "name" = EXCLUDED."name", "value" = EXCLUDED."value"`, stmt) + } + + qb.SetExcludedColumns("a", "b") + qb.SetColumns("a", "b", "c", "d") + + expected = `INSERT INTO "test" ("c", "d") VALUES (:c, :d)` + stmt, columns = qb.Upsert() + assert.Equal(t, 2, columns) + if db.DriverName() == MySQL { + assert.Equal(t, expected+` ON DUPLICATE KEY UPDATE "c" = VALUES("c"), "d" = VALUES("d")`, stmt) + } else { + assert.Equal(t, expected+` ON CONFLICT ON CONSTRAINT pk_test DO UPDATE SET "c" = EXCLUDED."c", "d" = EXCLUDED."d"`, stmt) + } + + qb.SetExcludedColumns("a") + + expected = `INSERT INTO "test" ("b", "c", "d") VALUES (:b, :c, :d)` + stmt, columns = qb.UpsertColumns("b", "c") + assert.Equal(t, 3, columns) + if db.DriverName() == MySQL { + assert.Equal(t, expected+` ON DUPLICATE KEY UPDATE "b" = VALUES("b"), "c" = VALUES("c")`, stmt) + } else { + assert.Equal(t, expected+` ON CONFLICT ON CONSTRAINT pk_test DO UPDATE SET "b" = EXCLUDED."b", "c" = EXCLUDED."c"`, stmt) + } + }) +} + +func newQB(db *DB, subject any) *QueryBuilder { + return &QueryBuilder{subject: subject, db: db} +} + +type test struct { + Name string + Value string + Random string +} + +func (t *test) Scope() any { + return struct { + ScoperID string + }{} +} + +type pgsqlConstraintName struct { + *test +} + +func (p *pgsqlConstraintName) PgsqlOnConflictConstraint() string { + return "idx_custom_constraint_name" +} + +func (p *pgsqlConstraintName) TableName() string { + return "test" +}