Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hantmac committed Jan 14, 2025
1 parent 7fb56d1 commit 11f56a8
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 55 deletions.
104 changes: 53 additions & 51 deletions cmd/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"fmt"
"log"
"net/url"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -92,7 +93,7 @@ func TestOracleMultiTableWorkflow(t *testing.T) {
assert.NoError(t, err)
}

func TestWorkFlow(t *testing.T) {
func TestMySQLWorkFlow(t *testing.T) {
{
fmt.Println("=== TEST MYSQL SOURCE ===")
prepareMysql()
Expand Down Expand Up @@ -139,56 +140,56 @@ func TestWorkFlow(t *testing.T) {
err = checkTargetTable("test_table", 20)
assert.NoError(t, err)
}
}

{
fmt.Println("=== TEST MSSQL SOURCE ===")
prepareSQLServer()
truncateDatabend("test_table", "http://databend:databend@localhost:8000")
prepareDatabend("test_table", "http://databend:databend@localhost:8000")
testConfig := prepareSqlServerTestConfig()
startTime := time.Now()
func TestMssqlWorkflow(t *testing.T) {
fmt.Println("=== TEST MSSQL SOURCE ===")
prepareSQLServer()
truncateDatabend("test_table", "http://databend:databend@localhost:8000")
prepareDatabend("test_table", "http://databend:databend@localhost:8000")
testConfig := prepareSqlServerTestConfig()
startTime := time.Now()

src, err := source.NewSource(testConfig)
if err != nil {
panic(err)
}
wg := sync.WaitGroup{}
dbs, err := src.GetDatabasesAccordingToSourceDbRegex(testConfig.SourceDB)
if err != nil {
panic(err)
}
log.Printf("dbs: %v", dbs)
dbTables, err := src.GetTablesAccordingToSourceTableRegex(testConfig.SourceTable, dbs)
if err != nil {
panic(err)
}
log.Printf("dbTables: %v", dbTables)
for db, tables := range dbTables {
for _, table := range tables {
wg.Add(1)
db := db
table := table
go func(cfg *cfg.Config, db, table string) {
cfgCopy := *testConfig
cfgCopy.SourceTable = table
cfgCopy.SourceDB = db
ig := ingester.NewDatabendIngester(&cfgCopy)
src, err := source.NewSource(&cfgCopy)
assert.NoError(t, err)
w := worker.NewWorker(&cfgCopy, fmt.Sprintf("%s.%s", db, table), ig, src)
w.Run(context.Background())
wg.Done()
}(testConfig, db, table)
}
src, err := source.NewSource(testConfig)
if err != nil {
panic(err)
}
wg := sync.WaitGroup{}
dbs, err := src.GetDatabasesAccordingToSourceDbRegex(testConfig.SourceDB)
if err != nil {
panic(err)
}
log.Printf("dbs: %v", dbs)
dbTables, err := src.GetTablesAccordingToSourceTableRegex(testConfig.SourceTable, dbs)
if err != nil {
panic(err)
}
log.Printf("dbTables: %v", dbTables)
for db, tables := range dbTables {
for _, table := range tables {
wg.Add(1)
db := db
table := table
go func(cfg *cfg.Config, db, table string) {
cfgCopy := *testConfig
cfgCopy.SourceTable = table
cfgCopy.SourceDB = db
ig := ingester.NewDatabendIngester(&cfgCopy)
src, err := source.NewSource(&cfgCopy)
assert.NoError(t, err)
w := worker.NewWorker(&cfgCopy, fmt.Sprintf("%s.%s", db, table), ig, src)
w.Run(context.Background())
wg.Done()
}(testConfig, db, table)
}
wg.Wait()
endTime := fmt.Sprintf("end time: %s", time.Now().Format("2006-01-02 15:04:05"))
fmt.Println(endTime)
fmt.Println(fmt.Sprintf("total time: %s", time.Since(startTime)))

err = checkTargetTable("test_table", 25)
assert.NoError(t, err)
}
wg.Wait()
endTime := fmt.Sprintf("end time: %s", time.Now().Format("2006-01-02 15:04:05"))
fmt.Println(endTime)
fmt.Println(fmt.Sprintf("total time: %s", time.Since(startTime)))

err = checkTargetTable("test_table", 10)
assert.NoError(t, err)
}

func TestSimpleOracleWorkflow(t *testing.T) {
Expand Down Expand Up @@ -587,8 +588,9 @@ func prepareTestConfig() *cfg.Config {

func prepareSQLServer() {
log.Println("===prepareSQLServer===")
encodedPassword := url.QueryEscape("Passw@rd")
// sqlserver://username:password@host:port?database=dbname
db, err := sql.Open("mssql", "sqlserver://sa:Password1234!@localhost:1433?encrypt=disable")
db, err := sql.Open("mssql", fmt.Sprintf("sqlserver://sa:%s@localhost:1433?encrypt=disable", encodedPassword))
if err != nil {
log.Fatal(err)
}
Expand All @@ -612,11 +614,11 @@ func prepareSQLServer() {
log.Fatal(err)
}

// 切换到新创建的数据库
_, err = db.Exec("USE mydb")
db, err = sql.Open("mssql", fmt.Sprintf("sqlserver://sa:%s@localhost:1433?database=mydb&encrypt=disable", encodedPassword))
if err != nil {
log.Fatal(err)
}
defer db.Close()

// 创建表
_, err = db.Exec(`
Expand Down Expand Up @@ -691,7 +693,7 @@ func prepareSqlServerTestConfig() *cfg.Config {
SourceHost: "127.0.0.1",
SourcePort: 1433,
SourceUser: "sa",
SourcePass: "Password1234!",
SourcePass: "Passw@rd",
SourceTable: "test_table",
SourceWhereCondition: "id > 0",
SourceSplitKey: "id",
Expand Down
2 changes: 1 addition & 1 deletion config/config_test_mssql.json
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"sourceSplitKey": "id",
"sourceSplitTimeKey": "",
"timeSplitUnit": "minute",
"databendDSN": "https://user:password@host.databend.com:443",
"databendDSN": "https://user:pass@host.databend.com:443",
"databendTable": "testSync.test1",
"batchSize": 2,
"batchMaxInterval": 30,
Expand Down
14 changes: 11 additions & 3 deletions source/sql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,9 +318,9 @@ func (s *SQLServerSource) QueryTableData(threadNum int, conditionSql string) ([]
batchSize)

// 设置查询超时
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
rows, err := s.db.QueryContext(ctx, query)
cancel()

if err != nil {
return nil, nil, fmt.Errorf("executing batch query at offset %d: %w", offset, err)
Expand Down Expand Up @@ -357,7 +357,11 @@ func (s *SQLServerSource) QueryTableData(threadNum int, conditionSql string) ([]
}
case *sql.NullBool:
if v.Valid {
row[i] = v.Bool
if v.Bool {
row[i] = 1 // target databend bool is int8
} else {
row[i] = 0
}
} else {
row[i] = nil
}
Expand Down Expand Up @@ -480,6 +484,7 @@ func (s *SQLServerSource) GetTablesAccordingToSourceTableRegex(sourceTablePatter

// 构建完整的表名(包含schema)
fullTableName := fmt.Sprintf("%s.%s", schemaName, tableName)
fmt.Println("full name table:", fullTableName)

match, err := regexp.MatchString(sourceTablePattern, fullTableName)
if err != nil {
Expand All @@ -488,6 +493,7 @@ func (s *SQLServerSource) GetTablesAccordingToSourceTableRegex(sourceTablePatter
}

if match {
fmt.Println("match table:", fullTableName)
tables = append(tables, fullTableName)
}
}
Expand All @@ -500,6 +506,8 @@ func (s *SQLServerSource) GetTablesAccordingToSourceTableRegex(sourceTablePatter
dbTables[database] = tables
}

fmt.Println("dbTables:", dbTables)

return dbTables, nil
}

Expand Down

0 comments on commit 11f56a8

Please sign in to comment.