diff --git a/datafusion/core/tests/sqllogictests/src/insert/mod.rs b/datafusion/core/tests/sqllogictests/src/insert/mod.rs index a8f24a051601..695b6d26d56b 100644 --- a/datafusion/core/tests/sqllogictests/src/insert/mod.rs +++ b/datafusion/core/tests/sqllogictests/src/insert/mod.rs @@ -24,9 +24,8 @@ use datafusion::datasource::MemTable; use datafusion::prelude::SessionContext; use datafusion_common::{DFSchema, DataFusionError}; use datafusion_expr::Expr as DFExpr; -use datafusion_sql::planner::SqlToRel; +use datafusion_sql::planner::{PlannerContext, SqlToRel}; use sqlparser::ast::{Expr, SetExpr, Statement as SQLStatement}; -use std::collections::HashMap; use std::sync::Arc; pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result { @@ -66,7 +65,11 @@ pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result< let logical_exprs = row .into_iter() .map(|expr| { - sql_to_rel.sql_to_rex(expr, &DFSchema::empty(), &mut HashMap::new()) + sql_to_rel.sql_to_rex( + expr, + &DFSchema::empty(), + &mut PlannerContext::new(), + ) }) .collect::, DataFusionError>>()?; // Directly use `select` to get `RecordBatch` diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 7dacadbac05e..77fc3a82cfcc 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -103,6 +103,32 @@ pub struct ParserOptions { parse_float_as_decimal: bool, } +#[derive(Debug, Clone)] +/// Struct to store Common Table Expression (CTE) provided with WITH clause and +/// Parameter Data Types provided with PREPARE statement +pub struct PlannerContext { + /// Data type provided with prepare statement + pub prepare_param_data_types: Vec, + /// Map of CTE name to logical plan of the WITH clause + pub ctes: HashMap, +} + +impl Default for PlannerContext { + fn default() -> Self { + Self::new() + } +} + +impl PlannerContext { + /// Create a new PlannerContext + pub fn new() -> Self { + Self { + prepare_param_data_types: vec![], + ctes: HashMap::new(), + } + } +} + /// SQL query planner pub struct SqlToRel<'a, S: ContextProvider> { schema_provider: &'a S, @@ -181,7 +207,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { describe_alias: _, .. } => self.explain_statement_to_plan(verbose, analyze, *statement), - Statement::Query(query) => self.query_to_plan(*query, &mut HashMap::new()), + Statement::Query(query) => { + self.query_to_plan(*query, &mut PlannerContext::new()) + } Statement::ShowVariable { variable } => self.show_variable_to_plan(&variable), Statement::SetVariable { local, @@ -204,7 +232,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { && table_properties.is_empty() && with_options.is_empty() => { - let plan = self.query_to_plan(*query, &mut HashMap::new())?; + let plan = self.query_to_plan(*query, &mut PlannerContext::new())?; let input_schema = plan.schema(); let plan = if !columns.is_empty() { @@ -248,7 +276,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { with_options, .. } if with_options.is_empty() => { - let mut plan = self.query_to_plan(*query, &mut HashMap::new())?; + let mut plan = self.query_to_plan(*query, &mut PlannerContext::new())?; plan = Self::apply_expr_alias(plan, &columns)?; Ok(LogicalPlan::CreateView(CreateView { @@ -370,19 +398,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub fn query_to_plan( &self, query: Query, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, ) -> Result { - self.query_to_plan_with_alias(query, None, ctes, None) + self.query_to_plan_with_alias(query, None, planner_context, None) } /// Generate a logical plan from a SQL subquery pub fn subquery_to_plan( &self, query: Query, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, outer_query_schema: &DFSchema, ) -> Result { - self.query_to_plan_with_alias(query, None, ctes, Some(outer_query_schema)) + self.query_to_plan_with_alias( + query, + None, + planner_context, + Some(outer_query_schema), + ) } /// Generate a logic plan from an SQL query with optional alias @@ -390,7 +423,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, query: Query, alias: Option, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, outer_query_schema: Option<&DFSchema>, ) -> Result { let set_expr = query.body; @@ -406,7 +439,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { for cte in with.cte_tables { // A `WITH` block can't use the same name more than once let cte_name = normalize_ident(&cte.alias.name); - if ctes.contains_key(&cte_name) { + if planner_context.ctes.contains_key(&cte_name) { return Err(DataFusionError::SQL(ParserError(format!( "WITH query name {:?} specified more than once", cte_name @@ -416,7 +449,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let logical_plan = self.query_to_plan_with_alias( *cte.query, Some(cte_name.clone()), - &mut ctes.clone(), + &mut planner_context.clone(), outer_query_schema, )?; @@ -424,10 +457,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). let logical_plan = self.apply_table_alias(logical_plan, cte.alias)?; - ctes.insert(cte_name, logical_plan); + planner_context.ctes.insert(cte_name, logical_plan); } } - let plan = self.set_expr_to_plan(*set_expr, alias, ctes, outer_query_schema)?; + let plan = + self.set_expr_to_plan(*set_expr, alias, planner_context, outer_query_schema)?; let plan = self.order_by(plan, query.order_by)?; self.limit(plan, query.offset, query.limit) } @@ -436,12 +470,12 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, set_expr: SetExpr, alias: Option, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, outer_query_schema: Option<&DFSchema>, ) -> Result { match set_expr { SetExpr::Select(s) => { - self.select_to_plan(*s, ctes, alias, outer_query_schema) + self.select_to_plan(*s, planner_context, alias, outer_query_schema) } SetExpr::Values(v) => self.sql_values_to_plan(v), SetExpr::SetOperation { @@ -455,10 +489,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SetQuantifier::Distinct | SetQuantifier::None => false, }; - let left_plan = - self.set_expr_to_plan(*left, None, ctes, outer_query_schema)?; - let right_plan = - self.set_expr_to_plan(*right, None, ctes, outer_query_schema)?; + let left_plan = self.set_expr_to_plan( + *left, + None, + planner_context, + outer_query_schema, + )?; + let right_plan = self.set_expr_to_plan( + *right, + None, + planner_context, + outer_query_schema, + )?; match (op, all) { (SetOperator::Union, true) => LogicalPlanBuilder::from(left_plan) .union(right_plan)? @@ -480,7 +522,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } } } - SetExpr::Query(q) => self.query_to_plan(*q, ctes), + SetExpr::Query(q) => self.query_to_plan(*q, planner_context), _ => Err(DataFusionError::NotImplemented(format!( "Query {} not implemented yet", set_expr @@ -618,19 +660,21 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn plan_from_tables( &self, mut from: Vec, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, outer_query_schema: Option<&DFSchema>, ) -> Result { match from.len() { 0 => Ok(LogicalPlanBuilder::empty(true).build()?), 1 => { let from = from.remove(0); - self.plan_table_with_joins(from, ctes, outer_query_schema) + self.plan_table_with_joins(from, planner_context, outer_query_schema) } _ => { let plans = from .into_iter() - .map(|t| self.plan_table_with_joins(t, ctes, outer_query_schema)) + .map(|t| { + self.plan_table_with_joins(t, planner_context, outer_query_schema) + }) .collect::>>()?; let mut left = plans[0].clone(); for right in plans.iter().skip(1) { @@ -644,7 +688,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn plan_table_with_joins( &self, t: TableWithJoins, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, outer_query_schema: Option<&DFSchema>, ) -> Result { // From clause may exist CTEs, we should separate them from global CTEs. @@ -652,28 +696,33 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Such as `select * from (WITH source AS (select 1 as e) SELECT * FROM source) t1, (WITH source AS (select 1 as e) SELECT * FROM source) t2;` which is valid. // So always use original global CTEs to plan CTEs in from clause. // Btw, don't need to add CTEs in from to global CTEs. - let origin_ctes = ctes.clone(); - let left = self.create_relation(t.relation, ctes, outer_query_schema)?; + let origin_planner_context = planner_context.clone(); + let left = + self.create_relation(t.relation, planner_context, outer_query_schema)?; match t.joins.len() { 0 => { - *ctes = origin_ctes; + *planner_context = origin_planner_context; Ok(left) } _ => { let mut joins = t.joins.into_iter(); - *ctes = origin_ctes.clone(); + *planner_context = origin_planner_context.clone(); let mut left = self.parse_relation_join( left, joins.next().unwrap(), // length of joins > 0 - ctes, + planner_context, outer_query_schema, )?; for join in joins { - *ctes = origin_ctes.clone(); - left = - self.parse_relation_join(left, join, ctes, outer_query_schema)?; + *planner_context = origin_planner_context.clone(); + left = self.parse_relation_join( + left, + join, + planner_context, + outer_query_schema, + )?; } - *ctes = origin_ctes; + *planner_context = origin_planner_context; Ok(left) } } @@ -683,22 +732,23 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, left: LogicalPlan, join: Join, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, outer_query_schema: Option<&DFSchema>, ) -> Result { - let right = self.create_relation(join.relation, ctes, outer_query_schema)?; + let right = + self.create_relation(join.relation, planner_context, outer_query_schema)?; match join.join_operator { JoinOperator::LeftOuter(constraint) => { - self.parse_join(left, right, constraint, JoinType::Left, ctes) + self.parse_join(left, right, constraint, JoinType::Left, planner_context) } JoinOperator::RightOuter(constraint) => { - self.parse_join(left, right, constraint, JoinType::Right, ctes) + self.parse_join(left, right, constraint, JoinType::Right, planner_context) } JoinOperator::Inner(constraint) => { - self.parse_join(left, right, constraint, JoinType::Inner, ctes) + self.parse_join(left, right, constraint, JoinType::Inner, planner_context) } JoinOperator::FullOuter(constraint) => { - self.parse_join(left, right, constraint, JoinType::Full, ctes) + self.parse_join(left, right, constraint, JoinType::Full, planner_context) } JoinOperator::CrossJoin => self.parse_cross_join(left, &right), other => Err(DataFusionError::NotImplemented(format!( @@ -722,7 +772,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { right: LogicalPlan, constraint: JoinConstraint, join_type: JoinType, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, ) -> Result { match constraint { JoinConstraint::On(sql_expr) => { @@ -730,7 +780,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let join_schema = left.schema().join(right.schema())?; // parse ON expression - let expr = self.sql_to_rex(sql_expr, &join_schema, ctes)?; + let expr = self.sql_to_rex(sql_expr, &join_schema, planner_context)?; // ambiguous check ensure_any_column_reference_is_unambiguous( @@ -837,7 +887,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn create_relation( &self, relation: TableFactor, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, outer_query_schema: Option<&DFSchema>, ) -> Result { let (plan, alias) = match relation { @@ -850,7 +900,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let table_name = normalize_sql_object_name(sql_object_name); let table_ref: TableReference = table_name.as_str().into(); let table_alias = alias.as_ref().map(|a| normalize_ident(&a.name)); - let cte = ctes.get(&table_name); + let cte = planner_context.ctes.get(&table_name); ( match (cte, self.schema_provider.get_table_provider(table_ref)) { (Some(cte_plan), _) => match table_alias { @@ -877,7 +927,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let logical_plan = self.query_to_plan_with_alias( *subquery, None, - ctes, + planner_context, outer_query_schema, )?; let normalized_alias = alias.as_ref().map(|a| normalize_ident(&a.name)); @@ -891,7 +941,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { table_with_joins, alias, } => ( - self.plan_table_with_joins(*table_with_joins, ctes, outer_query_schema)?, + self.plan_table_with_joins( + *table_with_joins, + planner_context, + outer_query_schema, + )?, alias, ), // @todo Support TableFactory::TableFunction? @@ -953,7 +1007,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { selection: Option, plan: LogicalPlan, outer_query_schema: Option<&DFSchema>, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, ) -> Result { match selection { Some(predicate_expr) => { @@ -968,7 +1022,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } let x: Vec<&DFSchemaRef> = all_schemas.iter().collect(); - let filter_expr = self.sql_to_rex(predicate_expr, &join_schema, ctes)?; + let filter_expr = + self.sql_to_rex(predicate_expr, &join_schema, planner_context)?; let mut using_columns = HashSet::new(); expr_to_columns(&filter_expr, &mut using_columns)?; let filter_expr = normalize_col_with_schemas( @@ -1017,7 +1072,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fn select_to_plan( &self, select: Select, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, alias: Option, outer_query_schema: Option<&DFSchema>, ) -> Result { @@ -1036,15 +1091,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } // process `from` clause - let plan = self.plan_from_tables(select.from, ctes, outer_query_schema)?; + let plan = + self.plan_from_tables(select.from, planner_context, outer_query_schema)?; let empty_from = matches!(plan, LogicalPlan::EmptyRelation(_)); // build from schema for unqualifier column ambiguous check // we should get only one field for unqualifier column from schema. let from_schema = self.build_schema_for_ambiguous_check(&plan)?; // process `where` clause - let plan = - self.plan_selection(select.selection, plan, outer_query_schema, ctes)?; + let plan = self.plan_selection( + select.selection, + plan, + outer_query_schema, + planner_context, + )?; // process the SELECT expressions, with wildcards expanded. let select_exprs = self.prepare_select_exprs( @@ -1052,7 +1112,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { select.projection, empty_from, outer_query_schema, - ctes, + planner_context, &from_schema, )?; @@ -1068,8 +1128,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let having_expr_opt = select .having .map::, _>(|having_expr| { - let having_expr = - self.sql_expr_to_logical_expr(having_expr, &combined_schema, ctes)?; + let having_expr = self.sql_expr_to_logical_expr( + having_expr, + &combined_schema, + planner_context, + )?; // This step "dereferences" any aliases in the HAVING clause. // // This is how we support queries with HAVING expressions that @@ -1107,7 +1170,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .into_iter() .map(|e| { let group_by_expr = - self.sql_expr_to_logical_expr(e, &combined_schema, ctes)?; + self.sql_expr_to_logical_expr(e, &combined_schema, planner_context)?; // aliases from the projection can conflict with same-named expressions in the input let mut alias_map = alias_map.clone(); for f in plan.schema().fields() { @@ -1199,7 +1262,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let x = select .distribute_by .iter() - .map(|e| self.sql_expr_to_logical_expr(e.clone(), &combined_schema, ctes)) + .map(|e| { + self.sql_expr_to_logical_expr( + e.clone(), + &combined_schema, + planner_context, + ) + }) .collect::>>()?; LogicalPlanBuilder::from(plan) .repartition(Partitioning::DistributeBy(x))? @@ -1218,7 +1287,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { projection: Vec, empty_from: bool, outer_query_schema: Option<&DFSchema>, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, from_schema: &DFSchema, ) -> Result> { projection @@ -1229,7 +1298,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan, empty_from, outer_query_schema, - ctes, + planner_context, from_schema, ) }) @@ -1371,7 +1440,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Some(skip_expr) => match self.sql_to_rex( skip_expr.value, input.schema(), - &mut HashMap::new(), + &mut PlannerContext::new(), )? { Expr::Literal(ScalarValue::Int64(Some(s))) => { if s < 0 { @@ -1394,7 +1463,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let n = match self.sql_to_rex( limit_expr, input.schema(), - &mut HashMap::new(), + &mut PlannerContext::new(), )? { Expr::Literal(ScalarValue::Int64(Some(n))) => Ok(n as usize), _ => Err(DataFusionError::Plan( @@ -1456,7 +1525,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let field = schema.field(field_index - 1); Expr::Column(field.qualified_column()) } - e => self.sql_expr_to_logical_expr(e, schema, &mut HashMap::new())?, + e => self.sql_expr_to_logical_expr(e, schema, &mut PlannerContext::new())?, }; Ok({ let asc = asc.unwrap_or(true); @@ -1537,7 +1606,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { plan: &LogicalPlan, empty_from: bool, outer_query_schema: Option<&DFSchema>, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, from_schema: &DFSchema, ) -> Result> { let input_schema = match outer_query_schema { @@ -1551,12 +1620,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { match sql { SelectItem::UnnamedExpr(expr) => { - let expr = self.sql_to_rex(expr, &input_schema, ctes)?; + let expr = self.sql_to_rex(expr, &input_schema, planner_context)?; self.column_reference_ambiguous_check(from_schema, &[expr.clone()])?; Ok(vec![normalize_col(expr, plan)?]) } SelectItem::ExprWithAlias { expr, alias } => { - let select_expr = self.sql_to_rex(expr, &input_schema, ctes)?; + let select_expr = + self.sql_to_rex(expr, &input_schema, planner_context)?; self.column_reference_ambiguous_check( from_schema, &[select_expr.clone()], @@ -1586,9 +1656,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, sql: SQLExpr, schema: &DFSchema, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, ) -> Result { - let mut expr = self.sql_expr_to_logical_expr(sql, schema, ctes)?; + let mut expr = self.sql_expr_to_logical_expr(sql, schema, planner_context)?; expr = self.rewrite_partial_qualifier(expr, schema); self.validate_schema_satisfies_exprs(schema, &[expr.clone()])?; Ok(expr) @@ -1626,19 +1696,19 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, sql: FunctionArg, schema: &DFSchema, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, ) -> Result { match sql { FunctionArg::Named { name: _, arg: FunctionArgExpr::Expr(arg), - } => self.sql_expr_to_logical_expr(arg, schema, ctes), + } => self.sql_expr_to_logical_expr(arg, schema, planner_context), FunctionArg::Named { name: _, arg: FunctionArgExpr::Wildcard, } => Ok(Expr::Wildcard), FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => { - self.sql_expr_to_logical_expr(arg, schema, ctes) + self.sql_expr_to_logical_expr(arg, schema, planner_context) } FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(Expr::Wildcard), _ => Err(DataFusionError::NotImplemented(format!( @@ -1654,7 +1724,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { op: BinaryOperator, right: SQLExpr, schema: &DFSchema, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, ) -> Result { let operator = match op { BinaryOperator::Gt => Ok(Operator::Gt), @@ -1687,9 +1757,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { }?; Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(self.sql_expr_to_logical_expr(left, schema, ctes)?), + Box::new(self.sql_expr_to_logical_expr(left, schema, planner_context)?), operator, - Box::new(self.sql_expr_to_logical_expr(right, schema, ctes)?), + Box::new(self.sql_expr_to_logical_expr(right, schema, planner_context)?), ))) } @@ -1698,13 +1768,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { op: UnaryOperator, expr: SQLExpr, schema: &DFSchema, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, ) -> Result { match op { UnaryOperator::Not => Ok(Expr::Not(Box::new( - self.sql_expr_to_logical_expr(expr, schema, ctes)?, + self.sql_expr_to_logical_expr(expr, schema, planner_context)?, ))), - UnaryOperator::Plus => Ok(self.sql_expr_to_logical_expr(expr, schema, ctes)?), + UnaryOperator::Plus => { + Ok(self.sql_expr_to_logical_expr(expr, schema, planner_context)?) + } UnaryOperator::Minus => { match expr { // optimization: if it's a number literal, we apply the negative operator @@ -1720,7 +1792,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { })?)), }, // not a literal, apply negative operator on expression - _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr(expr, schema, ctes)?))), + _ => Ok(Expr::Negative(Box::new(self.sql_expr_to_logical_expr(expr, schema, planner_context)?))), } } _ => Err(DataFusionError::NotImplemented(format!( @@ -1751,7 +1823,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { op, *expr, &schema, - &mut HashMap::new(), + &mut PlannerContext::new(), ), SQLExpr::BinaryOp { left, op, right } => self .parse_sql_binary_op( @@ -1759,7 +1831,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { op, *right, &schema, - &mut HashMap::new(), + &mut PlannerContext::new(), ), SQLExpr::TypedString { data_type, value } => { Ok(Expr::Cast(Cast::new( @@ -1771,7 +1843,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { Box::new(self.sql_expr_to_logical_expr( *expr, &schema, - &mut HashMap::new(), + &mut PlannerContext::new(), )?), self.convert_data_type(&data_type)?, ))), @@ -1790,7 +1862,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, sql: SQLExpr, schema: &DFSchema, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, ) -> Result { match sql { SQLExpr::Value(Value::Number(n, _)) => self.parse_sql_number(&n), @@ -1801,7 +1873,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fun: BuiltinScalarFunction::DatePart, args: vec![ Expr::Literal(ScalarValue::Utf8(Some(format!("{}", field)))), - self.sql_expr_to_logical_expr(*expr, schema, ctes)?, + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ], }), @@ -1857,7 +1929,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLExpr::ArrayIndex { obj, indexes } => { - let expr = self.sql_expr_to_logical_expr(*obj, schema, ctes)?; + let expr = self.sql_expr_to_logical_expr(*obj, schema, planner_context)?; plan_indexed(expr, indexes) } @@ -1918,20 +1990,20 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { else_result, } => { let expr = if let Some(e) = operand { - Some(Box::new(self.sql_expr_to_logical_expr(*e, schema, ctes)?)) + Some(Box::new(self.sql_expr_to_logical_expr(*e, schema, planner_context)?)) } else { None }; let when_expr = conditions .into_iter() - .map(|e| self.sql_expr_to_logical_expr(e, schema, ctes)) + .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; let then_expr = results .into_iter() - .map(|e| self.sql_expr_to_logical_expr(e, schema, ctes)) + .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; let else_expr = if let Some(e) = else_result { - Some(Box::new(self.sql_expr_to_logical_expr(*e, schema, ctes)?)) + Some(Box::new(self.sql_expr_to_logical_expr(*e, schema, planner_context)?)) } else { None }; @@ -1951,7 +2023,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { expr, data_type, } => Ok(Expr::Cast(Cast::new( - Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?), + Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?), self.convert_data_type(&data_type)?, ))), @@ -1959,7 +2031,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { expr, data_type, } => Ok(Expr::TryCast { - expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?), + expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?), data_type: self.convert_data_type(&data_type)?, }), @@ -1972,38 +2044,38 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))), SQLExpr::IsNull(expr) => Ok(Expr::IsNull(Box::new( - self.sql_expr_to_logical_expr(*expr, schema, ctes)?, + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), SQLExpr::IsNotNull(expr) => Ok(Expr::IsNotNull(Box::new( - self.sql_expr_to_logical_expr(*expr, schema, ctes)?, + self.sql_expr_to_logical_expr(*expr, schema, planner_context)?, ))), SQLExpr::IsDistinctFrom(left, right) => Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(self.sql_expr_to_logical_expr(*left, schema, ctes)?), + Box::new(self.sql_expr_to_logical_expr(*left, schema, planner_context)?), Operator::IsDistinctFrom, - Box::new(self.sql_expr_to_logical_expr(*right, schema, ctes)?), + Box::new(self.sql_expr_to_logical_expr(*right, schema, planner_context)?), ))), SQLExpr::IsNotDistinctFrom(left, right) => Ok(Expr::BinaryExpr(BinaryExpr::new( - Box::new(self.sql_expr_to_logical_expr(*left, schema, ctes)?), + Box::new(self.sql_expr_to_logical_expr(*left, schema, planner_context)?), Operator::IsNotDistinctFrom, - Box::new(self.sql_expr_to_logical_expr(*right, schema, ctes)?), + Box::new(self.sql_expr_to_logical_expr(*right, schema, planner_context)?), ))), - SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?))), + SQLExpr::IsTrue(expr) => Ok(Expr::IsTrue(Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?))), - SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?))), + SQLExpr::IsFalse(expr) => Ok(Expr::IsFalse(Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?))), - SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?))), + SQLExpr::IsNotTrue(expr) => Ok(Expr::IsNotTrue(Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?))), - SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?))), + SQLExpr::IsNotFalse(expr) => Ok(Expr::IsNotFalse(Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?))), - SQLExpr::IsUnknown(expr) => Ok(Expr::IsUnknown(Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?))), + SQLExpr::IsUnknown(expr) => Ok(Expr::IsUnknown(Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?))), - SQLExpr::IsNotUnknown(expr) => Ok(Expr::IsNotUnknown(Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?))), + SQLExpr::IsNotUnknown(expr) => Ok(Expr::IsNotUnknown(Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?))), - SQLExpr::UnaryOp { op, expr } => self.parse_sql_unary_op(op, *expr, schema, ctes), + SQLExpr::UnaryOp { op, expr } => self.parse_sql_unary_op(op, *expr, schema, planner_context), SQLExpr::Between { expr, @@ -2011,10 +2083,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { low, high, } => Ok(Expr::Between(Between::new( - Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?), + Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?), negated, - Box::new(self.sql_expr_to_logical_expr(*low, schema, ctes)?), - Box::new(self.sql_expr_to_logical_expr(*high, schema, ctes)?), + Box::new(self.sql_expr_to_logical_expr(*low, schema, planner_context)?), + Box::new(self.sql_expr_to_logical_expr(*high, schema, planner_context)?), ))), SQLExpr::InList { @@ -2024,18 +2096,18 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } => { let list_expr = list .into_iter() - .map(|e| self.sql_expr_to_logical_expr(e, schema, ctes)) + .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; Ok(Expr::InList { - expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?), + expr: Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?), list: list_expr, negated, }) } SQLExpr::Like { negated, expr, pattern, escape_char } => { - let pattern = self.sql_expr_to_logical_expr(*pattern, schema, ctes)?; + let pattern = self.sql_expr_to_logical_expr(*pattern, schema, planner_context)?; let pattern_type = pattern.get_type(schema)?; if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { return Err(DataFusionError::Plan( @@ -2044,14 +2116,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } Ok(Expr::Like(Like::new( negated, - Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?), + Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?), Box::new(pattern), escape_char, ))) } SQLExpr::ILike { negated, expr, pattern, escape_char } => { - let pattern = self.sql_expr_to_logical_expr(*pattern, schema, ctes)?; + let pattern = self.sql_expr_to_logical_expr(*pattern, schema, planner_context)?; let pattern_type = pattern.get_type(schema)?; if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { return Err(DataFusionError::Plan( @@ -2060,14 +2132,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } Ok(Expr::ILike(Like::new( negated, - Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?), + Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?), Box::new(pattern), escape_char, ))) } SQLExpr::SimilarTo { negated, expr, pattern, escape_char } => { - let pattern = self.sql_expr_to_logical_expr(*pattern, schema, ctes)?; + let pattern = self.sql_expr_to_logical_expr(*pattern, schema, planner_context)?; let pattern_type = pattern.get_type(schema)?; if pattern_type != DataType::Utf8 && pattern_type != DataType::Null { return Err(DataFusionError::Plan( @@ -2076,7 +2148,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } Ok(Expr::SimilarTo(Like::new( negated, - Box::new(self.sql_expr_to_logical_expr(*expr, schema, ctes)?), + Box::new(self.sql_expr_to_logical_expr(*expr, schema, planner_context)?), Box::new(pattern), escape_char, ))) @@ -2086,7 +2158,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { left, op, right, - } => self.parse_sql_binary_op(*left, op, *right, schema, ctes), + } => self.parse_sql_binary_op(*left, op, *right, schema, planner_context), #[cfg(feature = "unicode_expressions")] SQLExpr::Substring { @@ -2096,24 +2168,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } => { let args = match (substring_from, substring_for) { (Some(from_expr), Some(for_expr)) => { - let arg = self.sql_expr_to_logical_expr(*expr, schema, ctes)?; + let arg = self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; let from_logic = - self.sql_expr_to_logical_expr(*from_expr, schema, ctes)?; + self.sql_expr_to_logical_expr(*from_expr, schema, planner_context)?; let for_logic = - self.sql_expr_to_logical_expr(*for_expr, schema, ctes)?; + self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; vec![arg, from_logic, for_logic] } (Some(from_expr), None) => { - let arg = self.sql_expr_to_logical_expr(*expr, schema, ctes)?; + let arg = self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; let from_logic = - self.sql_expr_to_logical_expr(*from_expr, schema, ctes)?; + self.sql_expr_to_logical_expr(*from_expr, schema, planner_context)?; vec![arg, from_logic] } (None, Some(for_expr)) => { - let arg = self.sql_expr_to_logical_expr(*expr, schema, ctes)?; + let arg = self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; let from_logic = Expr::Literal(ScalarValue::Int64(Some(1))); let for_logic = - self.sql_expr_to_logical_expr(*for_expr, schema, ctes)?; + self.sql_expr_to_logical_expr(*for_expr, schema, planner_context)?; vec![arg, from_logic, for_logic] } (None, None) => { @@ -2159,10 +2231,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } None => BuiltinScalarFunction::Trim }; - let arg = self.sql_expr_to_logical_expr(*expr, schema, ctes)?; + let arg = self.sql_expr_to_logical_expr(*expr, schema, planner_context)?; let args = match trim_what { Some(to_trim) => { - let to_trim = self.sql_expr_to_logical_expr(*to_trim, schema, ctes)?; + let to_trim = self.sql_expr_to_logical_expr(*to_trim, schema, planner_context)?; vec![arg, to_trim] } None => vec![arg], @@ -2171,10 +2243,10 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } SQLExpr::AggregateExpressionWithFilter { expr, filter } => { - match self.sql_expr_to_logical_expr(*expr, schema, ctes)? { + match self.sql_expr_to_logical_expr(*expr, schema, planner_context)? { Expr::AggregateFunction { fun, args, distinct, .. - } => Ok(Expr::AggregateFunction { fun, args, distinct, filter: Some(Box::new(self.sql_expr_to_logical_expr(*filter, schema, ctes)?)) }), + } => Ok(Expr::AggregateFunction { fun, args, distinct, filter: Some(Box::new(self.sql_expr_to_logical_expr(*filter, schema, planner_context)?)) }), _ => Err(DataFusionError::Internal("AggregateExpressionWithFilter expression was not an AggregateFunction".to_string())) } } @@ -2208,7 +2280,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let partition_by = window .partition_by .into_iter() - .map(|e| self.sql_expr_to_logical_expr(e, schema, ctes)) + .map(|e| self.sql_expr_to_logical_expr(e, schema, planner_context)) .collect::>>()?; let order_by = window .order_by @@ -2306,25 +2378,25 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLExpr::Floor { expr, field: _field } => { let fun = BuiltinScalarFunction::Floor; - let args = vec![self.sql_expr_to_logical_expr(*expr, schema, ctes)?]; + let args = vec![self.sql_expr_to_logical_expr(*expr, schema, planner_context)?]; Ok(Expr::ScalarFunction { fun, args }) } SQLExpr::Ceil { expr, field: _field } => { let fun = BuiltinScalarFunction::Ceil; - let args = vec![self.sql_expr_to_logical_expr(*expr, schema, ctes)?]; + let args = vec![self.sql_expr_to_logical_expr(*expr, schema, planner_context)?]; Ok(Expr::ScalarFunction { fun, args }) } - SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(*e, schema, ctes), + SQLExpr::Nested(e) => self.sql_expr_to_logical_expr(*e, schema, planner_context), - SQLExpr::Exists { subquery, negated } => self.parse_exists_subquery(*subquery, negated, schema, ctes), + SQLExpr::Exists { subquery, negated } => self.parse_exists_subquery(*subquery, negated, schema, planner_context), - SQLExpr::InSubquery { expr, subquery, negated } => self.parse_in_subquery(*expr, *subquery, negated, schema, ctes), + SQLExpr::InSubquery { expr, subquery, negated } => self.parse_in_subquery(*expr, *subquery, negated, schema, planner_context), - SQLExpr::Subquery(subquery) => self.parse_scalar_subquery(*subquery, schema, ctes), + SQLExpr::Subquery(subquery) => self.parse_scalar_subquery(*subquery, schema, planner_context), - SQLExpr::ArrayAgg(array_agg) => self.parse_array_agg(array_agg, schema, ctes), + SQLExpr::ArrayAgg(array_agg) => self.parse_array_agg(array_agg, schema, planner_context), _ => Err(DataFusionError::NotImplemented(format!( "Unsupported ast node in sqltorel: {:?}", @@ -2338,13 +2410,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { subquery: Query, negated: bool, input_schema: &DFSchema, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, ) -> Result { Ok(Expr::Exists { subquery: Subquery { subquery: Arc::new(self.subquery_to_plan( subquery, - ctes, + planner_context, input_schema, )?), }, @@ -2358,14 +2430,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { subquery: Query, negated: bool, input_schema: &DFSchema, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, ) -> Result { Ok(Expr::InSubquery { - expr: Box::new(self.sql_to_rex(expr, input_schema, ctes)?), + expr: Box::new(self.sql_to_rex(expr, input_schema, planner_context)?), subquery: Subquery { subquery: Arc::new(self.subquery_to_plan( subquery, - ctes, + planner_context, input_schema, )?), }, @@ -2377,10 +2449,14 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, subquery: Query, input_schema: &DFSchema, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, ) -> Result { Ok(Expr::ScalarSubquery(Subquery { - subquery: Arc::new(self.subquery_to_plan(subquery, ctes, input_schema)?), + subquery: Arc::new(self.subquery_to_plan( + subquery, + planner_context, + input_schema, + )?), })) } @@ -2388,7 +2464,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { &self, array_agg: ArrayAgg, input_schema: &DFSchema, - ctes: &mut HashMap, + planner_context: &mut PlannerContext, ) -> Result { // Some dialects have special syntax for array_agg. DataFusion only supports it like a function. let ArrayAgg { @@ -2419,7 +2495,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { )); } - let args = vec![self.sql_expr_to_logical_expr(*expr, input_schema, ctes)?]; + let args = + vec![self.sql_expr_to_logical_expr(*expr, input_schema, planner_context)?]; // next, aggregate built-ins let fun = AggregateFunction::ArrayAgg; @@ -2437,7 +2514,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, ) -> Result> { args.into_iter() - .map(|a| self.sql_fn_arg_to_logical_expr(a, schema, &mut HashMap::new())) + .map(|a| { + self.sql_fn_arg_to_logical_expr(a, schema, &mut PlannerContext::new()) + }) .collect::>>() } @@ -2455,7 +2534,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone())) } - _ => self.sql_fn_arg_to_logical_expr(a, schema, &mut HashMap::new()), + _ => self.sql_fn_arg_to_logical_expr( + a, + schema, + &mut PlannerContext::new(), + ), }) .collect::>>()?, _ => self.function_args_to_expr(args, schema)?, @@ -2707,8 +2790,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let mut values = Vec::with_capacity(elements.len()); for element in elements { - let value = - self.sql_expr_to_logical_expr(element, schema, &mut HashMap::new())?; + let value = self.sql_expr_to_logical_expr( + element, + schema, + &mut PlannerContext::new(), + )?; match value { Expr::Literal(scalar) => { values.push(scalar);