Skip to content

Commit

Permalink
support FILTER clause in window functions
Browse files Browse the repository at this point in the history
fixes #1006
  • Loading branch information
lovasoa committed Oct 18, 2023
1 parent b9b8f7f commit 9af63ba
Show file tree
Hide file tree
Showing 15 changed files with 93 additions and 2 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ println!("AST: {:?}", ast);
This outputs

```rust
AST: [Query(Query { ctes: [], body: Select(Select { distinct: false, projection: [UnnamedExpr(Identifier("a")), UnnamedExpr(Identifier("b")), UnnamedExpr(Value(Long(123))), UnnamedExpr(Function(Function { name: ObjectName(["myfunc"]), args: [Identifier("b")], over: None, distinct: false }))], from: [TableWithJoins { relation: Table { name: ObjectName(["table_1"]), alias: None, args: [], with_hints: [] }, joins: [] }], selection: Some(BinaryOp { left: BinaryOp { left: Identifier("a"), op: Gt, right: Identifier("b") }, op: And, right: BinaryOp { left: Identifier("b"), op: Lt, right: Value(Long(100)) } }), group_by: [], having: None }), order_by: [OrderByExpr { expr: Identifier("a"), asc: Some(false) }, OrderByExpr { expr: Identifier("b"), asc: None }], limit: None, offset: None, fetch: None })]
AST: [Query(Query { ctes: [], body: Select(Select { distinct: false, projection: [UnnamedExpr(Identifier("a")), UnnamedExpr(Identifier("b")), UnnamedExpr(Value(Long(123))), UnnamedExpr(Function(Function { name: ObjectName(["myfunc"]), args: [Identifier("b")], filter: None, over: None, distinct: false }))], from: [TableWithJoins { relation: Table { name: ObjectName(["table_1"]), alias: None, args: [], with_hints: [] }, joins: [] }], selection: Some(BinaryOp { left: BinaryOp { left: Identifier("a"), op: Gt, right: Identifier("b") }, op: And, right: BinaryOp { left: Identifier("b"), op: Lt, right: Value(Long(100)) } }), group_by: [], having: None }), order_by: [OrderByExpr { expr: Identifier("a"), asc: Some(false) }, OrderByExpr { expr: Identifier("b"), asc: None }], limit: None, offset: None, fetch: None })]
```


