Skip to content

Commit 198342b

Browse files
committed
support min,max for decimal data type
1 parent e3a682d commit 198342b

File tree

16 files changed

+695
-54
lines changed

16 files changed

+695
-54
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion::arrow::datatypes::{DataType, Field, Schema};
19+
use datafusion::error::Result;
20+
use datafusion::prelude::*;
21+
use std::sync::Arc;
22+
23+
/// This example demonstrates executing a simple query against an Arrow data source (CSV) and
24+
/// fetching results
25+
#[tokio::main]
26+
async fn main() -> Result<()> {
27+
// create local execution context
28+
let mut ctx = ExecutionContext::new();
29+
30+
let testdata = datafusion::test_util::arrow_test_data();
31+
32+
// schema with decimal type
33+
let schema = Arc::new(Schema::new(vec![
34+
Field::new("c1", DataType::Decimal(10, 6), false),
35+
Field::new("c2", DataType::Float64, false),
36+
Field::new("c3", DataType::Boolean, false),
37+
]));
38+
39+
// register csv file with the execution context
40+
ctx.register_csv(
41+
"aggregate_simple",
42+
&format!("{}/csv/aggregate_simple.csv", testdata),
43+
CsvReadOptions::new().schema(&schema),
44+
)
45+
.await?;
46+
47+
// execute the query
48+
let df = ctx.sql("select c1 from aggregate_simple").await?;
49+
50+
// print the results
51+
df.show().await?;
52+
53+
Ok(())
54+
}

datafusion/Cargo.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ avro = ["avro-rs", "num-traits"]
5252
[dependencies]
5353
ahash = "0.7"
5454
hashbrown = { version = "0.11", features = ["raw"] }
55-
arrow = { version = "6.2.0", features = ["prettyprint"] }
56-
parquet = { version = "6.2.0", features = ["arrow"] }
55+
arrow = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow", features = ["prettyprint"] }
56+
#arrow = { version = "6.2.0", features = ["prettyprint"] }
57+
#parquet = { version = "6.2.0", features = ["arrow"] }
58+
parquet = { path = "/Users/kliu3/Documents/github/arrow-rs/parquet", features = ["arrow"] }
5759
sqlparser = "0.12"
5860
paste = "^1.0"
5961
num_cpus = "1.13.0"

datafusion/src/execution/context.rs

+25
Original file line numberDiff line numberDiff line change
@@ -3895,6 +3895,31 @@ mod tests {
38953895
Ok(())
38963896
}
38973897

3898+
#[tokio::test]
3899+
async fn aggregate_decimal() -> Result<()> {
3900+
let mut ctx = ExecutionContext::new();
3901+
// schema with data
3902+
let schema = Arc::new(Schema::new(vec![
3903+
Field::new("c1", DataType::Decimal(10, 6), false),
3904+
Field::new("c2", DataType::Float64, false),
3905+
Field::new("c3", DataType::Boolean, false),
3906+
]));
3907+
3908+
ctx.register_csv(
3909+
"aggregate_simple",
3910+
"tests/aggregate_simple.csv",
3911+
CsvReadOptions::new().schema(&schema),
3912+
)
3913+
.await?;
3914+
3915+
// decimal query
3916+
let result = plan_and_collect(&mut ctx, "select min(c1) from aggregate_simple")
3917+
.await
3918+
.unwrap();
3919+
println!("{:?}", result);
3920+
Ok(())
3921+
}
3922+
38983923
#[tokio::test]
38993924
async fn create_external_table_with_timestamps() {
39003925
let mut ctx = ExecutionContext::new();

datafusion/src/physical_plan/aggregates.rs

+97-39
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,16 @@
2828
2929
use super::{
3030
functions::{Signature, Volatility},
31-
type_coercion::{coerce, data_types},
3231
Accumulator, AggregateExpr, PhysicalExpr,
3332
};
3433
use crate::error::{DataFusionError, Result};
34+
use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_types};
3535
use crate::physical_plan::distinct_expressions;
3636
use crate::physical_plan::expressions;
3737
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
3838
use expressions::{avg_return_type, sum_return_type};
3939
use std::{fmt, str::FromStr, sync::Arc};
40+
4041
/// the implementation of an aggregate function
4142
pub type AccumulatorFunctionImplementation =
4243
Arc<dyn Fn() -> Result<Box<dyn Accumulator>> + Send + Sync>;
@@ -87,96 +88,125 @@ impl FromStr for AggregateFunction {
8788
return Err(DataFusionError::Plan(format!(
8889
"There is no built-in function named {}",
8990
name
90-
)))
91+
)));
9192
}
9293
})
9394
}
9495
}
9596

