Skip to content

Commit

Permalink
Implement right semi join and support in HashBuildProbeorder (#3958)
Browse files Browse the repository at this point in the history
* Implement right semi join

* Change error a bit

* protobuf

* protobuf

* protobuf

* Change column name to b2

* Rename everything

* Rename & fmt

* Change display to leftanti

* Fix last expected plan

* Commit generated file

* generated
  • Loading branch information
Dandandan authored Oct 27, 2022
1 parent e73a43c commit 002165b
Show file tree
Hide file tree
Showing 23 changed files with 266 additions and 116 deletions.
2 changes: 1 addition & 1 deletion benchmarks/expected-plans/q16.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Sort: supplier_cnt DESC NULLS FIRST, part.p_brand ASC NULLS LAST, part.p_type AS
Projection: group_alias_0 AS p_brand, group_alias_1 AS p_type, group_alias_2 AS p_size, COUNT(alias1) AS COUNT(DISTINCT partsupp.ps_suppkey)
Aggregate: groupBy=[[group_alias_0, group_alias_1, group_alias_2]], aggr=[[COUNT(alias1)]]
Aggregate: groupBy=[[part.p_brand AS group_alias_0, part.p_type AS group_alias_1, part.p_size AS group_alias_2, partsupp.ps_suppkey AS alias1]], aggr=[[]]
Anti Join: partsupp.ps_suppkey = __sq_1.s_suppkey
LeftAnti Join: partsupp.ps_suppkey = __sq_1.s_suppkey
Inner Join: partsupp.ps_partkey = part.p_partkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey]
Filter: part.p_brand != Utf8("Brand#45") AND part.p_type NOT LIKE Utf8("MEDIUM POLISHED%") AND part.p_size IN ([Int32(49), Int32(14), Int32(23), Int32(45), Int32(19), Int32(3), Int32(36), Int32(9)])
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/expected-plans/q18.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Sort: orders.o_totalprice DESC NULLS FIRST, orders.o_orderdate ASC NULLS LAST
Projection: customer.c_name, customer.c_custkey, orders.o_orderkey, orders.o_orderdate, orders.o_totalprice, SUM(lineitem.l_quantity)
Aggregate: groupBy=[[customer.c_name, customer.c_custkey, orders.o_orderkey, orders.o_orderdate, orders.o_totalprice]], aggr=[[SUM(lineitem.l_quantity)]]
Semi Join: orders.o_orderkey = __sq_1.l_orderkey
LeftSemi Join: orders.o_orderkey = __sq_1.l_orderkey
Inner Join: orders.o_orderkey = lineitem.l_orderkey
Inner Join: customer.c_custkey = orders.o_custkey
TableScan: customer projection=[c_custkey, c_name]
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/expected-plans/q20.txt
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
Sort: supplier.s_name ASC NULLS LAST
Projection: supplier.s_name, supplier.s_address
Semi Join: supplier.s_suppkey = __sq_2.ps_suppkey
LeftSemi Join: supplier.s_suppkey = __sq_2.ps_suppkey
Inner Join: supplier.s_nationkey = nation.n_nationkey
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey]
Filter: nation.n_name = Utf8("CANADA")
TableScan: nation projection=[n_nationkey, n_name]
Projection: partsupp.ps_suppkey AS ps_suppkey, alias=__sq_2
Filter: CAST(partsupp.ps_availqty AS Decimal128(38, 17)) > __sq_3.__value
Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, partsupp.ps_suppkey = __sq_3.l_suppkey
Semi Join: partsupp.ps_partkey = __sq_1.p_partkey
LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey
TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty]
Projection: part.p_partkey AS p_partkey, alias=__sq_1
Filter: part.p_name LIKE Utf8("forest%")
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/expected-plans/q21.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
Sort: numwait DESC NULLS FIRST, supplier.s_name ASC NULLS LAST
Projection: supplier.s_name, COUNT(UInt8(1)) AS numwait
Aggregate: groupBy=[[supplier.s_name]], aggr=[[COUNT(UInt8(1))]]
Anti Join: l1.l_orderkey = l3.l_orderkey Filter: l3.l_suppkey != l1.l_suppkey
Semi Join: l1.l_orderkey = l2.l_orderkey Filter: l2.l_suppkey != l1.l_suppkey
LeftAnti Join: l1.l_orderkey = l3.l_orderkey Filter: l3.l_suppkey != l1.l_suppkey
LeftSemi Join: l1.l_orderkey = l2.l_orderkey Filter: l2.l_suppkey != l1.l_suppkey
Inner Join: supplier.s_nationkey = nation.n_nationkey
Inner Join: l1.l_orderkey = orders.o_orderkey
Inner Join: supplier.s_suppkey = l1.l_suppkey
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/expected-plans/q22.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Sort: custsale.cntrycode ASC NULLS LAST
Projection: substr(customer.c_phone, Int64(1), Int64(2)) AS cntrycode, customer.c_acctbal, alias=custsale
Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __sq_1.__value
CrossJoin:
Anti Join: customer.c_custkey = orders.o_custkey
LeftAnti Join: customer.c_custkey = orders.o_custkey
Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
TableScan: customer projection=[c_custkey, c_phone, c_acctbal]
TableScan: orders projection=[o_custkey]
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/expected-plans/q4.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Sort: orders.o_orderpriority ASC NULLS LAST
Projection: orders.o_orderpriority, COUNT(UInt8(1)) AS order_count
Aggregate: groupBy=[[orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]]
Semi Join: orders.o_orderkey = lineitem.l_orderkey
LeftSemi Join: orders.o_orderkey = lineitem.l_orderkey
Filter: orders.o_orderdate >= Date32("8582") AND orders.o_orderdate < Date32("8674")
TableScan: orders projection=[o_orderkey, o_orderdate, o_orderpriority]
Filter: lineitem.l_commitdate < lineitem.l_receiptdate
Expand Down
11 changes: 9 additions & 2 deletions datafusion/core/src/physical_optimizer/hash_build_probe_order.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,13 @@ fn should_swap_join_order(left: &dyn ExecutionPlan, right: &dyn ExecutionPlan) -

