Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: sql planner creates cross join instead of inner join from select predicates #1566

Merged
merged 1 commit into from
Jan 21, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 90 additions & 19 deletions datafusion/src/sql/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,6 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
alias: Option<String>,
) -> Result<LogicalPlan> {
let plans = self.plan_from_tables(&select.from, ctes)?;

let plan = match &select.selection {
Some(predicate_expr) => {
// build join schema
Expand All @@ -714,33 +713,80 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
extract_possible_join_keys(&filter_expr, &mut possible_join_keys)?;

let mut all_join_keys = HashSet::new();
let mut left = plans[0].clone();
for right in plans.iter().skip(1) {
let left_schema = left.schema();
let right_schema = right.schema();

let mut plans = plans.into_iter();
let mut left = plans.next().unwrap(); // have at least one plan

// List of the plans that have not yet been joined
let mut remaining_plans: Vec<Option<LogicalPlan>> =
plans.into_iter().map(Some).collect();

// Take from the list of remaining plans,
loop {
let mut join_keys = vec![];
for (l, r) in &possible_join_keys {
if left_schema.field_from_column(l).is_ok()
&& right_schema.field_from_column(r).is_ok()
{
join_keys.push((l.clone(), r.clone()));
} else if left_schema.field_from_column(r).is_ok()
&& right_schema.field_from_column(l).is_ok()
{
join_keys.push((r.clone(), l.clone()));
}
}

// Search all remaining plans for the next to
// join. Prefer the first one that has a join
// predicate in the predicate lists
let plan_with_idx =
remaining_plans.iter().enumerate().find(|(_idx, plan)| {
// skip plans that have been joined already
let plan = if let Some(plan) = plan {
plan
} else {
return false;
};

// can we find a match?
let left_schema = left.schema();
let right_schema = plan.schema();
for (l, r) in &possible_join_keys {
if left_schema.field_from_column(l).is_ok()
&& right_schema.field_from_column(r).is_ok()
{
join_keys.push((l.clone(), r.clone()));
} else if left_schema.field_from_column(r).is_ok()
&& right_schema.field_from_column(l).is_ok()
{
join_keys.push((r.clone(), l.clone()));
}
}
// stop if we found join keys
!join_keys.is_empty()
});

// If we did not find join keys, either there are
// no more plans, or we can't find any plans that
// can be joined with predicates
if join_keys.is_empty() {
left =
LogicalPlanBuilder::from(left).cross_join(right)?.build()?;
assert!(plan_with_idx.is_none());
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very tight!


// pick the first non null plan to join
let plan_with_idx = remaining_plans
.iter()
.enumerate()
.find(|(_idx, plan)| plan.is_some());
if let Some((idx, _)) = plan_with_idx {
let plan = std::mem::take(&mut remaining_plans[idx]).unwrap();
left = LogicalPlanBuilder::from(left)
.cross_join(&plan)?
.build()?;
} else {
// no more plans to join
break;
}
} else {
// have a plan
let (idx, _) = plan_with_idx.expect("found plan node");
let plan = std::mem::take(&mut remaining_plans[idx]).unwrap();

let left_keys: Vec<Column> =
join_keys.iter().map(|(l, _)| l.clone()).collect();
let right_keys: Vec<Column> =
join_keys.iter().map(|(_, r)| r.clone()).collect();
let builder = LogicalPlanBuilder::from(left);
left = builder
.join(right, JoinType::Inner, (left_keys, right_keys))?
.join(&plan, JoinType::Inner, (left_keys, right_keys))?
.build()?;
}

Expand Down Expand Up @@ -3818,6 +3864,31 @@ mod tests {
\n TableScan: public.person projection=None";
quick_test(sql, expected);
}

#[test]
fn cross_join_to_inner_join() {
let sql = "select person.id from person, orders, lineitem where person.id = lineitem.l_item_id and orders.o_item_id = lineitem.l_description;";
let expected = "Projection: #person.id\
\n Join: #lineitem.l_description = #orders.o_item_id\
\n Join: #person.id = #lineitem.l_item_id\
\n TableScan: person projection=None\
\n TableScan: lineitem projection=None\
\n TableScan: orders projection=None";
quick_test(sql, expected);
}

#[test]
fn cross_join_not_to_inner_join() {
let sql = "select person.id from person, orders, lineitem where person.id = person.age;";
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here @pjmore

let expected = "Projection: #person.id\
\n Filter: #person.id = #person.age\
\n CrossJoin:\
\n CrossJoin:\
\n TableScan: person projection=None\
\n TableScan: orders projection=None\
\n TableScan: lineitem projection=None";
quick_test(sql, expected);
}
}

fn parse_sql_number(n: &str) -> Result<Expr> {
Expand Down