Skip to content

Commit

Permalink
Postgres dialect parse password with spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
xormplus committed Nov 18, 2017
1 parent 1ba0d4c commit e0e2d41
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 46 deletions.
64 changes: 18 additions & 46 deletions dialect_postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"errors"
"fmt"
"net/url"
"sort"
"strconv"
"strings"

Expand Down Expand Up @@ -1117,10 +1116,6 @@ func (vs values) Get(k string) (v string) {
return vs[k]
}

func errorf(s string, args ...interface{}) {
panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
}

func parseURL(connstr string) (string, error) {
u, err := url.Parse(connstr)
if err != nil {
Expand All @@ -1131,46 +1126,18 @@ func parseURL(connstr string) (string, error) {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}

var kvs []string
escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
accrue := func(k, v string) {
if v != "" {
kvs = append(kvs, k+"="+escaper.Replace(v))
}
}

if u.User != nil {
v := u.User.Username()
accrue("user", v)

v, _ = u.User.Password()
accrue("password", v)
}

i := strings.Index(u.Host, ":")
if i < 0 {
accrue("host", u.Host)
} else {
accrue("host", u.Host[:i])
accrue("port", u.Host[i+1:])
}

if u.Path != "" {
accrue("dbname", u.Path[1:])
return escaper.Replace(u.Path[1:]), nil
}

q := u.Query()
for k := range q {
accrue(k, q.Get(k))
}

sort.Strings(kvs) // Makes testing easier (not a performance concern)
return strings.Join(kvs, " "), nil
return "", nil
}

func parseOpts(name string, o values) {
func parseOpts(name string, o values) error {
if len(name) == 0 {
return
return fmt.Errorf("invalid options: %s", name)
}

name = strings.TrimSpace(name)
Expand All @@ -1179,31 +1146,36 @@ func parseOpts(name string, o values) {
for _, p := range ps {
kv := strings.Split(p, "=")
if len(kv) < 2 {
errorf("invalid option: %q", p)
return fmt.Errorf("invalid option: %q", p)
}
o.Set(kv[0], kv[1])
}

return nil
}

func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
db := &core.Uri{DbType: core.POSTGRES}
o := make(values)
var err error

if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") {
dataSourceName, err = parseURL(dataSourceName)
db.DbName, err = parseURL(dataSourceName)
if err != nil {
return nil, err
}
} else {
o := make(values)
err = parseOpts(dataSourceName, o)
if err != nil {
return nil, err
}

db.DbName = o.Get("dbname")
}
parseOpts(dataSourceName, o)

db.DbName = o.Get("dbname")
if db.DbName == "" {
return nil, errors.New("dbname is empty")
}
/*db.Schema = o.Get("schema")
if len(db.Schema) == 0 {
db.Schema = "public"
}*/

return db, nil
}
44 changes: 44 additions & 0 deletions dialect_postgres_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package xorm

import (
"reflect"
"testing"

"github.com/xormplus/core"
)

func TestPostgresDialect(t *testing.T) {
TestParse(t)
}

func TestParse(t *testing.T) {
tests := []struct {
in string
expected string
valid bool
}{
{"postgres://auser:password@localhost:5432/db?sslmode=disable", "db", true},
{"postgresql://auser:password@localhost:5432/db?sslmode=disable", "db", true},
{"postg://auser:password@localhost:5432/db?sslmode=disable", "db", false},
{"postgres://auser:pass with space@localhost:5432/db?sslmode=disable", "db", true},
{"postgres:// auser : password@localhost:5432/db?sslmode=disable", "db", true},
{"postgres://%20auser%20:pass%20with%20space@localhost:5432/db?sslmode=disable", "db", true},
{"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true},
{"dbname=db sslmode=disable", "db", true},
{"user=auser password=password dbname=db sslmode=disable", "db", true},
{"", "db", false},
{"dbname=db =disable", "db", false},
}

driver := core.QueryDriver("postgres")

for _, test := range tests {
uri, err := driver.Parse("postgres", test.in)

if err != nil && test.valid {
t.Errorf("%q got unexpected error: %s", test.in, err)
} else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) {
t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected)
}
}
}

0 comments on commit e0e2d41

Please sign in to comment.