diff --git a/common/cache.go b/common/cache.go index 1a8d290..bf6c9fe 100644 --- a/common/cache.go +++ b/common/cache.go @@ -189,7 +189,7 @@ func (c *Cache[K, V]) delete(e *list.Element) error { entry := e.Value.(*entry[K, V]) delete(c.items, entry.key) - if f, ok := e.Value.(closable); ok { + if f, ok := any(entry.val).(closable); ok { return f.Close() } return nil diff --git a/examples/gorm/main.go b/examples/gorm/main.go index 9bf05be..d4540e7 100644 --- a/examples/gorm/main.go +++ b/examples/gorm/main.go @@ -11,11 +11,11 @@ import ( "github.com/goxiaoy/go-saas/common" "github.com/goxiaoy/go-saas/data" "github.com/goxiaoy/go-saas/gin/saas" - gorm2 "github.com/goxiaoy/go-saas/gorm" + sgorm "github.com/goxiaoy/go-saas/gorm" "github.com/goxiaoy/go-saas/seed" "gorm.io/driver/mysql" "gorm.io/driver/sqlite" - g "gorm.io/gorm" + "gorm.io/gorm" "net/http" ) @@ -38,14 +38,14 @@ func init() { func main() { flag.Parse() - cache := common.NewCache[string, *g.DB]() + cache := common.NewCache[string, *sgorm.DbWrap]() defer cache.Flush() var connStrGen common.ConnStrGenerator switch driver { case sqlite.DriverName: sharedDsn = defaultSqliteSharedDsn - connStrGen = common.NewConnStrGenerator("./example-%s") + connStrGen = common.NewConnStrGenerator("./example-%s.db") case "mysql": if len(sharedDsn) == 0 { sharedDsn = defaultMysqlSharedDsn @@ -92,14 +92,14 @@ func main() { return tenantStore }, data.NewConnStrOption(conn)) - clientProvider := gorm2.ClientProviderFunc(func(ctx context.Context, s string) (*g.DB, error) { - client, _, err := cache.GetOrSet(s, func() (*g.DB, error) { + clientProvider := sgorm.ClientProviderFunc(func(ctx context.Context, s string) (*gorm.DB, error) { + client, _, err := cache.GetOrSet(s, func() (*sgorm.DbWrap, error) { if ensureDbExist != nil { if err := ensureDbExist(s); err != nil { return nil, err } } - var client *g.DB + var client *gorm.DB var err error db, err := sql.Open(driver, s) if err != nil { @@ -111,27 +111,27 @@ func main() { } if driver == sqlite.DriverName { - client, err = g.Open(&sqlite.Dialector{ + client, err = gorm.Open(&sqlite.Dialector{ DriverName: sqlite.DriverName, DSN: s, Conn: db, }) } else if driver == "mysql" { - client, err = g.Open(mysql.New(mysql.Config{ + client, err = gorm.Open(mysql.New(mysql.Config{ Conn: db, })) } - return client, err + return sgorm.NewDbWrap(client), err }) if err != nil { - return client, err + return nil, err } return client.WithContext(ctx).Debug(), err }) - dbProvider := gorm2.NewDbProvider(mr, clientProvider) + dbProvider := sgorm.NewDbProvider(mr, clientProvider) tenantStore = common.NewCachedTenantStore(&TenantStore{dbProvider: dbProvider}) diff --git a/gorm/gorm.go b/gorm/gorm.go index ffa2ff8..40c7a87 100644 --- a/gorm/gorm.go +++ b/gorm/gorm.go @@ -23,3 +23,30 @@ func (c ClientProviderFunc) Get(ctx context.Context, dsn string) (*gorm.DB, erro func NewDbProvider(cs data.ConnStrResolver, cp ClientProvider) DbProvider { return common.NewDbProvider[*gorm.DB](cs, cp) } + +type DbWrap struct { + *gorm.DB +} + +// NewDbWrap wrap gorm.DB into closable +func NewDbWrap(db *gorm.DB) *DbWrap { + return &DbWrap{db} +} + +func (d *DbWrap) Close() error { + return closeDb(d.DB) +} + +func closeDb(d *gorm.DB) error { + sqlDB, err := d.DB() + if err != nil { + return err + } + cErr := sqlDB.Close() + if cErr != nil { + //todo logging + //logger.Errorf("Gorm db close error: %s", err.Error()) + return cErr + } + return nil +}