Skip to content

Commit

Permalink
Introduce QueryBuilder type
Browse files Browse the repository at this point in the history
  • Loading branch information
yhabteab committed May 27, 2024
1 parent 2d47d95 commit 97b850a
Show file tree
Hide file tree
Showing 2 changed files with 485 additions and 0 deletions.
254 changes: 254 additions & 0 deletions database/query_builder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
package database

import (
"fmt"
"golang.org/x/exp/slices"
"reflect"
"sort"
"strings"
)

// QueryBuilder is an addon for the [DB] type that takes care of all the database statement building shenanigans.
// The recommended use of QueryBuilder is to only use it to generate a single query at a time and not two different
// ones. If for instance you want to generate `INSERT` and `SELECT` queries, it is best to use two different
// QueryBuilder instances. You can use the DB#QueryBuilder() method to get fully initialised instances each time.
type QueryBuilder struct {
db *DB

subject any
columns []string
excludedColumns []string

// Indicates whether the generated columns should be sorted in ascending order before generating the
// actual statements. This is intended for unit tests only and shouldn't be necessary for production code.
sort bool
}

// SetColumns sets the DB columns to be used when building the statements.
// When you do not want the columns to be extracted dynamically, you can use this method to specify them manually.
// Returns the current *[QueryBuilder] receiver and allows you to chain some method calls.
func (qb *QueryBuilder) SetColumns(columns ...string) *QueryBuilder {
qb.columns = columns
return qb
}

// SetExcludedColumns excludes the given columns from all the database statements.
// Returns the current *[QueryBuilder] receiver and allows you to chain some method calls.
func (qb *QueryBuilder) SetExcludedColumns(columns ...string) *QueryBuilder {
qb.excludedColumns = columns
return qb
}

// Delete returns a DELETE statement for the query builders subject filtered by ID.
func (qb *QueryBuilder) Delete() string {
return qb.DeleteBy("id")
}

// DeleteBy returns a DELETE statement for the query builders subject filtered by the given column.
func (qb *QueryBuilder) DeleteBy(column string) string {
return fmt.Sprintf(`DELETE FROM "%s" WHERE "%s" IN (?)`, TableName(qb.subject), column)
}

// Insert returns an INSERT INTO statement for the query builders subject.
func (qb *QueryBuilder) Insert() (string, int) {
columns := qb.BuildColumns()

return fmt.Sprintf(
`INSERT INTO "%s" ("%s") VALUES (%s)`,
TableName(qb.subject),
strings.Join(columns, `", "`),
fmt.Sprintf(":%s", strings.Join(columns, ", :")),
), len(columns)
}

// InsertIgnore returns an INSERT statement for the query builders subject for
// which the database ignores rows that have already been inserted.
func (qb *QueryBuilder) InsertIgnore() (string, int) {
columns := qb.BuildColumns()

var clause string
switch qb.db.DriverName() {
case MySQL:
// MySQL treats UPDATE id = id as a no-op.
clause = fmt.Sprintf(`ON DUPLICATE KEY UPDATE "%[1]s" = "%[1]s"`, columns[0])
case PostgreSQL:
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO NOTHING", qb.getPgsqlOnConflictConstraint())
default:
panic("Driver unsupported: " + qb.db.DriverName())
}

return fmt.Sprintf(
`INSERT INTO "%s" ("%s") VALUES (%s) %s`,
TableName(qb.subject),
strings.Join(columns, `", "`),
fmt.Sprintf(":%s", strings.Join(columns, ", :")),
clause,
), len(columns)
}

// Select returns a SELECT statement from the query builders subject and the already set columns.
// If no columns are set, they will be extracted from the query builders subject.
// When the query builders subject is of type Scoper, a WHERE clause is appended to the statement.
func (qb *QueryBuilder) Select() string {
var scoper Scoper
if sc, ok := qb.subject.(Scoper); ok {
scoper = sc
}

return qb.SelectScoped(scoper)
}

// SelectScoped returns a SELECT statement from the query builders subject and the already set columns filtered
// by the given scoper/column. When no columns are set, they will be extracted from the query builders subject.
// The argument scoper must either be of type Scoper, string or nil to get SELECT statements without a WHERE clause.
func (qb *QueryBuilder) SelectScoped(scoper any) string {
query := fmt.Sprintf(`SELECT "%s" FROM "%s"`, strings.Join(qb.BuildColumns(), `", "`), TableName(qb.subject))
where, placeholders := qb.Where(scoper)
if placeholders > 0 {
query += ` WHERE ` + where
}

return query
}

// Update returns an UPDATE statement for the query builders subject filter by ID column.
func (qb *QueryBuilder) Update() (string, int) {
return qb.UpdateScoped("id")
}

