Skip to content

Commit 6786203

Browse files
yjshenalamb
andauthored
Grouped Aggregate in row format (#2375)
* first move: re-group aggregates functionalities in core/physical_p/aggregates * basic accumulators * main updating procedure * output as record batch * aggregate with row state * make row non-optional * address comments, add docs, part fix #2455 * Apply suggestions from code review Co-authored-by: Andrew Lamb <[email protected]> Co-authored-by: Andrew Lamb <[email protected]>
1 parent 32cf354 commit 6786203

21 files changed

+1581
-50
lines changed

datafusion/core/Cargo.toml

+2-5
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,6 @@ force_hash_collisions = []
4848
jit = ["datafusion-jit"]
4949
pyarrow = ["pyo3", "arrow/pyarrow", "datafusion-common/pyarrow"]
5050
regex_expressions = ["datafusion-physical-expr/regex_expressions"]
51-
# Used to enable row format experiment
52-
row = ["datafusion-row"]
5351
# Used to enable scheduler
5452
scheduler = ["rayon"]
5553
simd = ["arrow/simd"]
@@ -66,7 +64,7 @@ datafusion-data-access = { path = "../../data-access", version = "1.0.0" }
6664
datafusion-expr = { path = "../expr", version = "7.0.0" }
6765
datafusion-jit = { path = "../jit", version = "7.0.0", optional = true }
6866
datafusion-physical-expr = { path = "../physical-expr", version = "7.0.0" }
69-
datafusion-row = { path = "../row", version = "7.0.0", optional = true }
67+
datafusion-row = { path = "../row", version = "7.0.0" }
7068
futures = "0.3"
7169
hashbrown = { version = "0.12", features = ["raw"] }
7270
lazy_static = { version = "^1.4.0" }
@@ -134,8 +132,7 @@ name = "sql_planner"
134132
[[bench]]
135133
harness = false
136134
name = "jit"
137-
required-features = ["row", "jit"]
135+
required-features = ["jit"]
138136

139137
[[test]]
140138
name = "row"
141-
required-features = ["row"]

datafusion/core/benches/aggregate_query_sql.rs

+10
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,16 @@ fn criterion_benchmark(c: &mut Criterion) {
133133
)
134134
})
135135
});
136+
137+
c.bench_function("aggregate_query_group_by_u64_multiple_keys", |b| {
138+
b.iter(|| {
139+
query(
140+
ctx.clone(),
141+
"SELECT u64_wide, utf8, MIN(f64), AVG(f64), COUNT(f64) \
142+
FROM t GROUP BY u64_wide, utf8",
143+
)
144+
})
145+
});
136146
}
137147

138148
criterion_group!(benches, criterion_benchmark);

datafusion/core/src/lib.rs

-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,6 @@ pub use datafusion_data_access;
233233
pub use datafusion_expr as logical_expr;
234234
pub use datafusion_physical_expr as physical_expr;
235235

236-
#[cfg(feature = "row")]
237236
pub use datafusion_row as row;
238237

239238
#[cfg(feature = "jit")]

datafusion/core/src/physical_plan/aggregates/hash.rs

+3-22
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ use futures::{
2828
};
2929

3030
use crate::error::Result;
31-
use crate::physical_plan::aggregates::{AccumulatorItem, AggregateMode};
31+
use crate::physical_plan::aggregates::{
32+
evaluate, evaluate_many, AccumulatorItem, AggregateMode,
33+
};
3234
use crate::physical_plan::hash_utils::create_hashes;
3335
use crate::physical_plan::metrics::{BaselineMetrics, RecordOutput};
3436
use crate::physical_plan::{aggregates, AggregateExpr, PhysicalExpr};
@@ -380,27 +382,6 @@ impl std::fmt::Debug for Accumulators {
380382
}
381383
}
382384

