Skip to content

Commit

Permalink
Add right anti join support and support it in HashBuildProbeOrder (ap…
Browse files Browse the repository at this point in the history
…ache#4011)

* Add right anti join support

* Fix
  • Loading branch information
Dandandan committed Nov 5, 2022
1 parent e1cdc9b commit 6356a21
Show file tree
Hide file tree
Showing 11 changed files with 178 additions and 52 deletions.
82 changes: 45 additions & 37 deletions datafusion/core/src/physical_optimizer/hash_build_probe_order.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ fn supports_swap(join_type: JoinType) -> bool {
| JoinType::Right
| JoinType::Full
| JoinType::LeftSemi
| JoinType::RightSemi => true,
JoinType::LeftAnti => false,
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti => true,
}
}

Expand All @@ -88,7 +89,8 @@ fn swap_join_type(join_type: JoinType) -> JoinType {
JoinType::Right => JoinType::Left,
JoinType::LeftSemi => JoinType::RightSemi,
JoinType::RightSemi => JoinType::LeftSemi,
_ => unreachable!(),
JoinType::LeftAnti => JoinType::RightAnti,
JoinType::RightAnti => JoinType::LeftAnti,
}
}

Expand Down Expand Up @@ -176,7 +178,10 @@ impl PhysicalOptimizerRule for HashBuildProbeOrder {
)?;
if matches!(
hash_join.join_type(),
JoinType::LeftSemi | JoinType::RightSemi
JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
) {
return Ok(Arc::new(new_join));
}
Expand Down Expand Up @@ -362,45 +367,48 @@ mod tests {
}

#[tokio::test]
async fn test_join_with_swap_left_semi() {
let (big, small) = create_big_and_small();

let join = HashJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
)],
None,
&JoinType::LeftSemi,
PartitionMode::CollectLeft,
&false,
)
.unwrap();
async fn test_join_with_swap_semi() {
let join_types = [JoinType::LeftSemi, JoinType::LeftAnti];
for join_type in join_types {
let (big, small) = create_big_and_small();

let join = HashJoinExec::try_new(
Arc::clone(&big),
Arc::clone(&small),
vec![(
Column::new_with_schema("big_col", &big.schema()).unwrap(),
Column::new_with_schema("small_col", &small.schema()).unwrap(),
)],
None,
&join_type,
PartitionMode::CollectLeft,
&false,
)
.unwrap();

let original_schema = join.schema();
let original_schema = join.schema();

let optimized_join = HashBuildProbeOrder::new()
.optimize(Arc::new(join), &SessionConfig::new())
.unwrap();
let optimized_join = HashBuildProbeOrder::new()
.optimize(Arc::new(join), &SessionConfig::new())
.unwrap();

let swapped_join = optimized_join
.as_any()
.downcast_ref::<HashJoinExec>()
.expect(
"A proj is not required to swap columns back to their original order",
);
let swapped_join = optimized_join
.as_any()
.downcast_ref::<HashJoinExec>()
.expect(
"A proj is not required to swap columns back to their original order",
);

assert_eq!(swapped_join.schema().fields().len(), 1);
assert_eq!(swapped_join.schema().fields().len(), 1);

assert_eq!(swapped_join.left().statistics().total_byte_size, Some(10));
assert_eq!(
swapped_join.right().statistics().total_byte_size,
Some(100000)
);
assert_eq!(swapped_join.left().statistics().total_byte_size, Some(10));
assert_eq!(
swapped_join.right().statistics().total_byte_size,
Some(100000)
);

assert_eq!(original_schema, swapped_join.schema());
assert_eq!(original_schema, swapped_join.schema());
}
}

