Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regex parser for parsing functions having SQL body with language sql (PG 15 feature) #2201

Merged
merged 5 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions yb-voyager/cmd/analyzeSchema.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ var (
parserIssueDetector = queryissue.NewParserIssueDetector()
multiRegex = regexp.MustCompile(`([a-zA-Z0-9_\.]+[,|;])`)
dollarQuoteRegex = regexp.MustCompile(`(\$.*\$)`)
sqlBodyBeginRegex = re("BEGIN", "ATOMIC")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we put a ^ in the regex to ensure that BEGIN ATOMIC is at the start of line?
Do we need to ensure that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to ensure that as it could be anywhere not necessarily in starting of the line e.g.

CREATE FUNCTION public.asterisksdfsf(n integer) RETURNS SETOF text
    LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE BEGIN ATOMIC
 SELECT repeat('*'::text, g.g) AS repeat
    FROM generate_series(1, asterisksdfsf.n) g(g);
END;

//TODO: optional but replace every possible space or new line char with [\s\n]+ in all regexs
viewWithCheckRegex = re("VIEW", capture(ident), anything, "WITH", opt(commonClause), "CHECK", "OPTION")
rangeRegex = re("PRECEDING", "and", anything, ":float")
Expand Down Expand Up @@ -912,7 +913,6 @@ sqlParsingLoop:

stmt += currLine + " "
formattedStmt += currLine + "\n"

// Assuming that both the dollar quote strings will not be in same line
switch dollarQuoteFlag {
case CODE_BLOCK_NOT_STARTED:
Expand All @@ -921,14 +921,30 @@ sqlParsingLoop:
} else if matches := dollarQuoteRegex.FindStringSubmatch(currLine); matches != nil {
dollarQuoteFlag = 1 //denotes start of the code/body part
codeBlockDelimiter = matches[0]
} else if matches := sqlBodyBeginRegex.FindStringSubmatch(currLine); matches != nil {
dollarQuoteFlag = 1 //denotes start of the sql body part https://www.postgresql.org/docs/15/sql-createfunction.html#:~:text=a%20new%20session.-,sql_body,-The%20body%20of
codeBlockDelimiter = "END" //SQL body to determine the end of BEGIN ATOMIC ... END; sql body
}
case CODE_BLOCK_STARTED:
if strings.Contains(currLine, codeBlockDelimiter) {
dollarQuoteFlag = 2 //denotes end of code/body part
if isEndOfSqlStmt(currLine) {
break sqlParsingLoop
switch codeBlockDelimiter {
case "END":
if strings.Contains(currLine, codeBlockDelimiter) ||
strings.Contains(currLine, strings.ToLower(codeBlockDelimiter)) {
Comment on lines +931 to +932
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit but can we use strings.Equalfold() here to compare both and dont care about casing(covering mixed case scenario also).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strings.Equalfold() checks if two strings are equal without case sensitivity but we need to check whether string Contains so can't use it here.

//TODO: anyways we should be using pg-parser: but for now for the END sql body delimiter checking the UPPER and LOWER both
dollarQuoteFlag = 2 //denotes end of code/body part
if isEndOfSqlStmt(currLine) {
break sqlParsingLoop
}
}
default:
if strings.Contains(currLine, codeBlockDelimiter) {
dollarQuoteFlag = 2 //denotes end of code/body part
if isEndOfSqlStmt(currLine) {
break sqlParsingLoop
}
}
}

case CODE_BLOCK_COMPLETED:
if isEndOfSqlStmt(currLine) {
break sqlParsingLoop
Expand Down Expand Up @@ -972,6 +988,9 @@ func isEndOfSqlStmt(line string) bool {
line = line[0:cmtStartIdx] // ignore comment
line = strings.TrimRight(line, " ")
}
if len(line) == 0 {
return false
}
return line[len(line)-1] == ';'
}

Expand Down
191 changes: 188 additions & 3 deletions yb-voyager/cmd/analyzeSchema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,12 @@ CREATE TABLE another_table (
t.Errorf("Error creating file for the objType %s: %v", objType, err)
}


defer os.Remove(sqlFile.Name())

sqlInfoArr := parseSqlFileForObjectType(sqlFile.Name(), objType)
// Validate the number of SQL statements found
if len(sqlInfoArr) != len(expectedSqlInfoArr) {
t.Errorf("Expected %d SQL statements for %s, got %d",len(expectedSqlInfoArr), objType, len(sqlInfoArr))
t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr))
}

for i, expectedSqlInfo := range expectedSqlInfoArr {
Expand Down Expand Up @@ -140,9 +139,195 @@ $$ LANGUAGE plpgsql;`,

// Validate the number of SQL statements found
if len(sqlInfoArr) != len(expectedSqlInfoArr) {
t.Errorf("Expected %d SQL statements for %s, got %d",len(expectedSqlInfoArr), objType, len(sqlInfoArr))
t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr))
}

for i, expectedSqlInfo := range expectedSqlInfoArr {
assert.Equal(t, expectedSqlInfo.objName, sqlInfoArr[i].objName)
assert.Equal(t, expectedSqlInfo.stmt, sqlInfoArr[i].stmt)
assert.Equal(t, expectedSqlInfo.formattedStmt, sqlInfoArr[i].formattedStmt)
}

}

func TestFunctionSQLFile(t *testing.T) {
functionFileContent := `CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
BEGIN ATOMIC
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
END;

CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$
DECLARE
temp_salary employees.salary%TYPE;
BEGIN
CREATE TEMP TABLE temp_high_earners AS
SELECT * FROM employees WHERE salary > threshold;
FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP
RAISE NOTICE 'High earner salary: %', temp_salary;
END LOOP;
END;
$$ LANGUAGE plpgsql;

CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END;

CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE
BEGIN ATOMIC; SELECT $1 + $2; END;

CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
begin atomic
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
end;

CREATE FUNCTION public.asterisks1(n integer) RETURNS text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
RETURN repeat('*'::text, n);

CREATE FUNCTION add(integer, integer) RETURNS integer
AS 'select test;'
LANGUAGE SQL
IMMUTABLE
RETURNS NULL ON NULL INPUT;

CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$
BEGIN
RETURN i + 1;
END;
$$ LANGUAGE plpgsql;

CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record
LANGUAGE sql
AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$;

CREATE FUNCTION check_password(uname TEXT, pass TEXT)
RETURNS BOOLEAN AS $$
DECLARE passed BOOLEAN;
BEGIN
SELECT (pwd = $2) INTO passed
FROM pwds
WHERE username = $1;

RETURN passed;
END;
$$ LANGUAGE plpgsql
SECURITY DEFINER
-- Set a secure search_path: trusted schema(s), then 'pg_temp'.
SET search_path = admin, pg_temp;`

expectedSqlInfoArr := []sqlInfo{
sqlInfo{
objName: "public.asterisks",
stmt: "CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE BEGIN ATOMIC SELECT repeat('*'::text, g.g) AS repeat FROM generate_series(1, asterisks.n) g(g); END; ",
formattedStmt: `CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
BEGIN ATOMIC
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
END;`,
},
sqlInfo{
objName: "copy_high_earners",
stmt: "CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$ DECLARE temp_salary employees.salary%TYPE; BEGIN CREATE TEMP TABLE temp_high_earners AS SELECT * FROM employees WHERE salary > threshold; FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP RAISE NOTICE 'High earner salary: %', temp_salary; END LOOP; END; $$ LANGUAGE plpgsql; ",
formattedStmt: `CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$
DECLARE
temp_salary employees.salary%TYPE;
BEGIN
CREATE TEMP TABLE temp_high_earners AS
SELECT * FROM employees WHERE salary > threshold;
FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP
RAISE NOTICE 'High earner salary: %', temp_salary;
END LOOP;
END;
$$ LANGUAGE plpgsql;`,
},
sqlInfo{
objName: "add",
stmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END; ",
formattedStmt: `CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END;`,
},
sqlInfo{
objName: "add",
stmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END; ",
formattedStmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE\nBEGIN ATOMIC; SELECT $1 + $2; END;",
},
sqlInfo{
objName: "public.case_sensitive_test",
stmt: "CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE begin atomic SELECT repeat('*'::text, g.g) AS repeat FROM generate_series(1, asterisks.n) g(g); end; ",
formattedStmt: `CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
begin atomic
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
end;`,
},
sqlInfo{
objName: "public.asterisks1",
stmt: "CREATE FUNCTION public.asterisks1(n integer) RETURNS text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE RETURN repeat('*'::text, n); ",
formattedStmt: `CREATE FUNCTION public.asterisks1(n integer) RETURNS text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
RETURN repeat('*'::text, n);`,
},
sqlInfo{
objName: "add",
stmt: "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select test;' LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT; ",
formattedStmt: `CREATE FUNCTION add(integer, integer) RETURNS integer
AS 'select test;'
LANGUAGE SQL
IMMUTABLE
RETURNS NULL ON NULL INPUT;`,
},
sqlInfo{
objName: "increment",
stmt: "CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$ BEGIN RETURN i + 1; END; $$ LANGUAGE plpgsql; ",
formattedStmt: `CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$
BEGIN
RETURN i + 1;
END;
$$ LANGUAGE plpgsql;`,
},
sqlInfo{
objName: "public.dup",
stmt: "CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record LANGUAGE sql AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$; ",
formattedStmt: `CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record
LANGUAGE sql
AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$;`,
},
sqlInfo{
objName: "check_password",
stmt: "CREATE FUNCTION check_password(uname TEXT, pass TEXT) RETURNS BOOLEAN AS $$ DECLARE passed BOOLEAN; BEGIN SELECT (pwd = $2) INTO passed FROM pwds WHERE username = $1; RETURN passed; END; $$ LANGUAGE plpgsql SECURITY DEFINER -- Set a secure search_path: trusted schema(s), then 'pg_temp'. SET search_path = admin, pg_temp; ",
formattedStmt: `CREATE FUNCTION check_password(uname TEXT, pass TEXT)
RETURNS BOOLEAN AS $$
DECLARE passed BOOLEAN;
BEGIN
SELECT (pwd = $2) INTO passed
FROM pwds
WHERE username = $1;
RETURN passed;
END;
$$ LANGUAGE plpgsql
SECURITY DEFINER
-- Set a secure search_path: trusted schema(s), then 'pg_temp'.
SET search_path = admin, pg_temp;`,
},
}
objType := "FUNCTION"
sqlFile, err := setupFile(objType, functionFileContent)
if err != nil {
t.Errorf("Error creating file for the objType %s: %v", objType, err)
}

defer os.Remove(sqlFile.Name())

sqlInfoArr := parseSqlFileForObjectType(sqlFile.Name(), objType)

// Validate the number of SQL statements found
if len(sqlInfoArr) != len(expectedSqlInfoArr) {
t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr))
}

for i, expectedSqlInfo := range expectedSqlInfoArr {
assert.Equal(t, expectedSqlInfo.objName, sqlInfoArr[i].objName)
assert.Equal(t, expectedSqlInfo.stmt, sqlInfoArr[i].stmt)
Expand Down