Skip to content

Commit

Permalink
refactor(plan_node): simplify batch agg-nodes (#9176)
Browse files Browse the repository at this point in the history
  • Loading branch information
ice1000 authored Apr 14, 2023
1 parent fdd319e commit de0b019
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 97 deletions.
42 changes: 20 additions & 22 deletions src/frontend/src/optimizer/plan_node/batch_hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ use risingwave_common::error::Result;
use risingwave_pb::batch_plan::plan_node::NodeBody;
use risingwave_pb::batch_plan::HashAggNode;

use super::generic::{GenericPlanRef, PlanAggCall};
use super::generic::{self, GenericPlanRef, PlanAggCall};
use super::{
ExprRewritable, LogicalAgg, PlanBase, PlanNodeType, PlanRef, PlanTreeNodeUnary, ToBatchPb,
ExprRewritable, PlanBase, PlanNodeType, PlanRef, PlanTreeNodeUnary, ToBatchPb,
ToDistributedBatch,
};
use crate::expr::ExprRewriter;
Expand All @@ -32,30 +32,31 @@ use crate::utils::ColIndexMappingRewriteExt;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BatchHashAgg {
pub base: PlanBase,
logical: LogicalAgg,
logical: generic::Agg<PlanRef>,
}

impl BatchHashAgg {
pub fn new(logical: LogicalAgg) -> Self {
let ctx = logical.base.ctx.clone();
let input = logical.input();
pub fn new(logical: generic::Agg<PlanRef>) -> Self {
let base = PlanBase::new_logical_with_core(&logical);
let ctx = base.ctx;
let input = logical.input.clone();
let input_dist = input.distribution();
let dist = match input_dist {
Distribution::HashShard(_) | Distribution::UpstreamHashShard(_, _) => logical
.i2o_col_mapping()
.rewrite_provided_distribution(input_dist),
d => d.clone(),
};
let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, Order::any());
let base = PlanBase::new_batch(ctx, base.schema, dist, Order::any());
BatchHashAgg { base, logical }
}

pub fn agg_calls(&self) -> &[PlanAggCall] {
self.logical.agg_calls()
&self.logical.agg_calls
}

pub fn group_key(&self) -> &[usize] {
self.logical.group_key()
&self.logical.group_key
}

fn to_two_phase_agg(&self, dist_input: PlanRef) -> Result<PlanRef> {
Expand All @@ -73,14 +74,14 @@ impl BatchHashAgg {
// insert total agg
let total_agg_types = self
.logical
.agg_calls()
.agg_calls
.iter()
.enumerate()
.map(|(partial_output_idx, agg_call)| {
agg_call.partial_to_total_agg_call(partial_output_idx + self.group_key().len())
})
.collect();
let total_agg_logical = LogicalAgg::new(
let total_agg_logical = generic::Agg::new(
total_agg_types,
(0..self.group_key().len()).collect(),
exchange,
Expand All @@ -104,18 +105,20 @@ impl fmt::Display for BatchHashAgg {

impl PlanTreeNodeUnary for BatchHashAgg {
fn input(&self) -> PlanRef {
self.logical.input()
self.logical.input.clone()
}

fn clone_with_input(&self, input: PlanRef) -> Self {
Self::new(self.logical.clone_with_input(input))
let mut logical = self.logical.clone();
logical.input = input;
Self::new(logical)
}
}

impl_plan_tree_node_for_unary! { BatchHashAgg }
impl ToDistributedBatch for BatchHashAgg {
fn to_distributed(&self) -> Result<PlanRef> {
if self.logical.two_phase_agg_forced() && self.logical.can_two_phase_agg() {
if self.logical.must_try_two_phase_agg() {
let input = self.input().to_distributed()?;
let input_dist = input.distribution();
if !self
Expand Down Expand Up @@ -170,13 +173,8 @@ impl ExprRewritable for BatchHashAgg {
}

fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
Self::new(
self.logical
.rewrite_exprs(r)
.as_logical_agg()
.unwrap()
.clone(),
)
.into()
let mut logical = self.logical.clone();
logical.rewrite_exprs(r);
Self::new(logical).into()
}
}
42 changes: 19 additions & 23 deletions src/frontend/src/optimizer/plan_node/batch_sort_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,8 @@ use risingwave_pb::batch_plan::plan_node::NodeBody;
use risingwave_pb::batch_plan::SortAggNode;
use risingwave_pb::expr::ExprNode;

use super::generic::{GenericPlanRef, PlanAggCall};
use super::{
ExprRewritable, LogicalAgg, PlanBase, PlanRef, PlanTreeNodeUnary, ToBatchPb, ToDistributedBatch,
};
use super::generic::{self, GenericPlanRef, PlanAggCall};
use super::{ExprRewritable, PlanBase, PlanRef, PlanTreeNodeUnary, ToBatchPb, ToDistributedBatch};
use crate::expr::{Expr, ExprImpl, ExprRewriter, InputRef};
use crate::optimizer::plan_node::ToLocalBatch;
use crate::optimizer::property::{Distribution, Order, RequiredDist};
Expand All @@ -32,14 +30,15 @@ use crate::utils::ColIndexMappingRewriteExt;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct BatchSortAgg {
pub base: PlanBase,
logical: LogicalAgg,
logical: generic::Agg<PlanRef>,
input_order: Order,
}

impl BatchSortAgg {
pub fn new(logical: LogicalAgg) -> Self {
let ctx = logical.base.ctx.clone();
let input = logical.input();
pub fn new(logical: generic::Agg<PlanRef>) -> Self {
let base = PlanBase::new_logical_with_core(&logical);
let ctx = base.ctx;
let input = logical.input.clone();
let input_dist = input.distribution();
let dist = match input_dist {
Distribution::HashShard(_) | Distribution::UpstreamHashShard(_, _) => logical
Expand All @@ -52,18 +51,18 @@ impl BatchSortAgg {
.order()
.column_orders
.iter()
.filter(|o| logical.group_key().iter().any(|g_k| *g_k == o.column_index))
.filter(|o| logical.group_key.iter().any(|g_k| *g_k == o.column_index))
.cloned()
.collect(),
};

assert_eq!(input_order.column_orders.len(), logical.group_key().len());
assert_eq!(input_order.column_orders.len(), logical.group_key.len());

let order = logical
.i2o_col_mapping()
.rewrite_provided_order(&input_order);

let base = PlanBase::new_batch(ctx, logical.schema().clone(), dist, order);
let base = PlanBase::new_batch(ctx, base.schema, dist, order);

BatchSortAgg {
base,
Expand All @@ -73,11 +72,11 @@ impl BatchSortAgg {
}

pub fn agg_calls(&self) -> &[PlanAggCall] {
self.logical.agg_calls()
&self.logical.agg_calls
}

pub fn group_key(&self) -> &[usize] {
self.logical.group_key()
&self.logical.group_key
}
}

Expand All @@ -89,11 +88,13 @@ impl fmt::Display for BatchSortAgg {

impl PlanTreeNodeUnary for BatchSortAgg {
fn input(&self) -> PlanRef {
self.logical.input()
self.logical.input.clone()
}

fn clone_with_input(&self, input: PlanRef) -> Self {
Self::new(self.logical.clone_with_input(input))
let mut logical = self.logical.clone();
logical.input = input;
Self::new(logical)
}
}
impl_plan_tree_node_for_unary! { BatchSortAgg }
Expand Down Expand Up @@ -143,13 +144,8 @@ impl ExprRewritable for BatchSortAgg {
}

fn rewrite_exprs(&self, r: &mut dyn ExprRewriter) -> PlanRef {
Self::new(
self.logical
.rewrite_exprs(r)
.as_logical_agg()
.unwrap()
.clone(),
)
.into()
let mut new_logical = self.logical.clone();
new_logical.rewrite_exprs(r);
Self::new(new_logical).into()
}
}
27 changes: 25 additions & 2 deletions src/frontend/src/optimizer/plan_node/generic/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use super::super::utils::TableCatalogBuilder;
use super::{stream, GenericPlanNode, GenericPlanRef};
use crate::expr::{Expr, ExprRewriter, InputRef, InputRefDisplay};
use crate::optimizer::optimizer_context::OptimizerContextRef;
use crate::optimizer::property::FunctionalDependencySet;
use crate::optimizer::property::{Distribution, FunctionalDependencySet, RequiredDist};
use crate::stream_fragmenter::BuildFragmentGraphState;
use crate::utils::{
ColIndexMapping, ColIndexMappingRewriteExt, Condition, ConditionDisplay, IndexRewriter,
Expand Down Expand Up @@ -74,7 +74,30 @@ impl<PlanRef: GenericPlanRef> Agg<PlanRef> {
}

pub(crate) fn can_two_phase_agg(&self) -> bool {
self.call_support_two_phase() && !self.is_agg_result_affected_by_order()
self.call_support_two_phase()
&& !self.is_agg_result_affected_by_order()
&& self.two_phase_agg_enabled()
}

/// Must try two phase agg iff we are forced to, and we satisfy the constraints.
pub(crate) fn must_try_two_phase_agg(&self) -> bool {
self.two_phase_agg_forced() && self.can_two_phase_agg()
}

fn two_phase_agg_forced(&self) -> bool {
self.ctx().session_ctx().config().get_force_two_phase_agg()
}

fn two_phase_agg_enabled(&self) -> bool {
self.ctx().session_ctx().config().get_enable_two_phase_agg()
}

/// Generally used by two phase hash agg.
/// If input dist already satisfies hash agg distribution,
/// it will be more expensive to do two phase agg, should just do shuffle agg.
pub(crate) fn hash_agg_dist_satisfied_by_input_dist(&self, input_dist: &Distribution) -> bool {
let required_dist = RequiredDist::shard_by_key(self.input.schema().len(), &self.group_key);
input_dist.satisfies(&required_dist)
}

fn call_support_two_phase(&self) -> bool {
Expand Down
58 changes: 12 additions & 46 deletions src/frontend/src/optimizer/plan_node/logical_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,6 @@ impl LogicalAgg {
})
}

/// Generally used by two phase hash agg.
/// If input dist already satisfies hash agg distribution,
/// it will be more expensive to do two phase agg, should just do shuffle agg.
pub(crate) fn hash_agg_dist_satisfied_by_input_dist(&self, input_dist: &Distribution) -> bool {
let required_dist =
RequiredDist::shard_by_key(self.input().schema().len(), self.group_key());
input_dist.satisfies(&required_dist)
}

/// Generates distributed stream plan.
fn gen_dist_stream_agg_plan(&self, stream_input: PlanRef) -> Result<PlanRef> {
let input_dist = stream_input.distribution();
Expand All @@ -220,20 +211,20 @@ impl LogicalAgg {
// Shuffle agg
// If we have group key, and we won't try two phase agg optimization at all,
// we will always choose shuffle agg over single agg.
if !self.group_key().is_empty() && !self.must_try_two_phase_agg() {
if !self.group_key().is_empty() && !self.core.must_try_two_phase_agg() {
return self.gen_shuffle_plan(stream_input);
}

// Standalone agg
// If no group key, and cannot two phase agg, we have to use single plan.
if self.group_key().is_empty() && !self.can_two_phase_agg() {
if self.group_key().is_empty() && !self.core.can_two_phase_agg() {
return self.gen_single_plan(stream_input);
}

debug_assert!(if !self.group_key().is_empty() {
self.must_try_two_phase_agg()
self.core.must_try_two_phase_agg()
} else {
self.can_two_phase_agg()
self.core.can_two_phase_agg()
});

// Stateless 2-phase simple agg
Expand All @@ -251,7 +242,7 @@ impl LogicalAgg {
// We shall first distribute it by PK,
// so it obeys consistent hash strategy via [`Distribution::HashShard`].
let stream_input =
if *input_dist == Distribution::SomeShard && self.must_try_two_phase_agg() {
if *input_dist == Distribution::SomeShard && self.core.must_try_two_phase_agg() {
RequiredDist::shard_by_key(stream_input.schema().len(), stream_input.logical_pk())
.enforce_if_not_satisfies(stream_input, &Order::any())?
} else {
Expand All @@ -264,7 +255,7 @@ impl LogicalAgg {
// with input distributed by dist_key.
match input_dist {
Distribution::HashShard(dist_key) | Distribution::UpstreamHashShard(dist_key, _)
if (!self.hash_agg_dist_satisfied_by_input_dist(input_dist)
if (!self.core.hash_agg_dist_satisfied_by_input_dist(input_dist)
|| self.group_key().is_empty()) =>
{
let dist_key = dist_key.clone();
Expand All @@ -281,31 +272,6 @@ impl LogicalAgg {
}
}

pub(crate) fn two_phase_agg_forced(&self) -> bool {
self.base
.ctx()
.session_ctx()
.config()
.get_force_two_phase_agg()
}

fn two_phase_agg_enabled(&self) -> bool {
self.base
.ctx()
.session_ctx()
.config()
.get_enable_two_phase_agg()
}

/// Must try two phase agg iff we are forced to, and we satisfy the constraints.
fn must_try_two_phase_agg(&self) -> bool {
self.two_phase_agg_forced() && self.can_two_phase_agg()
}

pub(crate) fn can_two_phase_agg(&self) -> bool {
self.core.can_two_phase_agg() && self.two_phase_agg_enabled()
}

// Check if the output of the aggregation needs to be sorted and return ordering req by group
// keys If group key order satisfies required order, push down the sort below the
// aggregation and use sort aggregation. The data type of the columns need to be int32
Expand Down Expand Up @@ -337,16 +303,15 @@ impl LogicalAgg {
// Check if the input is already sorted, and hence sort merge aggregation can be used
// It can only be used, if the input is sorted on all group key indices and the
// datatype of the column is int32
fn input_provides_order_on_group_keys(&self, new_logical: &LogicalAgg) -> bool {
fn input_provides_order_on_group_keys(&self, new_logical: &generic::Agg<PlanRef>) -> bool {
self.group_key().iter().all(|group_by_idx| {
new_logical
.input()
let input = &new_logical.input;
input
.order()
.column_orders
.iter()
.any(|order| order.column_index == *group_by_idx)
&& new_logical
.input()
&& input
.schema()
.fields()
.get(*group_by_idx)
Expand Down Expand Up @@ -1133,7 +1098,8 @@ impl ToBatch for LogicalAgg {
.rewrite_provided_order(&group_key_order);
}
let new_input = self.input().to_batch_with_order_required(&input_order)?;
let new_logical = self.clone_with_input(new_input);
let mut new_logical = self.core.clone();
new_logical.input = new_input;
if self
.ctx()
.session_ctx()
Expand Down
8 changes: 4 additions & 4 deletions src/frontend/src/optimizer/plan_node/logical_union.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ use crate::expr::{ExprImpl, InputRef, Literal};
use crate::optimizer::plan_node::generic::GenericPlanRef;
use crate::optimizer::plan_node::stream_union::StreamUnion;
use crate::optimizer::plan_node::{
generic, BatchHashAgg, BatchUnion, ColumnPruningContext, LogicalAgg, LogicalProject,
PlanTreeNode, PredicatePushdownContext, RewriteStreamContext, ToStreamContext,
generic, BatchHashAgg, BatchUnion, ColumnPruningContext, LogicalProject, PlanTreeNode,
PredicatePushdownContext, RewriteStreamContext, ToStreamContext,
};
use crate::optimizer::property::RequiredDist;
use crate::utils::{ColIndexMapping, Condition};
Expand Down Expand Up @@ -131,9 +131,9 @@ impl ToBatch for LogicalUnion {
// Convert union to union all + agg
if !self.all() {
let batch_union = BatchUnion::new(new_logical).into();
Ok(BatchHashAgg::new(LogicalAgg::new(
Ok(BatchHashAgg::new(generic::Agg::new(
vec![],
(0..self.base.schema.len()).collect_vec(),
(0..self.base.schema.len()).collect(),
batch_union,
))
.into())
Expand Down

0 comments on commit de0b019

Please sign in to comment.