Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement right semi join and support in HashBuildProbeorder #3958

Merged
merged 15 commits into from
Oct 27, 2022
11 changes: 9 additions & 2 deletions datafusion/core/src/physical_optimizer/hash_build_probe_order.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ fn should_swap_join_order(left: &dyn ExecutionPlan, right: &dyn ExecutionPlan) -

fn supports_swap(join_type: JoinType) -> bool {
match join_type {
JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => true,
JoinType::Semi | JoinType::Anti => false,
JoinType::Inner
| JoinType::Left
| JoinType::Right
| JoinType::Full
| JoinType::Semi
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to change JoinType::Semi to JoinType::LeftSemi explictly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That might be better for clarity, but also a breaking change.
I think we also should change Anti to LeftAnti if we want to do this.

FYI @alamb @andygrove

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could add the new variants and mark the old ones as deprecated?

| JoinType::RightSemi => true,
JoinType::Anti => false,
}
}

Expand All @@ -81,6 +86,8 @@ fn swap_join_type(join_type: JoinType) -> JoinType {
JoinType::Full => JoinType::Full,
JoinType::Left => JoinType::Right,
JoinType::Right => JoinType::Left,
JoinType::Semi => JoinType::RightSemi,
JoinType::RightSemi => JoinType::Semi,
_ => unreachable!(),
}
}
Expand Down
103 changes: 100 additions & 3 deletions datafusion/core/src/physical_plan/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,54 @@ fn build_join_indexes(
PrimitiveArray::<UInt32Type>::from(right),
))
}
JoinType::RightSemi => {
let mut left_indices = UInt64BufferBuilder::new(0);
let mut right_indices = UInt32BufferBuilder::new(0);

// Visit all of the right rows
for (row, hash_value) in hash_values.iter().enumerate() {
// Get the hash and find it in the build index

// For every item on the left and right we check if it matches
// This possibly contains rows with hash collisions,
// So we have to check here whether rows are equal or not
// We only produce one row if there is a match
if let Some((_, indices)) =
left.0.get(*hash_value, |(hash, _)| *hash_value == *hash)
{
for &i in indices {
// Check hash collisions
if equal_rows(
i as usize,
row,
&left_join_values,
&keys_values,
*null_equals_null,
)? {
left_indices.append(i);
right_indices.append(row as u32);
break;
}
}
}
}

let left = ArrayData::builder(DataType::UInt64)
.len(left_indices.len())
.add_buffer(left_indices.finish())
.build()
.unwrap();
let right = ArrayData::builder(DataType::UInt32)
.len(right_indices.len())
.add_buffer(right_indices.finish())
.build()
.unwrap();

Ok((
PrimitiveArray::<UInt64Type>::from(left),
PrimitiveArray::<UInt32Type>::from(right),
))
}
JoinType::Left => {
let mut left_indices = UInt64Builder::with_capacity(0);
let mut right_indices = UInt32Builder::with_capacity(0);
Expand Down Expand Up @@ -853,7 +901,11 @@ fn apply_join_filter(
)?;

match join_type {
JoinType::Inner | JoinType::Left | JoinType::Anti | JoinType::Semi => {
JoinType::Inner
| JoinType::Left
| JoinType::Anti
| JoinType::Semi
| JoinType::RightSemi => {
// For both INNER and LEFT joins, input arrays contains only indices for matched data.
// Due to this fact it's correct to simply apply filter to intermediate batch and return
// indices for left/right rows satisfying filter predicate
Expand Down Expand Up @@ -1287,7 +1339,9 @@ impl HashJoinStream {

buffer
}
JoinType::Inner | JoinType::Right => BooleanBufferBuilder::new(0),
JoinType::Inner | JoinType::Right | JoinType::RightSemi => {
BooleanBufferBuilder::new(0)
}
}
});

Expand Down Expand Up @@ -1324,7 +1378,7 @@ impl HashJoinStream {
visited_left_side.set_bit(x as usize, true);
});
}
JoinType::Inner | JoinType::Right => {}
JoinType::Inner | JoinType::Right | JoinType::RightSemi => {}
}
}
Some(result.map(|x| x.0))
Expand Down Expand Up @@ -1361,6 +1415,7 @@ impl HashJoinStream {
JoinType::Left
| JoinType::Full
| JoinType::Semi
| JoinType::RightSemi
| JoinType::Anti
| JoinType::Inner
| JoinType::Right => {}
Expand Down Expand Up @@ -2116,6 +2171,48 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn join_right_semi() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let left = build_table(
("a2", &vec![10, 20, 30, 40]),
("b1", &vec![4, 5, 6, 5]), // 5 is double on the left
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recommend calling this b2 not b1 for clarity

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense

("c2", &vec![70, 80, 90, 100]),
);
let right = build_table(
("a1", &vec![1, 2, 2, 3]),
("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the left
("c1", &vec![7, 8, 8, 9]),
);

let on = vec![(
Column::new_with_schema("b1", &left.schema())?,
Column::new_with_schema("b1", &right.schema())?,
)];

let join = join(left, right, on, &JoinType::RightSemi, false)?;

let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1"]);

let stream = join.execute(0, task_ctx)?;
let batches = common::collect(stream).await?;

let expected = vec![
"+----+----+----+",
"| a1 | b1 | c1 |",
"+----+----+----+",
"| 1 | 4 | 7 |",
"| 2 | 5 | 8 |",
"| 2 | 5 | 8 |",
"+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);

Ok(())
}

#[tokio::test]
async fn join_anti() -> Result<()> {
let session_ctx = SessionContext::new();
Expand Down
16 changes: 13 additions & 3 deletions datafusion/core/src/physical_plan/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ impl SortMergeJoinExec {
let left_schema = left.schema();
let right_schema = right.schema();

if join_type == JoinType::RightSemi {
return Err(DataFusionError::Plan(
"RightSemi not yet supported in SortMergeJoinExec".to_string(),
));
}

check_join_is_valid(&left_schema, &right_schema, &on)?;
if sort_options.len() != on.len() {
return Err(DataFusionError::Plan(format!(
Expand Down Expand Up @@ -132,7 +138,7 @@ impl ExecutionPlan for SortMergeJoinExec {
JoinType::Inner | JoinType::Left | JoinType::Semi | JoinType::Anti => {
self.left.output_ordering()
}
JoinType::Right => self.right.output_ordering(),
JoinType::Right | JoinType::RightSemi => self.right.output_ordering(),
JoinType::Full => None,
}
}
Expand Down Expand Up @@ -180,7 +186,7 @@ impl ExecutionPlan for SortMergeJoinExec {
self.on.iter().map(|on| on.0.clone()).collect(),
self.on.iter().map(|on| on.1.clone()).collect(),
),
JoinType::Right => (
JoinType::Right | JoinType::RightSemi => (
self.right.clone(),
self.left.clone(),
self.on.iter().map(|on| on.1.clone()).collect(),
Expand Down Expand Up @@ -767,7 +773,11 @@ impl SMJStream {
Ordering::Less => {
if matches!(
self.join_type,
JoinType::Left | JoinType::Right | JoinType::Full | JoinType::Anti
JoinType::Left
| JoinType::Right
| JoinType::RightSemi
| JoinType::Full
| JoinType::Anti
) {
join_streamed = !self.streamed_joined;
}
Expand Down
17 changes: 17 additions & 0 deletions datafusion/core/src/physical_plan/joins/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) ->
JoinType::Right => is_left, // left input is padded with nulls
JoinType::Full => true, // both inputs can be padded with nulls
JoinType::Semi => false, // doesn't introduce nulls
JoinType::RightSemi => false, // doesn't introduce nulls
JoinType::Anti => false, // doesn't introduce nulls (or can it??)
};

Expand Down Expand Up @@ -236,6 +237,21 @@ pub fn build_join_schema(
)
})
.unzip(),
JoinType::RightSemi => right
.fields()
.iter()
.cloned()
.enumerate()
.map(|(index, f)| {
(
f,
ColumnIndex {
index,
side: JoinSide::Right,
},
)
})
.unzip(),
};

(Schema::new(fields), column_indices)
Expand Down Expand Up @@ -396,6 +412,7 @@ fn estimate_join_cardinality(

JoinType::Semi => None,
JoinType::Anti => None,
JoinType::RightSemi => None,
}
}

Expand Down
1 change: 1 addition & 0 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,7 @@ pub fn build_join_schema(
// Only use the left side for the schema
left.fields().clone()
}
JoinType::RightSemi => right.fields().clone(),
};

let mut metadata = left.metadata().clone();
Expand Down
5 changes: 4 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -980,8 +980,10 @@ pub enum JoinType {
Right,
/// Full Join
Full,
/// Semi Join
/// Left Semi Join
Semi,
/// Right Semi Join
RightSemi,
/// Anti Join
Anti,
}
Expand All @@ -994,6 +996,7 @@ impl Display for JoinType {
JoinType::Right => "Right",
JoinType::Full => "Full",
JoinType::Semi => "Semi",
JoinType::RightSemi => "RightSemi",
JoinType::Anti => "Anti",
};
write!(f, "{}", join_type)
Expand Down
5 changes: 4 additions & 1 deletion datafusion/optimizer/src/filter_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ fn lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
// No columns from the right side of the join can be referenced in output
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
JoinType::Semi | JoinType::Anti => Ok((true, false)),
// No columns from the left side of the join can be referenced in output
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
JoinType::RightSemi => Ok((false, true)),
},
LogicalPlan::CrossJoin(_) => Ok((true, true)),
_ => Err(DataFusionError::Internal(
Expand All @@ -195,7 +198,7 @@ fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
JoinType::Left => Ok((false, true)),
JoinType::Right => Ok((true, false)),
JoinType::Full => Ok((false, false)),
JoinType::Semi | JoinType::Anti => {
JoinType::Semi | JoinType::Anti | JoinType::RightSemi => {
// filter_push_down does not yet support SEMI/ANTI joins with join conditions
Ok((false, false))
}
Expand Down