Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Postgres: Correct parsing of multiline statements #8512

Merged
merged 10 commits into from
Mar 17, 2020
69 changes: 68 additions & 1 deletion plugins/database/postgresql/postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"errors"
"fmt"
"regexp"
"strings"
"time"

Expand All @@ -29,7 +30,22 @@ ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';
`
)

var _ dbplugin.Database = &PostgreSQL{}
var (
_ dbplugin.Database = &PostgreSQL{}

// postgresEndStatement is basically the word "END" but
// surrounded by a word boundary to differentiate it from
// other words like "APPEND".
postgresEndStatement = regexp.MustCompile(`\bEND\b`)

// doubleQuotedPhrases finds substrings like "hello"
// and pulls them out with the quotes included.
doubleQuotedPhrases = regexp.MustCompile(`(".*?")`)

// singleQuotedPhrases finds substrings like 'hello'
// and pulls them out with the quotes included.
singleQuotedPhrases = regexp.MustCompile(`('.*?')`)
)

// New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) {
Expand Down Expand Up @@ -206,6 +222,20 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme

// Execute each query
for _, stmt := range statements.Creation {
if containsMultilineStatement(stmt) {
// Execute it as-is.
m := map[string]string{
"name": username,
"username": username,
"password": password,
"expiration": expirationStr,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, stmt); err != nil {
return "", "", err
}
continue
}
// Otherwise, it's fine to split the statements on the semicolon.
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
Expand Down Expand Up @@ -501,3 +531,40 @@ func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []str
p.RawConfig["password"] = password
return p.RawConfig, nil
}

// containsMultilineStatement is a best effort to determine whether
// a particular statement is multiline, and therefore should not be
// split upon semicolons. If it's unsure, it defaults to false.
func containsMultilineStatement(stmt string) bool {
// We're going to look for the word "END", but first let's ignore
// anything the user provided within single or double quotes since
// we're looking for an "END" within the Postgres syntax.
literals, err := extractQuotedStrings(stmt)
if err != nil {
return false
}
stmtWithoutLiterals := stmt
for _, literal := range literals {
stmtWithoutLiterals = strings.Replace(stmt, literal, "", -1)
}
// Now look for the word "END" specifically. This will miss any
// representations of END that aren't surrounded by spaces, but
// it should be easy to change on the user's side.
return postgresEndStatement.MatchString(stmtWithoutLiterals)
}

// extractQuotedStrings extracts 0 or many substrings
// that have been single- or double-quoted. Ex:
// `"Hello", silly 'elephant' from the "zoo".`
// returns [ `Hello`, `'elephant'`, `"zoo"` ]
func extractQuotedStrings(s string) ([]string, error) {
var found []string
toFind := []*regexp.Regexp{
doubleQuotedPhrases,
singleQuotedPhrases,
}
for _, typeOfPhrase := range toFind {
found = append(found, typeOfPhrase.FindAllString(s, -1)...)
}
return found, nil
}
81 changes: 81 additions & 0 deletions plugins/database/postgresql/postgresql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,12 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`,
},
},
"reproduce https://github.com/hashicorp/vault/issues/6098": {
tyrannosaurus-becks marked this conversation as resolved.
Show resolved Hide resolved
createStmts: []string{
// NOTE: "rolname" in the following line is not a typo.
"DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$",
tyrannosaurus-becks marked this conversation as resolved.
Show resolved Hide resolved
},
},
}

// Shared test container for speed - there should not be any overlap between the tests
Expand Down Expand Up @@ -192,6 +198,11 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
t.Fatalf("err: %s", err)
}

if name == "reproduce https://github.com/hashicorp/vault/issues/6098" {
// This test doesn't create creds, so we don't need to test them.
return
}

if err = testCredsExist(t, connURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
Expand Down Expand Up @@ -657,3 +668,73 @@ func createTestPGUser(t *testing.T, connURL string, username, password, query st
t.Fatal(err)
}
}

func TestContainsMultilineStatement(t *testing.T) {
type testCase struct {
Input string
Expected bool
}

testCases := map[string]*testCase{
"issue 6098 repro": {
Input: `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$`,
Expected: true,
},
"docs example": {
Input: `CREATE ROLE \"{{name}}\" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; \
GRANT SELECT ON ALL TABLES IN SCHEMA public TO \"{{name}}\";`,
Expected: false,
},
}

for tName, tCase := range testCases {
t.Run(tName, func(t *testing.T) {
if containsMultilineStatement(tCase.Input) != tCase.Expected {
t.Fatalf("%q should be %t for multiline input", tCase.Input, tCase.Expected)
}
})
}
}

func TestExtractQuotedStrings(t *testing.T) {
type testCase struct {
Input string
Expected []string
}

testCases := map[string]*testCase{
"no quotes": {
Input: `Five little monkeys jumping on the bed`,
Expected: []string{},
},
"two of both quote types": {
Input: `"Five" little 'monkeys' "jumping on" the' 'bed`,
Expected: []string{`"Five"`, `"jumping on"`, `'monkeys'`, `' '`},
},
"one single quote": {
Input: `Five little monkeys 'jumping on the bed`,
Expected: []string{},
},
"empty string": {
Input: ``,
Expected: []string{},
},
}

for tName, tCase := range testCases {
t.Run(tName, func(t *testing.T) {
results, err := extractQuotedStrings(tCase.Input)
if err != nil {
t.Fatal(err)
}
if len(results) != len(tCase.Expected) {
t.Fatalf("%s isn't equal to %s", results, tCase.Expected)
}
for i := range results {
if results[i] != tCase.Expected[i] {
t.Fatalf(`expected %q but received %q`, tCase.Expected, results[i])
}
}
})
}
}