diff --git a/dialect_postgres.go b/dialect_postgres.go index d907c68c0..1f74bd312 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -1223,3 +1223,15 @@ func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { return db, nil } + +type pqDriverPgx struct { + pqDriver +} + +func (pgx *pqDriverPgx) Parse(driverName, dataSourceName string) (*core.Uri, error) { + // Remove the leading characters for driver to work + if len(dataSourceName) >= 9 && dataSourceName[0] == 0 { + dataSourceName = dataSourceName[9:] + } + return pgx.pqDriver.Parse(driverName, dataSourceName) +} diff --git a/dialect_postgres_test.go b/dialect_postgres_test.go index 2ee1e2f38..6e6c44bbe 100644 --- a/dialect_postgres_test.go +++ b/dialect_postgres_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/go-xorm/core" + "github.com/jackc/pgx/stdlib" ) func TestParsePostgres(t *testing.T) { @@ -38,3 +39,48 @@ func TestParsePostgres(t *testing.T) { } } } + +func TestParsePgx(t *testing.T) { + tests := []struct { + in string + expected string + valid bool + }{ + {"postgres://auser:password@localhost:5432/db?sslmode=disable", "db", true}, + {"postgresql://auser:password@localhost:5432/db?sslmode=disable", "db", true}, + {"postg://auser:password@localhost:5432/db?sslmode=disable", "db", false}, + //{"postgres://auser:pass with space@localhost:5432/db?sslmode=disable", "db", true}, + //{"postgres:// auser : password@localhost:5432/db?sslmode=disable", "db", true}, + {"postgres://%20auser%20:pass%20with%20space@localhost:5432/db?sslmode=disable", "db", true}, + //{"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true}, + {"dbname=db sslmode=disable", "db", true}, + {"user=auser password=password dbname=db sslmode=disable", "db", true}, + {"", "db", false}, + {"dbname=db =disable", "db", false}, + } + + driver := core.QueryDriver("pgx") + + for _, test := range tests { + uri, err := driver.Parse("pgx", test.in) + + if err != nil && test.valid { + t.Errorf("%q got unexpected error: %s", test.in, err) + } else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { + t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected) + } + + // Register DriverConfig + drvierConfig := stdlib.DriverConfig{} + stdlib.RegisterDriverConfig(&drvierConfig) + uri, err = driver.Parse("pgx", + drvierConfig.ConnectionString(test.in)) + if err != nil && test.valid { + t.Errorf("%q got unexpected error: %s", test.in, err) + } else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { + t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected) + } + + } + +} diff --git a/xorm.go b/xorm.go index 13d0951ba..141c4897d 100644 --- a/xorm.go +++ b/xorm.go @@ -31,7 +31,7 @@ func regDrvsNDialects() bool { "mysql": {"mysql", func() core.Driver { return &mysqlDriver{} }, func() core.Dialect { return &mysql{} }}, "mymysql": {"mysql", func() core.Driver { return &mymysqlDriver{} }, func() core.Dialect { return &mysql{} }}, "postgres": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }}, - "pgx": {"postgres", func() core.Driver { return &pqDriver{} }, func() core.Dialect { return &postgres{} }}, + "pgx": {"postgres", func() core.Driver { return &pqDriverPgx{} }, func() core.Dialect { return &postgres{} }}, "sqlite3": {"sqlite3", func() core.Driver { return &sqlite3Driver{} }, func() core.Dialect { return &sqlite3{} }}, "oci8": {"oracle", func() core.Driver { return &oci8Driver{} }, func() core.Dialect { return &oracle{} }}, "goracle": {"oracle", func() core.Driver { return &goracleDriver{} }, func() core.Dialect { return &oracle{} }},