96-
/// Returns the datatype of the scalar function
97+
/// Returns the datatype of the aggregate function.
98+
/// This is used to get the returned data type for aggregate logical expr.
9799
pub fn return_type(fun: &AggregateFunction, arg_types: &[DataType]) -> Result<DataType> {
98100
// Note that this function *must* return the same type that the respective physical expression returns
99101
// or the execution panics.
100102

101-
// verify that this is a valid set of data types for this function
102-
data_types(arg_types, &signature(fun))?;
103+
let coerced_data_types = coerce_types(fun, arg_types, &signature(fun))?;
103104

104105
match fun {
106+
// TODO liukun4515
107+
// If the datafusion is compatible with PostgreSQL, the returned data type should be INT64.
105108
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
106109
Ok(DataType::UInt64)
107110
}
108-
AggregateFunction::Max | AggregateFunction::Min => Ok(arg_types[0].clone()),
109-
AggregateFunction::Sum => sum_return_type(&arg_types[0]),
110-
AggregateFunction::Avg => avg_return_type(&arg_types[0]),
111+
AggregateFunction::Max | AggregateFunction::Min => {
112+
// For min and max agg function, the returned type is same as input type.
113+
// The coerced_data_types is same with input_types.
114+
Ok(coerced_data_types[0].clone())
115+
}
116+
AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]),
117+
AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
111118
AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new(
112119
"item",
113-
arg_types[0].clone(),
120+
coerced_data_types[0].clone(),
114121
true,
115122
)))),
116123
}
117124
}
118125

