Skip to content

Commit

Permalink
planbuilder now use the new type
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <[email protected]>
  • Loading branch information
systay committed Feb 3, 2025
1 parent 4e4c29c commit 8c3064e
Show file tree
Hide file tree
Showing 12 changed files with 65 additions and 35 deletions.
12 changes: 11 additions & 1 deletion go/vt/sqlparser/ast_funcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1195,10 +1195,20 @@ func compliantName(in string) string {
return buf.String()
}

func (node *Select) AddSelectExprs(selectExprs SelectExprs2) {
func (node *Select) AddSelectExprs(selectExprs *SelectExprs2) {
if node.SelectExprs == nil {
node.SelectExprs = &SelectExprs2{}
}
node.SelectExprs.Exprs = append(node.SelectExprs.Exprs, selectExprs.Exprs...)
}

func (node *Select) AddSelectExpr(expr SelectExpr) {
if node.SelectExprs == nil {
node.SelectExprs = &SelectExprs2{}
}
node.SelectExprs.Exprs = append(node.SelectExprs.Exprs, expr)
}

// AddOrder adds an order by element
func (node *Select) AddOrder(order *Order) {
node.OrderBy = append(node.OrderBy, order)
Expand Down
4 changes: 2 additions & 2 deletions go/vt/vtgate/planbuilder/ddl.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ func buildCreateViewCommon(
}

// because we don't trust the schema tracker to have up-to-date info, we don't want to expand any SELECT * here
var expressions []sqlparser.SelectExprs
var expressions []*sqlparser.SelectExprs2
_ = sqlparser.VisitAllSelects(ddlSelect, func(p *sqlparser.Select, idx int) error {
expressions = append(expressions, sqlparser.Clone(p.SelectExprs))
return nil
Expand Down Expand Up @@ -252,7 +252,7 @@ func createViewEnabled(vschema plancontext.VSchema, reservedVars *sqlparser.Rese

// views definition with `select *` should not be expanded as schema tracker might not be up-to-date
// We copy the expressions and restore them after the planning context is created
var expressions []sqlparser.SelectExprs
var expressions []*sqlparser.SelectExprs2
_ = sqlparser.VisitAllSelects(ddlSelect, func(p *sqlparser.Select, idx int) error {
expressions = append(expressions, sqlparser.Clone(p.SelectExprs))
return nil
Expand Down
12 changes: 6 additions & 6 deletions go/vt/vtgate/planbuilder/expression_converter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package planbuilder
import (
"testing"

"vitess.io/vitess/go/slice"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/vt/sqlparser"
Expand Down Expand Up @@ -66,10 +68,8 @@ func TestConversion(t *testing.T) {
}
}

func extract(in sqlparser.SelectExprs) []sqlparser.Expr {
var result []sqlparser.Expr
for _, expr := range in {
result = append(result, expr.(*sqlparser.AliasedExpr).Expr)
}
return result
func extract(in *sqlparser.SelectExprs2) []sqlparser.Expr {
return slice.Map(in.Exprs, func(i sqlparser.SelectExpr) sqlparser.Expr {
return i.(*sqlparser.AliasedExpr).Expr
})
}
3 changes: 2 additions & 1 deletion go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -688,9 +688,10 @@ func autoIncGenerate(gen *operators.Generate) *engine.Generate {
if gen == nil {
return nil
}
exprs := &sqlparser.SelectExprs2{Exprs: []sqlparser.SelectExpr{&sqlparser.Nextval{Expr: &sqlparser.Argument{Name: "n", Type: sqltypes.Int64}}}}
selNext := &sqlparser.Select{
From: []sqlparser.TableExpr{&sqlparser.AliasedTableExpr{Expr: gen.TableName}},
SelectExprs: sqlparser.SelectExprs{&sqlparser.Nextval{Expr: &sqlparser.Argument{Name: "n", Type: sqltypes.Int64}}},
SelectExprs: exprs,
}
return &engine.Generate{
Keyspace: gen.Keyspace,
Expand Down
10 changes: 9 additions & 1 deletion go/vt/vtgate/planbuilder/operators/SQL_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ func (qb *queryBuilder) setWithRollup() {
func (qb *queryBuilder) addProjection(projection sqlparser.SelectExpr) {
switch stmt := qb.stmt.(type) {
case *sqlparser.Select:
if stmt.SelectExprs == nil {
stmt.SelectExprs = &sqlparser.SelectExprs2{}
}
stmt.SelectExprs.Exprs = append(stmt.SelectExprs.Exprs, projection)
return
case *sqlparser.Union:
Expand Down Expand Up @@ -284,7 +287,12 @@ func (qb *queryBuilder) joinWith(other *queryBuilder, onCondition sqlparser.Expr

if sel, isSel := stmt.(*sqlparser.Select); isSel {
otherSel := otherStmt.(*sqlparser.Select)
sel.SelectExprs.Exprs = append(sel.SelectExprs.Exprs, otherSel.SelectExprs.Exprs...)
if sel.SelectExprs == nil {
sel.SelectExprs = &sqlparser.SelectExprs2{}
}
if otherSel.SelectExprs != nil {
sel.SelectExprs.Exprs = append(sel.SelectExprs.Exprs, otherSel.SelectExprs.Exprs...)
}
}

qb.mergeWhereClauses(stmt, otherStmt)
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ func createDeleteWithInputOp(ctx *plancontext.PlanningContext, del *sqlparser.De
dmls := slice.Map(delOps, func(from dmlOp) Operator {
colsList = append(colsList, from.cols)
for _, col := range from.cols {
selectStmt.SelectExprs.Exprs = append(selectStmt.SelectExprs.Exprs, aeWrap(col))
selectStmt.AddSelectExpr(aeWrap(col))
}
return from.op
})
Expand Down
31 changes: 22 additions & 9 deletions go/vt/vtgate/planbuilder/operators/info_schema_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package operators

import (
"maps"
"slices"
"strings"

"vitess.io/vitess/go/mysql/collations"
Expand All @@ -34,14 +35,14 @@ import (
// They are special because we usually don't know at plan-time
// what keyspace the query go to, because we don't see normalized literal values
type InfoSchemaRouting struct {
SysTableTableSchema *sqlparser.Exprs
SysTableTableSchema []sqlparser.Expr
SysTableTableName map[string]sqlparser.Expr
Table *QueryTable
}

func (isr *InfoSchemaRouting) UpdateRoutingParams(ctx *plancontext.PlanningContext, rp *engine.RoutingParameters) {
rp.SysTableTableSchema = nil
for _, expr := range isr.SysTableTableSchema.Exprs {
for _, expr := range isr.SysTableTableSchema {
eexpr, err := evalengine.Translate(expr, &evalengine.Config{
Collation: collations.SystemCollation.Collation,
ResolveColumn: NotImplementedSchemaInfoResolver,
Expand Down Expand Up @@ -70,7 +71,7 @@ func (isr *InfoSchemaRouting) UpdateRoutingParams(ctx *plancontext.PlanningConte

func (isr *InfoSchemaRouting) Clone() Routing {
return &InfoSchemaRouting{
SysTableTableSchema: sqlparser.Clone(isr.SysTableTableSchema),
SysTableTableSchema: slices.Clone(isr.SysTableTableSchema),
SysTableTableName: maps.Clone(isr.SysTableTableName),
Table: isr.Table,
}
Expand All @@ -87,14 +88,14 @@ func (isr *InfoSchemaRouting) updateRoutingLogic(ctx *plancontext.PlanningContex
}

if isTableSchema {
for _, s := range isr.SysTableTableSchema.Exprs {
for _, s := range isr.SysTableTableSchema {
if sqlparser.Equals.Expr(out, s) {
// we already have this expression in the list
// stating it again does not add value
return isr
}
}
isr.SysTableTableSchema.Exprs = append(isr.SysTableTableSchema.Exprs, out)
isr.SysTableTableSchema = append(isr.SysTableTableSchema, out)
} else {
isr.SysTableTableName[bvName] = out
}
Expand Down Expand Up @@ -173,8 +174,8 @@ func tryMergeInfoSchemaRoutings(ctx *plancontext.PlanningContext, routingA, rout
// we have already checked type earlier, so this should always be safe
isrA := routingA.(*InfoSchemaRouting)
isrB := routingB.(*InfoSchemaRouting)
emptyA := len(isrA.SysTableTableName) == 0 && len(isrA.SysTableTableSchema.Exprs) == 0
emptyB := len(isrB.SysTableTableName) == 0 && len(isrB.SysTableTableSchema.Exprs) == 0
emptyA := len(isrA.SysTableTableName) == 0 && len(isrA.SysTableTableSchema) == 0
emptyB := len(isrB.SysTableTableName) == 0 && len(isrB.SysTableTableSchema) == 0

switch {
// if either side has no predicates to help us route, we can merge them
Expand All @@ -184,7 +185,7 @@ func tryMergeInfoSchemaRoutings(ctx *plancontext.PlanningContext, routingA, rout
return m.merge(ctx, lhsRoute, rhsRoute, isrA)

// if we have no schema predicates on either side, we can merge if the table info is the same
case len(isrA.SysTableTableSchema.Exprs) == 0 && len(isrB.SysTableTableSchema.Exprs) == 0:
case len(isrA.SysTableTableSchema) == 0 && len(isrB.SysTableTableSchema) == 0:
for k, expr := range isrB.SysTableTableName {
if e, found := isrA.SysTableTableName[k]; found && !sqlparser.Equals.Expr(expr, e) {
// schema names are the same, but we have contradicting table names, so we give up
Expand All @@ -195,7 +196,7 @@ func tryMergeInfoSchemaRoutings(ctx *plancontext.PlanningContext, routingA, rout
return m.merge(ctx, lhsRoute, rhsRoute, isrA)

// if both sides have the same schema predicate, we can safely merge them
case sqlparser.Equals.RefOfExprs(isrA.SysTableTableSchema, isrB.SysTableTableSchema):
case equalExprs(isrA.SysTableTableSchema, isrB.SysTableTableSchema):
for k, expr := range isrB.SysTableTableName {
isrA.SysTableTableName[k] = expr
}
Expand All @@ -207,6 +208,18 @@ func tryMergeInfoSchemaRoutings(ctx *plancontext.PlanningContext, routingA, rout
}
}

func equalExprs(a, b []sqlparser.Expr) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if !sqlparser.Equals.Expr(a[i], b[i]) {
return false
}
}
return true
}

var (
// these are filled in by the init() function below
schemaColumns57 = map[string]any{}
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ func createUpdateWithInputOp(ctx *plancontext.PlanningContext, upd *sqlparser.Up
colsList = append(colsList, from.cols)
uList = append(uList, from.updList)
for _, col := range from.cols {
selectStmt.SelectExprs.Exprs = append(selectStmt.SelectExprs.Exprs, aeWrap(col))
selectStmt.AddSelectExpr(aeWrap(col))
}
return from.op
})
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/planner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,5 +85,5 @@ func TestBindingSubquery(t *testing.T) {
}

func extractExpr(in *sqlparser.Select, idx int) sqlparser.Expr {
return in.SelectExprs[idx].(*sqlparser.AliasedExpr).Expr
return in.SelectExprs.Exprs[idx].(*sqlparser.AliasedExpr).Expr
}
8 changes: 3 additions & 5 deletions go/vt/vtgate/planbuilder/route.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,9 @@ func prepareTheAST(sel sqlparser.SelectStatement) {
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (bool, error) {
switch node := node.(type) {
case *sqlparser.Select:
if len(node.SelectExprs) == 0 {
node.SelectExprs = []sqlparser.SelectExpr{
&sqlparser.AliasedExpr{
Expr: sqlparser.NewIntLiteral("1"),
},
if node.SelectExprs == nil || len(node.SelectExprs.Exprs) == 0 {
node.SelectExprs = &sqlparser.SelectExprs2{
Exprs: []sqlparser.SelectExpr{&sqlparser.AliasedExpr{Expr: sqlparser.NewIntLiteral("1")}},
}
}
case *sqlparser.ComparisonExpr:
Expand Down
12 changes: 6 additions & 6 deletions go/vt/vtgate/planbuilder/select.go
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,9 @@ func buildSQLCalcFoundRowsPlan(
sel2.OrderBy = nil
sel2.Limit = nil

countStartExpr := []sqlparser.SelectExpr{&sqlparser.AliasedExpr{
Expr: &sqlparser.CountStar{},
}}
countStartExpr := &sqlparser.SelectExprs2{
Exprs: []sqlparser.SelectExpr{&sqlparser.AliasedExpr{Expr: &sqlparser.CountStar{}}},
}
if sel2.GroupBy == nil && sel2.Having == nil {
// if there is no grouping, we can use the same query and
// just replace the SELECT sub-clause to have a single count(*)
Expand Down Expand Up @@ -292,10 +292,10 @@ func handleDualSelects(sel *sqlparser.Select, vschema plancontext.VSchema) (engi
return nil, nil
}

exprs := make([]evalengine.Expr, len(sel.SelectExprs))
cols := make([]string, len(sel.SelectExprs))
exprs := make([]evalengine.Expr, len(sel.SelectExprs.Exprs))
cols := make([]string, len(sel.SelectExprs.Exprs))
var lockFunctions []*engine.LockFunc
for i, e := range sel.SelectExprs {
for i, e := range sel.SelectExprs.Exprs {
expr, ok := e.(*sqlparser.AliasedExpr)
if !ok {
return nil, nil
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/system_variables.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func (pc *sysvarPlanCache) parseAndBuildDefaultValue(sysvar sysvars.SystemVariab
panic(fmt.Sprintf("bug in set plan init - default value for %s not parsable: %s", sysvar.Name, sysvar.Default))
}
sel := stmt.(*sqlparser.Select)
aliasedExpr := sel.SelectExprs[0].(*sqlparser.AliasedExpr)
aliasedExpr := sel.SelectExprs.Exprs[0].(*sqlparser.AliasedExpr)
def, err := evalengine.Translate(aliasedExpr.Expr, &evalengine.Config{
Collation: pc.env.CollationEnv().DefaultConnectionCharset(),
Environment: pc.env,
Expand Down

0 comments on commit 8c3064e

Please sign in to comment.