Skip to content

Commit

Permalink
Remove CountWildcardRule in Analyzer and move the functionality in Ex…
Browse files Browse the repository at this point in the history
…prPlanner, add `plan_aggregate` and `plan_window` to planner (#14689)

* count planner

* window

* update slt

* remove rule

* rm rule

* doc

* fix name

* fix name

* fix test

* tpch test

* fix avro

* rename

* switch to count(*)

* use count(*)

* rename

* doc

* rename window funciotn

* fmt

* rm print

* upd logic

* count null
  • Loading branch information
jayzhan211 authored Feb 21, 2025
1 parent 22156b2 commit e03f9f6
Show file tree
Hide file tree
Showing 42 changed files with 652 additions and 442 deletions.
2 changes: 2 additions & 0 deletions datafusion/core/src/execution/session_state_defaults.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ impl SessionStateDefaults {
feature = "unicode_expressions"
))]
Arc::new(functions::planner::UserDefinedFunctionPlanner),
Arc::new(functions_aggregate::planner::AggregateFunctionPlanner),
Arc::new(functions_window::planner::WindowFunctionPlanner),
];

expr_planners
Expand Down
34 changes: 33 additions & 1 deletion datafusion/core/tests/dataframe/dataframe_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ use arrow::{
array::{Int32Array, StringArray},
record_batch::RecordBatch,
};
use datafusion_functions_aggregate::count::count_all;
use std::sync::Arc;

use datafusion::error::Result;
Expand All @@ -31,7 +32,7 @@ use datafusion::prelude::*;
use datafusion::assert_batches_eq;
use datafusion_common::{DFSchema, ScalarValue};
use datafusion_expr::expr::Alias;
use datafusion_expr::ExprSchemable;
use datafusion_expr::{table_scan, ExprSchemable, LogicalPlanBuilder};
use datafusion_functions_aggregate::expr_fn::{approx_median, approx_percentile_cont};
use datafusion_functions_nested::map::map;

Expand Down Expand Up @@ -1123,3 +1124,34 @@ async fn test_fn_map() -> Result<()> {

Ok(())
}

/// Call count wildcard from dataframe API
#[tokio::test]
async fn test_count_wildcard() -> Result<()> {
let schema = Schema::new(vec![
Field::new("a", DataType::UInt32, false),
Field::new("b", DataType::UInt32, false),
Field::new("c", DataType::UInt32, false),
]);

let table_scan = table_scan(Some("test"), &schema, None)?.build()?;
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("b")], vec![count_all()])
.unwrap()
.project(vec![count_all()])
.unwrap()
.sort(vec![count_all().sort(true, false)])
.unwrap()
.build()
.unwrap();

let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\
\n Projection: count(*) [count(*):Int64]\
\n Aggregate: groupBy=[[test.b]], aggr=[[count(*)]] [b:UInt32, count(*):Int64]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

let formatted_plan = plan.display_indent_schema().to_string();
assert_eq!(formatted_plan, expected);