383-
/// Evaluates expressions against a record batch.
384-
fn evaluate(
385-
expr: &[Arc<dyn PhysicalExpr>],
386-
batch: &RecordBatch,
387-
) -> Result<Vec<ArrayRef>> {
388-
expr.iter()
389-
.map(|expr| expr.evaluate(batch))
390-
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
391-
.collect::<Result<Vec<_>>>()
392-
}
393-
394-
/// Evaluates expressions against a record batch.
395-
fn evaluate_many(
396-
expr: &[Vec<Arc<dyn PhysicalExpr>>],
397-
batch: &RecordBatch,
398-
) -> Result<Vec<Vec<ArrayRef>>> {
399-
expr.iter()
400-
.map(|expr| evaluate(expr, batch))
401-
.collect::<Result<Vec<_>>>()
402-
}
403-
404385
/// Create a RecordBatch with all group keys and accumulator' states or values.
405386
fn create_batch_from_map(
406387
mode: &AggregateMode,

datafusion/core/src/physical_plan/aggregates/mod.rs

+67
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use crate::physical_plan::{
2929
};
3030
use arrow::array::ArrayRef;
3131
use arrow::datatypes::{Field, Schema, SchemaRef};
32+
use arrow::record_batch::RecordBatch;
3233
use datafusion_common::Result;
3334
use datafusion_expr::Accumulator;
3435
use datafusion_physical_expr::expressions::Column;
@@ -40,9 +41,13 @@ use std::sync::Arc;
4041

4142
mod hash;
4243
mod no_grouping;
44+
mod row_hash;
4345

46+
use crate::physical_plan::aggregates::row_hash::GroupedHashAggregateStreamV2;
4447
pub use datafusion_expr::AggregateFunction;
48+
use datafusion_physical_expr::aggregate::row_accumulator::RowAccumulator;
4549
pub use datafusion_physical_expr::expressions::create_aggregate_expr;
50+
use datafusion_row::{row_supported, RowType};
4651

4752
/// Hash aggregate modes
4853
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
@@ -142,6 +147,12 @@ impl AggregateExec {
142147
pub fn input_schema(&self) -> SchemaRef {
143148
self.input_schema.clone()
144149
}
150+
151+
fn row_aggregate_supported(&self) -> bool {
152+
let group_schema = group_schema(&self.schema, self.group_expr.len());
153+
row_supported(&group_schema, RowType::Compact)
154+
&& accumulator_v2_supported(&self.aggr_expr)
155+
}
145156
}
146157

147158
impl ExecutionPlan for AggregateExec {
@@ -212,6 +223,15 @@ impl ExecutionPlan for AggregateExec {
212223
input,
213224
baseline_metrics,
214225
)?))
226+
} else if self.row_aggregate_supported() {
227+
Ok(Box::pin(GroupedHashAggregateStreamV2::new(
228+
self.mode,
229+
self.schema.clone(),
230+
group_expr,
231+
self.aggr_expr.clone(),
232+
input,
233+
baseline_metrics,
234+
)?))
215235
} else {
216236
Ok(Box::pin(GroupedHashAggregateStream::new(
217237
self.mode,
@@ -315,6 +335,11 @@ fn create_schema(
315335
Ok(Schema::new(fields))
316336
}
317337

338+
fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef {
339+
let group_fields = schema.fields()[0..group_count].to_vec();
340+
Arc::new(Schema::new(group_fields))
341+
}
342+
318343
/// returns physical expressions to evaluate against a batch
319344
/// The expressions are different depending on `mode`:
320345
/// * Partial: AggregateExpr::expressions
@@ -362,6 +387,7 @@ fn merge_expressions(
362387
}
363388

364389
pub(crate) type AccumulatorItem = Box<dyn Accumulator>;
390+
pub(crate) type AccumulatorItemV2 = Box<dyn RowAccumulator>;
365391

366392
fn create_accumulators(
367393
aggr_expr: &[Arc<dyn AggregateExpr>],
@@ -372,6 +398,26 @@ fn create_accumulators(
372398
.collect::<datafusion_common::Result<Vec<_>>>()
373399
}
374400

401+
fn accumulator_v2_supported(aggr_expr: &[Arc<dyn AggregateExpr>]) -> bool {
402+
aggr_expr
403+
.iter()
404+
.all(|expr| expr.row_accumulator_supported())
405+
}
406+
407+
fn create_accumulators_v2(
408+
aggr_expr: &[Arc<dyn AggregateExpr>],
409+
) -> datafusion_common::Result<Vec<AccumulatorItemV2>> {
410+
let mut state_index = 0;
411+
aggr_expr
412+
.iter()
413+
.map(|expr| {
414+
let result = expr.create_row_accumulator(state_index);
415+
state_index += expr.state_fields().unwrap().len();
416+
result
417+
})
418+
.collect::<datafusion_common::Result<Vec<_>>>()
419+
}
420+
375421
/// returns a vector of ArrayRefs, where each entry corresponds to either the
376422
/// final value (mode = Final) or states (mode = Partial)
377423
fn finalize_aggregation(
@@ -402,6 +448,27 @@ fn finalize_aggregation(
402448
}
403449
}
404450

451+
/// Evaluates expressions against a record batch.
452+
fn evaluate(
453+
expr: &[Arc<dyn PhysicalExpr>],
454+
batch: &RecordBatch,
455+
) -> Result<Vec<ArrayRef>> {
456+
expr.iter()
457+
.map(|expr| expr.evaluate(batch))
458+
.map(|r| r.map(|v| v.into_array(batch.num_rows())))
459+
.collect::<Result<Vec<_>>>()
460+
}
461+
462+
/// Evaluates expressions against a record batch.
463+
fn evaluate_many(
464+
expr: &[Vec<Arc<dyn PhysicalExpr>>],
465+
batch: &RecordBatch,
466+
) -> Result<Vec<Vec<ArrayRef>>> {
467+
expr.iter()
468+
.map(|expr| evaluate(expr, batch))
469+
.collect::<Result<Vec<_>>>()
470+
}
471+
405472
#[cfg(test)]
406473
mod tests {
407474
use crate::execution::context::TaskContext;

0 commit comments

Comments
 (0)