Skip to content

Commit

Permalink
Merge pull request #1228 from dolthub/fulghum/case
Browse files Browse the repository at this point in the history
Add support for `CASE ... WHEN` in PL/pgSQL
  • Loading branch information
fulghum authored Mar 3, 2025
2 parents ec5c68d + f6b16a8 commit c4d9aab
Show file tree
Hide file tree
Showing 5 changed files with 291 additions and 1 deletion.
1 change: 1 addition & 0 deletions postgres/parser/types/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -2553,6 +2553,7 @@ var unreservedTypeTokens = map[string]*T{
"float4": Float,
"float8": Float,
"inet": INet,
"integer": Int4,
"int2": Int2,
"int4": Int4,
"int8": Int,
Expand Down
23 changes: 22 additions & 1 deletion server/plpgsql/interpreter_logic.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ package plpgsql

import (
"fmt"
"strings"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/analyzer"

"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/core/typecollection"
"github.com/dolthub/doltgresql/postgres/parser/types"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

Expand Down Expand Up @@ -92,7 +94,26 @@ func Call(ctx *sql.Context, iFunc InterpretedFunction, runner analyzer.Statement
if err != nil {
return nil, err
}
resolvedType, exists := typeCollection.GetType(id.NewType("pg_catalog", operation.PrimaryData))

// pg_query_go sets PrimaryData for implicit CASE statement variables to
// `pg_catalog."integer"`, so we remove double-quotes and extract the schema name.
typeName := operation.PrimaryData
typeName = strings.ReplaceAll(typeName, `"`, "")
schemaName := "pg_catalog"
if strings.Contains(typeName, ".") {
parts := strings.Split(typeName, ".")
schemaName = parts[0]
typeName = parts[1]
// Check the NonKeyword type names to see if we're looking at
// an alias of a type if we're in the pg_catalog schema.
if schemaName == "pg_catalog" {
typ, ok, _ := types.TypeForNonKeywordTypeName(typeName)
if ok && typ != nil {
typeName = typ.Name()
}
}
}
resolvedType, exists := typeCollection.GetType(id.NewType(schemaName, typeName))
if !exists {
return nil, pgtypes.ErrTypeDoesNotExist.New(operation.PrimaryData)
}
Expand Down
109 changes: 109 additions & 0 deletions server/plpgsql/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
package plpgsql

import (
"fmt"
"strings"

"github.com/cockroachdb/errors"
Expand Down Expand Up @@ -102,6 +103,25 @@ type plpgSQL_stmt_block struct {
LineNumber int32 `json:"lineno"`
}

// plpgSQL_stmt_case exists to match the expected JSON format.
type plpgSQL_stmt_case struct {
LineNumber int32 `json:"lineno"`
Expression expr `json:"t_expr"`
// VarNo indicates the ID for the __Case__Variable_N__ variable that holds the evaluated
// value of the case expression.
VarNo int32 `json:"t_varno"`
WhenList []statement `json:"case_when_list"`
HasElse bool `json:"have_else"`
Else []statement `json:"else_stmts"`
}

// plpgSQL_case_when exists to match the expected JSON format.
type plpgSQL_case_when struct {
LineNumber int32 `json:"lineno"`
Expression expr `json:"expr"`
Body []statement `json:"stmts"`
}

// plpgSQL_stmt_execsql exists to match the expected JSON format.
type plpgSQL_stmt_execsql struct {
SQLStmt sqlstmt `json:"sqlstmt"`
Expand Down Expand Up @@ -175,12 +195,14 @@ type sqlstmt struct {
// having a singular expected implementation.
type statement struct {
Assignment *plpgSQL_stmt_assign `json:"PLpgSQL_stmt_assign"`
Case *plpgSQL_stmt_case `json:"PLpgSQL_stmt_case"`
ExecSQL *plpgSQL_stmt_execsql `json:"PLpgSQL_stmt_execsql"`
Exit *plpgSQL_stmt_exit `json:"PLpgSQL_stmt_exit"`
If *plpgSQL_stmt_if `json:"PLpgSQL_stmt_if"`
Loop *plpgSQL_stmt_loop `json:"PLpgSQL_stmt_loop"`
Perform *plpgSQL_stmt_perform `json:"PLpgSQL_stmt_perform"`
Return *plpgSQL_stmt_return `json:"PLpgSQL_stmt_return"`
When *plpgSQL_case_when `json:"PLpgSQL_case_when"`
While *plpgSQL_stmt_while `json:"PLpgSQL_stmt_while"`
}

Expand All @@ -204,6 +226,93 @@ func (stmt *plpgSQL_stmt_assign) Convert() (Assignment, error) {
}, nil
}

func (stmt *plpgSQL_stmt_case) Convert() (block Block, err error) {
// If the CASE statement has a main expression, start by assigning it to a variable so
// we can evaluate it once and only once.
if stmt.Expression.Expression.Query != "" {
// TODO: pg_query_go creates the definitions for these variables, and
// ideally users shouldn't be able to reference them. We could
// update all the references to them (i.e. declaration, assignment,
// and WHEN block exprs) to change the name to include a \0 char to
// prevent users from referencing them or colliding with them.
block.Body = append(block.Body, Assignment{
VariableName: fmt.Sprintf("__Case__Variable_%d__", stmt.VarNo),
Expression: stmt.Expression.Expression.Query,
})
}

// Record indexes of all the GOTO ops that jump to the very end of the case block so we
// can update them later and plug in the correct offsets after we know the final size.
var gotoEndOpsIndexes []int

// Add operations for each WHEN statement...
for _, stmt := range stmt.WhenList {
when := stmt.When
if when == nil {
return Block{}, fmt.Errorf("case statement WHEN clause is nil")
}

// TODO: The generated expressions from pg_query_go uses double quotes
// around the variable name, which is valid for Postgres, but
// our engine doesn't currently resolve double-quoted strings to
// variables, so for now, we just extract the double quotes.
expressionString := when.Expression.Expression.Query
expressionString = strings.ReplaceAll(expressionString, `"`, "")

convertedWhenBodyStatements, err := jsonConvertStatements(when.Body)
if err != nil {
return Block{}, err
}

block.Body = append(block.Body,
If{
Condition: expressionString,
GotoOffset: 2,
},
Goto{
// This GOTO jumps to the next WHEN block, so step over all the statements
// from this WHEN block, plus 1 for the GOTO op we add at the end of each
// block, and plus 1 more to move to the next statement.
Offset: int32(len(convertedWhenBodyStatements) + 1 + 1),
})
block.Body = append(block.Body, convertedWhenBodyStatements...)

// Add a GOTO op to jump to the end of the entire CASE block, and record its position
// in the statement block so we can update it later.
block.Body = append(block.Body, Goto{})
gotoEndOpsIndexes = append(gotoEndOpsIndexes, len(block.Body)-1)
}

if stmt.HasElse {
convertElseBodyStatements, err := jsonConvertStatements(stmt.Else)
if err != nil {
return Block{}, err
}
block.Body = append(block.Body, convertElseBodyStatements...)
// TODO: If no cases match and there is no ELSE block, then add a RAISE statement
// to return an error.
//} else {
// Sample PostgreSQL error response:
// ERROR: case not found
// HINT: CASE statement is missing ELSE part.
// CONTEXT: PL/pgSQL function interpreted_case(integer) line 5 at CASE
}

// Update all the GOTO ops that jump to the very end of the case block.
for _, gotoEndOpIndex := range gotoEndOpsIndexes {
// Sanity check that we are looking at a GOTO statement
if _, ok := block.Body[gotoEndOpIndex].(Goto); !ok {
return Block{}, fmt.Errorf("expected Goto statement, got %T", block.Body[gotoEndOpIndex])
}

block.Body[gotoEndOpIndex] = Goto{
Offset: int32(len(block.Body) - gotoEndOpIndex),
}
}

return block, nil
}

// Convert converts the JSON statement into its output form.
func (stmt *plpgSQL_stmt_execsql) Convert() (ExecuteSQL, error) {
var target string
Expand Down
2 changes: 2 additions & 0 deletions server/plpgsql/json_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ func jsonConvertStatement(stmt statement) (Statement, error) {
switch {
case stmt.Assignment != nil:
return stmt.Assignment.Convert()
case stmt.Case != nil:
return stmt.Case.Convert()
case stmt.ExecSQL != nil:
return stmt.ExecSQL.Convert()
case stmt.Exit != nil:
Expand Down
157 changes: 157 additions & 0 deletions testing/go/create_function_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,163 @@ $$ LANGUAGE plpgsql;`},
},
},
},
{
Name: "CASE, with ELSE",
SetUpScript: []string{`
CREATE FUNCTION interpreted_case(x INT) RETURNS TEXT AS $$
DECLARE
msg TEXT;
BEGIN
CASE x
WHEN 1, 2 THEN
msg := 'one';
msg := msg || ' or two';
ELSE
msg := 'other';
msg := msg || ' value than one or two';
END CASE;
RETURN msg;
END;
$$ LANGUAGE plpgsql;`},
Assertions: []ScriptTestAssertion{
{
Query: "SELECT interpreted_case(1);",
Expected: []sql.Row{{"one or two"}},
},
{
Query: "SELECT interpreted_case(2);",
Expected: []sql.Row{{"one or two"}},
},
{
Query: "SELECT interpreted_case(0);",
Expected: []sql.Row{{"other value than one or two"}},
},
},
},
{
// TODO: When no CASE statements match, and there is no ELSE block,
// Postgres raises an exception. Unskip this test after we
// add support for raising exceptions from functions.
Skip: true,
Name: "CASE, without ELSE",
SetUpScript: []string{`
CREATE FUNCTION interpreted_case(x INT) RETURNS TEXT AS $$
DECLARE
msg TEXT;
BEGIN
CASE x
WHEN 1, 2 THEN
msg := 'one or two';
END CASE;
RETURN msg;
END;
$$ LANGUAGE plpgsql;`},
Assertions: []ScriptTestAssertion{
{
Query: "SELECT interpreted_case(1);",
Expected: []sql.Row{{"one or two"}},
},
{
Query: "SELECT interpreted_case(2);",
Expected: []sql.Row{{"one or two"}},
},
{
Query: "SELECT interpreted_case(0);",
ExpectedErr: "case not found",
},
},
},
{
Name: "Searched CASE, with ELSE",
SetUpScript: []string{`
CREATE FUNCTION interpreted_case(x INT) RETURNS TEXT AS $$
DECLARE
msg TEXT;
BEGIN
CASE
WHEN x BETWEEN 0 AND 10 THEN
msg := 'value is between zero';
msg := msg || ' and ten';
WHEN x BETWEEN 11 AND 20 THEN
msg := 'value is between eleven and twenty';
ELSE
msg := 'value';
msg := msg || ' is';
msg := msg || ' out of';
msg := msg || ' bounds';
END CASE;
RETURN msg;
END;
$$ LANGUAGE plpgsql;`},
Assertions: []ScriptTestAssertion{
{
Query: "SELECT interpreted_case(0);",
Expected: []sql.Row{{"value is between zero and ten"}},
},
{
Query: "SELECT interpreted_case(1);",
Expected: []sql.Row{{"value is between zero and ten"}},
},
{
Query: "SELECT interpreted_case(10);",
Expected: []sql.Row{{"value is between zero and ten"}},
},
{
Query: "SELECT interpreted_case(11);",
Expected: []sql.Row{{"value is between eleven and twenty"}},
},
{
Query: "SELECT interpreted_case(21);",
Expected: []sql.Row{{"value is out of bounds"}},
},
},
},
{
// TODO: When no CASE statements match, and there is no ELSE block,
// Postgres raises an exception. Unskip this test after we
// add support for raising exceptions from functions.
Skip: true,
Name: "Searched CASE, without ELSE",
SetUpScript: []string{`
CREATE FUNCTION interpreted_case(x INT) RETURNS TEXT AS $$
DECLARE
msg TEXT;
BEGIN
CASE
WHEN x BETWEEN 0 AND 10 THEN
msg := 'value is between zero and ten';
WHEN x BETWEEN 11 AND 20 THEN
msg := 'value';
msg := msg || ' is between';
msg := msg || ' eleven and';
msg := msg || ' twenty';
END CASE;
RETURN msg;
END;
$$ LANGUAGE plpgsql;`},
Assertions: []ScriptTestAssertion{
{
Query: "SELECT interpreted_case(0);",
Expected: []sql.Row{{"value is between zero and ten"}},
},
{
Query: "SELECT interpreted_case(1);",
Expected: []sql.Row{{"value is between zero and ten"}},
},
{
Query: "SELECT interpreted_case(10);",
Expected: []sql.Row{{"value is between zero and ten"}},
},
{
Query: "SELECT interpreted_case(11);",
Expected: []sql.Row{{"value is between eleven and twenty"}},
},
{
Query: "SELECT interpreted_case(21);",
ExpectedErr: "case not found",
},
},
},
{
Name: "CONTINUE",
SetUpScript: []string{`CREATE FUNCTION interpreted_continue() RETURNS INT4 AS $$
Expand Down

0 comments on commit c4d9aab

Please sign in to comment.