/// Compare the input plan with the plan after running the probe order optimizer.
Expand Down
109 changes: 103 additions & 6 deletions datafusion/core/src/physical_plan/joins/hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,6 @@ fn build_join_indexes(
&keys_values,
*null_equals_null,
)? {
left_indices.append(i);
right_indices.append(row as u32);
break;
}
Expand All @@ -813,6 +812,59 @@ fn build_join_indexes(
PrimitiveArray::<UInt32Type>::from(right),
))
}
JoinType::RightAnti => {
let mut left_indices = UInt64BufferBuilder::new(0);
let mut right_indices = UInt32BufferBuilder::new(0);

// Visit all of the right rows
for (row, hash_value) in hash_values.iter().enumerate() {
// Get the hash and find it in the build index

// For every item on the left and right we check if it doesn't match
// This possibly contains rows with hash collisions,
// So we have to check here whether rows are equal or not
// We only produce one row if there is no match
let matches = left.0.get(*hash_value, |(hash, _)| *hash_value == *hash);
let mut no_match = true;
match matches {
Some((_, indices)) => {
for &i in indices {
// Check hash collisions
if equal_rows(
i as usize,
row,
&left_join_values,
&keys_values,
*null_equals_null,
)? {
no_match = false;
break;
}
}
}
None => no_match = true,
};
if no_match {
right_indices.append(row as u32);
}
}

let left = ArrayData::builder(DataType::UInt64)
.len(left_indices.len())
.add_buffer(left_indices.finish())
.build()
.unwrap();
let right = ArrayData::builder(DataType::UInt32)
.len(right_indices.len())
.add_buffer(right_indices.finish())
.build()
.unwrap();

Ok((
PrimitiveArray::<UInt64Type>::from(left),
PrimitiveArray::<UInt32Type>::from(right),
))
}
JoinType::Left => {
let mut left_indices = UInt64Builder::with_capacity(0);
let mut right_indices = UInt32Builder::with_capacity(0);
Expand Down Expand Up @@ -887,7 +939,7 @@ fn apply_join_filter(
right_indices: UInt32Array,
filter: &JoinFilter,
) -> Result<(UInt64Array, UInt32Array)> {
if left_indices.is_empty() {
if left_indices.is_empty() && right_indices.is_empty() {
return Ok((left_indices, right_indices));
};

Expand All @@ -904,6 +956,7 @@ fn apply_join_filter(
JoinType::Inner
| JoinType::Left
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::LeftSemi
| JoinType::RightSemi => {
// For both INNER and LEFT joins, input arrays contains only indices for matched data.
Expand Down Expand Up @@ -1342,9 +1395,10 @@ impl HashJoinStream {

buffer
}
JoinType::Inner | JoinType::Right | JoinType::RightSemi => {
BooleanBufferBuilder::new(0)
}
JoinType::Inner
| JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti => BooleanBufferBuilder::new(0),
}
});

Expand Down Expand Up @@ -1381,7 +1435,10 @@ impl HashJoinStream {
visited_left_side.set_bit(x as usize, true);
});
}
JoinType::Inner | JoinType::Right | JoinType::RightSemi => {}
JoinType::Inner
| JoinType::Right
| JoinType::RightSemi
| JoinType::RightAnti => {}
}
}
Some(result.map(|x| x.0))
Expand Down Expand Up @@ -1420,6 +1477,7 @@ impl HashJoinStream {
| JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti
| JoinType::Inner
| JoinType::Right => {}
}
Expand Down Expand Up @@ -2255,6 +2313,45 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn join_right_anti() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();
let right = build_table(
("a1", &vec![1, 2, 2, 3, 5]),
("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right
("c1", &vec![7, 8, 8, 9, 11]),
);
let left = build_table(
("a2", &vec![10, 20, 30, 40]),
("b2", &vec![4, 5, 6, 5]), // 5 is double on the right
("c2", &vec![70, 80, 90, 100]),
);
let on = vec![(
Column::new_with_schema("b2", &left.schema())?,
Column::new_with_schema("b1", &right.schema())?,
)];

let join = join(left, right, on, &JoinType::RightAnti, false)?;

let columns = columns(&join.schema());
assert_eq!(columns, vec!["a1", "b1", "c1"]);

let stream = join.execute(0, task_ctx)?;
let batches = common::collect(stream).await?;

let expected = vec![
"+----+----+----+",
"| a1 | b1 | c1 |",
"+----+----+----+",
"| 3 | 7 | 9 |",
"| 5 | 7 | 11 |",
"+----+----+----+",
];
assert_batches_sorted_eq!(expected, &batches);
Ok(())
}

