Skip to content

Commit

Permalink
Merge pull request #2861 from dolthub/zachmu/quote-identifiers
Browse files Browse the repository at this point in the history
Added configurable options for how to quote identifiers in column defaults
  • Loading branch information
zachmu authored Feb 27, 2025
2 parents 5b478d4 + a65e489 commit b9d44c4
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 24 deletions.
2 changes: 1 addition & 1 deletion memory/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ func stripTblNames(e sql.Expression) (sql.Expression, transform.TreeIdentity, er
case *expression.GetField:
// strip table names
ne := expression.NewGetField(e.Index(), e.Type(), e.Name(), e.IsNullable())
ne = ne.WithBackTickNames(e.IsBackTickNames())
ne = ne.WithQuotedNames(sql.GlobalParser, e.IsQuotedIdentifier())
return ne, transform.NewTree, nil
default:
}
Expand Down
3 changes: 3 additions & 0 deletions sql/analyzer/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ func (ab *Builder) Build() *Analyzer {
Catalog: NewCatalog(ab.provider),
Coster: memo.NewDefaultCoster(),
ExecBuilder: rowexec.DefaultBuilder,
Parser: sql.GlobalParser,
}
}

Expand All @@ -288,6 +289,8 @@ type Analyzer struct {
ExecBuilder sql.NodeExecBuilder
// Runner represents the engine, which is represented as a separate interface to work around circular dependencies
Runner StatementRunner
// Parser is the parser used to parse SQL statements.
Parser sql.Parser
}

// NewDefault creates a default Analyzer instance with all default Rules and configuration.
Expand Down
14 changes: 7 additions & 7 deletions sql/analyzer/resolve_column_defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -307,15 +307,15 @@ func stripTableNamesFromDefault(e *expression.Wrapper) (sql.Expression, transfor
return expression.WrapExpression(&nd), transform.NewTree, nil
}

func backtickDefaultColumnValueNames(ctx *sql.Context, _ *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
span, ctx := ctx.Span("backtickDefaultColumnValueNames")
func quoteDefaultColumnValueNames(ctx *sql.Context, a *Analyzer, n sql.Node, _ *plan.Scope, _ RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
span, ctx := ctx.Span("quoteDefaultColumnValueNames")
defer span.End()

return transform.Node(n, func(n sql.Node) (sql.Node, transform.TreeIdentity, error) {
switch node := n.(type) {
case *plan.AlterDefaultSet:
eWrapper := expression.WrapExpression(node.Default)
newExpr, same, err := backtickDefault(eWrapper)
newExpr, same, err := quoteIdentifiers(a.Parser, eWrapper)
if err != nil {
return node, transform.SameTree, err
}
Expand All @@ -335,7 +335,7 @@ func backtickDefaultColumnValueNames(ctx *sql.Context, _ *Analyzer, n sql.Node,
return e, transform.SameTree, nil
}

return backtickDefault(eWrapper)
return quoteIdentifiers(a.Parser, eWrapper)
})
case *plan.ResolvedTable:
ct, ok := node.Table.(*information_schema.ColumnsTable)
Expand All @@ -354,7 +354,7 @@ func backtickDefaultColumnValueNames(ctx *sql.Context, _ *Analyzer, n sql.Node,
return e, transform.SameTree, nil
}

return backtickDefault(eWrapper)
return quoteIdentifiers(a.Parser, eWrapper)
})

if err != nil {
Expand All @@ -376,7 +376,7 @@ func backtickDefaultColumnValueNames(ctx *sql.Context, _ *Analyzer, n sql.Node,
})
}

