diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 6d8db2043c07..0d8da557397e 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -735,6 +735,48 @@ impl OptimizerRule for PushDownFilter { None => new_scan, } } + LogicalPlan::Extension(extension_plan) => { + let prevent_cols = + extension_plan.node.prevent_predicate_push_down_columns(); + + let predicates = utils::split_conjunction_owned(filter.predicate.clone()); + + let mut keep_predicates = vec![]; + let mut push_predicates = vec![]; + for expr in predicates { + let cols = expr.to_columns()?; + if cols.iter().any(|c| prevent_cols.contains(&c.name)) { + keep_predicates.push(expr); + } else { + push_predicates.push(expr); + } + } + + let new_children = match conjunction(push_predicates) { + Some(predicate) => extension_plan + .node + .inputs() + .into_iter() + .map(|child| { + Ok(LogicalPlan::Filter(Filter::try_new( + predicate.clone(), + Arc::new(child.clone()), + )?)) + }) + .collect::>>()?, + None => extension_plan.node.inputs().into_iter().cloned().collect(), + }; + // extension with new inputs. + let new_extension = child_plan.with_new_inputs(&new_children)?; + + match conjunction(keep_predicates) { + Some(predicate) => LogicalPlan::Filter(Filter::try_new( + predicate, + Arc::new(new_extension), + )?), + None => new_extension, + } + } _ => return Ok(None), }; Ok(Some(new_plan)) @@ -774,12 +816,15 @@ mod tests { use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; - use datafusion_common::DFSchema; + use datafusion_common::{DFSchema, DFSchemaRef}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{ and, col, in_list, in_subquery, lit, logical_plan::JoinType, or, sum, BinaryExpr, - Expr, LogicalPlanBuilder, Operator, TableSource, TableType, + Expr, Extension, LogicalPlanBuilder, Operator, TableSource, TableType, + UserDefinedLogicalNode, }; + use std::any::Any; + use std::fmt::{Debug, Formatter}; use std::sync::Arc; fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { @@ -1029,6 +1074,131 @@ mod tests { assert_optimized_plan_eq(&plan, expected) } + #[derive(Debug)] + struct NoopPlan { + input: Vec, + schema: DFSchemaRef, + } + + impl UserDefinedLogicalNode for NoopPlan { + fn as_any(&self) -> &dyn Any { + self + } + + fn inputs(&self) -> Vec<&LogicalPlan> { + self.input.iter().collect() + } + + fn schema(&self) -> &DFSchemaRef { + &self.schema + } + + fn expressions(&self) -> Vec { + self.input + .iter() + .flat_map(|child| child.expressions()) + .collect() + } + + fn prevent_predicate_push_down_columns(&self) -> HashSet { + HashSet::from_iter(vec!["c".to_string()]) + } + + fn fmt_for_explain(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "NoopPlan") + } + + fn from_template( + &self, + _exprs: &[Expr], + inputs: &[LogicalPlan], + ) -> Arc { + Arc::new(Self { + input: inputs.to_vec(), + schema: self.schema.clone(), + }) + } + } + + #[test] + fn user_defined_plan() -> Result<()> { + let table_scan = test_table_scan()?; + + let custom_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoopPlan { + input: vec![table_scan.clone()], + schema: table_scan.schema().clone(), + }), + }); + let plan = LogicalPlanBuilder::from(custom_plan) + .filter(col("a").eq(lit(1i64)))? + .build()?; + + // Push filter below NoopPlan + let expected = "\ + NoopPlan\ + \n Filter: test.a = Int64(1)\ + \n TableScan: test"; + assert_optimized_plan_eq(&plan, expected)?; + + let custom_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoopPlan { + input: vec![table_scan.clone()], + schema: table_scan.schema().clone(), + }), + }); + let plan = LogicalPlanBuilder::from(custom_plan) + .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))? + .build()?; + + // Push only predicate on `a` below NoopPlan + let expected = "\ + Filter: test.c = Int64(2)\ + \n NoopPlan\ + \n Filter: test.a = Int64(1)\ + \n TableScan: test"; + assert_optimized_plan_eq(&plan, expected)?; + + let custom_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoopPlan { + input: vec![table_scan.clone(), table_scan.clone()], + schema: table_scan.schema().clone(), + }), + }); + let plan = LogicalPlanBuilder::from(custom_plan) + .filter(col("a").eq(lit(1i64)))? + .build()?; + + // Push filter below NoopPlan for each child branch + let expected = "\ + NoopPlan\ + \n Filter: test.a = Int64(1)\ + \n TableScan: test\ + \n Filter: test.a = Int64(1)\ + \n TableScan: test"; + assert_optimized_plan_eq(&plan, expected)?; + + let custom_plan = LogicalPlan::Extension(Extension { + node: Arc::new(NoopPlan { + input: vec![table_scan.clone(), table_scan.clone()], + schema: table_scan.schema().clone(), + }), + }); + let plan = LogicalPlanBuilder::from(custom_plan) + .filter(col("a").eq(lit(1i64)).and(col("c").eq(lit(2i64))))? + .build()?; + + // Push only predicate on `a` below NoopPlan + let expected = "\ + Filter: test.c = Int64(2)\ + \n NoopPlan\ + \n Filter: test.a = Int64(1)\ + \n TableScan: test\ + \n Filter: test.a = Int64(1)\ + \n TableScan: test"; + assert_optimized_plan_eq(&plan, expected) + } + /// verifies that when two filters apply after an aggregation that only allows one to be pushed, one is pushed /// and the other not. #[test]