119126
/// Create a physical (function) expression.
120-
/// This function errors when `args`' can't be coerced to a valid argument type of the function.
127+
/// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the function.
121128
pub fn create_aggregate_expr(
122129
fun: &AggregateFunction,
123130
distinct: bool,
124-
args: &[Arc<dyn PhysicalExpr>],
131+
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
125132
input_schema: &Schema,
126133
name: impl Into<String>,
127134
) -> Result<Arc<dyn AggregateExpr>> {
128135
let name = name.into();
129-
let arg = coerce(args, input_schema, &signature(fun))?;
130-
if arg.is_empty() {
136+
// get the coerced phy exprs if some expr need try cast
137+
let coerced_phy_exprs =
138+
coerce_exprs(fun, input_phy_exprs, input_schema, &signature(fun))?;
139+
if coerced_phy_exprs.is_empty() {
131140
return Err(DataFusionError::Plan(format!(
132141
"Invalid or wrong number of arguments passed to aggregate: '{}'",
133142
name,
134143
)));
135144
}
136-
let arg = arg[0].clone();
137-
138-
let arg_types = args
145+
let coerced_types = coerced_phy_exprs
139146
.iter()
140147
.map(|e| e.data_type(input_schema))
141148
.collect::<Result<Vec<_>>>()?;
149+
let first_coerced_phy_expr = coerced_phy_exprs[0].clone();
142150

143-
let return_type = return_type(fun, &arg_types)?;
151+
// get the result data type for this aggregate function
152+
let input_phy_types = input_phy_exprs
153+
.iter()
154+
.map(|e| e.data_type(input_schema))
155+
.collect::<Result<Vec<_>>>()?;
156+
let return_type = return_type(fun, &input_phy_types)?;
144157

145158
Ok(match (fun, distinct) {
146-
(AggregateFunction::Count, false) => {
147-
Arc::new(expressions::Count::new(arg, name, return_type))
148-
}
159+
(AggregateFunction::Count, false) => Arc::new(expressions::Count::new(
160+
first_coerced_phy_expr,
161+
name,
162+
return_type,
163+
)),
149164
(AggregateFunction::Count, true) => {
150165
Arc::new(distinct_expressions::DistinctCount::new(
151-
arg_types,
152-
args.to_vec(),
166+
coerced_types,
167+
coerced_phy_exprs,
153168
name,
154169
return_type,
155170
))
156171
}
157-
(AggregateFunction::Sum, false) => {
158-
Arc::new(expressions::Sum::new(arg, name, return_type))
159-
}
172+
(AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new(
173+
first_coerced_phy_expr,
174+
name,
175+
return_type,
176+
)),
160177
(AggregateFunction::Sum, true) => {
161178
return Err(DataFusionError::NotImplemented(
162179
"SUM(DISTINCT) aggregations are not available".to_string(),
163180
));
164181
}
165-
(AggregateFunction::ApproxDistinct, _) => Arc::new(
166-
expressions::ApproxDistinct::new(arg, name, arg_types[0].clone()),
167-
),
168-
(AggregateFunction::ArrayAgg, _) => {
169-
Arc::new(expressions::ArrayAgg::new(arg, name, arg_types[0].clone()))
170-
}
171-
(AggregateFunction::Min, _) => {
172-
Arc::new(expressions::Min::new(arg, name, return_type))
173-
}
174-
(AggregateFunction::Max, _) => {
175-
Arc::new(expressions::Max::new(arg, name, return_type))
176-
}
177-
(AggregateFunction::Avg, false) => {
178-
Arc::new(expressions::Avg::new(arg, name, return_type))
182+
// TODO optimize the expr and make arg more meaningful
183+
(AggregateFunction::ApproxDistinct, _) => {
184+
Arc::new(expressions::ApproxDistinct::new(
185+
first_coerced_phy_expr,
186+
name,
187+
coerced_types[0].clone(),
188+
))
179189
}
190+
(AggregateFunction::ArrayAgg, _) => Arc::new(expressions::ArrayAgg::new(
191+
first_coerced_phy_expr,
192+
name,
193+
coerced_types[0].clone(),
194+
)),
195+
(AggregateFunction::Min, _) => Arc::new(expressions::Min::new(
196+
first_coerced_phy_expr,
197+
name,
198+
return_type,
199+
)),
200+
(AggregateFunction::Max, _) => Arc::new(expressions::Max::new(
201+
first_coerced_phy_expr,
202+
name,
203+
return_type,
204+
)),
205+
(AggregateFunction::Avg, false) => Arc::new(expressions::Avg::new(
206+
first_coerced_phy_expr,
207+
name,
208+
return_type,
209+
)),
180210
(AggregateFunction::Avg, true) => {
181211
return Err(DataFusionError::NotImplemented(
182212
"AVG(DISTINCT) aggregations are not available".to_string(),
@@ -209,6 +239,7 @@ static TIMESTAMPS: &[DataType] = &[
209239

210240
static DATES: &[DataType] = &[DataType::Date32, DataType::Date64];
211241

242+
// TODO liukun4515
212243
/// the signatures supported by the function `fun`.
213244
pub fn signature(fun: &AggregateFunction) -> Signature {
214245
// note: the physical expression must accept the type returned by this function or the execution panics.
@@ -244,6 +275,29 @@ mod tests {
244275

245276
let observed = return_type(&AggregateFunction::Max, &[DataType::Int32])?;
246277
assert_eq!(DataType::Int32, observed);
278+
279+
let observed = return_type(&AggregateFunction::Min, &[DataType::Decimal(10, 6)])?;
280+
assert_eq!(DataType::Decimal(10, 6), observed);
281+
282+
let observed =
283+
return_type(&AggregateFunction::Max, &[DataType::Decimal(28, 13)])?;
284+
assert_eq!(DataType::Decimal(28, 13), observed);
285+
286+
// TODO liukun4515
287+
// Add test to create agg expr
288+
let fun = AggregateFunction::Min;
289+
let input_schema =
290+
Schema::new(vec![Field::new("c1", DataType::Decimal(10, 6), false)]);
291+
// let input_schema =
292+
// Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
293+
let expr: Arc<dyn PhysicalExpr> =
294+
Arc::new(expressions::Column::new_with_schema("c1", &input_schema).unwrap());
295+
let args = vec![expr];
296+
let result_agg =
297+
create_aggregate_expr(&fun, false, &args[0..1], &input_schema, "c1")?;
298+
// assert the filed type
299+
assert_eq!(input_schema.field(0), &result_agg.field().unwrap());
300+
assert_eq!("min(c1)", result_agg.name());
247301
Ok(())
248302
}
249303

@@ -267,6 +321,10 @@ mod tests {
267321

268322
let observed = return_type(&AggregateFunction::Count, &[DataType::Int8])?;
269323
assert_eq!(DataType::UInt64, observed);
324+
325+
let observed =
326+
return_type(&AggregateFunction::Count, &[DataType::Decimal(28, 13)])?;
327+
assert_eq!(DataType::UInt64, observed);
270328
Ok(())
271329
}
272330

0 commit comments

Comments
 (0)