From bc94876a536ae7f50aca1d4e4862db1e9500e1cb Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Fri, 20 Jan 2023 09:11:19 +0100 Subject: [PATCH] Move more rewriting to SafeRewrite (#12063) * Update SafeRewrite to allow for replacing a node that we don't want to visit the children of Signed-off-by: Andres Taylor * Clean up ast_rewriting before refactoring it Signed-off-by: Andres Taylor * feat: fix usages of Rewrite in ast_rewriting Signed-off-by: Manan Gupta * feat: fix usages of Rewrite in normalizer Signed-off-by: Manan Gupta * safe rewrite having clause during horizon planning Signed-off-by: Florent Poinsard * move rewriteHavingAndOrderBy to safeRewrite Signed-off-by: Florent Poinsard * refactor: clean up code Signed-off-by: Andres Taylor Signed-off-by: Andres Taylor Signed-off-by: Manan Gupta Signed-off-by: Florent Poinsard Co-authored-by: Manan Gupta Co-authored-by: Florent Poinsard --- go/vt/sqlparser/ast_funcs.go | 16 +- go/vt/sqlparser/ast_rewriting.go | 298 ++++++++++-------- go/vt/sqlparser/normalizer.go | 71 +++-- go/vt/sqlparser/rewriter_api.go | 17 +- go/vt/vtgate/planbuilder/horizon_planning.go | 2 +- .../planbuilder/operator_transformers.go | 6 +- .../planbuilder/operators/queryprojection.go | 55 ++-- go/vt/vtgate/semantics/early_rewriter.go | 72 +++-- 8 files changed, 316 insertions(+), 221 deletions(-) diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 844445a8253..fcced205ddb 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -443,7 +443,7 @@ func NewWhere(typ WhereType, expr Expr) *Where { // and replaces it with to. If from matches root, // then to is returned. func ReplaceExpr(root, from, to Expr) Expr { - tmp := Rewrite(root, replaceExpr(from, to), nil) + tmp := SafeRewrite(root, stopWalking, replaceExpr(from, to)) expr, success := tmp.(Expr) if !success { @@ -454,16 +454,20 @@ func ReplaceExpr(root, from, to Expr) Expr { return expr } +func stopWalking(e SQLNode, _ SQLNode) bool { + switch e.(type) { + case *ExistsExpr, *Literal, *Subquery, *ValuesFuncExpr, *Default: + return false + default: + return true + } +} + func replaceExpr(from, to Expr) func(cursor *Cursor) bool { return func(cursor *Cursor) bool { if cursor.Node() == from { cursor.Replace(to) } - switch cursor.Node().(type) { - case *ExistsExpr, *Literal, *Subquery, *ValuesFuncExpr, *Default: - return false - } - return true } } diff --git a/go/vt/sqlparser/ast_rewriting.go b/go/vt/sqlparser/ast_rewriting.go index 6f997a11f6a..9a66202e7be 100644 --- a/go/vt/sqlparser/ast_rewriting.go +++ b/go/vt/sqlparser/ast_rewriting.go @@ -230,7 +230,10 @@ func RewriteAST( ) (*RewriteASTResult, error) { er := newASTRewriter(keyspace, selectLimit, setVarComment, sysVars, views) er.shouldRewriteDatabaseFunc = shouldRewriteDatabaseFunc(in) - result := Rewrite(in, er.rewrite, nil) + result := SafeRewrite(in, er.rewriteDown, er.rewriteUp) + if er.err != nil { + return nil, er.err + } out, ok := result.(Statement) if !ok { @@ -309,7 +312,7 @@ const ( func (er *astRewriter) rewriteAliasedExpr(node *AliasedExpr) (*BindVarNeeds, error) { inner := newASTRewriter(er.keyspace, er.selectLimit, er.setVarComment, er.sysVars, er.views) inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc - tmp := Rewrite(node.Expr, inner.rewrite, nil) + tmp := SafeRewrite(node.Expr, inner.rewriteDown, inner.rewriteUp) newExpr, ok := tmp.(Expr) if !ok { return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to rewrite AST. function expected to return Expr returned a %s", String(tmp)) @@ -318,7 +321,15 @@ func (er *astRewriter) rewriteAliasedExpr(node *AliasedExpr) (*BindVarNeeds, err return inner.bindVars, nil } -func (er *astRewriter) rewrite(cursor *Cursor) bool { +func (er *astRewriter) rewriteDown(node SQLNode, _ SQLNode) bool { + switch node := node.(type) { + case *Select: + er.visitSelect(node) + } + return true +} + +func (er *astRewriter) rewriteUp(cursor *Cursor) bool { // Add SET_VAR comment to this node if it supports it and is needed if supportOptimizerHint, supportsOptimizerHint := cursor.Node().(SupportOptimizerHint); supportsOptimizerHint && er.setVarComment != "" { newComments, err := supportOptimizerHint.GetParsedComments().AddQueryHint(er.setVarComment) @@ -330,118 +341,145 @@ func (er *astRewriter) rewrite(cursor *Cursor) bool { } switch node := cursor.Node().(type) { - case *Select: - for _, col := range node.SelectExprs { - _, hasStar := col.(*StarExpr) - if hasStar { - er.hasStarInSelect = true - } - - aliasedExpr, ok := col.(*AliasedExpr) - if ok && aliasedExpr.As.IsEmpty() { - buf := NewTrackedBuffer(nil) - aliasedExpr.Expr.Format(buf) - // select last_insert_id() -> select :__lastInsertId as `last_insert_id()` - innerBindVarNeeds, err := er.rewriteAliasedExpr(aliasedExpr) - if err != nil { - er.err = err - return false - } - if innerBindVarNeeds.HasRewrites() { - aliasedExpr.As = NewIdentifierCI(buf.String()) - } - er.bindVars.MergeWith(innerBindVarNeeds) - } - } - // set select limit if explicitly not set when sql_select_limit is set on the connection. - if er.selectLimit > 0 && node.Limit == nil { - node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(er.selectLimit))} - } case *Union: - // set select limit if explicitly not set when sql_select_limit is set on the connection. - if er.selectLimit > 0 && node.Limit == nil { - node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(er.selectLimit))} - } + er.rewriteUnion(node) case *FuncExpr: er.funcRewrite(cursor, node) case *Variable: - // Iff we are in SET, we want to change the scope of variables if a modifier has been set - // and only on the lhs of the assignment: - // set session sql_mode = @someElse - // here we need to change the scope of `sql_mode` and not of `@someElse` - if v, isSet := cursor.Parent().(*SetExpr); isSet && v.Var == node { - break - } - switch node.Scope { - case VariableScope: - er.udvRewrite(cursor, node) - case GlobalScope, SessionScope, NextTxScope: - er.sysVarRewrite(cursor, node) - } + er.rewriteVariable(cursor, node) case *Subquery: er.unnestSubQueries(cursor, node) case *NotExpr: - switch inner := node.Expr.(type) { - case *ComparisonExpr: - // not col = 42 => col != 42 - // not col > 42 => col <= 42 - // etc - canChange, inverse := inverseOp(inner.Operator) - if canChange { - inner.Operator = inverse - cursor.Replace(inner) - } - case *NotExpr: - // not not true => true - cursor.Replace(inner.Expr) - case BoolVal: - // not true => false - inner = !inner - cursor.Replace(inner) - } + er.rewriteNotExpr(cursor, node) case *AliasedTableExpr: - aliasTableName, ok := node.Expr.(TableName) - if !ok { - break - } - // Qualifier should not be added to dual table - tblName := aliasTableName.Name.String() - if tblName == "dual" { - break + er.rewriteAliasedTable(cursor, node) + case *ShowBasic: + er.rewriteShowBasic(node) + case *ExistsExpr: + er.existsRewrite(cursor, node) + } + return true +} + +func (er *astRewriter) rewriteUnion(node *Union) { + // set select limit if explicitly not set when sql_select_limit is set on the connection. + if er.selectLimit > 0 && node.Limit == nil { + node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(er.selectLimit))} + } +} + +func (er *astRewriter) rewriteAliasedTable(cursor *Cursor, node *AliasedTableExpr) { + aliasTableName, ok := node.Expr.(TableName) + if !ok { + return + } + + // Qualifier should not be added to dual table + tblName := aliasTableName.Name.String() + if tblName == "dual" { + return + } + + if SystemSchema(er.keyspace) { + if aliasTableName.Qualifier.IsEmpty() { + aliasTableName.Qualifier = NewIdentifierCS(er.keyspace) + node.Expr = aliasTableName + cursor.Replace(node) } - if SystemSchema(er.keyspace) { - if aliasTableName.Qualifier.IsEmpty() { - aliasTableName.Qualifier = NewIdentifierCS(er.keyspace) - node.Expr = aliasTableName - cursor.Replace(node) - } - break + return + } + + // Could we be dealing with a view? + if er.views == nil { + return + } + view := er.views.FindView(aliasTableName) + if view == nil { + return + } + + // Aha! It's a view. Let's replace it with a derived table + node.Expr = &DerivedTable{Select: CloneSelectStatement(view)} + if node.As.IsEmpty() { + node.As = NewIdentifierCS(tblName) + } +} + +func (er *astRewriter) rewriteShowBasic(node *ShowBasic) { + if node.Command == VariableGlobal || node.Command == VariableSession { + varsToAdd := sysvars.GetInterestingVariables() + for _, sysVar := range varsToAdd { + er.bindVars.AddSysVar(sysVar) } - if er.views == nil { - break + } +} + +func (er *astRewriter) rewriteNotExpr(cursor *Cursor, node *NotExpr) { + switch inner := node.Expr.(type) { + case *ComparisonExpr: + // not col = 42 => col != 42 + // not col > 42 => col <= 42 + // etc + canChange, inverse := inverseOp(inner.Operator) + if canChange { + inner.Operator = inverse + cursor.Replace(inner) } - view := er.views.FindView(aliasTableName) - if view == nil { - break + case *NotExpr: + // not not true => true + cursor.Replace(inner.Expr) + case BoolVal: + // not true => false + inner = !inner + cursor.Replace(inner) + } +} + +func (er *astRewriter) rewriteVariable(cursor *Cursor, node *Variable) { + // Iff we are in SET, we want to change the scope of variables if a modifier has been set + // and only on the lhs of the assignment: + // set session sql_mode = @someElse + // here we need to change the scope of `sql_mode` and not of `@someElse` + if v, isSet := cursor.Parent().(*SetExpr); isSet && v.Var == node { + return + } + switch node.Scope { + case VariableScope: + er.udvRewrite(cursor, node) + case GlobalScope, SessionScope, NextTxScope: + er.sysVarRewrite(cursor, node) + } +} + +func (er *astRewriter) visitSelect(node *Select) { + for _, col := range node.SelectExprs { + if _, hasStar := col.(*StarExpr); hasStar { + er.hasStarInSelect = true + continue } - node.Expr = &DerivedTable{ - Select: CloneSelectStatement(view), + aliasedExpr, ok := col.(*AliasedExpr) + if !ok || !aliasedExpr.As.IsEmpty() { + continue } - if node.As.IsEmpty() { - node.As = NewIdentifierCS(tblName) + buf := NewTrackedBuffer(nil) + aliasedExpr.Expr.Format(buf) + // select last_insert_id() -> select :__lastInsertId as `last_insert_id()` + innerBindVarNeeds, err := er.rewriteAliasedExpr(aliasedExpr) + if err != nil { + er.err = err + return } - case *ShowBasic: - if node.Command == VariableGlobal || node.Command == VariableSession { - varsToAdd := sysvars.GetInterestingVariables() - for _, sysVar := range varsToAdd { - er.bindVars.AddSysVar(sysVar) - } + if innerBindVarNeeds.HasRewrites() { + aliasedExpr.As = NewIdentifierCI(buf.String()) } - case *ExistsExpr: - er.existsRewrite(cursor, node) + er.bindVars.MergeWith(innerBindVarNeeds) + + } + // set select limit if explicitly not set when sql_select_limit is set on the connection. + if er.selectLimit > 0 && node.Limit == nil { + node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(er.selectLimit))} } - return true } func inverseOp(i ComparisonExprOperator) (bool, ComparisonExprOperator) { @@ -527,17 +565,15 @@ var funcRewrites = map[string]string{ func (er *astRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) { bindVar, found := funcRewrites[node.Name.Lowered()] - if found { - if bindVar == DBVarName && !er.shouldRewriteDatabaseFunc { - return - } - if len(node.Exprs) > 0 { - er.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Argument to %s() not supported", node.Name.Lowered()) - return - } - cursor.Replace(bindVarExpression(bindVar)) - er.bindVars.AddFuncResult(bindVar) + if !found || (bindVar == DBVarName && !er.shouldRewriteDatabaseFunc) { + return + } + if len(node.Exprs) > 0 { + er.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Argument to %s() not supported", node.Name.Lowered()) + return } + cursor.Replace(bindVarExpression(bindVar)) + er.bindVars.AddFuncResult(bindVar) } func (er *astRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquery) { @@ -582,7 +618,7 @@ func (er *astRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquery) { er.bindVars.NoteRewrite() // we need to make sure that the inner expression also gets rewritten, // so we fire off another rewriter traversal here - rewritten := Rewrite(expr.Expr, er.rewrite, nil) + rewritten := SafeRewrite(expr.Expr, er.rewriteDown, er.rewriteUp) // Here we need to handle the subquery rewrite in case in occurs in an IN clause // For example, SELECT id FROM user WHERE id IN (SELECT 1 FROM DUAL) @@ -604,31 +640,33 @@ func (er *astRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquery) { } func (er *astRewriter) existsRewrite(cursor *Cursor, node *ExistsExpr) { - switch node := node.Subquery.Select.(type) { - case *Select: - if node.Limit == nil { - node.Limit = &Limit{} - } - node.Limit.Rowcount = NewIntLiteral("1") + sel, ok := node.Subquery.Select.(*Select) + if !ok { + return + } - if node.Having != nil { - // If the query has HAVING, we can't take any shortcuts - return - } + if sel.Limit == nil { + sel.Limit = &Limit{} + } + sel.Limit.Rowcount = NewIntLiteral("1") - if len(node.GroupBy) == 0 && node.SelectExprs.AllAggregation() { - // in these situations, we are guaranteed to always get a non-empty result, - // so we can replace the EXISTS with a literal true - cursor.Replace(BoolVal(true)) - } + if sel.Having != nil { + // If the query has HAVING, we can't take any shortcuts + return + } - // If we are not doing HAVING, we can safely replace all select expressions with a - // single `1` and remove any grouping - node.SelectExprs = SelectExprs{ - &AliasedExpr{Expr: NewIntLiteral("1")}, - } - node.GroupBy = nil + if len(sel.GroupBy) == 0 && sel.SelectExprs.AllAggregation() { + // in these situations, we are guaranteed to always get a non-empty result, + // so we can replace the EXISTS with a literal true + cursor.Replace(BoolVal(true)) + } + + // If we are not doing HAVING, we can safely replace all select expressions with a + // single `1` and remove any grouping + sel.SelectExprs = SelectExprs{ + &AliasedExpr{Expr: NewIntLiteral("1")}, } + sel.GroupBy = nil } func bindVarExpression(name string) Expr { diff --git a/go/vt/sqlparser/normalizer.go b/go/vt/sqlparser/normalizer.go index b827cf853c7..6d7ee526bf6 100644 --- a/go/vt/sqlparser/normalizer.go +++ b/go/vt/sqlparser/normalizer.go @@ -37,7 +37,7 @@ type BindVars map[string]struct{} // treated as distinct. func Normalize(stmt Statement, reserved *ReservedVars, bindVars map[string]*querypb.BindVariable) error { nz := newNormalizer(reserved, bindVars) - _ = Rewrite(stmt, nz.WalkStatement, nil) + _ = SafeRewrite(stmt, nz.walkStatementDown, nz.walkStatementUp) return nz.err } @@ -57,24 +57,35 @@ func newNormalizer(reserved *ReservedVars, bindVars map[string]*querypb.BindVari } } -// WalkStatement is the top level walk function. +// walkStatementUp is one half of the top level walk function. +func (nz *normalizer) walkStatementUp(cursor *Cursor) bool { + if nz.err != nil { + return false + } + node, isLiteral := cursor.Node().(*Literal) + if !isLiteral { + return true + } + nz.convertLiteral(node, cursor) + return nz.err == nil // only continue if we haven't found any errors +} + +// walkStatementDown is the top level walk function. // If it encounters a Select, it switches to a mode // where variables are deduped. -func (nz *normalizer) WalkStatement(cursor *Cursor) bool { - switch node := cursor.Node().(type) { +func (nz *normalizer) walkStatementDown(node, parent SQLNode) bool { + switch node := node.(type) { // no need to normalize the statement types case *Set, *Show, *Begin, *Commit, *Rollback, *Savepoint, DDLStatement, *SRollback, *Release, *OtherAdmin, *OtherRead: return false case *Select: - _, isDerived := cursor.Parent().(*DerivedTable) + _, isDerived := parent.(*DerivedTable) var tmp bool tmp, nz.inDerived = nz.inDerived, isDerived - _ = Rewrite(node, nz.WalkSelect, nil) + _ = SafeRewrite(node, nz.walkDownSelect, nz.walkUpSelect) // Don't continue nz.inDerived = tmp return false - case *Literal: - nz.convertLiteral(node, cursor) case *ComparisonExpr: nz.convertComparison(node) case *UpdateExpr: @@ -89,32 +100,25 @@ func (nz *normalizer) WalkStatement(cursor *Cursor) bool { return nz.err == nil // only continue if we haven't found any errors } -// WalkSelect normalizes the AST in Select mode. -func (nz *normalizer) WalkSelect(cursor *Cursor) bool { - switch node := cursor.Node().(type) { +// walkDownSelect normalizes the AST in Select mode. +func (nz *normalizer) walkDownSelect(node, parent SQLNode) bool { + switch node := node.(type) { case *Select: - _, isDerived := cursor.Parent().(*DerivedTable) + _, isDerived := parent.(*DerivedTable) if !isDerived { return true } var tmp bool tmp, nz.inDerived = nz.inDerived, isDerived - _ = Rewrite(node, nz.WalkSelect, nil) + // initiating a new AST walk here means that we might change something while walking down on the tree, + // but since we are only changing literals, we can be safe that we are not changing the SELECT struct, + // only something much further down, and that should be safe + _ = SafeRewrite(node, nz.walkDownSelect, nz.walkUpSelect) // Don't continue nz.inDerived = tmp return false case SelectExprs: return !nz.inDerived - case *Literal: - parent := cursor.Parent() - switch parent.(type) { - case *Order, GroupBy: - return false - case *Limit: - nz.convertLiteral(node, cursor) - default: - nz.convertLiteralDedup(node, cursor) - } case *ComparisonExpr: nz.convertComparison(node) case *FramePoint: @@ -131,6 +135,27 @@ func (nz *normalizer) WalkSelect(cursor *Cursor) bool { return nz.err == nil // only continue if we haven't found any errors } +// walkUpSelect normalizes the Literals in Select mode. +func (nz *normalizer) walkUpSelect(cursor *Cursor) bool { + if nz.err != nil { + return false + } + node, isLiteral := cursor.Node().(*Literal) + if !isLiteral { + return true + } + parent := cursor.Parent() + switch parent.(type) { + case *Order, GroupBy: + return false + case *Limit: + nz.convertLiteral(node, cursor) + default: + nz.convertLiteralDedup(node, cursor) + } + return nz.err == nil // only continue if we haven't found any errors +} + func validateLiteral(node *Literal) (err error) { switch node.Type { case DateVal: diff --git a/go/vt/sqlparser/rewriter_api.go b/go/vt/sqlparser/rewriter_api.go index 7f4a3a3ff82..05d371bad13 100644 --- a/go/vt/sqlparser/rewriter_api.go +++ b/go/vt/sqlparser/rewriter_api.go @@ -53,11 +53,22 @@ func Rewrite(node SQLNode, pre, post ApplyFunc) (result SQLNode) { // SafeRewrite does not allow replacing nodes on the down walk of the tree walking // Long term this is the only Rewrite functionality we want -func SafeRewrite(node SQLNode, down func(SQLNode) bool, up ApplyFunc) (result SQLNode) { +func SafeRewrite( + node SQLNode, + shouldVisitChildren func(node SQLNode, parent SQLNode) bool, + up ApplyFunc, +) SQLNode { var pre func(cursor *Cursor) bool - if down != nil { + if shouldVisitChildren != nil { pre = func(cursor *Cursor) bool { - return down(cursor.Node()) + visitChildren := shouldVisitChildren(cursor.Node(), cursor.Parent()) + if !visitChildren && up != nil { + // this gives the up-function a chance to do work on this node even if we are not visiting the children + // unfortunately, if the `up` function also returns false for this node, we won't abort the rest of the + // tree walking. This is a temporary limitation, and will be fixed when we generated the correct code + up(cursor) + } + return visitChildren } } return Rewrite(node, pre, up) diff --git a/go/vt/vtgate/planbuilder/horizon_planning.go b/go/vt/vtgate/planbuilder/horizon_planning.go index 47e22b72598..29e832faf6d 100644 --- a/go/vt/vtgate/planbuilder/horizon_planning.go +++ b/go/vt/vtgate/planbuilder/horizon_planning.go @@ -267,7 +267,7 @@ func (hp *horizonPlanning) planAggrUsingOA( if hp.sel.Having != nil { rewriter := hp.qp.AggrRewriter(ctx) - sqlparser.Rewrite(hp.sel.Having.Expr, rewriter.Rewrite(), nil) + sqlparser.SafeRewrite(hp.sel.Having.Expr, rewriter.RewriteDown(), rewriter.RewriteUp()) if rewriter.Err != nil { return nil, rewriter.Err } diff --git a/go/vt/vtgate/planbuilder/operator_transformers.go b/go/vt/vtgate/planbuilder/operator_transformers.go index 0806651d6db..9dc2dcd6e0f 100644 --- a/go/vt/vtgate/planbuilder/operator_transformers.go +++ b/go/vt/vtgate/planbuilder/operator_transformers.go @@ -295,11 +295,11 @@ func replaceSubQuery(ctx *plancontext.PlanningContext, sel sqlparser.Statement) return } sqr := &subQReplacer{subqueryToReplace: extractedSubqueries} - sqlparser.Rewrite(sel, sqr.replacer, nil) + sqlparser.SafeRewrite(sel, nil, sqr.replacer) for sqr.replaced { // to handle subqueries inside subqueries, we need to do this again and again until no replacements are left sqr.replaced = false - sqlparser.Rewrite(sel, sqr.replacer, nil) + sqlparser.SafeRewrite(sel, nil, sqr.replacer) } } @@ -646,7 +646,7 @@ func (sqr *subQReplacer) replacer(cursor *sqlparser.Cursor) bool { if ext.GetArgName() == replaceByExpr.GetArgName() { cursor.Replace(ext.Original) sqr.replaced = true - return false + return true } } return true diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index 835fcbe0f37..29e356c6650 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -185,36 +185,49 @@ func CreateQPFromSelect(ctx *plancontext.PlanningContext, sel *sqlparser.Select) return qp, nil } -// Rewrite will go through an expression, add aggregations to the QP, and rewrite them to use column offset -func (ar *AggrRewriter) Rewrite() func(*sqlparser.Cursor) bool { +// RewriteDown stops the walker from entering inside aggregation functions +func (ar *AggrRewriter) RewriteDown() func(sqlparser.SQLNode, sqlparser.SQLNode) bool { + return func(node, _ sqlparser.SQLNode) bool { + if ar.Err != nil { + return true + } + _, ok := node.(sqlparser.AggrFunc) + return !ok + } +} + +// RewriteUp will go through an expression, add aggregations to the QP, and rewrite them to use column offset +func (ar *AggrRewriter) RewriteUp() func(*sqlparser.Cursor) bool { return func(cursor *sqlparser.Cursor) bool { if ar.Err != nil { return false } sqlNode := cursor.Node() - if fExp, ok := sqlNode.(sqlparser.AggrFunc); ok { - for offset, expr := range ar.qp.SelectExprs { - ae, err := expr.GetAliasedExpr() - if err != nil { - ar.Err = err - return false - } - if ar.st.EqualsExpr(ae.Expr, fExp) { - cursor.Replace(sqlparser.NewOffset(offset, fExp)) - return false // no need to visit aggregation children - } + fExp, ok := sqlNode.(sqlparser.AggrFunc) + if !ok { + return true + } + for offset, expr := range ar.qp.SelectExprs { + ae, err := expr.GetAliasedExpr() + if err != nil { + ar.Err = err + return false } - - col := SelectExpr{ - Aggr: true, - Col: &sqlparser.AliasedExpr{Expr: fExp}, + if ar.st.EqualsExpr(ae.Expr, fExp) { + cursor.Replace(sqlparser.NewOffset(offset, fExp)) + return true } - ar.qp.HasAggr = true + } - cursor.Replace(sqlparser.NewOffset(len(ar.qp.SelectExprs), fExp)) - ar.qp.SelectExprs = append(ar.qp.SelectExprs, col) - ar.qp.AddedColumn++ + col := SelectExpr{ + Aggr: true, + Col: &sqlparser.AliasedExpr{Expr: fExp}, } + ar.qp.HasAggr = true + + cursor.Replace(sqlparser.NewOffset(len(ar.qp.SelectExprs), fExp)) + ar.qp.SelectExprs = append(ar.qp.SelectExprs, col) + ar.qp.AddedColumn++ return true } diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index ded398e8b21..e575f78e6ee 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -40,7 +40,7 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { if node.Type != sqlparser.HavingClause { return nil } - rewriteHavingAndOrderBy(cursor, node) + rewriteHavingAndOrderBy(node, cursor.Parent()) case sqlparser.SelectExprs: _, isSel := cursor.Parent().(*sqlparser.Select) if !isSel { @@ -57,7 +57,7 @@ func (r *earlyRewriter) down(cursor *sqlparser.Cursor) error { } case sqlparser.OrderBy: r.clause = "order clause" - rewriteHavingAndOrderBy(cursor, node) + rewriteHavingAndOrderBy(node, cursor.Parent()) case *sqlparser.OrExpr: newNode := rewriteOrFalse(*node) if newNode != nil { @@ -143,48 +143,52 @@ func (r *earlyRewriter) expandStar(cursor *sqlparser.Cursor, node sqlparser.Sele // HAVING/ORDER BY clause is inside an aggregation function // // This is a fucking weird scoping rule, but it's what MySQL seems to do... ¯\_(ツ)_/¯ -func rewriteHavingAndOrderBy(cursor *sqlparser.Cursor, node sqlparser.SQLNode) { - sel, isSel := cursor.Parent().(*sqlparser.Select) +func rewriteHavingAndOrderBy(node, parent sqlparser.SQLNode) { + // TODO - clean up and comment this mess + sel, isSel := parent.(*sqlparser.Select) if !isSel { return } - sqlparser.Rewrite(node, func(inner *sqlparser.Cursor) bool { - switch col := inner.Node().(type) { - case *sqlparser.Subquery: - return false - case *sqlparser.ColName: - if !col.Qualifier.IsEmpty() { + + sqlparser.SafeRewrite(node, func(node, _ sqlparser.SQLNode) bool { + _, isSubQ := node.(*sqlparser.Subquery) + return !isSubQ + }, func(cursor *sqlparser.Cursor) bool { + col, ok := cursor.Node().(*sqlparser.ColName) + if !ok { + return true + } + if !col.Qualifier.IsEmpty() { + return true + } + _, parentIsAggr := cursor.Parent().(sqlparser.AggrFunc) + for _, e := range sel.SelectExprs { + ae, ok := e.(*sqlparser.AliasedExpr) + if !ok || !ae.As.Equal(col.Name) { + continue + } + _, aliasPointsToAggr := ae.Expr.(sqlparser.AggrFunc) + if parentIsAggr && aliasPointsToAggr { return false } - _, parentIsAggr := inner.Parent().(sqlparser.AggrFunc) - for _, e := range sel.SelectExprs { - ae, ok := e.(*sqlparser.AliasedExpr) - if !ok || !ae.As.Equal(col.Name) { - continue - } - _, aliasPointsToAggr := ae.Expr.(sqlparser.AggrFunc) - if parentIsAggr && aliasPointsToAggr { - return false - } - safeToRewrite := true - _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { - switch node.(type) { - case *sqlparser.ColName: - safeToRewrite = false - return false, nil - case sqlparser.AggrFunc: - return false, nil - } - return true, nil - }, ae.Expr) - if safeToRewrite { - inner.Replace(ae.Expr) + safeToRewrite := true + _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { + switch node.(type) { + case *sqlparser.ColName: + safeToRewrite = false + return false, nil + case sqlparser.AggrFunc: + return false, nil } + return true, nil + }, ae.Expr) + if safeToRewrite { + cursor.Replace(ae.Expr) } } return true - }, nil) + }) } func (r *earlyRewriter) rewriteOrderByExpr(node *sqlparser.Literal) (sqlparser.Expr, error) {