Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Postgres: Correct parsing of multiline statements #8512

Merged
merged 10 commits into from
Mar 17, 2020
69 changes: 68 additions & 1 deletion plugins/database/postgresql/postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"database/sql"
"errors"
"fmt"
"regexp"
"strings"
"time"

Expand All @@ -29,7 +30,22 @@ ALTER ROLE "{{username}}" WITH PASSWORD '{{password}}';
`
)

var _ dbplugin.Database = &PostgreSQL{}
var (
_ dbplugin.Database = &PostgreSQL{}

// postgresEndStatement is basically the word "END" but
// surrounded by a word boundary to differentiate it from
// other words like "APPEND".
postgresEndStatement = regexp.MustCompile(`\bEND\b`)

// doubleQuotedPhrases finds substrings like "hello"
// and pulls them out with the quotes included.
doubleQuotedPhrases = regexp.MustCompile(`(".*?")`)

// singleQuotedPhrases finds substrings like 'hello'
// and pulls them out with the quotes included.
singleQuotedPhrases = regexp.MustCompile(`('.*?')`)
)

// New implements builtinplugins.BuiltinFactory
func New() (interface{}, error) {
Expand Down Expand Up @@ -206,6 +222,20 @@ func (p *PostgreSQL) CreateUser(ctx context.Context, statements dbplugin.Stateme

// Execute each query
for _, stmt := range statements.Creation {
if containsMultilineStatement(stmt) {
// Execute it as-is.
m := map[string]string{
"name": username,
"username": username,
"password": password,
"expiration": expirationStr,
}
if err := dbtxn.ExecuteTxQuery(ctx, tx, m, stmt); err != nil {
return "", "", err
}
continue
}
// Otherwise, it's fine to split the statements on the semicolon.
for _, query := range strutil.ParseArbitraryStringSlice(stmt, ";") {
query = strings.TrimSpace(query)
if len(query) == 0 {
Expand Down Expand Up @@ -501,3 +531,40 @@ func (p *PostgreSQL) RotateRootCredentials(ctx context.Context, statements []str
p.RawConfig["password"] = password
return p.RawConfig, nil
}

// containsMultilineStatement is a best effort to determine whether
// a particular statement is multiline, and therefore should not be
// split upon semicolons. If it's unsure, it defaults to false.
func containsMultilineStatement(stmt string) bool {
// We're going to look for the word "END", but first let's ignore
// anything the user provided within single or double quotes since
// we're looking for an "END" within the Postgres syntax.
literals, err := extractQuotedStrings(stmt)
if err != nil {
return false
}
stmtWithoutLiterals := stmt
for _, literal := range literals {
stmtWithoutLiterals = strings.Replace(stmt, literal, "", -1)
}
// Now look for the word "END" specifically. This will miss any
// representations of END that aren't surrounded by spaces, but
// it should be easy to change on the user's side.
return postgresEndStatement.MatchString(stmtWithoutLiterals)
}

// extractQuotedStrings extracts 0 or many substrings
// that have been single- or double-quoted. Ex:
// `"Hello", silly 'elephant' from the "zoo".`
// returns [ `Hello`, `'elephant'`, `"zoo"` ]
func extractQuotedStrings(s string) ([]string, error) {
var found []string
toFind := []*regexp.Regexp{
doubleQuotedPhrases,
singleQuotedPhrases,
}
for _, typeOfPhrase := range toFind {
found = append(found, typeOfPhrase.FindAllString(s, -1)...)
}
return found, nil
}
106 changes: 105 additions & 1 deletion plugins/database/postgresql/postgresql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ func TestPostgreSQL_CreateUser_missingArgs(t *testing.T) {

func TestPostgreSQL_CreateUser(t *testing.T) {
type testCase struct {
createStmts []string
createStmts []string
shouldTestCredsExist bool
}

tests := map[string]testCase{
Expand All @@ -126,6 +127,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
VALID UNTIL '{{expiration}}';
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{name}}";`,
},
shouldTestCredsExist: true,
},
"admin username": {
createStmts: []string{`
Expand All @@ -135,6 +137,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
VALID UNTIL '{{expiration}}';
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO "{{username}}";`,
},
shouldTestCredsExist: true,
},
"read only name": {
createStmts: []string{`
Expand All @@ -145,6 +148,7 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{name}}";`,
},
shouldTestCredsExist: true,
},
"read only username": {
createStmts: []string{`
Expand All @@ -155,6 +159,23 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{username}}";
GRANT SELECT ON ALL SEQUENCES IN SCHEMA public TO "{{username}}";`,
},
shouldTestCredsExist: true,
},
// https://github.com/hashicorp/vault/issues/6098
"reproduce GH-6098": {
createStmts: []string{
// NOTE: "rolname" in the following line is not a typo.
"DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$",
tyrannosaurus-becks marked this conversation as resolved.
Show resolved Hide resolved
},
// This test statement doesn't generate creds.
shouldTestCredsExist: false,
},
"reproduce issue with template": {
createStmts: []string{
`DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE "{{username}}"; END IF; END $$`,
},
// This test statement doesn't generate creds.
shouldTestCredsExist: false,
},
}

