Skip to content

Commit 450ccbb

Browse files
committed
support min,max for decimal data type
1 parent e3a682d commit 450ccbb

File tree

10 files changed

+384
-8
lines changed

10 files changed

+384
-8
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion::arrow::datatypes::{DataType, Field, Schema};
19+
use datafusion::error::Result;
20+
use datafusion::prelude::*;
21+
use std::sync::Arc;
22+
23+
/// This example demonstrates executing a simple query against an Arrow data source (CSV) and
24+
/// fetching results
25+
#[tokio::main]
26+
async fn main() -> Result<()> {
27+
// create local execution context
28+
let mut ctx = ExecutionContext::new();
29+
30+
let testdata = datafusion::test_util::arrow_test_data();
31+
32+
// schema with decimal type
33+
let schema = Arc::new(Schema::new(vec![
34+
Field::new("c1", DataType::Decimal(10, 6), false),
35+
Field::new("c2", DataType::Float64, false),
36+
Field::new("c3", DataType::Boolean, false),
37+
]));
38+
39+
// register csv file with the execution context
40+
ctx.register_csv(
41+
"aggregate_simple",
42+
&format!("{}/csv/aggregate_simple.csv", testdata),
43+
CsvReadOptions::new().schema(&schema),
44+
)
45+
.await?;
46+
47+
// execute the query
48+
let df = ctx.sql("select c1 from aggregate_simple").await?;
49+
50+
// print the results
51+
df.show().await?;
52+
53+
Ok(())
54+
}

datafusion/Cargo.toml

+4-2
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,10 @@ avro = ["avro-rs", "num-traits"]
5252
[dependencies]
5353
ahash = "0.7"
5454
hashbrown = { version = "0.11", features = ["raw"] }
55-
arrow = { version = "6.2.0", features = ["prettyprint"] }
56-
parquet = { version = "6.2.0", features = ["arrow"] }
55+
arrow = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow", features = ["prettyprint"] }
56+
#arrow = { version = "6.2.0", features = ["prettyprint"] }
57+
#parquet = { version = "6.2.0", features = ["arrow"] }
58+
parquet = { path = "/Users/kliu3/Documents/github/arrow-rs/parquet", features = ["arrow"] }
5759
sqlparser = "0.12"
5860
paste = "^1.0"
5961
num_cpus = "1.13.0"

datafusion/src/execution/context.rs

+25
Original file line numberDiff line numberDiff line change
@@ -3895,6 +3895,31 @@ mod tests {
38953895
Ok(())
38963896
}
38973897

