Skip to content

Commit

Permalink
fix: gorm close
Browse files Browse the repository at this point in the history
  • Loading branch information
goxiaoy committed Jun 20, 2022
1 parent cd7fc4a commit 2a508a1
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 13 deletions.
2 changes: 1 addition & 1 deletion common/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ func (c *Cache[K, V]) delete(e *list.Element) error {
entry := e.Value.(*entry[K, V])
delete(c.items, entry.key)

if f, ok := e.Value.(closable); ok {
if f, ok := any(entry.val).(closable); ok {
return f.Close()
}
return nil
Expand Down
24 changes: 12 additions & 12 deletions examples/gorm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import (
"github.com/goxiaoy/go-saas/common"
"github.com/goxiaoy/go-saas/data"
"github.com/goxiaoy/go-saas/gin/saas"
gorm2 "github.com/goxiaoy/go-saas/gorm"
sgorm "github.com/goxiaoy/go-saas/gorm"
"github.com/goxiaoy/go-saas/seed"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
g "gorm.io/gorm"
"gorm.io/gorm"
"net/http"
)

Expand All @@ -38,14 +38,14 @@ func init() {
func main() {
flag.Parse()

cache := common.NewCache[string, *g.DB]()
cache := common.NewCache[string, *sgorm.DbWrap]()
defer cache.Flush()

var connStrGen common.ConnStrGenerator
switch driver {
case sqlite.DriverName:
sharedDsn = defaultSqliteSharedDsn
connStrGen = common.NewConnStrGenerator("./example-%s")
connStrGen = common.NewConnStrGenerator("./example-%s.db")
case "mysql":
if len(sharedDsn) == 0 {
sharedDsn = defaultMysqlSharedDsn
Expand Down Expand Up @@ -92,14 +92,14 @@ func main() {
return tenantStore
}, data.NewConnStrOption(conn))

clientProvider := gorm2.ClientProviderFunc(func(ctx context.Context, s string) (*g.DB, error) {
client, _, err := cache.GetOrSet(s, func() (*g.DB, error) {
clientProvider := sgorm.ClientProviderFunc(func(ctx context.Context, s string) (*gorm.DB, error) {
client, _, err := cache.GetOrSet(s, func() (*sgorm.DbWrap, error) {
if ensureDbExist != nil {
if err := ensureDbExist(s); err != nil {
return nil, err
}
}
var client *g.DB
var client *gorm.DB
var err error
db, err := sql.Open(driver, s)
if err != nil {
Expand All @@ -111,27 +111,27 @@ func main() {
}

if driver == sqlite.DriverName {
client, err = g.Open(&sqlite.Dialector{
client, err = gorm.Open(&sqlite.Dialector{
DriverName: sqlite.DriverName,
DSN: s,
Conn: db,
})
} else if driver == "mysql" {
client, err = g.Open(mysql.New(mysql.Config{
client, err = gorm.Open(mysql.New(mysql.Config{
Conn: db,
}))
}
return client, err
return sgorm.NewDbWrap(client), err
})

if err != nil {
return client, err
return nil, err
}
return client.WithContext(ctx).Debug(), err

})

dbProvider := gorm2.NewDbProvider(mr, clientProvider)
dbProvider := sgorm.NewDbProvider(mr, clientProvider)

tenantStore = common.NewCachedTenantStore(&TenantStore{dbProvider: dbProvider})

Expand Down
27 changes: 27 additions & 0 deletions gorm/gorm.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,30 @@ func (c ClientProviderFunc) Get(ctx context.Context, dsn string) (*gorm.DB, erro
func NewDbProvider(cs data.ConnStrResolver, cp ClientProvider) DbProvider {
return common.NewDbProvider[*gorm.DB](cs, cp)
}

type DbWrap struct {
*gorm.DB
}

// NewDbWrap wrap gorm.DB into closable
func NewDbWrap(db *gorm.DB) *DbWrap {
return &DbWrap{db}
}

func (d *DbWrap) Close() error {
return closeDb(d.DB)
}

func closeDb(d *gorm.DB) error {
sqlDB, err := d.DB()
if err != nil {
return err
}
cErr := sqlDB.Close()
if cErr != nil {
//todo logging
//logger.Errorf("Gorm db close error: %s", err.Error())
return cErr
}
return nil
}

0 comments on commit 2a508a1

Please sign in to comment.