@@ -105,7 +105,7 @@ impl AggregateExpr for ApproxQuantile {
105
105
}
106
106
107
107
fn field ( & self ) -> Result < Field > {
108
- Ok ( Field :: new ( & self . name , DataType :: Float64 , false ) )
108
+ Ok ( Field :: new ( & self . name , self . input_data_type . clone ( ) , false ) )
109
109
}
110
110
111
111
/// See [`TDigest::to_scalar_state()`] for a description of the serialised
@@ -151,7 +151,9 @@ impl AggregateExpr for ApproxQuantile {
151
151
152
152
fn create_accumulator ( & self ) -> Result < Box < dyn Accumulator > > {
153
153
let accumulator: Box < dyn Accumulator > = match & self . input_data_type {
154
- DataType :: UInt8
154
+ t
155
+ @
156
+ ( DataType :: UInt8
155
157
| DataType :: UInt16
156
158
| DataType :: UInt32
157
159
| DataType :: UInt64
@@ -160,8 +162,8 @@ impl AggregateExpr for ApproxQuantile {
160
162
| DataType :: Int32
161
163
| DataType :: Int64
162
164
| DataType :: Float32
163
- | DataType :: Float64 => {
164
- Box :: new ( ApproxQuantileAccumulator :: new ( self . quantile ) )
165
+ | DataType :: Float64 ) => {
166
+ Box :: new ( ApproxQuantileAccumulator :: new ( self . quantile , t . clone ( ) ) )
165
167
}
166
168
other => {
167
169
return Err ( DataFusionError :: NotImplemented ( format ! (
@@ -182,13 +184,15 @@ impl AggregateExpr for ApproxQuantile {
182
184
pub struct ApproxQuantileAccumulator {
183
185
digest : TDigest ,
184
186
quantile : f64 ,
187
+ return_type : DataType ,
185
188
}
186
189
187
190
impl ApproxQuantileAccumulator {
188
- pub fn new ( quantile : f64 ) -> Self {
191
+ pub fn new ( quantile : f64 , return_type : DataType ) -> Self {
189
192
Self {
190
193
digest : TDigest :: new ( 100 ) ,
191
194
quantile,
195
+ return_type,
192
196
}
193
197
}
194
198
}
@@ -283,8 +287,22 @@ impl Accumulator for ApproxQuantileAccumulator {
283
287
}
284
288
285
289
fn evaluate ( & self ) -> Result < ScalarValue > {
286
- Ok ( ScalarValue :: Float64 ( Some (
287
- self . digest . estimate_quantile ( self . quantile ) ,
288
- ) ) )
290
+ let q = self . digest . estimate_quantile ( self . quantile ) ;
291
+
292
+ // These acceptable return types MUST match the validation in
293
+ // ApproxQuantile::create_accumulator.
294
+ Ok ( match & self . return_type {
295
+ DataType :: Int8 => ScalarValue :: Int8 ( Some ( q as i8 ) ) ,
296
+ DataType :: Int16 => ScalarValue :: Int16 ( Some ( q as i16 ) ) ,
297
+ DataType :: Int32 => ScalarValue :: Int32 ( Some ( q as i32 ) ) ,
298
+ DataType :: Int64 => ScalarValue :: Int64 ( Some ( q as i64 ) ) ,
299
+ DataType :: UInt8 => ScalarValue :: UInt8 ( Some ( q as u8 ) ) ,
300
+ DataType :: UInt16 => ScalarValue :: UInt16 ( Some ( q as u16 ) ) ,
301
+ DataType :: UInt32 => ScalarValue :: UInt32 ( Some ( q as u32 ) ) ,
302
+ DataType :: UInt64 => ScalarValue :: UInt64 ( Some ( q as u64 ) ) ,
303
+ DataType :: Float32 => ScalarValue :: Float32 ( Some ( q as f32 ) ) ,
304
+ DataType :: Float64 => ScalarValue :: Float64 ( Some ( q as f64 ) ) ,
305
+ v => unreachable ! ( "unexpected return type {:?}" , v) ,
306
+ } )
289
307
}
290
308
}
0 commit comments