From e2daee92c5b1c24481ac5903c82aa5bbed1395ef Mon Sep 17 00:00:00 2001 From: ygf11 Date: Sun, 15 Jan 2023 20:23:45 +0800 Subject: [PATCH] Support non-tuple expression for in-subquery to join (#4826) * Support non-tuple expression for in-subquery to join * add tests * add comment and fix cargo fmt * fix comment * clean unused comment * Update datafusion/optimizer/src/decorrelate_where_in.rs Co-authored-by: Andrew Lamb * Update datafusion/optimizer/src/decorrelate_where_in.rs Co-authored-by: Andrew Lamb * Update datafusion/optimizer/src/decorrelate_where_in.rs Co-authored-by: Andrew Lamb * fix comment * fix cargo fmt * add tests * fix cargo fmt Co-authored-by: Andrew Lamb --- datafusion/core/tests/sql/joins.rs | 275 +++++++++++ datafusion/core/tests/sql/subqueries.rs | 13 +- datafusion/expr/src/utils.rs | 5 +- .../optimizer/src/decorrelate_where_in.rs | 456 ++++++++++++++---- 4 files changed, 643 insertions(+), 106 deletions(-) diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index db5c706d3353..c20c66e1016f 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -2868,3 +2868,278 @@ async fn test_cross_join_to_groupby_with_different_key_ordering() -> Result<()> Ok(()) } + +#[tokio::test] +async fn subquery_to_join_with_both_side_expr() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in (select t2.t2_id + 1 from t2)"; + + // assert logical plan + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]", + " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "| 33 | c | 3 |", + "| 44 | d | 4 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn subquery_to_join_with_muti_filter() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in + (select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t2.t2_int > 0)"; + + // assert logical plan + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N]", + " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N]", + " Filter: t2.t2_int > UInt32(0) [t2_id:UInt32;N, t2_int:UInt32;N]", + " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "| 33 | c | 3 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn three_projection_exprs_subquery_to_join() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in + (select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t1.t1_name != t2.t2_name and t2.t2_int > 0)"; + + // assert logical plan + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int AND t1.t1_name != __correlated_sq_1.t2_name [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]", + " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int, t2.t2_name [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]", + " Filter: t2.t2_int > UInt32(0) [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "| 33 | c | 3 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn in_subquery_to_join_with_correlated_outer_filter() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in + (select t2.t2_id + 1 from t2 where t1.t1_int > 0)"; + + // assert logical plan + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + // The `t1.t1_int > UInt32(0)` should be pushdown by `filter push down rule`. + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]", + " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "| 33 | c | 3 |", + "| 44 | d | 4 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn in_subquery_to_join_with_outer_filter() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in + (select t2.t2_id + 1 from t2 where t1.t1_int <= t2.t2_int and t1.t1_name != t2.t2_name) and t1.t1_id > 0"; + + // assert logical plan + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) Filter: t1.t1_int <= __correlated_sq_1.t2_int AND t1.t1_name != __correlated_sq_1.t2_name [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Filter: t1.t1_id > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]", + " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1), t2.t2_int, t2.t2_name [CAST(t2_id AS Int64) + Int64(1):Int64;N, t2_int:UInt32;N, t2_name:Utf8;N]", + " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 11 | a | 1 |", + "| 33 | c | 3 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn two_in_subquery_to_join_with_outer_filter() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t1.t1_int from t1 where t1.t1_id + 12 in + (select t2.t2_id + 1 from t2) + and t1.t1_int in(select t2.t2_int + 1 from t2) + and t1.t1_id > 0"; + + // assert logical plan + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t1.t1_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: CAST(t1.t1_int AS Int64) = __correlated_sq_2.CAST(t2_int AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " LeftSemi Join: CAST(t1.t1_id AS Int64) + Int64(12) = __correlated_sq_1.CAST(t2_id AS Int64) + Int64(1) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Filter: t1.t1_id > UInt32(0) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " SubqueryAlias: __correlated_sq_1 [CAST(t2_id AS Int64) + Int64(1):Int64;N]", + " Projection: CAST(t2.t2_id AS Int64) + Int64(1) AS CAST(t2_id AS Int64) + Int64(1) [CAST(t2_id AS Int64) + Int64(1):Int64;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + " SubqueryAlias: __correlated_sq_2 [CAST(t2_int AS Int64) + Int64(1):Int64;N]", + " Projection: CAST(t2.t2_int AS Int64) + Int64(1) AS CAST(t2_int AS Int64) + Int64(1) [CAST(t2_int AS Int64) + Int64(1):Int64;N]", + " TableScan: t2 projection=[t2_int] [t2_int:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + let expected = vec![ + "+-------+---------+--------+", + "| t1_id | t1_name | t1_int |", + "+-------+---------+--------+", + "| 44 | d | 4 |", + "+-------+---------+--------+", + ]; + + let results = execute_to_batches(&ctx, sql).await; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 2627a2db07b9..6928e98b789b 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -94,12 +94,13 @@ where o_orderstatus in ( let dataframe = ctx.sql(sql).await.unwrap(); let plan = dataframe.into_optimized_plan().unwrap(); let actual = format!("{}", plan.display_indent()); - let expected = r#"Projection: orders.o_orderkey - LeftSemi Join: orders.o_orderstatus = __correlated_sq_1.l_linestatus, orders.o_orderkey = __correlated_sq_1.l_orderkey - TableScan: orders projection=[o_orderkey, o_orderstatus] - SubqueryAlias: __correlated_sq_1 - Projection: lineitem.l_linestatus AS l_linestatus, lineitem.l_orderkey AS l_orderkey - TableScan: lineitem projection=[l_orderkey, l_linestatus]"#; + + let expected = "Projection: orders.o_orderkey\ + \n LeftSemi Join: orders.o_orderstatus = __correlated_sq_1.l_linestatus, orders.o_orderkey = __correlated_sq_1.l_orderkey\ + \n TableScan: orders projection=[o_orderkey, o_orderstatus]\ + \n SubqueryAlias: __correlated_sq_1\ + \n Projection: lineitem.l_linestatus AS l_linestatus, lineitem.l_orderkey\ + \n TableScan: lineitem projection=[l_orderkey, l_linestatus]"; assert_eq!(actual, expected); // assert data diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 682d1332143c..e84ba0b6f8dd 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -965,7 +965,10 @@ pub fn can_hash(data_type: &DataType) -> bool { } /// Check whether all columns are from the schema. -fn check_all_column_from_schema(columns: &HashSet, schema: DFSchemaRef) -> bool { +pub fn check_all_column_from_schema( + columns: &HashSet, + schema: DFSchemaRef, +) -> bool { columns .iter() .all(|column| schema.index_of_column(column).is_ok()) diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index 1aa976ce8ca7..13e3acf78876 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -17,15 +17,15 @@ use crate::alias::AliasGenerator; use crate::optimizer::ApplyOrder; -use crate::utils::{ - alias_cols, conjunction, exprs_to_join_cols, find_join_exprs, merge_cols, - only_or_err, split_conjunction, swap_table, verify_not_disjunction, -}; +use crate::utils::{conjunction, only_or_err, split_conjunction}; use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{context, Result}; +use datafusion_common::{context, Column, Result}; +use datafusion_expr::expr_rewriter::{replace_col, unnormalize_col}; use datafusion_expr::logical_plan::{JoinType, Projection, Subquery}; +use datafusion_expr::utils::check_all_column_from_schema; use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; use log::debug; +use std::collections::{BTreeSet, HashMap}; use std::sync::Arc; #[derive(Default)] @@ -96,6 +96,7 @@ impl OptimizerRule for DecorrelateWhereIn { return Ok(None); } + // iterate through all exists clauses in predicate, turning each into a join // iterate through all exists clauses in predicate, turning each into a join let mut cur_input = filter.input.as_ref().clone(); for subquery in subqueries { @@ -121,81 +122,98 @@ impl OptimizerRule for DecorrelateWhereIn { } } +/// Optimize the where in subquery to left-anti/left-semi join. +/// If the subquery is a correlated subquery, we need extract the join predicate from the subquery. +/// +/// For example, given a query like: +/// `select t1.a, t1.b from t1 where t1 in (select t2.a from t2 where t1.b = t2.b and t1.c > t2.c)` +/// +/// The optimized plan will be: +/// +/// ```text +/// Projection: t1.a, t1.b +/// LeftSemi Join: Filter: t1.a = __correlated_sq_1.a AND t1.b = __correlated_sq_1.b AND t1.c > __correlated_sq_1.c +/// TableScan: t1 +/// SubqueryAlias: __correlated_sq_1 +/// Projection: t2.a AS a, t2.b, t2.c +/// TableScan: t2 +/// ``` fn optimize_where_in( query_info: &SubqueryInfo, - outer_input: &LogicalPlan, + left: &LogicalPlan, outer_other_exprs: &[Expr], alias: &AliasGenerator, ) -> Result { - let proj = Projection::try_from_plan(&query_info.query.subquery) + let projection = Projection::try_from_plan(&query_info.query.subquery) .map_err(|e| context!("a projection is required", e))?; - let mut subqry_input = proj.input.clone(); - let proj = only_or_err(proj.expr.as_slice()) + let subquery_input = projection.input.clone(); + let subquery_expr = only_or_err(projection.expr.as_slice()) .map_err(|e| context!("single expression projection required", e))?; - let subquery_col = proj - .try_into_col() - .map_err(|e| context!("single column projection required", e))?; - let outer_col = query_info - .where_in_expr - .try_into_col() - .map_err(|e| context!("column comparison required", e))?; - - // If subquery is correlated, grab necessary information - let mut subqry_cols = vec![]; - let mut outer_cols = vec![]; - let mut join_filters = None; - let mut other_subqry_exprs = vec![]; - if let LogicalPlan::Filter(subqry_filter) = (*subqry_input).clone() { - // split into filters - let subqry_filter_exprs = split_conjunction(&subqry_filter.predicate); - verify_not_disjunction(&subqry_filter_exprs)?; - - // Grab column names to join on - let (col_exprs, other_exprs) = - find_join_exprs(subqry_filter_exprs, subqry_filter.input.schema()) - .map_err(|e| context!("column correlation not found", e))?; - if !col_exprs.is_empty() { - // it's correlated - subqry_input = subqry_filter.input.clone(); - (outer_cols, subqry_cols, join_filters) = - exprs_to_join_cols(&col_exprs, subqry_filter.input.schema(), false) - .map_err(|e| context!("column correlation not found", e))?; - other_subqry_exprs = other_exprs; - } - } - let (subqry_cols, outer_cols) = - merge_cols((&[subquery_col], &subqry_cols), (&[outer_col], &outer_cols)); - - // build subquery side of join - the thing the subquery was querying - let subqry_alias = alias.next("__correlated_sq"); - let mut subqry_plan = LogicalPlanBuilder::from((*subqry_input).clone()); - if let Some(expr) = conjunction(other_subqry_exprs) { - // if the subquery had additional expressions, restore them - subqry_plan = subqry_plan.filter(expr)? + // extract join filters + let (join_filters, subquery_input) = extract_join_filters(subquery_input.as_ref())?; + + // in_predicate may be also include in the join filters, remove it from the join filters. + let in_predicate = Expr::eq(query_info.where_in_expr.clone(), subquery_expr.clone()); + let join_filters = remove_duplicated_filter(join_filters, in_predicate); + + // replace qualified name with subquery alias. + let subquery_alias = alias.next("__correlated_sq"); + let input_schema = subquery_input.schema(); + let mut subquery_cols: BTreeSet = + join_filters + .iter() + .try_fold(BTreeSet::new(), |mut cols, expr| { + let using_cols: Vec = expr + .to_columns()? + .into_iter() + .filter(|col| input_schema.field_from_column(col).is_ok()) + .collect::<_>(); + + cols.extend(using_cols); + Result::Ok(cols) + })?; + let join_filter = conjunction(join_filters).map_or(Ok(None), |filter| { + replace_qualified_name(filter, &subquery_cols, &subquery_alias).map(Option::Some) + })?; + + // add projection + if let Expr::Column(col) = subquery_expr { + subquery_cols.remove(col); } - let projection = alias_cols(&subqry_cols); - let subqry_plan = subqry_plan - .project(projection)? - .alias(&subqry_alias)? + let subquery_expr_name = format!("{:?}", unnormalize_col(subquery_expr.clone())); + let first_expr = subquery_expr.clone().alias(subquery_expr_name.clone()); + let projection_exprs: Vec = [first_expr] + .into_iter() + .chain(subquery_cols.into_iter().map(Expr::Column)) + .collect(); + + let right = LogicalPlanBuilder::from(subquery_input) + .project(projection_exprs)? + .alias(&subquery_alias)? .build()?; - debug!("subquery plan:\n{}", subqry_plan.display_indent()); - - // qualify the join columns for outside the subquery - let subqry_cols = swap_table(&subqry_alias, &subqry_cols); - let join_keys = (outer_cols, subqry_cols); // join our sub query into the main plan let join_type = match query_info.negated { true => JoinType::LeftAnti, false => JoinType::LeftSemi, }; - let mut new_plan = LogicalPlanBuilder::from(outer_input.clone()).join( - subqry_plan, + let right_join_col = Column::new(Some(subquery_alias), subquery_expr_name); + let in_predicate = Expr::eq( + query_info.where_in_expr.clone(), + Expr::Column(right_join_col), + ); + let join_filter = join_filter + .map(|filter| in_predicate.clone().and(filter)) + .unwrap_or_else(|| in_predicate); + + let mut new_plan = LogicalPlanBuilder::from(left.clone()).join( + right, join_type, - join_keys, - join_filters, + (Vec::::new(), Vec::::new()), + Some(join_filter), )?; + if let Some(expr) = conjunction(outer_other_exprs.to_vec()) { new_plan = new_plan.filter(expr)? // if the main query had additional expressions, restore them } @@ -205,6 +223,72 @@ fn optimize_where_in( Ok(new_plan) } +fn extract_join_filters(maybe_filter: &LogicalPlan) -> Result<(Vec, LogicalPlan)> { + if let LogicalPlan::Filter(plan_filter) = maybe_filter { + let input_schema = plan_filter.input.schema(); + let subquery_filter_exprs = split_conjunction(&plan_filter.predicate); + + let mut join_filters: Vec = vec![]; + let mut subquery_filters: Vec = vec![]; + for expr in subquery_filter_exprs { + let cols = expr.to_columns()?; + if check_all_column_from_schema(&cols, input_schema.clone()) { + subquery_filters.push(expr.clone()); + } else { + join_filters.push(expr.clone()) + } + } + + // if the subquery still has filter expressions, restore them. + let mut plan = LogicalPlanBuilder::from((*plan_filter.input).clone()); + if let Some(expr) = conjunction(subquery_filters) { + plan = plan.filter(expr)? + } + + Ok((join_filters, plan.build()?)) + } else { + Ok((vec![], maybe_filter.clone())) + } +} + +fn remove_duplicated_filter(filters: Vec, in_predicate: Expr) -> Vec { + filters + .into_iter() + .filter(|filter| { + if filter == &in_predicate { + return false; + } + + // ignore the binary order + !match (filter, &in_predicate) { + (Expr::BinaryExpr(a_expr), Expr::BinaryExpr(b_expr)) => { + (a_expr.op == b_expr.op) + && (a_expr.left == b_expr.left && a_expr.right == b_expr.right) + || (a_expr.left == b_expr.right && a_expr.right == b_expr.left) + } + _ => false, + } + }) + .collect::>() +} + +fn replace_qualified_name( + expr: Expr, + cols: &BTreeSet, + subquery_alias: &str, +) -> Result { + let alias_cols: Vec = cols + .iter() + .map(|col| { + Column::from_qualified_name(format!("{}.{}", subquery_alias, col.name)) + }) + .collect(); + let replace_map: HashMap<&Column, &Column> = + cols.iter().zip(alias_cols.iter()).collect(); + + replace_col(expr, &replace_map) +} + struct SubqueryInfo { query: Subquery, where_in_expr: Expr, @@ -263,8 +347,8 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.b = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq_1.c AS c [c:UInt32]\ @@ -272,7 +356,6 @@ mod tests { \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ \n Projection: sq_2.c AS c [c:UInt32]\ \n TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_equal(&plan, expected) } @@ -293,7 +376,7 @@ mod tests { let expected = "Projection: test.b [b:UInt32]\ \n Filter: test.a = UInt32(1) AND test.b < UInt32(30) [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq.c AS c [c:UInt32]\ @@ -347,7 +430,7 @@ mod tests { \n Subquery: [c:UInt32]\ \n Projection: sq1.c [c:UInt32]\ \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ - \n LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq2.c AS c [c:UInt32]\ @@ -372,11 +455,11 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: test.b = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.b = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_1 [a:UInt32]\ \n Projection: sq.a AS a [a:UInt32]\ - \n LeftSemi Join: sq.a = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: sq.a = __correlated_sq_2.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_2 [c:UInt32]\ \n Projection: sq_nested.c AS c [c:UInt32]\ @@ -401,14 +484,14 @@ mod tests { .project(vec![col("b")])? .build()?; - let expected = "Projection: wrapped.b [b:UInt32]\ + let expected = "Projection: wrapped.b [b:UInt32]\ \n Filter: wrapped.b < UInt32(30) OR wrapped.c IN () [b:UInt32, c:UInt32]\ \n Subquery: [c:UInt32]\ \n Projection: sq_outer.c [c:UInt32]\ \n TableScan: sq_outer [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: wrapped [b:UInt32, c:UInt32]\ \n Projection: test.b, test.c [b:UInt32, c:UInt32]\ - \n LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq_inner.c AS c [c:UInt32]\ @@ -443,14 +526,16 @@ mod tests { debug!("plan to optimize:\n{}", plan.display_indent()); let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ + \n SubqueryAlias: __correlated_sq_2 [o_custkey:Int64]\ + \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + assert_optimized_plan_eq_display_indent( Arc::new(DecorrelateWhereIn::new()), &plan, @@ -486,11 +571,11 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: orders.o_orderkey = __correlated_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n LeftSemi Join: Filter: orders.o_orderkey = __correlated_sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n SubqueryAlias: __correlated_sq_2 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey AS l_orderkey [l_orderkey:Int64]\ @@ -524,7 +609,7 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ @@ -554,14 +639,12 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - // Query will fail, but we can still transform the plan let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey = customer.c_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ - \n Filter: customer.c_custkey = customer.c_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; assert_optimized_plan_eq_display_indent( Arc::new(DecorrelateWhereIn::new()), @@ -587,7 +670,7 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ @@ -618,7 +701,7 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey Filter: customer.c_custkey != orders.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey != __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ @@ -647,11 +730,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - // can't optimize on arbitrary expressions (yet) - assert_optimizer_err( + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND customer.c_custkey < __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ + \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + + assert_optimized_plan_eq_display_indent( Arc::new(DecorrelateWhereIn::new()), &plan, - "column correlation not found", + expected, ); Ok(()) } @@ -675,11 +764,19 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - assert_optimizer_err( + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey AND (customer.c_custkey = __correlated_sq_1.o_custkey OR __correlated_sq_1.o_orderkey = Int32(1)) [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64, o_orderkey:Int64]\ + \n Projection: orders.o_custkey AS o_custkey, orders.o_orderkey [o_custkey:Int64, o_orderkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + + assert_optimized_plan_eq_display_indent( Arc::new(DecorrelateWhereIn::new()), &plan, - "Optimizing disjunctions not supported!", + expected, ); + Ok(()) } @@ -721,11 +818,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - // TODO: support join on expression - assert_optimizer_err( + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey + Int32(1) = __correlated_sq_1.o_custkey AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ + \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + + assert_optimized_plan_eq_display_indent( Arc::new(DecorrelateWhereIn::new()), &plan, - "column comparison required", + expected, ); Ok(()) } @@ -745,11 +848,17 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - // TODO: support join on expressions? - assert_optimizer_err( + let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey + Int32(1) AND customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ + \n SubqueryAlias: __correlated_sq_1 [o_custkey + Int32(1):Int64, o_custkey:Int64]\ + \n Projection: orders.o_custkey + Int32(1) AS o_custkey + Int32(1), orders.o_custkey [o_custkey + Int32(1):Int64, o_custkey:Int64]\ + \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; + + assert_optimized_plan_eq_display_indent( Arc::new(DecorrelateWhereIn::new()), &plan, - "single column projection required", + expected, ); Ok(()) } @@ -800,7 +909,7 @@ mod tests { let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ \n Filter: customer.c_custkey = Int32(1) [c_custkey:Int64, c_name:Utf8]\ - \n LeftSemi Join: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: Filter: customer.c_custkey = __correlated_sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ \n SubqueryAlias: __correlated_sq_1 [o_custkey:Int64]\ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ @@ -865,10 +974,10 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: test.c = __correlated_sq_1.c, test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_1 [c:UInt32, a:UInt32]\ - \n Projection: sq.c AS c, sq.a AS a [c:UInt32, a:UInt32]\ + \n Projection: sq.c AS c, sq.a [c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_eq_display_indent( @@ -889,7 +998,7 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftSemi Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq.c AS c [c:UInt32]\ @@ -913,7 +1022,7 @@ mod tests { .build()?; let expected = "Projection: test.b [b:UInt32]\ - \n LeftAnti Join: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftAnti Join: Filter: test.c = __correlated_sq_1.c [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ \n SubqueryAlias: __correlated_sq_1 [c:UInt32]\ \n Projection: sq.c AS c [c:UInt32]\ @@ -926,4 +1035,153 @@ mod tests { ); Ok(()) } + + #[test] + fn in_subquery_both_side_expr() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + + let subquery = LogicalPlanBuilder::from(subquery_scan) + .project(vec![col("c") * lit(2u32)])? + .build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32]\ + \n Projection: sq.c * UInt32(2) AS c * UInt32(2) [c * UInt32(2):UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); + Ok(()) + } + + #[test] + fn in_subquery_join_filter_and_inner_filter() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter( + col("test.a") + .eq(col("sq.a")) + .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))), + )? + .project(vec![col("c") * lit(2u32)])? + .build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a = __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32]\ + \n Projection: sq.c * UInt32(2) AS c * UInt32(2), sq.a [c * UInt32(2):UInt32, a:UInt32]\ + \n Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); + Ok(()) + } + + #[test] + fn in_subquery_muti_project_subquery_cols() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan = test_table_scan_with_name("sq")?; + + let subquery = LogicalPlanBuilder::from(subquery_scan) + .filter( + col("test.a") + .add(col("test.b")) + .eq(col("sq.a").add(col("sq.b"))) + .and(col("sq.a").add(lit(1u32)).eq(col("sq.b"))), + )? + .project(vec![col("c") * lit(2u32)])? + .build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter(in_subquery(col("c") + lit(1u32), Arc::new(subquery)))? + .project(vec![col("test.b")])? + .build()?; + + let expected = "Projection: test.b [b:UInt32]\ + \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a + test.b = __correlated_sq_1.a + __correlated_sq_1.b [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32, b:UInt32]\ + \n Projection: sq.c * UInt32(2) AS c * UInt32(2), sq.a, sq.b [c * UInt32(2):UInt32, a:UInt32, b:UInt32]\ + \n Filter: sq.a + UInt32(1) = sq.b [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); + Ok(()) + } + + #[test] + fn two_in_subquery_with_outer_filter() -> Result<()> { + let table_scan = test_table_scan()?; + let subquery_scan1 = test_table_scan_with_name("sq1")?; + let subquery_scan2 = test_table_scan_with_name("sq2")?; + + let subquery1 = LogicalPlanBuilder::from(subquery_scan1) + .filter(col("test.a").gt(col("sq1.a")))? + .project(vec![col("c") * lit(2u32)])? + .build()?; + + let subquery2 = LogicalPlanBuilder::from(subquery_scan2) + .filter(col("test.a").gt(col("sq2.a")))? + .project(vec![col("c") * lit(2u32)])? + .build()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .filter( + in_subquery(col("c") + lit(1u32), Arc::new(subquery1)).and( + in_subquery(col("c") * lit(2u32), Arc::new(subquery2)) + .and(col("test.c").gt(lit(1u32))), + ), + )? + .project(vec![col("test.b")])? + .build()?; + + // Filter: test.c > UInt32(1) happen twice. + // issue: https://github.com/apache/arrow-datafusion/issues/4914 + let expected = "Projection: test.b [b:UInt32]\ + \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c * UInt32(2) = __correlated_sq_2.c * UInt32(2) AND test.a > __correlated_sq_2.a [a:UInt32, b:UInt32, c:UInt32]\ + \n Filter: test.c > UInt32(1) [a:UInt32, b:UInt32, c:UInt32]\ + \n LeftSemi Join: Filter: test.c + UInt32(1) = __correlated_sq_1.c * UInt32(2) AND test.a > __correlated_sq_1.a [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]\ + \n SubqueryAlias: __correlated_sq_1 [c * UInt32(2):UInt32, a:UInt32]\ + \n Projection: sq1.c * UInt32(2) AS c * UInt32(2), sq1.a [c * UInt32(2):UInt32, a:UInt32]\ + \n TableScan: sq1 [a:UInt32, b:UInt32, c:UInt32]\ + \n SubqueryAlias: __correlated_sq_2 [c * UInt32(2):UInt32, a:UInt32]\ + \n Projection: sq2.c * UInt32(2) AS c * UInt32(2), sq2.a [c * UInt32(2):UInt32, a:UInt32]\ + \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]"; + + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); + Ok(()) + } }