Skip to content

Commit

Permalink
ARROW-10760: [Rust] [DataFusion] Fixed error in filter push down over…
Browse files Browse the repository at this point in the history
… joins

This fixes an error on which a predicate depending on columns from both sides of a join was being pushed down through the join, causing incorrect plans.

This causes all filters to be independently pushed to each side of the join, at the same time keeping any predicate that cannot be pushed (e.g. because it depends on both sides of the join).

Closes #8797 from jorgecarleitao/fix_push

Authored-by: Jorge C. Leitao <[email protected]>
Signed-off-by: Andy Grove <[email protected]>
  • Loading branch information
jorgecarleitao authored and andygrove committed Nov 29, 2020
1 parent 71e2cb2 commit 322cd01
Showing 1 changed file with 230 additions and 38 deletions.
268 changes: 230 additions & 38 deletions rust/datafusion/src/optimizer/filter_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

//! Filter Push Down optimizer rule ensures that filters are applied as early as possible in the plan
use arrow::datatypes::Schema;

use crate::error::Result;
use crate::logical_plan::Expr;
use crate::logical_plan::{and, LogicalPlan};
Expand Down Expand Up @@ -57,38 +59,104 @@ struct State {
filters: Vec<(Expr, HashSet<String>)>,
}

/// builds a new [LogicalPlan] from `plan` by issuing new [LogicalPlan::Filter] if any of the filters
/// in `state` depend on the columns `used_columns`.
fn issue_filters(
mut state: State,
used_columns: HashSet<String>,
plan: &LogicalPlan,
) -> Result<LogicalPlan> {
// pick all filters in the current state that depend on any of `used_columns`
let (predicates, predicate_columns): (Vec<_>, Vec<_>) = state
type Predicates<'a> = (Vec<&'a Expr>, Vec<&'a HashSet<String>>);

/// returns all predicates in `state` that depend on any of `used_columns`
fn get_predicates<'a>(
state: &'a State,
used_columns: &HashSet<String>,
) -> Predicates<'a> {
state
.filters
.iter()
.filter(|(_, columns)| {
columns
.intersection(&used_columns)
.intersection(used_columns)
.collect::<HashSet<_>>()
.len()
> 0
})
.map(|&(ref a, ref b)| (a, b))
.unzip()
}

// returns 3 (potentially overlaping) sets of predicates:
// * pushable to left: its columns are all on the left
// * pushable to right: its columns is all on the right
// * keep: the set of columns is not in only either left or right
// Note that a predicate can be both pushed to the left and to the right.
fn get_join_predicates<'a>(
state: &'a State,
left: &Schema,
right: &Schema,
) -> (
Vec<&'a HashSet<String>>,
Vec<&'a HashSet<String>>,
Predicates<'a>,
) {
let left_columns = &left
.fields()
.iter()
.map(|f| f.name().clone())
.collect::<HashSet<_>>();
let right_columns = &right
.fields()
.iter()
.map(|f| f.name().clone())
.collect::<HashSet<_>>();

let filters = state
.filters
.iter()
.map(|(predicate, columns)| {
(
(predicate, columns),
(
columns,
left_columns.intersection(columns).collect::<HashSet<_>>(),
right_columns.intersection(columns).collect::<HashSet<_>>(),
),
)
})
.collect::<Vec<_>>();

let pushable_to_left = filters
.iter()
.filter(|(_, (columns, left, _))| left.len() == columns.len())
.map(|((_, b), _)| *b)
.collect();
let pushable_to_right = filters
.iter()
.filter(|(_, (columns, _, right))| right.len() == columns.len())
.map(|((_, b), _)| *b)
.collect();
let keep = filters
.iter()
.filter(|(_, (columns, left, right))| {
// predicates whose columns are not in only one side of the join need to remain
let all_in_left = left.len() == columns.len();
let all_in_right = right.len() == columns.len();
!all_in_left && !all_in_right
})
.map(|((ref a, ref b), _)| (a, b))
.unzip();
(pushable_to_left, pushable_to_right, keep)
}

if predicates.is_empty() {
// all filters can be pushed down => optimize inputs and return new plan
let new_inputs = utils::inputs(&plan)
.iter()
.map(|input| optimize(input, state.clone()))
.collect::<Result<Vec<_>>>()?;
/// Optimizes the plan
fn push_down(state: &State, plan: &LogicalPlan) -> Result<LogicalPlan> {
let new_inputs = utils::inputs(&plan)
.iter()
.map(|input| optimize(input, state.clone()))
.collect::<Result<Vec<_>>>()?;

let expr = utils::expressions(&plan);
return utils::from_plan(&plan, &expr, &new_inputs);
}
let expr = utils::expressions(&plan);
utils::from_plan(&plan, &expr, &new_inputs)
}

/// returns a new [LogicalPlan] that wraps `plan` in a [LogicalPlan::Filter] with
/// its predicate be all `predicates` ANDed.
fn add_filter(plan: LogicalPlan, predicates: &[&Expr]) -> LogicalPlan {
// reduce filters to a single filter with an AND
let predicate = predicates
.iter()
Expand All @@ -97,28 +165,56 @@ fn issue_filters(
and(acc, (*predicate).to_owned())
});

// add a new filter node with the predicates
let plan = LogicalPlan::Filter {
LogicalPlan::Filter {
predicate,
input: Arc::new(plan.clone()),
};
input: Arc::new(plan),
}
}

// remove all filters from the state that cannot be pushed further down
state.filters = state
.filters
// remove all filters from `filters` that are in `predicate_columns`
fn remove_filters(
filters: &[(Expr, HashSet<String>)],
predicate_columns: &[&HashSet<String>],
) -> Vec<(Expr, HashSet<String>)> {
filters
.iter()
.filter(|(_, columns)| !predicate_columns.contains(&columns))
.cloned()
.collect::<Vec<_>>();
.collect::<Vec<_>>()
}

