Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hyperphoton committed Jan 17, 2022
1 parent d9def5a commit 6d8dd1e
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 26 deletions.
2 changes: 1 addition & 1 deletion conn_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func (pool ConnPool) QueryContext(ctx context.Context, query string, args ...int
return nil, err
}

pool.sharding.querys.Store("last_query", query)
pool.sharding.querys.Store("last_query", stQuery)

if table != "" {
if r, ok := pool.sharding.Resolvers[table]; ok {
Expand Down
58 changes: 33 additions & 25 deletions sharding_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ var (
default:
return "", err
}
return fmt.Sprintf("_%03d", userId%8), nil
return fmt.Sprintf("_%02d", userId%4), nil
},
ShardingAlgorithmByPrimaryKey: func(id int64) (suffix string) {
return fmt.Sprintf("_%03d", keygen.TableIdx(id))
return fmt.Sprintf("_%02d", keygen.TableIdx(id))
},
PrimaryKeyGenerate: func(tableIdx int64) int64 {
return keygen.Next(tableIdx)
Expand All @@ -72,77 +72,85 @@ var (
})
)

var tables = []string{"orders", "categories"}

func init() {
dropTables()
err := db.AutoMigrate(&Order{}, &Category{})
if err != nil {
panic(err)
}
stables := []string{"orders_00", "orders_01", "orders_02", "orders_03"}
for _, table := range stables {
db.Exec(`CREATE TABLE ` + table + ` (
id BIGSERIAL PRIMARY KEY,
user_id bigint,
product text
)`)
}

db.Use(&sharding)
}

func dropTables() {
for _, tableName := range tables {
db.Exec("DROP TABLE IF EXISTS " + tableName)
tables := []string{"orders", "orders_00", "orders_01", "orders_02", "orders_03", "categories"}
for _, table := range tables {
db.Exec("DROP TABLE IF EXISTS " + table)
}
}

func TestInsert(t *testing.T) {
tx := db.Create(&Order{ID: 100, UserID: 100, Product: "iPhone"})
assertQueryResult(t, `INSERT INTO "orders_004" ("user_id", "product", "id") VALUES ($1, $2, $3) RETURNING "id"`, tx)
assertQueryResult(t, `INSERT INTO "orders_00" ("user_id", "product", "id") VALUES ($1, $2, $3) RETURNING "id"`, tx)
}

func TestFillID(t *testing.T) {
db.Create(&Order{UserID: 100, Product: "iPhone"})
lastQuery := sharding.LastQuery()
assert.Equal(t, `INSERT INTO "orders_004" ("user_id", "product", "id") VALUES`, lastQuery[0:57])
assert.Equal(t, `INSERT INTO "orders_00" ("user_id", "product", "id") VALUES`, lastQuery[0:59])
}

func TestSelect1(t *testing.T) {
tx := db.Model(&Order{}).Where("user_id", 101).Where("id", keygen.Next(24)).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_024" WHERE "user_id" = $1 AND "id" = $2`, tx)
tx := db.Model(&Order{}).Where("user_id", 101).Where("id", keygen.Next(1)).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "user_id" = $1 AND "id" = $2`, tx)
}

func TestSelect2(t *testing.T) {
tx := db.Model(&Order{}).Where("id", keygen.Next(24)).Where("user_id", 101).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_005" WHERE "id" = $1 AND "user_id" = $2`, tx)
tx := db.Model(&Order{}).Where("id", keygen.Next(1)).Where("user_id", 101).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "id" = $1 AND "user_id" = $2`, tx)
}

func TestSelect3(t *testing.T) {
tx := db.Model(&Order{}).Where("id", keygen.Next(24)).Where("user_id = 101").Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_005" WHERE "id" = $1 AND "user_id" = 101`, tx)
tx := db.Model(&Order{}).Where("id", keygen.Next(1)).Where("user_id = 101").Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "id" = $1 AND "user_id" = 101`, tx)
}

func TestSelect4(t *testing.T) {
tx := db.Model(&Order{}).Where("product", "iPad").Where("user_id", 100).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_004" WHERE "product" = $1 AND "user_id" = $2`, tx)
assertQueryResult(t, `SELECT * FROM "orders_00" WHERE "product" = $1 AND "user_id" = $2`, tx)
}

func TestSelect5(t *testing.T) {
tx := db.Model(&Order{}).Where("user_id = 101").Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_005" WHERE "user_id" = 101`, tx)
assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "user_id" = 101`, tx)
}

func TestSelect6(t *testing.T) {
tx := db.Model(&Order{}).Where("id", keygen.Next(24)).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_024" WHERE "id" = $1`, tx)
tx := db.Model(&Order{}).Where("id", keygen.Next(2)).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_02" WHERE "id" = $1`, tx)
}

func TestSelect7(t *testing.T) {
tx := db.Model(&Order{}).Where("user_id", 101).Where("id > ?", keygen.Next(24)).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_005" WHERE "user_id" = $1 AND "id" > $2`, tx)
tx := db.Model(&Order{}).Where("user_id", 101).Where("id > ?", keygen.Next(1)).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "user_id" = $1 AND "id" > $2`, tx)
}

func TestSelect8(t *testing.T) {
tx := db.Model(&Order{}).Where("id > ?", keygen.Next(24)).Where("user_id", 101).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_005" WHERE "id" > $1 AND "user_id" = $2`, tx)
tx := db.Model(&Order{}).Where("id > ?", keygen.Next(1)).Where("user_id", 101).Find(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "id" > $1 AND "user_id" = $2`, tx)
}

func TestSelect9(t *testing.T) {
tx := db.Model(&Order{}).Where("user_id = 101").First(&[]Order{})
assertQueryResult(t, `SELECT * FROM "orders_005" WHERE "user_id" = 101 ORDER BY "orders_005"."id" LIMIT 1`, tx)
assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "user_id" = 101 ORDER BY "orders_01"."id" LIMIT 1`, tx)
}

func TestSelect10(t *testing.T) {
Expand All @@ -167,12 +175,12 @@ func TestSelect13(t *testing.T) {

func TestUpdate(t *testing.T) {
tx := db.Model(&Order{}).Where("user_id = ?", 100).Update("product", "new title")
assertQueryResult(t, `UPDATE "orders_004" SET "product" = $1 WHERE "user_id" = $2`, tx)
assertQueryResult(t, `UPDATE "orders_00" SET "product" = $1 WHERE "user_id" = $2`, tx)
}

func TestDelete(t *testing.T) {
tx := db.Where("user_id = ?", 100).Delete(&Order{})
assertQueryResult(t, `DELETE FROM "orders_004" WHERE "user_id" = $1`, tx)
assertQueryResult(t, `DELETE FROM "orders_00" WHERE "user_id" = $1`, tx)
}

func TestInsertMissingShardingKey(t *testing.T) {
Expand Down

0 comments on commit 6d8dd1e

Please sign in to comment.