diff --git a/datafusion/physical-plan/src/aggregates/group_values/mod.rs b/datafusion/physical-plan/src/aggregates/group_values/mod.rs index eca53c889b06..866349ab3f79 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/mod.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/mod.rs @@ -162,9 +162,16 @@ pub trait GroupValues: Send { fn clear_shrink(&mut self, batch: &RecordBatch); } -pub fn new_group_values(schema: SchemaRef, partitioning_group_values: bool, num_partitions: usize) -> Result { +pub fn new_group_values( + schema: SchemaRef, + partitioning_group_values: bool, + num_partitions: usize, +) -> Result { let group_values = if partitioning_group_values && schema.fields.len() > 1 { - GroupValuesLike::Partitioned(Box::new(PartitionedGroupValuesRows::try_new(schema, num_partitions)?)) + GroupValuesLike::Partitioned(Box::new(PartitionedGroupValuesRows::try_new( + schema, + num_partitions, + )?)) } else { GroupValuesLike::Single(new_single_group_values(schema)?) }; diff --git a/datafusion/physical-plan/src/aggregates/row_hash.rs b/datafusion/physical-plan/src/aggregates/row_hash.rs index cdd6ac8e59e3..e9362fdd05e4 100644 --- a/datafusion/physical-plan/src/aggregates/row_hash.rs +++ b/datafusion/physical-plan/src/aggregates/row_hash.rs @@ -64,6 +64,8 @@ pub(crate) enum ExecutionState { /// When producing output, the remaining rows to output are stored /// here and are sliced off as needed in batch_size chunks ProducingOutput(RecordBatch), + + ProducingPartitionedOutput(Vec>), /// Produce intermediate aggregate state for each input row without /// aggregation. /// @@ -76,6 +78,29 @@ pub(crate) enum ExecutionState { use super::order::GroupOrdering; use super::AggregateExec; +struct PartitionedOutput { + batches: Vec>, + current_idx: usize, + exhausted: bool +} + +impl PartitionedOutput { + pub fn new(batches: Vec) -> Self { + let batches = batches.into_iter().map(|batch| Some(batch)).collect(); + + Self { + batches, + current_idx: 0, + exhausted: false, + } + } + + pub fn next_batch(&mut self) -> Option { + + } +} + + /// This encapsulates the spilling state struct SpillState { // ======================================================================== @@ -677,7 +702,9 @@ impl Stream for GroupedHashAggregateStream { } if let Some(to_emit) = self.group_ordering.emit_to() { - let batch = extract_ok!(self.emit(to_emit, false)); + let mut batch = extract_ok!(self.emit(to_emit, false)); + assert_eq!(batch.len(), 1); + let batch = batch.pop().unwrap(); self.exec_state = ExecutionState::ProducingOutput(batch); timer.done(); // make sure the exec_state just set is not overwritten below @@ -759,6 +786,8 @@ impl Stream for GroupedHashAggregateStream { let _ = self.update_memory_reservation(); return Poll::Ready(None); } + + ExecutionState::ProducingPartitionedOutput(_) => todo!(), } } } @@ -1101,7 +1130,9 @@ impl GroupedHashAggregateStream { /// Emit all rows, sort them, and store them on disk. fn spill(&mut self) -> Result<()> { - let emit = self.emit(EmitTo::All, true)?; + let mut emit = self.emit(EmitTo::All, true)?; + assert_eq!(emit.len(), 1); + let emit = emit.pop().unwrap(); let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?; let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?; // TODO: slice large `sorted` and write to multiple files in parallel @@ -1138,9 +1169,16 @@ impl GroupedHashAggregateStream { && matches!(self.mode, AggregateMode::Partial) && self.update_memory_reservation().is_err() { - let n = self.group_values.len() / self.batch_size * self.batch_size; - let batch = self.emit(EmitTo::First(n), false)?; - self.exec_state = ExecutionState::ProducingOutput(batch); + if !self.group_values.is_partitioned() { + let n = self.group_values.len() / self.batch_size * self.batch_size; + let mut batch = self.emit(EmitTo::First(n), false)?; + let batch = batch.pop().unwrap(); + self.exec_state = ExecutionState::ProducingOutput(batch); + } else { + let batches = self.emit(EmitTo::All, false)?; + let batches = batches.into_iter().map(|batch| Some(batch)).collect(); + self.exec_state = ExecutionState::ProducingPartitionedOutput(batches); + } } Ok(()) } @@ -1150,7 +1188,9 @@ impl GroupedHashAggregateStream { /// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully /// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`]. fn update_merged_stream(&mut self) -> Result<()> { - let batch = self.emit(EmitTo::All, true)?; + let mut batch = self.emit(EmitTo::All, true)?; + assert_eq!(batch.len(), 1); + let batch = batch.pop().unwrap(); // clear up memory for streaming_merge self.clear_all(); self.update_memory_reservation()?; @@ -1198,8 +1238,15 @@ impl GroupedHashAggregateStream { let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); let timer = elapsed_compute.timer(); self.exec_state = if self.spill_state.spills.is_empty() { - let batch = self.emit(EmitTo::All, false)?; - ExecutionState::ProducingOutput(batch) + if !self.group_values.is_partitioned() { + let mut batch = self.emit(EmitTo::All, false)?; + let batch = batch.pop().unwrap(); + ExecutionState::ProducingOutput(batch) + } else { + let batches = self.emit(EmitTo::All, false)?; + let batches = batches.into_iter().map(|batch| Some(batch)).collect(); + ExecutionState::ProducingPartitionedOutput(batches) + } } else { // If spill files exist, stream-merge them. self.update_merged_stream()?; @@ -1231,8 +1278,15 @@ impl GroupedHashAggregateStream { fn switch_to_skip_aggregation(&mut self) -> Result<()> { if let Some(probe) = self.skip_aggregation_probe.as_mut() { if probe.should_skip() { - let batch = self.emit(EmitTo::All, false)?; - self.exec_state = ExecutionState::ProducingOutput(batch); + if !self.group_values.is_partitioned() { + let mut batch = self.emit(EmitTo::All, false)?; + let batch = batch.pop().unwrap(); + self.exec_state = ExecutionState::ProducingOutput(batch); + } else { + let batches = self.emit(EmitTo::All, false)?; + let batches = batches.into_iter().map(|batch| Some(batch)).collect(); + self.exec_state = ExecutionState::ProducingPartitionedOutput(batches); + } } }