Ok(())
}
30 changes: 15 additions & 15 deletions datafusion/core/tests/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ use arrow::datatypes::{
};
use arrow::error::ArrowError;
use arrow::util::pretty::pretty_format_batches;
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_functions_aggregate::count::{count_all, count_udaf};
use datafusion_functions_aggregate::expr_fn::{
array_agg, avg, count, count_distinct, max, median, min, sum,
};
Expand Down Expand Up @@ -72,7 +73,7 @@ use datafusion_expr::expr::{GroupingSet, Sort, WindowFunction};
use datafusion_expr::var_provider::{VarProvider, VarType};
use datafusion_expr::{
cast, col, create_udf, exists, in_subquery, lit, out_ref_col, placeholder,
scalar_subquery, when, wildcard, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan,
scalar_subquery, when, Expr, ExprFunctionExt, ExprSchemable, LogicalPlan,
ScalarFunctionImplementation, WindowFrame, WindowFrameBound, WindowFrameUnits,
WindowFunctionDefinition,
};
Expand Down Expand Up @@ -2463,8 +2464,8 @@ async fn test_count_wildcard_on_sort() -> Result<()> {
let df_results = ctx
.table("t1")
.await?
.aggregate(vec![col("b")], vec![count(wildcard())])?
.sort(vec![count(wildcard()).sort(true, false)])?
.aggregate(vec![col("b")], vec![count_all()])?
.sort(vec![count_all().sort(true, false)])?
.explain(false, false)?
.collect()
.await?;
Expand Down Expand Up @@ -2498,8 +2499,8 @@ async fn test_count_wildcard_on_where_in() -> Result<()> {
Arc::new(
ctx.table("t2")
.await?
.aggregate(vec![], vec![count(wildcard())])?
.select(vec![count(wildcard())])?
.aggregate(vec![], vec![count_all()])?
.select(vec![count_all()])?
.into_optimized_plan()?,
),
))?
Expand Down Expand Up @@ -2532,8 +2533,8 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> {
.filter(exists(Arc::new(
ctx.table("t2")
.await?
.aggregate(vec![], vec![count(wildcard())])?
.select(vec![count(wildcard())])?
.aggregate(vec![], vec![count_all()])?
.select(vec![count_all()])?
.into_unoptimized_plan(),
// Usually, into_optimized_plan() should be used here, but due to
// https://github.com/apache/datafusion/issues/5771,
Expand Down Expand Up @@ -2568,7 +2569,7 @@ async fn test_count_wildcard_on_window() -> Result<()> {
.await?
.select(vec![Expr::WindowFunction(WindowFunction::new(
WindowFunctionDefinition::AggregateUDF(count_udaf()),
vec![wildcard()],
vec![Expr::Literal(COUNT_STAR_EXPANSION)],
))
.order_by(vec![Sort::new(col("a"), false, true)])
.window_frame(WindowFrame::new_bounds(
Expand Down Expand Up @@ -2599,17 +2600,16 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> {
let sql_results = ctx
.sql("select count(*) from t1")
.await?
.select(vec![col("count(*)")])?
.explain(false, false)?
.collect()
.await?;

// add `.select(vec![count(wildcard())])?` to make sure we can analyze all node instead of just top node.
// add `.select(vec![count_wildcard()])?` to make sure we can analyze all node instead of just top node.
let df_results = ctx
.table("t1")
.await?
.aggregate(vec![], vec![count(wildcard())])?
.select(vec![count(wildcard())])?
.aggregate(vec![], vec![count_all()])?
.select(vec![count_all()])?
.explain(false, false)?
.collect()
.await?;
Expand Down Expand Up @@ -2646,8 +2646,8 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> {
ctx.table("t2")
.await?
.filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))?
.aggregate(vec![], vec![count(wildcard())])?
.select(vec![col(count(wildcard()).to_string())])?
.aggregate(vec![], vec![count_all()])?
.select(vec![col(count_all().to_string())])?
.into_unoptimized_plan(),
))
.gt(lit(ScalarValue::UInt8(Some(0)))),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ async fn explain_logical_plan_only() {
let expected = vec![
vec![
"logical_plan",
"Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]\
"Aggregate: groupBy=[[]], aggr=[[count(*)]]\
\n SubqueryAlias: t\
\n Projection: \
\n Values: (Utf8(\"a\"), Int64(1), Int64(100)), (Utf8(\"a\"), Int64(2), Int64(150))"
Expand Down
1 change: 0 additions & 1 deletion datafusion/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2294,7 +2294,6 @@ impl Display for SchemaDisplay<'_> {
| Expr::OuterReferenceColumn(..)
| Expr::Placeholder(_)
| Expr::Wildcard { .. } => write!(f, "{}", self.0),

Expr::AggregateFunction(AggregateFunction { func, params }) => {
match func.schema_name(params) {
Ok(name) => {
Expand Down
62 changes: 53 additions & 9 deletions datafusion/expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@ use datafusion_common::{
config::ConfigOptions, file_options::file_type::FileType, not_impl_err, DFSchema,
Result, TableReference,
};
use sqlparser::ast;
use sqlparser::ast::{self, NullTreatment};

use crate::{AggregateUDF, Expr, GetFieldAccess, ScalarUDF, TableSource, WindowUDF};
use crate::{
AggregateUDF, Expr, GetFieldAccess, ScalarUDF, SortExpr, TableSource, WindowFrame,
WindowFunctionDefinition, WindowUDF,
};

/// Provides the `SQL` query planner meta-data about tables and
/// functions referenced in SQL statements, without a direct dependency on the
Expand Down Expand Up @@ -138,7 +141,7 @@ pub trait ExprPlanner: Debug + Send + Sync {

/// Plan an array literal, such as `[1, 2, 3]`
///
/// Returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
fn plan_array_literal(
&self,
exprs: Vec<Expr>,
Expand All @@ -149,14 +152,14 @@ pub trait ExprPlanner: Debug + Send + Sync {

/// Plan a `POSITION` expression, such as `POSITION(<expr> in <expr>)`
///
/// returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
fn plan_position(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Original(args))
}

/// Plan a dictionary literal, such as `{ key: value, ...}`
///
/// Returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
fn plan_dictionary_literal(
&self,
expr: RawDictionaryExpr,
Expand All @@ -167,14 +170,14 @@ pub trait ExprPlanner: Debug + Send + Sync {

/// Plan an extract expression, such as`EXTRACT(month FROM foo)`
///
/// Returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
fn plan_extract(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Original(args))
}

/// Plan an substring expression, such as `SUBSTRING(<expr> [FROM <expr>] [FOR <expr>])`
///
/// Returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
fn plan_substring(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Original(args))
}
Expand All @@ -195,14 +198,14 @@ pub trait ExprPlanner: Debug + Send + Sync {

/// Plans an overlay expression, such as `overlay(str PLACING substr FROM pos [FOR count])`
///
/// Returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
fn plan_overlay(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Original(args))
}

/// Plans a `make_map` expression, such as `make_map(key1, value1, key2, value2, ...)`
///
/// Returns origin expression arguments if not possible
/// Returns original expression arguments if not possible
fn plan_make_map(&self, args: Vec<Expr>) -> Result<PlannerResult<Vec<Expr>>> {
Ok(PlannerResult::Original(args))
}
Expand Down Expand Up @@ -230,6 +233,23 @@ pub trait ExprPlanner: Debug + Send + Sync {
fn plan_any(&self, expr: RawBinaryExpr) -> Result<PlannerResult<RawBinaryExpr>> {
Ok(PlannerResult::Original(expr))
}

/// Plans aggregate functions, such as `COUNT(<expr>)`
///
/// Returns original expression arguments if not possible
fn plan_aggregate(
&self,
expr: RawAggregateExpr,
) -> Result<PlannerResult<RawAggregateExpr>> {
Ok(PlannerResult::Original(expr))
}

/// Plans window functions, such as `COUNT(<expr>)`
///
/// Returns original expression arguments if not possible
fn plan_window(&self, expr: RawWindowExpr) -> Result<PlannerResult<RawWindowExpr>> {
Ok(PlannerResult::Original(expr))
}
}

/// An operator with two arguments to plan
Expand Down Expand Up @@ -266,6 +286,30 @@ pub struct RawDictionaryExpr {
pub values: Vec<Expr>,
}

/// This structure is used by `AggregateFunctionPlanner` to plan operators with
/// custom expressions.
#[derive(Debug, Clone)]
pub struct RawAggregateExpr {
pub func: Arc<AggregateUDF>,
pub args: Vec<Expr>,
pub distinct: bool,
pub filter: Option<Box<Expr>>,
pub order_by: Option<Vec<SortExpr>>,
pub null_treatment: Option<NullTreatment>,
}

/// This structure is used by `WindowFunctionPlanner` to plan operators with
/// custom expressions.
#[derive(Debug, Clone)]
pub struct RawWindowExpr {
pub func_def: WindowFunctionDefinition,
pub args: Vec<Expr>,
pub partition_by: Vec<Expr>,
pub order_by: Vec<SortExpr>,
pub window_frame: WindowFrame,
pub null_treatment: Option<NullTreatment>,
}

/// Result of planning a raw expr with [`ExprPlanner`]
#[derive(Debug, Clone)]
pub enum PlannerResult<T> {
Expand Down
21 changes: 13 additions & 8 deletions datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -515,27 +515,32 @@ pub trait AggregateUDFImpl: Debug + Send + Sync {
null_treatment,
} = params;

let mut schema_name = String::new();
let mut display_name = String::new();

schema_name.write_fmt(format_args!(
display_name.write_fmt(format_args!(
"{}({}{})",
self.name(),
if *distinct { "DISTINCT " } else { "" },
expr_vec_fmt!(args)
))?;

if let Some(nt) = null_treatment {
schema_name.write_fmt(format_args!(" {}", nt))?;
display_name.write_fmt(format_args!(" {}", nt))?;
}
if let Some(fe) = filter {
schema_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
display_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?;
}
if let Some(order_by) = order_by {
schema_name
.write_fmt(format_args!(" ORDER BY [{}]", expr_vec_fmt!(order_by)))?;
if let Some(ob) = order_by {
display_name.write_fmt(format_args!(
" ORDER BY [{}]",
ob.iter()
.map(|o| format!("{o}"))
.collect::<Vec<String>>()
.join(", ")
))?;
}

Ok(schema_name)
Ok(display_name)
}

/// Returns the user-defined display name of function, given the arguments
Expand Down
Loading

0 comments on commit e03f9f6

Please sign in to comment.