diff --git a/database/contracts.go b/database/contracts.go index bf55d320..5f02d043 100644 --- a/database/contracts.go +++ b/database/contracts.go @@ -29,6 +29,11 @@ type IDer interface { // EntityFactoryFunc knows how to create an Entity. type EntityFactoryFunc func() Entity +type EntityConstraint[T any] interface { + Entity + *T +} + // Upserter implements the Upsert method, // which returns a part of the object for ON DUPLICATE KEY UPDATE. type Upserter interface { diff --git a/database/db.go b/database/db.go index 05ebfd9e..f1ac240d 100644 --- a/database/db.go +++ b/database/db.go @@ -39,6 +39,7 @@ type DB struct { Options *Options addr string + queryBuilder QueryBuilder columnMap ColumnMap logger *logging.Logger tableSemaphores map[string]*semaphore.Weighted @@ -256,6 +257,7 @@ func NewDbFromConfig(c *Config, logger *logging.Logger, connectorCallbacks Retry return &DB{ DB: db, Options: &c.Options, + queryBuilder: NewQueryBuilder(db.DriverName()), columnMap: NewColumnMap(db.Mapper), addr: addr, logger: logger, @@ -893,3 +895,7 @@ func (db *DB) Log(ctx context.Context, query string, counter *com.Counter) perio db.logger.Debugf("Finished executing %q with %d rows in %s", query, counter.Total(), tick.Elapsed) })) } + +func (db *DB) QueryBuilder() QueryBuilder { + return db.queryBuilder +} diff --git a/database/delete.go b/database/delete.go new file mode 100644 index 00000000..084c7bef --- /dev/null +++ b/database/delete.go @@ -0,0 +1,213 @@ +package database + +import ( + "context" + "fmt" + "github.com/icinga/icinga-go-library/backoff" + "github.com/icinga/icinga-go-library/com" + "github.com/icinga/icinga-go-library/retry" + "github.com/jmoiron/sqlx" + "github.com/pkg/errors" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" + "reflect" + "time" +) + +// DeleteStatement is the interface for building DELETE statements. +type DeleteStatement interface { + // From sets the table name for the DELETE statement. + // Overrides the table name provided by the entity. + From(table string) DeleteStatement + + // SetWhere sets the where clause for the DELETE statement. + SetWhere(where string) DeleteStatement + + // Entity returns the entity associated with the DELETE statement. + Entity() Entity + + // Table returns the table name for the DELETE statement. + Table() string + + Where() string +} + +// NewDeleteStatement returns a new deleteStatement for the given entity. +func NewDeleteStatement(entity Entity) DeleteStatement { + return &deleteStatement{ + entity: entity, + } +} + +// deleteStatement is the default implementation of the DeleteStatement interface. +type deleteStatement struct { + entity Entity + table string + where string +} + +func (d *deleteStatement) From(table string) DeleteStatement { + d.table = table + + return d +} + +func (d *deleteStatement) SetWhere(where string) DeleteStatement { + d.where = where + + return d +} + +func (d *deleteStatement) Entity() Entity { + return d.entity +} + +func (d *deleteStatement) Table() string { + return d.table +} + +func (d *deleteStatement) Where() string { + return d.where +} + +// DeleteOption is a functional option for DeleteStreamed(). +type DeleteOption func(opts *deleteOptions) + +// WithDeleteStatement sets the DELETE statement to be used for deleting entities. +func WithDeleteStatement(stmt DeleteStatement) DeleteOption { + return func(opts *deleteOptions) { + opts.stmt = stmt + } +} + +// WithOnDelete sets the callbacks for a successful DELETE operation. +func WithOnDelete(onDelete ...OnSuccess[any]) DeleteOption { + return func(opts *deleteOptions) { + opts.onDelete = append(opts.onDelete, onDelete...) + } +} + +// deleteOptions stores the options for DeleteStreamed. +type deleteOptions struct { + stmt DeleteStatement + onDelete []OnSuccess[any] +} + +// DeleteStreamed deletes entities from the given channel from the database. +func DeleteStreamed( + ctx context.Context, + db *DB, + entityType Entity, + entities <-chan any, + options ...DeleteOption, +) error { + opts := &deleteOptions{} + for _, option := range options { + option(opts) + } + + first, forward, err := com.CopyFirst(ctx, entities) + if err != nil { + return errors.Wrap(err, "can't copy first entity") + } + + sem := db.GetSemaphoreForTable(TableName(entityType)) + + var stmt string + + if opts.stmt != nil { + stmt, err = db.QueryBuilder().DeleteStatement(opts.stmt) + if err != nil { + return err + } + } else { + stmt, err = db.QueryBuilder().DeleteStatement(NewDeleteStatement(entityType)) + if err != nil { + return err + } + } + + switch reflect.TypeOf(first).Kind() { + case reflect.Struct, reflect.Map: + return namedBulkExec(ctx, db, stmt, db.Options.MaxPlaceholdersPerStatement, sem, forward, com.NeverSplit[any], opts.onDelete...) + default: + return bulkExec(ctx, db, stmt, db.Options.MaxPlaceholdersPerStatement, sem, forward, opts.onDelete...) + } +} + +func bulkExec( + ctx context.Context, db *DB, query string, count int, sem *semaphore.Weighted, arg <-chan any, onSuccess ...OnSuccess[any], +) error { + var counter com.Counter + defer db.Log(ctx, query, &counter).Stop() + + g, ctx := errgroup.WithContext(ctx) + // Use context from group. + bulk := com.Bulk(ctx, arg, count, com.NeverSplit[any]) + + g.Go(func() error { + g, ctx := errgroup.WithContext(ctx) + + for b := range bulk { + if err := sem.Acquire(ctx, 1); err != nil { + return errors.Wrap(err, "can't acquire semaphore") + } + + g.Go(func(b []any) func() error { + return func() error { + defer sem.Release(1) + + return retry.WithBackoff( + ctx, + func(context.Context) error { + var valCollection []any + + for _, v := range b { + val := reflect.ValueOf(v) + if val.Kind() == reflect.Slice { + for i := 0; i < val.Len(); i++ { + valCollection = append(valCollection, val.Index(i).Interface()) + } + } else { + valCollection = append(valCollection, val.Interface()) + } + } + + stmt, args, err := sqlx.In(query, valCollection) + if err != nil { + return fmt.Errorf( + "%w: %w", + retry.ErrNotRetryable, + errors.Wrapf(err, "can't build placeholders for %q", query), + ) + } + + stmt = db.Rebind(stmt) + _, err = db.ExecContext(ctx, stmt, args...) + if err != nil { + return CantPerformQuery(err, query) + } + + counter.Add(uint64(len(b))) + + for _, onSuccess := range onSuccess { + if err := onSuccess(ctx, b); err != nil { + return err + } + } + + return nil + }, + retry.Retryable, + backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second), + db.GetDefaultRetrySettings(), + ) + } + }(b)) + } + + return g.Wait() + }) + + return g.Wait() +} diff --git a/database/example_upsert_test.go b/database/example_upsert_test.go new file mode 100644 index 00000000..77b08821 --- /dev/null +++ b/database/example_upsert_test.go @@ -0,0 +1,324 @@ +package database + +import ( + "context" + "fmt" + "github.com/icinga/icinga-go-library/com" + "golang.org/x/sync/errgroup" + "time" +) + +func ExampleUpsertStreamed() { + var ( + testEntites = []User{ + {Id: 1, Name: "test1", Age: 10, Email: "test1@test.com"}, + {Id: 2, Name: "test2", Age: 20, Email: "test2@test.com"}, + } + testSelect = &[]User{} + g, ctx = errgroup.WithContext(context.Background()) + entities = make(chan User, len(testEntites)) + logs = getTestLogging() + db = getTestDb(logs) + log = logs.GetLogger() + err error + ) + initTestDb(db) + + g.Go(func() error { + return UpsertStreamed(ctx, db, entities) + }) + + for _, entity := range testEntites { + entities <- entity + } + + close(entities) + time.Sleep(10 * time.Millisecond) + + if err = db.Select(testSelect, "SELECT * FROM user"); err != nil { + log.Fatalf("cannot select from db: %v", err) + } + + fmt.Println(*testSelect) + + if err = g.Wait(); err != nil { + log.Fatalf("error while upserting entities: %v", err) + } + + _ = db.Close() + + // Output: + // [{1 test1 10 test1@test.com} {2 test2 20 test2@test.com}] +} + +func ExampleUpsertStreamedWithStatement() { + var ( + testEntites = []User{ + {Id: 1, Name: "test1"}, + {Id: 2, Name: "test2"}, + } + testSelect = &[]User{} + g, ctx = errgroup.WithContext(context.Background()) + entities = make(chan User, len(testEntites)) + logs = getTestLogging() + db = getTestDb(logs) + log = logs.GetLogger() + stmt = NewUpsertStatement(User{}).SetColumns("id", "name") + err error + ) + initTestDb(db) + + g.Go(func() error { + return UpsertStreamed(ctx, db, entities, WithUpsertStatement(stmt)) + }) + + for _, entity := range testEntites { + entities <- entity + } + + close(entities) + time.Sleep(10 * time.Millisecond) + + if err = db.Select(testSelect, "SELECT * FROM user"); err != nil { + log.Fatalf("cannot select from db: %v", err) + } + + fmt.Println(*testSelect) + + if err = g.Wait(); err != nil { + log.Fatalf("error while upserting entities: %v", err) + } + + _ = db.Close() + + // Output: + // [{1 test1 0 } {2 test2 0 }] +} + +func ExampleUpsertStreamedWithOnUpsert() { + var ( + testEntites = []User{ + {Id: 1, Name: "test1", Age: 10, Email: "test1@test.com"}, + {Id: 2, Name: "test2", Age: 20, Email: "test2@test.com"}, + } + callback = func(ctx context.Context, affectedRows []any) (err error) { + fmt.Printf("number of affected rows: %d\n", len(affectedRows)) + return nil + } + testSelect = &[]User{} + g, ctx = errgroup.WithContext(context.Background()) + entities = make(chan User, len(testEntites)) + logs = getTestLogging() + db = getTestDb(logs) + log = logs.GetLogger() + err error + ) + initTestDb(db) + + g.Go(func() error { + return UpsertStreamed(ctx, db, entities, WithOnUpsert(callback)) + }) + + for _, entity := range testEntites { + entities <- entity + } + + time.Sleep(1 * time.Second) + close(entities) + + if err = db.Select(testSelect, "SELECT * FROM user"); err != nil { + log.Fatalf("cannot select from db: %v", err) + } + + fmt.Println(*testSelect) + + if err = g.Wait(); err != nil { + log.Fatalf("error while upserting entities: %v", err) + } + + _ = db.Close() + + // Output: + // number of affected rows: 2 + // [{1 test1 10 test1@test.com} {2 test2 20 test2@test.com}] +} + +func ExampleNamedBulkUpsert() { + var ( + testEntites = []User{ + {Id: 1, Name: "test1", Age: 10, Email: "test1@test.com"}, + {Id: 2, Name: "test2", Age: 20, Email: "test2@test.com"}, + } + testSelect = &[]User{} + g, ctx = errgroup.WithContext(context.Background()) + entities = make(chan Entity, len(testEntites)) + logs = getTestLogging() + db = getTestDb(logs) + log = logs.GetLogger() + sem = db.GetSemaphoreForTable(TableName(User{})) + err error + ) + initTestDb(db) + + stmt, placeholders, err := db.QueryBuilder().UpsertStatement(NewUpsertStatement(User{})) + if err != nil { + log.Fatalf("error while building upsert statement: %v", err) + } + + g.Go(func() error { + return db.NamedBulkExec(ctx, stmt, placeholders, sem, entities, com.NeverSplit) + }) + + for _, entity := range testEntites { + entities <- entity + } + + time.Sleep(1 * time.Second) + close(entities) + + if err = db.Select(testSelect, "SELECT * FROM user"); err != nil { + log.Fatalf("cannot select from db: %v", err) + } + + fmt.Println(*testSelect) + + if err = g.Wait(); err != nil { + log.Fatalf("error while upserting entities: %v", err) + } + + _ = db.Close() + + // Output: + // [{1 test1 10 test1@test.com} {2 test2 20 test2@test.com}] +} + +func ExampleNamedBulkUpsertWithOnUpsert() { + var ( + testEntites = []User{ + {Id: 1, Name: "test1", Age: 10, Email: "test1@test.com"}, + {Id: 2, Name: "test2", Age: 20, Email: "test2@test.com"}, + } + testSelect = &[]User{} + callback = func(ctx context.Context, affectedRows []Entity) (err error) { + fmt.Printf("number of affected rows: %d\n", len(affectedRows)) + return nil + } + g, ctx = errgroup.WithContext(context.Background()) + entities = make(chan Entity, len(testEntites)) + logs = getTestLogging() + db = getTestDb(logs) + log = logs.GetLogger() + sem = db.GetSemaphoreForTable(TableName(User{})) + err error + ) + initTestDb(db) + + stmt, placeholders, err := db.QueryBuilder().UpsertStatement(NewUpsertStatement(User{})) + if err != nil { + log.Fatalf("error while building upsert statement: %v", err) + } + + g.Go(func() error { + return db.NamedBulkExec(ctx, stmt, placeholders, sem, entities, com.NeverSplit, callback) + }) + + for _, entity := range testEntites { + entities <- entity + } + + time.Sleep(1 * time.Second) + close(entities) + + if err = db.Select(testSelect, "SELECT * FROM user"); err != nil { + log.Fatalf("cannot select from db: %v", err) + } + + fmt.Println(*testSelect) + + if err = g.Wait(); err != nil { + log.Fatalf("error while upserting entities: %v", err) + } + + _ = db.Close() + + // Output: + // number of affected rows: 2 + // [{1 test1 10 test1@test.com} {2 test2 20 test2@test.com}] +} + +func ExampleNamedExecUpsert() { + var ( + testEntites = []User{ + {Id: 1, Name: "test1", Age: 10, Email: "test1@test.com"}, + {Id: 2, Name: "test2", Age: 20, Email: "test2@test.com"}, + } + testSelect = &[]User{} + ctx = context.Background() + logs = getTestLogging() + db = getTestDb(logs) + log = logs.GetLogger() + err error + ) + initTestDb(db) + + stmt, _, err := db.QueryBuilder().UpsertStatement(NewUpsertStatement(User{})) + if err != nil { + log.Fatalf("error while building upsert statement: %v", err) + } + + for _, entity := range testEntites { + if _, err = db.NamedExecContext(ctx, stmt, entity); err != nil { + log.Fatalf("error while upserting entity: %v", err) + } + } + + if err = db.Select(testSelect, "SELECT * FROM user"); err != nil { + log.Fatalf("cannot select from db: %v", err) + } + + fmt.Println(*testSelect) + + _ = db.Close() + + // Output: + // [{1 test1 10 test1@test.com} {2 test2 20 test2@test.com}] +} + +func ExampleExecUpsert() { + var ( + testEntites = [][]any{ + {1, "test1", 10, "test1@test.com"}, + {2, "test2", 20, "test2@test.com"}, + } + testSelect = &[]User{} + stmt = `INSERT INTO user ("id", "name", "age", "email") VALUES (?, ?, ?, ?) ON CONFLICT DO UPDATE SET "name" = EXCLUDED."name", "age" = EXCLUDED."age", "email" = EXCLUDED."email"` + ctx = context.Background() + logs = getTestLogging() + db = getTestDb(logs) + log = logs.GetLogger() + err error + ) + initTestDb(db) + + //stmt, _, err := db.QueryBuilder().UpsertStatement(NewUpsertStatement(User{})) + //if err != nil { + // log.Fatalf("error while building upsert statement: %v", err) + //} + + for _, entity := range testEntites { + if _, err = db.ExecContext(ctx, stmt, entity...); err != nil { + log.Fatalf("error while upserting entity: %v", err) + } + } + + if err = db.Select(testSelect, "SELECT * FROM user"); err != nil { + log.Fatalf("cannot select from db: %v", err) + } + + fmt.Println(*testSelect) + + _ = db.Close() + + // Output: + // [{1 test1 10 test1@test.com} {2 test2 20 test2@test.com}] +} diff --git a/database/insert.go b/database/insert.go new file mode 100644 index 00000000..2df546e2 --- /dev/null +++ b/database/insert.go @@ -0,0 +1,204 @@ +package database + +import "context" + +// InsertStatement is the interface for building INSERT statements. +type InsertStatement interface { + // Into sets the table name for the INSERT statement. + // Overrides the table name provided by the entity. + Into(table string) InsertStatement + + // SetColumns sets the columns to be inserted. + SetColumns(columns ...string) InsertStatement + + // SetExcludedColumns sets the columns to be excluded from the INSERT statement. + // Excludes also columns set by SetColumns. + SetExcludedColumns(columns ...string) InsertStatement + + // Entity returns the entity associated with the INSERT statement. + Entity() Entity + + // Table returns the table name for the INSERT statement. + Table() string + + // Columns returns the columns to be inserted. + Columns() []string + + // ExcludedColumns returns the columns to be excluded from the INSERT statement. + ExcludedColumns() []string +} + +// NewInsertStatement returns a new insertStatement for the given entity. +func NewInsertStatement(entity Entity) InsertStatement { + return &insertStatement{ + entity: entity, + } +} + +// insertStatement is the default implementation of the InsertStatement interface. +type insertStatement struct { + entity Entity + table string + columns []string + excludedColumns []string +} + +func (i *insertStatement) Into(table string) InsertStatement { + i.table = table + + return i +} + +func (i *insertStatement) SetColumns(columns ...string) InsertStatement { + i.columns = columns + + return i +} + +func (i *insertStatement) SetExcludedColumns(columns ...string) InsertStatement { + i.excludedColumns = columns + + return i +} + +func (i *insertStatement) Entity() Entity { + return i.entity +} + +func (i *insertStatement) Table() string { + return i.table +} + +func (i *insertStatement) Columns() []string { + return i.columns +} + +func (i *insertStatement) ExcludedColumns() []string { + return i.excludedColumns +} + +// InsertSelectStatement is the interface for building INSERT SELECT statements. +type InsertSelectStatement interface { + // Into sets the table name for the INSERT SELECT statement. + // Overrides the table name provided by the entity. + Into(table string) InsertSelectStatement + + // SetColumns sets the columns to be inserted. + SetColumns(columns ...string) InsertSelectStatement + + // SetExcludedColumns sets the columns to be excluded from the INSERT SELECT statement. + // Excludes also columns set by SetColumns. + SetExcludedColumns(columns ...string) InsertSelectStatement + + // SetSelect sets the SELECT statement for the INSERT SELECT statement. + SetSelect(stmt SelectStatement) InsertSelectStatement + + // Entity returns the entity associated with the INSERT SELECT statement. + Entity() Entity + + // Table returns the table name for the INSERT SELECT statement. + Table() string + + // Columns returns the columns to be inserted. + Columns() []string + + // ExcludedColumns returns the columns to be excluded from the INSERT statement. + ExcludedColumns() []string + + // Select returns the SELECT statement for the INSERT SELECT statement. + Select() SelectStatement +} + +// NewInsertSelectStatement returns a new insertSelectStatement for the given entity. +func NewInsertSelectStatement(entity Entity) InsertSelectStatement { + return &insertSelectStatement{ + entity: entity, + } +} + +// insertSelectStatement is the default implementation of the InsertSelectStatement interface. +type insertSelectStatement struct { + entity Entity + table string + columns []string + excludedColumns []string + selectStmt SelectStatement +} + +func (i *insertSelectStatement) Into(table string) InsertSelectStatement { + i.table = table + + return i +} + +func (i *insertSelectStatement) SetColumns(columns ...string) InsertSelectStatement { + i.columns = columns + + return i +} + +func (i *insertSelectStatement) SetExcludedColumns(columns ...string) InsertSelectStatement { + i.excludedColumns = columns + + return i +} + +func (i *insertSelectStatement) SetSelect(stmt SelectStatement) InsertSelectStatement { + i.selectStmt = stmt + + return i +} + +func (i *insertSelectStatement) Entity() Entity { + return i.entity +} + +func (i *insertSelectStatement) Table() string { + return i.table +} + +func (i *insertSelectStatement) Columns() []string { + return i.columns +} + +func (i *insertSelectStatement) ExcludedColumns() []string { + return i.excludedColumns +} + +func (i *insertSelectStatement) Select() SelectStatement { + return i.selectStmt +} + +// InsertOption is a functional option for InsertStreamed(). +type InsertOption func(opts *insertOptions) + +// WithInsertStatement sets the INSERT statement to be used for inserting entities. +func WithInsertStatement(stmt InsertStatement) InsertOption { + return func(opts *insertOptions) { + opts.stmt = stmt + } +} + +// WithOnInsert sets the onInsert callbacks for a successful INSERT statement. +func WithOnInsert(onInsert ...OnSuccess[any]) InsertOption { + return func(opts *insertOptions) { + opts.onInsert = append(opts.onInsert, onInsert...) + } +} + +// insertOptions stores the options for InsertStreamed. +type insertOptions struct { + stmt InsertStatement + onInsert []OnSuccess[any] +} + +// InsertStreamed inserts entities from the given channel into the database. +func InsertStreamed[T any, V EntityConstraint[T]]( + ctx context.Context, + db *DB, + entities <-chan T, + options ...InsertOption, +) error { + // TODO (jr): implement + return nil +} diff --git a/database/query_builder.go b/database/query_builder.go new file mode 100644 index 00000000..5b3f596e --- /dev/null +++ b/database/query_builder.go @@ -0,0 +1,307 @@ +package database + +import ( + "errors" + "fmt" + "github.com/icinga/icinga-go-library/strcase" + "github.com/jmoiron/sqlx/reflectx" + "slices" + "sort" + "strings" +) + +var ( + ErrUnsupportedDriver = errors.New("unsupported database driver") + ErrMissingStatementPart = errors.New("missing statement part") +) + +type QueryBuilder interface { + UpsertStatement(stmt UpsertStatement) (string, int, error) + + InsertStatement(stmt InsertStatement) string + + InsertIgnoreStatement(stmt InsertStatement) (string, error) + + InsertSelectStatement(stmt InsertSelectStatement) (string, error) + + SelectStatement(stmt SelectStatement) string + + UpdateStatement(stmt UpdateStatement) (string, error) + + UpdateAllStatement(stmt UpdateStatement) (string, error) + + DeleteStatement(stmt DeleteStatement) (string, error) + + DeleteAllStatement(stmt DeleteStatement) (string, error) + + BuildColumns(entity Entity, columns []string, excludedColumns []string) []string +} + +func NewQueryBuilder(driver string) QueryBuilder { + return &queryBuilder{ + dbDriver: driver, + columnMap: NewColumnMap(reflectx.NewMapperFunc("db", strcase.Snake)), + } +} + +func NewTestQueryBuilder(driver string) QueryBuilder { + return &queryBuilder{ + dbDriver: driver, + columnMap: NewColumnMap(reflectx.NewMapperFunc("db", strcase.Snake)), + sort: true, + } +} + +type queryBuilder struct { + dbDriver string + columnMap ColumnMap + + // Indicates whether the generated columns should be sorted in ascending order before generating the + // actual statements. This is intended for unit tests only and shouldn't be necessary for production code. + sort bool +} + +func (qb *queryBuilder) UpsertStatement(stmt UpsertStatement) (string, int, error) { + columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) + into := stmt.Table() + if into == "" { + into = TableName(stmt.Entity()) + } + var setFormat, clause string + switch qb.dbDriver { + case MySQL: + clause = "ON DUPLICATE KEY UPDATE" + setFormat = `"%[1]s" = VALUES("%[1]s")` + case PostgreSQL: + var constraint string + if constrainter, ok := stmt.Entity().(PgsqlOnConflictConstrainter); ok { + constraint = constrainter.PgsqlOnConflictConstraint() + } else { + constraint = "pk_" + into + } + + clause = fmt.Sprintf("ON CONFLICT ON CONSTRAINT %s DO UPDATE SET", constraint) + setFormat = `"%[1]s" = EXCLUDED."%[1]s"` + case SQLite: + clause = "ON CONFLICT DO UPDATE SET" + setFormat = `"%[1]s" = EXCLUDED."%[1]s"` + default: + return "", 0, fmt.Errorf("%w: %s", ErrUnsupportedDriver, qb.dbDriver) + } + + set := make([]string, 0, len(columns)) + for _, column := range columns { + set = append(set, fmt.Sprintf(setFormat, column)) + } + + return fmt.Sprintf( + `INSERT INTO "%s" ("%s") VALUES (%s) %s %s`, + into, + strings.Join(columns, `", "`), + fmt.Sprintf(":%s", strings.Join(columns, ", :")), + clause, + strings.Join(set, ", "), + ), len(columns), nil +} + +func (qb *queryBuilder) InsertStatement(stmt InsertStatement) string { + columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) + into := stmt.Table() + if into == "" { + into = TableName(stmt.Entity()) + } + + return fmt.Sprintf( + `INSERT INTO "%s" ("%s") VALUES (%s)`, + into, + strings.Join(columns, `", "`), + fmt.Sprintf(":%s", strings.Join(columns, ", :")), + ) +} + +func (qb *queryBuilder) InsertIgnoreStatement(stmt InsertStatement) (string, error) { + columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) + into := stmt.Table() + if into == "" { + into = TableName(stmt.Entity()) + } + + switch qb.dbDriver { + case MySQL: + return fmt.Sprintf( + `INSERT IGNORE INTO "%s" ("%s") VALUES (%s)`, + into, + strings.Join(columns, `", "`), + fmt.Sprintf(":%s", strings.Join(columns, ", :")), + ), nil + case PostgreSQL, SQLite: + return fmt.Sprintf( + `INSERT INTO "%s" ("%s") VALUES (%s) ON CONFLICT DO NOTHING`, + into, + strings.Join(columns, `", "`), + fmt.Sprintf(":%s", strings.Join(columns, ", :")), + ), nil + default: + return "", fmt.Errorf("%w: %s", ErrUnsupportedDriver, qb.dbDriver) + } +} + +func (qb *queryBuilder) InsertSelectStatement(stmt InsertSelectStatement) (string, error) { + columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) + + sel := stmt.Select() + if sel == nil { + return "", fmt.Errorf("%w: %s", ErrMissingStatementPart, "select statement") + } + selectStmt := qb.SelectStatement(sel) + + into := stmt.Table() + if into == "" { + into = TableName(stmt.Entity()) + } + + return fmt.Sprintf( + `INSERT INTO "%s" ("%s") %s`, + into, + strings.Join(columns, `", "`), + selectStmt, + ), nil +} + +func (qb *queryBuilder) SelectStatement(stmt SelectStatement) string { + columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) + + from := stmt.Table() + if from == "" { + from = TableName(stmt.Entity()) + } + + where := stmt.Where() + if where != "" { + where = fmt.Sprintf(" WHERE %s", where) + } + + return fmt.Sprintf( + `SELECT "%s" FROM "%s"%s`, + strings.Join(columns, `", "`), + from, + where, + ) +} + +func (qb *queryBuilder) UpdateStatement(stmt UpdateStatement) (string, error) { + columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) + + table := stmt.Table() + if table == "" { + table = TableName(stmt.Entity()) + } + + where := stmt.Where() + if where == "" { + return "", fmt.Errorf("%w: %s", ErrMissingStatementPart, "where statement - use UpdateAllStatement() instead") + } + + var set []string + + for _, col := range columns { + set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col)) + } + + return fmt.Sprintf( + `UPDATE "%s" SET %s WHERE %s`, + table, + strings.Join(set, ", "), + where, + ), nil +} + +func (qb *queryBuilder) UpdateAllStatement(stmt UpdateStatement) (string, error) { + columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns()) + + table := stmt.Table() + if table == "" { + table = TableName(stmt.Entity()) + } + + where := stmt.Where() + if where != "" { + return "", errors.New("cannot use UpdateAllStatement() with where statement - use UpdateStatement() instead") + } + + var set []string + + for _, col := range columns { + set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col)) + } + + return fmt.Sprintf( + `UPDATE "%s" SET %s`, + table, + set, + ), nil +} + +func (qb *queryBuilder) DeleteStatement(stmt DeleteStatement) (string, error) { + from := stmt.Table() + if from == "" { + from = TableName(stmt.Entity()) + } + where := stmt.Where() + if where != "" { + where = fmt.Sprintf(" WHERE %s", where) + } else { + return "", fmt.Errorf("%w: %s", ErrMissingStatementPart, "cannot use DeleteStatement() without where statement - use DeleteAllStatement() instead") + } + + return fmt.Sprintf( + `DELETE FROM "%s"%s`, + from, + where, + ), nil +} + +func (qb *queryBuilder) DeleteAllStatement(stmt DeleteStatement) (string, error) { + from := stmt.Table() + if from == "" { + from = TableName(stmt.Entity()) + } + where := stmt.Where() + if where != "" { + return "", errors.New("cannot use DeleteAllStatement() with where statement - use DeleteStatement() instead") + } + + return fmt.Sprintf( + `DELETE FROM "%s"`, + from, + ), nil +} + +func (qb *queryBuilder) BuildColumns(entity Entity, columns []string, excludedColumns []string) []string { + var entityColumns []string + + if len(columns) > 0 { + entityColumns = columns + } else { + tempColumns := qb.columnMap.Columns(entity) + entityColumns = make([]string, len(tempColumns)) + copy(entityColumns, tempColumns) + } + + if len(excludedColumns) > 0 { + entityColumns = slices.DeleteFunc( + entityColumns, + func(column string) bool { + return slices.Contains(excludedColumns, column) + }, + ) + } + + if qb.sort { + // The order in which the columns appear is not guaranteed as we extract the columns dynamically + // from the struct. So, we've to sort them here to be able to test the generated statements. + sort.Strings(entityColumns) + } + + return entityColumns[:len(entityColumns):len(entityColumns)] +} diff --git a/database/query_builder_test.go b/database/query_builder_test.go new file mode 100644 index 00000000..b7c92980 --- /dev/null +++ b/database/query_builder_test.go @@ -0,0 +1,638 @@ +package database + +import ( + "github.com/icinga/icinga-go-library/testutils" + "testing" +) + +type InsertStatementTestData struct { + Table string + Columns []string + ExcludedColumns []string +} + +type InsertIgnoreStatementTestData struct { + Driver string + Table string + Columns []string + ExcludedColumns []string +} + +type InsertSelectStatementTestData struct { + Table string + Columns []string + ExcludedColumns []string + Select SelectStatement +} + +type UpdateStatementTestData struct { + Table string + Columns []string + ExcludedColumns []string + Where string +} + +type UpsertStatementTestData struct { + Driver string + Table string + Columns []string + ExcludedColumns []string +} + +type DeleteStatementTestData struct { + Table string + Where string +} + +type DeleteAllStatementTestData struct { + Table string +} + +type SelectStatementTestData struct { + Table string + Columns []string + ExcludedColumns []string + Where string +} + +func TestInsertStatement(t *testing.T) { + tests := []testutils.TestCase[string, InsertStatementTestData]{ + { + Name: "NoColumnsSet", + Expected: `INSERT INTO "user" ("age", "email", "id", "name") VALUES (:age, :email, :id, :name)`, + }, + { + Name: "ColumnsSet", + Expected: `INSERT INTO "user" ("email", "id", "name") VALUES (:email, :id, :name)`, + Data: InsertStatementTestData{ + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "ExcludedColumnsSet", + Expected: `INSERT INTO "user" ("age", "id", "name") VALUES (:age, :id, :name)`, + Data: InsertStatementTestData{ + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "ColumnsAndExcludedColumnsSet", + Expected: `INSERT INTO "user" ("id", "name") VALUES (:id, :name)`, + Data: InsertStatementTestData{ + Columns: []string{"id", "name", "email"}, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "OverrideTableName", + Expected: `INSERT INTO "custom_table_name" ("email", "id", "name") VALUES (:email, :id, :name)`, + Data: InsertStatementTestData{ + Table: "custom_table_name", + Columns: []string{"id", "name", "email"}, + }, + }, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data InsertStatementTestData) (string, error) { + var actual string + var err error + + stmt := NewInsertStatement(&User{}). + SetColumns(data.Columns...). + SetExcludedColumns(data.ExcludedColumns...) + + if data.Table != "" { + stmt.Into(data.Table) + } + + qb := NewTestQueryBuilder(MySQL) + actual = qb.InsertStatement(stmt) + + return actual, err + + })) + } +} + +func TestInsertIgnoreStatement(t *testing.T) { + tests := []testutils.TestCase[string, InsertIgnoreStatementTestData]{ + { + Name: "NoColumnsSet_MySQL", + Expected: `INSERT IGNORE INTO "user" ("age", "email", "id", "name") VALUES (:age, :email, :id, :name)`, + Data: InsertIgnoreStatementTestData{ + Driver: MySQL, + }, + }, + { + Name: "ColumnsSet_MySQL", + Expected: `INSERT IGNORE INTO "user" ("email", "id", "name") VALUES (:email, :id, :name)`, + Data: InsertIgnoreStatementTestData{ + Driver: MySQL, + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "ExcludedColumnsSet_MySQL", + Expected: `INSERT IGNORE INTO "user" ("age", "id", "name") VALUES (:age, :id, :name)`, + Data: InsertIgnoreStatementTestData{ + Driver: MySQL, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "ColumnsAndExcludedColumnsSet_MySQL", + Expected: `INSERT IGNORE INTO "user" ("id", "name") VALUES (:id, :name)`, + Data: InsertIgnoreStatementTestData{ + Driver: MySQL, + Columns: []string{"id", "name", "email"}, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "OverrideTableName_MySQL", + Expected: `INSERT IGNORE INTO "custom_table_name" ("email", "id", "name") VALUES (:email, :id, :name)`, + Data: InsertIgnoreStatementTestData{ + Driver: MySQL, + Table: "custom_table_name", + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "NoColumnsSet_PostgreSQL", + Expected: `INSERT INTO "user" ("age", "email", "id", "name") VALUES (:age, :email, :id, :name) ON CONFLICT DO NOTHING`, + Data: InsertIgnoreStatementTestData{ + Driver: PostgreSQL, + }, + }, + { + Name: "ColumnsSet_PostgreSQL", + Expected: `INSERT INTO "user" ("email", "id", "name") VALUES (:email, :id, :name) ON CONFLICT DO NOTHING`, + Data: InsertIgnoreStatementTestData{ + Driver: PostgreSQL, + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "ExcludedColumnsSet_PostgreSQL", + Expected: `INSERT INTO "user" ("age", "id", "name") VALUES (:age, :id, :name) ON CONFLICT DO NOTHING`, + Data: InsertIgnoreStatementTestData{ + Driver: PostgreSQL, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "ColumnsAndExcludedColumnsSet_PostgreSQL", + Expected: `INSERT INTO "user" ("id", "name") VALUES (:id, :name) ON CONFLICT DO NOTHING`, + Data: InsertIgnoreStatementTestData{ + Driver: PostgreSQL, + Columns: []string{"id", "name", "email"}, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "OverrideTableName_PostgreSQL", + Expected: `INSERT INTO "custom_table_name" ("email", "id", "name") VALUES (:email, :id, :name) ON CONFLICT DO NOTHING`, + Data: InsertIgnoreStatementTestData{ + Driver: PostgreSQL, + Table: "custom_table_name", + Columns: []string{"id", "name", "email"}, + ExcludedColumns: nil, + }, + }, + { + Name: "UnsupportedDriver", + Error: testutils.ErrorIs(ErrUnsupportedDriver), + Data: InsertIgnoreStatementTestData{ + Driver: "abcxyz", // Unsupported driver + Columns: []string{"id", "name", "email"}, + ExcludedColumns: nil, + }, + }, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data InsertIgnoreStatementTestData) (string, error) { + var actual string + var err error + + stmt := NewInsertStatement(&User{}). + SetColumns(data.Columns...). + SetExcludedColumns(data.ExcludedColumns...) + + if data.Table != "" { + stmt.Into(data.Table) + } + + qb := NewTestQueryBuilder(data.Driver) + actual, err = qb.InsertIgnoreStatement(stmt) + + return actual, err + + })) + } +} + +func TestInsertSelectStatement(t *testing.T) { + tests := []testutils.TestCase[string, InsertSelectStatementTestData]{ + { + Name: "ColumnsSet", + Expected: `INSERT INTO "user" ("email", "id", "name") SELECT "email", "id", "name" FROM "user" WHERE id = :id`, + Data: InsertSelectStatementTestData{ + Columns: []string{"id", "name", "email"}, + Select: NewSelectStatement(&User{}).SetColumns("id", "name", "email").SetWhere("id = :id"), + }, + }, + { + Name: "ExcludedColumnsSet", + Expected: `INSERT INTO "user" ("age", "id", "name") SELECT "age", "id", "name" FROM "user" WHERE id = :id`, + Data: InsertSelectStatementTestData{ + ExcludedColumns: []string{"email"}, + Select: NewSelectStatement(&User{}).SetExcludedColumns("email").SetWhere("id = :id"), + }, + }, + { + Name: "ColumnsAndExcludedColumnsSet", + Expected: `INSERT INTO "user" ("id", "name") SELECT "id", "name" FROM "user" WHERE id = :id`, + Data: InsertSelectStatementTestData{ + Columns: []string{"id", "name", "email"}, + ExcludedColumns: []string{"email"}, + Select: NewSelectStatement(&User{}).SetColumns("id", "name", "email").SetExcludedColumns("email").SetWhere("id = :id"), + }, + }, + { + Name: "OverrideTableName", + Expected: `INSERT INTO "custom_table_name" ("email", "id", "name") SELECT "email", "id", "name" FROM "user" WHERE id = :id`, + Data: InsertSelectStatementTestData{ + Table: "custom_table_name", + Columns: []string{"id", "name", "email"}, + Select: NewSelectStatement(&User{}).SetColumns("id", "name", "email").SetWhere("id = :id"), + }, + }, + { + Name: "SelectStatementMissing", + Error: testutils.ErrorIs(ErrMissingStatementPart), + Data: InsertSelectStatementTestData{}, + }, + //{ + // Name: "InvalidColumnName", + // Data: InsertStatementTestData{ + // Columns: []string{"id", "name", "email", "invalid_column"}, + // ExcludedColumns: nil, + // }, + // Error: testutils.ErrorIs(ErrInvalidColumnName), + //}, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data InsertSelectStatementTestData) (string, error) { + var actual string + var err error + + stmt := NewInsertSelectStatement(&User{}). + SetColumns(data.Columns...). + SetExcludedColumns(data.ExcludedColumns...) + + if data.Select != nil { + stmt.SetSelect(data.Select.(SelectStatement)) + } + + if data.Table != "" { + stmt.Into(data.Table) + } + + qb := NewTestQueryBuilder(MySQL) + actual, err = qb.InsertSelectStatement(stmt) + + return actual, err + + })) + } +} + +func TestUpdateStatement(t *testing.T) { + tests := []testutils.TestCase[string, UpdateStatementTestData]{ + { + Name: "NoWhereSet", + Error: testutils.ErrorIs(ErrMissingStatementPart), + }, + { + Name: "ColumnsSet", + Expected: `UPDATE "user" SET "email" = :email, "name" = :name WHERE id = :id`, + Data: UpdateStatementTestData{ + Columns: []string{"name", "email"}, + Where: "id = :id", + }, + }, + { + Name: "ExcludedColumnsSet", + Expected: `UPDATE "user" SET "email" = :email, "name" = :name WHERE id = :id`, + Data: UpdateStatementTestData{ + ExcludedColumns: []string{"id", "age"}, + Where: "id = :id", + }, + }, + { + Name: "OverrideTableName", + Expected: `UPDATE "custom_table_name" SET "email" = :email, "id" = :id, "name" = :name WHERE id = :id`, + Data: UpdateStatementTestData{ + Table: "custom_table_name", + Columns: []string{"id", "name", "email"}, + Where: "id = :id", + }, + }, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data UpdateStatementTestData) (string, error) { + var actual string + var err error + + stmt := NewUpdateStatement(&User{}). + SetColumns(data.Columns...). + SetExcludedColumns(data.ExcludedColumns...) + + if data.Where != "" { + stmt.SetWhere(data.Where) + } + + if data.Table != "" { + stmt.SetTable(data.Table) + } + + qb := NewTestQueryBuilder(MySQL) + actual, err = qb.UpdateStatement(stmt) + + return actual, err + + })) + } +} + +func TestUpsertStatement(t *testing.T) { + tests := []testutils.TestCase[string, UpsertStatementTestData]{ + { + Name: "NoColumnsSet_MySQL", + Expected: `INSERT INTO "user" ("age", "email", "id", "name") VALUES (:age, :email, :id, :name) ON DUPLICATE KEY UPDATE "age" = VALUES("age"), "email" = VALUES("email"), "id" = VALUES("id"), "name" = VALUES("name")`, + Data: UpsertStatementTestData{ + Driver: MySQL, + }, + }, + { + Name: "ColumnsSet_MySQL", + Expected: `INSERT INTO "user" ("email", "id", "name") VALUES (:email, :id, :name) ON DUPLICATE KEY UPDATE "email" = VALUES("email"), "id" = VALUES("id"), "name" = VALUES("name")`, + Data: UpsertStatementTestData{ + Driver: MySQL, + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "ExcludedColumnsSet_MySQL", + Expected: `INSERT INTO "user" ("age", "id", "name") VALUES (:age, :id, :name) ON DUPLICATE KEY UPDATE "age" = VALUES("age"), "id" = VALUES("id"), "name" = VALUES("name")`, + Data: UpsertStatementTestData{ + Driver: MySQL, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "ColumnsAndExcludedColumnsSet_MySQL", + Expected: `INSERT INTO "user" ("id", "name") VALUES (:id, :name) ON DUPLICATE KEY UPDATE "id" = VALUES("id"), "name" = VALUES("name")`, + Data: UpsertStatementTestData{ + Driver: MySQL, + Columns: []string{"id", "name", "email"}, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "OverrideTableName_MySQL", + Expected: `INSERT INTO "custom_table_name" ("email", "id", "name") VALUES (:email, :id, :name) ON DUPLICATE KEY UPDATE "email" = VALUES("email"), "id" = VALUES("id"), "name" = VALUES("name")`, + Data: UpsertStatementTestData{ + Driver: MySQL, + Table: "custom_table_name", + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "NoColumnsSet_PostgreSQL", + Expected: `INSERT INTO "user" ("age", "email", "id", "name") VALUES (:age, :email, :id, :name) ON CONFLICT ON CONSTRAINT pk_user DO UPDATE SET "age" = EXCLUDED."age", "email" = EXCLUDED."email", "id" = EXCLUDED."id", "name" = EXCLUDED."name"`, + Data: UpsertStatementTestData{ + Driver: PostgreSQL, + }, + }, + { + Name: "ColumnsSet_PostgreSQL", + Expected: `INSERT INTO "user" ("email", "id", "name") VALUES (:email, :id, :name) ON CONFLICT ON CONSTRAINT pk_user DO UPDATE SET "email" = EXCLUDED."email", "id" = EXCLUDED."id", "name" = EXCLUDED."name"`, + Data: UpsertStatementTestData{ + Driver: PostgreSQL, + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "ExcludedColumnsSet_PostgreSQL", + Expected: `INSERT INTO "user" ("age", "id", "name") VALUES (:age, :id, :name) ON CONFLICT ON CONSTRAINT pk_user DO UPDATE SET "age" = EXCLUDED."age", "id" = EXCLUDED."id", "name" = EXCLUDED."name"`, + Data: UpsertStatementTestData{ + Driver: PostgreSQL, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "ColumnsAndExcludedColumnsSet_PostgreSQL", + Expected: `INSERT INTO "user" ("id", "name") VALUES (:id, :name) ON CONFLICT ON CONSTRAINT pk_user DO UPDATE SET "id" = EXCLUDED."id", "name" = EXCLUDED."name"`, + Data: UpsertStatementTestData{ + Driver: PostgreSQL, + Columns: []string{"id", "name", "email"}, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "OverrideTableName_PostgreSQL", + Expected: `INSERT INTO "custom_table_name" ("email", "id", "name") VALUES (:email, :id, :name) ON CONFLICT ON CONSTRAINT pk_custom_table_name DO UPDATE SET "email" = EXCLUDED."email", "id" = EXCLUDED."id", "name" = EXCLUDED."name"`, + Data: UpsertStatementTestData{ + Driver: PostgreSQL, + Table: "custom_table_name", + Columns: []string{"id", "name", "email"}, + }, + }, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data UpsertStatementTestData) (string, error) { + var actual string + var err error + + stmt := NewUpsertStatement(&User{}). + SetColumns(data.Columns...). + SetExcludedColumns(data.ExcludedColumns...) + + if data.Table != "" { + stmt.Into(data.Table) + } + + qb := NewTestQueryBuilder(data.Driver) + actual, _, err = qb.UpsertStatement(stmt) + + return actual, err + + })) + } +} + +func TestDeleteStatement(t *testing.T) { + tests := []testutils.TestCase[string, DeleteStatementTestData]{ + { + Name: "NoWhereSet", + Error: testutils.ErrorIs(ErrMissingStatementPart), + }, + { + Name: "WhereSet", + Expected: `DELETE FROM "user" WHERE id = :id`, + Data: DeleteStatementTestData{ + Where: "id = :id", + }, + }, + { + Name: "OverrideTableName", + Expected: `DELETE FROM "custom_table_name" WHERE id = :id`, + Data: DeleteStatementTestData{ + Table: "custom_table_name", + Where: "id = :id", + }, + }, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data DeleteStatementTestData) (string, error) { + var actual string + var err error + + stmt := NewDeleteStatement(&User{}) + + if data.Where != "" { + stmt.SetWhere(data.Where) + } + + if data.Table != "" { + stmt.From(data.Table) + } + + qb := NewTestQueryBuilder(MySQL) + actual, err = qb.DeleteStatement(stmt) + + return actual, err + + })) + } +} + +func TestDeleteAllStatement(t *testing.T) { + tests := []testutils.TestCase[string, DeleteAllStatementTestData]{ + { + Name: "AutoTableName", + Expected: `DELETE FROM "user"`, + }, + { + Name: "OverrideTableName", + Expected: `DELETE FROM "custom_table_name"`, + Data: DeleteAllStatementTestData{ + Table: "custom_table_name", + }, + }, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data DeleteAllStatementTestData) (string, error) { + var actual string + var err error + + stmt := NewDeleteStatement(&User{}) + + if data.Table != "" { + stmt.From(data.Table) + } + + qb := NewTestQueryBuilder(MySQL) + actual, err = qb.DeleteAllStatement(stmt) + + return actual, err + + })) + } +} + +func TestSelectStatement(t *testing.T) { + tests := []testutils.TestCase[string, SelectStatementTestData]{ + { + Name: "NoColumnsSet", + Expected: `SELECT "age", "email", "id", "name" FROM "user"`, + }, + { + Name: "ColumnsSet", + Expected: `SELECT "email", "id", "name" FROM "user"`, + Data: SelectStatementTestData{ + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "ExcludedColumnsSet", + Expected: `SELECT "age", "id", "name" FROM "user"`, + Data: SelectStatementTestData{ + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "ColumnsAndExcludedColumnsSet", + Expected: `SELECT "id", "name" FROM "user"`, + Data: SelectStatementTestData{ + Columns: []string{"id", "name", "email"}, + ExcludedColumns: []string{"email"}, + }, + }, + { + Name: "OverrideTableName", + Expected: `SELECT "email", "id", "name" FROM "custom_table_name"`, + Data: SelectStatementTestData{ + Table: "custom_table_name", + Columns: []string{"id", "name", "email"}, + }, + }, + { + Name: "WhereSet", + Expected: `SELECT "age", "email", "id", "name" FROM "user" WHERE id = :id`, + Data: SelectStatementTestData{ + Where: "id = :id", + }, + }, + { + Name: "MultipleConditionsWhereSet", + Expected: `SELECT "age", "email", "id", "name" FROM "user" WHERE id = :id AND name = :name AND email = :email`, + Data: SelectStatementTestData{ + Where: "id = :id AND name = :name AND email = :email", + }, + }, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data SelectStatementTestData) (string, error) { + var actual string + var err error + + stmt := NewSelectStatement(&User{}). + SetColumns(data.Columns...). + SetExcludedColumns(data.ExcludedColumns...) + + if data.Table != "" { + stmt.From(data.Table) + } + + if data.Where != "" { + stmt.SetWhere(data.Where) + } + + qb := NewTestQueryBuilder(MySQL) + actual = qb.SelectStatement(stmt) + + return actual, err + + })) + } +} diff --git a/database/select.go b/database/select.go new file mode 100644 index 00000000..9ced365e --- /dev/null +++ b/database/select.go @@ -0,0 +1,93 @@ +package database + +// SelectStatement is the interface for building SELECT statements. +type SelectStatement interface { + // From sets the table name for the SELECT statement. + // Overrides the table name provided by the entity. + From(table string) SelectStatement + + // SetColumns sets the columns to be selected. + SetColumns(columns ...string) SelectStatement + + // SetExcludedColumns sets the columns to be excluded from the SELECT statement. + // Excludes also columns set by SetColumns. + SetExcludedColumns(columns ...string) SelectStatement + + // SetWhere sets the where clause for the SELECT statement. + SetWhere(where string) SelectStatement + + // Entity returns the entity associated with the SELECT statement. + Entity() Entity + + // Table returns the table name for the SELECT statement. + Table() string + + // Columns returns the columns to be selected. + Columns() []string + + // ExcludedColumns returns the columns to be excluded from the SELECT statement. + ExcludedColumns() []string + + // Where returns the where clause for the SELECT statement. + Where() string +} + +// NewSelectStatement returns a new selectStatement for the given entity. +func NewSelectStatement(entity Entity) SelectStatement { + return &selectStatement{ + entity: entity, + } +} + +// selectStatement is the default implementation of the SelectStatement interface. +type selectStatement struct { + entity Entity + table string + columns []string + excludedColumns []string + where string +} + +func (s *selectStatement) From(table string) SelectStatement { + s.table = table + + return s +} + +func (s *selectStatement) SetColumns(columns ...string) SelectStatement { + s.columns = columns + + return s +} + +func (s *selectStatement) SetExcludedColumns(columns ...string) SelectStatement { + s.excludedColumns = columns + + return s +} + +func (s *selectStatement) SetWhere(where string) SelectStatement { + s.where = where + + return s +} + +func (s *selectStatement) Entity() Entity { + return s.entity +} + +func (s *selectStatement) Table() string { + return s.table +} + +func (s *selectStatement) Columns() []string { + return s.columns +} + +func (s *selectStatement) ExcludedColumns() []string { + return s.excludedColumns +} + +func (s *selectStatement) Where() string { + return s.where +} diff --git a/database/testutils.go b/database/testutils.go new file mode 100644 index 00000000..d8bb830b --- /dev/null +++ b/database/testutils.go @@ -0,0 +1,95 @@ +package database + +import ( + "fmt" + "github.com/creasty/defaults" + "github.com/icinga/icinga-go-library/logging" + "github.com/icinga/icinga-go-library/utils" + "go.uber.org/zap/zapcore" + "math/rand" + "strconv" + "time" +) + +type User struct { + Id Id + Name string + Age int + Email string +} + +type Id int + +func (i Id) String() string { + return strconv.Itoa(int(i)) +} + +func (m User) ID() ID { + return m.Id +} + +func (m User) SetID(id ID) { + m.Id = id.(Id) +} + +func (m User) Fingerprint() Fingerprinter { + return m +} + +func getTestLogging() *logging.Logging { + logs, err := logging.NewLoggingFromConfig( + "Icinga Go Library", + logging.Config{Level: zapcore.DebugLevel, Output: "console", Interval: time.Second * 10}, + ) + if err != nil { + utils.PrintErrorThenExit(err, 1) + } + + return logs +} + +func getTestDb(logs *logging.Logging) *DB { + var defaultOptions Options + + if err := defaults.Set(&defaultOptions); err != nil { + utils.PrintErrorThenExit(err, 1) + } + + randomName := strconv.Itoa(rand.Int()) + + db, err := NewDbFromConfig( + &Config{Type: "sqlite", Database: fmt.Sprintf(":memory:%s", randomName), Options: defaultOptions}, + logs.GetChildLogger("database"), + RetryConnectorCallbacks{}, + ) + if err != nil { + utils.PrintErrorThenExit(err, 1) + } + + return db +} + +func initTestDb(db *DB) { + if _, err := db.Query("DROP TABLE IF EXISTS user"); err != nil { + utils.PrintErrorThenExit(err, 1) + } + + if _, err := db.Query(`CREATE TABLE user ("id" INTEGER PRIMARY KEY, "name" VARCHAR(255) DEFAULT '', "age" INTEGER DEFAULT 0, "email" VARCHAR(255) DEFAULT '')`); err != nil { + utils.PrintErrorThenExit(err, 1) + } +} + +func prefillTestDb(db *DB) { + entities := []User{ + {Id: 1, Name: "Alice Johnson", Age: 25, Email: "alice.johnson@example.com"}, + {Id: 2, Name: "Bob Smith", Age: 30, Email: "bob.smith@example.com"}, + {Id: 3, Name: "Charlie Brown", Age: 22, Email: "charlie.brown@example.com"}, + {Id: 4, Name: "Diana Prince", Age: 28, Email: "diana.prince@example.com"}, + } + + for _, entity := range entities { + if _, err := db.NamedExec(`INSERT INTO user ("id", "name", "age", "email") VALUES (:id, :name, :age, :email)`, entity); err != nil { + utils.PrintErrorThenExit(err, 1) + } + } +} diff --git a/database/update.go b/database/update.go new file mode 100644 index 00000000..d6a6946f --- /dev/null +++ b/database/update.go @@ -0,0 +1,128 @@ +package database + +import "context" + +// UpdateStatement is the interface for building UPDATE statements. +type UpdateStatement interface { + // SetTable sets the table name for the UPDATE statement. + // Overrides the table name provided by the entity. + SetTable(table string) UpdateStatement + + // SetColumns sets the columns to be updated. + SetColumns(columns ...string) UpdateStatement + + // SetExcludedColumns sets the columns to be excluded from the UPDATE statement. + // Excludes also columns set by SetColumns. + SetExcludedColumns(columns ...string) UpdateStatement + + // SetWhere sets the where clause for the UPDATE statement. + SetWhere(where string) UpdateStatement + + // Entity returns the entity associated with the UPDATE statement. + Entity() Entity + + // Table returns the table name for the UPDATE statement. + Table() string + + // Columns returns the columns to be updated. + Columns() []string + + // ExcludedColumns returns the columns to be excluded from the UPDATE statement. + ExcludedColumns() []string + + // Where returns the where clause for the UPDATE statement. + Where() string +} + +// NewUpdateStatement returns a new updateStatement for the given entity. +func NewUpdateStatement(entity Entity) UpdateStatement { + return &updateStatement{ + entity: entity, + } +} + +// updateStatement is the default implementation of the UpdateStatement interface. +type updateStatement struct { + entity Entity + table string + columns []string + excludedColumns []string + where string +} + +func (u *updateStatement) SetTable(table string) UpdateStatement { + u.table = table + + return u +} + +func (u *updateStatement) SetColumns(columns ...string) UpdateStatement { + u.columns = columns + + return u +} + +func (u *updateStatement) SetExcludedColumns(columns ...string) UpdateStatement { + u.excludedColumns = columns + + return u +} + +func (u *updateStatement) SetWhere(where string) UpdateStatement { + u.where = where + + return u +} + +func (u *updateStatement) Entity() Entity { + return u.entity +} + +func (u *updateStatement) Table() string { + return u.table +} + +func (u *updateStatement) Columns() []string { + return u.columns +} + +func (u *updateStatement) ExcludedColumns() []string { + return u.excludedColumns +} + +func (u *updateStatement) Where() string { + return u.where +} + +// UpdateOption is a functional option for UpdateStreamed(). +type UpdateOption func(opts *updateOptions) + +// WithUpdateStatement sets the UPDATE statement to be used for updating entities. +func WithUpdateStatement(stmt UpdateStatement) UpdateOption { + return func(opts *updateOptions) { + opts.stmt = stmt + } +} + +// WithOnUpdate sets the callback functions to be called after a successful UPDATE. +func WithOnUpdate(onUpdate ...OnSuccess[any]) UpdateOption { + return func(opts *updateOptions) { + opts.onUpdate = append(opts.onUpdate, onUpdate...) + } +} + +// updateOptions stores the options for UpdateStreamed. +type updateOptions struct { + stmt UpdateStatement + onUpdate []OnSuccess[any] +} + +func UpdateStreamed[T any, V EntityConstraint[T]]( + ctx context.Context, + db *DB, + entities <-chan T, + options ...UpdateOption, +) error { + // TODO (jr): implement + return nil +} diff --git a/database/upsert.go b/database/upsert.go new file mode 100644 index 00000000..b55f9316 --- /dev/null +++ b/database/upsert.go @@ -0,0 +1,242 @@ +package database + +import ( + "context" + "github.com/icinga/icinga-go-library/backoff" + "github.com/icinga/icinga-go-library/com" + "github.com/icinga/icinga-go-library/retry" + "github.com/pkg/errors" + "golang.org/x/sync/errgroup" + "golang.org/x/sync/semaphore" + "time" +) + +// UpsertStatement is the interface for building UPSERT statements. +type UpsertStatement interface { + // Into sets the table name for the UPSERT statement. + // Overrides the table name provided by the entity. + Into(table string) UpsertStatement + + // SetColumns sets the columns to be inserted or updated. + SetColumns(columns ...string) UpsertStatement + + // SetExcludedColumns sets the columns to be excluded from the UPSERT statement. + // Excludes also columns set by SetColumns. + SetExcludedColumns(columns ...string) UpsertStatement + + // Entity returns the entity associated with the UPSERT statement. + Entity() Entity + + // Table returns the table name for the UPSERT statement. + Table() string + + // Columns returns the columns to be inserted or updated. + Columns() []string + + // ExcludedColumns returns the columns to be excluded from the UPSERT statement. + ExcludedColumns() []string +} + +// NewUpsertStatement returns a new upsertStatement for the given entity. +func NewUpsertStatement(entity Entity) UpsertStatement { + return &upsertStatement{ + entity: entity, + } +} + +// upsertStatement is the default implementation of the UpsertStatement interface. +type upsertStatement struct { + entity Entity + table string + columns []string + excludedColumns []string +} + +func (u *upsertStatement) Into(table string) UpsertStatement { + u.table = table + + return u +} + +func (u *upsertStatement) SetColumns(columns ...string) UpsertStatement { + u.columns = columns + + return u +} + +func (u *upsertStatement) SetExcludedColumns(columns ...string) UpsertStatement { + u.excludedColumns = columns + + return u +} + +func (u *upsertStatement) Entity() Entity { + return u.entity +} + +func (u *upsertStatement) Table() string { + return u.table +} + +func (u *upsertStatement) Columns() []string { + return u.columns +} + +func (u *upsertStatement) ExcludedColumns() []string { + return u.excludedColumns +} + +// UpsertOption is a functional option for UpsertStreamed(). +type UpsertOption func(opts *upsertOptions) + +// WithUpsertStatement sets the UPSERT statement to be used for upserting entities. +func WithUpsertStatement(stmt UpsertStatement) UpsertOption { + return func(opts *upsertOptions) { + opts.stmt = stmt + } +} + +// WithOnUpsert sets the callback functions to be called after a successful UPSERT. +func WithOnUpsert(onUpsert ...OnSuccess[any]) UpsertOption { + return func(opts *upsertOptions) { + opts.onUpsert = append(opts.onUpsert, onUpsert...) + } +} + +// upsertOptions stores the options for UpsertStreamed. +type upsertOptions struct { + stmt UpsertStatement + onUpsert []OnSuccess[any] +} + +// UpsertStreamed upserts entities from the given channel into the database. +func UpsertStreamed[T any, V EntityConstraint[T]]( + ctx context.Context, + db *DB, + entities <-chan T, + options ...UpsertOption, +) error { + var ( + opts = &upsertOptions{} + entityType = V(new(T)) + sem = db.GetSemaphoreForTable(TableName(entityType)) + stmt string + placeholders int + err error + ) + + for _, option := range options { + option(opts) + } + + if opts.stmt != nil { + stmt, placeholders, err = db.QueryBuilder().UpsertStatement(opts.stmt) + if err != nil { + return err + } + } else { + stmt, placeholders, err = db.QueryBuilder().UpsertStatement(NewUpsertStatement(entityType)) + if err != nil { + return err + } + } + + return namedBulkExec[T]( + ctx, db, stmt, db.BatchSizeByPlaceholders(placeholders), sem, + entities, splitOnDupId[T], opts.onUpsert..., + ) +} + +func namedBulkExec[T any]( + ctx context.Context, + db *DB, + query string, + count int, + sem *semaphore.Weighted, + arg <-chan T, + splitPolicyFactory com.BulkChunkSplitPolicyFactory[T], + onSuccess ...OnSuccess[any], +) error { + var counter com.Counter + defer db.Log(ctx, query, &counter).Stop() + + g, ctx := errgroup.WithContext(ctx) + bulk := com.Bulk(ctx, arg, count, splitPolicyFactory) + + g.Go(func() error { + for { + select { + case b, ok := <-bulk: + if !ok { + return nil + } + + if err := sem.Acquire(ctx, 1); err != nil { + return errors.Wrap(err, "can't acquire semaphore") + } + + g.Go(func(b []T) func() error { + return func() error { + defer sem.Release(1) + + return retry.WithBackoff( + ctx, + func(ctx context.Context) error { + _, err := db.NamedExecContext(ctx, query, b) + if err != nil { + return CantPerformQuery(err, query) + } + + counter.Add(uint64(len(b))) + + for _, onSuccess := range onSuccess { + // TODO (jr): remove -> workaround vvvv + var items []any + for _, item := range b { + items = append(items, any(item)) + } + // TODO ---- workaround end ---- ^^^^ + + if err := onSuccess(ctx, items); err != nil { + return err + } + } + + return nil + }, + retry.Retryable, + backoff.NewExponentialWithJitter(1*time.Millisecond, 1*time.Second), + db.GetDefaultRetrySettings(), + ) + } + }(b)) + case <-ctx.Done(): + return ctx.Err() + } + } + }) + + return g.Wait() +} + +func splitOnDupId[T any]() com.BulkChunkSplitPolicy[T] { + seenIds := map[string]struct{}{} + + return func(ider T) bool { + entity, ok := any(ider).(IDer) + if !ok { + panic("Type T does not implement IDer") + } + + id := entity.ID().String() + + _, ok = seenIds[id] + if ok { + seenIds = map[string]struct{}{id: {}} + } else { + seenIds[id] = struct{}{} + } + + return ok + } +} diff --git a/database/upsert_test.go b/database/upsert_test.go new file mode 100644 index 00000000..3dada2b4 --- /dev/null +++ b/database/upsert_test.go @@ -0,0 +1,163 @@ +package database + +import ( + "context" + "github.com/icinga/icinga-go-library/testutils" + "testing" + "time" +) + +type UpsertStreamedTestData struct { + Entities []User + Statement UpsertStatement + Callbacks []OnSuccess[User] +} + +func TestUpsertStreamed(t *testing.T) { + tests := []testutils.TestCase[[]User, UpsertStreamedTestData]{ + { + Name: "Insert", + Expected: []User{ + {Id: 1, Name: "Alice Johnson", Age: 25, Email: "alice.johnson@example.com"}, + {Id: 2, Name: "Bob Smith", Age: 30, Email: "bob.smith@example.com"}, + {Id: 3, Name: "Charlie Brown", Age: 22, Email: "charlie.brown@example.com"}, + {Id: 4, Name: "Diana Prince", Age: 28, Email: "diana.prince@example.com"}, + {Id: 5, Name: "Evan Davis", Age: 35, Email: "evan.davis@example.com"}, + {Id: 6, Name: "Fiona White", Age: 27, Email: "fiona.white@example.com"}, + {Id: 7, Name: "George King", Age: 29, Email: "george.king@example.com"}, + {Id: 8, Name: "Hannah Moore", Age: 31, Email: "hannah.moore@example.com"}, + }, + Data: UpsertStreamedTestData{ + Entities: []User{ + {Id: 5, Name: "Evan Davis", Age: 35, Email: "evan.davis@example.com"}, + {Id: 6, Name: "Fiona White", Age: 27, Email: "fiona.white@example.com"}, + {Id: 7, Name: "George King", Age: 29, Email: "george.king@example.com"}, + {Id: 8, Name: "Hannah Moore", Age: 31, Email: "hannah.moore@example.com"}, + }, + }, + }, + { + Name: "Update", + Expected: []User{ + {Id: 1, Name: "Alice Johnson", Age: 25, Email: "alice.johnson@example.com"}, + {Id: 2, Name: "Bob Smith", Age: 30, Email: "bob.smith@example.com"}, + {Id: 3, Name: "Evan Davis", Age: 35, Email: "evan.davis@example.com"}, + {Id: 4, Name: "Fiona White", Age: 27, Email: "fiona.white@example.com"}, + }, + Data: UpsertStreamedTestData{ + Entities: []User{ + {Id: 3, Name: "Evan Davis", Age: 35, Email: "evan.davis@example.com"}, + {Id: 4, Name: "Fiona White", Age: 27, Email: "fiona.white@example.com"}, + }, + }, + }, + { + Name: "InsertAndUpdate", + Expected: []User{ + {Id: 1, Name: "Alice Johnson", Age: 25, Email: "alice.johnson@example.com"}, + {Id: 2, Name: "Bob Smith", Age: 30, Email: "bob.smith@example.com"}, + {Id: 3, Name: "Charlie Brown", Age: 22, Email: "charlie.brown@example.com"}, + {Id: 4, Name: "George King", Age: 29, Email: "george.king@example.com"}, + {Id: 5, Name: "Hannah Moore", Age: 31, Email: "hannah.moore@example.com"}, + {Id: 6, Name: "Fiona White", Age: 27, Email: "fiona.white@example.com"}, + }, + Data: UpsertStreamedTestData{ + Entities: []User{ + {Id: 5, Name: "Evan Davis", Age: 35, Email: "evan.davis@example.com"}, + {Id: 6, Name: "Fiona White", Age: 27, Email: "fiona.white@example.com"}, + {Id: 4, Name: "George King", Age: 29, Email: "george.king@example.com"}, + {Id: 5, Name: "Hannah Moore", Age: 31, Email: "hannah.moore@example.com"}, + }, + }, + }, + { + Name: "WithStatement", + Expected: []User{ + {Id: 1, Name: "Alice Johnson", Age: 25, Email: "alice.johnson@example.com"}, + {Id: 2, Name: "Bob Smith", Age: 30, Email: "bob.smith@example.com"}, + {Id: 3, Name: "Charlie Brown", Age: 22, Email: "charlie.brown@example.com"}, + {Id: 4, Name: "Diana Prince", Age: 28, Email: "diana.prince@example.com"}, + {Id: 5, Name: "Evan Davis", Age: 35, Email: "evan.davis@example.com"}, + {Id: 6, Name: "Fiona White", Age: 27, Email: "fiona.white@example.com"}, + }, + Data: UpsertStreamedTestData{ + Entities: []User{ + {Id: 5, Name: "Evan Davis", Age: 35, Email: "evan.davis@example.com"}, + {Id: 6, Name: "Fiona White", Age: 27, Email: "fiona.white@example.com"}, + }, + Statement: NewUpsertStatement(&User{}), + }, + }, + { + Name: "WithFalseStatement", + Error: testutils.ErrorContains("can't perform"), // TODO (jr): is it the right way? + Data: UpsertStreamedTestData{ + Entities: []User{ + {Id: 5, Name: "test5", Age: 50, Email: "test5@test.com"}, + }, + Statement: NewUpsertStatement(&User{}).Into("false_table"), + }, + }, + } + + for _, tst := range tests { + t.Run(tst.Name, tst.F(func(data UpsertStreamedTestData) ([]User, error) { + var ( + upsertError error + ctx, cancel = context.WithCancel(context.Background()) + entities = make(chan User) + logs = getTestLogging() + db = getTestDb(logs) + ) + + go func() { + if tst.Data.Statement != nil { + upsertError = UpsertStreamed(ctx, db, entities, WithUpsertStatement(tst.Data.Statement)) + } else { + upsertError = UpsertStreamed(ctx, db, entities) + } + }() + + initTestDb(db) + prefillTestDb(db) + + for _, entity := range tst.Data.Entities { + entities <- entity + } + + var actual []User + + time.Sleep(time.Second) + + if err := db.Select(&actual, "SELECT * FROM user"); err != nil { + t.Fatalf("cannot select from database: %v", err) + } + + cancel() + _ = db.Close() + + return actual, upsertError + })) + } +} + +// TODO (jr) +//func TestUpsertStreamedCallback(t *testing.T) { +// tests := []testutils.TestCase[any, UpsertStreamedTestData]{ +// { +// Name: "OneCallback", +// Data: UpsertStreamedTestData{ +// Callbacks: []OnSuccess[User]{ +// func(ctx context.Context, affectedRows []User) error { +// +// }, +// }, +// }, +// }, +// } +//} + +// TODO (jr) +// func TestUpsertStreamedEarlyDbClose(t *testing.T) { +// +// } diff --git a/retry/retry.go b/retry/retry.go index fc1648cf..6bd7fde8 100644 --- a/retry/retry.go +++ b/retry/retry.go @@ -3,6 +3,7 @@ package retry import ( "context" "database/sql/driver" + stderrors "errors" "github.com/go-sql-driver/mysql" "github.com/icinga/icinga-go-library/backoff" "github.com/lib/pq" @@ -17,6 +18,8 @@ import ( // DefaultTimeout is our opinionated default timeout for retrying database and Redis operations. const DefaultTimeout = 5 * time.Minute +var ErrNotRetryable = stderrors.New("error not retryable") + // RetryableFunc is a retryable function. type RetryableFunc func(context.Context) error @@ -137,6 +140,10 @@ func ResetTimeout(t *time.Timer, d time.Duration) { // i.e. temporary, timeout, DNS, connection refused and reset, host down and unreachable and // network down and unreachable errors. In addition, any database error is considered retryable. func Retryable(err error) bool { + if errors.Is(err, ErrNotRetryable) { + return false + } + var temporary interface { Temporary() bool }