Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rewriting stored procedures #2851

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ type Engine struct {
Parser sql.Parser
}

var _ analyzer.StatementRunner = (*Engine)(nil)
var _ sql.StatementRunner = (*Engine)(nil)

type ColumnWithRawDefault struct {
SqlColumn *sql.Column
Expand Down
28 changes: 23 additions & 5 deletions enginetest/memory_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,24 +198,42 @@ func TestSingleQueryPrepared(t *testing.T) {

// Convenience test for debugging a single query. Unskip and set to the desired query.
func TestSingleScript(t *testing.T) {
t.Skip()
//t.Skip()
var scripts = []queries.ScriptTest{
{
Name: "test script",
Name: "Simple SELECT",
SetUpScript: []string{
"create table t (i int);",
`
create procedure proc(i int)
begin
set @x = 0;
repeat set @x = @x + 1;
until @x > i
end repeat;
end`,
//"create procedure proc(x int) select x > 1;",
},
Assertions: []queries.ScriptTestAssertion{
{
Query: "select 1 into @a",
Expected: []sql.Row{},
Query: "call proc(10);",
Expected: []sql.Row{
{types.NewOkResult(0)},
},
},
{
Query: "select @x;",
Expected: []sql.Row{
{11},
},
},
},
},
}

for _, test := range scripts {
harness := enginetest.NewMemoryHarness("", 1, testNumPartitions, true, nil)
// TODO: fix this
//harness.UseServer()
engine, err := harness.NewEngine(t)
if err != nil {
panic(err)
Expand Down
112 changes: 101 additions & 11 deletions enginetest/queries/procedure_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"time"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/types"
)

Expand Down Expand Up @@ -75,17 +76,13 @@ var ProcedureLogicTests = []ScriptTest{
{
Query: "CALL testabc(2, 3)",
Expected: []sql.Row{
{
6.0,
},
{6.0},
},
},
{
Query: "CALL testabc(9, 9.5)",
Expected: []sql.Row{
{
85.5,
},
{85.5},
},
},
},
Expand Down Expand Up @@ -2831,19 +2828,112 @@ var ProcedureCreateInSubroutineTests = []ScriptTest{
},
},
},

{
Name: "procedure must not contain CREATE TABLE",
Name: "table ddl statements in stored procedures",
Assertions: []ScriptTestAssertion{
{
Query: "create procedure p() create table t (pk int);",
ExpectedErrStr: "CREATE statements in CREATE PROCEDURE not yet supported",
Query: "create procedure create_proc() create table t (i int primary key, j int);",
Expected: []sql.Row{
{types.NewOkResult(0)},
},
},
{
Query: "create procedure p() begin create table t (pk int); end;",
ExpectedErrStr: "CREATE statements in CREATE PROCEDURE not yet supported",
Query: "call create_proc()",
Expected: []sql.Row{
{types.NewOkResult(0)},
},
},
{
Query: "show create table t;",
Expected: []sql.Row{
{"t", "CREATE TABLE `t` (\n" +
" `i` int NOT NULL,\n" +
" `j` int,\n" +
" PRIMARY KEY (`i`)\n" +
") ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"},
},
},
{
Query: "call create_proc()",
ExpectedErrStr: "table with name t already exists",
},

{
Query: "create procedure insert_proc() insert into t values (1, 1), (2, 2), (3, 3);",
Expected: []sql.Row{
{types.NewOkResult(0)},
},
},
{
Query: "call insert_proc()",
Expected: []sql.Row{
{types.NewOkResult(3)},
},
},
{
Query: "select * from t",
Expected: []sql.Row{
{1, 1},
{2, 2},
{3, 3},
},
},
{
Query: "call insert_proc()",
ExpectedErrStr: "duplicate primary key given: [1]",
},

{
Query: "create procedure update_proc() update t set j = 999 where i > 1;",
Expected: []sql.Row{
{types.NewOkResult(0)},
},
},
{
Query: "call update_proc()",
Expected: []sql.Row{
{types.OkResult{RowsAffected: 2, Info: plan.UpdateInfo{Matched: 2, Updated: 2}}},
},
},
{
Query: "select * from t",
Expected: []sql.Row{
{1, 1},
{2, 999},
{3, 999},
},
},
{
Query: "call update_proc()",
Expected: []sql.Row{
{types.OkResult{RowsAffected: 0, Info: plan.UpdateInfo{Matched: 2}}},
},
},

{
Query: "create procedure drop_proc() drop table t;",
Expected: []sql.Row{
{types.NewOkResult(0)},
},
},
{
Query: "call drop_proc()",
Expected: []sql.Row{
{types.NewOkResult(0)},
},
},
{
Query: "show tables like 't'",
Expected: []sql.Row{},
},
{
Query: "call drop_proc()",
ExpectedErrStr: "Unknown table 't'",
},
},
},

{
Name: "procedure must not contain CREATE TRIGGER",
SetUpScript: []string{
Expand Down
2 changes: 1 addition & 1 deletion sql/analyzer/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ type Analyzer struct {
// ExecBuilder converts a sql.Node tree into an executable iterator.
ExecBuilder sql.NodeExecBuilder
// Runner represents the engine, which is represented as a separate interface to work around circular dependencies
Runner StatementRunner
Runner sql.StatementRunner
}

// NewDefault creates a default Analyzer instance with all default Rules and configuration.
Expand Down
38 changes: 19 additions & 19 deletions sql/analyzer/interpreter.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,33 @@
package analyzer

import (
"github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/go-mysql-server/sql/transform"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/procedures"
"github.com/dolthub/go-mysql-server/sql/transform"
)

// Interpreter is an interface that implements an interpreter. These are typically used for functions (which may be
// implemented as a set of operations that are interpreted during runtime).
type Interpreter interface {
SetStatementRunner(ctx *sql.Context, runner StatementRunner) sql.Expression
}

// StatementRunner is essentially an interface that the engine will implement. We cannot directly reference the engine
// here as it will cause an import cycle, so this may be updated to suit any function changes that the engine
// experiences.
type StatementRunner interface {
QueryWithBindings(ctx *sql.Context, query string, parsed sqlparser.Statement, bindings map[string]sqlparser.Expr, qFlags *sql.QueryFlags) (sql.Schema, sql.RowIter, *sql.QueryFlags, error)
}

// interpreter hands the engine to any interpreter expressions.
func interpreter(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
return transform.NodeExprs(n, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
if interp, ok := expr.(Interpreter); ok {
newNode, sameNode, err := transform.Node(n, func(node sql.Node) (sql.Node, transform.TreeIdentity, error) {
if interp, ok := node.(procedures.InterpreterNode); ok {
return interp.SetStatementRunner(ctx, a.Runner), transform.NewTree, nil
}
return node, transform.SameTree, nil
})
if err != nil {
return nil, transform.SameTree, err
}

newNode, sameExpr, err := transform.NodeExprs(newNode, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
if interp, ok := expr.(sql.Interpreter); ok {
return interp.SetStatementRunner(ctx, a.Runner), transform.NewTree, nil
}
return expr, transform.SameTree, nil
})
if err != nil {
return nil, transform.SameTree, err
}

return newNode, sameNode && sameExpr, err
}
2 changes: 2 additions & 0 deletions sql/expression/procedurereference.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ import (
"github.com/dolthub/go-mysql-server/sql/types"
)

// TODO: instead of procedure reference, copy stack from doltgres

// ProcedureReference contains the state for a single CALL statement of a stored procedure.
type ProcedureReference struct {
InnermostScope *procedureScope
Expand Down
49 changes: 43 additions & 6 deletions sql/plan/call.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ package plan
import (
"fmt"

"github.com/dolthub/go-mysql-server/sql/types"

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/procedures"
"github.com/dolthub/go-mysql-server/sql/types"
)

// TODO: we need different types of calls: one for external procedures one for stored procedures

type Call struct {
db sql.Database
Name string
Expand All @@ -32,22 +34,28 @@ type Call struct {
Pref *expression.ProcedureReference
cat sql.Catalog
Analyzed bool

// this will have list of parsed operations to run
Runner sql.StatementRunner
Ops []procedures.InterpreterOperation
}

var _ sql.Node = (*Call)(nil)
var _ sql.CollationCoercible = (*Call)(nil)
var _ sql.Expressioner = (*Call)(nil)
var _ procedures.InterpreterNode = (*Call)(nil)
var _ Versionable = (*Call)(nil)

// NewCall returns a *Call node.
func NewCall(db sql.Database, name string, params []sql.Expression, proc *Procedure, asOf sql.Expression, catalog sql.Catalog) *Call {
func NewCall(db sql.Database, name string, params []sql.Expression, proc *Procedure, asOf sql.Expression, catalog sql.Catalog, ops []procedures.InterpreterOperation) *Call {
return &Call{
db: db,
Name: name,
Params: params,
Procedure: proc,
asOf: asOf,
cat: catalog,
Ops: ops,
}
}

Expand Down Expand Up @@ -170,9 +178,6 @@ func (c *Call) DebugString() string {
} else {
tp.WriteNode("CALL %s.%s(%s)", c.db.Name(), c.Name, paramStr)
}
if c.Procedure != nil {
tp.WriteChildren(sql.DebugString(c.Procedure.Body))
}

return tp.String()
}
Expand All @@ -197,3 +202,35 @@ func (c *Call) Dispose() {
disposeNode(c.Procedure)
}
}

// SetStatementRunner implements the sql.InterpreterNode interface.
func (c *Call) SetStatementRunner(ctx *sql.Context, runner sql.StatementRunner) sql.Node {
nc := *c
nc.Runner = runner
return &nc
}

// GetRunner implements the sql.InterpreterNode interface.
func (c *Call) GetRunner() sql.StatementRunner {
return c.Runner
}

// GetParameters implements the sql.InterpreterNode interface.
func (c *Call) GetParameters() []sql.Type {
return nil
}

// GetParameterNames implements the sql.InterpreterNode interface.
func (c *Call) GetParameterNames() []string {
return nil
}

// GetStatements implements the sql.InterpreterNode interface.
func (c *Call) GetStatements() []*procedures.InterpreterOperation {
return c.Procedure.Ops
}

// GetReturn implements the sql.InterpreterNode interface.
func (c *Call) GetReturn() sql.Type {
return nil
}
2 changes: 1 addition & 1 deletion sql/plan/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ func NodeRepresentsSelect(s sql.Node) bool {
case *Call:
return NodeRepresentsSelect(node.Procedure)
case *Procedure:
return NodeRepresentsSelect(node.Body)
return NodeRepresentsSelect(node.ExternalProc)
case *Block:
for _, stmt := range node.statements {
if NodeRepresentsSelect(stmt) {
Expand Down
Loading
Loading