Skip to content

Commit

Permalink
feat(stream,agg): enable distinct agg support in backend (#8100)
Browse files Browse the repository at this point in the history
Previously in #7797, distinct agg support is added (without cache) but not enabled. This PR enables it by disable 2-phase rewrite rule for streaming distinct agg calls, and also adds an LRU cache in the deduplicater.

This will close #7682, and possibly resolve or at least mitigate the performance issue in #7350 and #7271.

Approved-By: st1page
  • Loading branch information
stdrc authored Feb 24, 2023
1 parent 26de7bd commit 3c5bf28
Show file tree
Hide file tree
Showing 14 changed files with 304 additions and 318 deletions.
28 changes: 9 additions & 19 deletions src/frontend/planner_test/tests/testdata/agg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -779,13 +779,10 @@
└─BatchScan { table: t, columns: [t.a, t.b, t.c], distribution: SomeShard }
stream_plan: |
StreamMaterialize { columns: [a, distinct_b_num, sum_c], pk_columns: [a], pk_conflict: "no check" }
└─StreamProject { exprs: [t.a, count(t.b), sum(sum(t.c))] }
└─StreamHashAgg { group_key: [t.a], aggs: [count, count(t.b), sum(sum(t.c))] }
└─StreamProject { exprs: [t.a, count(distinct t.b), sum(t.c)] }
└─StreamHashAgg { group_key: [t.a], aggs: [count, count(distinct t.b), sum(t.c)] }
└─StreamExchange { dist: HashShard(t.a) }
└─StreamProject { exprs: [t.a, t.b, sum(t.c)] }
└─StreamHashAgg { group_key: [t.a, t.b], aggs: [count, sum(t.c)] }
└─StreamExchange { dist: HashShard(t.a, t.b) }
└─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
- name: distinct agg and non-disintct agg with intersected argument
sql: |
create table t(a int, b int, c int);
Expand All @@ -805,14 +802,10 @@
└─BatchScan { table: t, columns: [t.a, t.b, t.c], distribution: SomeShard }
stream_plan: |
StreamMaterialize { columns: [a, distinct_b_num, distinct_c_sum, sum_c], pk_columns: [a], pk_conflict: "no check" }
└─StreamProject { exprs: [t.a, count(t.b) filter((flag = 1:Int64)), count(t.c) filter((flag = 0:Int64)), sum(sum(t.c)) filter((flag = 0:Int64))] }
└─StreamHashAgg { group_key: [t.a], aggs: [count, count(t.b) filter((flag = 1:Int64)), count(t.c) filter((flag = 0:Int64)), sum(sum(t.c)) filter((flag = 0:Int64))] }
└─StreamProject { exprs: [t.a, count(distinct t.b), count(distinct t.c), sum(t.c)] }
└─StreamHashAgg { group_key: [t.a], aggs: [count, count(distinct t.b), count(distinct t.c), sum(t.c)] }
└─StreamExchange { dist: HashShard(t.a) }
└─StreamProject { exprs: [t.a, t.b, t.c, flag, sum(t.c)] }
└─StreamHashAgg { group_key: [t.a, t.b, t.c, flag], aggs: [count, sum(t.c)] }
└─StreamExchange { dist: HashShard(t.a, t.b, t.c, flag) }
└─StreamExpand { column_subsets: [[t.a, t.c], [t.a, t.b]] }
└─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
- name: distinct agg with filter
sql: |
create table t(a int, b int, c int);
Expand All @@ -830,13 +823,10 @@
└─BatchScan { table: t, columns: [t.a, t.b, t.c], distribution: SomeShard }
stream_plan: |
StreamMaterialize { columns: [a, count, sum], pk_columns: [a], pk_conflict: "no check" }
└─StreamProject { exprs: [t.a, count(t.b) filter((count filter((t.b < 100:Int32)) > 0:Int64)), sum(sum(t.c))] }
└─StreamHashAgg { group_key: [t.a], aggs: [count, count(t.b) filter((count filter((t.b < 100:Int32)) > 0:Int64)), sum(sum(t.c))] }
└─StreamProject { exprs: [t.a, count(distinct t.b) filter((t.b < 100:Int32)), sum(t.c)] }
└─StreamHashAgg { group_key: [t.a], aggs: [count, count(distinct t.b) filter((t.b < 100:Int32)), sum(t.c)] }
└─StreamExchange { dist: HashShard(t.a) }
└─StreamProject { exprs: [t.a, t.b, count filter((t.b < 100:Int32)), sum(t.c)] }
└─StreamHashAgg { group_key: [t.a, t.b], aggs: [count, count filter((t.b < 100:Int32)), sum(t.c)] }
└─StreamExchange { dist: HashShard(t.a, t.b) }
└─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
└─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
- name: non-distinct agg with filter
sql: |
create table t(a int, b int, c int);
Expand Down
71 changes: 23 additions & 48 deletions src/frontend/planner_test/tests/testdata/nexmark.yaml

Large diffs are not rendered by default.

77 changes: 26 additions & 51 deletions src/frontend/planner_test/tests/testdata/nexmark_source.yaml

Large diffs are not rendered by default.

48 changes: 21 additions & 27 deletions src/frontend/planner_test/tests/testdata/tpch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2691,40 +2691,35 @@
└─BatchScan { table: partsupp, columns: [partsupp.ps_partkey, partsupp.ps_suppkey], distribution: UpstreamHashShard(partsupp.ps_partkey, partsupp.ps_suppkey) }
stream_plan: |
StreamMaterialize { columns: [p_brand, p_type, p_size, supplier_cnt], pk_columns: [p_brand, p_type, p_size], order_descs: [supplier_cnt, p_brand, p_type, p_size], pk_conflict: "no check" }
└─StreamProject { exprs: [part.p_brand, part.p_type, part.p_size, count(partsupp.ps_suppkey)] }
└─StreamHashAgg { group_key: [part.p_brand, part.p_type, part.p_size], aggs: [count, count(partsupp.ps_suppkey)] }
└─StreamProject { exprs: [part.p_brand, part.p_type, part.p_size, count(distinct partsupp.ps_suppkey)] }
└─StreamHashAgg { group_key: [part.p_brand, part.p_type, part.p_size], aggs: [count, count(distinct partsupp.ps_suppkey)] }
└─StreamExchange { dist: HashShard(part.p_brand, part.p_type, part.p_size) }
└─StreamProject { exprs: [part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey] }
└─StreamHashAgg { group_key: [part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey], aggs: [count] }
└─StreamHashJoin { type: LeftAnti, predicate: partsupp.ps_suppkey = supplier.s_suppkey, output: [part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey, partsupp.ps_partkey, part.p_partkey] }
├─StreamExchange { dist: HashShard(partsupp.ps_suppkey) }
| └─StreamHashJoin { type: Inner, predicate: partsupp.ps_partkey = part.p_partkey, output: [partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size, partsupp.ps_partkey, part.p_partkey] }
| ├─StreamExchange { dist: HashShard(partsupp.ps_partkey) }
| | └─StreamTableScan { table: partsupp, columns: [partsupp.ps_partkey, partsupp.ps_suppkey], pk: [partsupp.ps_partkey, partsupp.ps_suppkey], dist: UpstreamHashShard(partsupp.ps_partkey, partsupp.ps_suppkey) }
| └─StreamExchange { dist: HashShard(part.p_partkey) }
| └─StreamFilter { predicate: (part.p_brand <> 'Brand#45':Varchar) AND (Not((part.p_type >= 'SMALL PLATED':Varchar)) OR Not((part.p_type < 'SMALL PLATEE':Varchar))) AND In(part.p_size, 19:Int32, 17:Int32, 16:Int32, 23:Int32, 10:Int32, 4:Int32, 38:Int32, 11:Int32) }
| └─StreamTableScan { table: part, columns: [part.p_partkey, part.p_brand, part.p_type, part.p_size], pk: [part.p_partkey], dist: UpstreamHashShard(part.p_partkey) }
└─StreamExchange { dist: HashShard(supplier.s_suppkey) }
└─StreamProject { exprs: [supplier.s_suppkey] }
└─StreamFilter { predicate: Like(supplier.s_comment, '%Customer%Complaints%':Varchar) }
└─StreamTableScan { table: supplier, columns: [supplier.s_suppkey, supplier.s_comment], pk: [supplier.s_suppkey], dist: UpstreamHashShard(supplier.s_suppkey) }
└─StreamHashJoin { type: LeftAnti, predicate: partsupp.ps_suppkey = supplier.s_suppkey, output: [part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey, partsupp.ps_partkey, part.p_partkey] }
├─StreamExchange { dist: HashShard(partsupp.ps_suppkey) }
| └─StreamHashJoin { type: Inner, predicate: partsupp.ps_partkey = part.p_partkey, output: [partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size, partsupp.ps_partkey, part.p_partkey] }
| ├─StreamExchange { dist: HashShard(partsupp.ps_partkey) }
| | └─StreamTableScan { table: partsupp, columns: [partsupp.ps_partkey, partsupp.ps_suppkey], pk: [partsupp.ps_partkey, partsupp.ps_suppkey], dist: UpstreamHashShard(partsupp.ps_partkey, partsupp.ps_suppkey) }
| └─StreamExchange { dist: HashShard(part.p_partkey) }
| └─StreamFilter { predicate: (part.p_brand <> 'Brand#45':Varchar) AND (Not((part.p_type >= 'SMALL PLATED':Varchar)) OR Not((part.p_type < 'SMALL PLATEE':Varchar))) AND In(part.p_size, 19:Int32, 17:Int32, 16:Int32, 23:Int32, 10:Int32, 4:Int32, 38:Int32, 11:Int32) }
| └─StreamTableScan { table: part, columns: [part.p_partkey, part.p_brand, part.p_type, part.p_size], pk: [part.p_partkey], dist: UpstreamHashShard(part.p_partkey) }
└─StreamExchange { dist: HashShard(supplier.s_suppkey) }
└─StreamProject { exprs: [supplier.s_suppkey] }
└─StreamFilter { predicate: Like(supplier.s_comment, '%Customer%Complaints%':Varchar) }
└─StreamTableScan { table: supplier, columns: [supplier.s_suppkey, supplier.s_comment], pk: [supplier.s_suppkey], dist: UpstreamHashShard(supplier.s_suppkey) }
stream_dist_plan: |
Fragment 0
StreamMaterialize { columns: [p_brand, p_type, p_size, supplier_cnt], pk_columns: [p_brand, p_type, p_size], order_descs: [supplier_cnt, p_brand, p_type, p_size], pk_conflict: "no check" }
materialized table: 4294967294
StreamProject { exprs: [part.p_brand, part.p_type, part.p_size, count(partsupp.ps_suppkey)] }
StreamHashAgg { group_key: [part.p_brand, part.p_type, part.p_size], aggs: [count, count(partsupp.ps_suppkey)] }
StreamProject { exprs: [part.p_brand, part.p_type, part.p_size, count(distinct partsupp.ps_suppkey)] }
StreamHashAgg { group_key: [part.p_brand, part.p_type, part.p_size], aggs: [count, count(distinct partsupp.ps_suppkey)] }
result table: 0, state tables: []
StreamExchange Hash([0, 1, 2]) from 1
Fragment 1
StreamProject { exprs: [part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey] }
StreamHashAgg { group_key: [part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey], aggs: [count] }
result table: 1, state tables: []
StreamHashJoin { type: LeftAnti, predicate: partsupp.ps_suppkey = supplier.s_suppkey, output: [part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey, partsupp.ps_partkey, part.p_partkey] }
left table: 2, right table 4, left degree table: 3, right degree table: 5,
StreamExchange Hash([0]) from 2
StreamExchange Hash([0]) from 5
StreamHashJoin { type: LeftAnti, predicate: partsupp.ps_suppkey = supplier.s_suppkey, output: [part.p_brand, part.p_type, part.p_size, partsupp.ps_suppkey, partsupp.ps_partkey, part.p_partkey] }
left table: 2, right table 4, left degree table: 3, right degree table: 5,
StreamExchange Hash([0]) from 2
StreamExchange Hash([0]) from 5
Fragment 2
StreamHashJoin { type: Inner, predicate: partsupp.ps_partkey = part.p_partkey, output: [partsupp.ps_suppkey, part.p_brand, part.p_type, part.p_size, partsupp.ps_partkey, part.p_partkey] }
Expand All @@ -2750,8 +2745,7 @@
Upstream
BatchPlanNode
Table 0 { columns: [part_p_brand, part_p_type, part_p_size, count, count(partsupp_ps_suppkey)], primary key: [$0 ASC, $1 ASC, $2 ASC], value indices: [3, 4], distribution key: [0, 1, 2] }
Table 1 { columns: [part_p_brand, part_p_type, part_p_size, partsupp_ps_suppkey, count], primary key: [$0 ASC, $1 ASC, $2 ASC, $3 ASC], value indices: [4], distribution key: [3] }
Table 0 { columns: [part_p_brand, part_p_type, part_p_size, count, count(distinct partsupp_ps_suppkey)], primary key: [$0 ASC, $1 ASC, $2 ASC], value indices: [3, 4], distribution key: [0, 1, 2] }
Table 2 { columns: [partsupp_ps_suppkey, part_p_brand, part_p_type, part_p_size, partsupp_ps_partkey, part_p_partkey], primary key: [$0 ASC, $4 ASC, $5 ASC], value indices: [0, 1, 2, 3, 4, 5], distribution key: [0] }
Table 3 { columns: [partsupp_ps_suppkey, partsupp_ps_partkey, part_p_partkey, _degree], primary key: [$0 ASC, $1 ASC, $2 ASC], value indices: [3], distribution key: [0] }
Table 4 { columns: [supplier_s_suppkey], primary key: [$0 ASC], value indices: [0], distribution key: [0] }
Expand Down
5 changes: 4 additions & 1 deletion src/frontend/src/optimizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,10 @@ impl PlanRoot {
plan = self.optimize_by_rules(
plan,
"Convert Distinct Aggregation".to_string(),
vec![UnionToDistinctRule::create(), DistinctAggRule::create()],
vec![
UnionToDistinctRule::create(),
DistinctAggRule::create(for_stream),
],
ApplyOrder::TopDown,
);

Expand Down
19 changes: 14 additions & 5 deletions src/frontend/src/optimizer/rule/distinct_agg_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,22 @@ use crate::optimizer::PlanRef;
use crate::utils::{ColIndexMapping, Condition};

/// Transform distinct aggregates to `LogicalAgg` -> `LogicalAgg` -> `Expand` -> `Input`.
pub struct DistinctAggRule {}
pub struct DistinctAggRule {
for_stream: bool,
}

impl Rule for DistinctAggRule {
fn apply(&self, plan: PlanRef) -> Option<PlanRef> {
let agg: &LogicalAgg = plan.as_logical_agg()?;
let (mut agg_calls, mut agg_group_keys, input) = agg.clone().decompose();
let original_group_keys_len = agg_group_keys.len();

if self.for_stream && !agg_group_keys.is_empty() {
// Due to performance issue, we don't do 2-phase agg for stream distinct agg with group
// by. See https://github.com/risingwavelabs/risingwave/issues/7271 for more.
return None;
}

let original_group_keys_len = agg_group_keys.len();
let (node, flag_values, has_expand) =
Self::build_expand(input, &mut agg_group_keys, &mut agg_calls)?;
let mid_agg = Self::build_middle_agg(node, agg_group_keys, agg_calls.clone(), has_expand);
Expand All @@ -50,8 +59,8 @@ impl Rule for DistinctAggRule {
}

impl DistinctAggRule {
pub fn create() -> BoxedRule {
Box::new(DistinctAggRule {})
pub fn create(for_stream: bool) -> BoxedRule {
Box::new(DistinctAggRule { for_stream })
}

/// Construct `Expand` for distinct aggregates.
Expand Down Expand Up @@ -110,7 +119,7 @@ impl DistinctAggRule {

let n_different_distinct = distinct_aggs
.iter()
.unique_by(|agg_call| agg_call.input_indices())
.unique_by(|agg_call| agg_call.input_indices()[0])
.count();
assert_ne!(n_different_distinct, 0); // since `distinct_aggs` is not empty here
if n_different_distinct == 1 {
Expand Down
2 changes: 1 addition & 1 deletion src/stream/src/executor/agg_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub struct AggExecutorArgs<S: StateStore> {
pub storages: Vec<AggStateStorage<S>>,
pub result_table: StateTable<S>,
pub distinct_dedup_tables: HashMap<usize, StateTable<S>>,
pub watermark_epoch: AtomicU64Ref,

// extra
pub extra: Option<AggExecutorArgsExtra>,
Expand All @@ -53,5 +54,4 @@ pub struct AggExecutorArgsExtra {
// things only used by hash agg currently
pub metrics: Arc<StreamingMetrics>,
pub chunk_size: usize,
pub watermark_epoch: AtomicU64Ref,
}
22 changes: 2 additions & 20 deletions src/stream/src/executor/aggregation/agg_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::collections::HashMap;
use std::fmt::Debug;

use itertools::Itertools;
Expand All @@ -26,7 +25,7 @@ use risingwave_common::util::iter_util::ZipEqFast;
use risingwave_storage::StateStore;

use super::agg_state::{AggState, AggStateStorage};
use super::{AggCall, DistinctDeduplicater};
use super::AggCall;
use crate::common::table::state_table::StateTable;
use crate::executor::error::StreamExecutorResult;
use crate::executor::PkIndices;
Expand All @@ -39,9 +38,6 @@ pub struct AggGroup<S: StateStore> {
/// Current managed states for all [`AggCall`]s.
states: Vec<AggState<S>>,

/// Distinct deduplicater to deduplicate input rows for each distinct agg call.
distinct_dedup: DistinctDeduplicater<S>,

/// Previous outputs of managed states. Initializing with `None`.
prev_outputs: Option<OwnedRow>,
}
Expand Down Expand Up @@ -102,7 +98,6 @@ impl<S: StateStore> AggGroup<S> {
Ok(Self {
group_key,
states,
distinct_dedup: DistinctDeduplicater::new(agg_calls),
prev_outputs,
})
}
Expand All @@ -127,24 +122,13 @@ impl<S: StateStore> AggGroup<S> {

/// Apply input chunk to all managed agg states.
/// `visibilities` contains the row visibility of the input chunk for each agg call.
pub async fn apply_chunk(
pub fn apply_chunk(
&mut self,
storages: &mut [AggStateStorage<S>],
ops: &[Op],
columns: &[Column],
visibilities: Vec<Option<Bitmap>>,
distinct_dedup_tables: &mut HashMap<usize, StateTable<S>>,
) -> StreamExecutorResult<()> {
let visibilities = self
.distinct_dedup
.dedup_chunk(
ops,
columns,
visibilities,
distinct_dedup_tables,
self.group_key.as_ref(),
)
.await?;
let columns = columns.iter().map(|col| col.array_ref()).collect_vec();
for ((state, storage), visibility) in self
.states
Expand All @@ -163,7 +147,6 @@ impl<S: StateStore> AggGroup<S> {
pub async fn flush_state_if_needed(
&self,
storages: &mut [AggStateStorage<S>],
distinct_dedup_tables: &mut HashMap<usize, StateTable<S>>,
) -> StreamExecutorResult<()> {
futures::future::try_join_all(self.states.iter().zip_eq_fast(storages).filter_map(
|(state, storage)| match state {
Expand All @@ -175,7 +158,6 @@ impl<S: StateStore> AggGroup<S> {
},
))
.await?;
self.distinct_dedup.flush(distinct_dedup_tables)?;
Ok(())
}

Expand Down
Loading

0 comments on commit 3c5bf28

Please sign in to comment.