Expand Down
9 changes: 9 additions & 0 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -985,8 +985,11 @@ impl Display for WindowType {
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct WindowSpec {
// OVER (PARTITION BY ...)
pub partition_by: Vec<Expr>,
// OVER (ORDER BY ...)
pub order_by: Vec<OrderByExpr>,
// OVER (window frame)
pub window_frame: Option<WindowFrame>,
}

Expand Down Expand Up @@ -3617,6 +3620,8 @@ impl fmt::Display for CloseCursor {
pub struct Function {
pub name: ObjectName,
pub args: Vec<FunctionArg>,
// e.g. `x > 5` in `COUNT(x) FILTER (WHERE x > 5)`
pub filter: Option<Box<Expr>>,
pub over: Option<WindowType>,
// aggregate functions may specify eg `COUNT(DISTINCT x)`
pub distinct: bool,
Expand Down Expand Up @@ -3665,6 +3670,10 @@ impl fmt::Display for Function {
display_comma_separated(&self.order_by),
)?;

if let Some(filter_cond) = &self.filter {
write!(f, " FILTER (WHERE {filter_cond})")?;
}

if let Some(o) = &self.over {
write!(f, " OVER {o}")?;
}
Expand Down
2 changes: 1 addition & 1 deletion src/ast/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ where
/// *expr = Expr::Function(Function {
/// name: ObjectName(vec![Ident::new("f")]),
/// args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(old_expr))],
/// over: None, distinct: false, special: false, order_by: vec![],
/// filter: None, over: None, distinct: false, special: false, order_by: vec![],
/// });
/// }
/// ControlFlow::<()>::Continue(())
Expand Down
4 changes: 4 additions & 0 deletions src/dialect/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ impl Dialect for SQLiteDialect {
|| ('\u{007f}'..='\u{ffff}').contains(&ch)
}

fn supports_filter_during_aggregation(&self) -> bool {
true
}

fn is_identifier_part(&self, ch: char) -> bool {
self.is_identifier_start(ch) || ch.is_ascii_digit()
}
Expand Down
14 changes: 14 additions & 0 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,7 @@ impl<'a> Parser<'a> {
Ok(Expr::Function(Function {
name: ObjectName(vec![w.to_ident()]),
args: vec![],
filter: None,
over: None,
distinct: false,
special: true,
Expand Down Expand Up @@ -957,6 +958,17 @@ impl<'a> Parser<'a> {
self.expect_token(&Token::LParen)?;
let distinct = self.parse_all_or_distinct()?.is_some();
let (args, order_by) = self.parse_optional_args_with_orderby()?;
let filter = if self.dialect.supports_filter_during_aggregation()
&& self.parse_keyword(Keyword::FILTER)
&& self.consume_token(&Token::LParen)
&& self.parse_keyword(Keyword::WHERE)
{
let filter = Some(Box::new(self.parse_expr()?));
self.expect_token(&Token::RParen)?;
filter
} else {
None
};
let over = if self.parse_keyword(Keyword::OVER) {
if self.consume_token(&Token::LParen) {
let window_spec = self.parse_window_spec()?;
Expand All @@ -970,6 +982,7 @@ impl<'a> Parser<'a> {
Ok(Expr::Function(Function {
name,
args,
filter,
over,
distinct,
special: false,
Expand All @@ -987,6 +1000,7 @@ impl<'a> Parser<'a> {
Ok(Expr::Function(Function {
name,
args,
filter: None,
over: None,
distinct: false,
special,
Expand Down
1 change: 1 addition & 0 deletions tests/sqlparser_bigquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ fn parse_map_access_offset() {
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(
number("0")
))),],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down
4 changes: 4 additions & 0 deletions tests/sqlparser_clickhouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ fn parse_map_access_expr() {
Value::SingleQuotedString("endpoint".to_string())
))),
],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -88,6 +89,7 @@ fn parse_map_access_expr() {
Value::SingleQuotedString("app".to_string())
))),
],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -137,6 +139,7 @@ fn parse_array_fn() {
FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(Ident::new("x1")))),
FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(Ident::new("x2")))),
],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -195,6 +198,7 @@ fn parse_delimited_identifiers() {
&Expr::Function(Function {
name: ObjectName(vec![Ident::with_quote('"', "myfun")]),
args: vec![],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down
18 changes: 18 additions & 0 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,7 @@ fn parse_select_count_wildcard() {
&Expr::Function(Function {
name: ObjectName(vec![Ident::new("COUNT")]),
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Wildcard)],
filter: None,
over: None,
distinct: false,
special: false,
Expand All @@ -892,6 +893,7 @@ fn parse_select_count_distinct() {
op: UnaryOperator::Plus,
expr: Box::new(Expr::Identifier(Ident::new("x"))),
}))],
filter: None,
over: None,
distinct: true,
special: false,
Expand Down Expand Up @@ -1859,6 +1861,7 @@ fn parse_select_having() {
left: Box::new(Expr::Function(Function {
name: ObjectName(vec![Ident::new("COUNT")]),
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Wildcard)],
filter: None,
over: None,
distinct: false,
special: false,
Expand All @@ -1884,6 +1887,7 @@ fn parse_select_qualify() {
left: Box::new(Expr::Function(Function {
name: ObjectName(vec![Ident::new("ROW_NUMBER")]),
args: vec![],
filter: None,
over: Some(WindowType::WindowSpec(WindowSpec {
partition_by: vec![Expr::Identifier(Ident::new("p"))],
order_by: vec![OrderByExpr {
Expand Down Expand Up @@ -3323,6 +3327,7 @@ fn parse_scalar_function_in_projection() {
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(
Expr::Identifier(Ident::new("id"))
))],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -3442,6 +3447,7 @@ fn parse_named_argument_function() {
))),
},
],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -3473,6 +3479,7 @@ fn parse_window_functions() {
&Expr::Function(Function {
name: ObjectName(vec![Ident::new("row_number")]),
args: vec![],
filter: None,
over: Some(WindowType::WindowSpec(WindowSpec {
partition_by: vec![],
order_by: vec![OrderByExpr {
Expand Down Expand Up @@ -3516,6 +3523,7 @@ fn test_parse_named_window() {
quote_style: None,
}),
))],
filter: None,
over: Some(WindowType::NamedWindow(Ident {
value: "window1".to_string(),
quote_style: None,
Expand All @@ -3541,6 +3549,7 @@ fn test_parse_named_window() {
quote_style: None,
}),
))],
filter: None,
over: Some(WindowType::NamedWindow(Ident {
value: "window2".to_string(),
quote_style: None,
Expand Down Expand Up @@ -4009,6 +4018,7 @@ fn parse_at_timezone() {
quote_style: None,
}]),
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(zero.clone()))],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -4036,6 +4046,7 @@ fn parse_at_timezone() {
quote_style: None,
},],),
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(zero))],
filter: None,
over: None,
distinct: false,
special: false,
Expand All @@ -4047,6 +4058,7 @@ fn parse_at_timezone() {
Value::SingleQuotedString("%Y-%m-%dT%H".to_string()),
),),),
],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -4205,6 +4217,7 @@ fn parse_table_function() {
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(
Value::SingleQuotedString("1".to_owned()),
)))],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -4356,6 +4369,7 @@ fn parse_unnest_in_from_clause() {
FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(number("2")))),
FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(number("3")))),
],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -4385,6 +4399,7 @@ fn parse_unnest_in_from_clause() {
FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(number("2")))),
FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(number("3")))),
],
filter: None,
over: None,
distinct: false,
special: false,
Expand All @@ -4396,6 +4411,7 @@ fn parse_unnest_in_from_clause() {
FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(number("5")))),
FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(number("6")))),
],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -6776,6 +6792,7 @@ fn parse_time_functions() {
let select_localtime_func_call_ast = Function {
name: ObjectName(vec![Ident::new(func_name)]),
args: vec![],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -7256,6 +7273,7 @@ fn parse_pivot_table() {
args: (vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(
Expr::CompoundIdentifier(vec![Ident::new("a"), Ident::new("amount"),])
))]),
filter: None,
over: None,
distinct: false,
special: false,
Expand Down
1 change: 1 addition & 0 deletions tests/sqlparser_hive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ fn parse_delimited_identifiers() {
&Expr::Function(Function {
name: ObjectName(vec![Ident::with_quote('"', "myfun")]),
args: vec![],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down
1 change: 1 addition & 0 deletions tests/sqlparser_mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ fn parse_delimited_identifiers() {
&Expr::Function(Function {
name: ObjectName(vec![Ident::with_quote('"', "myfun")]),
args: vec![],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down
6 changes: 6 additions & 0 deletions tests/sqlparser_mysql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,7 @@ fn parse_insert_with_on_duplicate_update() {
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(
Expr::Identifier(Ident::new("description"))
))],
filter: None,
over: None,
distinct: false,
special: false,
Expand All @@ -1077,6 +1078,7 @@ fn parse_insert_with_on_duplicate_update() {
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(
Expr::Identifier(Ident::new("perm_create"))
))],
filter: None,
over: None,
distinct: false,
special: false,
Expand All @@ -1090,6 +1092,7 @@ fn parse_insert_with_on_duplicate_update() {
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(
Expr::Identifier(Ident::new("perm_read"))
))],
filter: None,
over: None,
distinct: false,
special: false,
Expand All @@ -1103,6 +1106,7 @@ fn parse_insert_with_on_duplicate_update() {
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(
Expr::Identifier(Ident::new("perm_update"))
))],
filter: None,
over: None,
distinct: false,
special: false,
Expand All @@ -1116,6 +1120,7 @@ fn parse_insert_with_on_duplicate_update() {
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(
Expr::Identifier(Ident::new("perm_delete"))
))],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -1460,6 +1465,7 @@ fn parse_table_colum_option_on_update() {
option: ColumnOption::OnUpdate(Expr::Function(Function {
name: ObjectName(vec![Ident::new("CURRENT_TIMESTAMP")]),
args: vec![],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down
6 changes: 6 additions & 0 deletions tests/sqlparser_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2272,6 +2272,7 @@ fn test_composite_value() {
named: true
}
)))],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down Expand Up @@ -2433,6 +2434,7 @@ fn parse_current_functions() {
&Expr::Function(Function {
name: ObjectName(vec![Ident::new("CURRENT_CATALOG")]),
args: vec![],
filter: None,
over: None,
distinct: false,
special: true,
Expand All @@ -2444,6 +2446,7 @@ fn parse_current_functions() {
&Expr::Function(Function {
name: ObjectName(vec![Ident::new("CURRENT_USER")]),
args: vec![],
filter: None,
over: None,
distinct: false,
special: true,
Expand All @@ -2455,6 +2458,7 @@ fn parse_current_functions() {
&Expr::Function(Function {
name: ObjectName(vec![Ident::new("SESSION_USER")]),
args: vec![],
filter: None,
over: None,
distinct: false,
special: true,
Expand All @@ -2466,6 +2470,7 @@ fn parse_current_functions() {
&Expr::Function(Function {
name: ObjectName(vec![Ident::new("USER")]),
args: vec![],
filter: None,
over: None,
distinct: false,
special: true,
Expand Down Expand Up @@ -2916,6 +2921,7 @@ fn parse_delimited_identifiers() {
&Expr::Function(Function {
name: ObjectName(vec![Ident::with_quote('"', "myfun")]),
args: vec![],
filter: None,
over: None,
distinct: false,
special: false,
Expand Down
Loading

0 comments on commit 9af63ba

Please sign in to comment.