diff --git a/example/gorm/dialector.go b/example/gorm/dialector.go new file mode 100644 index 000000000..2d37c2899 --- /dev/null +++ b/example/gorm/dialector.go @@ -0,0 +1,176 @@ +package main + +import ( + "database/sql" + "strconv" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/callbacks" + "gorm.io/gorm/clause" + "gorm.io/gorm/logger" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" +) + +// Note: code adapted from gorm source: go-gorm/sqlite. + +// Dialector implements the Dialector interface from gorm. +type Dialector struct { + db *sql.DB +} + +// NewDialector constructs a new dialector from a sql store. +func NewDialector(db *sql.DB) gorm.Dialector { + return &Dialector{db: db} +} + +// Name returns the name of the dialector. +func (d *Dialector) Name() string { + return "genjidb" +} + +// Initialize initializes the dialector with a db. +func (d *Dialector) Initialize(db *gorm.DB) (err error) { + // register callbacks + callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ + LastInsertIDReversed: true, + }) + for k, v := range d.ClauseBuilders() { + db.ClauseBuilders[k] = v + } + return nil +} + +// ClauseBuilders returns the set of clause builders. +func (d *Dialector) ClauseBuilders() map[string]clause.ClauseBuilder { + return map[string]clause.ClauseBuilder{ + "INSERT": func(c clause.Clause, builder clause.Builder) { + if insert, ok := c.Expression.(clause.Insert); ok { + if stmt, ok := builder.(*gorm.Statement); ok { + stmt.WriteString("INSERT ") + if insert.Modifier != "" { + stmt.WriteString(insert.Modifier) + stmt.WriteByte(' ') + } + + stmt.WriteString("INTO ") + if insert.Table.Name == "" { + stmt.WriteQuoted(stmt.Table) + } else { + stmt.WriteQuoted(insert.Table) + } + return + } + } + + c.Build(builder) + }, + "LIMIT": func(c clause.Clause, builder clause.Builder) { + if limit, ok := c.Expression.(clause.Limit); ok { + if limit.Limit > 0 { + builder.WriteString("LIMIT ") + builder.WriteString(strconv.Itoa(limit.Limit)) + } + if limit.Offset > 0 { + if limit.Limit > 0 { + builder.WriteString(" ") + } + builder.WriteString("OFFSET ") + builder.WriteString(strconv.Itoa(limit.Offset)) + } + } + }, + "FOR": func(c clause.Clause, builder clause.Builder) { + if _, ok := c.Expression.(clause.Locking); ok { + // SQLite3 does not support row-level locking. + return + } + c.Build(builder) + }, + } +} + +func (d *Dialector) DefaultValueOf(field *schema.Field) clause.Expression { + if field.AutoIncrement { + return clause.Expr{SQL: "NULL"} + } + + // doesn't work, will raise error + return clause.Expr{SQL: "DEFAULT"} +} + +func (d *Dialector) Migrator(db *gorm.DB) gorm.Migrator { + return Migrator{migrator.Migrator{Config: migrator.Config{ + DB: db, + Dialector: d, + CreateIndexAfterCreateTable: true, + }}} +} + +func (d *Dialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { + writer.WriteByte('?') +} + +func (d *Dialector) QuoteTo(writer clause.Writer, str string) { + writer.WriteByte('`') + if strings.Contains(str, ".") { + for idx, str := range strings.Split(str, ".") { + if idx > 0 { + writer.WriteString(".`") + } + writer.WriteString(str) + writer.WriteByte('`') + } + } else { + writer.WriteString(str) + writer.WriteByte('`') + } +} + +func (d *Dialector) Explain(sql string, vars ...interface{}) string { + return logger.ExplainSQL(sql, nil, `"`, vars...) +} + +func (d *Dialector) DataTypeOf(field *schema.Field) string { + switch field.DataType { + case schema.Bool: + return "numeric" + case schema.Int, schema.Uint: + /* + if field.AutoIncrement && !field.PrimaryKey { + // https://www.sqlite.org/autoinc.html + return "integer PRIMARY KEY AUTOINCREMENT" + } else { + } + */ + return "integer" + case schema.Float: + return "real" + case schema.String: + return "text" + case schema.Time: + // GenjiDB does not (yet) support datetime + // return "datetime" + return "text" + case schema.Bytes: + return "blob" + } + + return string(field.DataType) +} + +func (d *Dialector) SavePoint(tx *gorm.DB, name string) error { + // tx.Exec("SAVEPOINT " + name) + // return nil + return gorm.ErrNotImplemented +} + +func (d *Dialector) RollbackTo(tx *gorm.DB, name string) error { + // tx.Exec("ROLLBACK TO SAVEPOINT " + name) + // return nil + return gorm.ErrNotImplemented +} + +// _ is a type assertion +var _ gorm.Dialector = ((*Dialector)(nil)) diff --git a/example/gorm/go.mod b/example/gorm/go.mod new file mode 100644 index 000000000..5a66a7300 --- /dev/null +++ b/example/gorm/go.mod @@ -0,0 +1,13 @@ +module github.com/genjidb/genji/example/gorm + +go 1.15 + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/genjidb/genji v0.12.0 + github.com/kr/pretty v0.1.0 // indirect + gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect + gorm.io/gorm v1.21.7 +) + +replace github.com/genjidb/genji v0.12.0 => ../../ diff --git a/example/gorm/go.sum b/example/gorm/go.sum new file mode 100644 index 000000000..ac2185913 --- /dev/null +++ b/example/gorm/go.sum @@ -0,0 +1,37 @@ +github.com/buger/jsonparser v1.1.1 h1:2PnMjfWD7wBILjqQbt530v576A/cAbQvEW9gGIpYMUs= +github.com/buger/jsonparser v1.1.1/go.mod h1:6RYKKt7H4d4+iWqouImQ9R2FZql3VbhNgx27UK13J/0= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/google/btree v1.0.0 h1:0udJVsspx3VBr5FwtLhQQtuAsVc79tTq0ocGIPAU6qo= +github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.2 h1:eVKgfIdy9b6zbWBMgFpfDPoAMifwSZagU9HmEU6zgiI= +github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/vmihailenco/msgpack/v5 v5.1.4 h1:6K44/cU6dMNGkVTGGuu7ef2NdSRFMhAFGGLfE3cqtHM= +github.com/vmihailenco/msgpack/v5 v5.1.4/go.mod h1:C5gboKD0TJPqWDTVTtrQNfRbiBwHZGo8UTqP/9/XvLI= +github.com/vmihailenco/tagparser v0.1.2 h1:gnjoVuB/kljJ5wICEEOpx98oXMWPLj22G67Vbd1qPqc= +github.com/vmihailenco/tagparser v0.1.2/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI= +go.etcd.io/bbolt v1.3.5 h1:XAzx9gjCb0Rxj7EoqcClPD1d5ZBxZJk0jbuoPHenBt0= +go.etcd.io/bbolt v1.3.5/go.mod h1:G5EMThwa9y8QZGBClrRx5EY+Yw9kAhnjy3bSjsnlVTQ= +golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5 h1:LfCXLvNmTYH9kEmVgqbnsWfruoXZIrh4YBgqVHtDvw0= +golang.org/x/sys v0.0.0-20200202164722-d101bd2416d5/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/gorm v1.21.7 h1:MuY8oejVL5l3iT7PfE3z5I4J+KW/Nu2w/uTpLe3vV1Q= +gorm.io/gorm v1.21.7/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0= diff --git a/example/gorm/main.go b/example/gorm/main.go new file mode 100644 index 000000000..8717124ef --- /dev/null +++ b/example/gorm/main.go @@ -0,0 +1,68 @@ +package main + +import ( + "context" + "database/sql" + "errors" + + "github.com/genjidb/genji" + "github.com/genjidb/genji/engine/memoryengine" + gdriver "github.com/genjidb/genji/sql/driver" + "gorm.io/gorm" +) + +func run() error { + ctx := context.Background() + + store := memoryengine.NewEngine() + gjdb, err := genji.New(ctx, store) + if err != nil { + return err + } + /* + driver, ok := gdriver.NewDriver(gjdb).(driver.DriverContext) + if !ok { + return gorm.ErrNotImplemented + } + conn, err := driver.OpenConnector("") + if err != nil { + return err + } + */ + conn := gdriver.NewConnector(gjdb) + sqlDB := sql.OpenDB(conn) + dialector := NewDialector(sqlDB) + conf := &gorm.Config{Dialector: dialector, ConnPool: sqlDB} + db, err := gorm.Open(dialector, conf) + if err != nil { + return err + } + if err := db.AutoMigrate(&Entry{}); err != nil { + return err + } + db.Create(&Entry{Value: 4, ID: 1}) + db.Create(&Entry{Value: 10, ID: 2}) + db.Create(&Entry{Value: 30, ID: 3}) + + var e Entry + out := db.Where("value = ?", 30).Find(&e) + if out.Error != nil { + return out.Error + } + if e.Value != 30 { + return errors.New("value was incorrect") + } + return nil +} + +// Entry is an entry in the database. +type Entry struct { + ID int `gorm:"primaryKey"` + Value int `json:"value"` +} + +func main() { + if err := run(); err != nil { + panic(err) + } +} diff --git a/example/gorm/migrator.go b/example/gorm/migrator.go new file mode 100644 index 000000000..e0d528753 --- /dev/null +++ b/example/gorm/migrator.go @@ -0,0 +1,293 @@ +package main + +import ( + "errors" + "fmt" + "regexp" + "strings" + + "gorm.io/gorm" + "gorm.io/gorm/clause" + "gorm.io/gorm/migrator" + "gorm.io/gorm/schema" +) + +// Migrator migrates the database. +type Migrator struct { + migrator.Migrator +} + +func (m *Migrator) RunWithoutForeignKey(fc func() error) error { + // TODO: we do not support PRAGMA foreign_keys + /* + var enabled int + m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled) + if enabled == 1 { + m.DB.Exec("PRAGMA foreign_keys = OFF") + defer m.DB.Exec("PRAGMA foreign_keys = ON") + } + */ + + return fc() +} + +func (m Migrator) HasTable(value interface{}) bool { + var count int + m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + return m.DB.Raw("SELECT count(*) FROM gorm_meta WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count) + }) + return count > 0 +} + +func (m Migrator) DropTable(values ...interface{}) error { + return m.RunWithoutForeignKey(func() error { + values = m.ReorderModels(values, false) + tx := m.DB.Session(&gorm.Session{}) + + for i := len(values) - 1; i >= 0; i-- { + if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { + return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error + }); err != nil { + return err + } + } + + return nil + }) +} + +func (m Migrator) HasColumn(value interface{}, name string) bool { + var count int + m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName + } + + if name != "" { + m.DB.Raw( + "SELECT count(*) FROM gorm_meta WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)", + "table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", + ).Row().Scan(&count) + } + return nil + }) + return count > 0 +} + +func (m Migrator) AlterColumn(value interface{}, name string) error { + return m.RunWithoutForeignKey(func() error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(name); field != nil { + var ( + createSQL string + newTableName = stmt.Table + "__temp" + ) + + m.DB.Raw("SELECT sql FROM gorm_meta WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) + + if reg, err := regexp.Compile("(`|'|\"| )" + field.DBName + "(`|'|\"| ) .*?,"); err == nil { + tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") + if err != nil { + return err + } + + createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) + createSQL = reg.ReplaceAllString(createSQL, fmt.Sprintf("`%v` ?,", field.DBName)) + + var columns []string + columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + for _, columnType := range columnTypes { + columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) + } + + return m.DB.Transaction(func(tx *gorm.DB) error { + queries := []string{ + createSQL, + fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table), + fmt.Sprintf("DROP TABLE `%v`", stmt.Table), + fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, stmt.Table), + } + for _, query := range queries { + if err := tx.Exec(query, m.FullDataTypeOf(field)).Error; err != nil { + return err + } + } + return nil + }) + } else { + return err + } + } else { + return fmt.Errorf("failed to alter field with name %v", name) + } + }) + }) +} + +func (m Migrator) DropColumn(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if field := stmt.Schema.LookUpField(name); field != nil { + name = field.DBName + } + + var ( + createSQL string + newTableName = stmt.Table + "__temp" + ) + + m.DB.Raw("SELECT sql FROM gorm_meta WHERE type = ? AND tbl_name = ? AND name = ?", "table", stmt.Table, stmt.Table).Row().Scan(&createSQL) + + if reg, err := regexp.Compile("(`|'|\"| )" + name + "(`|'|\"| ) .*?,"); err == nil { + tableReg, err := regexp.Compile(" ('|`|\"| )" + stmt.Table + "('|`|\"| ) ") + if err != nil { + return err + } + + createSQL = tableReg.ReplaceAllString(createSQL, fmt.Sprintf(" `%v` ", newTableName)) + createSQL = reg.ReplaceAllString(createSQL, "") + + var columns []string + columnTypes, _ := m.DB.Migrator().ColumnTypes(value) + for _, columnType := range columnTypes { + if columnType.Name() != name { + columns = append(columns, fmt.Sprintf("`%v`", columnType.Name())) + } + } + + return m.DB.Transaction(func(tx *gorm.DB) error { + queries := []string{ + createSQL, + fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), stmt.Table), + fmt.Sprintf("DROP TABLE `%v`", stmt.Table), + fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, stmt.Table), + } + for _, query := range queries { + if err := tx.Exec(query).Error; err != nil { + return err + } + } + return nil + }) + } else { + return err + } + }) +} + +// CreateConstraint creates a constraint. +func (m Migrator) CreateConstraint(interface{}, string) error { + return errors.New("constraints not implemented") +} + +// DropConstraint drops a constraint. +func (m Migrator) DropConstraint(interface{}, string) error { + return errors.New("constraints not implemented") +} + +// HasConstraint checks if a constraint exists. +func (m Migrator) HasConstraint(value interface{}, name string) bool { + /* ErrConstraintsNotImplemented + var count int64 + m.RunWithValue(value, func(stmt *gorm.Statement) error { + m.DB.Raw( + "SELECT count(*) FROM gorm_meta WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ?)", + "table", stmt.Table, `%CONSTRAINT "`+name+`" %`, `%CONSTRAINT `+name+` %`, "%CONSTRAINT `"+name+"`%", + ).Row().Scan(&count) + + return nil + }) + + return count > 0 + */ + + return false +} + +func (m Migrator) CurrentDatabase() (name string) { + return "genjidb" +} + +func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { + for _, opt := range opts { + str := stmt.Quote(opt.DBName) + if opt.Expression != "" { + str = opt.Expression + } + + if opt.Collate != "" { + str += " COLLATE " + opt.Collate + } + + if opt.Sort != "" { + str += " " + opt.Sort + } + results = append(results, clause.Expr{SQL: str}) + } + return +} + +func (m Migrator) CreateIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + opts := m.BuildIndexOptions(idx.Fields, stmt) + values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts} + + createIndexSQL := "CREATE " + if idx.Class != "" { + createIndexSQL += idx.Class + " " + } + createIndexSQL += "INDEX ?" + + if idx.Type != "" { + createIndexSQL += " USING " + idx.Type + } + createIndexSQL += " ON ??" + + if idx.Where != "" { + createIndexSQL += " WHERE " + idx.Where + } + + return m.DB.Exec(createIndexSQL, values...).Error + } + + return fmt.Errorf("failed to create index with name %v", name) + }) +} + +func (m Migrator) HasIndex(value interface{}, name string) bool { + var count int + m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + if name != "" { + m.DB.Raw( + "SELECT count(*) FROM gorm_meta WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name, + ).Row().Scan(&count) + } + return nil + }) + return count > 0 +} + +func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + var sql string + m.DB.Raw("SELECT sql FROM gorm_meta WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql) + if sql != "" { + return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error + } + return fmt.Errorf("failed to find index with name %v", oldName) + }) +} + +func (m Migrator) DropIndex(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + if idx := stmt.Schema.LookIndex(name); idx != nil { + name = idx.Name + } + + return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error + }) +} diff --git a/sql/driver/driver.go b/sql/driver/driver.go index 27c8cfb6f..2c0c9aba5 100644 --- a/sql/driver/driver.go +++ b/sql/driver/driver.go @@ -64,6 +64,14 @@ type connector struct { closeOnce sync.Once } +// NewConnector constructs a new connector with a db. +func NewConnector(db *genji.DB) driver.Connector { + return &connector{ + driver: &sqlDriver{}, + db: db, + } +} + func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { return &conn{db: c.db}, nil }