Skip to content

Commit

Permalink
fix(parser): fix parsing nested wildcard struct field access (#7024)
Browse files Browse the repository at this point in the history
Fix #7011: nested wildcard struct field access panics when there are additional parentheses.

Also did some minor style refactoring and added some comments.

Approved-By: st1page
Approved-By: yezizp2012
  • Loading branch information
xxchan authored and lmatz committed Jan 3, 2023
1 parent 68f0fa8 commit 4297ad5
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 79 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,25 +57,46 @@
└─LogicalScan { table: t2, columns: [t2.c, t2._row_id] }
- sql: |
create schema s;
create table s.t(a STRUCT<b INTEGER>);
create table s.t(a STRUCT<b INTEGER, c INTEGER>);
select s.t.a from s.t;
logical_plan: |
LogicalProject { exprs: [t.a] }
└─LogicalScan { table: t, columns: [t.a, t._row_id] }
- sql: |
create schema s;
create table s.t(a STRUCT<b INTEGER>);
create table s.t(a STRUCT<b INTEGER, c INTEGER>);
select (s.t.a).b from s.t;
logical_plan: |
LogicalProject { exprs: [Field(t.a, 0:Int32)] }
└─LogicalScan { table: t, columns: [t.a, t._row_id] }
- sql: |
create schema s;
create table s.t(a STRUCT<b INTEGER>);
create table s.t(a STRUCT<b INTEGER, c INTEGER>);
select (s.t).a.b from s.t;
logical_plan: |
LogicalProject { exprs: [Field(t.a, 0:Int32)] }
└─LogicalScan { table: t, columns: [t.a, t._row_id] }
- sql: |
create schema s;
create table s.t(a STRUCT<b INTEGER, c INTEGER>);
select ((s.t).a).b from s.t;
logical_plan: |
LogicalProject { exprs: [Field(t.a, 0:Int32)] }
└─LogicalScan { table: t, columns: [t.a, t._row_id] }
- sql: |
create schema s;
create table s.t(a STRUCT<b INTEGER, c INTEGER>);
select (s.t).a.* from s.t;
logical_plan: |
LogicalProject { exprs: [Field(t.a, 0:Int32), Field(t.a, 1:Int32)] }
└─LogicalScan { table: t, columns: [t.a, t._row_id] }
- sql: |
create schema s;
create table s.t(a STRUCT<b INTEGER, c INTEGER>);
select ((s.t).a).* from s.t;
logical_plan: |
LogicalProject { exprs: [Field(t.a, 0:Int32), Field(t.a, 1:Int32)] }
└─LogicalScan { table: t, columns: [t.a, t._row_id] }
- sql: |
create schema t;
create table t.t(t STRUCT<t INTEGER>);
Expand Down
13 changes: 11 additions & 2 deletions src/sqlparser/src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,17 @@ pub enum Expr {
Identifier(Ident),
/// Multi-part identifier, e.g. `table_alias.column` or `schema.table.col`
CompoundIdentifier(Vec<Ident>),
/// Struct-field identifier, expr is a table or a column struct, ident is field.
/// Struct-field identifier.
/// Expr is an arbitrary expression, returning either a table or a column.
/// Idents are consecutive field accesses.
/// e.g. `(table.v1).v2` or `(table).v1.v2`
///
/// It must contain parentheses to be distinguished from a [`Expr::CompoundIdentifier`].
/// See also <https://www.postgresql.org/docs/current/rowtypes.html#ROWTYPES-ACCESSING>
///
/// The left parentheses must be put at the beginning of the expression.
/// The first parenthesized part is the `expr` part, and the rest are flattened into `idents`.
/// e.g., `((v1).v2.v3).v4` is equivalent to `(v1).v2.v3.v4`.
FieldIdentifier(Box<Expr>, Vec<Ident>),
/// `IS NULL` operator
IsNull(Box<Expr>),
Expand Down Expand Up @@ -1674,7 +1683,7 @@ impl fmt::Display for Assignment {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum FunctionArgExpr {
Expr(Expr),
/// Expr is a table or a struct column.
/// Expr is an arbitrary expression, returning either a table or a column.
/// Idents are the prefix of `*`, which are consecutive field accesses.
/// e.g. `(table.v1).*` or `(table).v1.*`
ExprQualifiedWildcard(Expr, Vec<Ident>),
Expand Down
2 changes: 1 addition & 1 deletion src/sqlparser/src/ast/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ impl fmt::Display for Cte {
pub enum SelectItem {
/// Any expression, not followed by `[ AS ] alias`
UnnamedExpr(Expr),
/// Expr is a table or a struct column.
/// Expr is an arbitrary expression, returning either a table or a column.
/// Idents are the prefix of `*`, which are consecutive field accesses.
/// e.g. `(table.v1).*` or `(table).v1.*`
ExprQualifiedWildcard(Expr, Vec<Ident>),
Expand Down
170 changes: 97 additions & 73 deletions src/sqlparser/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ use IsLateral::*;

pub enum WildcardOrExpr {
Expr(Expr),
/// Expr is a table or a column struct.
/// Expr is an arbitrary expression, returning either a table or a column.
/// Idents are the prefix of `*`, which are consecutive field accesses.
/// e.g. `(table.v1).*` or `(table).v1.*`
///
/// See also [`Expr::FieldIdentifier`] for behaviors of parentheses.
ExprQualifiedWildcard(Expr, Vec<Ident>),
QualifiedWildcard(ObjectName),
Wildcard,
Expand Down Expand Up @@ -232,7 +234,14 @@ impl Parser {
Ok(Statement::Analyze { table_name })
}

/// Parse a new expression including wildcard & qualified wildcard
/// Tries to parse a wildcard expression. If it is not a wildcard, parses an expression.
///
/// A wildcard expression either means:
/// - Selecting all fields from a struct. In this case, it is a
/// [`WildcardOrExpr::ExprQualifiedWildcard`]. Similar to [`Expr::FieldIdentifier`], It must
/// contain parentheses.
/// - Selecting all columns from a table. In this case, it is a
/// [`WildcardOrExpr::QualifiedWildcard`] or a [`WildcardOrExpr::Wildcard`].
pub fn parse_wildcard_or_expr(&mut self) -> Result<WildcardOrExpr, ParserError> {
let index = self.index;

Expand All @@ -246,17 +255,17 @@ impl Parser {
Token::Mul => {
return Ok(WildcardOrExpr::Wildcard);
}
// TODO: support (((table.v1).v2).*)
// parser wildcard field selection expression
// parses wildcard field selection expression.
// Code is similar to `parse_struct_selection`
Token::LParen => {
let mut expr = self.parse_expr()?;
if self.consume_token(&Token::RParen) {
// Cast off nested expression to avoid interface by parenthesis.
while let Expr::Nested(expr1) = expr {
expr = *expr1;
// Unwrap parentheses
while let Expr::Nested(inner) = expr {
expr = *inner;
}
// Now that we have an expr, what follows must be
// dot-delimited identifiers, e.g. `(a).b.c.*`
// dot-delimited identifiers, e.g. `b.c.*` in `(a).b.c.*`
let wildcard_expr = self.parse_simple_wildcard_expr(index)?;
return self.expr_concat_wildcard_expr(expr, wildcard_expr);
}
Expand All @@ -268,47 +277,64 @@ impl Parser {
self.parse_expr().map(WildcardOrExpr::Expr)
}

/// Will return a `WildcardExpr::QualifiedWildcard(ObjectName)` with word concat or
/// `WildcardExpr::Expr`
/// Concats `ident` and `wildcard_expr` in `ident.wildcard_expr`
pub fn word_concat_wildcard_expr(
&mut self,
ident: Ident,
expr: WildcardOrExpr,
simple_wildcard_expr: WildcardOrExpr,
) -> Result<WildcardOrExpr, ParserError> {
if let WildcardOrExpr::QualifiedWildcard(mut idents) = expr {
let mut id_parts = vec![ident];
id_parts.append(&mut idents.0);
Ok(WildcardOrExpr::QualifiedWildcard(ObjectName(id_parts)))
} else if let WildcardOrExpr::Wildcard = expr {
Ok(WildcardOrExpr::QualifiedWildcard(ObjectName(vec![ident])))
} else {
Ok(expr)
let mut idents = vec![ident];
match simple_wildcard_expr {
WildcardOrExpr::QualifiedWildcard(ids) => idents.extend(ids.0),
WildcardOrExpr::Wildcard => {}
WildcardOrExpr::ExprQualifiedWildcard(_, _) => unreachable!(),
WildcardOrExpr::Expr(e) => return Ok(WildcardOrExpr::Expr(e)),
}
Ok(WildcardOrExpr::QualifiedWildcard(ObjectName(idents)))
}

/// Will return a `WildcardExpr::ExprQualifiedWildcard(Expr,ObjectName)` with expr concat or
/// `WildcardExpr::Expr`
/// Concats `expr` and `wildcard_expr` in `(expr).wildcard_expr`.
pub fn expr_concat_wildcard_expr(
&mut self,
expr: Expr,
wildcard_expr: WildcardOrExpr,
simple_wildcard_expr: WildcardOrExpr,
) -> Result<WildcardOrExpr, ParserError> {
if let WildcardOrExpr::QualifiedWildcard(idents) = wildcard_expr {
let mut id_parts = idents.0;
if let Expr::FieldIdentifier(expr, mut idents) = expr {
idents.append(&mut id_parts);
Ok(WildcardOrExpr::ExprQualifiedWildcard(*expr, idents))
} else {
Ok(WildcardOrExpr::ExprQualifiedWildcard(expr, id_parts))
}
} else if let WildcardOrExpr::Wildcard = wildcard_expr {
Ok(WildcardOrExpr::ExprQualifiedWildcard(expr, vec![]))
} else {
Ok(wildcard_expr)
if let WildcardOrExpr::Expr(e) = simple_wildcard_expr {
return Ok(WildcardOrExpr::Expr(e));
}

// similar to `parse_struct_selection`
let mut idents = vec![];
let expr = match expr {
// expr is `(foo)`
Expr::Identifier(_) => expr,
// expr is `(foo.v1)`
Expr::CompoundIdentifier(_) => expr,
// expr is `((1,2,3)::foo)`
Expr::Cast { .. } => expr,
// expr is `((foo.v1).v2)`
Expr::FieldIdentifier(expr, ids) => {
// Put `ids` to the latter part!
idents.extend(ids);
*expr
}
// expr is other things, e.g., `(1+2)`. It will become an unexpected period error at
// upper level.
_ => return Ok(WildcardOrExpr::Expr(expr)),
};

match simple_wildcard_expr {
WildcardOrExpr::QualifiedWildcard(ids) => idents.extend(ids.0),
WildcardOrExpr::Wildcard => {}
WildcardOrExpr::ExprQualifiedWildcard(_, _) => unreachable!(),
WildcardOrExpr::Expr(_) => unreachable!(),
}
Ok(WildcardOrExpr::ExprQualifiedWildcard(expr, idents))
}

/// Will return a `WildcardExpr::QualifiedWildcard(ObjectName)` or `WildcardExpr::Expr`
/// Tries to parses a wildcard expression without any parentheses.
///
/// If wildcard is not found, go back to `index` and parse an expression.
pub fn parse_simple_wildcard_expr(
&mut self,
index: usize,
Expand Down Expand Up @@ -501,7 +527,7 @@ impl Parser {
}
};
self.expect_token(&Token::RParen)?;
if self.peek_token() == Token::Period {
if self.peek_token() == Token::Period && matches!(expr, Expr::Nested(_)) {
self.parse_struct_selection(expr)
} else {
Ok(expr)
Expand All @@ -520,44 +546,42 @@ impl Parser {
}
}

// Parser field selection expression
/// Parses a field selection expression. See also [`Expr::FieldIdentifier`].
pub fn parse_struct_selection(&mut self, expr: Expr) -> Result<Expr, ParserError> {
if let Expr::Nested(compound_expr) = expr.clone() {
let mut nested_expr = *compound_expr;
// Cast off nested expression to avoid interface by parenthesis.
while let Expr::Nested(expr1) = nested_expr {
nested_expr = *expr1;
}
match nested_expr {
// Parser expr like `SELECT (foo).v1 from foo`
Expr::Identifier(ident) => Ok(Expr::FieldIdentifier(
Box::new(Expr::Identifier(ident)),
self.parse_field()?,
)),
// Parser expr like `SELECT (foo.v1).v2 from foo`
Expr::CompoundIdentifier(idents) => Ok(Expr::FieldIdentifier(
Box::new(Expr::CompoundIdentifier(idents)),
self.parse_field()?,
)),
// Parser expr like `SELECT ((1,2,3)::foo).v1`
Expr::Cast { expr, data_type } => Ok(Expr::FieldIdentifier(
Box::new(Expr::Cast { expr, data_type }),
self.parse_field()?,
)),
// Parser expr like `SELECT ((foo.v1).v2).v3 from foo`
Expr::FieldIdentifier(expr, mut idents) => {
idents.extend(self.parse_field()?);
Ok(Expr::FieldIdentifier(expr, idents))
}
_ => Ok(expr),
}
} else {
Ok(expr)
}
}

/// Parser all words after period until not period
pub fn parse_field(&mut self) -> Result<Vec<Ident>, ParserError> {
let mut nested_expr = expr.clone();
// Unwrap parentheses
while let Expr::Nested(inner) = nested_expr {
nested_expr = *inner;
}
match nested_expr {
// expr is `(foo)`
Expr::Identifier(ident) => Ok(Expr::FieldIdentifier(
Box::new(Expr::Identifier(ident)),
self.parse_fields()?,
)),
// expr is `(foo.v1)`
Expr::CompoundIdentifier(idents) => Ok(Expr::FieldIdentifier(
Box::new(Expr::CompoundIdentifier(idents)),
self.parse_fields()?,
)),
// expr is `((1,2,3)::foo)`
Expr::Cast { expr, data_type } => Ok(Expr::FieldIdentifier(
Box::new(Expr::Cast { expr, data_type }),
self.parse_fields()?,
)),
// expr is `((foo.v1).v2)`
Expr::FieldIdentifier(expr, mut idents) => {
idents.extend(self.parse_fields()?);
Ok(Expr::FieldIdentifier(expr, idents))
}
// expr is other things, e.g., `(1+2)`. It will become an unexpected period error at
// upper level.
_ => Ok(expr),
}
}

/// Parses consecutive field identifiers after a period. i.e., `.foo.bar.baz`
pub fn parse_fields(&mut self) -> Result<Vec<Ident>, ParserError> {
let mut idents = vec![];
while self.consume_token(&Token::Period) {
match self.next_token() {
Expand Down

0 comments on commit 4297ad5

Please sign in to comment.