Skip to content

Commit

Permalink
feat: apply relational refactor for hash agg (max, min) (#2999)
Browse files Browse the repository at this point in the history
* feat: two closure can not get mut ref of same variable

* use Arc::Mutex to wrap the state table

* roll back string agg

* add StateTable to get_output

* finish basic coding (unit test failed)

* finish basic coding

* fix bug

* show case

* use empty Row for scan

* tweak
  • Loading branch information
BowenXiao1999 authored Jun 20, 2022
1 parent 35bb16a commit 9169436
Show file tree
Hide file tree
Showing 10 changed files with 473 additions and 203 deletions.
8 changes: 4 additions & 4 deletions src/storage/src/table/state_table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,11 @@ impl<S: StateStore> StateTable<S> {
}

/// This function scans rows from the relational table with specific `pk_prefix`.
pub async fn iter_with_pk_prefix<'a>(
&'a self,
pk_prefix: &'a Row,
pub async fn iter_with_pk_prefix(
&self,
pk_prefix: &Row,
epoch: u64,
) -> StorageResult<RowStream<'a, S>> {
) -> StorageResult<RowStream<'_, S>> {
let order_types = &self.pk_serializer.clone().into_order_types()[0..pk_prefix.size()];
let prefix_serializer = OrderedRowSerializer::new(order_types.into());
let encoded_prefix = serialize_pk(pk_prefix, &prefix_serializer);
Expand Down
37 changes: 27 additions & 10 deletions src/stream/src/executor/aggregation/agg_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use std::fmt::Debug;
use itertools::Itertools;
use risingwave_common::array::{ArrayBuilderImpl, Op};
use risingwave_common::types::Datum;
use risingwave_storage::table::state_table::StateTable;
use risingwave_storage::StateStore;

