Skip to content

Commit

Permalink
draft: make API loading consistent
Browse files Browse the repository at this point in the history
  • Loading branch information
luantranminh committed May 9, 2024
1 parent 28183e9 commit 8328d67
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 59 deletions.
105 changes: 60 additions & 45 deletions gormschema/gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"slices"

"ariga.io/atlas-go-sdk/recordriver"
"github.com/go-openapi/inflect"
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlite"
Expand All @@ -33,7 +32,6 @@ type (
dialect string
config *gorm.Config
beforeAutoMigrate []func(*gorm.DB) error
afterAutoMigrate []func(*gorm.DB) error
}
// Option configures the Loader.
Option func(*Loader)
Expand All @@ -46,8 +44,30 @@ func WithConfig(cfg *gorm.Config) Option {
}
}

type (
viewDefiner interface {
ViewDef(*gorm.DB) gorm.ViewOption
}

ViewDef struct {
Def string
}
)

// Load loads the models and returns the DDL statements representing the schema.
func (l *Loader) Load(models ...any) (string, error) {
func (l *Loader) Load(objs ...any) (string, error) {
var (
views []viewDefiner
tables []any
)
for _, obj := range objs {
switch view := obj.(type) {
case viewDefiner:
views = append(views, view)
default:
tables = append(tables, obj)
}
}
var di gorm.Dialector
switch l.dialect {
case "sqlite":
Expand Down Expand Up @@ -94,26 +114,24 @@ func (l *Loader) Load(models ...any) (string, error) {
return "", err
}
}
if err = db.AutoMigrate(models...); err != nil {
if err = db.AutoMigrate(tables...); err != nil {
return "", err
}
for _, cb := range l.afterAutoMigrate {
if err = cb(db); err != nil {
return "", err
}
db, err = gorm.Open(dialector{
Dialector: di,
}, l.config)
if err != nil {
return "", err
}
cm, ok := db.Migrator().(*migrator)
if !ok {
return "", fmt.Errorf("unexpected migrator type: %T", db.Migrator())
}
if err = cm.CreateViews(views); err != nil {
return "", err
}
if !l.config.DisableForeignKeyConstraintWhenMigrating && l.dialect != "sqlite" {
db, err = gorm.Open(dialector{
Dialector: di,
}, l.config)
if err != nil {
return "", err
}
cm, ok := db.Migrator().(*migrator)
if !ok {
return "", err
}
if err = cm.CreateConstraints(models); err != nil {
if err = cm.CreateConstraints(tables); err != nil {
return "", err
}
}
Expand All @@ -133,8 +151,8 @@ type dialector struct {
gorm.Dialector
}

// Migrator returns a new gorm.Migrator which can be used to automatically create all Constraints
// on existing tables.
// Migrator returns a new gorm.Migrator, which can be used to extend the default migrator,
// helping to create constraints and views ...
func (d dialector) Migrator(db *gorm.DB) gorm.Migrator {
return &migrator{
Migrator: gormig.Migrator{
Expand Down Expand Up @@ -187,6 +205,27 @@ func (m *migrator) CreateConstraints(models []any) error {
return nil
}

// CreateViews creates the given "view-based" models
func (m *migrator) CreateViews(views []viewDefiner) error {
for _, view := range views {
viewDef := view.ViewDef(m.DB)
viewName := m.DB.Config.NamingStrategy.TableName(indirect(reflect.TypeOf(view)).Name())
if namer, ok := view.(interface {
TableName() string
}); ok {
viewName = namer.TableName()
}
if err := m.DB.Migrator().CreateView(viewName, gorm.ViewOption{
Replace: viewDef.Replace,
CheckOption: viewDef.CheckOption,
Query: viewDef.Query,
}); err != nil {
return err
}
}
return nil
}

// WithJoinTable sets up a join table for the given model and field.
func WithJoinTable(model any, field string, jointable any) Option {
return func(l *Loader) {
Expand All @@ -196,30 +235,6 @@ func WithJoinTable(model any, field string, jointable any) Option {
}
}

type (
view interface {
ViewDef(*gorm.DB) gorm.ViewOption
}
)

// WithViews sets up callbacks to create views for the given "view-based" models.
func WithViews(models ...any) Option {
return func(l *Loader) {
for _, model := range models {
if view, ok := model.(view); ok {
l.afterAutoMigrate = append(l.afterAutoMigrate, func(db *gorm.DB) error {
viewDef := view.ViewDef(db)
return db.Migrator().CreateView(inflect.Underscore(indirect(reflect.TypeOf(view)).Name()), gorm.ViewOption{
Replace: viewDef.Replace,
CheckOption: viewDef.CheckOption,
Query: viewDef.Query,
})
})
}
}
}
}

func indirect(t reflect.Type) reflect.Type {
for t.Kind() == reflect.Ptr {
t = t.Elem()
Expand Down
23 changes: 10 additions & 13 deletions gormschema/gorm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ import (

func TestSQLiteConfig(t *testing.T) {
resetSession()
l := gormschema.New("sqlite", gormschema.WithViews(models.TopPetOwner{}))
sql, err := l.Load(models.Pet{}, models.User{}, ckmodels.Event{}, ckmodels.Location{})
l := gormschema.New("sqlite")
sql, err := l.Load(models.Pet{}, models.User{}, ckmodels.Event{}, ckmodels.Location{}, models.TopPetOwner{})
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/sqlite_default")
resetSession()
Expand All @@ -31,8 +31,8 @@ func TestSQLiteConfig(t *testing.T) {

func TestPostgreSQLConfig(t *testing.T) {
resetSession()
l := gormschema.New("postgres", gormschema.WithViews(models.TopPetOwner{}))
sql, err := l.Load(ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{})
l := gormschema.New("postgres")
sql, err := l.Load(ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{}, models.TopPetOwner{})
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/postgresql_default")
resetSession()
Expand All @@ -47,8 +47,8 @@ func TestPostgreSQLConfig(t *testing.T) {

func TestMySQLConfig(t *testing.T) {
resetSession()
l := gormschema.New("mysql", gormschema.WithViews(models.TopPetOwner{}))
sql, err := l.Load(ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{})
l := gormschema.New("mysql")
sql, err := l.Load(ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{}, models.TopPetOwner{})
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/mysql_default")
resetSession()
Expand All @@ -61,19 +61,16 @@ func TestMySQLConfig(t *testing.T) {
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/mysql_no_fk")
resetSession()
l = gormschema.New("mysql",
gormschema.WithViews(customjointable.TopCrowdedAddresses{}),
gormschema.WithJoinTable(&customjointable.Person{}, "Addresses", &customjointable.PersonAddress{}),
)
sql, err = l.Load(customjointable.Address{}, customjointable.Person{})
l = gormschema.New("mysql", gormschema.WithJoinTable(&customjointable.Person{}, "Addresses", &customjointable.PersonAddress{}))
sql, err = l.Load(customjointable.Address{}, customjointable.Person{}, customjointable.TopCrowdedAddresses{})
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/mysql_custom_join_table")
}

func TestSQLServerConfig(t *testing.T) {
resetSession()
l := gormschema.New("sqlserver", gormschema.WithViews(models.TopPetOwner{}))
sql, err := l.Load(ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{})
l := gormschema.New("sqlserver")
sql, err := l.Load(ckmodels.Location{}, ckmodels.Event{}, models.User{}, models.Pet{}, models.TopPetOwner{})
require.NoError(t, err)
requireEqualContent(t, sql, "testdata/sqlserver_default")
resetSession()
Expand Down
9 changes: 8 additions & 1 deletion internal/testdata/models/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,14 @@ type User struct {
Pets []Pet
}

type TopPetOwner struct{}
type TopPetOwner struct {
ID uint
PetCount int
}

func (TopPetOwner) TableName() string {
return "top_pet_owner_custom_name"
}

func (TopPetOwner) ViewDef(db *gorm.DB) gorm.ViewOption {
return gorm.ViewOption{
Expand Down

0 comments on commit 8328d67

Please sign in to comment.