3898+
#[tokio::test]
3899+
async fn aggregate_decimal() -> Result<()> {
3900+
let mut ctx = ExecutionContext::new();
3901+
// schema with data
3902+
let schema = Arc::new(Schema::new(vec![
3903+
Field::new("c1", DataType::Decimal(10, 6), false),
3904+
Field::new("c2", DataType::Float64, false),
3905+
Field::new("c3", DataType::Boolean, false),
3906+
]));
3907+
3908+
ctx.register_csv(
3909+
"aggregate_simple",
3910+
"tests/aggregate_simple.csv",
3911+
CsvReadOptions::new().schema(&schema),
3912+
)
3913+
.await?;
3914+
3915+
// decimal query
3916+
let result = plan_and_collect(&mut ctx, "select min(c1) from aggregate_simple")
3917+
.await
3918+
.unwrap();
3919+
println!("{:?}", result);
3920+
Ok(())
3921+
}
3922+
38983923
#[tokio::test]
38993924
async fn create_external_table_with_timestamps() {
39003925
let mut ctx = ExecutionContext::new();

datafusion/src/physical_plan/aggregates.rs

+28
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,7 @@ static TIMESTAMPS: &[DataType] = &[
209209

210210
static DATES: &[DataType] = &[DataType::Date32, DataType::Date64];
211211

212+
// TODO liukun4515
212213
/// the signatures supported by the function `fun`.
213214
pub fn signature(fun: &AggregateFunction) -> Signature {
214215
// note: the physical expression must accept the type returned by this function or the execution panics.
@@ -244,6 +245,29 @@ mod tests {
244245

245246
let observed = return_type(&AggregateFunction::Max, &[DataType::Int32])?;
246247
assert_eq!(DataType::Int32, observed);
248+
249+
let observed = return_type(&AggregateFunction::Min, &[DataType::Decimal(10, 6)])?;
250+
assert_eq!(DataType::Decimal(10, 6), observed);
251+
252+
let observed =
253+
return_type(&AggregateFunction::Max, &[DataType::Decimal(28, 13)])?;
254+
assert_eq!(DataType::Decimal(28, 13), observed);
255+
256+
// TODO liukun4515
257+
// Add test to create agg expr
258+
let fun = AggregateFunction::Min;
259+
let input_schema =
260+
Schema::new(vec![Field::new("c1", DataType::Decimal(10, 6), false)]);
261+
// let input_schema =
262+
// Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
263+
let expr: Arc<dyn PhysicalExpr> =
264+
Arc::new(expressions::Column::new_with_schema("c1", &input_schema).unwrap());
265+
let args = vec![expr];
266+
let result_agg =
267+
create_aggregate_expr(&fun, false, &args[0..1], &input_schema, "c1")?;
268+
// assert the filed type
269+
assert_eq!(input_schema.field(0), &result_agg.field().unwrap());
270+
assert_eq!("min(c1)", result_agg.name());
247271
Ok(())
248272
}
249273

@@ -267,6 +291,10 @@ mod tests {
267291

268292
let observed = return_type(&AggregateFunction::Count, &[DataType::Int8])?;
269293
assert_eq!(DataType::UInt64, observed);
294+
295+
let observed =
296+
return_type(&AggregateFunction::Count, &[DataType::Decimal(28, 13)])?;
297+
assert_eq!(DataType::UInt64, observed);
270298
Ok(())
271299
}
272300

datafusion/src/physical_plan/expressions/coercion.rs

+82-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
use arrow::datatypes::DataType;
2121

2222
/// Determine if a DataType is signed numeric or not
23-
pub fn is_signed_numeric(dt: &DataType) -> bool {
23+
pub(crate) fn is_signed_numeric(dt: &DataType) -> bool {
2424
matches!(
2525
dt,
2626
DataType::Int8
@@ -29,12 +29,13 @@ pub fn is_signed_numeric(dt: &DataType) -> bool {
2929
| DataType::Int64
3030
| DataType::Float16
3131
| DataType::Float32
32-
| DataType::Float64
32+
| DataType::Float64 // TODO liukun4515
33+
// | DataType::Decimal(_,_)
3334
)
3435
}
3536

3637
/// Determine if a DataType is numeric or not
37-
pub fn is_numeric(dt: &DataType) -> bool {
38+
fn is_numeric(dt: &DataType) -> bool {
3839
is_signed_numeric(dt)
3940
|| match dt {
4041
DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => {
@@ -125,6 +126,11 @@ pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<Da
125126
return Some(lhs_type.clone());
126127
}
127128

129+
// TODO liukun4515
130+
// In the decimal data type, what to do if the metadata of decimal type is diff.
131+
// add decimal data type, diff operator we should have diff rule to do coercion.
132+
// first step, we can just support decimal type in case which left and right datatype are the same
133+
128134
// these are ordered from most informative to least informative so
129135
// that the coercion removes the least amount of information
130136
match (lhs_type, rhs_type) {
@@ -170,6 +176,8 @@ pub fn order_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option<DataTy
170176
#[cfg(test)]
171177
mod tests {
172178
use super::*;
179+
use crate::arrow::datatypes::DataType::Int8;
180+
use arrow::datatypes::DataType::{Float32, Float64, Int16, Int32, Int64};
173181

174182
#[test]
175183
fn test_dictionary_type_coersion() {
@@ -192,4 +200,75 @@ mod tests {
192200
let rhs_type = Dictionary(Box::new(Int8), Box::new(Utf8));
193201
assert_eq!(dictionary_coercion(&lhs_type, &rhs_type), Some(Utf8));
194202
}
203+
204+
#[test]
205+
fn test_is_signed_numeric() {
206+
assert!(is_signed_numeric(&DataType::Int8));
207+
assert!(is_signed_numeric(&DataType::Int16));
208+
assert!(is_signed_numeric(&DataType::Int32));
209+
assert!(is_signed_numeric(&DataType::Int64));
210+
assert!(is_signed_numeric(&DataType::Float16));
211+
assert!(is_signed_numeric(&DataType::Float32));
212+
assert!(is_signed_numeric(&DataType::Float64));
213+
214+
// decimal data type
215+
assert!(is_signed_numeric(&DataType::Decimal(12, 2)));
216+
assert!(is_signed_numeric(&DataType::Decimal(14, 10)));
217+
218+
// negative test
219+
assert!(!is_signed_numeric(&DataType::UInt64));
220+
assert!(!is_signed_numeric(&DataType::UInt16));
221+
}
222+
223+
#[test]
224+
fn test_is_numeric() {
225+
assert!(is_numeric(&DataType::Int8));
226+
assert!(is_numeric(&DataType::Int16));
227+
assert!(is_numeric(&DataType::Int32));
228+
assert!(is_numeric(&DataType::Int64));
229+
assert!(is_numeric(&DataType::Float16));
230+
assert!(is_numeric(&DataType::Float32));
231+
assert!(is_numeric(&DataType::Float64));
232+
233+
// decimal data type
234+
assert!(is_numeric(&DataType::Decimal(12, 2)));
235+
assert!(is_numeric(&DataType::Decimal(14, 10)));
236+
237+
// unsigned test
238+
assert!(is_numeric(&DataType::UInt8));
239+
assert!(is_numeric(&DataType::UInt16));
240+
assert!(is_numeric(&DataType::UInt32));
241+
assert!(is_numeric(&DataType::UInt64));
242+
243+
// negative test
244+
assert!(!is_numeric(&DataType::Boolean));
245+
assert!(!is_numeric(&DataType::Date32));
246+
}
247+
248+
#[test]
249+
fn test_numerical_coercion() {
250+
// negative test
251+
assert_eq!(
252+
None,
253+
numerical_coercion(&DataType::Float64, &DataType::Binary)
254+
);
255+
assert_eq!(
256+
None,
257+
numerical_coercion(&DataType::Float64, &DataType::Utf8)
258+
);
259+
260+
// positive test
261+
let test_types = vec![Int8, Int16, Int32, Int64, Float32, Float64];
262+
let mut index = test_types.len();
263+
while index > 0 {
264+
let this_type = &test_types[index - 1];
265+
for i in 0..index {
266+
assert_eq!(
267+
Some(this_type.clone()),
268+
numerical_coercion(this_type, &test_types[i])
269+
);
270+
}
271+
index -= 1;
272+
}
273+
}
195274
}

datafusion/src/physical_plan/expressions/min_max.rs

+30
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,36 @@ mod tests {
487487
use arrow::datatypes::*;
488488
use arrow::record_batch::RecordBatch;
489489

490+
#[test]
491+
fn min_decimal() -> Result<()> {
492+
todo!()
493+
}
494+
495+
#[test]
496+
fn max_decimal() -> Result<()> {
497+
todo!()
498+
}
499+
500+
#[test]
501+
fn min_decimal_with_nulls() -> Result<()> {
502+
todo!()
503+
}
504+
505+
#[test]
506+
fn max_decimal_with_nulls() -> Result<()> {
507+
todo!()
508+
}
509+
510+
#[test]
511+
fn min_decimal_with_all_nulls() -> Result<()> {
512+
todo!()
513+
}
514+
515+
#[test]
516+
fn max_decimal_with_all_nulls() -> Result<()> {
517+
todo!()
518+
}
519+
490520
#[test]
491521
fn max_i32() -> Result<()> {
492522
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5]));

datafusion/src/physical_plan/functions.rs

+1
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ impl Signature {
144144
}
145145

146146
///A function's volatility, which defines the functions eligibility for certain optimizations
147+
///Ref from postgresql https://www.postgresql.org/docs/current/xfunc-volatility.html
147148
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
148149
pub enum Volatility {
149150
/// Immutable - An immutable function will always return the same output when given the same input. An example of this is [BuiltinScalarFunction::Cos].

datafusion/src/physical_plan/type_coercion.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ pub fn data_types(
7676
if current_types.is_empty() {
7777
return Ok(vec![]);
7878
}
79+
7980
let valid_types = get_valid_types(&signature.type_signature, current_types)?;
8081

8182
if valid_types
@@ -144,6 +145,8 @@ fn maybe_data_types(
144145
valid_types: &[DataType],
145146
current_types: &[DataType],
146147
) -> Option<Vec<DataType>> {
148+
// TODO liukun4515
149+
147150
if valid_types.len() != current_types.len() {
148151
return None;
149152
}
@@ -155,7 +158,6 @@ fn maybe_data_types(
155158
if current_type == valid_type {
156159
new_type.push(current_type.clone())
157160
} else {
158-
// attempt to coerce
159161
if can_coerce_from(valid_type, current_type) {
160162
new_type.push(valid_type.clone())
161163
} else {
@@ -171,9 +173,11 @@ fn maybe_data_types(
171173
/// (losslessly converted) into a value of `type_to`
172174
///
173175
/// See the module level documentation for more detail on coercion.
174-
pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
176+
fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool {
175177
use self::DataType::*;
178+
// TODO liukun4515
176179
match type_into {
180+
// TODO, decimal data type, we just support the decimal
177181
Int8 => matches!(type_from, Int8),
178182
Int16 => matches!(type_from, Int8 | Int16 | UInt8),
179183
Int32 => matches!(type_from, Int8 | Int16 | Int32 | UInt8 | UInt16),

0 commit comments

Comments
 (0)