diff --git a/server/ast/aliased_table_expr.go b/server/ast/aliased_table_expr.go index 8ea1e10427..a12af43f2c 100644 --- a/server/ast/aliased_table_expr.go +++ b/server/ast/aliased_table_expr.go @@ -72,16 +72,18 @@ func nodeAliasedTableExpr(ctx *Context, node *tree.AliasedTableExpr) (*vitess.Al } if inSelect, ok := innerSelect.(*vitess.Select); ok { if len(inSelect.From) == 1 { - if valuesStmt, ok := inSelect.From[0].(*vitess.ValuesStatement); ok { - if len(node.As.Cols) > 0 { - columns := make([]vitess.ColIdent, len(node.As.Cols)) - for i := range node.As.Cols { - columns[i] = vitess.NewColIdent(string(node.As.Cols[i])) + if aliasedTblExpr, ok := inSelect.From[0].(*vitess.AliasedTableExpr); ok { + if valuesStmt, ok := aliasedTblExpr.Expr.(*vitess.ValuesStatement); ok { + if len(node.As.Cols) > 0 { + columns := make([]vitess.ColIdent, len(node.As.Cols)) + for i := range node.As.Cols { + columns[i] = vitess.NewColIdent(string(node.As.Cols[i])) + } + valuesStmt.Columns = columns } - valuesStmt.Columns = columns + aliasExpr = valuesStmt + break } - aliasExpr = valuesStmt - break } } } diff --git a/server/ast/insert.go b/server/ast/insert.go index 2f7ba88893..3b87ae7d9d 100644 --- a/server/ast/insert.go +++ b/server/ast/insert.go @@ -87,9 +87,11 @@ func nodeInsert(ctx *Context, node *tree.Insert) (*vitess.Insert, error) { // For a ValuesStatement with simple rows, GMS expects AliasedValues if vSelect, ok := rows.(*vitess.Select); ok && len(vSelect.From) == 1 { - if valsStmt, ok := vSelect.From[0].(*vitess.ValuesStatement); ok { - rows = &vitess.AliasedValues{ - Values: valsStmt.Rows, + if aliasedStmt, ok := vSelect.From[0].(*vitess.AliasedTableExpr); ok { + if valsStmt, ok := aliasedStmt.Expr.(*vitess.ValuesStatement); ok { + rows = &vitess.AliasedValues{ + Values: valsStmt.Rows, + } } } } diff --git a/server/ast/values_clause.go b/server/ast/values_clause.go index 8bfe36f8ad..7c150295f0 100644 --- a/server/ast/values_clause.go +++ b/server/ast/values_clause.go @@ -33,19 +33,15 @@ func nodeValuesClause(ctx *Context, node *tree.ValuesClause) (*vitess.Select, er } valTuples[i] = vitess.ValTuple(exprs) } - //TODO: ValuesStatement might need to be aliased - //TODO: is the SelectExprs necessary? return &vitess.Select{ SelectExprs: vitess.SelectExprs{ - &vitess.StarExpr{ - TableName: vitess.TableName{ - Name: vitess.NewTableIdent("*"), - }, - }, + &vitess.StarExpr{}, }, From: vitess.TableExprs{ - &vitess.ValuesStatement{ - Rows: valTuples, + &vitess.AliasedTableExpr{ + Expr: &vitess.ValuesStatement{ + Rows: valTuples, + }, }, }, }, nil diff --git a/server/ast/with.go b/server/ast/with.go index 165f0a03dc..8b4551d756 100644 --- a/server/ast/with.go +++ b/server/ast/with.go @@ -17,15 +17,64 @@ package ast import ( "fmt" - vitess "github.com/dolthub/vitess/go/vt/sqlparser" - "github.com/dolthub/doltgresql/postgres/parser/sem/tree" + + vitess "github.com/dolthub/vitess/go/vt/sqlparser" ) +// nodeWith handles *tree.CTE nodes. +func nodeCTE(ctx *Context, node *tree.CTE) (*vitess.CommonTableExpr, error) { + if node == nil { + return nil, nil + } + + alias := vitess.NewTableIdent(string(node.Name.Alias)) + cols := make([]vitess.ColIdent, len(node.Name.Cols)) + for i, col := range node.Name.Cols { + cols[i] = vitess.NewColIdent(string(col)) + } + + subSelect, ok := node.Stmt.(*tree.Select) + if !ok { + return nil, fmt.Errorf("unsupported CTE statement type: %T", node.Stmt) + } + + selectStmt, err := nodeSelect(ctx, subSelect) + if err != nil { + return nil, err + } + + subQuery := &vitess.Subquery{ + Select: selectStmt, + } + + return &vitess.CommonTableExpr{ + AliasedTableExpr: &vitess.AliasedTableExpr{ + Expr: subQuery, + As: alias, + Auth: vitess.AuthInformation{AuthType: vitess.AuthType_IGNORE}, + }, + Columns: cols, + }, nil +} + // nodeWith handles *tree.With nodes. func nodeWith(ctx *Context, node *tree.With) (*vitess.With, error) { if node == nil { return nil, nil } - return &vitess.With{}, fmt.Errorf("WITH is not yet supported") + + ctes := make([]vitess.TableExpr, len(node.CTEList)) + for i, cte := range node.CTEList { + var err error + ctes[i], err = nodeCTE(ctx, cte) + if err != nil { + return nil, err + } + } + + return &vitess.With{ + Recursive: node.Recursive, + Ctes: ctes, + }, nil } diff --git a/server/auth/auth_handler.go b/server/auth/auth_handler.go index c06469df48..a8f364069e 100644 --- a/server/auth/auth_handler.go +++ b/server/auth/auth_handler.go @@ -97,6 +97,9 @@ func (h *AuthorizationHandler) HandleAuth(ctx *sql.Context, aqs sql.Authorizatio var privileges []Privilege switch auth.AuthType { + case AuthType_IGNORE: + // This means that authorization is being handled elsewhere (such as a child or parent), and should be ignored here + return nil case AuthType_DELETE: privileges = []Privilege{Privilege_DELETE} case AuthType_INSERT: diff --git a/testing/go/auth_test.go b/testing/go/auth_test.go index 7f9b31463b..38b28b1dff 100644 --- a/testing/go/auth_test.go +++ b/testing/go/auth_test.go @@ -429,6 +429,12 @@ func TestAuthTests(t *testing.T) { Password: `a`, ExpectedErr: `denied`, }, + { + Query: `WITH cte AS (SELECT * FROM test ORDER BY pk) SELECT * FROM cte;`, + Username: `user1`, + Password: `a`, + ExpectedErr: `denied`, + }, { Query: `INSERT INTO test VALUES (10);`, Username: `user1`, @@ -459,6 +465,13 @@ func TestAuthTests(t *testing.T) { Password: `a`, Expected: []sql.Row{{1}, {6}, {7}}, }, + { + Skip: true, // CTEs are seen as different tables + Query: `WITH cte AS (SELECT * FROM test ORDER BY pk) SELECT * FROM cte;`, + Username: `user1`, + Password: `a`, + Expected: []sql.Row{{1}, {6}, {7}}, + }, { Query: `INSERT INTO test VALUES (10);`, Username: `user1`, diff --git a/testing/go/dolt_tables_test.go b/testing/go/dolt_tables_test.go index 5a3ffa9d76..b1533be026 100755 --- a/testing/go/dolt_tables_test.go +++ b/testing/go/dolt_tables_test.go @@ -118,7 +118,6 @@ func TestUserSpaceDoltTables(t *testing.T) { }, Assertions: []ScriptTestAssertion{ { - // TODO: WITH is not yet supported Query: `WITH sorted_diffs_by_pk AS (SELECT "to_id", diff --git a/testing/go/with_test.go b/testing/go/with_test.go new file mode 100644 index 0000000000..3fdd03e3a4 --- /dev/null +++ b/testing/go/with_test.go @@ -0,0 +1,80 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package _go + +import ( + "testing" + + "github.com/dolthub/go-mysql-server/sql" +) + +func TestWithStatements(t *testing.T) { + RunScripts(t, WithStatementTests) +} + +var WithStatementTests = []ScriptTest{ + { + Name: "basic values statements", + SetUpScript: []string{ + "create table t (i int primary key);", + "insert into t values (1), (2), (3);", + }, + Assertions: []ScriptTestAssertion{ + { + Query: "with cte as (select 1) select * from cte;", + Expected: []sql.Row{ + {1}, + }, + }, + { + Query: "with cte as (select 1, 2, 3 union select 4, 5, 6) select * from cte;", + Expected: []sql.Row{ + {1, 2, 3}, + {4, 5, 6}, + }, + }, + { + Query: "with cte as (values (1)) select * from cte;", + Expected: []sql.Row{ + {1}, + }, + }, + { + Query: "with cte as (values (1, 2, 3) union values (4, 5, 6)) select * from cte;", + Expected: []sql.Row{ + {1, 2, 3}, + {4, 5, 6}, + }, + }, + { + Query: "with cte as (select 1, 2, 3 union values (4, 5, 6)) select * from cte;", + Expected: []sql.Row{ + {1, 2, 3}, + {4, 5, 6}, + }, + }, + { + Query: "with recursive cte(x) as (select 1 union all select x + 1 from cte) select * from cte limit 5;", + Expected: []sql.Row{ + {1}, + {2}, + {3}, + {4}, + {5}, + }, + }, + }, + }, +}