From 860e99ab79310019904df32f719d6ed25b930b61 Mon Sep 17 00:00:00 2001 From: priyanshi-yb Date: Thu, 16 Jan 2025 18:30:25 +0530 Subject: [PATCH 1/5] Fix regex parser issue for parsing functions having SQL body with language sql (PG15 feature) --- yb-voyager/cmd/analyzeSchema.go | 3 + yb-voyager/cmd/analyzeSchema_test.go | 93 +++++++++++++++++++++++++++- 2 files changed, 93 insertions(+), 3 deletions(-) diff --git a/yb-voyager/cmd/analyzeSchema.go b/yb-voyager/cmd/analyzeSchema.go index fe25825e6..6bd80e393 100644 --- a/yb-voyager/cmd/analyzeSchema.go +++ b/yb-voyager/cmd/analyzeSchema.go @@ -921,6 +921,9 @@ sqlParsingLoop: } else if matches := dollarQuoteRegex.FindStringSubmatch(currLine); matches != nil { dollarQuoteFlag = 1 //denotes start of the code/body part codeBlockDelimiter = matches[0] + } else if strings.Contains(currLine, "BEGIN ATOMIC") { + dollarQuoteFlag = 1 //denotes start of the sql body part https://www.postgresql.org/docs/15/sql-createfunction.html#:~:text=a%20new%20session.-,sql_body,-The%20body%20of + codeBlockDelimiter = "END" } case CODE_BLOCK_STARTED: if strings.Contains(currLine, codeBlockDelimiter) { diff --git a/yb-voyager/cmd/analyzeSchema_test.go b/yb-voyager/cmd/analyzeSchema_test.go index d768b94b0..670014c5a 100644 --- a/yb-voyager/cmd/analyzeSchema_test.go +++ b/yb-voyager/cmd/analyzeSchema_test.go @@ -76,13 +76,12 @@ CREATE TABLE another_table ( t.Errorf("Error creating file for the objType %s: %v", objType, err) } - defer os.Remove(sqlFile.Name()) sqlInfoArr := parseSqlFileForObjectType(sqlFile.Name(), objType) // Validate the number of SQL statements found if len(sqlInfoArr) != len(expectedSqlInfoArr) { - t.Errorf("Expected %d SQL statements for %s, got %d",len(expectedSqlInfoArr), objType, len(sqlInfoArr)) + t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr)) } for i, expectedSqlInfo := range expectedSqlInfoArr { @@ -140,9 +139,97 @@ $$ LANGUAGE plpgsql;`, // Validate the number of SQL statements found if len(sqlInfoArr) != len(expectedSqlInfoArr) { - t.Errorf("Expected %d SQL statements for %s, got %d",len(expectedSqlInfoArr), objType, len(sqlInfoArr)) + t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr)) + } + + for i, expectedSqlInfo := range expectedSqlInfoArr { + assert.Equal(t, expectedSqlInfo.objName, sqlInfoArr[i].objName) + assert.Equal(t, expectedSqlInfo.stmt, sqlInfoArr[i].stmt) + assert.Equal(t, expectedSqlInfo.formattedStmt, sqlInfoArr[i].formattedStmt) } +} + +func TestFunctionSQLFile(t *testing.T) { + functionFileContent := `CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text + LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE + BEGIN ATOMIC + SELECT repeat('*'::text, g.g) AS repeat + FROM generate_series(1, asterisks.n) g(g); +END; + +CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$ +DECLARE + temp_salary employees.salary%TYPE; +BEGIN + CREATE TEMP TABLE temp_high_earners AS + SELECT * FROM employees WHERE salary > threshold; + FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP + RAISE NOTICE 'High earner salary: %', temp_salary; + END LOOP; +END; +$$ LANGUAGE plpgsql; + +CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END; + +CREATE FUNCTION public.asterisks1(n integer) RETURNS text + LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE + RETURN repeat('*'::text, n);` + + expectedSqlInfoArr := []sqlInfo{ + sqlInfo{ + objName: "public.asterisks", + stmt: "CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE BEGIN ATOMIC SELECT repeat('*'::text, g.g) AS repeat FROM generate_series(1, asterisks.n) g(g); END; ", + formattedStmt: `CREATE FUNCTION public.asterisks(n integer) RETURNS SETOF text + LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE + BEGIN ATOMIC + SELECT repeat('*'::text, g.g) AS repeat + FROM generate_series(1, asterisks.n) g(g); +END;`, + }, + sqlInfo{ + objName: "copy_high_earners", + stmt: "CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$ DECLARE temp_salary employees.salary%TYPE; BEGIN CREATE TEMP TABLE temp_high_earners AS SELECT * FROM employees WHERE salary > threshold; FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP RAISE NOTICE 'High earner salary: %', temp_salary; END LOOP; END; $$ LANGUAGE plpgsql; ", + formattedStmt: `CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$ +DECLARE + temp_salary employees.salary%TYPE; +BEGIN + CREATE TEMP TABLE temp_high_earners AS + SELECT * FROM employees WHERE salary > threshold; + FOR temp_salary IN SELECT salary FROM temp_high_earners LOOP + RAISE NOTICE 'High earner salary: %', temp_salary; + END LOOP; +END; +$$ LANGUAGE plpgsql;`, + }, + sqlInfo{ + objName: "add", + stmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END; ", + formattedStmt: `CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END;`, + }, + sqlInfo{ + objName: "public.asterisks1", + stmt: "CREATE FUNCTION public.asterisks1(n integer) RETURNS text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE RETURN repeat('*'::text, n); ", + formattedStmt: `CREATE FUNCTION public.asterisks1(n integer) RETURNS text + LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE + RETURN repeat('*'::text, n);`, + }, + } + objType := "FUNCTION" + sqlFile, err := setupFile(objType, functionFileContent) + if err != nil { + t.Errorf("Error creating file for the objType %s: %v", objType, err) + } + + defer os.Remove(sqlFile.Name()) + + sqlInfoArr := parseSqlFileForObjectType(sqlFile.Name(), objType) + + // Validate the number of SQL statements found + if len(sqlInfoArr) != len(expectedSqlInfoArr) { + t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr)) + } + for i, expectedSqlInfo := range expectedSqlInfoArr { assert.Equal(t, expectedSqlInfo.objName, sqlInfoArr[i].objName) assert.Equal(t, expectedSqlInfo.stmt, sqlInfoArr[i].stmt) From f4e04a988b20dbd8d35ffc570fe12db6fe2e8ebc Mon Sep 17 00:00:00 2001 From: priyanshi-yb Date: Tue, 21 Jan 2025 15:04:14 +0530 Subject: [PATCH 2/5] cover other cases --- yb-voyager/cmd/analyzeSchema.go | 12 ++++++++---- yb-voyager/cmd/analyzeSchema_test.go | 26 ++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/yb-voyager/cmd/analyzeSchema.go b/yb-voyager/cmd/analyzeSchema.go index 6bd80e393..b649459f7 100644 --- a/yb-voyager/cmd/analyzeSchema.go +++ b/yb-voyager/cmd/analyzeSchema.go @@ -134,6 +134,7 @@ var ( parserIssueDetector = queryissue.NewParserIssueDetector() multiRegex = regexp.MustCompile(`([a-zA-Z0-9_\.]+[,|;])`) dollarQuoteRegex = regexp.MustCompile(`(\$.*\$)`) + sqlBodyBeginRegex = re("BEGIN", "ATOMIC") //TODO: optional but replace every possible space or new line char with [\s\n]+ in all regexs viewWithCheckRegex = re("VIEW", capture(ident), anything, "WITH", opt(commonClause), "CHECK", "OPTION") rangeRegex = re("PRECEDING", "and", anything, ":float") @@ -921,12 +922,15 @@ sqlParsingLoop: } else if matches := dollarQuoteRegex.FindStringSubmatch(currLine); matches != nil { dollarQuoteFlag = 1 //denotes start of the code/body part codeBlockDelimiter = matches[0] - } else if strings.Contains(currLine, "BEGIN ATOMIC") { - dollarQuoteFlag = 1 //denotes start of the sql body part https://www.postgresql.org/docs/15/sql-createfunction.html#:~:text=a%20new%20session.-,sql_body,-The%20body%20of - codeBlockDelimiter = "END" + } else if matches := sqlBodyBeginRegex.FindStringSubmatch(currLine); matches != nil { + dollarQuoteFlag = 1 //denotes start of the sql body part https://www.postgresql.org/docs/15/sql-createfunction.html#:~:text=a%20new%20session.-,sql_body,-The%20body%20of + codeBlockDelimiter = "END" //SQL body to determine the end of BEGIN ATOMIC ... END; sql body } case CODE_BLOCK_STARTED: - if strings.Contains(currLine, codeBlockDelimiter) { + if strings.Contains(currLine, codeBlockDelimiter) || + strings.Contains(currLine, strings.ToLower(codeBlockDelimiter)) { + //TODO: anyways we should be using pg-parser: for the END sql body delimiter checking the UPPER and LOWER both + //not using regex as there are some issues while doing that (not debugged that) dollarQuoteFlag = 2 //denotes end of code/body part if isEndOfSqlStmt(currLine) { break sqlParsingLoop diff --git a/yb-voyager/cmd/analyzeSchema_test.go b/yb-voyager/cmd/analyzeSchema_test.go index 670014c5a..0405cc418 100644 --- a/yb-voyager/cmd/analyzeSchema_test.go +++ b/yb-voyager/cmd/analyzeSchema_test.go @@ -172,6 +172,16 @@ $$ LANGUAGE plpgsql; CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END; +CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE +BEGIN ATOMIC; SELECT $1 + $2; END; + +CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text + LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE + begin atomic + SELECT repeat('*'::text, g.g) AS repeat + FROM generate_series(1, asterisks.n) g(g); +end; + CREATE FUNCTION public.asterisks1(n integer) RETURNS text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE RETURN repeat('*'::text, n);` @@ -207,6 +217,21 @@ $$ LANGUAGE plpgsql;`, stmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END; ", formattedStmt: `CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END;`, }, + sqlInfo{ + objName: "add", + stmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE BEGIN ATOMIC; SELECT $1 + $2; END; ", + formattedStmt: "CREATE FUNCTION add(int, int) RETURNS int IMMUTABLE PARALLEL SAFE\nBEGIN ATOMIC; SELECT $1 + $2; END;", + }, + sqlInfo{ + objName: "public.case_sensitive_test", + stmt: "CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE begin atomic SELECT repeat('*'::text, g.g) AS repeat FROM generate_series(1, asterisks.n) g(g); end; ", + formattedStmt: `CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text + LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE + begin atomic + SELECT repeat('*'::text, g.g) AS repeat + FROM generate_series(1, asterisks.n) g(g); +end;`, + }, sqlInfo{ objName: "public.asterisks1", stmt: "CREATE FUNCTION public.asterisks1(n integer) RETURNS text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE RETURN repeat('*'::text, n); ", @@ -230,6 +255,7 @@ $$ LANGUAGE plpgsql;`, t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr)) } + fmt.Printf("sqlinfoarr - %v", sqlInfoArr) for i, expectedSqlInfo := range expectedSqlInfoArr { assert.Equal(t, expectedSqlInfo.objName, sqlInfoArr[i].objName) assert.Equal(t, expectedSqlInfo.stmt, sqlInfoArr[i].stmt) From 80de2d6eb2bcf29c0c13a6902e11454f3674fcc8 Mon Sep 17 00:00:00 2001 From: priyanshi-yb Date: Tue, 21 Jan 2025 15:05:16 +0530 Subject: [PATCH 3/5] remove print --- yb-voyager/cmd/analyzeSchema_test.go | 1 - 1 file changed, 1 deletion(-) diff --git a/yb-voyager/cmd/analyzeSchema_test.go b/yb-voyager/cmd/analyzeSchema_test.go index 0405cc418..a6c41476b 100644 --- a/yb-voyager/cmd/analyzeSchema_test.go +++ b/yb-voyager/cmd/analyzeSchema_test.go @@ -255,7 +255,6 @@ end;`, t.Errorf("Expected %d SQL statements for %s, got %d", len(expectedSqlInfoArr), objType, len(sqlInfoArr)) } - fmt.Printf("sqlinfoarr - %v", sqlInfoArr) for i, expectedSqlInfo := range expectedSqlInfoArr { assert.Equal(t, expectedSqlInfo.objName, sqlInfoArr[i].objName) assert.Equal(t, expectedSqlInfo.stmt, sqlInfoArr[i].stmt) From ff7fb07cce31ea2eb61360e38a65286ecb3ca2ac Mon Sep 17 00:00:00 2001 From: priyanshi-yb Date: Tue, 21 Jan 2025 18:53:37 +0530 Subject: [PATCH 4/5] add more cases, and handle comment case in IsEndOFSqlStmt --- yb-voyager/cmd/analyzeSchema.go | 7 +-- yb-voyager/cmd/analyzeSchema_test.go | 79 ++++++++++++++++++++++++++-- 2 files changed, 80 insertions(+), 6 deletions(-) diff --git a/yb-voyager/cmd/analyzeSchema.go b/yb-voyager/cmd/analyzeSchema.go index b649459f7..6a07462c1 100644 --- a/yb-voyager/cmd/analyzeSchema.go +++ b/yb-voyager/cmd/analyzeSchema.go @@ -913,7 +913,6 @@ sqlParsingLoop: stmt += currLine + " " formattedStmt += currLine + "\n" - // Assuming that both the dollar quote strings will not be in same line switch dollarQuoteFlag { case CODE_BLOCK_NOT_STARTED: @@ -929,8 +928,7 @@ sqlParsingLoop: case CODE_BLOCK_STARTED: if strings.Contains(currLine, codeBlockDelimiter) || strings.Contains(currLine, strings.ToLower(codeBlockDelimiter)) { - //TODO: anyways we should be using pg-parser: for the END sql body delimiter checking the UPPER and LOWER both - //not using regex as there are some issues while doing that (not debugged that) + //TODO: anyways we should be using pg-parser: but for now for the END sql body delimiter checking the UPPER and LOWER both dollarQuoteFlag = 2 //denotes end of code/body part if isEndOfSqlStmt(currLine) { break sqlParsingLoop @@ -979,6 +977,9 @@ func isEndOfSqlStmt(line string) bool { line = line[0:cmtStartIdx] // ignore comment line = strings.TrimRight(line, " ") } + if len(line) == 0 { + return false + } return line[len(line)-1] == ';' } diff --git a/yb-voyager/cmd/analyzeSchema_test.go b/yb-voyager/cmd/analyzeSchema_test.go index a6c41476b..97bd80a3b 100644 --- a/yb-voyager/cmd/analyzeSchema_test.go +++ b/yb-voyager/cmd/analyzeSchema_test.go @@ -156,7 +156,7 @@ func TestFunctionSQLFile(t *testing.T) { BEGIN ATOMIC SELECT repeat('*'::text, g.g) AS repeat FROM generate_series(1, asterisks.n) g(g); -END; +END; CREATE OR REPLACE FUNCTION copy_high_earners(threshold NUMERIC) RETURNS VOID AS $$ DECLARE @@ -180,11 +180,42 @@ CREATE FUNCTION public.case_sensitive_test(n integer) RETURNS SETOF text begin atomic SELECT repeat('*'::text, g.g) AS repeat FROM generate_series(1, asterisks.n) g(g); -end; +end; CREATE FUNCTION public.asterisks1(n integer) RETURNS text LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE - RETURN repeat('*'::text, n);` + RETURN repeat('*'::text, n); + +CREATE FUNCTION add(integer, integer) RETURNS integer + AS 'select test;' + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT; + +CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; +$$ LANGUAGE plpgsql; + +CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record + LANGUAGE sql + AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$; + +CREATE FUNCTION check_password(uname TEXT, pass TEXT) +RETURNS BOOLEAN AS $$ +DECLARE passed BOOLEAN; +BEGIN + SELECT (pwd = $2) INTO passed + FROM pwds + WHERE username = $1; + + RETURN passed; +END; +$$ LANGUAGE plpgsql + SECURITY DEFINER + -- Set a secure search_path: trusted schema(s), then 'pg_temp'. + SET search_path = admin, pg_temp;` expectedSqlInfoArr := []sqlInfo{ sqlInfo{ @@ -239,6 +270,48 @@ end;`, LANGUAGE sql IMMUTABLE STRICT PARALLEL SAFE RETURN repeat('*'::text, n);`, }, + sqlInfo{ + objName: "add", + stmt: "CREATE FUNCTION add(integer, integer) RETURNS integer AS 'select test;' LANGUAGE SQL IMMUTABLE RETURNS NULL ON NULL INPUT; ", + formattedStmt: `CREATE FUNCTION add(integer, integer) RETURNS integer + AS 'select test;' + LANGUAGE SQL + IMMUTABLE + RETURNS NULL ON NULL INPUT;`, + }, + sqlInfo{ + objName: "increment", + stmt: "CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$ BEGIN RETURN i + 1; END; $$ LANGUAGE plpgsql; ", + formattedStmt: `CREATE OR REPLACE FUNCTION increment(i integer) RETURNS integer AS $$ + BEGIN + RETURN i + 1; + END; +$$ LANGUAGE plpgsql;`, + }, + sqlInfo{ + objName: "public.dup", + stmt: "CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record LANGUAGE sql AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$; ", + formattedStmt: `CREATE FUNCTION public.dup(integer, OUT f1 integer, OUT f2 text) RETURNS record + LANGUAGE sql + AS $_$ SELECT $1, CAST($1 AS text) || ' is text' $_$;`, + }, + sqlInfo{ + objName: "check_password", + stmt: "CREATE FUNCTION check_password(uname TEXT, pass TEXT) RETURNS BOOLEAN AS $$ DECLARE passed BOOLEAN; BEGIN SELECT (pwd = $2) INTO passed FROM pwds WHERE username = $1; RETURN passed; END; $$ LANGUAGE plpgsql SECURITY DEFINER -- Set a secure search_path: trusted schema(s), then 'pg_temp'. SET search_path = admin, pg_temp; ", + formattedStmt: `CREATE FUNCTION check_password(uname TEXT, pass TEXT) +RETURNS BOOLEAN AS $$ +DECLARE passed BOOLEAN; +BEGIN + SELECT (pwd = $2) INTO passed + FROM pwds + WHERE username = $1; + RETURN passed; +END; +$$ LANGUAGE plpgsql + SECURITY DEFINER + -- Set a secure search_path: trusted schema(s), then 'pg_temp'. + SET search_path = admin, pg_temp;`, + }, } objType := "FUNCTION" sqlFile, err := setupFile(objType, functionFileContent) From 6df033fb649b338c99b8d1c392507182dac7f96b Mon Sep 17 00:00:00 2001 From: priyanshi-yb Date: Tue, 21 Jan 2025 21:28:26 +0530 Subject: [PATCH 5/5] review comment --- yb-voyager/cmd/analyzeSchema.go | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/yb-voyager/cmd/analyzeSchema.go b/yb-voyager/cmd/analyzeSchema.go index 6a07462c1..7c4f6635a 100644 --- a/yb-voyager/cmd/analyzeSchema.go +++ b/yb-voyager/cmd/analyzeSchema.go @@ -926,14 +926,25 @@ sqlParsingLoop: codeBlockDelimiter = "END" //SQL body to determine the end of BEGIN ATOMIC ... END; sql body } case CODE_BLOCK_STARTED: - if strings.Contains(currLine, codeBlockDelimiter) || - strings.Contains(currLine, strings.ToLower(codeBlockDelimiter)) { - //TODO: anyways we should be using pg-parser: but for now for the END sql body delimiter checking the UPPER and LOWER both - dollarQuoteFlag = 2 //denotes end of code/body part - if isEndOfSqlStmt(currLine) { - break sqlParsingLoop + switch codeBlockDelimiter { + case "END": + if strings.Contains(currLine, codeBlockDelimiter) || + strings.Contains(currLine, strings.ToLower(codeBlockDelimiter)) { + //TODO: anyways we should be using pg-parser: but for now for the END sql body delimiter checking the UPPER and LOWER both + dollarQuoteFlag = 2 //denotes end of code/body part + if isEndOfSqlStmt(currLine) { + break sqlParsingLoop + } + } + default: + if strings.Contains(currLine, codeBlockDelimiter) { + dollarQuoteFlag = 2 //denotes end of code/body part + if isEndOfSqlStmt(currLine) { + break sqlParsingLoop + } } } + case CODE_BLOCK_COMPLETED: if isEndOfSqlStmt(currLine) { break sqlParsingLoop