Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor extract_join_keys and move the ExtractEquijoinPredicate rule #4760

Merged
merged 9 commits into from
Jan 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 98 additions & 4 deletions datafusion/core/tests/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,20 @@

use arrow::datatypes::{DataType, Field, Schema};
use arrow::{
array::{Int32Array, StringArray},
array::{Int32Array, StringArray, UInt32Array},
record_batch::RecordBatch,
};
use datafusion::from_slice::FromSlice;
use std::sync::Arc;

use datafusion::assert_batches_eq;
use datafusion::dataframe::DataFrame;
use datafusion::error::Result;
use datafusion::execution::context::SessionContext;
use datafusion::prelude::CsvReadOptions;
use datafusion::prelude::JoinType;
use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::{avg, count, lit, sum};
use datafusion_expr::{col, Expr};
use datafusion_expr::{avg, col, count, lit, sum, Expr, ExprSchemable};

#[tokio::test]
async fn join() -> Result<()> {
Expand Down Expand Up @@ -352,6 +351,62 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn join_with_alias_filter() -> Result<()> {
let join_ctx = create_join_context()?;
let t1 = join_ctx.table("t1")?;
let t2 = join_ctx.table("t2")?;
let t1_schema = t1.schema().clone();
let t2_schema = t2.schema().clone();

// filter: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) as t1.a + 1 = t2.a + 2
let filter = Expr::eq(
col("t1.a") + lit(3i64).cast_to(&DataType::UInt32, &t1_schema)?,
col("t2.a") + lit(1i32).cast_to(&DataType::UInt32, &t2_schema)?,
)
.alias("t1.b + 1 = t2.a + 2");

let df = t1
.join(t2, JoinType::Inner, &[], &[], Some(filter))?
.select(vec![
col("t1.a"),
col("t2.a"),
col("t1.b"),
col("t1.c"),
col("t2.b"),
col("t2.c"),
])?;
let optimized_plan = df.clone().into_optimized_plan()?;

let expected = vec![
"Projection: t1.a, t2.a, t1.b, t1.c, t2.b, t2.c [a:UInt32, a:UInt32, b:Utf8, c:Int32, b:Utf8, c:Int32]",
" Inner Join: t1.a + UInt32(3) = t2.a + UInt32(1) [a:UInt32, b:Utf8, c:Int32, a:UInt32, b:Utf8, c:Int32]",
" TableScan: t1 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32]",
" TableScan: t2 projection=[a, b, c] [a:UInt32, b:Utf8, c:Int32]",
];

let formatted = optimized_plan.display_indent_schema().to_string();
let actual: Vec<&str> = formatted.trim().lines().collect();
assert_eq!(
expected, actual,
"\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n"
);

let results = df.collect().await?;
let expected: Vec<&str> = vec![
"+----+----+---+----+---+---+",
"| a | a | b | c | b | c |",
"+----+----+---+----+---+---+",
"| 11 | 13 | c | 30 | c | 3 |",
"| 1 | 3 | a | 10 | a | 1 |",
"+----+----+---+----+---+---+",
];

assert_batches_sorted_eq!(expected, &results);

Ok(())
}

fn create_test_table() -> Result<DataFrame> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Utf8, false),
Expand Down Expand Up @@ -388,3 +443,42 @@ async fn aggregates_table(ctx: &SessionContext) -> Result<DataFrame> {
)
.await
}

fn create_join_context() -> Result<SessionContext> {
let t1 = Arc::new(Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::Utf8, false),
Field::new("c", DataType::Int32, false),
]));
let t2 = Arc::new(Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::Utf8, false),
Field::new("c", DataType::Int32, false),
]));