fn supports_swap(join_type: JoinType) -> bool {
match join_type {
JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => true,
JoinType::Semi | JoinType::Anti => false,
JoinType::Inner
| JoinType::Left
| JoinType::Right
| JoinType::Full
| JoinType::LeftSemi
| JoinType::RightSemi => true,
JoinType::LeftAnti => false,
}
}

Expand All @@ -81,6 +86,8 @@ fn swap_join_type(join_type: JoinType) -> JoinType {
JoinType::Full => JoinType::Full,
JoinType::Left => JoinType::Right,
JoinType::Right => JoinType::Left,
JoinType::LeftSemi => JoinType::RightSemi,
JoinType::RightSemi => JoinType::LeftSemi,
_ => unreachable!(),
}
}
Expand Down
132 changes: 116 additions & 16 deletions datafusion/core/src/physical_plan/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -653,7 +653,7 @@ fn build_batch(
(left_indices, right_indices)
};

if matches!(join_type, JoinType::Semi | JoinType::Anti) {
if matches!(join_type, JoinType::LeftSemi | JoinType::LeftAnti) {
return Ok((
RecordBatch::new_empty(Arc::new(schema.clone())),
left_filtered_indices,
Expand Down Expand Up @@ -719,7 +719,7 @@ fn build_join_indexes(
let left = &left_data.0;

match join_type {
JoinType::Inner | JoinType::Semi | JoinType::Anti => {
JoinType::Inner | JoinType::LeftSemi | JoinType::LeftAnti => {
// Using a buffer builder to avoid slower normal builder
let mut left_indices = UInt64BufferBuilder::new(0);
let mut right_indices = UInt32BufferBuilder::new(0);
Expand Down Expand Up @@ -765,6 +765,54 @@ fn build_join_indexes(
PrimitiveArray::<UInt32Type>::from(right),
))
}
JoinType::RightSemi => {
let mut left_indices = UInt64BufferBuilder::new(0);
let mut right_indices = UInt32BufferBuilder::new(0);

// Visit all of the right rows
for (row, hash_value) in hash_values.iter().enumerate() {
// Get the hash and find it in the build index

// For every item on the left and right we check if it matches
// This possibly contains rows with hash collisions,
// So we have to check here whether rows are equal or not
// We only produce one row if there is a match
if let Some((_, indices)) =
left.0.get(*hash_value, |(hash, _)| *hash_value == *hash)
{
for &i in indices {
// Check hash collisions
if equal_rows(
i as usize,
row,
&left_join_values,
&keys_values,
*null_equals_null,
)? {
left_indices.append(i);
right_indices.append(row as u32);
break;
}
}
}
}

let left = ArrayData::builder(DataType::UInt64)
.len(left_indices.len())
.add_buffer(left_indices.finish())
.build()
.unwrap();
let right = ArrayData::builder(DataType::UInt32)
.len(right_indices.len())
.add_buffer(right_indices.finish())
.build()
.unwrap();

Ok((
PrimitiveArray::<UInt64Type>::from(left),
PrimitiveArray::<UInt32Type>::from(right),
))
}
JoinType::Left => {
let mut left_indices = UInt64Builder::with_capacity(0);
let mut right_indices = UInt32Builder::with_capacity(0);
Expand Down Expand Up @@ -853,7 +901,11 @@ fn apply_join_filter(
)?;

match join_type {
JoinType::Inner | JoinType::Left | JoinType::Anti | JoinType::Semi => {
JoinType::Inner
| JoinType::Left
| JoinType::LeftAnti
| JoinType::LeftSemi
| JoinType::RightSemi => {
// For both INNER and LEFT joins, input arrays contains only indices for matched data.
// Due to this fact it's correct to simply apply filter to intermediate batch and return
// indices for left/right rows satisfying filter predicate
Expand Down Expand Up @@ -1280,14 +1332,19 @@ impl HashJoinStream {
let visited_left_side = self.visited_left_side.get_or_insert_with(|| {
let num_rows = left_data.1.num_rows();
match self.join_type {
JoinType::Left | JoinType::Full | JoinType::Semi | JoinType::Anti => {
JoinType::Left
| JoinType::Full
| JoinType::LeftSemi
| JoinType::LeftAnti => {
let mut buffer = BooleanBufferBuilder::new(num_rows);

buffer.append_n(num_rows, false);

buffer
}
JoinType::Inner | JoinType::Right => BooleanBufferBuilder::new(0),
JoinType::Inner | JoinType::Right | JoinType::RightSemi => {
BooleanBufferBuilder::new(0)
}
}
});

Expand Down Expand Up @@ -1318,13 +1375,13 @@ impl HashJoinStream {
match self.join_type {
JoinType::Left
| JoinType::Full
| JoinType::Semi
| JoinType::Anti => {
| JoinType::LeftSemi
| JoinType::LeftAnti => {
left_side.iter().flatten().for_each(|x| {
visited_left_side.set_bit(x as usize, true);
});
}
JoinType::Inner | JoinType::Right => {}
JoinType::Inner | JoinType::Right | JoinType::RightSemi => {}
}
}
Some(result.map(|x| x.0))
Expand All @@ -1335,16 +1392,16 @@ impl HashJoinStream {
match self.join_type {
JoinType::Left
| JoinType::Full
| JoinType::Semi
| JoinType::Anti
| JoinType::LeftSemi
| JoinType::LeftAnti
if !self.is_exhausted =>
{
let result = produce_from_matched(
visited_left_side,
&self.schema,
&self.column_indices,
left_data,
self.join_type != JoinType::Semi,
self.join_type != JoinType::LeftSemi,
);
if let Ok(ref batch) = result {
self.join_metrics.input_batches.add(1);
Expand All @@ -1360,8 +1417,9 @@ impl HashJoinStream {
}
JoinType::Left
| JoinType::Full
| JoinType::Semi
| JoinType::Anti
| JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::Inner
| JoinType::Right => {}
}
Expand Down Expand Up @@ -2094,7 +2152,49 @@ mod tests {
Column::new_with_schema("b1", &right.schema())?,
)];

let join = join(left, right, on, &JoinType::Semi, false)?;
let join = join(left, right, on, &JoinType::LeftSemi, false)?;

let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1"]);

let stream = join.execute(0, task_ctx)?;
let batches = common::collect(stream).await?;

let expected = vec![
"+----+----+----+",
"| a1 | b1 | c1 |",
"+----+----+----+",
"| 1 | 4 | 7 |",
"| 2 | 5 | 8 |",
"| 2 | 5 | 8 |",
"+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);

Ok(())
}

#[tokio::test]
async fn join_right_semi() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let left = build_table(
("a2", &vec![10, 20, 30, 40]),
("b2", &vec![4, 5, 6, 5]), // 5 is double on the left
("c2", &vec![70, 80, 90, 100]),
);
let right = build_table(
("a1", &vec![1, 2, 2, 3]),
("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the left
("c1", &vec![7, 8, 8, 9]),
);

let on = vec![(
Column::new_with_schema("b2", &left.schema())?,
Column::new_with_schema("b1", &right.schema())?,
)];

let join = join(left, right, on, &JoinType::RightSemi, false)?;

let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1"]);
Expand Down Expand Up @@ -2135,7 +2235,7 @@ mod tests {
Column::new_with_schema("b1", &right.schema())?,
)];

let join = join(left, right, on, &JoinType::Anti, false)?;
let join = join(left, right, on, &JoinType::LeftAnti, false)?;

let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1"]);
Expand Down Expand Up @@ -2196,7 +2296,7 @@ mod tests {
let filter =
JoinFilter::new(filter_expression, column_indices, intermediate_schema);

let join = join_with_filter(left, right, on, filter, &JoinType::Anti, false)?;
let join = join_with_filter(left, right, on, filter, &JoinType::LeftAnti, false)?;

let columns = columns(&join.schema());
assert_eq!(columns, vec!["col1", "col2", "col3"]);
Expand Down
Loading

0 comments on commit 002165b

Please sign in to comment.