Skip to content

Commit 066fcb8

Browse files
committed
refactor: use input type as return type
Casts the calculated quantile value to the same type as the input data.
1 parent ce4e1f3 commit 066fcb8

File tree

3 files changed

+29
-11
lines changed

3 files changed

+29
-11
lines changed

datafusion/src/physical_plan/aggregates.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ pub fn return_type(
145145
coerced_data_types[0].clone(),
146146
true,
147147
)))),
148-
AggregateFunction::ApproxQuantile => Ok(DataType::Float64),
148+
AggregateFunction::ApproxQuantile => Ok(coerced_data_types[0].clone()),
149149
}
150150
}
151151

@@ -455,7 +455,7 @@ mod tests {
455455
assert!(result_agg_phy_exprs.as_any().is::<ApproxQuantile>());
456456
assert_eq!("c1", result_agg_phy_exprs.name());
457457
assert_eq!(
458-
Field::new("c1", DataType::Float64, false),
458+
Field::new("c1", data_type.clone(), false),
459459
result_agg_phy_exprs.field().unwrap()
460460
);
461461
}

datafusion/src/physical_plan/expressions/approx_quantile.rs

+26-8
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ impl AggregateExpr for ApproxQuantile {
105105
}
106106

107107
fn field(&self) -> Result<Field> {
108-
Ok(Field::new(&self.name, DataType::Float64, false))
108+
Ok(Field::new(&self.name, self.input_data_type.clone(), false))
109109
}
110110

111111
/// See [`TDigest::to_scalar_state()`] for a description of the serialised
@@ -151,7 +151,9 @@ impl AggregateExpr for ApproxQuantile {
151151

152152
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
153153
let accumulator: Box<dyn Accumulator> = match &self.input_data_type {
154-
DataType::UInt8
154+
t
155+
@
156+
(DataType::UInt8
155157
| DataType::UInt16
156158
| DataType::UInt32
157159
| DataType::UInt64
@@ -160,8 +162,8 @@ impl AggregateExpr for ApproxQuantile {
160162
| DataType::Int32
161163
| DataType::Int64
162164
| DataType::Float32
163-
| DataType::Float64 => {
164-
Box::new(ApproxQuantileAccumulator::new(self.quantile))
165+
| DataType::Float64) => {
166+
Box::new(ApproxQuantileAccumulator::new(self.quantile, t.clone()))
165167
}
166168
other => {
167169
return Err(DataFusionError::NotImplemented(format!(
@@ -182,13 +184,15 @@ impl AggregateExpr for ApproxQuantile {
182184
pub struct ApproxQuantileAccumulator {
183185
digest: TDigest,
184186
quantile: f64,
187+
return_type: DataType,
185188
}
186189

187190
impl ApproxQuantileAccumulator {
188-
pub fn new(quantile: f64) -> Self {
191+
pub fn new(quantile: f64, return_type: DataType) -> Self {
189192
Self {
190193
digest: TDigest::new(100),
191194
quantile,
195+
return_type,
192196
}
193197
}
194198
}
@@ -283,8 +287,22 @@ impl Accumulator for ApproxQuantileAccumulator {
283287
}
284288

285289
fn evaluate(&self) -> Result<ScalarValue> {
286-
Ok(ScalarValue::Float64(Some(
287-
self.digest.estimate_quantile(self.quantile),
288-
)))
290+
let q = self.digest.estimate_quantile(self.quantile);
291+
292+
// These acceptable return types MUST match the validation in
293+
// ApproxQuantile::create_accumulator.
294+
Ok(match &self.return_type {
295+
DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
296+
DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
297+
DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
298+
DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
299+
DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
300+
DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
301+
DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
302+
DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
303+
DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
304+
DataType::Float64 => ScalarValue::Float64(Some(q as f64)),
305+
v => unreachable!("unexpected return type {:?}", v),
306+
})
289307
}
290308
}

datafusion/tests/sql/aggregates.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ async fn csv_query_approx_quantile() -> Result<()> {
348348
// within 5% of the $actual quantile value.
349349
macro_rules! quantile_test {
350350
($ctx:ident, column=$column:literal, quantile=$quantile:literal, actual=$actual:literal) => {
351-
let sql = format!("SELECT (ABS(1 - approx_quantile({}, {}) / {}) < 0.05) AS q FROM aggregate_test_100", $column, $quantile, $actual);
351+
let sql = format!("SELECT (ABS(1 - CAST(approx_quantile({}, {}) AS DOUBLE) / {}) < 0.05) AS q FROM aggregate_test_100", $column, $quantile, $actual);
352352
let actual = execute_to_batches(&mut ctx, &sql).await;
353353
//
354354
// "+------+",

0 commit comments

Comments
 (0)