diff --git a/conn_pool.go b/conn_pool.go index ca4bd30..dd45cec 100644 --- a/conn_pool.go +++ b/conn_pool.go @@ -61,7 +61,7 @@ func (pool ConnPool) QueryContext(ctx context.Context, query string, args ...int return nil, err } - pool.sharding.querys.Store("last_query", query) + pool.sharding.querys.Store("last_query", stQuery) if table != "" { if r, ok := pool.sharding.Resolvers[table]; ok { diff --git a/sharding_test.go b/sharding_test.go index e7caa24..b5c6f9e 100644 --- a/sharding_test.go +++ b/sharding_test.go @@ -60,10 +60,10 @@ var ( default: return "", err } - return fmt.Sprintf("_%03d", userId%8), nil + return fmt.Sprintf("_%02d", userId%4), nil }, ShardingAlgorithmByPrimaryKey: func(id int64) (suffix string) { - return fmt.Sprintf("_%03d", keygen.TableIdx(id)) + return fmt.Sprintf("_%02d", keygen.TableIdx(id)) }, PrimaryKeyGenerate: func(tableIdx int64) int64 { return keygen.Next(tableIdx) @@ -72,77 +72,85 @@ var ( }) ) -var tables = []string{"orders", "categories"} - func init() { dropTables() err := db.AutoMigrate(&Order{}, &Category{}) if err != nil { panic(err) } + stables := []string{"orders_00", "orders_01", "orders_02", "orders_03"} + for _, table := range stables { + db.Exec(`CREATE TABLE ` + table + ` ( + id BIGSERIAL PRIMARY KEY, + user_id bigint, + product text + )`) + } + db.Use(&sharding) } func dropTables() { - for _, tableName := range tables { - db.Exec("DROP TABLE IF EXISTS " + tableName) + tables := []string{"orders", "orders_00", "orders_01", "orders_02", "orders_03", "categories"} + for _, table := range tables { + db.Exec("DROP TABLE IF EXISTS " + table) } } func TestInsert(t *testing.T) { tx := db.Create(&Order{ID: 100, UserID: 100, Product: "iPhone"}) - assertQueryResult(t, `INSERT INTO "orders_004" ("user_id", "product", "id") VALUES ($1, $2, $3) RETURNING "id"`, tx) + assertQueryResult(t, `INSERT INTO "orders_00" ("user_id", "product", "id") VALUES ($1, $2, $3) RETURNING "id"`, tx) } func TestFillID(t *testing.T) { db.Create(&Order{UserID: 100, Product: "iPhone"}) lastQuery := sharding.LastQuery() - assert.Equal(t, `INSERT INTO "orders_004" ("user_id", "product", "id") VALUES`, lastQuery[0:57]) + assert.Equal(t, `INSERT INTO "orders_00" ("user_id", "product", "id") VALUES`, lastQuery[0:59]) } func TestSelect1(t *testing.T) { - tx := db.Model(&Order{}).Where("user_id", 101).Where("id", keygen.Next(24)).Find(&[]Order{}) - assertQueryResult(t, `SELECT * FROM "orders_024" WHERE "user_id" = $1 AND "id" = $2`, tx) + tx := db.Model(&Order{}).Where("user_id", 101).Where("id", keygen.Next(1)).Find(&[]Order{}) + assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "user_id" = $1 AND "id" = $2`, tx) } func TestSelect2(t *testing.T) { - tx := db.Model(&Order{}).Where("id", keygen.Next(24)).Where("user_id", 101).Find(&[]Order{}) - assertQueryResult(t, `SELECT * FROM "orders_005" WHERE "id" = $1 AND "user_id" = $2`, tx) + tx := db.Model(&Order{}).Where("id", keygen.Next(1)).Where("user_id", 101).Find(&[]Order{}) + assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "id" = $1 AND "user_id" = $2`, tx) } func TestSelect3(t *testing.T) { - tx := db.Model(&Order{}).Where("id", keygen.Next(24)).Where("user_id = 101").Find(&[]Order{}) - assertQueryResult(t, `SELECT * FROM "orders_005" WHERE "id" = $1 AND "user_id" = 101`, tx) + tx := db.Model(&Order{}).Where("id", keygen.Next(1)).Where("user_id = 101").Find(&[]Order{}) + assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "id" = $1 AND "user_id" = 101`, tx) } func TestSelect4(t *testing.T) { tx := db.Model(&Order{}).Where("product", "iPad").Where("user_id", 100).Find(&[]Order{}) - assertQueryResult(t, `SELECT * FROM "orders_004" WHERE "product" = $1 AND "user_id" = $2`, tx) + assertQueryResult(t, `SELECT * FROM "orders_00" WHERE "product" = $1 AND "user_id" = $2`, tx) } func TestSelect5(t *testing.T) { tx := db.Model(&Order{}).Where("user_id = 101").Find(&[]Order{}) - assertQueryResult(t, `SELECT * FROM "orders_005" WHERE "user_id" = 101`, tx) + assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "user_id" = 101`, tx) } func TestSelect6(t *testing.T) { - tx := db.Model(&Order{}).Where("id", keygen.Next(24)).Find(&[]Order{}) - assertQueryResult(t, `SELECT * FROM "orders_024" WHERE "id" = $1`, tx) + tx := db.Model(&Order{}).Where("id", keygen.Next(2)).Find(&[]Order{}) + assertQueryResult(t, `SELECT * FROM "orders_02" WHERE "id" = $1`, tx) } func TestSelect7(t *testing.T) { - tx := db.Model(&Order{}).Where("user_id", 101).Where("id > ?", keygen.Next(24)).Find(&[]Order{}) - assertQueryResult(t, `SELECT * FROM "orders_005" WHERE "user_id" = $1 AND "id" > $2`, tx) + tx := db.Model(&Order{}).Where("user_id", 101).Where("id > ?", keygen.Next(1)).Find(&[]Order{}) + assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "user_id" = $1 AND "id" > $2`, tx) } func TestSelect8(t *testing.T) { - tx := db.Model(&Order{}).Where("id > ?", keygen.Next(24)).Where("user_id", 101).Find(&[]Order{}) - assertQueryResult(t, `SELECT * FROM "orders_005" WHERE "id" > $1 AND "user_id" = $2`, tx) + tx := db.Model(&Order{}).Where("id > ?", keygen.Next(1)).Where("user_id", 101).Find(&[]Order{}) + assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "id" > $1 AND "user_id" = $2`, tx) } func TestSelect9(t *testing.T) { tx := db.Model(&Order{}).Where("user_id = 101").First(&[]Order{}) - assertQueryResult(t, `SELECT * FROM "orders_005" WHERE "user_id" = 101 ORDER BY "orders_005"."id" LIMIT 1`, tx) + assertQueryResult(t, `SELECT * FROM "orders_01" WHERE "user_id" = 101 ORDER BY "orders_01"."id" LIMIT 1`, tx) } func TestSelect10(t *testing.T) { @@ -167,12 +175,12 @@ func TestSelect13(t *testing.T) { func TestUpdate(t *testing.T) { tx := db.Model(&Order{}).Where("user_id = ?", 100).Update("product", "new title") - assertQueryResult(t, `UPDATE "orders_004" SET "product" = $1 WHERE "user_id" = $2`, tx) + assertQueryResult(t, `UPDATE "orders_00" SET "product" = $1 WHERE "user_id" = $2`, tx) } func TestDelete(t *testing.T) { tx := db.Where("user_id = ?", 100).Delete(&Order{}) - assertQueryResult(t, `DELETE FROM "orders_004" WHERE "user_id" = $1`, tx) + assertQueryResult(t, `DELETE FROM "orders_00" WHERE "user_id" = $1`, tx) } func TestInsertMissingShardingKey(t *testing.T) {