Skip to content

Commit

Permalink
fix: sql planner creates cross join instead of inner join from select…
Browse files Browse the repository at this point in the history
… predicates
  • Loading branch information
xudong963 committed Jan 19, 2022
1 parent f027e5f commit c330dc0
Showing 1 changed file with 88 additions and 18 deletions.
106 changes: 88 additions & 18 deletions datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
alias: Option<String>,
) -> Result<LogicalPlan> {
let plans = self.plan_from_tables(&select.from, ctes)?;

let plan = match &select.selection {
Some(predicate_expr) => {
// build join schema
Expand All @@ -714,33 +713,79 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
extract_possible_join_keys(&filter_expr, &mut possible_join_keys)?;

let mut all_join_keys = HashSet::new();
let mut left = plans[0].clone();
for right in plans.iter().skip(1) {
let left_schema = left.schema();
let right_schema = right.schema();

let mut plans = plans.into_iter();
let mut left = plans.next().unwrap(); // have at least one plan

// List of the plans that have not yet been joined
let mut remaining_plans: Vec<Option<LogicalPlan>> =
plans.into_iter().map(Some).collect();

// Take from the list of remaining plans,
loop {
let mut join_keys = vec![];
for (l, r) in &possible_join_keys {
if left_schema.field_from_column(l).is_ok()
&& right_schema.field_from_column(r).is_ok()
{
join_keys.push((l.clone(), r.clone()));
} else if left_schema.field_from_column(r).is_ok()
&& right_schema.field_from_column(l).is_ok()
{
join_keys.push((r.clone(), l.clone()));

// Search all remaining plans for the next to
// join. Prefer the first one that has a join
// predicate in the predicate lists
let plan_with_idx = remaining_plans.iter().enumerate().find(|(_idx, plan)| {
// skip plans that have been joined already
let plan = if let Some(plan) = plan {
plan
} else {
return false;
};

// can we find a match?
let left_schema = left.schema();
let right_schema = plan.schema();
for (l, r) in &possible_join_keys {
if left_schema.field_from_column(l).is_ok()
&& right_schema.field_from_column(r).is_ok()
{
join_keys.push((l.clone(), r.clone()));
} else if left_schema.field_from_column(r).is_ok()
&& right_schema.field_from_column(l).is_ok()
{
join_keys.push((r.clone(), l.clone()));
}
}
}
// stop if we found join keys
!join_keys.is_empty()
});

// If we did not find join keys, either there are
// no more plans, or we can't find any plans that
// can be joined with predicates
if join_keys.is_empty() {
left =
LogicalPlanBuilder::from(left).cross_join(right)?.build()?;
assert!(plan_with_idx.is_none());

// pick the first non null plan to join
let plan_with_idx = remaining_plans
.iter()
.enumerate()
.find(|(_idx, plan)| plan.is_some());
if let Some((idx, _)) = plan_with_idx {
let plan = std::mem::take(&mut remaining_plans[idx]).unwrap();
left = LogicalPlanBuilder::from(left)
.cross_join(&plan)?
.build()?;
} else {
// no more plans to join
break;
}
} else {
// have a plan
let (idx, _) = plan_with_idx.expect("found plan node");
let plan = std::mem::take(&mut remaining_plans[idx]).unwrap();

let left_keys: Vec<Column> =
join_keys.iter().map(|(l, _)| l.clone()).collect();
let right_keys: Vec<Column> =
join_keys.iter().map(|(_, r)| r.clone()).collect();
let builder = LogicalPlanBuilder::from(left);
left = builder
.join(right, JoinType::Inner, (left_keys, right_keys))?
.join(&plan, JoinType::Inner, (left_keys, right_keys))?
.build()?;
}

Expand Down Expand Up @@ -3818,6 +3863,31 @@ mod tests {
\n TableScan: public.person projection=None";
quick_test(sql, expected);
}

#[test]
fn cross_join_to_inner_join() {
let sql = "select person.id from person, orders, lineitem where person.id = lineitem.l_item_id and orders.o_item_id = lineitem.l_description;";
let expected = "Projection: #person.id\
\n Join: #lineitem.l_description = #orders.o_item_id\
\n Join: #person.id = #lineitem.l_item_id\
\n TableScan: person projection=None\
\n TableScan: lineitem projection=None\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}

#[test]
fn cross_join_not_to_inner_join() {
let sql = "select person.id from person, orders, lineitem where person.id = person.age;";
let expected = "Projection: #person.id\
\n Filter: #person.id = #person.age\
\n CrossJoin:\
\n CrossJoin:\
\n TableScan: person projection=None\
\n TableScan: orders projection=None\
\n TableScan: lineitem projection=None";
quick_test(sql, expected);
}
}

fn parse_sql_number(n: &str) -> Result<Expr> {
Expand Down

0 comments on commit c330dc0

Please sign in to comment.