Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make sure no AST types are bare slices #17674

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions go/test/endtoend/vtgate/queries/random/query_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ func (sg *selectGenerator) randomSelect() {
}

// make sure we have at least one select expression
for isRandomExpr || len(sg.sel.SelectExprs) == 0 {
for isRandomExpr || len(sg.sel.SelectExprs.Exprs) == 0 {
// TODO: if the random expression is an int literal,
// TODO: and if the query is (potentially) an aggregate query,
// TODO: then we must group by the random expression,
Expand Down Expand Up @@ -395,7 +395,7 @@ func (sg *selectGenerator) createJoin(tables []tableT) {

// returns 1-3 random expressions based on the last two elements of tables
// tables should have at least two elements
func (sg *selectGenerator) createJoinPredicates(tables []tableT) sqlparser.Exprs {
func (sg *selectGenerator) createJoinPredicates(tables []tableT) []sqlparser.Expr {
if len(tables) < 2 {
log.Fatalf("tables has %d elements, needs at least 2", len(tables))
}
Expand Down Expand Up @@ -427,7 +427,7 @@ func (sg *selectGenerator) createGroupBy(tables []tableT) (grouping []column) {

// add to select
if rand.IntN(2) < 1 {
sg.sel.SelectExprs = append(sg.sel.SelectExprs, newAliasedColumn(col, ""))
sg.sel.AddSelectExpr(newAliasedColumn(col, ""))
grouping = append(grouping, col)
}
}
Expand All @@ -437,13 +437,13 @@ func (sg *selectGenerator) createGroupBy(tables []tableT) (grouping []column) {

// aliasGroupingColumns randomly aliases the grouping columns in the SelectExprs
func (sg *selectGenerator) aliasGroupingColumns(grouping []column) []column {
if len(grouping) != len(sg.sel.SelectExprs) {
log.Fatalf("grouping (length: %d) and sg.sel.SelectExprs (length: %d) should have the same length at this point", len(grouping), len(sg.sel.SelectExprs))
if len(grouping) != len(sg.sel.SelectExprs.Exprs) {
log.Fatalf("grouping (length: %d) and sg.sel.SelectExprs (length: %d) should have the same length at this point", len(grouping), len(sg.sel.SelectExprs.Exprs))
}

for i := range grouping {
if rand.IntN(2) < 1 {
if aliasedExpr, ok := sg.sel.SelectExprs[i].(*sqlparser.AliasedExpr); ok {
if aliasedExpr, ok := sg.sel.SelectExprs.Exprs[i].(*sqlparser.AliasedExpr); ok {
alias := fmt.Sprintf("cgroup%d", i)
aliasedExpr.SetAlias(alias)
grouping[i].name = alias
Expand All @@ -454,7 +454,7 @@ func (sg *selectGenerator) aliasGroupingColumns(grouping []column) []column {
return grouping
}

// returns the aggregation columns as three types: sqlparser.SelectExprs, []column
// returns the aggregation columns as three types: *sqlparser.SelectExprs, []column
func (sg *selectGenerator) createAggregations(tables []tableT) (aggregates []column) {
exprGenerators := slice.Map(tables, func(t tableT) sqlparser.ExprGenerator { return &t })
// add scalar subqueries
Expand Down Expand Up @@ -485,7 +485,7 @@ func (sg *selectGenerator) createOrderBy() {
}

// randomly order on SelectExprs
for _, selExpr := range sg.sel.SelectExprs {
for _, selExpr := range sg.sel.SelectExprs.Exprs {
if aliasedExpr, ok := selExpr.(*sqlparser.AliasedExpr); ok && rand.IntN(2) < 1 {
literal, ok := aliasedExpr.Expr.(*sqlparser.Literal)
isIntLiteral := ok && literal.Type == sqlparser.IntVal
Expand Down Expand Up @@ -527,7 +527,7 @@ func (sg *selectGenerator) createHavingPredicates(grouping []column) {
}

// returns between minExprs and maxExprs random expressions using generators
func (sg *selectGenerator) createRandomExprs(minExprs, maxExprs int, generators ...sqlparser.ExprGenerator) (predicates sqlparser.Exprs) {
func (sg *selectGenerator) createRandomExprs(minExprs, maxExprs int, generators ...sqlparser.ExprGenerator) (predicates []sqlparser.Expr) {
if minExprs > maxExprs {
log.Fatalf("minExprs is greater than maxExprs; minExprs: %d, maxExprs: %d\n", minExprs, maxExprs)
} else if maxExprs <= 0 {
Expand Down Expand Up @@ -578,28 +578,28 @@ func (sg *selectGenerator) randomlyAlias(expr sqlparser.Expr, alias string) colu
} else {
col.name = alias
}
sg.sel.SelectExprs = append(sg.sel.SelectExprs, sqlparser.NewAliasedExpr(expr, alias))
sg.sel.AddSelectExpr(sqlparser.NewAliasedExpr(expr, alias))

return col
}

// matchNumCols makes sure sg.sel.SelectExprs and newTable both have the same number of cols: sg.genConfig.NumCols
func (sg *selectGenerator) matchNumCols(tables []tableT, newTable tableT, canAggregate bool) tableT {
// remove SelectExprs and newTable.cols randomly until there are sg.genConfig.NumCols amount
for len(sg.sel.SelectExprs) > sg.genConfig.NumCols && sg.genConfig.NumCols > 0 {
for len(sg.sel.SelectExprs.Exprs) > sg.genConfig.NumCols && sg.genConfig.NumCols > 0 {
// select a random index and remove it from SelectExprs and newTable
idx := rand.IntN(len(sg.sel.SelectExprs))
idx := rand.IntN(len(sg.sel.SelectExprs.Exprs))

sg.sel.SelectExprs[idx] = sg.sel.SelectExprs[len(sg.sel.SelectExprs)-1]
sg.sel.SelectExprs = sg.sel.SelectExprs[:len(sg.sel.SelectExprs)-1]
sg.sel.SelectExprs.Exprs[idx] = sg.sel.SelectExprs.Exprs[len(sg.sel.SelectExprs.Exprs)-1]
sg.sel.SelectExprs.Exprs = sg.sel.SelectExprs.Exprs[:len(sg.sel.SelectExprs.Exprs)-1]

newTable.cols[idx] = newTable.cols[len(newTable.cols)-1]
newTable.cols = newTable.cols[:len(newTable.cols)-1]
}

// alternatively, add random expressions until there are sg.genConfig.NumCols amount
if sg.genConfig.NumCols > len(sg.sel.SelectExprs) {
diff := sg.genConfig.NumCols - len(sg.sel.SelectExprs)
if sg.genConfig.NumCols > len(sg.sel.SelectExprs.Exprs) {
diff := sg.genConfig.NumCols - len(sg.sel.SelectExprs.Exprs)
exprs := sg.createRandomExprs(diff, diff,
slice.Map(tables, func(t tableT) sqlparser.ExprGenerator { return &t })...)

Expand Down
50 changes: 48 additions & 2 deletions go/tools/astfmtgen/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,6 @@ func (r *Rewriter) rewriteAstPrintf(cursor *astutil.Cursor, expr *ast.CallExpr)

token := format[i]
switch token {
case 'c':
cursor.InsertBefore(r.rewriteLiteral(callexpr.X, "WriteByte", expr.Args[2+fieldnum]))
case 's':
cursor.InsertBefore(r.rewriteLiteral(callexpr.X, "WriteString", expr.Args[2+fieldnum]))
case 'l', 'r', 'v':
Expand Down Expand Up @@ -249,6 +247,50 @@ func (r *Rewriter) rewriteAstPrintf(cursor *astutil.Cursor, expr *ast.CallExpr)
Args: []ast.Expr{&ast.BasicLit{Value: `"%d"`, Kind: gotoken.STRING}, expr.Args[2+fieldnum]},
}
cursor.InsertBefore(r.rewriteLiteral(callexpr.X, "WriteString", call))
case 'n': // new directive for slices of AST nodes checked at code generation time
inputExpr := expr.Args[2+fieldnum]
inputType := r.pkg.TypesInfo.Types[inputExpr].Type
sliceType, ok := inputType.(*types.Slice)
if !ok {
panic("'%n' directive requires a slice")
}
if types.Implements(sliceType.Elem(), r.astExpr) {
// Fast path: input is []Expr
call := &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: callexpr.X,
Sel: &ast.Ident{Name: "formatExprs"},
},
Args: []ast.Expr{inputExpr},
}
cursor.InsertBefore(&ast.ExprStmt{X: call})
break
}
log.Printf("slow path for `n` directive with type %T", types.TypeString(inputType, noQualifier))
// Slow path: slice elements do not implement Expr
cursor.InsertBefore(&ast.ExprStmt{
X: &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: &ast.Ident{Name: "log"},
Sel: &ast.Ident{Name: "Printf"},
},
Args: []ast.Expr{
&ast.BasicLit{
Kind: gotoken.STRING,
Value: strconv.Quote("slow path for %n with type %T"),
},
inputExpr,
},
},
})
systay marked this conversation as resolved.
Show resolved Hide resolved
call := &ast.CallExpr{
Fun: &ast.SelectorExpr{
X: callexpr.X,
Sel: &ast.Ident{Name: "formatNodes"},
},
Args: []ast.Expr{inputExpr},
}
cursor.InsertBefore(&ast.ExprStmt{X: call})
default:
panic(fmt.Sprintf("unsupported escape %q", token))
}
Expand All @@ -259,3 +301,7 @@ func (r *Rewriter) rewriteAstPrintf(cursor *astutil.Cursor, expr *ast.CallExpr)
cursor.Delete()
return true
}

var noQualifier = func(p *types.Package) string {
return ""
}
11 changes: 1 addition & 10 deletions go/tools/asthelpergen/asthelpergen.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,6 @@ type (
}
)

// exprInterfacePath is the path of the sqlparser.Expr interface.
const exprInterfacePath = "vitess.io/vitess/go/vt/sqlparser.Expr"

func (gen *astHelperGen) iface() *types.Interface {
return gen._iface
}
Expand Down Expand Up @@ -207,19 +204,13 @@ func GenerateASTHelpers(options *Options) (map[string]*jen.File, error) {
return nil, err
}

exprType, _ := findTypeObject(exprInterfacePath, scopes)
var exprInterface *types.Interface
if exprType != nil {
exprInterface = exprType.Type().(*types.Named).Underlying().(*types.Interface)
}

nt := tt.Type().(*types.Named)
pName := nt.Obj().Pkg().Name()
generator := newGenerator(loaded[0].Module, loaded[0].TypesSizes, nt,
newEqualsGen(pName, &options.Equals),
newCloneGen(pName, &options.Clone),
newVisitGen(pName),
newRewriterGen(pName, types.TypeString(nt, noQualifier), exprInterface),
newRewriterGen(pName, types.TypeString(nt, noQualifier)),
newCOWGen(pName, nt),
)

Expand Down
19 changes: 19 additions & 0 deletions go/tools/asthelpergen/asthelpergen_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"strings"
"testing"

"vitess.io/vitess/go/tools/codegen"

"github.com/stretchr/testify/require"
)

Expand All @@ -45,3 +47,20 @@ func TestFullGeneration(t *testing.T) {
require.False(t, applyIdx == 0 && cloneIdx == 0, "file doesn't contain expected contents")
}
}

func TestRecreateAllFiles(t *testing.T) {
// t.Skip("This test recreates all files in the integration directory. It should only be run when the ASTHelperGen code has changed.")
result, err := GenerateASTHelpers(&Options{
Packages: []string{"./integration/..."},
RootInterface: "vitess.io/vitess/go/tools/asthelpergen/integration.AST",
Clone: CloneOptions{
Exclude: []string{"*NoCloneType"},
},
})
require.NoError(t, err)

for fullPath, file := range result {
err := codegen.SaveJenFile(fullPath, file)
require.NoError(t, err)
}
}
Loading
Loading