Skip to content

Commit

Permalink
Merge pull request #13 from go-gorm/feature/read-write-connections
Browse files Browse the repository at this point in the history
  • Loading branch information
huacnlee authored Mar 18, 2022
2 parents cde328a + 1ac1ef1 commit d9ee537
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 27 deletions.
20 changes: 19 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,21 @@ jobs:
env:
DIALECTOR: postgres
DATABASE_URL: postgres://gorm:gorm@localhost:5432/sharding-test
DATABASE_READ_URL: postgres://gorm:gorm@localhost:5432/sharding-read-test
DATABASE_WRITE_URL: postgres://gorm:gorm@localhost:5432/sharding-write-test
steps:
- name: Set up Go
uses: actions/setup-go@v1
with:
go-version: 1.17
id: go

- name: Create Read Database
run: PGPASSWORD=gorm psql -h localhost -U gorm -d sharding-test -c 'CREATE DATABASE "sharding-read-test";'

- name: Create Write Databases
run: PGPASSWORD=gorm psql -h localhost -U gorm -d sharding-test -c 'CREATE DATABASE "sharding-write-test";'

- name: Check out code into the Go module directory
uses: actions/checkout@v1

Expand All @@ -67,7 +75,7 @@ jobs:
MYSQL_DATABASE: sharding-test
MYSQL_USER: gorm
MYSQL_PASSWORD: gorm
MYSQL_RANDOM_ROOT_PASSWORD: "yes"
MYSQL_ROOT_PASSWORD: gorm
ports:
- 3306:3306
options: >-
Expand All @@ -80,13 +88,23 @@ jobs:
env:
DIALECTOR: mysql
DATABASE_URL: gorm:gorm@tcp(127.0.0.1:3306)/sharding-test?charset=utf8mb4&parseTime=True&loc=Local
DATABASE_READ_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-read-test?charset=utf8mb4&parseTime=True&loc=Local
DATABASE_WRITE_URL: root:gorm@tcp(127.0.0.1:3306)/sharding-write-test?charset=utf8mb4&parseTime=True&loc=Local
steps:
- name: Set up Go
uses: actions/setup-go@v1
with:
go-version: 1.17
id: go

- name: Create Read Database
run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-read-test
#run: mysql -e 'CREATE DATABASE sharding-read-test' -ugorm -pgorm

- name: Create Write Database
run: mysqladmin -h 127.0.0.1 -uroot -pgorm create sharding-write-test
#run: mysql -e 'CREATE DATABASE sharding-write-test' -ugorm -pgorm

- name: Check out code into the Go module directory
uses: actions/checkout@v1