// define data.
let batch1 = RecordBatch::try_new(
t1,
vec![
Arc::new(UInt32Array::from_slice([1, 10, 11, 100])),
Arc::new(StringArray::from_slice(["a", "b", "c", "d"])),
Arc::new(Int32Array::from_slice([10, 20, 30, 40])),
],
)?;
// define data.
let batch2 = RecordBatch::try_new(
t2,
vec![
Arc::new(UInt32Array::from_slice([3, 10, 13, 100])),
Arc::new(StringArray::from_slice(["a", "b", "c", "d"])),
Arc::new(Int32Array::from_slice([1, 2, 3, 4])),
],
)?;

let ctx = SessionContext::new();

ctx.register_batch("t1", batch1)?;
ctx.register_batch("t2", batch2)?;

Ok(ctx)
}
129 changes: 64 additions & 65 deletions datafusion/optimizer/src/extract_equijoin_predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@

//! Optimizer rule to extract equijoin expr from filter
use crate::optimizer::ApplyOrder;
use crate::utils::split_conjunction;
use crate::{OptimizerConfig, OptimizerRule};
use datafusion_common::DFSchema;
use datafusion_common::Result;
use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair};
use datafusion_expr::{BinaryExpr, Expr, ExprSchemable, Join, LogicalPlan, Operator};
use std::sync::Arc;

