Skip to content

Commit ce4e1f3

Browse files
committed
refactor: bastilla approx_quantile support
Adds bastilla wire encoding for approx_quantile. Adding support for this required modifying the AggregateExprNode proto message to support propigating multiple LogicalExprNode aggregate arguments - all the existing aggregations take a single argument, so this wasn't needed before. This commit adds "repeated" to the expr field, which I believe is backwards compatible as described here: https://developers.google.com/protocol-buffers/docs/proto3#updating Specifically, adding "repeated" to an existing message field: "For ... message fields, optional is compatible with repeated" No existing tests needed fixing, and a new roundtrip test is included that covers the change to allow multiple expr.
1 parent d8ffb99 commit ce4e1f3

File tree

5 files changed

+34
-7
lines changed

5 files changed

+34
-7
lines changed

ballista/rust/core/proto/ballista.proto

+2-1
Original file line numberDiff line numberDiff line change
@@ -173,11 +173,12 @@ enum AggregateFunction {
173173
VARIANCE_POP=8;
174174
STDDEV=9;
175175
STDDEV_POP=10;
176+
APPROX_QUANTILE = 11;
176177
}
177178

178179
message AggregateExprNode {
179180
AggregateFunction aggr_function = 1;
180-
LogicalExprNode expr = 2;
181+
repeated LogicalExprNode expr = 2;
181182
}
182183

183184
enum BuiltInWindowFunction {

ballista/rust/core/src/serde/logical_plan/from_proto.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -962,7 +962,11 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
962962

963963
Ok(Expr::AggregateFunction {
964964
fun,
965-
args: vec![parse_required_expr(&expr.expr)?],
965+
args: expr
966+
.expr
967+
.iter()
968+
.map(|e| e.try_into())
969+
.collect::<Result<Vec<_>, _>>()?,
966970
distinct: false, //TODO
967971
})
968972
}

ballista/rust/core/src/serde/logical_plan/mod.rs

+14-1
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ mod roundtrip_tests {
2424
use super::super::{super::error::Result, protobuf};
2525
use crate::error::BallistaError;
2626
use core::panic;
27-
use datafusion::logical_plan::Repartition;
2827
use datafusion::{
2928
arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit},
3029
datasource::object_store::local::LocalFileSystem,
@@ -37,6 +36,7 @@ mod roundtrip_tests {
3736
scalar::ScalarValue,
3837
sql::parser::FileType,
3938
};
39+
use datafusion::{logical_plan::Repartition, physical_plan::aggregates};
4040
use protobuf::arrow_type;
4141
use std::{convert::TryInto, sync::Arc};
4242

@@ -988,4 +988,17 @@ mod roundtrip_tests {
988988

989989
Ok(())
990990
}
991+
992+
#[test]
993+
fn roundtrip_approx_quantile() -> Result<()> {
994+
let test_expr = Expr::AggregateFunction {
995+
fun: aggregates::AggregateFunction::ApproxQuantile,
996+
args: vec![col("bananas"), lit(0.42)],
997+
distinct: false,
998+
};
999+
1000+
roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr);
1001+
1002+
Ok(())
1003+
}
9911004
}

ballista/rust/core/src/serde/logical_plan/to_proto.rs

+10-4
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,9 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
10201020
AggregateFunction::ApproxDistinct => {
10211021
protobuf::AggregateFunction::ApproxDistinct
10221022
}
1023+
AggregateFunction::ApproxQuantile => {
1024+
protobuf::AggregateFunction::ApproxQuantile
1025+
}
10231026
AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg,
10241027
AggregateFunction::Min => protobuf::AggregateFunction::Min,
10251028
AggregateFunction::Max => protobuf::AggregateFunction::Max,
@@ -1036,11 +1039,13 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
10361039
}
10371040
};
10381041

1039-
let arg = &args[0];
1040-
let aggregate_expr = Box::new(protobuf::AggregateExprNode {
1042+
let aggregate_expr = protobuf::AggregateExprNode {
10411043
aggr_function: aggr_function.into(),
1042-
expr: Some(Box::new(arg.try_into()?)),
1043-
});
1044+
expr: args
1045+
.iter()
1046+
.map(|v| v.try_into())
1047+
.collect::<Result<Vec<_>, _>>()?,
1048+
};
10441049
Ok(protobuf::LogicalExprNode {
10451050
expr_type: Some(ExprType::AggregateExpr(aggregate_expr)),
10461051
})
@@ -1268,6 +1273,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction {
12681273
AggregateFunction::VariancePop => Self::VariancePop,
12691274
AggregateFunction::Stddev => Self::Stddev,
12701275
AggregateFunction::StddevPop => Self::StddevPop,
1276+
AggregateFunction::ApproxQuantile => Self::ApproxQuantile,
12711277
}
12721278
}
12731279
}

ballista/rust/core/src/serde/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,9 @@ impl From<protobuf::AggregateFunction> for AggregateFunction {
123123
protobuf::AggregateFunction::VariancePop => AggregateFunction::VariancePop,
124124
protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev,
125125
protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop,
126+
protobuf::AggregateFunction::ApproxQuantile => {
127+
AggregateFunction::ApproxQuantile
128+
}
126129
}
127130
}
128131
}

0 commit comments

Comments
 (0)