Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix regex parser for parsing functions having SQL body with language sql (PG 15 feature) #2201

Merged
merged 5 commits into from
Jan 22, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions yb-voyager/cmd/analyzeSchema.go
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,9 @@ sqlParsingLoop:
} else if matches := dollarQuoteRegex.FindStringSubmatch(currLine); matches != nil {
dollarQuoteFlag = 1 //denotes start of the code/body part
codeBlockDelimiter = matches[0]
} else if strings.Contains(currLine, "BEGIN ATOMIC") {
sanyamsinghal marked this conversation as resolved.
Show resolved Hide resolved
dollarQuoteFlag = 1 //denotes start of the sql body part https://www.postgresql.org/docs/15/sql-createfunction.html#:~:text=a%20new%20session.-,sql_body,-The%20body%20of
codeBlockDelimiter = "END"
Copy link
Collaborator

@sanyamsinghal sanyamsinghal Jan 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current code will not be covering the below cases:

  1. There are chances of BEGIN ATOMIC or END being on same line, different line, something on same line and rest on the next line.
CREATE OR REPLACE FUNCTION subtract_and_log(a INT, b INT)
RETURNS INT
LANGUAGE SQL
BEGIN ATOMIC INSERT INTO test_log(action) VALUES ('Subtracting numbers'); RETURN a - b; END;
  1. Posibility of different casing(smallcase, mixed case) used for BEGIN ATOMIC ... END;

  2. Using regexp here would be right way to go here, given there are N number of posibilities.

Copy link
Contributor Author

@priyanshi-yb priyanshi-yb Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are chances of BEGIN ATOMIC or END being on same line, different line, something on same line and rest on the next line.

This automatically works with the current code, added unit test case for that.

Posibility of different casing(smallcase, mixed case) used for BEGIN ATOMIC ... END;

Fixed the casing by using the BEGIN ATOMIC with regexp but END with string lower upper both

Not trying to cover a lot of cases given that we should use the pg-parser anyways but having this meanwhile to atleast able to detect the SQL body case for this PR #2165

Copy link
Collaborator

@sanyamsinghal sanyamsinghal Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt we are going to have pg-parser soon for breaking the sqlfile into sql statements.
Also for oracle cases where syntax is off, (not fully converted to PG), there the parser will fail, so we could either use current approach or mix of both.

I would say lets cover most of it if we can with small efforts.
Since this SQL body thing is a part of DDL, and our import schema can also break.

Copy link
Contributor Author

@priyanshi-yb priyanshi-yb Jan 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the SQL BODY, I have tried to cover most of the cases on the PG docs https://www.postgresql.org/docs/11/sql-createfunction.html (fixed one panic in IsEndOFSQL function). The only todo is that I am using the ToLOWER and END for codeDelimiter for SQL where we could use regex but should be okay I think.
do you have any other case in mind that I might be missing?

}
case CODE_BLOCK_STARTED:
if strings.Contains(currLine, codeBlockDelimiter) {
Expand Down
93 changes: 90 additions & 3 deletions yb-voyager/cmd/analyzeSchema_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,12 @@ CREATE TABLE another_table (
t.Errorf("Error creating file for the objType %s: %v", objType, err)
}


defer os.Remove(sqlFile.Name())

sqlInfoArr := parseSqlFileForObjectType(sqlFile.Name(), objType)
// Validate the number of SQL statements found
if len(sqlInfoArr) != len(expectedSqlInfoArr) {
t.Errorf("Expected %d SQL statements for %s, got %d",len(expectedSqlInfoArr), objType, len(sqlInfoArr))
t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr))
}

for i, expectedSqlInfo := range expectedSqlInfoArr {
Expand Down Expand Up @@ -140,9 +139,97 @@ $$ LANGUAGE plpgsql;`,

// Validate the number of SQL statements found
if len(sqlInfoArr) != len(expectedSqlInfoArr) {
t.Errorf("Expected %d SQL statements for %s, got %d",len(expectedSqlInfoArr), objType, len(sqlInfoArr))
t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr))
}

for i, expectedSqlInfo := range expectedSqlInfoArr {
assert.Equal(t, expectedSqlInfo.objName, sqlInfoArr[i].objName)
assert.Equal(t, expectedSqlInfo.stmt, sqlInfoArr[i].stmt)
assert.Equal(t, expectedSqlInfo.formattedStmt, sqlInfoArr[i].formattedStmt)
}

}

func TestFunctionSQLFile(t *testing.T) {
functionFileContent := `CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
BEGIN ATOMIC
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
END;

CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$
DECLARE
temp_salary employees.salary%TYPE;
BEGIN
CREATE TEMP TABLE temp_high_earners AS
SELECT * FROM employees WHERE salary > threshold;
FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP
RAISE NOTICE 'High earner salary: %', temp_salary;
END LOOP;
END;
$$ LANGUAGE plpgsql;

CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END;

CREATE FUNCTION public.asterisks1(n integer) RETURNS text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
RETURN repeat('*'::text, n);`

expectedSqlInfoArr := []sqlInfo{
sqlInfo{
objName: "public.asterisks",
stmt: "CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE BEGIN ATOMIC SELECT repeat('*'::text, g.g) AS repeat FROM generate_series(1, asterisks.n) g(g); END; ",
formattedStmt: `CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
BEGIN ATOMIC
SELECT repeat('*'::text, g.g) AS repeat
FROM generate_series(1, asterisks.n) g(g);
END;`,
},
sqlInfo{
objName: "copy_high_earners",
stmt: "CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$ DECLARE temp_salary employees.salary%TYPE; BEGIN CREATE TEMP TABLE temp_high_earners AS SELECT * FROM employees WHERE salary > threshold; FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP RAISE NOTICE 'High earner salary: %', temp_salary; END LOOP; END; $$ LANGUAGE plpgsql; ",
formattedStmt: `CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$
DECLARE
temp_salary employees.salary%TYPE;
BEGIN
CREATE TEMP TABLE temp_high_earners AS
SELECT * FROM employees WHERE salary > threshold;
FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP
RAISE NOTICE 'High earner salary: %', temp_salary;
END LOOP;
END;
$$ LANGUAGE plpgsql;`,
},
sqlInfo{
objName: "add",
stmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END; ",
formattedStmt: `CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END;`,
},
sqlInfo{
objName: "public.asterisks1",
stmt: "CREATE FUNCTION public.asterisks1(n integer) RETURNS text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE RETURN repeat('*'::text, n); ",
formattedStmt: `CREATE FUNCTION public.asterisks1(n integer) RETURNS text
LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE
RETURN repeat('*'::text, n);`,
},
}
objType := "FUNCTION"
sqlFile, err := setupFile(objType, functionFileContent)
if err != nil {
t.Errorf("Error creating file for the objType %s: %v", objType, err)
}

defer os.Remove(sqlFile.Name())

sqlInfoArr := parseSqlFileForObjectType(sqlFile.Name(), objType)

// Validate the number of SQL statements found
if len(sqlInfoArr) != len(expectedSqlInfoArr) {
t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr))
}

for i, expectedSqlInfo := range expectedSqlInfoArr {
assert.Equal(t, expectedSqlInfo.objName, sqlInfoArr[i].objName)
assert.Equal(t, expectedSqlInfo.stmt, sqlInfoArr[i].stmt)
Expand Down
Loading