Skip to content

Commit

Permalink
feat(hash join): trivial interval join (close #9228)
Browse files Browse the repository at this point in the history
  • Loading branch information
Liang Zhao committed Apr 17, 2023
1 parent 58980e8 commit af9a661
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 32 deletions.
8 changes: 8 additions & 0 deletions proto/stream_plan.proto
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,13 @@ message GroupTopNNode {
bool with_ties = 6;
}

message BandJoinCondition {
uint32 left_column_idx = 1;
uint32 right_column_idx = 2;
expr.ExprNode left_greater_conjunction = 3;
expr.ExprNode right_greater_conjunction = 4;
}

message DeltaExpression {
expr.ExprNode.Type delta_type = 1;
expr.ExprNode delta = 2;
Expand Down Expand Up @@ -306,6 +313,7 @@ message HashJoinNode {
// Whether to optimize for append only stream.
// It is true when the input is append-only
bool is_append_only = 14;
BandJoinCondition band_condition = 15;
}

message TemporalJoinNode {
Expand Down
1 change: 1 addition & 0 deletions src/frontend/src/optimizer/plan_node/logical_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,7 @@ impl LogicalJoin {
let stream_hash_join = StreamHashJoin::new(logical_join.core.clone(), predicate.clone());
let pull_filter = self.join_type() == JoinType::Inner
&& stream_hash_join.eq_join_predicate().has_non_eq()
&& stream_hash_join.band_condition().is_none()
&& stream_hash_join.inequality_pairs().is_empty();
if pull_filter {
let default_indices = (0..self.internal_column_num()).collect::<Vec<_>>();
Expand Down
5 changes: 3 additions & 2 deletions src/frontend/src/optimizer/plan_node/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ impl HashJoin {
input: &impl StreamPlanRef,
join_key_indices: Vec<usize>,
dk_indices_in_jk: Vec<usize>,
band_key_index: Option<usize>,
) -> (TableCatalog, TableCatalog, Vec<usize>) {
let schema = input.schema();

Expand All @@ -245,13 +246,13 @@ impl HashJoin {

let degree_table_dist_keys = dk_indices_in_jk.clone();

// The pk of hash join internal and degree table should be join_key + input_pk.
// The pk of hash join internal and degree table should be join_key + input_pk [+ band_key].
let join_key_len = join_key_indices.len();
let mut pk_indices = join_key_indices;

// dedup the pk in dist key..
let mut deduped_input_pk_indices = vec![];
for input_pk_idx in input.logical_pk() {
for input_pk_idx in band_key_index.iter().chain(input.logical_pk()) {
if !pk_indices.contains(input_pk_idx)
&& !deduped_input_pk_indices.contains(input_pk_idx)
{
Expand Down
179 changes: 168 additions & 11 deletions src/frontend/src/optimizer/plan_node/stream_hash_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::BTreeMap;
use std::fmt;

use fixedbitset::FixedBitSet;
Expand All @@ -20,7 +21,9 @@ use risingwave_common::catalog::{FieldDisplay, Schema};
use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_pb::plan_common::JoinType;
use risingwave_pb::stream_plan::stream_node::NodeBody;
use risingwave_pb::stream_plan::{DeltaExpression, HashJoinNode, PbInequalityPair};
use risingwave_pb::stream_plan::{
DeltaExpression, HashJoinNode, PbBandJoinCondition, PbInequalityPair,
};

use super::{
generic, ExprRewritable, PlanBase, PlanRef, PlanTreeNodeBinary, StreamDeltaJoin, StreamNode,
Expand All @@ -31,7 +34,15 @@ use crate::optimizer::plan_node::utils::IndicesDisplay;
use crate::optimizer::plan_node::{EqJoinPredicate, EqJoinPredicateDisplay};
use crate::optimizer::property::Distribution;
use crate::stream_fragmenter::BuildFragmentGraphState;
use crate::utils::ColIndexMappingRewriteExt;
use crate::utils::{ColIndexMappingRewriteExt, Condition};

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BandCondition {
left_column_idx: usize,
right_column_idx: usize,
left_greater_conjunction_idx: usize,
right_greater_conjunction_idx: usize,
}

/// [`StreamHashJoin`] implements [`super::LogicalJoin`] with hash table. It builds a hash table
/// from inner (right-side) relation and probes with data from outer (left-side) relation to
Expand All @@ -53,6 +64,10 @@ pub struct StreamHashJoin {
/// It is true if input of both side is append-only
is_append_only: bool,

/// A pair of band conditions. We can select entries which match the condition only in
/// `HashJoinExecutor`.
band_condition: Option<BandCondition>,

/// The conjunction index of the inequality which is used to clean left state table in
/// `HashJoinExecutor`. If any equal condition is able to clean state table, this field
/// will always be `None`.
Expand All @@ -79,6 +94,8 @@ impl StreamHashJoin {
&logical,
);

let mut band_condition = None;

let mut inequality_pairs = vec![];
let mut clean_left_state_conjunction_idx = None;
let mut clean_right_state_conjunction_idx = None;
Expand All @@ -97,10 +114,15 @@ impl StreamHashJoin {
let watermark_columns = {
let l2i = logical.l2i_col_mapping();
let r2i = logical.r2i_col_mapping();
let (left_cols_num, original_inequality_pairs) = eq_join_predicate.inequality_pairs();

let mut equal_condition_clean_state = false;
let mut join_key_indices_in_all_columns =
Vec::with_capacity(eq_join_predicate.eq_indexes().len() * 2);
let mut watermark_columns = FixedBitSet::with_capacity(logical.internal_column_num());
for (left_key, right_key) in eq_join_predicate.eq_indexes() {
join_key_indices_in_all_columns.push(left_key);
join_key_indices_in_all_columns.push(right_key + left_cols_num);
if logical.left.watermark_columns().contains(left_key)
&& logical.right.watermark_columns().contains(right_key)
{
Expand All @@ -113,7 +135,82 @@ impl StreamHashJoin {
}
}
}
let (left_cols_num, original_inequality_pairs) = eq_join_predicate.inequality_pairs();
join_key_indices_in_all_columns.sort();
join_key_indices_in_all_columns.dedup();

let mut both_pk_in_jk = true;
for left_pk in logical.left.logical_pk() {
if join_key_indices_in_all_columns
.binary_search(left_pk)
.is_err()
{
both_pk_in_jk = false;
break;
}
}
if both_pk_in_jk {
for right_pk in logical.right.logical_pk() {
if join_key_indices_in_all_columns
.binary_search(&(right_pk + left_cols_num))
.is_err()
{
both_pk_in_jk = false;
break;
}
}
}

if !both_pk_in_jk {
println!("{eq_join_predicate:?}");
let mut pair2conjunction_idx = BTreeMap::new();
for (
conjunction_idx,
InequalityInputPair {
key_required_larger,
key_required_smaller,
..
},
) in &original_inequality_pairs
{
println!("{key_required_larger:?}, {key_required_smaller:}");
if join_key_indices_in_all_columns
.binary_search(key_required_larger)
.is_err()
&& join_key_indices_in_all_columns
.binary_search(key_required_smaller)
.is_err()
{
if let Some(record_conjunction_idx) =
pair2conjunction_idx.get(&(*key_required_smaller, *key_required_larger))
{
band_condition = Some(if key_required_larger < key_required_smaller {
BandCondition {
left_column_idx: *key_required_larger,
right_column_idx: key_required_smaller - left_cols_num,
left_greater_conjunction_idx: *conjunction_idx,
right_greater_conjunction_idx: *record_conjunction_idx,
}
} else {
BandCondition {
left_column_idx: *key_required_smaller,
right_column_idx: key_required_larger - left_cols_num,
left_greater_conjunction_idx: *record_conjunction_idx,
right_greater_conjunction_idx: *conjunction_idx,
}
});
break;
}

pair2conjunction_idx
.try_insert(
(*key_required_larger, *key_required_smaller),
*conjunction_idx,
)
.ok();
}
}
}

for (
conjunction_idx,
InequalityInputPair {
Expand Down Expand Up @@ -207,6 +304,7 @@ impl StreamHashJoin {
eq_join_predicate,
inequality_pairs,
is_append_only: append_only,
band_condition,
clean_left_state_conjunction_idx,
clean_right_state_conjunction_idx,
}
Expand Down Expand Up @@ -290,6 +388,10 @@ impl StreamHashJoin {
pub fn inequality_pairs(&self) -> &Vec<(bool, InequalityInputPair)> {
&self.inequality_pairs
}

pub fn band_condition(&self) -> &Option<BandCondition> {
&self.band_condition
}
}

impl fmt::Display for StreamHashJoin {
Expand All @@ -305,9 +407,7 @@ impl fmt::Display for StreamHashJoin {
&& self.right().watermark_columns().contains(rjk)
{
f.debug_struct("StreamWindowJoin")
} else if self.clean_left_state_conjunction_idx.is_some()
&& self.clean_right_state_conjunction_idx.is_some()
{
} else if self.band_condition.is_some() {
f.debug_struct("StreamIntervalJoin")
} else if self.is_append_only {
f.debug_struct("StreamAppendOnlyHashJoin")
Expand All @@ -329,6 +429,25 @@ impl fmt::Display for StreamHashJoin {
},
);

if let Some(band_condition) = self.band_condition.as_ref() {
builder.field(
"band_join_condition_left_greater",
&ExprDisplay {
expr: &self.eq_join_predicate().other_cond().conjunctions
[band_condition.left_greater_conjunction_idx],
input_schema: &concat_schema,
},
);
builder.field(
"band_join_condition_right_greater",
&ExprDisplay {
expr: &self.eq_join_predicate().other_cond().conjunctions
[band_condition.right_greater_conjunction_idx],
input_schema: &concat_schema,
},
);
}

if let Some(conjunction_idx) = self.clean_left_state_conjunction_idx {
builder.field(
"conditions_to_clean_left_state_table",
Expand Down Expand Up @@ -418,12 +537,18 @@ impl StreamNode for StreamHashJoin {
self.left().plan_base(),
left_jk_indices,
dk_indices_in_jk.clone(),
self.band_condition
.as_ref()
.map(|band_condition| band_condition.left_column_idx),
);
let (right_table, right_degree_table, right_deduped_input_pk_indices) =
HashJoin::infer_internal_and_degree_table_catalog(
self.right().plan_base(),
right_jk_indices,
dk_indices_in_jk,
self.band_condition
.as_ref()
.map(|band_condition| band_condition.right_column_idx),
);

let left_deduped_input_pk_indices = left_deduped_input_pk_indices
Expand All @@ -447,16 +572,32 @@ impl StreamNode for StreamHashJoin {

let null_safe_prost = self.eq_join_predicate.null_safes().into_iter().collect();

let condition = if let Some(band_condition) = self.band_condition.as_ref() {
Condition {
conjunctions: self
.eq_join_predicate
.other_cond()
.conjunctions
.iter()
.enumerate()
.filter(|(conjunction_idx, _)| {
*conjunction_idx != band_condition.left_greater_conjunction_idx
&& *conjunction_idx != band_condition.right_greater_conjunction_idx
})
.map(|(_, conjunction)| conjunction.clone())
.collect_vec(),
}
.as_expr_unless_true()
} else {
self.eq_join_predicate.other_cond().as_expr_unless_true()
};

NodeBody::HashJoin(HashJoinNode {
join_type: self.logical.join_type as i32,
left_key: left_jk_indices_prost,
right_key: right_jk_indices_prost,
null_safe: null_safe_prost,
condition: self
.eq_join_predicate
.other_cond()
.as_expr_unless_true()
.map(|x| x.to_expr_proto()),
condition: condition.map(|x| x.to_expr_proto()),
inequality_pairs: self
.inequality_pairs
.iter()
Expand Down Expand Up @@ -496,6 +637,22 @@ impl StreamNode for StreamHashJoin {
.map(|&x| x as u32)
.collect(),
is_append_only: self.is_append_only,
band_condition: self.band_condition.as_ref().map(|band_condition| {
PbBandJoinCondition {
left_column_idx: band_condition.left_column_idx as u32,
right_column_idx: band_condition.right_column_idx as u32,
left_greater_conjunction: Some(
self.eq_join_predicate.other_cond().conjunctions
[band_condition.left_greater_conjunction_idx]
.to_expr_proto(),
),
right_greater_conjunction: Some(
self.eq_join_predicate.other_cond().conjunctions
[band_condition.right_greater_conjunction_idx]
.to_expr_proto(),
),
}
}),
})
}
}
Expand Down
Loading

0 comments on commit af9a661

Please sign in to comment.