// continue optimization over all input nodes by cloning the current state (i.e. each node is independent)
let new_inputs = utils::inputs(&plan)
// keeps all filters from `filters` that are in `predicate_columns`
fn keep_filters(
filters: &[(Expr, HashSet<String>)],
predicate_columns: &[&HashSet<String>],
) -> Vec<(Expr, HashSet<String>)> {
filters
.iter()
.map(|input| optimize(input, state.clone()))
.collect::<Result<Vec<_>>>()?;
.filter(|(_, columns)| predicate_columns.contains(&columns))
.cloned()
.collect::<Vec<_>>()
}

let expr = utils::expressions(&plan);
utils::from_plan(&plan, &expr, &new_inputs)
/// builds a new [LogicalPlan] from `plan` by issuing new [LogicalPlan::Filter] if any of the filters
/// in `state` depend on the columns `used_columns`.
fn issue_filters(
mut state: State,
used_columns: HashSet<String>,
plan: &LogicalPlan,
) -> Result<LogicalPlan> {
let (predicates, predicate_columns) = get_predicates(&state, &used_columns);

if predicates.is_empty() {
// all filters can be pushed down => optimize inputs and return new plan
return push_down(&state, plan);
}

let plan = add_filter(plan.clone(), &predicates);

state.filters = remove_filters(&state.filters, &predicate_columns);

// continue optimization over all input nodes by cloning the current state (i.e. each node is independent)
push_down(&state, &plan)
}

fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
Expand Down Expand Up @@ -183,7 +279,7 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
}
LogicalPlan::Sort { .. } => {
// sort is filter-commutable
issue_filters(state, HashSet::new(), plan)
push_down(&state, plan)
}
LogicalPlan::Limit { input, .. } => {
// limit is _not_ filter-commutable => collect all columns from its input
Expand All @@ -195,9 +291,31 @@ fn optimize(plan: &LogicalPlan, mut state: State) -> Result<LogicalPlan> {
.collect::<HashSet<_>>();
issue_filters(state, used_columns, plan)
}
LogicalPlan::Join { .. } => {
// join is filter-commutable
issue_filters(state, HashSet::new(), plan)
LogicalPlan::Join { left, right, .. } => {
let (pushable_to_left, pushable_to_right, keep) =
get_join_predicates(&state, &left.schema(), &right.schema());

let mut left_state = state.clone();
left_state.filters = keep_filters(&left_state.filters, &pushable_to_left);
let left = optimize(left, left_state)?;

let mut right_state = state.clone();
right_state.filters = keep_filters(&right_state.filters, &pushable_to_right);
let right = optimize(right, right_state)?;

// create a new Join with the new `left` and `right`
let expr = utils::expressions(&plan);
let plan = utils::from_plan(&plan, &expr, &vec![left, right])?;

if keep.0.is_empty() {
Ok(plan)
} else {
// wrap the join on the filter whose predicates must be kept
let plan = add_filter(plan, &keep.0);
state.filters = remove_filters(&state.filters, &keep.1);

Ok(plan)
}
}
_ => {
// all other plans are _not_ filter-commutable
Expand Down Expand Up @@ -594,8 +712,9 @@ mod tests {
Ok(())
}

/// post-join predicates on a column common to both sides is pushed to both sides
#[test]
fn filters_join() -> Result<()> {
fn filter_join_on_common_independent() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(&table_scan).build()?;
let right = LogicalPlanBuilder::from(&table_scan)
Expand Down Expand Up @@ -628,4 +747,77 @@ mod tests {
assert_optimized_plan_eq(&plan, expected);
Ok(())
}

/// post-join predicates with columns from both sides are not pushed
#[test]
fn filter_join_on_common_dependent() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("a"), col("c")])?
.build()?;
let right = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let plan = LogicalPlanBuilder::from(&left)
.join(&right, JoinType::Inner, &["a"], &["a"])?
// "b" and "c" are not shared by either side: they are only available together after the join
.filter(col("c").lt_eq(col("b")))?
.build()?;

// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
"\
Filter: #c LtEq #b\
\n Join: a = a\
\n Projection: #a, #c\
\n TableScan: test projection=None\
\n Projection: #a, #b\
\n TableScan: test projection=None"
);

// expected is equal: no push-down
let expected = &format!("{:?}", plan);
assert_optimized_plan_eq(&plan, expected);
Ok(())
}

/// post-join predicates with columns from one side of a join are pushed only to that side
#[test]
fn filter_join_on_one_side() -> Result<()> {
let table_scan = test_table_scan()?;
let left = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("a"), col("b")])?
.build()?;
let right = LogicalPlanBuilder::from(&table_scan)
.project(vec![col("a"), col("c")])?
.build()?;
let plan = LogicalPlanBuilder::from(&left)
.join(&right, JoinType::Inner, &["a"], &["a"])?
.filter(col("b").lt_eq(lit(1i64)))?
.build()?;

// not part of the test, just good to know:
assert_eq!(
format!("{:?}", plan),
"\
Filter: #b LtEq Int64(1)\
\n Join: a = a\
\n Projection: #a, #b\
\n TableScan: test projection=None\
\n Projection: #a, #c\
\n TableScan: test projection=None"
);

let expected = "\
Join: a = a\
\n Projection: #a, #b\
\n Filter: #b LtEq Int64(1)\
\n TableScan: test projection=None\
\n Projection: #a, #c\
\n TableScan: test projection=None";

assert_optimized_plan_eq(&plan, expected);
Ok(())
}
}

0 comments on commit 322cd01

Please sign in to comment.