From d69b128a52efeb334b9433947030e7e62e230703 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 13 May 2024 12:16:40 -0400 Subject: [PATCH] Stop copying LogicalPlan and Exprs in `ScalarSubqueryToJoin` --- .../optimizer/src/scalar_subquery_to_join.rs | 83 ++++++++++++------- 1 file changed, 54 insertions(+), 29 deletions(-) diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index b7fce68fb3cc..71692b934543 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -29,7 +29,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::tree_node::{ Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, }; -use datafusion_common::{plan_err, Column, Result, ScalarValue}; +use datafusion_common::{internal_err, plan_err, Column, Result, ScalarValue}; use datafusion_expr::expr_rewriter::create_col_from_scalar_expr; use datafusion_expr::logical_plan::{JoinType, Subquery}; use datafusion_expr::utils::conjunction; @@ -50,7 +50,7 @@ impl ScalarSubqueryToJoin { /// # Arguments /// * `predicate` - A conjunction to split and search /// - /// Returns a tuple (subqueries, rewrite expression) + /// Returns a tuple (subqueries, alias) fn extract_subquery_exprs( &self, predicate: &Expr, @@ -71,19 +71,36 @@ impl ScalarSubqueryToJoin { impl OptimizerRule for ScalarSubqueryToJoin { fn try_optimize( &self, - plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _plan: &LogicalPlan, + _config: &dyn OptimizerConfig, ) -> Result> { + internal_err!("Should have called ScalarSubqueryToJoin::rewrite") + } + + fn supports_rewrite(&self) -> bool { + true + } + + fn rewrite( + &self, + plan: LogicalPlan, + config: &dyn OptimizerConfig, + ) -> Result> { match plan { LogicalPlan::Filter(filter) => { + // Optimization: skip the rest of the rule and its copies if + // there are no scalar subqueries + if !contains_scalar_subquery(&filter.predicate) { + return Ok(Transformed::no(LogicalPlan::Filter(filter))); + } + let (subqueries, mut rewrite_expr) = self.extract_subquery_exprs( &filter.predicate, config.alias_generator(), )?; if subqueries.is_empty() { - // regular filter, no subquery exists clause here - return Ok(None); + return internal_err!("Expected subqueries not found in filter"); } // iterate through all subqueries in predicate, turning each into a left join @@ -94,16 +111,13 @@ impl OptimizerRule for ScalarSubqueryToJoin { { if !expr_check_map.is_empty() { rewrite_expr = rewrite_expr - .clone() .transform_up(|expr| { - if let Expr::Column(col) = &expr { - if let Some(map_expr) = - expr_check_map.get(&col.name) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } + // replace column references with entry in map, if it exists + if let Some(map_expr) = expr + .try_as_col() + .and_then(|col| expr_check_map.get(&col.name)) + { + Ok(Transformed::yes(map_expr.clone())) } else { Ok(Transformed::no(expr)) } @@ -113,15 +127,21 @@ impl OptimizerRule for ScalarSubqueryToJoin { cur_input = optimized_subquery; } else { // if we can't handle all of the subqueries then bail for now - return Ok(None); + return Ok(Transformed::no(LogicalPlan::Filter(filter))); } } let new_plan = LogicalPlanBuilder::from(cur_input) .filter(rewrite_expr)? .build()?; - Ok(Some(new_plan)) + Ok(Transformed::yes(new_plan)) } LogicalPlan::Projection(projection) => { + // Optimization: skip the rest of the rule and its copies if + // there are no scalar subqueries + if !projection.expr.iter().any(contains_scalar_subquery) { + return Ok(Transformed::no(LogicalPlan::Projection(projection))); + } + let mut all_subqueryies = vec![]; let mut expr_to_rewrite_expr_map = HashMap::new(); let mut subquery_to_expr_map = HashMap::new(); @@ -135,8 +155,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { expr_to_rewrite_expr_map.insert(expr, rewrite_exprs); } if all_subqueryies.is_empty() { - // regular projection, no subquery exists clause here - return Ok(None); + return internal_err!("Expected subqueries not found in projection"); } // iterate through all subqueries in predicate, turning each into a left join let mut cur_input = projection.input.as_ref().clone(); @@ -153,14 +172,13 @@ impl OptimizerRule for ScalarSubqueryToJoin { let new_expr = rewrite_expr .clone() .transform_up(|expr| { - if let Expr::Column(col) = &expr { - if let Some(map_expr) = + // replace column references with entry in map, if it exists + if let Some(map_expr) = + expr.try_as_col().and_then(|col| { expr_check_map.get(&col.name) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } + }) + { + Ok(Transformed::yes(map_expr.clone())) } else { Ok(Transformed::no(expr)) } @@ -172,7 +190,7 @@ impl OptimizerRule for ScalarSubqueryToJoin { } } else { // if we can't handle all of the subqueries then bail for now - return Ok(None); + return Ok(Transformed::no(LogicalPlan::Projection(projection))); } } @@ -190,10 +208,10 @@ impl OptimizerRule for ScalarSubqueryToJoin { let new_plan = LogicalPlanBuilder::from(cur_input) .project(proj_exprs)? .build()?; - Ok(Some(new_plan)) + Ok(Transformed::yes(new_plan)) } - _ => Ok(None), + plan => Ok(Transformed::no(plan)), } } @@ -206,6 +224,13 @@ impl OptimizerRule for ScalarSubqueryToJoin { } } +/// Returns true if the expression has a scalar subquery somewhere in it +/// false otherwise +fn contains_scalar_subquery(expr: &Expr) -> bool { + expr.exists(|expr| Ok(matches!(expr, Expr::ScalarSubquery(_)))) + .expect("Inner is always Ok") +} + struct ExtractScalarSubQuery { sub_query_info: Vec<(Subquery, String)>, alias_gen: Arc,