Skip to content

Commit

Permalink
Refactor Expr::Cast to use a struct. (#3931)
Browse files Browse the repository at this point in the history
* Refactor Expr::Cast to use a struct.

* fix

* fix fmt

* fix review
  • Loading branch information
jackwener authored Oct 24, 2022
1 parent e669480 commit b5c23c2
Show file tree
Hide file tree
Showing 20 changed files with 200 additions and 209 deletions.
40 changes: 20 additions & 20 deletions benchmarks/src/bin/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -614,8 +614,8 @@ mod tests {

use datafusion::arrow::array::*;
use datafusion::arrow::util::display::array_value_to_string;
use datafusion::logical_expr::expr::Cast;
use datafusion::logical_expr::Expr;
use datafusion::logical_expr::Expr::Cast;
use datafusion::logical_expr::Expr::ScalarFunction;
use datafusion::sql::TableReference;

Expand Down Expand Up @@ -798,9 +798,9 @@ mod tests {
let path = Path::new(&path);
if let Ok(expected) = read_text_file(path) {
assert_eq!(expected, actual,
// generate output that is easier to copy/paste/update
"\n\nMismatch of expected content in: {:?}\nExpected:\n\n{}\n\nActual:\n\n{}\n\n",
path, expected, actual);
// generate output that is easier to copy/paste/update
"\n\nMismatch of expected content in: {:?}\nExpected:\n\n{}\n\nActual:\n\n{}\n\n",
path, expected, actual);
found = true;
break;
}
Expand Down Expand Up @@ -1264,10 +1264,10 @@ mod tests {
args: vec![col(Field::name(field)).mul(lit(100))],
}.div(lit(100)));
Expr::Alias(
Box::new(Cast {
expr: round,
data_type: DataType::Decimal128(38, 2),
}),
Box::new(Expr::Cast(Cast::new(
round,
DataType::Decimal128(38, 2),
))),
Field::name(field).to_string(),
)
}
Expand Down Expand Up @@ -1343,23 +1343,23 @@ mod tests {
DataType::Decimal128(_, _) => {
// there's no support for casting from Utf8 to Decimal, so
// we'll cast from Utf8 to Float64 to Decimal for Decimal types
let inner_cast = Box::new(Cast {
expr: Box::new(trim(col(Field::name(field)))),
data_type: DataType::Float64,
});
let inner_cast = Box::new(Expr::Cast(Cast::new(
Box::new(trim(col(Field::name(field)))),
DataType::Float64,
)));
Expr::Alias(
Box::new(Cast {
expr: inner_cast,
data_type: Field::data_type(field).to_owned(),
}),
Box::new(Expr::Cast(Cast::new(
inner_cast,
Field::data_type(field).to_owned(),
))),
Field::name(field).to_string(),
)
}
_ => Expr::Alias(
Box::new(Cast {
expr: Box::new(trim(col(Field::name(field)))),
data_type: Field::data_type(field).to_owned(),
}),
Box::new(Expr::Cast(Cast::new(
Box::new(trim(col(Field::name(field)))),
Field::data_type(field).to_owned(),
))),
Field::name(field).to_string(),
),
}
Expand Down
13 changes: 6 additions & 7 deletions datafusion/core/src/physical_optimizer/pruning.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use arrow::{
record_batch::RecordBatch,
};
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_expr::expr::BinaryExpr;
use datafusion_expr::expr::{BinaryExpr, Cast};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
use datafusion_expr::utils::expr_to_columns;
use datafusion_expr::{binary_expr, cast, try_cast, ExprSchemable};
Expand Down Expand Up @@ -190,11 +190,10 @@ impl PruningPredicate {
let predicate_array = downcast_value!(array, BooleanArray);

Ok(predicate_array
.into_iter()
.map(|x| x.unwrap_or(true)) // None -> true per comments above
.collect::<Vec<_>>())

},
.into_iter()
.map(|x| x.unwrap_or(true)) // None -> true per comments above
.collect::<Vec<_>>())
}
// result was a column
ColumnarValue::Scalar(ScalarValue::Boolean(v)) => {
let v = v.unwrap_or(true); // None -> true per comments above
Expand Down Expand Up @@ -530,7 +529,7 @@ fn rewrite_expr_to_prunable(
// `col op lit()`
Expr::Column(_) => Ok((column_expr.clone(), op, scalar_expr.clone())),
// `cast(col) op lit()`
Expr::Cast { expr, data_type } => {
Expr::Cast(Cast { expr, data_type }) => {
let from_type = expr.get_type(&schema)?;
verify_support_type_for_prune(&from_type, data_type)?;
let (left, op, right) =
Expand Down
51 changes: 27 additions & 24 deletions datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,9 @@ use arrow::compute::SortOptions;
use arrow::datatypes::{Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::expr::{Between, BinaryExpr, GetIndexedField, GroupingSet, Like};
use datafusion_expr::expr::{
Between, BinaryExpr, Cast, GetIndexedField, GroupingSet, Like,
};
use datafusion_expr::expr_rewriter::unnormalize_cols;
use datafusion_expr::utils::{expand_wildcard, expr_to_columns};
use datafusion_expr::WindowFrameUnits;
Expand Down Expand Up @@ -126,7 +128,7 @@ fn create_physical_name(e: &Expr, is_first_expr: bool) -> Result<String> {
name += "END";
Ok(name)
}
Expr::Cast { expr, .. } => {
Expr::Cast(Cast { expr, .. }) => {
// CAST does not change the expression name
create_physical_name(expr, false)
}
Expand Down Expand Up @@ -462,7 +464,7 @@ impl DefaultPhysicalPlanner {
) -> BoxFuture<'a, Result<Arc<dyn ExecutionPlan>>> {
async move {
let exec_plan: Result<Arc<dyn ExecutionPlan>> = match logical_plan {
LogicalPlan::TableScan (TableScan {
LogicalPlan::TableScan(TableScan {
source,
projection,
filters,
Expand All @@ -484,7 +486,7 @@ impl DefaultPhysicalPlanner {
let exec_schema = schema.as_ref().to_owned().into();
let exprs = values.iter()
.map(|row| {
row.iter().map(|expr|{
row.iter().map(|expr| {
self.create_physical_expr(
expr,
schema,
Expand All @@ -497,7 +499,7 @@ impl DefaultPhysicalPlanner {
.collect::<Result<Vec<_>>>()?;
let value_exec = ValuesExec::try_new(
SchemaRef::new(exec_schema),
exprs
exprs,
)?;
Ok(Arc::new(value_exec))
}
Expand Down Expand Up @@ -612,7 +614,7 @@ impl DefaultPhysicalPlanner {
window_expr,
input_exec,
physical_input_schema,
)?) )
)?))
}
LogicalPlan::Aggregate(Aggregate {
input,
Expand Down Expand Up @@ -692,16 +694,16 @@ impl DefaultPhysicalPlanner {
aggregates,
initial_aggr,
physical_input_schema.clone(),
)?) )
)?))
}
LogicalPlan::Distinct(Distinct {input}) => {
LogicalPlan::Distinct(Distinct { input }) => {
// Convert distinct to groupby with no aggregations
let group_expr = expand_wildcard(input.schema(), input)?;
let aggregate = LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
input.clone(),
group_expr,
vec![],
input.schema().clone() // input schema and aggregate schema are the same in this case
let aggregate = LogicalPlan::Aggregate(Aggregate::try_new_with_schema(
input.clone(),
group_expr,
vec![],
input.schema().clone(), // input schema and aggregate schema are the same in this case
)?);
Ok(self.create_initial_plan(&aggregate, session_state).await?)
}
Expand Down Expand Up @@ -755,7 +757,7 @@ impl DefaultPhysicalPlanner {
Ok(Arc::new(ProjectionExec::try_new(
physical_exprs,
input_exec,
)?) )
)?))
}
LogicalPlan::Filter(filter) => {
let physical_input = self.create_initial_plan(filter.input(), session_state).await?;
Expand All @@ -768,14 +770,14 @@ impl DefaultPhysicalPlanner {
&input_schema,
session_state,
)?;
Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?) )
Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?))
}
LogicalPlan::Union(Union { inputs, .. }) => {
let physical_plans = futures::stream::iter(inputs)
.then(|lp| self.create_initial_plan(lp, session_state))
.try_collect::<Vec<_>>()
.await?;
Ok(Arc::new(UnionExec::new(physical_plans)) )
Ok(Arc::new(UnionExec::new(physical_plans)))
}
LogicalPlan::Repartition(Repartition {
input,
Expand Down Expand Up @@ -803,13 +805,13 @@ impl DefaultPhysicalPlanner {
Partitioning::Hash(runtime_expr, *n)
}
LogicalPartitioning::DistributeBy(_) => {
return Err(DataFusionError::NotImplemented("Physical plan does not support DistributeBy partitioning".to_string()))
return Err(DataFusionError::NotImplemented("Physical plan does not support DistributeBy partitioning".to_string()));
}
};
Ok(Arc::new(RepartitionExec::try_new(
physical_input,
physical_partitioning,
)?) )
)?))
}
LogicalPlan::Sort(Sort { expr, input, fetch, .. }) => {
let physical_input = self.create_initial_plan(input, session_state).await?;
Expand Down Expand Up @@ -852,7 +854,8 @@ impl DefaultPhysicalPlanner {
Arc::new(merge)
} else {
Arc::new(SortExec::try_new(sort_expr, physical_input, *fetch)?)
}) }
})
}
LogicalPlan::Join(Join {
left,
right,
Expand Down Expand Up @@ -922,14 +925,14 @@ impl DefaultPhysicalPlanner {
expr,
&filter_df_schema,
&filter_schema,
&session_state.execution_props
&session_state.execution_props,
)?;
let column_indices = join_utils::JoinFilter::build_column_indices(left_field_indices, right_field_indices);

Some(join_utils::JoinFilter::new(
filter_expr,
column_indices,
filter_schema
filter_schema,
))
}
_ => None
Expand Down Expand Up @@ -995,15 +998,15 @@ impl DefaultPhysicalPlanner {
*produce_one_row,
SchemaRef::new(schema.as_ref().to_owned().into()),
))),
LogicalPlan::SubqueryAlias(SubqueryAlias { input,.. }) => {
LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => {
match input.as_ref() {
LogicalPlan::TableScan(..) => {
self.create_initial_plan(input, session_state).await
}
_ => Err(DataFusionError::Plan("SubqueryAlias should only wrap TableScan".to_string()))
}
}
LogicalPlan::Limit(Limit { input, skip, fetch,.. }) => {
LogicalPlan::Limit(Limit { input, skip, fetch, .. }) => {
let input = self.create_initial_plan(input, session_state).await?;

// GlobalLimitExec requires a single partition for input
Expand Down Expand Up @@ -1055,7 +1058,7 @@ impl DefaultPhysicalPlanner {
SchemaRef::new(Schema::empty()),
)))
}
LogicalPlan::Explain (_) => Err(DataFusionError::Internal(
LogicalPlan::Explain(_) => Err(DataFusionError::Internal(
"Unsupported logical plan: Explain must be root of the plan".to_string(),
)),
LogicalPlan::Analyze(a) => {
Expand Down
10 changes: 5 additions & 5 deletions datafusion/core/tests/provider_filter_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use datafusion::physical_plan::{
use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;
use datafusion_common::DataFusionError;
use datafusion_expr::expr::BinaryExpr;
use datafusion_expr::expr::{BinaryExpr, Cast};
use std::ops::Deref;
use std::sync::Arc;

Expand Down Expand Up @@ -153,7 +153,7 @@ impl TableProvider for CustomProvider {
Expr::Literal(ScalarValue::Int16(Some(i))) => *i as i64,
Expr::Literal(ScalarValue::Int32(Some(i))) => *i as i64,
Expr::Literal(ScalarValue::Int64(Some(i))) => *i as i64,
Expr::Cast { expr, data_type: _ } => match expr.deref() {
Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() {
Expr::Literal(lit_value) => match lit_value {
ScalarValue::Int8(Some(v)) => *v as i64,
ScalarValue::Int16(Some(v)) => *v as i64,
Expand All @@ -163,21 +163,21 @@ impl TableProvider for CustomProvider {
return Err(DataFusionError::NotImplemented(format!(
"Do not support value {:?}",
other_value
)))
)));
}
},
other_expr => {
return Err(DataFusionError::NotImplemented(format!(
"Do not support expr {:?}",
other_expr
)))
)));
}
},
other_expr => {
return Err(DataFusionError::NotImplemented(format!(
"Do not support expr {:?}",
other_expr
)))
)));
}
};

Expand Down
Loading

0 comments on commit b5c23c2

Please sign in to comment.