Expand Down Expand Up @@ -192,6 +213,11 @@ func TestPostgreSQL_CreateUser(t *testing.T) {
t.Fatalf("err: %s", err)
}

if !test.shouldTestCredsExist {
// We're done here.
return
}

if err = testCredsExist(t, connURL, username, password); err != nil {
t.Fatalf("Could not connect with new credentials: %s", err)
}
Expand Down Expand Up @@ -657,3 +683,81 @@ func createTestPGUser(t *testing.T, connURL string, username, password, query st
t.Fatal(err)
}
}

func TestContainsMultilineStatement(t *testing.T) {
type testCase struct {
Input string
Expected bool
}

testCases := map[string]*testCase{
"issue 6098 repro": {
Input: `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname='my_role') THEN CREATE ROLE my_role; END IF; END $$`,
Expected: true,
},
"multiline with template fields": {
Input: `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname="{{name}}") THEN CREATE ROLE {{name}}; END IF; END $$`,
Expected: true,
},
"docs example": {
Input: `CREATE ROLE "{{name}}" WITH LOGIN PASSWORD '{{password}}' VALID UNTIL '{{expiration}}'; \
GRANT SELECT ON ALL TABLES IN SCHEMA public TO "{{name}}";`,
Expected: false,
},
}

for tName, tCase := range testCases {
t.Run(tName, func(t *testing.T) {
if containsMultilineStatement(tCase.Input) != tCase.Expected {
t.Fatalf("%q should be %t for multiline input", tCase.Input, tCase.Expected)
}
})
}
}

func TestExtractQuotedStrings(t *testing.T) {
type testCase struct {
Input string
Expected []string
}

testCases := map[string]*testCase{
"no quotes": {
Input: `Five little monkeys jumping on the bed`,
Expected: []string{},
},
"two of both quote types": {
Input: `"Five" little 'monkeys' "jumping on" the' 'bed`,
Expected: []string{`"Five"`, `"jumping on"`, `'monkeys'`, `' '`},
},
"one single quote": {
Input: `Five little monkeys 'jumping on the bed`,
Expected: []string{},
},
"empty string": {
Input: ``,
Expected: []string{},
},
"templated field": {
Input: `DO $$ BEGIN IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname="{{name}}") THEN CREATE ROLE {{name}}; END IF; END $$`,
Expected: []string{`"{{name}}"`},
},
}

for tName, tCase := range testCases {
t.Run(tName, func(t *testing.T) {
results, err := extractQuotedStrings(tCase.Input)
if err != nil {
t.Fatal(err)
}
if len(results) != len(tCase.Expected) {
t.Fatalf("%s isn't equal to %s", results, tCase.Expected)
}
for i := range results {
if results[i] != tCase.Expected[i] {
t.Fatalf(`expected %q but received %q`, tCase.Expected, results[i])
}
}
})
}
}