Skip to content

Commit

Permalink
add read decimal parquet test and prune test for decimal
Browse files Browse the repository at this point in the history
  • Loading branch information
liukun4515 committed Jul 24, 2022
1 parent 834924f commit ee1cc40
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 2 deletions.
47 changes: 45 additions & 2 deletions datafusion/core/src/datasource/file_format/parquet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -525,8 +525,8 @@ mod tests {
use crate::physical_plan::metrics::MetricValue;
use crate::prelude::{SessionConfig, SessionContext};
use arrow::array::{
ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array,
StringArray, TimestampNanosecondArray,
Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array,
Int32Array, StringArray, TimestampNanosecondArray,
};
use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
Expand Down Expand Up @@ -1023,6 +1023,49 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn read_decimal_parquet() -> Result<()> {
let session_ctx = SessionContext::new();
let task_ctx = session_ctx.task_ctx();

// parquet use the int32 as the physical type to store decimal
let exec = get_exec("int32_decimal.parquet", None, None).await?;
let batches = collect(exec, task_ctx.clone()).await?;
assert_eq!(1, batches.len());
assert_eq!(1, batches[0].num_columns());
let column = batches[0].column(0);
assert_eq!(&DataType::Decimal(4, 2), column.data_type());

// parquet use the int64 as the physical type to store decimal
let exec = get_exec("int64_decimal.parquet", None, None).await?;
let batches = collect(exec, task_ctx.clone()).await?;
assert_eq!(1, batches.len());
assert_eq!(1, batches[0].num_columns());
let column = batches[0].column(0);
assert_eq!(&DataType::Decimal(10, 2), column.data_type());

// parquet use the fixed length binary as the physical type to store decimal
let exec = get_exec("fixed_length_decimal.parquet", None, None).await?;
let batches = collect(exec, task_ctx.clone()).await?;
assert_eq!(1, batches.len());
assert_eq!(1, batches[0].num_columns());
let column = batches[0].column(0);
assert_eq!(&DataType::Decimal(25, 2), column.data_type());

let exec = get_exec("fixed_length_decimal_legacy.parquet", None, None).await?;
let batches = collect(exec, task_ctx.clone()).await?;
assert_eq!(1, batches.len());
assert_eq!(1, batches[0].num_columns());
let column = batches[0].column(0);
assert_eq!(&DataType::Decimal(25, 2), column.data_type());

// parquet use the fixed length binary as the physical type to store decimal
// TODO: arrow-rs don't support convert the physical type of binary to decimal
// let exec = get_exec("byte_array_decimal.parquet", None, None).await?;

Ok(())
}

fn assert_bytes_scanned(exec: Arc<dyn ExecutionPlan>, expected: usize) {
let actual = exec
.metrics()
Expand Down
102 changes: 102 additions & 0 deletions datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -800,10 +800,12 @@ mod tests {
use crate::from_slice::FromSlice;
use crate::logical_plan::{col, lit};
use crate::{assert_batches_eq, physical_optimizer::pruning::StatisticsType};
use arrow::array::DecimalArray;
use arrow::{
array::{BinaryArray, Int32Array, Int64Array, StringArray},
datatypes::{DataType, TimeUnit},
};
use datafusion_common::ScalarValue;
use std::collections::HashMap;

#[derive(Debug)]
Expand All @@ -814,6 +816,38 @@ mod tests {
}

impl ContainerStats {
fn new_decimal128(
min: impl IntoIterator<Item = Option<i128>>,
max: impl IntoIterator<Item = Option<i128>>,
precision: usize,
scale: usize,
) -> Self {
Self {
min: Arc::new(
min.into_iter()
.collect::<DecimalArray>()
.with_precision_and_scale(precision, scale)
.unwrap(),
),
max: Arc::new(
max.into_iter()
.collect::<DecimalArray>()
.with_precision_and_scale(precision, scale)
.unwrap(),
),
}
}

fn new_i64(
min: impl IntoIterator<Item = Option<i64>>,
max: impl IntoIterator<Item = Option<i64>>,
) -> Self {
Self {
min: Arc::new(min.into_iter().collect::<Int64Array>()),
max: Arc::new(max.into_iter().collect::<Int64Array>()),
}
}

fn new_i32(
min: impl IntoIterator<Item = Option<i32>>,
max: impl IntoIterator<Item = Option<i32>>,
Expand Down Expand Up @@ -1418,6 +1452,74 @@ mod tests {
Ok(())
}

#[test]
fn prune_decimal_data() {
// decimal(9,2)
let schema = Arc::new(Schema::new(vec![Field::new(
"s1",
DataType::Decimal(9, 2),
true,
)]));
// s1 > 5
let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2)));
// If the data is written by spark, the physical data type is INT32 in the parquet
// So we use the INT32 type of statistic.
let statistics = TestStatistics::new().with(
"s1",
ContainerStats::new_i32(
vec![Some(0), Some(4), None, Some(3)], // min
vec![Some(5), Some(6), Some(4), None], // max
),
);
let p = PruningPredicate::try_new(expr, schema).unwrap();
let result = p.prune(&statistics).unwrap();
let expected = vec![false, true, false, true];
assert_eq!(result, expected);

// decimal(18,2)
let schema = Arc::new(Schema::new(vec![Field::new(
"s1",
DataType::Decimal(18, 2),
true,
)]));
// s1 > 5
let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 18, 2)));
// If the data is written by spark, the physical data type is INT64 in the parquet
// So we use the INT32 type of statistic.
let statistics = TestStatistics::new().with(
"s1",
ContainerStats::new_i64(
vec![Some(0), Some(4), None, Some(3)], // min
vec![Some(5), Some(6), Some(4), None], // max
),
);
let p = PruningPredicate::try_new(expr, schema).unwrap();
let result = p.prune(&statistics).unwrap();
let expected = vec![false, true, false, true];
assert_eq!(result, expected);

// decimal(23,2)
let schema = Arc::new(Schema::new(vec![Field::new(
"s1",
DataType::Decimal(23, 2),
true,
)]));
// s1 > 5
let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 23, 2)));
let statistics = TestStatistics::new().with(
"s1",
ContainerStats::new_decimal128(
vec![Some(0), Some(400), None, Some(300)], // min
vec![Some(500), Some(600), Some(400), None], // max
23,
2,
),
);
let p = PruningPredicate::try_new(expr, schema).unwrap();
let result = p.prune(&statistics).unwrap();
let expected = vec![false, true, false, true];
assert_eq!(result, expected);
}
#[test]
fn prune_api() {
let schema = Arc::new(Schema::new(vec![
Expand Down

0 comments on commit ee1cc40

Please sign in to comment.