func backtickDefault(wrap *expression.Wrapper) (sql.Expression, transform.TreeIdentity, error) {
func quoteIdentifiers(parser sql.Parser, wrap *expression.Wrapper) (sql.Expression, transform.TreeIdentity, error) {
newDefault, ok := wrap.Unwrap().(*sql.ColumnDefaultValue)
if !ok {
return wrap, transform.SameTree, nil
Expand All @@ -388,7 +388,7 @@ func backtickDefault(wrap *expression.Wrapper) (sql.Expression, transform.TreeId

newExpr, same, err := transform.Expr(newDefault.Expr, func(expr sql.Expression) (sql.Expression, transform.TreeIdentity, error) {
if e, isGf := expr.(*expression.GetField); isGf {
return e.WithBackTickNames(true), transform.NewTree, nil
return e.WithQuotedNames(parser, true), transform.NewTree, nil
}
return expr, transform.SameTree, nil
})
Expand Down
6 changes: 3 additions & 3 deletions sql/analyzer/rule_ids.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ const (
validateDeleteFromId // validateDeleteFrom

// after all
cacheSubqueryAliasesInJoinsId // cacheSubqueryAliasesInJoins
BacktickDefaulColumnValueNamesId // backtickDefaultColumnValueNames
TrackProcessId // trackProcess
cacheSubqueryAliasesInJoinsId // cacheSubqueryAliasesInJoins
QuoteDefaultColumnValueNamesId // quoteDefaultColumnValueNames
TrackProcessId // trackProcess
)
6 changes: 3 additions & 3 deletions sql/analyzer/ruleid_string.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion sql/analyzer/rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func init() {
{applyProceduresId, applyProcedures},
{inlineSubqueryAliasRefsId, inlineSubqueryAliasRefs},
{cacheSubqueryAliasesInJoinsId, cacheSubqueryAliasesInJoins},
{BacktickDefaulColumnValueNamesId, backtickDefaultColumnValueNames},
{QuoteDefaultColumnValueNamesId, quoteDefaultColumnValueNames},
{TrackProcessId, trackProcess},
}
}
Expand Down
27 changes: 18 additions & 9 deletions sql/expression/get_field.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ type GetField struct {
fieldType2 sql.Type2
nullable bool

backTickNames bool
// parser is the parser used to parse the expression and print it
parser sql.Parser

// quoteName indicates whether the field name should be quoted when printed with String()
quoteName bool
}

var _ sql.Expression = (*GetField)(nil)
Expand Down Expand Up @@ -161,12 +165,16 @@ func (p *GetField) WithChildren(children ...sql.Expression) (sql.Expression, err
}

func (p *GetField) String() string {
// We never quote anything if the table identifier is present. Quoting the field name is a very narrow use case
// used only for serializing column default values and related fields, in which case the table name will always be
// stripped away. The output of this method is load-bearing in many places of analysis and execution.
if p.table == "" {
if p.backTickNames {
return "`" + p.name + "`"
if p.quoteName {
return p.parser.QuoteIdentifier(p.name)
}
return p.name
}

return p.table + "." + p.name
}

Expand All @@ -188,16 +196,17 @@ func (p *GetField) WithIndex(n int) sql.Expression {
return &p2
}

// WithBackTickNames returns a copy of this expression with the backtick names flag set to the given value.
func (p *GetField) WithBackTickNames(backtick bool) *GetField {
// WithQuotedNames returns a copy of this expression with the backtick names flag set to the given value.
func (p *GetField) WithQuotedNames(parser sql.Parser, quoteNames bool) *GetField {
p2 := *p
p2.backTickNames = backtick
p2.quoteName = quoteNames
p2.parser = parser
return &p2
}

// IsBackTickNames returns whether the field name should be quoted with backticks.
func (p *GetField) IsBackTickNames() bool {
return p.backTickNames
// IsQuotedIdentifier returns whether the field name should be quoted.
func (p *GetField) IsQuotedIdentifier() bool {
return p.quoteName
}

// CollationCoercibility implements the interface sql.CollationCoercible.
Expand Down
9 changes: 9 additions & 0 deletions sql/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package sql

import (
"context"
"fmt"
trace2 "runtime/trace"
"strings"
"unicode"
Expand Down Expand Up @@ -44,6 +45,10 @@ type Parser interface {
// the index of the start of the next query. If |query| represents a no-op statement, such as ";" or "-- comment",
// then implementations must return Vitess' ErrEmpty error.
ParseOneWithOptions(context.Context, string, ast.ParserOptions) (ast.Statement, int, error)
// QuoteIdentifier returns the identifier given quoted according to this parser's dialect. This is used to
// standardize identifiers that cannot be parsed without quoting, because they break the normal identifier naming
// rules (such as containing spaces)
QuoteIdentifier(identifier string) string
}

var _ Parser = &MysqlParser{}
Expand Down Expand Up @@ -99,3 +104,7 @@ func RemoveSpaceAndDelimiter(query string, d rune) string {
return r == d || unicode.IsSpace(r)
})
}

func (m *MysqlParser) QuoteIdentifier(identifier string) string {
return fmt.Sprintf("`%s`", strings.ReplaceAll(identifier, "`", "``"))
}

0 comments on commit b9d44c4

Please sign in to comment.