diff --git a/migrator.go b/migrator.go index 75a288c..34e1d22 100644 --- a/migrator.go +++ b/migrator.go @@ -16,6 +16,65 @@ type Migrator struct { migrator.Migrator } +// AutoMigrate auto migrate values +// +// // Migrating and setting a comment for a single table +// db.Set("gorm:table_comments", "用户信息表").AutoMigrate(&User{}) +// +// // Migrating and setting comments for multiple tables +// db.Set("gorm:table_comments", []string{"用户信息表", "公司信息表"}).AutoMigrate(&User{}, &Company{}) +func (m Migrator) AutoMigrate(values ...interface{}) error { + if err := m.Migrator.AutoMigrate(values...); err != nil { + return err + } + + if tableComments, ok := m.DB.Get("gorm:table_comments"); ok { + var comments []string + switch c := tableComments.(type) { + case string: + comments = append(comments, c) + case []string: + comments = c + default: + return nil + } + for i := 0; i < len(values) && i < len(comments); i++ { + value := values[i] + comment := comments[i] + tx := m.DB.Session(&gorm.Session{}) + if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + schemaName := getTableSchemaName(stmt.Schema) + if schemaName == "" { + schemaName = "dbo" + } + tableName := m.CurrentTable(stmt) + + var setCommentSql string + if m.HasTableComment(stmt) { + setCommentSql = "EXEC sp_updateextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?" + } else { + setCommentSql = "EXEC sp_addextendedproperty 'MS_Description', ?, 'SCHEMA', ?, 'TABLE', ?" + } + return tx.Exec(setCommentSql, comment, schemaName, tableName).Error + }); err != nil { + return err + } + } + } + return nil +} + +func (m Migrator) HasTableComment(stmt *gorm.Statement) bool { + var count int + if err := m.DB.Raw( + "SELECT count(*) FROM sys.objects obj LEFT JOIN sys.extended_properties ep ON ep.major_id = obj.object_id WHERE ep.minor_id = ? AND ep.class = ? AND obj.type = ? AND obj.name = ?", + 0, 1, "U", stmt.Table, + ).Row().Scan(&count); err != nil { + return false + } + return count > 0 +} + func (m Migrator) GetTables() (tableList []string, err error) { return tableList, m.DB.Raw("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_CATALOG = ?", m.CurrentDatabase()).Scan(&tableList).Error }