use crate::executor::error::StreamExecutorResult;
Expand Down Expand Up @@ -44,9 +45,13 @@ impl<S: StateStore> Debug for AggState<S> {
pub const ROW_COUNT_COLUMN: usize = 0;

impl<S: StateStore> AggState<S> {
pub async fn row_count(&mut self, epoch: u64) -> StreamExecutorResult<i64> {
pub async fn row_count(
&mut self,
epoch: u64,
state_table: &StateTable<S>,
) -> StreamExecutorResult<i64> {
Ok(self.managed_states[ROW_COUNT_COLUMN]
.get_output(epoch)
.get_output(epoch, state_table)
.await?
.map(|x| *x.as_int64())
.unwrap_or(0))
Expand All @@ -71,14 +76,18 @@ impl<S: StateStore> AggState<S> {
/// changes to the state. If the state is already marked dirty in this epoch, this function does
/// no-op.
/// After calling this function, `self.is_dirty()` will return `true`.
pub async fn may_mark_as_dirty(&mut self, epoch: u64) -> StreamExecutorResult<()> {
pub async fn may_mark_as_dirty(
&mut self,
epoch: u64,
state_tables: &[StateTable<S>],
) -> StreamExecutorResult<()> {
if self.is_dirty() {
return Ok(());
}

let mut outputs = vec![];
for state in &mut self.managed_states {
outputs.push(state.get_output(epoch).await?);
for (state, state_table) in self.managed_states.iter_mut().zip_eq(state_tables.iter()) {
outputs.push(state.get_output(epoch, state_table).await?);
}
self.prev_states = Some(outputs);
Ok(())
Expand All @@ -93,12 +102,15 @@ impl<S: StateStore> AggState<S> {
builders: &mut [ArrayBuilderImpl],
new_ops: &mut Vec<Op>,
epoch: u64,
state_tables: &[StateTable<S>],
) -> StreamExecutorResult<usize> {
if !self.is_dirty() {
return Ok(0);
}

let row_count = self.row_count(epoch).await?;
let row_count = self
.row_count(epoch, &state_tables[ROW_COUNT_COLUMN])
.await?;
let prev_row_count = self.prev_row_count();

trace!(
Expand All @@ -120,8 +132,12 @@ impl<S: StateStore> AggState<S> {
// previous state is empty, current state is not empty, insert one `Insert` op.
new_ops.push(Op::Insert);

for (builder, state) in builders.iter_mut().zip_eq(self.managed_states.iter_mut()) {
let data = state.get_output(epoch).await?;
for ((builder, state), state_table) in builders
.iter_mut()
.zip_eq(self.managed_states.iter_mut())
.zip_eq(state_tables.iter())
{
let data = state.get_output(epoch, state_table).await?;
trace!("append_datum (0 -> N): {:?}", &data);
builder.append_datum(&data)?;
}
Expand Down Expand Up @@ -149,12 +165,13 @@ impl<S: StateStore> AggState<S> {
new_ops.push(Op::UpdateDelete);
new_ops.push(Op::UpdateInsert);

for (builder, prev_state, cur_state) in itertools::multizip((
for (builder, prev_state, cur_state, state_table) in itertools::multizip((
builders.iter_mut(),
self.prev_states.as_ref().unwrap().iter(),
self.managed_states.iter_mut(),
state_tables.iter(),
)) {
let cur_state = cur_state.get_output(epoch).await?;
let cur_state = cur_state.get_output(epoch, state_table).await?;
trace!(
"append_datum (N -> N): prev = {:?}, cur = {:?}",
prev_state,
Expand Down
2 changes: 1 addition & 1 deletion src/stream/src/executor/aggregation/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ pub async fn generate_managed_agg_state<S: StateStore>(

if idx == ROW_COUNT_COLUMN {
// For the rowcount state, we should record the rowcount.
let output = managed_state.get_output(epoch).await?;
let output = managed_state.get_output(epoch, &state_tables[idx]).await?;
row_count = Some(output.as_ref().map(|x| *x.as_int64() as usize).unwrap_or(0));
}

Expand Down
19 changes: 12 additions & 7 deletions src/stream/src/executor/global_simple_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ impl<S: StateStore> SimpleAggExecutor<S> {
ks.clone(),
agg_call,
&key_indices,
&pk_indices,
&input_info.pk_indices,
&schema,
input.as_ref(),
));
Expand Down Expand Up @@ -142,7 +142,7 @@ impl<S: StateStore> SimpleAggExecutor<S> {
keyspace: &[Keyspace<S>],
chunk: StreamChunk,
epoch: u64,
state_tables: &[StateTable<S>],
state_tables: &mut [StateTable<S>],
) -> StreamExecutorResult<()> {
let (ops, columns, visibility) = chunk.into_inner();

Expand Down Expand Up @@ -181,12 +181,17 @@ impl<S: StateStore> SimpleAggExecutor<S> {
let states = states.as_mut().unwrap();

// 2. Mark the state as dirty by filling prev states
states.may_mark_as_dirty(epoch).await?;
states.may_mark_as_dirty(epoch, state_tables).await?;

// 3. Apply batch to each of the state (per agg_call)
for (agg_state, data) in states.managed_states.iter_mut().zip_eq(all_agg_data.iter()) {
for ((agg_state, data), state_table) in states
.managed_states
.iter_mut()
.zip_eq(all_agg_data.iter())
.zip_eq(state_tables.iter_mut())
{
agg_state
.apply_batch(&ops, visibility.as_ref(), data, epoch)
.apply_batch(&ops, visibility.as_ref(), data, epoch, state_table)
.await?;
}

Expand Down Expand Up @@ -234,7 +239,7 @@ impl<S: StateStore> SimpleAggExecutor<S> {

// --- Retrieve modified states and put the changes into the builders ---
states
.build_changes(&mut builders, &mut new_ops, epoch)
.build_changes(&mut builders, &mut new_ops, epoch, state_tables)
.await?;

let columns: Vec<Column> = builders
Expand Down Expand Up @@ -279,7 +284,7 @@ impl<S: StateStore> SimpleAggExecutor<S> {
&keyspace,
chunk,
epoch,
&state_tables,
&mut state_tables,
)
.await?;
}
Expand Down
56 changes: 33 additions & 23 deletions src/stream/src/executor/hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use risingwave_common::hash::{HashCode, HashKey};
use risingwave_common::util::hash_util::CRC32FastBuilder;
use risingwave_storage::table::state_table::StateTable;
use risingwave_storage::{Keyspace, StateStore};
use tokio::sync::RwLock;

use super::{
expect_first_barrier, pk_input_arrays, Executor, PkDataTypes, PkIndicesRef,
Expand Down Expand Up @@ -86,7 +87,7 @@ struct HashAggExecutorExtra<S: StateStore> {
/// all of the aggregation functions in this executor should depend on same group of keys
key_indices: Vec<usize>,

state_tables: Vec<StateTable<S>>,
state_tables: Arc<RwLock<Vec<StateTable<S>>>>,
}

impl<K: HashKey, S: StateStore> Executor for HashAggExecutor<K, S> {
Expand Down Expand Up @@ -125,7 +126,7 @@ impl<K: HashKey, S: StateStore> HashAggExecutor<K, S> {
ks.clone(),
agg_call,
&key_indices,
&pk_indices,
&input_info.pk_indices,
&schema,
input.as_ref(),
));
Expand All @@ -142,7 +143,7 @@ impl<K: HashKey, S: StateStore> HashAggExecutor<K, S> {
keyspace,
agg_calls,
key_indices,
state_tables,
state_tables: Arc::new(RwLock::new(state_tables)),
},
_phantom: PhantomData,
})
Expand Down Expand Up @@ -200,16 +201,16 @@ impl<K: HashKey, S: StateStore> HashAggExecutor<K, S> {
}

async fn apply_chunk(
&HashAggExecutorExtra::<S> {
&mut HashAggExecutorExtra::<S> {
ref key_indices,
ref agg_calls,
ref input_pk_indices,
ref input_schema,
ref keyspace,
ref schema,
ref state_tables,
ref mut state_tables,
..
}: &HashAggExecutorExtra<S>,
}: &mut HashAggExecutorExtra<S>,
state_map: &mut EvictableHashMap<K, Option<Box<AggState<S>>>>,
chunk: StreamChunk,
epoch: u64,
Expand Down Expand Up @@ -279,23 +280,29 @@ impl<K: HashKey, S: StateStore> HashAggExecutor<K, S> {
input_pk_data_types.clone(),
epoch,
Some(hash_code),
state_tables,
&*state_tables.read().await,
)
.await?,
),
}
};

// 2. Mark the state as dirty by filling prev states
states.may_mark_as_dirty(epoch).await?;
states
.may_mark_as_dirty(epoch, &*state_tables.read().await)
.await?;

let mut state_tables = state_tables.write().await;
// 3. Apply batch to each of the state (per agg_call)
for (agg_state, data) in
states.managed_states.iter_mut().zip_eq(all_agg_data.iter())
for ((agg_state, data), state_table) in states
.managed_states
.iter_mut()
.zip_eq(all_agg_data.iter())
.zip_eq(state_tables.iter_mut())
{
let data = data.iter().map(|d| &**d).collect_vec();
agg_state
.apply_batch(&ops, Some(&vis_map), &data, epoch)
.apply_batch(&ops, Some(&vis_map), &data, epoch, state_table)
.await?;
}

Expand Down Expand Up @@ -329,10 +336,10 @@ impl<K: HashKey, S: StateStore> HashAggExecutor<K, S> {
// --- Flush states to the state store ---
// Some state will have the correct output only after their internal states have been
// fully flushed.
let (write_batch, dirty_cnt) = {
let dirty_cnt = {
let mut write_batch = store.start_write_batch();
let mut dirty_cnt = 0;

let mut state_tables = state_tables.write().await;
for states in state_map.values_mut() {
if states.as_ref().unwrap().is_dirty() {
dirty_cnt += 1;
Expand All @@ -348,21 +355,19 @@ impl<K: HashKey, S: StateStore> HashAggExecutor<K, S> {
}
}

// Batch commit state table.
for state_table in state_tables.iter_mut() {
state_table.commit(epoch).await?;
}

(write_batch, dirty_cnt)
dirty_cnt
};

if dirty_cnt == 0 {
// Nothing to flush.
assert!(write_batch.is_empty());
return Ok(());
} else {
write_batch.ingest(epoch).await?;
// Batch commit data.
for state_table in state_tables.write().await.iter_mut() {
state_table.commit(epoch).await?;
}

let state_tables = state_tables.read().await;
// --- Produce the stream chunk ---
let mut batches = IterChunks::chunks(state_map.iter_mut(), PROCESSING_WINDOW_SIZE);
while let Some(batch) = batches.next() {
Expand All @@ -379,7 +384,12 @@ impl<K: HashKey, S: StateStore> HashAggExecutor<K, S> {
let appended = states
.as_mut()
.unwrap()
.build_changes(&mut builders[key_indices.len()..], &mut new_ops, epoch)
.build_changes(
&mut builders[key_indices.len()..],
&mut new_ops,
epoch,
&state_tables,
)
.await?;

for _ in 0..appended {
Expand Down Expand Up @@ -431,7 +441,7 @@ impl<K: HashKey, S: StateStore> HashAggExecutor<K, S> {
let msg = msg?;
match msg {
Message::Chunk(chunk) => {
Self::apply_chunk(&extra, &mut state_map, chunk, epoch).await?;
Self::apply_chunk(&mut extra, &mut state_map, chunk, epoch).await?;
}
Message::Barrier(barrier) => {
let next_epoch = barrier.epoch.curr;
Expand Down
Loading

0 comments on commit 9169436

Please sign in to comment.