// UpdateScoped returns an UPDATE statement for the query builders subject filtered by the given column/scoper.
// The argument scoper must either be of type Scoper, string or nil to get UPDATE statements without a WHERE clause.
func (qb *QueryBuilder) UpdateScoped(scoper any) (string, int) {
columns := qb.BuildColumns()
set := make([]string, 0, len(columns))

for _, col := range columns {
set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col))
}

placeholders := len(columns)
query := `UPDATE "%s" SET %s`
if where, count := qb.Where(scoper); count > 0 {
placeholders += count
query += ` WHERE ` + where
}

return fmt.Sprintf(query, TableName(qb.subject), strings.Join(set, ", ")), placeholders
}

// Upsert returns an upsert statement for the query builders subject.
func (qb *QueryBuilder) Upsert() (string, int) {
var updateColumns []string
if upserter, ok := qb.subject.(Upserter); ok {
updateColumns = qb.db.columnMap.Columns(upserter.Upsert())
} else {
updateColumns = qb.BuildColumns()
}

return qb.UpsertColumns(updateColumns...)
}

// UpsertColumns returns an upsert statement for the query builders subject and the specified update columns.
func (qb *QueryBuilder) UpsertColumns(updateColumns ...string) (string, int) {
var clause, setFormat string
switch qb.db.DriverName() {
case MySQL:
clause = "ON DUPLICATE KEY UPDATE"
setFormat = `"%[1]s" = VALUES("%[1]s")`
case PostgreSQL:
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", qb.getPgsqlOnConflictConstraint())
setFormat = `"%[1]s" = EXCLUDED."%[1]s"`
default:
panic("Driver unsupported: " + qb.db.DriverName())
}

set := make([]string, 0, len(updateColumns))
for _, col := range updateColumns {
set = append(set, fmt.Sprintf(setFormat, col))
}

insertColumns := qb.BuildColumns()

return fmt.Sprintf(
`INSERT INTO "%s" ("%s") VALUES (%s) %s %s`,
TableName(qb.subject),
strings.Join(insertColumns, `", "`),
fmt.Sprintf(":%s", strings.Join(insertColumns, ", :")),
clause,
strings.Join(set, ", "),
), len(insertColumns)
}

// Where returns a WHERE clause with named placeholder conditions built from the
// specified scoper/column combined with the AND operator.
func (qb *QueryBuilder) Where(subject any) (string, int) {
t := reflect.TypeOf(subject)
if t == nil { // Subject is a nil interface value.
return "", 0
}

var columns []string
if t.Kind() == reflect.String {
columns = []string{subject.(string)}
} else if t.Kind() == reflect.Struct || t.Kind() == reflect.Pointer {
if scoper, ok := subject.(Scoper); ok {
return qb.Where(scoper.Scope())
}

columns = qb.db.columnMap.Columns(subject)
} else { // This should never happen unless someone wants to do some silly things.
panic(fmt.Sprintf("qb.Where: unknown subject type provided: %q", t.Kind().String()))
}

where := make([]string, 0, len(columns))
for _, col := range columns {
where = append(where, fmt.Sprintf(`"%[1]s" = :%[1]s`, col))
}

return strings.Join(where, ` AND `), len(columns)
}

// BuildColumns returns all the Query Builder columns (if specified), otherwise they are
// determined dynamically using its subject. Additionally, it checks whether columns need
// to be excluded and proceeds accordingly.
func (qb *QueryBuilder) BuildColumns() []string {
var columns []string
if len(qb.columns) > 0 {
columns = qb.columns
} else {
columns = qb.db.columnMap.Columns(qb.subject)
}

if len(qb.excludedColumns) > 0 {
columns = slices.DeleteFunc(append([]string(nil), columns...), func(column string) bool {
for _, exclude := range qb.excludedColumns {
if exclude == column {
return true
}
}

return false
})
}

if qb.sort {
// The order in which the columns appear is not guaranteed as we extract the columns dynamically
// from the struct. So, we've to sort them here to be able to test the generated statements.
sort.SliceStable(columns, func(a, b int) bool {
return columns[a] < columns[b]
})
}

return columns
}

// getPgsqlOnConflictConstraint returns the constraint name of the current [QueryBuilder]'s subject.
// If the subject does not implement the PgsqlOnConflictConstrainter interface, it will simply return
// the table name prefixed with `pk_`.
func (qb *QueryBuilder) getPgsqlOnConflictConstraint() string {
if constrainter, ok := qb.subject.(PgsqlOnConflictConstrainter); ok {
return constrainter.PgsqlOnConflictConstraint()
}

return "pk_" + TableName(qb.subject)
}
Loading

0 comments on commit 97b850a

Please sign in to comment.