// equijoin predicate
type EquijoinPredicate = (Expr, Expr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️


/// Optimization rule that extract equijoin expr from the filter
#[derive(Default)]
pub struct ExtractEquijoinPredicate;
Expand Down Expand Up @@ -56,27 +60,22 @@ impl OptimizerRule for ExtractEquijoinPredicate {
let right_schema = right.schema();

filter.as_ref().map_or(Result::Ok(None), |expr| {
let mut accum: Vec<(Expr, Expr)> = vec![];
let mut accum_filter: Vec<Expr> = vec![];
// TODO: avoding clone with split_conjunction
extract_join_keys(
expr.clone(),
&mut accum,
&mut accum_filter,
left_schema,
right_schema,
)?;

let optimized_plan = (!accum.is_empty()).then(|| {
let (equijoin_predicates, non_equijoin_expr) =
split_eq_and_noneq_join_predicate(
expr,
left_schema,
right_schema,
)?;

let optimized_plan = (!equijoin_predicates.is_empty()).then(|| {
let mut new_on = on.clone();
new_on.extend(accum);
new_on.extend(equijoin_predicates);

let new_filter = accum_filter.into_iter().reduce(Expr::and);
LogicalPlan::Join(Join {
left: left.clone(),
right: right.clone(),
on: new_on,
filter: new_filter,
filter: non_equijoin_expr,
join_type: *join_type,
join_constraint: *join_constraint,
schema: schema.clone(),
Expand All @@ -100,30 +99,22 @@ impl OptimizerRule for ExtractEquijoinPredicate {
}
}

/// Extracts equijoin ON condition be a single Eq or multiple conjunctive Eqs
/// Filters matching this pattern are added to `accum`
/// Filters that don't match this pattern are added to `accum_filter`
/// Examples:
/// ```text
/// foo = bar => accum=[(foo, bar)] accum_filter=[]
/// foo = bar AND bar = baz => accum=[(foo, bar), (bar, baz)] accum_filter=[]
/// foo = bar AND baz > 1 => accum=[(foo, bar)] accum_filter=[baz > 1]
///
/// For equijoin join key, assume we have tables -- a(c0, c1 c2) and b(c0, c1, c2):
/// (a.c0 = 10) => accum=[], accum_filter=[a.c0 = 10]
/// (a.c0 + 1 = b.c0 * 2) => accum=[(a.c0 + 1, b.c0 * 2)], accum_filter=[]
/// (a.c0 + b.c0 = 10) => accum=[], accum_filter=[a.c0 + b.c0 = 10]
/// ```
fn extract_join_keys(
expr: Expr,
accum: &mut Vec<(Expr, Expr)>,
accum_filter: &mut Vec<Expr>,
fn split_eq_and_noneq_join_predicate(
filter: &Expr,
left_schema: &Arc<DFSchema>,
right_schema: &Arc<DFSchema>,
) -> Result<()> {
match &expr {
Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op {
Operator::Eq => {
) -> Result<(Vec<EquijoinPredicate>, Option<Expr>)> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a much nicer interface 👍 for being self documenting

let exprs = split_conjunction(filter);

let mut accum_join_keys: Vec<(Expr, Expr)> = vec![];
let mut accum_filters: Vec<Expr> = vec![];
for expr in exprs {
match expr {
Expr::BinaryExpr(BinaryExpr {
left,
op: Operator::Eq,
right,
}) => {
let left = left.as_ref();
let right = right.as_ref();

Expand All @@ -139,48 +130,27 @@ fn extract_join_keys(
let right_expr_type = right_expr.get_type(right_schema)?;

if can_hash(&left_expr_type) && can_hash(&right_expr_type) {
accum.push((left_expr, right_expr));
accum_join_keys.push((left_expr, right_expr));
} else {
accum_filter.push(expr);
accum_filters.push(expr.clone());
}
} else {
accum_filter.push(expr);
}
}
Operator::And => {
if let Expr::BinaryExpr(BinaryExpr { left, op: _, right }) = expr {
extract_join_keys(
*left,
accum,
accum_filter,
left_schema,
right_schema,
)?;
extract_join_keys(
*right,
accum,
accum_filter,
left_schema,
right_schema,
)?;
accum_filters.push(expr.clone());
}
}
_other => {
accum_filter.push(expr);
}
},
_other => {
accum_filter.push(expr);
_ => accum_filters.push(expr.clone()),
}
}

Ok(())
let result_filter = accum_filters.into_iter().reduce(Expr::and);
Ok((accum_join_keys, result_filter))
}

#[cfg(test)]
mod tests {
use super::*;
use crate::test::*;
use arrow::datatypes::DataType;
use datafusion_common::Column;
use datafusion_expr::{
col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType,
Expand Down Expand Up @@ -387,4 +357,33 @@ mod tests {

assert_plan_eq(&plan, expected)
}

#[test]
fn join_with_alias_filter() -> Result<()> {
Comment on lines +361 to +362
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, I recommend add a integration-test to show the plan after all rule optimize it.

Copy link
Contributor Author

@ygf11 ygf11 Jan 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we can't create a join whose condition is an alias a sql, I add a integration-test with dataframe api.

let t1 = test_table_scan_with_name("t1")?;
let t2 = test_table_scan_with_name("t2")?;

let t1_schema = t1.schema().clone();
let t2_schema = t2.schema().clone();

// filter: t1.a + CAST(Int64(1), UInt32) = t2.a + CAST(Int64(2), UInt32) as t1.a + 1 = t2.a + 2
let filter = Expr::eq(
col("t1.a") + lit(1i64).cast_to(&DataType::UInt32, &t1_schema)?,
col("t2.a") + lit(2i32).cast_to(&DataType::UInt32, &t2_schema)?,
)
.alias("t1.a + 1 = t2.a + 2");
let plan = LogicalPlanBuilder::from(t1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The split_conjunction will unalias the expr.

.join(
t2,
JoinType::Left,
(Vec::<Column>::new(), Vec::<Column>::new()),
Some(filter),
)?
.build()?;
let expected = "Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";

assert_plan_eq(&plan, expected)
}
}
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,12 @@ impl Optimizer {
let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
Arc::new(InlineTableScan::new()),
Arc::new(TypeCoercion::new()),
Arc::new(ExtractEquijoinPredicate::new()),
Arc::new(SimplifyExpressions::new()),
Arc::new(UnwrapCastInComparison::new()),
Arc::new(DecorrelateWhereExists::new()),
Arc::new(DecorrelateWhereIn::new()),
Arc::new(ScalarSubqueryToJoin::new()),
Arc::new(ExtractEquijoinPredicate::new()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 related comments from #4711 (comment)

// simplify expressions does not simplify expressions in subqueries, so we
// run it again after running the optimizations that potentially converted
// subqueries to joins
Expand Down