Skip to content

Commit

Permalink
impl emit_partitioned.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rachelint committed Sep 18, 2024
1 parent f24f3c3 commit bec7a3a
Showing 1 changed file with 40 additions and 15 deletions.
55 changes: 40 additions & 15 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ use arrow::array::*;
use arrow::datatypes::SchemaRef;
use arrow_schema::SortOptions;
use datafusion_common::utils::get_arrayref_at_indices;
use datafusion_common::{internal_datafusion_err, DataFusionError, Result};
use datafusion_common::{
arrow_datafusion_err, internal_datafusion_err, DataFusionError, Result,
};
use datafusion_execution::disk_manager::RefCountedTempFile;
use datafusion_execution::memory_pool::proxy::VecAllocExt;
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
Expand Down Expand Up @@ -868,7 +870,7 @@ impl GroupedHashAggregateStream {
get_filter_at_indices(opt_filter, &batch_indices)
})
.collect::<Result<Vec<_>>>()?;

// Update the accumulators of each partition
for (part_idx, part_start_end) in offsets.windows(2).enumerate() {
let (offset, length) =
Expand All @@ -886,7 +888,8 @@ impl GroupedHashAggregateStream {
.map(|array| array.slice(offset, length))
.collect::<Vec<_>>();

let part_opt_filter = opt_filter.as_ref().map(|f| f.slice(offset, length));
let part_opt_filter =
opt_filter.as_ref().map(|f| f.slice(offset, length));
let part_opt_filter =
part_opt_filter.as_ref().map(|filter| filter.as_boolean());

Expand Down Expand Up @@ -980,21 +983,29 @@ impl GroupedHashAggregateStream {

/// Create an output RecordBatch with the group keys and
/// accumulator states/values specified in emit_to
fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<RecordBatch> {
self.emit_single(emit_to, spilling)
fn emit(&mut self, emit_to: EmitTo, spilling: bool) -> Result<Vec<RecordBatch>> {
if !self.group_values.is_partitioned() {
self.emit_single(emit_to, spilling)
} else {
self.emit_partitioned(emit_to)
}
}

/// Create an output RecordBatch with the group keys and
/// accumulator states/values specified in emit_to
fn emit_single(&mut self, emit_to: EmitTo, spilling: bool) -> Result<RecordBatch> {
fn emit_single(
&mut self,
emit_to: EmitTo,
spilling: bool,
) -> Result<Vec<RecordBatch>> {
let schema = if spilling {
Arc::clone(&self.spill_state.spill_schema)
} else {
self.schema()
};

if self.group_values.is_empty() {
return Ok(RecordBatch::new_empty(schema));
return Ok(Vec::new());
}

let group_values = self.group_values.as_single_mut();
Expand Down Expand Up @@ -1023,10 +1034,11 @@ impl GroupedHashAggregateStream {
// over the target memory size after emission, we can emit again rather than returning Err.
let _ = self.update_memory_reservation();
let batch = RecordBatch::try_new(schema, output)?;
Ok(batch)

Ok(vec![batch])
}

fn emit_partitioned(&mut self, emit_to: EmitTo, spilling: bool) -> Result<Vec<RecordBatch>> {
fn emit_partitioned(&mut self, emit_to: EmitTo) -> Result<Vec<RecordBatch>> {
assert!(
self.mode == AggregateMode::Partial
&& matches!(self.group_ordering, GroupOrdering::None)
Expand All @@ -1035,22 +1047,35 @@ impl GroupedHashAggregateStream {
let schema = self.schema();

if self.group_values.is_empty() {
return Ok(RecordBatch::new_empty(schema));
return Ok(Vec::new());
}

let group_values = self.group_values.as_partitioned_mut();
let mut output = group_values.emit(emit_to)?;
let mut partitioned_outputs = group_values.emit(emit_to)?;

// Next output each aggregate value
for acc in self.accumulators[0].iter_mut() {
output.extend(acc.state(emit_to)?)
for (output, accs) in partitioned_outputs
.iter_mut()
.zip(self.accumulators.iter_mut())
{
for acc in accs.iter_mut() {
output.extend(acc.state(emit_to)?);
}
}

// emit reduces the memory usage. Ignore Err from update_memory_reservation. Even if it is
// over the target memory size after emission, we can emit again rather than returning Err.
let _ = self.update_memory_reservation();
let batch = RecordBatch::try_new(schema, output)?;
Ok(batch)

let batch_parts = partitioned_outputs
.into_iter()
.map(|part| {
RecordBatch::try_new(schema.clone(), part)
.map_err(|e| arrow_datafusion_err!(e))
})
.collect::<Result<Vec<_>>>()?;

Ok(batch_parts)
}

/// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly
Expand Down

0 comments on commit bec7a3a

Please sign in to comment.