28
28
29
29
use super :: {
30
30
functions:: { Signature , Volatility } ,
31
- type_coercion:: { coerce, data_types} ,
32
31
Accumulator , AggregateExpr , PhysicalExpr ,
33
32
} ;
34
33
use crate :: error:: { DataFusionError , Result } ;
34
+ use crate :: physical_plan:: coercion_rule:: aggregate_rule:: { coerce_exprs, coerce_types} ;
35
35
use crate :: physical_plan:: distinct_expressions;
36
36
use crate :: physical_plan:: expressions;
37
37
use arrow:: datatypes:: { DataType , Field , Schema , TimeUnit } ;
38
38
use expressions:: { avg_return_type, sum_return_type} ;
39
39
use std:: { fmt, str:: FromStr , sync:: Arc } ;
40
+
40
41
/// the implementation of an aggregate function
41
42
pub type AccumulatorFunctionImplementation =
42
43
Arc < dyn Fn ( ) -> Result < Box < dyn Accumulator > > + Send + Sync > ;
@@ -87,96 +88,125 @@ impl FromStr for AggregateFunction {
87
88
return Err ( DataFusionError :: Plan ( format ! (
88
89
"There is no built-in function named {}" ,
89
90
name
90
- ) ) )
91
+ ) ) ) ;
91
92
}
92
93
} )
93
94
}
94
95
}
95
96
96
- /// Returns the datatype of the scalar function
97
+ /// Returns the datatype of the aggregate function.
98
+ /// This is used to get the returned data type for aggregate logical expr.
97
99
pub fn return_type ( fun : & AggregateFunction , arg_types : & [ DataType ] ) -> Result < DataType > {
98
100
// Note that this function *must* return the same type that the respective physical expression returns
99
101
// or the execution panics.
100
102
101
- // verify that this is a valid set of data types for this function
102
- data_types ( arg_types, & signature ( fun) ) ?;
103
+ let coerced_data_types = coerce_types ( fun, arg_types, & signature ( fun) ) ?;
103
104
104
105
match fun {
106
+ // TODO liukun4515
107
+ // If the datafusion is compatible with PostgreSQL, the returned data type should be INT64.
105
108
AggregateFunction :: Count | AggregateFunction :: ApproxDistinct => {
106
109
Ok ( DataType :: UInt64 )
107
110
}
108
- AggregateFunction :: Max | AggregateFunction :: Min => Ok ( arg_types[ 0 ] . clone ( ) ) ,
109
- AggregateFunction :: Sum => sum_return_type ( & arg_types[ 0 ] ) ,
110
- AggregateFunction :: Avg => avg_return_type ( & arg_types[ 0 ] ) ,
111
+ AggregateFunction :: Max | AggregateFunction :: Min => {
112
+ // For min and max agg function, the returned type is same as input type.
113
+ // The coerced_data_types is same with input_types.
114
+ Ok ( coerced_data_types[ 0 ] . clone ( ) )
115
+ }
116
+ AggregateFunction :: Sum => sum_return_type ( & coerced_data_types[ 0 ] ) ,
117
+ AggregateFunction :: Avg => avg_return_type ( & coerced_data_types[ 0 ] ) ,
111
118
AggregateFunction :: ArrayAgg => Ok ( DataType :: List ( Box :: new ( Field :: new (
112
119
"item" ,
113
- arg_types [ 0 ] . clone ( ) ,
120
+ coerced_data_types [ 0 ] . clone ( ) ,
114
121
true ,
115
122
) ) ) ) ,
116
123
}
117
124
}
118
125
119
126
/// Create a physical (function) expression.
120
- /// This function errors when `args `' can't be coerced to a valid argument type of the function.
127
+ /// This function errors when `input_phy_exprs `' can't be coerced to a valid argument type of the function.
121
128
pub fn create_aggregate_expr (
122
129
fun : & AggregateFunction ,
123
130
distinct : bool ,
124
- args : & [ Arc < dyn PhysicalExpr > ] ,
131
+ input_phy_exprs : & [ Arc < dyn PhysicalExpr > ] ,
125
132
input_schema : & Schema ,
126
133
name : impl Into < String > ,
127
134
) -> Result < Arc < dyn AggregateExpr > > {
128
135
let name = name. into ( ) ;
129
- let arg = coerce ( args, input_schema, & signature ( fun) ) ?;
130
- if arg. is_empty ( ) {
136
+ // get the coerced phy exprs if some expr need try cast
137
+ let coerced_phy_exprs =
138
+ coerce_exprs ( fun, input_phy_exprs, input_schema, & signature ( fun) ) ?;
139
+ if coerced_phy_exprs. is_empty ( ) {
131
140
return Err ( DataFusionError :: Plan ( format ! (
132
141
"Invalid or wrong number of arguments passed to aggregate: '{}'" ,
133
142
name,
134
143
) ) ) ;
135
144
}
136
- let arg = arg[ 0 ] . clone ( ) ;
137
-
138
- let arg_types = args
145
+ let coerced_types = coerced_phy_exprs
139
146
. iter ( )
140
147
. map ( |e| e. data_type ( input_schema) )
141
148
. collect :: < Result < Vec < _ > > > ( ) ?;
149
+ let first_coerced_phy_expr = coerced_phy_exprs[ 0 ] . clone ( ) ;
142
150
143
- let return_type = return_type ( fun, & arg_types) ?;
151
+ // get the result data type for this aggregate function
152
+ let input_phy_types = input_phy_exprs
153
+ . iter ( )
154
+ . map ( |e| e. data_type ( input_schema) )
155
+ . collect :: < Result < Vec < _ > > > ( ) ?;
156
+ let return_type = return_type ( fun, & input_phy_types) ?;
144
157
145
158
Ok ( match ( fun, distinct) {
146
- ( AggregateFunction :: Count , false ) => {
147
- Arc :: new ( expressions:: Count :: new ( arg, name, return_type) )
148
- }
159
+ ( AggregateFunction :: Count , false ) => Arc :: new ( expressions:: Count :: new (
160
+ first_coerced_phy_expr,
161
+ name,
162
+ return_type,
163
+ ) ) ,
149
164
( AggregateFunction :: Count , true ) => {
150
165
Arc :: new ( distinct_expressions:: DistinctCount :: new (
151
- arg_types ,
152
- args . to_vec ( ) ,
166
+ coerced_types ,
167
+ coerced_phy_exprs ,
153
168
name,
154
169
return_type,
155
170
) )
156
171
}
157
- ( AggregateFunction :: Sum , false ) => {
158
- Arc :: new ( expressions:: Sum :: new ( arg, name, return_type) )
159
- }
172
+ ( AggregateFunction :: Sum , false ) => Arc :: new ( expressions:: Sum :: new (
173
+ first_coerced_phy_expr,
174
+ name,
175
+ return_type,
176
+ ) ) ,
160
177
( AggregateFunction :: Sum , true ) => {
161
178
return Err ( DataFusionError :: NotImplemented (
162
179
"SUM(DISTINCT) aggregations are not available" . to_string ( ) ,
163
180
) ) ;
164
181
}
165
- ( AggregateFunction :: ApproxDistinct , _) => Arc :: new (
166
- expressions:: ApproxDistinct :: new ( arg, name, arg_types[ 0 ] . clone ( ) ) ,
167
- ) ,
168
- ( AggregateFunction :: ArrayAgg , _) => {
169
- Arc :: new ( expressions:: ArrayAgg :: new ( arg, name, arg_types[ 0 ] . clone ( ) ) )
170
- }
171
- ( AggregateFunction :: Min , _) => {
172
- Arc :: new ( expressions:: Min :: new ( arg, name, return_type) )
173
- }
174
- ( AggregateFunction :: Max , _) => {
175
- Arc :: new ( expressions:: Max :: new ( arg, name, return_type) )
176
- }
177
- ( AggregateFunction :: Avg , false ) => {
178
- Arc :: new ( expressions:: Avg :: new ( arg, name, return_type) )
182
+ // TODO optimize the expr and make arg more meaningful
183
+ ( AggregateFunction :: ApproxDistinct , _) => {
184
+ Arc :: new ( expressions:: ApproxDistinct :: new (
185
+ first_coerced_phy_expr,
186
+ name,
187
+ coerced_types[ 0 ] . clone ( ) ,
188
+ ) )
179
189
}
190
+ ( AggregateFunction :: ArrayAgg , _) => Arc :: new ( expressions:: ArrayAgg :: new (
191
+ first_coerced_phy_expr,
192
+ name,
193
+ coerced_types[ 0 ] . clone ( ) ,
194
+ ) ) ,
195
+ ( AggregateFunction :: Min , _) => Arc :: new ( expressions:: Min :: new (
196
+ first_coerced_phy_expr,
197
+ name,
198
+ return_type,
199
+ ) ) ,
200
+ ( AggregateFunction :: Max , _) => Arc :: new ( expressions:: Max :: new (
201
+ first_coerced_phy_expr,
202
+ name,
203
+ return_type,
204
+ ) ) ,
205
+ ( AggregateFunction :: Avg , false ) => Arc :: new ( expressions:: Avg :: new (
206
+ first_coerced_phy_expr,
207
+ name,
208
+ return_type,
209
+ ) ) ,
180
210
( AggregateFunction :: Avg , true ) => {
181
211
return Err ( DataFusionError :: NotImplemented (
182
212
"AVG(DISTINCT) aggregations are not available" . to_string ( ) ,
@@ -209,6 +239,7 @@ static TIMESTAMPS: &[DataType] = &[
209
239
210
240
static DATES : & [ DataType ] = & [ DataType :: Date32 , DataType :: Date64 ] ;
211
241
242
+ // TODO liukun4515
212
243
/// the signatures supported by the function `fun`.
213
244
pub fn signature ( fun : & AggregateFunction ) -> Signature {
214
245
// note: the physical expression must accept the type returned by this function or the execution panics.
@@ -244,6 +275,29 @@ mod tests {
244
275
245
276
let observed = return_type ( & AggregateFunction :: Max , & [ DataType :: Int32 ] ) ?;
246
277
assert_eq ! ( DataType :: Int32 , observed) ;
278
+
279
+ let observed = return_type ( & AggregateFunction :: Min , & [ DataType :: Decimal ( 10 , 6 ) ] ) ?;
280
+ assert_eq ! ( DataType :: Decimal ( 10 , 6 ) , observed) ;
281
+
282
+ let observed =
283
+ return_type ( & AggregateFunction :: Max , & [ DataType :: Decimal ( 28 , 13 ) ] ) ?;
284
+ assert_eq ! ( DataType :: Decimal ( 28 , 13 ) , observed) ;
285
+
286
+ // TODO liukun4515
287
+ // Add test to create agg expr
288
+ let fun = AggregateFunction :: Min ;
289
+ let input_schema =
290
+ Schema :: new ( vec ! [ Field :: new( "c1" , DataType :: Decimal ( 10 , 6 ) , false ) ] ) ;
291
+ // let input_schema =
292
+ // Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
293
+ let expr: Arc < dyn PhysicalExpr > =
294
+ Arc :: new ( expressions:: Column :: new_with_schema ( "c1" , & input_schema) . unwrap ( ) ) ;
295
+ let args = vec ! [ expr] ;
296
+ let result_agg =
297
+ create_aggregate_expr ( & fun, false , & args[ 0 ..1 ] , & input_schema, "c1" ) ?;
298
+ // assert the filed type
299
+ assert_eq ! ( input_schema. field( 0 ) , & result_agg. field( ) . unwrap( ) ) ;
300
+ assert_eq ! ( "min(c1)" , result_agg. name( ) ) ;
247
301
Ok ( ( ) )
248
302
}
249
303
@@ -267,6 +321,10 @@ mod tests {
267
321
268
322
let observed = return_type ( & AggregateFunction :: Count , & [ DataType :: Int8 ] ) ?;
269
323
assert_eq ! ( DataType :: UInt64 , observed) ;
324
+
325
+ let observed =
326
+ return_type ( & AggregateFunction :: Count , & [ DataType :: Decimal ( 28 , 13 ) ] ) ?;
327
+ assert_eq ! ( DataType :: UInt64 , observed) ;
270
328
Ok ( ( ) )
271
329
}
272
330
0 commit comments