Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regex parser for parsing functions having SQL body with language sql (PG 15 feature) #2201

Merged
merged 5 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions yb-voyager/cmd/analyzeSchema.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ var (
parserIssueDetector = queryissue.NewParserIssueDetector()
multiRegex = regexp.MustCompile(`([a-zA-Z0-9_\.]+[,|;])`)
dollarQuoteRegex = regexp.MustCompile(`(\$.*\$)`)
sqlBodyBeginRegex = re("BEGIN", "ATOMIC")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we put a ^ in the regex to ensure that BEGIN ATOMIC is at the start of line?
Do we need to ensure that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to ensure that as it could be anywhere not necessarily in starting of the line e.g.

CREATE FUNCTION public.asterisksdfsf(n integer) RETURNS SETOF text
    LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE BEGIN ATOMIC
 SELECT repeat('*'::text, g.g) AS repeat
    FROM generate_series(1, asterisksdfsf.n) g(g);
END;

//TODO: optional but replace every possible space or new line char with [\s\n]+ in all regexs
viewWithCheckRegex = re("VIEW", capture(ident), anything, "WITH", opt(commonClause), "CHECK", "OPTION")
rangeRegex = re("PRECEDING", "and", anything, ":float")
Expand Down Expand Up @@ -912,7 +913,6 @@ sqlParsingLoop:

stmt += currLine + " "
formattedStmt += currLine + "\n"

// Assuming that both the dollar quote strings will not be in same line
switch dollarQuoteFlag {
case CODE_BLOCK_NOT_STARTED:
Expand All @@ -921,9 +921,14 @@ sqlParsingLoop:
} else if matches := dollarQuoteRegex.FindStringSubmatch(currLine); matches != nil {
dollarQuoteFlag = 1 //denotes start of the code/body part
codeBlockDelimiter = matches[0]
} else if matches := sqlBodyBeginRegex.FindStringSubmatch(currLine); matches != nil {
dollarQuoteFlag = 1 //denotes start of the sql body part https://www.postgresql.org/docs/15/sql-createfunction.html#:~:text=a%20new%20session.-,sql_body,-The%20body%20of
codeBlockDelimiter = "END" //SQL body to determine the end of BEGIN ATOMIC ... END; sql body
}
case CODE_BLOCK_STARTED:
if strings.Contains(currLine, codeBlockDelimiter) {
if strings.Contains(currLine, codeBlockDelimiter) ||
strings.Contains(currLine, strings.ToLower(codeBlockDelimiter)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Concern here is that we don't know the variety of cases possible with plpgsql, since it can almost have anything inside the body. Here the condition is getting changed to match with either same or lowercase variant of it.

Should we just handle the sqlbody func case separately(by that i mean having different if conditions here for that) and not update the existing ones.
Just want to avoid any unknowns

Copy link
Contributor Author

@priyanshi-yb priyanshi-yb Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the condition is getting changed to match with either same or lowercase variant of it.

the same or lowercase variant is only of the codeBlockDelimiter variable which could be dollarQuoteRegex = regexp.MustCompile((\$.*\$)) match or the END string for sqlBody.

But okay, I get your point about any unknowns with lowercase variant of the dollarQuoteRegex case, I can do that in separate if condition for just END case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

//TODO: anyways we should be using pg-parser: but for now for the END sql body delimiter checking the UPPER and LOWER both
dollarQuoteFlag = 2 //denotes end of code/body part
if isEndOfSqlStmt(currLine) {
break sqlParsingLoop
Expand Down Expand Up @@ -972,6 +977,9 @@ func isEndOfSqlStmt(line string) bool {
line = line[0:cmtStartIdx] // ignore comment
line = strings.TrimRight(line, " ")
}
if len(line) == 0 {
return false
}
return line[len(line)-1] == ';'
}

Expand Down
191 changes: 188 additions & 3 deletions yb-voyager/cmd/analyzeSchema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,12 @@ CREATE TABLE another_table (
t.Errorf("Error creating file for the objType %s: %v", objType, err)
}


defer os.Remove(sqlFile.Name())

sqlInfoArr := parseSqlFileForObjectType(sqlFile.Name(), objType)
// Validate the number of SQL statements found
if len(sqlInfoArr) != len(expectedSqlInfoArr) {
t.Errorf("Expected %d SQL statements for %s, got %d",len(expectedSqlInfoArr), objType, len(sqlInfoArr))
t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr))
}

for i, expectedSqlInfo := range expectedSqlInfoArr {
Expand Down Expand Up @@ -140,9 +139,195 @@ $$ LANGUAGE plpgsql;`,

// Validate the number of SQL statements found
if len(sqlInfoArr) != len(expectedSqlInfoArr) {
t.Errorf("Expected %d SQL statements for %s, got %d",len(expectedSqlInfoArr), objType, len(sqlInfoArr))
t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr))
}

for i, expectedSqlInfo := range expectedSqlInfoArr {
assert.Equal(t, expectedSqlInfo.objName, sqlInfoArr[i].objName)
assert.Equal(t, expectedSqlInfo.stmt, sqlInfoArr[i].stmt)
assert.Equal(t, expectedSqlInfo.formattedStmt, sqlInfoArr[i].formattedStmt)
}

}

func TestFunctionSQLFile(t *testing.T) {
functionFileContent := `CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
BEGIN ATOMIC
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
END;

CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$
DECLARE
temp_salary employees.salary%TYPE;
BEGIN
CREATE TEMP TABLE temp_high_earners AS
SELECT * FROM employees WHERE salary > threshold;
FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP
RAISE NOTICE 'High earner salary: %', temp_salary;
END LOOP;
END;
$$ LANGUAGE plpgsql;

CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END;

CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE
BEGIN ATOMIC; SELECT $1 + $2; END;

CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
begin atomic
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
end;

CREATE FUNCTION public.asterisks1(n integer) RETURNS text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
RETURN repeat('*'::text, n);

CREATE FUNCTION add(integer, integer) RETURNS integer
AS 'select test;'
LANGUAGE SQL
IMMUTABLE
RETURNS NULL ON NULL INPUT;

CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$
BEGIN
RETURN i + 1;
END;
$$ LANGUAGE plpgsql;

CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record
LANGUAGE sql
AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$;

CREATE FUNCTION check_password(uname TEXT, pass TEXT)
RETURNS BOOLEAN AS $$
DECLARE passed BOOLEAN;
BEGIN
SELECT (pwd = $2) INTO passed
FROM pwds
WHERE username = $1;

RETURN passed;
END;
$$ LANGUAGE plpgsql
SECURITY DEFINER
-- Set a secure search_path: trusted schema(s), then 'pg_temp'.
SET search_path = admin, pg_temp;`

expectedSqlInfoArr := []sqlInfo{
sqlInfo{
objName: "public.asterisks",
stmt: "CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE BEGIN ATOMIC SELECT repeat('*'::text, g.g) AS repeat FROM generate_series(1, asterisks.n) g(g); END; ",
formattedStmt: `CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
BEGIN ATOMIC
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
END;`,
},
sqlInfo{
objName: "copy_high_earners",
stmt: "CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$ DECLARE temp_salary employees.salary%TYPE; BEGIN CREATE TEMP TABLE temp_high_earners AS SELECT * FROM employees WHERE salary > threshold; FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP RAISE NOTICE 'High earner salary: %', temp_salary; END LOOP; END; $$ LANGUAGE plpgsql; ",
formattedStmt: `CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$
DECLARE
temp_salary employees.salary%TYPE;
BEGIN
CREATE TEMP TABLE temp_high_earners AS
SELECT * FROM employees WHERE salary > threshold;
FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP
RAISE NOTICE 'High earner salary: %', temp_salary;
END LOOP;
END;
$$ LANGUAGE plpgsql;`,
},
sqlInfo{
objName: "add",
stmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END; ",
formattedStmt: `CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END;`,
},
sqlInfo{
objName: "add",
stmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END; ",
formattedStmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE\nBEGIN ATOMIC; SELECT $1 + $2; END;",
},
sqlInfo{
objName: "public.case_sensitive_test",
stmt: "CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE begin atomic SELECT repeat('*'::text, g.g) AS repeat FROM generate_series(1, asterisks.n) g(g); end; ",
formattedStmt: `CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
begin atomic
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
end;`,
},
sqlInfo{
objName: "public.asterisks1",
stmt: "CREATE FUNCTION public.asterisks1(n integer) RETURNS text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE RETURN repeat('*'::text, n); ",
formattedStmt: `CREATE FUNCTION public.asterisks1(n integer) RETURNS text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
RETURN repeat('*'::text, n);`,
},
sqlInfo{
objName: "add",
stmt: "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select test;' LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT; ",
formattedStmt: `CREATE FUNCTION add(integer, integer) RETURNS integer
AS 'select test;'
LANGUAGE SQL
IMMUTABLE
RETURNS NULL ON NULL INPUT;`,
},
sqlInfo{
objName: "increment",
stmt: "CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$ BEGIN RETURN i + 1; END; $$ LANGUAGE plpgsql; ",
formattedStmt: `CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$
BEGIN
RETURN i + 1;
END;
$$ LANGUAGE plpgsql;`,
},
sqlInfo{
objName: "public.dup",
stmt: "CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record LANGUAGE sql AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$; ",
formattedStmt: `CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record
LANGUAGE sql
AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$;`,
},
sqlInfo{
objName: "check_password",
stmt: "CREATE FUNCTION check_password(uname TEXT, pass TEXT) RETURNS BOOLEAN AS $$ DECLARE passed BOOLEAN; BEGIN SELECT (pwd = $2) INTO passed FROM pwds WHERE username = $1; RETURN passed; END; $$ LANGUAGE plpgsql SECURITY DEFINER -- Set a secure search_path: trusted schema(s), then 'pg_temp'. SET search_path = admin, pg_temp; ",
formattedStmt: `CREATE FUNCTION check_password(uname TEXT, pass TEXT)
RETURNS BOOLEAN AS $$
DECLARE passed BOOLEAN;
BEGIN
SELECT (pwd = $2) INTO passed
FROM pwds
WHERE username = $1;
RETURN passed;
END;
$$ LANGUAGE plpgsql
SECURITY DEFINER
-- Set a secure search_path: trusted schema(s), then 'pg_temp'.
SET search_path = admin, pg_temp;`,
},
}
objType := "FUNCTION"
sqlFile, err := setupFile(objType, functionFileContent)
if err != nil {
t.Errorf("Error creating file for the objType %s: %v", objType, err)
}

defer os.Remove(sqlFile.Name())

sqlInfoArr := parseSqlFileForObjectType(sqlFile.Name(), objType)

// Validate the number of SQL statements found
if len(sqlInfoArr) != len(expectedSqlInfoArr) {
t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr))
}

for i, expectedSqlInfo := range expectedSqlInfoArr {
assert.Equal(t, expectedSqlInfo.objName, sqlInfoArr[i].objName)
assert.Equal(t, expectedSqlInfo.stmt, sqlInfoArr[i].stmt)
Expand Down