Expand Down
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,31 @@ Recommend options:
- [Snowflake](https://github.com/bwmarrin/snowflake)
- [Database sequence by manully](https://www.postgresql.org/docs/current/sql-createsequence.html)
## Combining with dbresolver
> 🚨 NOTE: Use dbresolver first.
```go
dsn := "host=localhost user=gorm password=gorm dbname=gorm port=5432 sslmode=disable"
dsnRead := "host=localhost user=gorm password=gorm dbname=gorm-slave port=5432 sslmode=disable"

conn := postgres.Open(dsn)
connRead := postgres.Open(dsnRead)

db, err := gorm.Open(conn, &gorm.Config{})
dbRead, err := gorm.Open(conn, &gorm.Config{})

db.Use(dbresolver.Register(dbresolver.Config{
Replicas: []gorm.Dialector{dbRead.Dialector},
}))

db.Use(sharding.Register(sharding.Config{
ShardingKey: "user_id",
NumberOfShards: 64,
PrimaryKeyGenerator: sharding.PKSnowflake,
}))
```
## Sharding process
This graph show up how Gorm Sharding works.
Expand Down
13 changes: 0 additions & 13 deletions conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,6 @@ type ConnPool struct {
gorm.ConnPool
}

// registerConnPool replace Gorm db.ConnPool as custom
func (s *Sharding) registerConnPool(db *gorm.DB) {
// Avoid assign loop
basePool := db.ConnPool
if _, ok := basePool.(ConnPool); ok {
return
}

s.ConnPool = &ConnPool{ConnPool: basePool, sharding: s}
db.ConnPool = s.ConnPool
db.Statement.ConnPool = s.ConnPool
}

func (pool *ConnPool) String() string {
return "gorm:sharding:conn_pool"
}
Expand Down
3 changes: 2 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ require (
github.com/longbridgeapp/sqlparser v0.3.1
gorm.io/driver/mysql v1.3.2
gorm.io/driver/postgres v1.3.1
gorm.io/gorm v1.23.1
gorm.io/gorm v1.23.2
gorm.io/hints v1.1.0
gorm.io/plugin/dbresolver v1.1.0
)

require (
Expand Down
10 changes: 9 additions & 1 deletion go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
Expand Down Expand Up @@ -76,6 +77,7 @@ github.com/jackc/puddle v1.2.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dv
github.com/jackc/puddle v1.2.1/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas=
github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
Expand Down Expand Up @@ -205,16 +207,22 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/mysql v1.0.3/go.mod h1:twGxftLBlFgNVNakL7F+P/x9oYqoymG3YYT8cAfI9oI=
gorm.io/driver/mysql v1.3.2 h1:QJryWiqQ91EvZ0jZL48NOpdlPdMjdip1hQ8bTgo4H7I=
gorm.io/driver/mysql v1.3.2/go.mod h1:ChK6AHbHgDCFZyJp0F+BmVGb06PSIoh9uVYKAlRbb2U=
gorm.io/driver/postgres v1.3.1 h1:Pyv+gg1Gq1IgsLYytj/S2k7ebII3CzEdpqQkPOdH24g=
gorm.io/driver/postgres v1.3.1/go.mod h1:WwvWOuR9unCLpGWCL6Y3JOeBWvbKi6JLhayiVclSZZU=
gorm.io/driver/sqlite v1.1.6 h1:p3U8WXkVFTOLPED4JjrZExfndjOtya3db8w9/vEMNyI=
gorm.io/driver/sqlite v1.1.6/go.mod h1:W8LmC/6UvVbHKah0+QOC7Ja66EaZXHwUTjgXY8YNWX8=
gorm.io/gorm v1.20.4/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw=
gorm.io/gorm v1.20.11/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw=
gorm.io/gorm v1.21.15/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0=
gorm.io/gorm v1.22.2/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0=
gorm.io/gorm v1.23.1 h1:aj5IlhDzEPsoIyOPtTRVI+SyaN1u6k613sbt4pwbxG0=
gorm.io/gorm v1.23.1/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/gorm v1.23.2 h1:xmq9QRMWL8HTJyhAUBXy8FqIIQCYESeKfJL4DoGKiWQ=
gorm.io/gorm v1.23.2/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk=
gorm.io/hints v1.1.0 h1:Lp4z3rxREufSdxn4qmkK3TLDltrM10FLTHiuqwDPvXw=
gorm.io/hints v1.1.0/go.mod h1:lKQ0JjySsPBj3uslFzY3JhYDtqEwzm+G1hv8rWujB6Y=
gorm.io/plugin/dbresolver v1.1.0 h1:cegr4DeprR6SkLIQlKhJLYxH8muFbJ4SmnojXvoeb00=
gorm.io/plugin/dbresolver v1.1.0/go.mod h1:tpImigFAEejCALOttyhWqsy4vfa2Uh/vAUVnL5IRF7Y=
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
17 changes: 15 additions & 2 deletions sharding.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ func (s *Sharding) LastQuery() string {
// Initialize implement for Gorm plugin interface
func (s *Sharding) Initialize(db *gorm.DB) error {
s.DB = db
s.registerConnPool(db)
s.registerCallbacks(db)

for t, c := range s.configs {
if c.PrimaryKeyGenerator == PKPGSequence {
Expand All @@ -218,6 +218,20 @@ func (s *Sharding) Initialize(db *gorm.DB) error {
return s.compile()
}

func (s *Sharding) registerCallbacks(db *gorm.DB) {
s.Callback().Create().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Query().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Update().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Delete().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Row().Before("*").Register("gorm:sharding", s.switchConn)
s.Callback().Raw().Before("*").Register("gorm:sharding", s.switchConn)
}

func (s *Sharding) switchConn(db *gorm.DB) {
s.ConnPool = &ConnPool{ConnPool: db.Statement.ConnPool, sharding: s}
db.Statement.ConnPool = s.ConnPool
}

// resolve split the old query to full table query and sharding table query
func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery, tableName string, err error) {
ftQuery = query
Expand Down Expand Up @@ -248,7 +262,6 @@ func (s *Sharding) resolve(query string, args ...interface{}) (ftQuery, stQuery,
}
table = tbl
condition = stmt.Condition

case *sqlparser.InsertStatement:
table = stmt.TableName
isInsert = true
Expand Down
109 changes: 100 additions & 9 deletions sharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/hints"
"gorm.io/plugin/dbresolver"
)

type Order struct {
Expand All @@ -37,35 +38,80 @@ func databaseURL() string {
return databaseURL
}

func databaseReadURL() string {
databaseURL := os.Getenv("DATABASE_READ_URL")
if len(databaseURL) == 0 {
databaseURL = "postgres://localhost:5432/sharding-read-test?sslmode=disable"
if os.Getenv("DIALECTOR") == "mysql" {
databaseURL = "root@tcp(127.0.0.1:3306)/sharding-read-test?charset=utf8mb4"
}
}
return databaseURL
}

func databaseWriteURL() string {
databaseURL := os.Getenv("DATABASE_WRITE_URL")
if len(databaseURL) == 0 {
databaseURL = "postgres://localhost:5432/sharding-write-test?sslmode=disable"
if os.Getenv("DIALECTOR") == "mysql" {
databaseURL = "root@tcp(127.0.0.1:3306)/sharding-write-test?charset=utf8mb4"
}
}
return databaseURL
}

var (
dbConfig = postgres.Config{
DSN: databaseURL(),
PreferSimpleProtocol: true,
}
db *gorm.DB

shardingConfig = Config{
DoubleWrite: true,
ShardingKey: "user_id",
NumberOfShards: 4,
PrimaryKeyGenerator: PKSnowflake,
dbReadConfig = postgres.Config{
DSN: databaseReadURL(),
PreferSimpleProtocol: true,
}
dbWriteConfig = postgres.Config{
DSN: databaseWriteURL(),
PreferSimpleProtocol: true,
}
db, dbRead, dbWrite *gorm.DB

middleware = Register(shardingConfig, &Order{})
node, _ = snowflake.NewNode(1)
shardingConfig Config
middleware *Sharding
node, _ = snowflake.NewNode(1)
)

func init() {
if os.Getenv("DIALECTOR") == "mysql" {
db, _ = gorm.Open(mysql.Open(databaseURL()), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
dbRead, _ = gorm.Open(mysql.Open(databaseReadURL()), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
dbWrite, _ = gorm.Open(mysql.Open(databaseWriteURL()), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
} else {
db, _ = gorm.Open(postgres.New(dbConfig), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
dbRead, _ = gorm.Open(postgres.New(dbReadConfig), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
dbWrite, _ = gorm.Open(postgres.New(dbWriteConfig), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
}

shardingConfig = Config{
DoubleWrite: true,
ShardingKey: "user_id",
NumberOfShards: 4,
PrimaryKeyGenerator: PKSnowflake,
}

middleware = Register(shardingConfig, &Order{})

fmt.Println("Clean only tables ...")
dropTables()
fmt.Println("AutoMigrate tables ...")
Expand All @@ -80,6 +126,16 @@ func init() {
user_id bigint,
product text
)`)
dbRead.Exec(`CREATE TABLE ` + table + ` (
id bigint PRIMARY KEY,
user_id bigint,
product text
)`)
dbWrite.Exec(`CREATE TABLE ` + table + ` (
id bigint PRIMARY KEY,
user_id bigint,
product text
)`)
}

db.Use(middleware)
Expand All @@ -89,6 +145,8 @@ func dropTables() {
tables := []string{"orders", "orders_0", "orders_1", "orders_2", "orders_3", "categories"}
for _, table := range tables {
db.Exec("DROP TABLE IF EXISTS " + table)
dbRead.Exec("DROP TABLE IF EXISTS " + table)
dbWrite.Exec("DROP TABLE IF EXISTS " + table)
db.Exec(("DROP SEQUENCE IF EXISTS gorm_sharding_" + table + "_id_seq"))
}
}
Expand Down Expand Up @@ -264,6 +322,39 @@ func TestPKPGSequence(t *testing.T) {
assert.Equal(t, expected, middleware.LastQuery())
}

func TestReadWriteSplitting(t *testing.T) {
dbRead.Exec("INSERT INTO orders_0 (id, product, user_id) VALUES(1, 'iPad', 100)")
dbWrite.Exec("INSERT INTO orders_0 (id, product, user_id) VALUES(1, 'iPad', 100)")

var db *gorm.DB
if os.Getenv("DIALECTOR") == "mysql" {
db, _ = gorm.Open(mysql.Open(databaseURL()), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
} else {
db, _ = gorm.Open(postgres.New(dbConfig), &gorm.Config{
DisableForeignKeyConstraintWhenMigrating: true,
})
}

db.Use(dbresolver.Register(dbresolver.Config{
Sources: []gorm.Dialector{dbWrite.Dialector},
Replicas: []gorm.Dialector{dbRead.Dialector},
}))
db.Use(middleware)

var order Order
db.Model(&Order{}).Where("user_id", 100).Find(&order)
assert.Equal(t, "iPad", order.Product)

db.Model(&Order{}).Where("user_id", 100).Update("product", "iPhone")
db.Table("orders_0").Where("user_id", 100).Find(&order)
assert.Equal(t, "iPad", order.Product)

dbWrite.Table("orders_0").Where("user_id", 100).Find(&order)
assert.Equal(t, "iPhone", order.Product)
}

func assertQueryResult(t *testing.T, expected string, tx *gorm.DB) {
t.Helper()
assert.Equal(t, toDialect(expected), middleware.LastQuery())
Expand Down

0 comments on commit d9ee537

Please sign in to comment.