-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.go
152 lines (123 loc) · 3.57 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
package migrent
import (
"context"
"database/sql"
"fmt"
"sort"
"time"
entsql "entgo.io/ent/dialect/sql"
"github.com/dakimura/migrent/ent/migration"
"github.com/dakimura/migrent/ent"
)
type MigrationName string
type Migration interface {
Up(ctx context.Context) error
Down(ctx context.Context) error
}
type Client struct {
entclient *ent.Client
}
func NewClient(entclient *ent.Client) *Client {
return &Client{entclient: entclient}
}
func Open(driverName, dataSourceName string, options ...ent.Option) (*Client, error) {
cli, err := ent.Open(driverName, dataSourceName, options...)
if err != nil {
return nil, err
}
return &Client{entclient: cli}, nil
}
func OpenByMySQLDB(db *sql.DB) *Client {
return OpenByDB(db, "mysql")
}
func OpenByDB(db *sql.DB, driver string) *Client {
drv := entsql.OpenDB(driver, db)
client := ent.NewClient(ent.Driver(drv))
return &Client{entclient: client}
}
func (c *Client) Up(ctx context.Context, migs map[MigrationName]Migration) error {
// create internal table if not exists
err := c.createMigrationTable(ctx)
if err != nil {
return err
}
// sort migration names
mNames := sortedMigrationNames(migs)
for _, name := range mNames {
// --- check if the migration is already applied or not
m, err := c.entclient.Migration.Query().
Where(migration.NameEQ(string(name))).
All(ctx)
if err != nil {
return fmt.Errorf("querying migration history for %s: %w", name, err)
}
// this migration is already applied to DB, skip
if len(m) > 0 {
continue
}
// --- apply the migration
err = migs[name].Up(ctx)
if err != nil {
return fmt.Errorf("migration(Up) for %s: %w", name, err)
}
// --- record the migration to the internal table
_, err = c.entclient.Migration.Create().SetName(string(name)).SetAppliedAt(time.Now()).Save(ctx)
if err != nil {
return fmt.Errorf("record migration(Up) of %s: %w", name, err)
}
}
return nil
}
func (c *Client) Down(ctx context.Context, migs map[MigrationName]Migration) error {
// implement me. the below is currently just a copypaste of Up()
err := c.createMigrationTable(ctx)
if err != nil {
return err
}
// sort migration names
mNames := sortedMigrationNames(migs)
for _, name := range mNames {
// --- check if the migration is already applied or not
m, err := c.entclient.Migration.Query().
Where(migration.NameEQ(string(name))).
All(ctx)
if err != nil {
return fmt.Errorf("querying migration history for %s: %w", name, err)
}
// this migration is not applied to DB, skip
if len(m) == 0 {
continue
}
// --- apply the migration
err = migs[name].Down(ctx)
if err != nil {
return fmt.Errorf("migration(Down) for %s: %w", name, err)
}
// --- record the migration to the internal table
_, err = c.entclient.Migration.Delete().Where(migration.NameEQ(string(name))).Exec(ctx)
if err != nil {
return fmt.Errorf("record migration(Down) of %s: %w", name, err)
}
}
return nil
}
// TODO: where to run "defer client.Close()" ?
// createMigrationTable creates an internal migration table if not exists
func (c *Client) createMigrationTable(ctx context.Context) error {
err := c.entclient.Schema.Create(ctx)
if err != nil {
return fmt.Errorf("create the internal migration table: %w", err)
}
return nil
}
// sortedMigrationNames returns the migration names sorted by dictionary order.
func sortedMigrationNames(m map[MigrationName]Migration) []MigrationName {
var names []MigrationName
for name := range m {
names = append(names, name)
}
sort.SliceStable(names, func(i, j int) bool {
return names[i] < names[j]
})
return names
}