Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement right semi join and support in HashBuildProbeorder #3958

Merged
merged 15 commits into from
Oct 27, 2022
2 changes: 1 addition & 1 deletion benchmarks/expected-plans/q16.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Sort: supplier_cnt DESC NULLS FIRST, part.p_brand ASC NULLS LAST, part.p_type AS
Projection: group_alias_0 AS p_brand, group_alias_1 AS p_type, group_alias_2 AS p_size, COUNT(alias1) AS COUNT(DISTINCT partsupp.ps_suppkey)
Aggregate: groupBy=[[group_alias_0, group_alias_1, group_alias_2]], aggr=[[COUNT(alias1)]]
Aggregate: groupBy=[[part.p_brand AS group_alias_0, part.p_type AS group_alias_1, part.p_size AS group_alias_2, partsupp.ps_suppkey AS alias1]], aggr=[[]]
Anti Join: partsupp.ps_suppkey = __sq_1.s_suppkey
LeftAnti Join: partsupp.ps_suppkey = __sq_1.s_suppkey
Inner Join: partsupp.ps_partkey = part.p_partkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey]
Filter: part.p_brand != Utf8("Brand#45") AND part.p_type NOT LIKE Utf8("MEDIUM POLISHED%") AND part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)])
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/expected-plans/q18.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Sort: orders.o_totalprice DESC NULLS FIRST, orders.o_orderdate ASC NULLS LAST
Projection: customer.c_name, customer.c_custkey, orders.o_orderkey, orders.o_orderdate, orders.o_totalprice, SUM(lineitem.l_quantity)
Aggregate: groupBy=[[customer.c_name, customer.c_custkey, orders.o_orderkey, orders.o_orderdate, orders.o_totalprice]], aggr=[[SUM(lineitem.l_quantity)]]
Semi Join: orders.o_orderkey = __sq_1.l_orderkey
LeftSemi Join: orders.o_orderkey = __sq_1.l_orderkey
Inner Join: orders.o_orderkey = lineitem.l_orderkey
Inner Join: customer.c_custkey = orders.o_custkey
TableScan: customer projection=[c_custkey, c_name]
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/expected-plans/q20.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
Sort: supplier.s_name ASC NULLS LAST
Projection: supplier.s_name, supplier.s_address
Semi Join: supplier.s_suppkey = __sq_2.ps_suppkey
LeftSemi Join: supplier.s_suppkey = __sq_2.ps_suppkey
Inner Join: supplier.s_nationkey = nation.n_nationkey
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey]
Filter: nation.n_name = Utf8("CANADA")
TableScan: nation projection=[n_nationkey, n_name]
Projection: partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
Filter: CAST(partsupp.ps_availqty AS Decimal128(38, 17)) > __sq_3.__value
Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, partsupp.ps_suppkey = __sq_3.l_suppkey
Semi Join: partsupp.ps_partkey = __sq_1.p_partkey
LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty]
Projection: part.p_partkey AS p_partkey, alias=__sq_1
Filter: part.p_name LIKE Utf8("forest%")
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/expected-plans/q21.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS LAST
Projection: supplier.s_name, COUNT(UInt8(1)) AS numwait
Aggregate: groupBy=[[supplier.s_name]], aggr=[[COUNT(UInt8(1))]]
Anti Join: l1.l_orderkey = l3.l_orderkey Filter: l3.l_suppkey != l1.l_suppkey
Semi Join: l1.l_orderkey = l2.l_orderkey Filter: l2.l_suppkey != l1.l_suppkey
LeftAnti Join: l1.l_orderkey = l3.l_orderkey Filter: l3.l_suppkey != l1.l_suppkey
LeftSemi Join: l1.l_orderkey = l2.l_orderkey Filter: l2.l_suppkey != l1.l_suppkey
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: l1.l_orderkey = orders.o_orderkey
Inner Join: supplier.s_suppkey = l1.l_suppkey
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/expected-plans/q4.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Sort: orders.o_orderpriority ASC NULLS LAST
Projection: orders.o_orderpriority, COUNT(UInt8(1)) AS order_count
Aggregate: groupBy=[[orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]]
Semi Join: orders.o_orderkey = lineitem.l_orderkey
LeftSemi Join: orders.o_orderkey = lineitem.l_orderkey
Filter: orders.o_orderdate >= Date32("8582") AND orders.o_orderdate < Date32("8674")
TableScan: orders projection=[o_orderkey, o_orderdate, o_orderpriority]
Filter: lineitem.l_commitdate < lineitem.l_receiptdate
Expand Down
11 changes: 9 additions & 2 deletions datafusion/core/src/physical_optimizer/hash_build_probe_order.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ fn should_swap_join_order(left: &dyn ExecutionPlan, right: &dyn ExecutionPlan) -

