Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gormschema: supports trigger #50

Merged
merged 11 commits into from
Jul 11, 2024
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ env "gorm" {
> Note: Views are available for logged-in users, run `atlas login` if you haven't already. To learn more about logged-in features for Atlas, visit [Feature Availability](https://atlasgo.io/features#database-features).

To define a Go struct as a database `VIEW`, implement the `ViewDef` method as follow:

```go
// User is a regular gorm.Model stored in the "users" table.
type User struct {
Expand All @@ -158,7 +159,9 @@ func (WorkingAgedUsers) ViewDef(dialect string) []gormschema.ViewOption {
}
}
```

In order to pass a plain `CREATE VIEW` statement, use the `CreateStmt` as follows:

```go
type BotlTracker struct {
ID uint
Expand All @@ -176,19 +179,48 @@ func (BotlTracker) ViewDef(dialect string) []gormschema.ViewOption {
}
}
```

To include both VIEWs and TABLEs in the migration generation, pass all models to the `Load` function:

```go
stmts, err := gormschema.New("mysql").Load(
&models.User{}, // Table-based model.
&models.WorkingAgedUsers{}, // View-based model.
)
```

The view-based model works just like a regular models in GORM queries. However, make sure the view name is identical to the struct name, and in case they are differ, configure the name using the `TableName` method:

```go
func (WorkingAgedUsers) TableName() string {
return "working_aged_users_custom_name" // View name is different than pluralized struct name.
}
```

#### Trigger

> Note: Trigger feature is only available for logged-in users, run `atlas login` if you haven't already. To learn more about logged-in features for Atlas, visit [Feature Availability](https://atlasgo.io/features#database-features).

To attach triggers to a table, use the `Triggers` method as follows:

```go
type Pet struct {
gorm.Model
Name string
}

func (Pet) Triggers(dialect string) []gormschema.Trigger {
var stmt string
switch dialect {
case "mysql":
stmt = "CREATE TRIGGER pet_insert BEFORE INSERT ON pets FOR EACH ROW SET NEW.name = UPPER(NEW.name)"
}
return []gormschema.Trigger{
gormschema.NewTrigger(gormschema.CreateStmt(stmt)),
}
}
```

### Additional Configuration

To supply custom `gorm.Config{}` object to the provider use the [Go Program Mode](#as-go-file) with
Expand Down
183 changes: 116 additions & 67 deletions gormschema/gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,6 @@ import (
gormig "gorm.io/gorm/migrator"
)

// New returns a new Loader.
func New(dialect string, opts ...Option) *Loader {
l := &Loader{dialect: dialect, config: &gorm.Config{}}
for _, opt := range opts {
opt(l)
}
return l
}

type (
// Loader is a Loader for gorm schema.
Loader struct {
Expand All @@ -35,6 +26,33 @@ type (
}
// Option configures the Loader.
Option func(*Loader)
// ViewOption implemented by VIEW's related options
ViewOption interface {
isViewOption()
apply(*schemaBuilder)
}
// TriggerOption implemented by TRIGGER's related options
TriggerOption interface {
isTriggerOption()
apply(*schemaBuilder)
}
// Trigger defines a trigger.
Trigger struct {
opts []TriggerOption
}
// ViewDefiner defines a view.
ViewDefiner interface {
ViewDef(dialect string) []ViewOption
}
// schemaOption configures the schemaBuilder.
schemaOption func(*schemaBuilder)
schemaBuilder struct {
db *gorm.DB
createStmt string
// viewName is only used for the BuildStmt option.
// BuildStmt returns only a subquery; viewName helps to create a full CREATE VIEW statement.
viewName string
}
)

// WithConfig sets the gorm config.
Expand All @@ -44,6 +62,60 @@ func WithConfig(cfg *gorm.Config) Option {
}
}

// WithJoinTable sets up a join table for the given model and field.
// Deprecated: put the join tables alongside the models in the Load call.
func WithJoinTable(model any, field string, jointable any) Option {
return func(l *Loader) {
l.beforeAutoMigrate = append(l.beforeAutoMigrate, func(db *gorm.DB) error {
return db.SetupJoinTable(model, field, jointable)
})
}
}

// New returns a new Loader.
func New(dialect string, opts ...Option) *Loader {
l := &Loader{dialect: dialect, config: &gorm.Config{}}
for _, opt := range opts {
opt(l)
}
return l
}

// NewTrigger receives a list of TriggerOption to build a Trigger.
func NewTrigger(opts ...TriggerOption) Trigger {
return Trigger{opts: opts}
}

func (s schemaOption) apply(b *schemaBuilder) {
s(b)
}

func (schemaOption) isViewOption() {}
func (schemaOption) isTriggerOption() {}

// CreateStmt accepts raw SQL to create a view or trigger
func CreateStmt(stmt string) interface {
ViewOption
TriggerOption
} {
return schemaOption(func(b *schemaBuilder) {
b.createStmt = stmt
})
}

// BuildStmt accepts a function with gorm query builder to create a CREATE VIEW statement.
// With this option, the view's name will be the same as the model's table name
func BuildStmt(fn func(db *gorm.DB) *gorm.DB) ViewOption {
return schemaOption(func(b *schemaBuilder) {
vd := b.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return fn(tx).
Unscoped(). // Skip gorm deleted_at filtering.
Find(nil) // Execute the query and convert it to SQL.
})
b.createStmt = fmt.Sprintf("CREATE VIEW %s AS %s", b.viewName, vd)
})
}

