Skip to content

Commit

Permalink
sqlparser: escape keyword names correctly
Browse files Browse the repository at this point in the history
Fix for issue #797.
Names that used keywords were not always getting back-quoted
correctly during codegen. This is a comprehensive fix that
covers all such possible cases.
  • Loading branch information
sougou committed Aug 4, 2015
1 parent 0074e44 commit 1c59d2e
Show file tree
Hide file tree
Showing 9 changed files with 580 additions and 534 deletions.
12 changes: 11 additions & 1 deletion data/test/sqlparser_test/parse_pass.sql
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ select -1 from t where b = -2
select 1 from t // aa#select 1 from t
select 1 from t -- aa#select 1 from t
select /* simplest */ 1 from t
select /* keyword col */ `By` from t#select /* keyword col */ `by` from t
select /* double star **/ 1 from t
select /* double */ /* comment */ 1 from t
select /* back-quote */ 1 from `t`#select /* back-quote */ 1 from t
select /* back-quote keyword */ 1 from `from`#select /* back-quote keyword */ 1 from `from`
select /* back-quote keyword */ 1 from `By`#select /* back-quote keyword */ 1 from `By`
select /* @ */ @@a from b
select /* \0 */ '\0' from a
select 1 /* drop this comment */ from t#select 1 from t
Expand All @@ -28,7 +29,9 @@ select /* select list */ 1, 2 from t
select /* * */ * from t
select /* column alias */ a b from t#select /* column alias */ a as b from t
select /* column alias with as */ a as b from t
select /* keyword column alias */ a as `By` from t#select /* keyword column alias */ a as `by` from t
select /* a.* */ a.* from t
select /* `By`.* */ `By`.* from t
select /* select with bool expr */ a = b from t
select /* case_when */ case when a = b then c end from t
select /* case_when_else */ case when a = b then c else d end from t
Expand All @@ -39,11 +42,13 @@ select /* table list */ 1 from t1, t2
select /* parenthessis in table list 1 */ 1 from (t1), t2
select /* parenthessis in table list 2 */ 1 from t1, (t2)
select /* use */ 1 from t1 use index (a) where b = 1
select /* keyword index */ 1 from t1 use index (`By`) where b = 1#select /* keyword index */ 1 from t1 use index (`by`) where b = 1
select /* ignore */ 1 from t1 as t2 ignore index (a), t3 use index (b) where b = 1
select /* use */ 1 from t1 as t2 use index (a), t3 use index (b) where b = 1
select /* force */ 1 from t1 as t2 force index (a), t3 force index (b) where b = 1
select /* table alias */ 1 from t t1#select /* table alias */ 1 from t as t1
select /* table alias with as */ 1 from t as t1
select /* keyword table alias */ 1 from t as `By`
select /* join */ 1 from t1 join t2
select /* straight_join */ 1 from t1 straight_join t2
select /* left join */ 1 from t1 left join t2
Expand All @@ -55,6 +60,7 @@ select /* cross join */ 1 from t1 cross join t2
select /* natural join */ 1 from t1 natural join t2
select /* join on */ 1 from t1 join t2 on a = b
select /* s.t */ 1 from s.t
select /* keyword schema & table name */ 1 from `By`.`bY`
select /* select in from */ 1 from (select 1 from t)
select /* where */ 1 from t where a = b
select /* and */ 1 from t where a = b and a = c
Expand Down Expand Up @@ -103,6 +109,7 @@ select /* if as func */ 1 from t where a = if(b)
select /* function with distinct */ count(distinct a) from t
select /* a */ a from t
select /* a.b */ a.b from t
select /* keyword a.b */ `By`.`bY` from t#select /* keyword a.b */ `By`.`by` from t
select /* string */ 'a' from t
select /* double quoted string */ "a" from t#select /* double quoted string */ 'a' from t
select /* quote quote in string */ 'a''a' from t#select /* quote quote in string */ 'a\'a' from t
Expand Down Expand Up @@ -161,6 +168,7 @@ set /* simple */ a = 3
set /* list */ a = 3, b = 4
alter ignore table a add foo#alter table a
alter table a add foo#alter table a
alter table `By` add foo#alter table `By`
alter table a alter foo#alter table a
alter table a change foo#alter table a
alter table a modify foo#alter table a
Expand All @@ -172,8 +180,10 @@ alter table a default foo#alter table a
alter table a discard foo#alter table a
alter table a import foo#alter table a
alter table a rename b#rename table a b
alter table `By` rename `bY`#rename table `By` `bY`
alter table a rename to b#rename table a b
create table a
create table `by`
create table if not exists a#create table a
create index a on b#alter table b
create unique index a on b#alter table b
Expand Down
4 changes: 2 additions & 2 deletions go/vt/sqlparser/analyzer.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
// GetTableName returns the table name from the SimpleTableExpr
// only if it's a simple expression. Otherwise, it returns "".
func GetTableName(node SimpleTableExpr) string {
if n, ok := node.(*TableName); ok && n.Qualifier == nil {
if n, ok := node.(*TableName); ok && n.Qualifier == "" {
return string(n.Name)
}
// sub-select or '.' expression
Expand Down Expand Up @@ -47,7 +47,7 @@ func IsValue(node ValExpr) bool {
return false
}

// HasINCaluse returns true if any of the conditions has an IN clause.
// HasINClause returns true if any of the conditions has an IN clause.
func HasINClause(conditions []BoolExpr) bool {
for _, node := range conditions {
if c, ok := node.(*ComparisonExpr); ok && c.Operator == AST_IN {
Expand Down
82 changes: 47 additions & 35 deletions go/vt/sqlparser/ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"strconv"
"strings"

"github.com/youtube/vitess/go/sqltypes"
)
Expand Down Expand Up @@ -196,8 +197,8 @@ func (node *Set) Format(buf *TrackedBuffer) {
// NewName is set for AST_ALTER, AST_CREATE, AST_RENAME.
type DDL struct {
Action string
Table []byte
NewName []byte
Table TableID
NewName TableID
}

const (
Expand All @@ -210,11 +211,11 @@ const (
func (node *DDL) Format(buf *TrackedBuffer) {
switch node.Action {
case AST_CREATE:
buf.Myprintf("%s table %s", node.Action, node.NewName)
buf.Myprintf("%s table %v", node.Action, node.NewName)
case AST_RENAME:
buf.Myprintf("%s table %s %s", node.Action, node.Table, node.NewName)
buf.Myprintf("%s table %v %v", node.Action, node.Table, node.NewName)
default:
buf.Myprintf("%s table %s", node.Action, node.Table)
buf.Myprintf("%s table %v", node.Action, node.Table)
}
}

Expand Down Expand Up @@ -258,26 +259,26 @@ func (*NonStarExpr) ISelectExpr() {}

// StarExpr defines a '*' or 'table.*' expression.
type StarExpr struct {
TableName []byte
TableName TableID
}

func (node *StarExpr) Format(buf *TrackedBuffer) {
if node.TableName != nil {
buf.Myprintf("%s.", node.TableName)
if node.TableName != "" {
buf.Myprintf("%v.", node.TableName)
}
buf.Myprintf("*")
}

// NonStarExpr defines a non-'*' select expr.
type NonStarExpr struct {
Expr Expr
As []byte
As SQLName
}

func (node *NonStarExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("%v", node.Expr)
if node.As != nil {
buf.Myprintf(" as %s", node.As)
if node.As != "" {
buf.Myprintf(" as %v", node.As)
}
}

Expand Down Expand Up @@ -319,14 +320,14 @@ func (*JoinTableExpr) ITableExpr() {}
// coupled with an optional alias or index hint.
type AliasedTableExpr struct {
Expr SimpleTableExpr
As []byte
As TableID
Hints *IndexHints
}

func (node *AliasedTableExpr) Format(buf *TrackedBuffer) {
buf.Myprintf("%v", node.Expr)
if node.As != nil {
buf.Myprintf(" as %s", node.As)
if node.As != "" {
buf.Myprintf(" as %v", node.As)
}
if node.Hints != nil {
// Hint node provides the space padding.
Expand All @@ -345,15 +346,14 @@ func (*Subquery) ISimpleTableExpr() {}

// TableName represents a table name.
type TableName struct {
Name, Qualifier []byte
Name, Qualifier TableID
}

func (node *TableName) Format(buf *TrackedBuffer) {
if node.Qualifier != nil {
escape(buf, node.Qualifier)
buf.Myprintf(".")
if node.Qualifier != "" {
buf.Myprintf("%v.", node.Qualifier)
}
escape(buf, node.Name)
buf.Myprintf("%v", node.Name)
}

// ParenTableExpr represents a parenthesized TableExpr.
Expand Down Expand Up @@ -393,7 +393,7 @@ func (node *JoinTableExpr) Format(buf *TrackedBuffer) {
// IndexHints represents a list of index hints.
type IndexHints struct {
Type string
Indexes [][]byte
Indexes []SQLName
}

const (
Expand All @@ -406,7 +406,7 @@ func (node *IndexHints) Format(buf *TrackedBuffer) {
buf.Myprintf(" %s index ", node.Type)
prefix := "("
for _, n := range node.Indexes {
buf.Myprintf("%s%s", prefix, n)
buf.Myprintf("%s%v", prefix, n)
prefix = ", "
}
buf.Myprintf(")")
Expand Down Expand Up @@ -646,23 +646,15 @@ func (node *NullVal) Format(buf *TrackedBuffer) {

// ColName represents a column name.
type ColName struct {
Name, Qualifier []byte
Name SQLName
Qualifier TableID
}

func (node *ColName) Format(buf *TrackedBuffer) {
if node.Qualifier != nil {
escape(buf, node.Qualifier)
buf.Myprintf(".")
}
escape(buf, node.Name)
}

func escape(buf *TrackedBuffer, name []byte) {
if _, ok := keywords[string(name)]; ok {
buf.Myprintf("`%s`", name)
} else {
buf.Myprintf("%s", name)
if node.Qualifier != "" {
buf.Myprintf("%v.", node.Qualifier)
}
buf.Myprintf("%v", node.Name)
}

// ColTuple represents a list of column values.
Expand Down Expand Up @@ -756,7 +748,7 @@ func (node *UnaryExpr) Format(buf *TrackedBuffer) {

// FuncExpr represents a function call.
type FuncExpr struct {
Name []byte
Name string
Distinct bool
Exprs SelectExprs
}
Expand Down Expand Up @@ -971,3 +963,23 @@ func (node OnDup) Format(buf *TrackedBuffer) {
}
buf.Myprintf(" on duplicate key update %v", UpdateExprs(node))
}

type TableID string

func (node TableID) Format(buf *TrackedBuffer) {
if _, ok := keywords[strings.ToLower(string(node))]; ok {
buf.Myprintf("`%s`", string(node))
return
}
buf.Myprintf("%s", string(node))
}

type SQLName string

func (node SQLName) Format(buf *TrackedBuffer) {
if _, ok := keywords[string(node)]; ok {
buf.Myprintf("`%s`", string(node))
return
}
buf.Myprintf("%s", string(node))
}
Loading

0 comments on commit 1c59d2e

Please sign in to comment.