fn supports_swap(join_type: JoinType) -> bool {
match join_type {
JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => true,
JoinType::Semi | JoinType::Anti => false,
JoinType::Inner
| JoinType::Left
| JoinType::Right
| JoinType::Full
| JoinType::LeftSemi
| JoinType::RightSemi => true,
JoinType::LeftAnti => false,
}
}

Expand All @@ -81,6 +86,8 @@ fn swap_join_type(join_type: JoinType) -> JoinType {
JoinType::Full => JoinType::Full,
JoinType::Left => JoinType::Right,
JoinType::Right => JoinType::Left,
JoinType::LeftSemi => JoinType::RightSemi,
JoinType::RightSemi => JoinType::LeftSemi,
_ => unreachable!(),
}
}
Expand Down
132 changes: 116 additions & 16 deletions datafusion/core/src/physical_plan/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ fn build_batch(
(left_indices, right_indices)
};

if matches!(join_type, JoinType::Semi | JoinType::Anti) {
if matches!(join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
return Ok((
RecordBatch::new_empty(Arc::new(schema.clone())),
left_filtered_indices,
Expand Down Expand Up @@ -719,7 +719,7 @@ fn build_join_indexes(
let left = &left_data.0;

match join_type {
JoinType::Inner | JoinType::Semi | JoinType::Anti => {
JoinType::Inner | JoinType::LeftSemi | JoinType::LeftAnti => {
// Using a buffer builder to avoid slower normal builder
let mut left_indices = UInt64BufferBuilder::new(0);
let mut right_indices = UInt32BufferBuilder::new(0);
Expand Down Expand Up @@ -765,6 +765,54 @@ fn build_join_indexes(
PrimitiveArray::<UInt32Type>::from(right),
))
}
JoinType::RightSemi => {
let mut left_indices = UInt64BufferBuilder::new(0);
let mut right_indices = UInt32BufferBuilder::new(0);

// Visit all of the right rows
for (row, hash_value) in hash_values.iter().enumerate() {
// Get the hash and find it in the build index

// For every item on the left and right we check if it matches
// This possibly contains rows with hash collisions,
// So we have to check here whether rows are equal or not
// We only produce one row if there is a match
if let Some((_, indices)) =
left.0.get(*hash_value, |(hash, _)| *hash_value == *hash)
{
for &i in indices {
// Check hash collisions
if equal_rows(
i as usize,
row,
&left_join_values,
&keys_values,
*null_equals_null,
)? {
left_indices.append(i);
right_indices.append(row as u32);
break;
}
}
}
}

let left = ArrayData::builder(DataType::UInt64)
.len(left_indices.len())
.add_buffer(left_indices.finish())
.build()
.unwrap();
let right = ArrayData::builder(DataType::UInt32)
.len(right_indices.len())
.add_buffer(right_indices.finish())
.build()
.unwrap();

Ok((
PrimitiveArray::<UInt64Type>::from(left),
PrimitiveArray::<UInt32Type>::from(right),
))
}
JoinType::Left => {
let mut left_indices = UInt64Builder::with_capacity(0);
let mut right_indices = UInt32Builder::with_capacity(0);
Expand Down Expand Up @@ -853,7 +901,11 @@ fn apply_join_filter(
)?;

match join_type {
JoinType::Inner | JoinType::Left | JoinType::Anti | JoinType::Semi => {
JoinType::Inner
| JoinType::Left
| JoinType::LeftAnti
| JoinType::LeftSemi
| JoinType::RightSemi => {
// For both INNER and LEFT joins, input arrays contains only indices for matched data.
// Due to this fact it's correct to simply apply filter to intermediate batch and return
// indices for left/right rows satisfying filter predicate
Expand Down Expand Up @@ -1280,14 +1332,19 @@ impl HashJoinStream {
let visited_left_side = self.visited_left_side.get_or_insert_with(|| {
let num_rows = left_data.1.num_rows();
match self.join_type {
JoinType::Left | JoinType::Full | JoinType::Semi | JoinType::Anti => {
JoinType::Left
| JoinType::Full
| JoinType::LeftSemi
| JoinType::LeftAnti => {
let mut buffer = BooleanBufferBuilder::new(num_rows);

buffer.append_n(num_rows, false);

buffer
}
JoinType::Inner | JoinType::Right => BooleanBufferBuilder::new(0),
JoinType::Inner | JoinType::Right | JoinType::RightSemi => {
BooleanBufferBuilder::new(0)
}
}
});

Expand Down Expand Up @@ -1318,13 +1375,13 @@ impl HashJoinStream {
match self.join_type {
JoinType::Left
| JoinType::Full
| JoinType::Semi
| JoinType::Anti => {
| JoinType::LeftSemi
| JoinType::LeftAnti => {
left_side.iter().flatten().for_each(|x| {
visited_left_side.set_bit(x as usize, true);
});
}
JoinType::Inner | JoinType::Right => {}
JoinType::Inner | JoinType::Right | JoinType::RightSemi => {}
}
}
Some(result.map(|x| x.0))
Expand All @@ -1335,16 +1392,16 @@ impl HashJoinStream {
match self.join_type {
JoinType::Left
| JoinType::Full
| JoinType::Semi
| JoinType::Anti
| JoinType::LeftSemi
| JoinType::LeftAnti
if !self.is_exhausted =>
{
let result = produce_from_matched(
visited_left_side,
&self.schema,
&self.column_indices,
left_data,
self.join_type != JoinType::Semi,
self.join_type != JoinType::LeftSemi,
);
if let Ok(ref batch) = result {
self.join_metrics.input_batches.add(1);
Expand All @@ -1360,8 +1417,9 @@ impl HashJoinStream {
}
JoinType::Left
| JoinType::Full
| JoinType::Semi
| JoinType::Anti
| JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::Inner
| JoinType::Right => {}
}
Expand Down Expand Up @@ -2094,7 +2152,49 @@ mod tests {
Column::new_with_schema("b1", &right.schema())?,
)];

let join = join(left, right, on, &JoinType::Semi, false)?;
let join = join(left, right, on, &JoinType::LeftSemi, false)?;

let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1"]);

let stream = join.execute(0, task_ctx)?;
let batches = common::collect(stream).await?;

let expected = vec![
"+----+----+----+",
"| a1 | b1 | c1 |",
"+----+----+----+",
"| 1 | 4 | 7 |",
"| 2 | 5 | 8 |",
"| 2 | 5 | 8 |",
"+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);

Ok(())
}

#[tokio::test]
async fn join_right_semi() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let left = build_table(
("a2", &vec![10, 20, 30, 40]),
("b2", &vec![4, 5, 6, 5]), // 5 is double on the left
("c2", &vec![70, 80, 90, 100]),
);
let right = build_table(
("a1", &vec![1, 2, 2, 3]),
("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the left
("c1", &vec![7, 8, 8, 9]),
);

let on = vec![(
Column::new_with_schema("b2", &left.schema())?,
Column::new_with_schema("b1", &right.schema())?,
)];

let join = join(left, right, on, &JoinType::RightSemi, false)?;

let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1"]);
Expand Down Expand Up @@ -2135,7 +2235,7 @@ mod tests {
Column::new_with_schema("b1", &right.schema())?,
)];

let join = join(left, right, on, &JoinType::Anti, false)?;
let join = join(left, right, on, &JoinType::LeftAnti, false)?;

let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1"]);
Expand Down Expand Up @@ -2196,7 +2296,7 @@ mod tests {
let filter =
JoinFilter::new(filter_expression, column_indices, intermediate_schema);

let join = join_with_filter(left, right, on, filter, &JoinType::Anti, false)?;
let join = join_with_filter(left, right, on, filter, &JoinType::LeftAnti, false)?;

let columns = columns(&join.schema());
assert_eq!(columns, vec!["col1", "col2", "col3"]);
Expand Down
35 changes: 23 additions & 12 deletions datafusion/core/src/physical_plan/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ impl SortMergeJoinExec {
let left_schema = left.schema();
let right_schema = right.schema();

if join_type == JoinType::RightSemi {
return Err(DataFusionError::NotImplemented(
"SortMergeJoinExec does not support JoinType::RightSemi".to_string(),
));
}

check_join_is_valid(&left_schema, &right_schema, &on)?;
if sort_options.len() != on.len() {
return Err(DataFusionError::Plan(format!(
Expand Down Expand Up @@ -129,10 +135,11 @@ impl ExecutionPlan for SortMergeJoinExec {

fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
match self.join_type {
JoinType::Inner | JoinType::Left | JoinType::Semi | JoinType::Anti => {
self.left.output_ordering()
}
JoinType::Right => self.right.output_ordering(),
JoinType::Inner
| JoinType::Left
| JoinType::LeftSemi
| JoinType::LeftAnti => self.left.output_ordering(),
JoinType::Right | JoinType::RightSemi => self.right.output_ordering(),
JoinType::Full => None,
}
}
Expand Down Expand Up @@ -173,14 +180,14 @@ impl ExecutionPlan for SortMergeJoinExec {
JoinType::Inner
| JoinType::Left
| JoinType::Full
| JoinType::Anti
| JoinType::Semi => (
| JoinType::LeftAnti
| JoinType::LeftSemi => (
self.left.clone(),
self.right.clone(),
self.on.iter().map(|on| on.0.clone()).collect(),
self.on.iter().map(|on| on.1.clone()).collect(),
),
JoinType::Right => (
JoinType::Right | JoinType::RightSemi => (
self.right.clone(),
self.left.clone(),
self.on.iter().map(|on| on.1.clone()).collect(),
Expand Down Expand Up @@ -767,13 +774,17 @@ impl SMJStream {
Ordering::Less => {
if matches!(
self.join_type,
JoinType::Left | JoinType::Right | JoinType::Full | JoinType::Anti
JoinType::Left
| JoinType::Right
| JoinType::RightSemi
| JoinType::Full
| JoinType::LeftAnti
) {
join_streamed = !self.streamed_joined;
}
}
Ordering::Equal => {
if matches!(self.join_type, JoinType::Semi) {
if matches!(self.join_type, JoinType::LeftSemi) {
join_streamed = !self.streamed_joined;
}
if matches!(
Expand Down Expand Up @@ -915,7 +926,7 @@ impl SMJStream {
let buffered_indices: UInt64Array = chunk.buffered_indices.finish();

let mut buffered_columns =
if matches!(self.join_type, JoinType::Semi | JoinType::Anti) {
if matches!(self.join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
vec![]
} else if let Some(buffered_idx) = chunk.buffered_batch_idx {
self.buffered_data.batches[buffered_idx]
Expand Down Expand Up @@ -1732,7 +1743,7 @@ mod tests {
Column::new_with_schema("b1", &right.schema())?,
)];

let (_, batches) = join_collect(left, right, on, JoinType::Anti).await?;
let (_, batches) = join_collect(left, right, on, JoinType::LeftAnti).await?;
let expected = vec![
"+----+----+----+",
"| a1 | b1 | c1 |",
Expand Down Expand Up @@ -1763,7 +1774,7 @@ mod tests {
Column::new_with_schema("b1", &right.schema())?,
)];

let (_, batches) = join_collect(left, right, on, JoinType::Semi).await?;
let (_, batches) = join_collect(left, right, on, JoinType::LeftSemi).await?;
let expected = vec![
"+----+----+----+",
"| a1 | b1 | c1 |",
Expand Down
Loading