Skip to content

Commit

Permalink
Merge remote-tracking branch 'apache/master' into alamb/consolidate_p…
Browse files Browse the repository at this point in the history
…arquet
  • Loading branch information
alamb committed Oct 24, 2022
2 parents 8180c3a + e1f866e commit 1f2067f
Show file tree
Hide file tree
Showing 39 changed files with 855 additions and 283 deletions.
14 changes: 12 additions & 2 deletions benchmarks/expected-plans/q15.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,18 @@ Sort: supplier.s_suppkey ASC NULLS LAST
Inner Join: revenue0.total_revenue = __sq_1.__value
Inner Join: supplier.s_suppkey = revenue0.supplier_no
TableScan: supplier projection=[s_suppkey, s_name, s_address, s_phone]
TableScan: revenue0 projection=[supplier_no, total_revenue]
Projection: supplier_no, total_revenue, alias=revenue0
Projection: lineitem.l_suppkey AS supplier_no, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue
Projection: lineitem.l_suppkey, SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)
Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[SUM(CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]]
Filter: lineitem.l_shipdate >= Date32("9496") AND lineitem.l_shipdate < Date32("9587")
TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate]
Projection: MAX(revenue0.total_revenue) AS __value, alias=__sq_1
Aggregate: groupBy=[[]], aggr=[[MAX(revenue0.total_revenue)]]
TableScan: revenue0 projection=[total_revenue]
Projection: total_revenue, alias=revenue0
Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS total_revenue
Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)
Aggregate: groupBy=[[lineitem.l_suppkey]], aggr=[[SUM(CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]]
Filter: lineitem.l_shipdate >= Date32("9496") AND lineitem.l_shipdate < Date32("9587")
TableScan: lineitem projection=[l_suppkey, l_extendedprice, l_discount, l_shipdate]
EmptyRelation
40 changes: 20 additions & 20 deletions benchmarks/src/bin/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,8 +615,8 @@ mod tests {

use datafusion::arrow::array::*;
use datafusion::arrow::util::display::array_value_to_string;
use datafusion::logical_expr::expr::Cast;
use datafusion::logical_expr::Expr;
use datafusion::logical_expr::Expr::Cast;
use datafusion::logical_expr::Expr::ScalarFunction;
use datafusion::sql::TableReference;

Expand Down Expand Up @@ -799,9 +799,9 @@ mod tests {
let path = Path::new(&path);
if let Ok(expected) = read_text_file(path) {
assert_eq!(expected, actual,
// generate output that is easier to copy/paste/update
"\n\nMismatch of expected content in: {:?}\nExpected:\n\n{}\n\nActual:\n\n{}\n\n",
path, expected, actual);
// generate output that is easier to copy/paste/update
"\n\nMismatch of expected content in: {:?}\nExpected:\n\n{}\n\nActual:\n\n{}\n\n",
path, expected, actual);
found = true;
break;
}
Expand Down Expand Up @@ -1265,10 +1265,10 @@ mod tests {
args: vec![col(Field::name(field)).mul(lit(100))],
}.div(lit(100)));
Expr::Alias(
Box::new(Cast {
expr: round,
data_type: DataType::Decimal128(38, 2),
}),
Box::new(Expr::Cast(Cast::new(
round,
DataType::Decimal128(38, 2),
))),
Field::name(field).to_string(),
)
}
Expand Down Expand Up @@ -1344,23 +1344,23 @@ mod tests {
DataType::Decimal128(_, _) => {
// there's no support for casting from Utf8 to Decimal, so
// we'll cast from Utf8 to Float64 to Decimal for Decimal types
let inner_cast = Box::new(Cast {
expr: Box::new(trim(col(Field::name(field)))),
data_type: DataType::Float64,
});
let inner_cast = Box::new(Expr::Cast(Cast::new(
Box::new(trim(col(Field::name(field)))),
DataType::Float64,
)));
Expr::Alias(
Box::new(Cast {
expr: inner_cast,
data_type: Field::data_type(field).to_owned(),
}),
Box::new(Expr::Cast(Cast::new(
inner_cast,
Field::data_type(field).to_owned(),
))),
Field::name(field).to_string(),
)
}
_ => Expr::Alias(
Box::new(Cast {
expr: Box::new(trim(col(Field::name(field)))),
data_type: Field::data_type(field).to_owned(),
}),
Box::new(Expr::Cast(Cast::new(
Box::new(trim(col(Field::name(field)))),
Field::data_type(field).to_owned(),
))),
Field::name(field).to_string(),
),
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/common/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pyarrow = ["pyo3", "arrow/pyarrow"]
[dependencies]
apache-avro = { version = "0.14", default-features = false, features = ["snappy"], optional = true }
arrow = { version = "25.0.0", default-features = false }
cranelift-module = { version = "0.88.0", optional = true }
cranelift-module = { version = "0.89.0", optional = true }
object_store = { version = "0.5.0", default-features = false, optional = true }
ordered-float = "3.0"
parquet = { version = "25.0.0", default-features = false, optional = true }
Expand Down
79 changes: 78 additions & 1 deletion datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ pub enum ScalarValue {
LargeUtf8(Option<String>),
/// binary
Binary(Option<Vec<u8>>),
/// fixed size binary
FixedSizeBinary(i32, Option<Vec<u8>>),
/// large binary
LargeBinary(Option<Vec<u8>>),
/// list of nested ScalarValue
Expand Down Expand Up @@ -159,6 +161,8 @@ impl PartialEq for ScalarValue {
(LargeUtf8(_), _) => false,
(Binary(v1), Binary(v2)) => v1.eq(v2),
(Binary(_), _) => false,
(FixedSizeBinary(_, v1), FixedSizeBinary(_, v2)) => v1.eq(v2),
(FixedSizeBinary(_, _), _) => false,
(LargeBinary(v1), LargeBinary(v2)) => v1.eq(v2),
(LargeBinary(_), _) => false,
(List(v1, t1), List(v2, t2)) => v1.eq(v2) && t1.eq(t2),
Expand Down Expand Up @@ -247,6 +251,8 @@ impl PartialOrd for ScalarValue {
(LargeUtf8(_), _) => None,
(Binary(v1), Binary(v2)) => v1.partial_cmp(v2),
(Binary(_), _) => None,
(FixedSizeBinary(_, v1), FixedSizeBinary(_, v2)) => v1.partial_cmp(v2),
(FixedSizeBinary(_, _), _) => None,
(LargeBinary(v1), LargeBinary(v2)) => v1.partial_cmp(v2),
(LargeBinary(_), _) => None,
(List(v1, t1), List(v2, t2)) => {
Expand Down Expand Up @@ -536,6 +542,7 @@ impl std::hash::Hash for ScalarValue {
Utf8(v) => v.hash(state),
LargeUtf8(v) => v.hash(state),
Binary(v) => v.hash(state),
FixedSizeBinary(_, v) => v.hash(state),
LargeBinary(v) => v.hash(state),
List(v, t) => {
v.hash(state);
Expand Down Expand Up @@ -900,6 +907,7 @@ impl ScalarValue {
ScalarValue::Utf8(_) => DataType::Utf8,
ScalarValue::LargeUtf8(_) => DataType::LargeUtf8,
ScalarValue::Binary(_) => DataType::Binary,
ScalarValue::FixedSizeBinary(sz, _) => DataType::FixedSizeBinary(*sz),
ScalarValue::LargeBinary(_) => DataType::LargeBinary,
ScalarValue::List(_, field) => DataType::List(Box::new(Field::new(
"item",
Expand Down Expand Up @@ -987,6 +995,7 @@ impl ScalarValue {
ScalarValue::Utf8(v) => v.is_none(),
ScalarValue::LargeUtf8(v) => v.is_none(),
ScalarValue::Binary(v) => v.is_none(),
ScalarValue::FixedSizeBinary(_, v) => v.is_none(),
ScalarValue::LargeBinary(v) => v.is_none(),
ScalarValue::List(v, _) => v.is_none(),
ScalarValue::Date32(v) => v.is_none(),
Expand Down Expand Up @@ -1393,13 +1402,30 @@ impl ScalarValue {
_ => unreachable!("Invalid dictionary keys type: {:?}", key_type),
}
}
DataType::FixedSizeBinary(_) => {
let array = scalars
.map(|sv| {
if let ScalarValue::FixedSizeBinary(_, v) = sv {
Ok(v)
} else {
Err(DataFusionError::Internal(format!(
"Inconsistent types in ScalarValue::iter_to_array. \
Expected {:?}, got {:?}",
data_type, sv
)))
}
})
.collect::<Result<Vec<_>>>()?;
let array =
FixedSizeBinaryArray::try_from_sparse_iter(array.into_iter())?;
Arc::new(array)
}
// explicitly enumerate unsupported types so newly added
// types must be aknowledged
DataType::Float16
| DataType::Time32(_)
| DataType::Time64(_)
| DataType::Duration(_)
| DataType::FixedSizeBinary(_)
| DataType::FixedSizeList(_, _)
| DataType::Interval(_)
| DataType::LargeList(_)
Expand Down Expand Up @@ -1602,6 +1628,20 @@ impl ScalarValue {
Arc::new(repeat(None::<&str>).take(size).collect::<BinaryArray>())
}
},
ScalarValue::FixedSizeBinary(_, e) => match e {
Some(value) => Arc::new(
FixedSizeBinaryArray::try_from_sparse_iter(
repeat(Some(value.as_slice())).take(size),
)
.unwrap(),
),
None => Arc::new(
FixedSizeBinaryArray::try_from_sparse_iter(
repeat(None::<&[u8]>).take(size),
)
.unwrap(),
),
},
ScalarValue::LargeBinary(e) => match e {
Some(value) => Arc::new(
repeat(Some(value.as_slice()))
Expand Down Expand Up @@ -1887,6 +1927,23 @@ impl ScalarValue {
};
ScalarValue::new_list(value, nested_type.data_type().clone())
}
DataType::FixedSizeBinary(_) => {
let array = array
.as_any()
.downcast_ref::<FixedSizeBinaryArray>()
.unwrap();
let size = match array.data_type() {
DataType::FixedSizeBinary(size) => *size,
_ => unreachable!(),
};
ScalarValue::FixedSizeBinary(
size,
match array.is_null(index) {
true => None,
false => Some(array.value(index).into()),
},
)
}
other => {
return Err(DataFusionError::NotImplemented(format!(
"Can't create a scalar from array of type \"{:?}\"",
Expand Down Expand Up @@ -1973,6 +2030,9 @@ impl ScalarValue {
ScalarValue::Binary(val) => {
eq_array_primitive!(array, index, BinaryArray, val)
}
ScalarValue::FixedSizeBinary(_, val) => {
eq_array_primitive!(array, index, FixedSizeBinaryArray, val)
}
ScalarValue::LargeBinary(val) => {
eq_array_primitive!(array, index, LargeBinaryArray, val)
}
Expand Down Expand Up @@ -2317,6 +2377,17 @@ impl fmt::Display for ScalarValue {
)?,
None => write!(f, "NULL")?,
},
ScalarValue::FixedSizeBinary(_, e) => match e {
Some(l) => write!(
f,
"{}",
l.iter()
.map(|v| format!("{}", v))
.collect::<Vec<_>>()
.join(",")
)?,
None => write!(f, "NULL")?,
},
ScalarValue::LargeBinary(e) => match e {
Some(l) => write!(
f,
Expand Down Expand Up @@ -2397,6 +2468,12 @@ impl fmt::Debug for ScalarValue {
ScalarValue::LargeUtf8(Some(_)) => write!(f, "LargeUtf8(\"{}\")", self),
ScalarValue::Binary(None) => write!(f, "Binary({})", self),
ScalarValue::Binary(Some(_)) => write!(f, "Binary(\"{}\")", self),
ScalarValue::FixedSizeBinary(size, None) => {
write!(f, "FixedSizeBinary({}, {})", size, self)
}
ScalarValue::FixedSizeBinary(size, Some(_)) => {
write!(f, "FixedSizeBinary({}, \"{}\")", size, self)
}
ScalarValue::LargeBinary(None) => write!(f, "LargeBinary({})", self),
ScalarValue::LargeBinary(Some(_)) => write!(f, "LargeBinary(\"{}\")", self),
ScalarValue::List(_, _) => write!(f, "List([{}])", self),
Expand Down
67 changes: 56 additions & 11 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ use crate::physical_plan::SendableRecordBatchStream;
use crate::physical_plan::{collect, collect_partitioned};
use crate::physical_plan::{execute_stream, execute_stream_partitioned, ExecutionPlan};
use crate::prelude::SessionContext;
use crate::scalar::ScalarValue;
use async_trait::async_trait;
use datafusion_common::{Column, DFSchema};
use datafusion_expr::TableProviderFilterPushDown;
use parking_lot::RwLock;
use parquet::file::properties::WriterProperties;
use std::any::Any;
Expand Down Expand Up @@ -773,6 +773,18 @@ impl TableProvider for DataFrame {
self
}

fn get_logical_plan(&self) -> Option<&LogicalPlan> {
Some(&self.plan)
}

fn supports_filter_pushdown(
&self,
_filter: &Expr,
) -> Result<TableProviderFilterPushDown> {
// A filter is added on the DataFrame when given
Ok(TableProviderFilterPushDown::Exact)
}

fn schema(&self) -> SchemaRef {
let schema: Schema = self.plan.schema().as_ref().into();
Arc::new(schema)
Expand All @@ -789,7 +801,7 @@ impl TableProvider for DataFrame {
filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let expr = projection
let mut expr = projection
.as_ref()
// construct projections
.map_or_else(
Expand All @@ -806,12 +818,12 @@ impl TableProvider for DataFrame {
.collect::<Vec<_>>();
self.select_columns(names.as_slice())
},
)?
// add predicates, otherwise use `true` as the predicate
.filter(filters.iter().cloned().fold(
Expr::Literal(ScalarValue::Boolean(Some(true))),
|acc, new| acc.and(new),
))?;
)?;
// Add filter when given
let filter = filters.iter().cloned().reduce(|acc, new| acc.and(new));
if let Some(filter) = filter {
expr = expr.filter(filter)?
}
// add a limit if given
Self::new(
self.session_state.clone(),
Expand All @@ -830,9 +842,10 @@ mod tests {
use std::vec;

use super::*;
use crate::execution::options::CsvReadOptions;
use crate::execution::options::{CsvReadOptions, ParquetReadOptions};
use crate::physical_plan::ColumnarValue;
use crate::test_util;
use crate::test_util::parquet_test_data;
use crate::{assert_batches_sorted_eq, execution::context::SessionContext};
use arrow::array::Int32Array;
use arrow::datatypes::DataType;
Expand Down Expand Up @@ -1328,8 +1341,12 @@ mod tests {
\n Limit: skip=0, fetch=1\
\n Sort: t1.c1 ASC NULLS FIRST, t1.c2 ASC NULLS FIRST, t1.c3 ASC NULLS FIRST, t2.c1 ASC NULLS FIRST, t2.c2 ASC NULLS FIRST, t2.c3 ASC NULLS FIRST, fetch=1\
\n Inner Join: t1.c1 = t2.c1\
\n TableScan: t1 projection=[c1, c2, c3]\
\n TableScan: t2 projection=[c1, c2, c3]",
\n Projection: aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c3, alias=t1\
\n Projection: aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c3\
\n TableScan: aggregate_test_100 projection=[c1, c2, c3]\
\n Projection: aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c3, alias=t2\
\n Projection: aggregate_test_100.c1, aggregate_test_100.c2, aggregate_test_100.c3\
\n TableScan: aggregate_test_100 projection=[c1, c2, c3]",
format!("{:?}", df_renamed.to_logical_plan()?)
);

Expand All @@ -1349,6 +1366,34 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn filter_pushdown_dataframe() -> Result<()> {
let ctx = SessionContext::new();

ctx.register_parquet(
"test",
&format!("{}/alltypes_plain.snappy.parquet", parquet_test_data()),
ParquetReadOptions::default(),
)
.await?;

ctx.register_table("t1", ctx.table("test")?)?;

let df = ctx
.table("t1")?
.filter(col("id").eq(lit(1)))?
.select_columns(&["bool_col", "int_col"])?;

let plan = df.explain(false, false)?.collect().await?;
// Filters all the way to Parquet
let formatted = arrow::util::pretty::pretty_format_batches(&plan)
.unwrap()
.to_string();
assert!(formatted.contains("predicate=id_min@0 <= 1 AND 1 <= id_max@1"));

Ok(())
}

#[tokio::test]
async fn cast_expr_test() -> Result<()> {
let df = test_table()
Expand Down
Loading

0 comments on commit 1f2067f

Please sign in to comment.