// Load loads the models and returns the DDL statements representing the schema.
func (l *Loader) Load(models ...any) (string, error) {
var (
Expand Down Expand Up @@ -125,6 +197,9 @@ func (l *Loader) Load(models ...any) (string, error) {
if err = cm.CreateViews(views); err != nil {
return "", err
}
if err = cm.CreateTriggers(models); err != nil {
return "", err
}
if !l.config.DisableForeignKeyConstraintWhenMigrating && l.dialect != "sqlite" {
if err = cm.CreateConstraints(tables); err != nil {
return "", err
Expand Down Expand Up @@ -242,75 +317,20 @@ func (m *migrator) CreateViews(views []ViewDefiner) error {
}); ok {
viewName = namer.TableName()
}
viewBuilder := &viewBuilder{
schemaBuilder := &schemaBuilder{
db: m.DB,
viewName: viewName,
}
for _, opt := range view.ViewDef(m.Dialector.Name()) {
opt(viewBuilder)
opt.apply(schemaBuilder)
}
if err := m.DB.Exec(viewBuilder.createStmt).Error; err != nil {
if err := m.DB.Exec(schemaBuilder.createStmt).Error; err != nil {
return err
}
}
return nil
}

// WithJoinTable sets up a join table for the given model and field.
// Deprecated: put the join tables alongside the models in the Load call.
func WithJoinTable(model any, field string, jointable any) Option {
return func(l *Loader) {
l.beforeAutoMigrate = append(l.beforeAutoMigrate, func(db *gorm.DB) error {
return db.SetupJoinTable(model, field, jointable)
})
}
}

func indirect(t reflect.Type) reflect.Type {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}

type (
// ViewOption configures a viewBuilder.
ViewOption func(*viewBuilder)
// ViewDefiner defines a view.
ViewDefiner interface {
ViewDef(dialect string) []ViewOption
}
viewBuilder struct {
db *gorm.DB
createStmt string
// viewName is only used for the BuildStmt option.
// BuildStmt returns only a subquery; viewName helps to create a full CREATE VIEW statement.
viewName string
}
)

// CreateStmt accepts raw SQL to create a CREATE VIEW statement.
func CreateStmt(stmt string) ViewOption {
return func(b *viewBuilder) {
b.createStmt = b.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return tx.Exec(stmt)
})
}
}

// BuildStmt accepts a function with gorm query builder to create a CREATE VIEW statement.
// With this option, the view's name will be the same as the model's table name
func BuildStmt(fn func(db *gorm.DB) *gorm.DB) ViewOption {
return func(b *viewBuilder) {
vd := b.db.ToSQL(func(tx *gorm.DB) *gorm.DB {
return fn(tx).
Unscoped(). // Skip gorm deleted_at filtering.
Find(nil) // Execute the query and convert it to SQL.
})
b.createStmt = fmt.Sprintf("CREATE VIEW %s AS %s", b.viewName, vd)
}
}

// orderModels places join tables at the end of the list of models (if any),
// which helps GORM resolve m2m relationships correctly.
func (m *migrator) orderModels(models ...any) ([]any, error) {
Expand Down Expand Up @@ -348,3 +368,32 @@ func (m *migrator) orderModels(models ...any) ([]any, error) {
}
return append(otherTables, joinTables...), nil
}

// CreateTriggers creates the triggers for the given models.
func (m *migrator) CreateTriggers(models []any) error {
for _, model := range models {
if md, ok := model.(interface {
Triggers(string) []Trigger
}); ok {
for _, trigger := range md.Triggers(m.Dialector.Name()) {
schemaBuilder := &schemaBuilder{
db: m.DB,
}
for _, opt := range trigger.opts {
opt.apply(schemaBuilder)
if err := m.DB.Exec(schemaBuilder.createStmt).Error; err != nil {
return err
}
}
}
}
}
return nil
}

func indirect(t reflect.Type) reflect.Type {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}
41 changes: 36 additions & 5 deletions gormschema/gorm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,21 @@ import (
func TestSQLiteConfig(t *testing.T) {
resetSession()
l := gormschema.New("sqlite")
sql, err := l.Load(models.WorkingAgedUsers{}, models.Pet{}, ckmodels.Event{}, ckmodels.Location{}, models.TopPetOwner{})
sql, err := l.Load(
models.WorkingAgedUsers{},
models.Pet{},
models.UserPetHistory{},
ckmodels.Event{},
ckmodels.Location{},
models.TopPetOwner{},
)
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/sqlite_default")
resetSession()
l = gormschema.New("sqlite", gormschema.WithConfig(&gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
}))
sql, err = l.Load(models.Pet{}, models.User{})
sql, err = l.Load(models.UserPetHistory{}, models.Pet{}, models.User{})
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/sqlite_no_fk")
resetSession()
Expand All @@ -32,7 +39,15 @@ func TestSQLiteConfig(t *testing.T) {
func TestPostgreSQLConfig(t *testing.T) {
resetSession()
l := gormschema.New("postgres")
sql, err := l.Load(models.WorkingAgedUsers{}, ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{}, models.TopPetOwner{})
sql, err := l.Load(
models.WorkingAgedUsers{},
ckmodels.Location{},
ckmodels.Event{},
models.UserPetHistory{},
models.User{},
models.Pet{},
models.TopPetOwner{},
)
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/postgresql_default")
resetSession()
Expand All @@ -48,7 +63,15 @@ func TestPostgreSQLConfig(t *testing.T) {
func TestMySQLConfig(t *testing.T) {
resetSession()
l := gormschema.New("mysql")
sql, err := l.Load(models.WorkingAgedUsers{}, ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{}, models.TopPetOwner{})
sql, err := l.Load(
models.WorkingAgedUsers{},
ckmodels.Location{},
ckmodels.Event{},
models.UserPetHistory{},
models.User{},
models.Pet{},
models.TopPetOwner{},
)
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/mysql_default")
resetSession()
Expand Down Expand Up @@ -80,7 +103,15 @@ func TestMySQLConfig(t *testing.T) {
func TestSQLServerConfig(t *testing.T) {
resetSession()
l := gormschema.New("sqlserver")
sql, err := l.Load(models.WorkingAgedUsers{}, ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{}, models.TopPetOwner{})
sql, err := l.Load(
models.WorkingAgedUsers{},
ckmodels.Location{},
ckmodels.Event{},
models.UserPetHistory{},
models.User{},
models.Pet{},
models.TopPetOwner{},
)
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/sqlserver_default")
resetSession()
Expand Down
Loading
Loading