From e0e2d418e1ceaf12b1fb865c5d2723f27fed0f2d Mon Sep 17 00:00:00 2001 From: xormplus Date: Sat, 18 Nov 2017 13:51:52 +0800 Subject: [PATCH] Postgres dialect parse password with spaces --- dialect_postgres.go | 64 +++++++++++----------------------------- dialect_postgres_test.go | 44 +++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 46 deletions(-) create mode 100644 dialect_postgres_test.go diff --git a/dialect_postgres.go b/dialect_postgres.go index 5f5f20b..6eb1ba1 100644 --- a/dialect_postgres.go +++ b/dialect_postgres.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "net/url" - "sort" "strconv" "strings" @@ -1117,10 +1116,6 @@ func (vs values) Get(k string) (v string) { return vs[k] } -func errorf(s string, args ...interface{}) { - panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...))) -} - func parseURL(connstr string) (string, error) { u, err := url.Parse(connstr) if err != nil { @@ -1131,46 +1126,18 @@ func parseURL(connstr string) (string, error) { return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) } - var kvs []string escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) - accrue := func(k, v string) { - if v != "" { - kvs = append(kvs, k+"="+escaper.Replace(v)) - } - } - - if u.User != nil { - v := u.User.Username() - accrue("user", v) - - v, _ = u.User.Password() - accrue("password", v) - } - - i := strings.Index(u.Host, ":") - if i < 0 { - accrue("host", u.Host) - } else { - accrue("host", u.Host[:i]) - accrue("port", u.Host[i+1:]) - } if u.Path != "" { - accrue("dbname", u.Path[1:]) + return escaper.Replace(u.Path[1:]), nil } - q := u.Query() - for k := range q { - accrue(k, q.Get(k)) - } - - sort.Strings(kvs) // Makes testing easier (not a performance concern) - return strings.Join(kvs, " "), nil + return "", nil } -func parseOpts(name string, o values) { +func parseOpts(name string, o values) error { if len(name) == 0 { - return + return fmt.Errorf("invalid options: %s", name) } name = strings.TrimSpace(name) @@ -1179,31 +1146,36 @@ func parseOpts(name string, o values) { for _, p := range ps { kv := strings.Split(p, "=") if len(kv) < 2 { - errorf("invalid option: %q", p) + return fmt.Errorf("invalid option: %q", p) } o.Set(kv[0], kv[1]) } + + return nil } func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) { db := &core.Uri{DbType: core.POSTGRES} - o := make(values) var err error + if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") { - dataSourceName, err = parseURL(dataSourceName) + db.DbName, err = parseURL(dataSourceName) + if err != nil { + return nil, err + } + } else { + o := make(values) + err = parseOpts(dataSourceName, o) if err != nil { return nil, err } + + db.DbName = o.Get("dbname") } - parseOpts(dataSourceName, o) - db.DbName = o.Get("dbname") if db.DbName == "" { return nil, errors.New("dbname is empty") } - /*db.Schema = o.Get("schema") - if len(db.Schema) == 0 { - db.Schema = "public" - }*/ + return db, nil } diff --git a/dialect_postgres_test.go b/dialect_postgres_test.go new file mode 100644 index 0000000..0961ff7 --- /dev/null +++ b/dialect_postgres_test.go @@ -0,0 +1,44 @@ +package xorm + +import ( + "reflect" + "testing" + + "github.com/xormplus/core" +) + +func TestPostgresDialect(t *testing.T) { + TestParse(t) +} + +func TestParse(t *testing.T) { + tests := []struct { + in string + expected string + valid bool + }{ + {"postgres://auser:password@localhost:5432/db?sslmode=disable", "db", true}, + {"postgresql://auser:password@localhost:5432/db?sslmode=disable", "db", true}, + {"postg://auser:password@localhost:5432/db?sslmode=disable", "db", false}, + {"postgres://auser:pass with space@localhost:5432/db?sslmode=disable", "db", true}, + {"postgres:// auser : password@localhost:5432/db?sslmode=disable", "db", true}, + {"postgres://%20auser%20:pass%20with%20space@localhost:5432/db?sslmode=disable", "db", true}, + {"postgres://auser:パスワード@localhost:5432/データベース?sslmode=disable", "データベース", true}, + {"dbname=db sslmode=disable", "db", true}, + {"user=auser password=password dbname=db sslmode=disable", "db", true}, + {"", "db", false}, + {"dbname=db =disable", "db", false}, + } + + driver := core.QueryDriver("postgres") + + for _, test := range tests { + uri, err := driver.Parse("postgres", test.in) + + if err != nil && test.valid { + t.Errorf("%q got unexpected error: %s", test.in, err) + } else if err == nil && !reflect.DeepEqual(test.expected, uri.DbName) { + t.Errorf("%q got: %#v want: %#v", test.in, uri.DbName, test.expected) + } + } +}