Skip to content

Commit

Permalink
support CTE clause (#1207)
Browse files Browse the repository at this point in the history
  • Loading branch information
wjhuang2016 authored Apr 15, 2021
1 parent 7366a94 commit 48e7f46
Show file tree
Hide file tree
Showing 5 changed files with 9,382 additions and 9,009 deletions.
125 changes: 123 additions & 2 deletions ast/dml.go
Original file line number Diff line number Diff line change
Expand Up @@ -991,7 +991,22 @@ func (s *SelectStmtKind) String() string {
return ""
}

// SelectStmt represents a select/table/values query node.
type CommonTableExpression struct {
node

Name model.CIStr
Query *SubqueryExpr
ColNameList []model.CIStr
}

type WithClause struct {
node

IsRecursive bool
CTEs []*CommonTableExpression
}

// SelectStmt represents the select query node.
// See https://dev.mysql.com/doc/refman/5.7/en/select.html
type SelectStmt struct {
dmlNode
Expand Down Expand Up @@ -1032,6 +1047,53 @@ type SelectStmt struct {
Kind SelectStmtKind
// Lists is filled only when Kind == SelectStmtKindValues
Lists []*RowExpr
With *WithClause
}

func (n *WithClause) Restore(ctx *format.RestoreCtx) error {
ctx.WriteKeyWord("WITH ")
if n.IsRecursive {
ctx.WriteKeyWord("RECURSIVE ")
}
for i, cte := range n.CTEs {
if i != 0 {
ctx.WritePlain(", ")
}
ctx.WriteName(cte.Name.String())
if len(cte.ColNameList) > 0 {
ctx.WritePlain(" (")
for j, name := range cte.ColNameList {
if j != 0 {
ctx.WritePlain(", ")
}
ctx.WriteName(name.String())
}
ctx.WritePlain(")")
}
ctx.WriteKeyWord(" AS ")
err := cte.Query.Restore(ctx)
if err != nil {
return err
}
}
ctx.WritePlain(" ")
return nil
}

func (n *WithClause) Accept(v Visitor) (Node, bool) {
newNode, skipChildren := v.Enter(n)
if skipChildren {
return v.Leave(newNode)
}

for _, cte := range n.CTEs {
node, ok := cte.Query.Accept(v)
if !ok {
return n, false
}
cte.Query = node.(*SubqueryExpr)
}
return v.Leave(n)
}

// Restore implements Node interface.
Expand All @@ -1042,6 +1104,13 @@ func (n *SelectStmt) Restore(ctx *format.RestoreCtx) error {
ctx.WritePlain(")")
}()
}
if n.With != nil {
err := n.With.Restore(ctx)
if err != nil {
return err
}
}

ctx.WriteKeyWord(n.Kind.String())
ctx.WritePlain(" ")
switch n.Kind {
Expand Down Expand Up @@ -1204,6 +1273,15 @@ func (n *SelectStmt) Accept(v Visitor) (Node, bool) {
}

n = newNode.(*SelectStmt)

if n.With != nil {
node, ok := n.With.Accept(v)
if !ok {
return n, false
}
n.With = node.(*WithClause)
}

if n.TableHints != nil && len(n.TableHints) != 0 {
newHints := make([]*TableOptimizerHint, len(n.TableHints))
for i, hint := range n.TableHints {
Expand Down Expand Up @@ -1381,10 +1459,17 @@ type SetOprStmt struct {
SelectList *SetOprSelectList
OrderBy *OrderByClause
Limit *Limit
With *WithClause
}

// Restore implements Node interface.
func (n *SetOprStmt) Restore(ctx *format.RestoreCtx) error {
if n.With != nil {
if err := n.With.Restore(ctx); err != nil {
return errors.Annotate(err, "An error occurred while restore UnionStmt.With")
}
}

if err := n.SelectList.Restore(ctx); err != nil {
return errors.Annotate(err, "An error occurred while restore SetOprStmt.SelectList")
}
Expand All @@ -1411,7 +1496,13 @@ func (n *SetOprStmt) Accept(v Visitor) (Node, bool) {
if skipChildren {
return v.Leave(newNode)
}
n = newNode.(*SetOprStmt)
if n.With != nil {
node, ok := n.With.Accept(v)
if !ok {
return n, false
}
n.With = node.(*WithClause)
}
if n.SelectList != nil {
node, ok := n.SelectList.Accept(v)
if !ok {
Expand Down Expand Up @@ -1943,10 +2034,18 @@ type DeleteStmt struct {
BeforeFrom bool
// TableHints represents the table level Optimizer Hint for join type.
TableHints []*TableOptimizerHint
With *WithClause
}

// Restore implements Node interface.
func (n *DeleteStmt) Restore(ctx *format.RestoreCtx) error {
if n.With != nil {
err := n.With.Restore(ctx)
if err != nil {
return err
}
}

ctx.WriteKeyWord("DELETE ")

if n.TableHints != nil && len(n.TableHints) != 0 {
Expand Down Expand Up @@ -2036,6 +2135,13 @@ func (n *DeleteStmt) Accept(v Visitor) (Node, bool) {
}

n = newNode.(*DeleteStmt)
if n.With != nil {
node, ok := n.With.Accept(v)
if !ok {
return n, false
}
n.With = node.(*WithClause)
}
node, ok := n.TableRefs.Accept(v)
if !ok {
return n, false
Expand Down Expand Up @@ -2088,10 +2194,18 @@ type UpdateStmt struct {
IgnoreErr bool
MultipleTable bool
TableHints []*TableOptimizerHint
With *WithClause
}

// Restore implements Node interface.
func (n *UpdateStmt) Restore(ctx *format.RestoreCtx) error {
if n.With != nil {
err := n.With.Restore(ctx)
if err != nil {
return err
}
}

ctx.WriteKeyWord("UPDATE ")

if n.TableHints != nil && len(n.TableHints) != 0 {
Expand Down Expand Up @@ -2169,6 +2283,13 @@ func (n *UpdateStmt) Accept(v Visitor) (Node, bool) {
return v.Leave(newNode)
}
n = newNode.(*UpdateStmt)
if n.With != nil {
node, ok := n.With.Accept(v)
if !ok {
return n, false
}
n.With = node.(*WithClause)
}
node, ok := n.TableRefs.Accept(v)
if !ok {
return n, false
Expand Down
1 change: 1 addition & 0 deletions misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ var tokenMap = map[string]int{
"REBUILD": rebuild,
"RECENT": recent,
"RECOVER": recover,
"RECURSIVE": recursive,
"REDUNDANT": redundant,
"REFERENCES": references,
"REGEXP": regexpKwd,
Expand Down
Loading

0 comments on commit 48e7f46

Please sign in to comment.