-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
485 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,254 @@ | ||
package database | ||
|
||
import ( | ||
"fmt" | ||
"golang.org/x/exp/slices" | ||
"reflect" | ||
"sort" | ||
"strings" | ||
) | ||
|
||
// QueryBuilder is an addon for the [DB] type that takes care of all the database statement building shenanigans. | ||
// The recommended use of QueryBuilder is to only use it to generate a single query at a time and not two different | ||
// ones. If for instance you want to generate `INSERT` and `SELECT` queries, it is best to use two different | ||
// QueryBuilder instances. You can use the DB#QueryBuilder() method to get fully initialised instances each time. | ||
type QueryBuilder struct { | ||
db *DB | ||
|
||
subject any | ||
columns []string | ||
excludedColumns []string | ||
|
||
// Indicates whether the generated columns should be sorted in ascending order before generating the | ||
// actual statements. This is intended for unit tests only and shouldn't be necessary for production code. | ||
sort bool | ||
} | ||
|
||
// SetColumns sets the DB columns to be used when building the statements. | ||
// When you do not want the columns to be extracted dynamically, you can use this method to specify them manually. | ||
// Returns the current *[QueryBuilder] receiver and allows you to chain some method calls. | ||
func (qb *QueryBuilder) SetColumns(columns ...string) *QueryBuilder { | ||
qb.columns = columns | ||
return qb | ||
} | ||
|
||
// SetExcludedColumns excludes the given columns from all the database statements. | ||
// Returns the current *[QueryBuilder] receiver and allows you to chain some method calls. | ||
func (qb *QueryBuilder) SetExcludedColumns(columns ...string) *QueryBuilder { | ||
qb.excludedColumns = columns | ||
return qb | ||
} | ||
|
||
// Delete returns a DELETE statement for the query builders subject filtered by ID. | ||
func (qb *QueryBuilder) Delete() string { | ||
return qb.DeleteBy("id") | ||
} | ||
|
||
// DeleteBy returns a DELETE statement for the query builders subject filtered by the given column. | ||
func (qb *QueryBuilder) DeleteBy(column string) string { | ||
return fmt.Sprintf(`DELETE FROM "%s" WHERE "%s" IN (?)`, TableName(qb.subject), column) | ||
} | ||
|
||
// Insert returns an INSERT INTO statement for the query builders subject. | ||
func (qb *QueryBuilder) Insert() (string, int) { | ||
columns := qb.BuildColumns() | ||
|
||
return fmt.Sprintf( | ||
`INSERT INTO "%s" ("%s") VALUES (%s)`, | ||
TableName(qb.subject), | ||
strings.Join(columns, `", "`), | ||
fmt.Sprintf(":%s", strings.Join(columns, ", :")), | ||
), len(columns) | ||
} | ||
|
||
// InsertIgnore returns an INSERT statement for the query builders subject for | ||
// which the database ignores rows that have already been inserted. | ||
func (qb *QueryBuilder) InsertIgnore() (string, int) { | ||
columns := qb.BuildColumns() | ||
|
||
var clause string | ||
switch qb.db.DriverName() { | ||
case MySQL: | ||
// MySQL treats UPDATE id = id as a no-op. | ||
clause = fmt.Sprintf(`ON DUPLICATE KEY UPDATE "%[1]s" = "%[1]s"`, columns[0]) | ||
case PostgreSQL: | ||
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO NOTHING", qb.getPgsqlOnConflictConstraint()) | ||
default: | ||
panic("Driver unsupported: " + qb.db.DriverName()) | ||
} | ||
|
||
return fmt.Sprintf( | ||
`INSERT INTO "%s" ("%s") VALUES (%s) %s`, | ||
TableName(qb.subject), | ||
strings.Join(columns, `", "`), | ||
fmt.Sprintf(":%s", strings.Join(columns, ", :")), | ||
clause, | ||
), len(columns) | ||
} | ||
|
||
// Select returns a SELECT statement from the query builders subject and the already set columns. | ||
// If no columns are set, they will be extracted from the query builders subject. | ||
// When the query builders subject is of type Scoper, a WHERE clause is appended to the statement. | ||
func (qb *QueryBuilder) Select() string { | ||
var scoper Scoper | ||
if sc, ok := qb.subject.(Scoper); ok { | ||
scoper = sc | ||
} | ||
|
||
return qb.SelectScoped(scoper) | ||
} | ||
|
||
// SelectScoped returns a SELECT statement from the query builders subject and the already set columns filtered | ||
// by the given scoper/column. When no columns are set, they will be extracted from the query builders subject. | ||
// The argument scoper must either be of type Scoper, string or nil to get SELECT statements without a WHERE clause. | ||
func (qb *QueryBuilder) SelectScoped(scoper any) string { | ||
query := fmt.Sprintf(`SELECT "%s" FROM "%s"`, strings.Join(qb.BuildColumns(), `", "`), TableName(qb.subject)) | ||
where, placeholders := qb.Where(scoper) | ||
if placeholders > 0 { | ||
query += ` WHERE ` + where | ||
} | ||
|
||
return query | ||
} | ||
|
||
// Update returns an UPDATE statement for the query builders subject filter by ID column. | ||
func (qb *QueryBuilder) Update() (string, int) { | ||
return qb.UpdateScoped("id") | ||
} | ||
|
||
// UpdateScoped returns an UPDATE statement for the query builders subject filtered by the given column/scoper. | ||
// The argument scoper must either be of type Scoper, string or nil to get UPDATE statements without a WHERE clause. | ||
func (qb *QueryBuilder) UpdateScoped(scoper any) (string, int) { | ||
columns := qb.BuildColumns() | ||
set := make([]string, 0, len(columns)) | ||
|
||
for _, col := range columns { | ||
set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col)) | ||
} | ||
|
||
placeholders := len(columns) | ||
query := `UPDATE "%s" SET %s` | ||
if where, count := qb.Where(scoper); count > 0 { | ||
placeholders += count | ||
query += ` WHERE ` + where | ||
} | ||
|
||
return fmt.Sprintf(query, TableName(qb.subject), strings.Join(set, ", ")), placeholders | ||
} | ||
|
||
// Upsert returns an upsert statement for the query builders subject. | ||
func (qb *QueryBuilder) Upsert() (string, int) { | ||
var updateColumns []string | ||
if upserter, ok := qb.subject.(Upserter); ok { | ||
updateColumns = qb.db.columnMap.Columns(upserter.Upsert()) | ||
} else { | ||
updateColumns = qb.BuildColumns() | ||
} | ||
|
||
return qb.UpsertColumns(updateColumns...) | ||
} | ||
|
||
// UpsertColumns returns an upsert statement for the query builders subject and the specified update columns. | ||
func (qb *QueryBuilder) UpsertColumns(updateColumns ...string) (string, int) { | ||
var clause, setFormat string | ||
switch qb.db.DriverName() { | ||
case MySQL: | ||
clause = "ON DUPLICATE KEY UPDATE" | ||
setFormat = `"%[1]s" = VALUES("%[1]s")` | ||
case PostgreSQL: | ||
clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", qb.getPgsqlOnConflictConstraint()) | ||
setFormat = `"%[1]s" = EXCLUDED."%[1]s"` | ||
default: | ||
panic("Driver unsupported: " + qb.db.DriverName()) | ||
} | ||
|
||
set := make([]string, 0, len(updateColumns)) | ||
for _, col := range updateColumns { | ||
set = append(set, fmt.Sprintf(setFormat, col)) | ||
} | ||
|
||
insertColumns := qb.BuildColumns() | ||
|
||
return fmt.Sprintf( | ||
`INSERT INTO "%s" ("%s") VALUES (%s) %s %s`, | ||
TableName(qb.subject), | ||
strings.Join(insertColumns, `", "`), | ||
fmt.Sprintf(":%s", strings.Join(insertColumns, ", :")), | ||
clause, | ||
strings.Join(set, ", "), | ||
), len(insertColumns) | ||
} | ||
|
||
// Where returns a WHERE clause with named placeholder conditions built from the | ||
// specified scoper/column combined with the AND operator. | ||
func (qb *QueryBuilder) Where(subject any) (string, int) { | ||
t := reflect.TypeOf(subject) | ||
if t == nil { // Subject is a nil interface value. | ||
return "", 0 | ||
} | ||
|
||
var columns []string | ||
if t.Kind() == reflect.String { | ||
columns = []string{subject.(string)} | ||
} else if t.Kind() == reflect.Struct || t.Kind() == reflect.Pointer { | ||
if scoper, ok := subject.(Scoper); ok { | ||
return qb.Where(scoper.Scope()) | ||
} | ||
|
||
columns = qb.db.columnMap.Columns(subject) | ||
} else { // This should never happen unless someone wants to do some silly things. | ||
panic(fmt.Sprintf("qb.Where: unknown subject type provided: %q", t.Kind().String())) | ||
} | ||
|
||
where := make([]string, 0, len(columns)) | ||
for _, col := range columns { | ||
where = append(where, fmt.Sprintf(`"%[1]s" = :%[1]s`, col)) | ||
} | ||
|
||
return strings.Join(where, ` AND `), len(columns) | ||
} | ||
|
||
// BuildColumns returns all the Query Builder columns (if specified), otherwise they are | ||
// determined dynamically using its subject. Additionally, it checks whether columns need | ||
// to be excluded and proceeds accordingly. | ||
func (qb *QueryBuilder) BuildColumns() []string { | ||
var columns []string | ||
if len(qb.columns) > 0 { | ||
columns = qb.columns | ||
} else { | ||
columns = qb.db.columnMap.Columns(qb.subject) | ||
} | ||
|
||
if len(qb.excludedColumns) > 0 { | ||
columns = slices.DeleteFunc(append([]string(nil), columns...), func(column string) bool { | ||
for _, exclude := range qb.excludedColumns { | ||
if exclude == column { | ||
return true | ||
} | ||
} | ||
|
||
return false | ||
}) | ||
} | ||
|
||
if qb.sort { | ||
// The order in which the columns appear is not guaranteed as we extract the columns dynamically | ||
// from the struct. So, we've to sort them here to be able to test the generated statements. | ||
sort.SliceStable(columns, func(a, b int) bool { | ||
return columns[a] < columns[b] | ||
}) | ||
} | ||
|
||
return columns | ||
} | ||
|
||
// getPgsqlOnConflictConstraint returns the constraint name of the current [QueryBuilder]'s subject. | ||
// If the subject does not implement the PgsqlOnConflictConstrainter interface, it will simply return | ||
// the table name prefixed with `pk_`. | ||
func (qb *QueryBuilder) getPgsqlOnConflictConstraint() string { | ||
if constrainter, ok := qb.subject.(PgsqlOnConflictConstrainter); ok { | ||
return constrainter.PgsqlOnConflictConstraint() | ||
} | ||
|
||
return "pk_" + TableName(qb.subject) | ||
} |
Oops, something went wrong.