#[tokio::test]
async fn join_anti_with_filter() -> Result<()> {
let session_ctx = SessionContext::new();
Expand Down
6 changes: 4 additions & 2 deletions datafusion/core/src/physical_plan/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ impl ExecutionPlan for SortMergeJoinExec {
| JoinType::Left
| JoinType::LeftSemi
| JoinType::LeftAnti => self.left.output_ordering(),
JoinType::Right | JoinType::RightSemi => self.right.output_ordering(),
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => {
self.right.output_ordering()
}
JoinType::Full => None,
}
}
Expand Down Expand Up @@ -187,7 +189,7 @@ impl ExecutionPlan for SortMergeJoinExec {
self.on.iter().map(|on| on.0.clone()).collect(),
self.on.iter().map(|on| on.1.clone()).collect(),
),
JoinType::Right | JoinType::RightSemi => (
JoinType::Right | JoinType::RightSemi | JoinType::RightAnti => (
self.right.clone(),
self.left.clone(),
self.on.iter().map(|on| on.1.clone()).collect(),
Expand Down
10 changes: 6 additions & 4 deletions datafusion/core/src/physical_plan/joins/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ fn output_join_field(old_field: &Field, join_type: &JoinType, is_left: bool) ->
JoinType::LeftSemi => false, // doesn't introduce nulls
JoinType::RightSemi => false, // doesn't introduce nulls
JoinType::LeftAnti => false, // doesn't introduce nulls (or can it??)
JoinType::RightAnti => false, // doesn't introduce nulls (or can it??)
};

if force_nullable {
Expand Down Expand Up @@ -237,7 +238,7 @@ pub fn build_join_schema(
)
})
.unzip(),
JoinType::RightSemi => right
JoinType::RightSemi | JoinType::RightAnti => right
.fields()
.iter()
.cloned()
Expand Down Expand Up @@ -410,9 +411,10 @@ fn estimate_join_cardinality(
})
}

JoinType::LeftSemi => None,
JoinType::LeftAnti => None,
JoinType::RightSemi => None,
JoinType::LeftSemi
| JoinType::RightSemi
| JoinType::LeftAnti
| JoinType::RightAnti => None,
}
}

Expand Down
5 changes: 4 additions & 1 deletion datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,10 @@ pub fn build_join_schema(
// Only use the left side for the schema
left.fields().clone()
}
JoinType::RightSemi => right.fields().clone(),
JoinType::RightSemi | JoinType::RightAnti => {
// Only use the right side for the schema
right.fields().clone()
}
};

let mut metadata = left.metadata().clone();
Expand Down
3 changes: 3 additions & 0 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -982,6 +982,8 @@ pub enum JoinType {
RightSemi,
/// Left Anti Join
LeftAnti,
/// Right Anti Join
RightAnti,
}

impl Display for JoinType {
Expand All @@ -994,6 +996,7 @@ impl Display for JoinType {
JoinType::LeftSemi => "LeftSemi",
JoinType::RightSemi => "RightSemi",
JoinType::LeftAnti => "LeftAnti",
JoinType::RightAnti => "RightAnti",
};
write!(f, "{}", join_type)
}
Expand Down
7 changes: 5 additions & 2 deletions datafusion/optimizer/src/filter_push_down.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ fn lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
JoinType::LeftSemi | JoinType::LeftAnti => Ok((true, false)),
// No columns from the left side of the join can be referenced in output
// predicates for semi/anti joins, so whether we specify t/f doesn't matter.
JoinType::RightSemi => Ok((false, true)),
JoinType::RightSemi | JoinType::RightAnti => Ok((false, true)),
},
LogicalPlan::CrossJoin(_) => Ok((true, true)),
_ => Err(DataFusionError::Internal(
Expand All @@ -198,7 +198,10 @@ fn on_lr_is_preserved(plan: &LogicalPlan) -> Result<(bool, bool)> {
JoinType::Left => Ok((false, true)),
JoinType::Right => Ok((true, false)),
JoinType::Full => Ok((false, false)),
JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightSemi => {
JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::RightSemi
| JoinType::RightAnti => {
// filter_push_down does not yet support SEMI/ANTI joins with join conditions
Ok((false, false))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ enum JoinType {
LEFTSEMI = 4;
LEFTANTI = 5;
RIGHTSEMI = 6;
RIGHTANTI = 7;
}

enum JoinConstraint {
Expand Down
Loading

0 comments on commit 6356a21

Please sign in to comment.