Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: make Ctes a struct to also store data types provided by prepare stmt #4520

Merged
merged 2 commits into from
Dec 6, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions datafusion/core/tests/sqllogictests/src/insert/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ use datafusion::datasource::MemTable;
use datafusion::prelude::SessionContext;
use datafusion_common::{DFSchema, DataFusionError};
use datafusion_expr::Expr as DFExpr;
use datafusion_sql::planner::SqlToRel;
use datafusion_sql::planner::{Ctes, SqlToRel};
use sqlparser::ast::{Expr, SetExpr, Statement as SQLStatement};
use std::collections::HashMap;
use std::sync::Arc;

pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result<String> {
Expand Down Expand Up @@ -65,9 +64,7 @@ pub async fn insert(ctx: &SessionContext, insert_stmt: &SQLStatement) -> Result<
for row in insert_values.into_iter() {
let logical_exprs = row
.into_iter()
.map(|expr| {
sql_to_rel.sql_to_rex(expr, &DFSchema::empty(), &mut HashMap::new())
})
.map(|expr| sql_to_rel.sql_to_rex(expr, &DFSchema::empty(), &mut Ctes::new()))
.collect::<std::result::Result<Vec<DFExpr>, DataFusionError>>()?;
// Directly use `select` to get `RecordBatch`
let dataframe = ctx.read_empty()?;
Expand Down
110 changes: 64 additions & 46 deletions datafusion/sql/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,31 @@ pub struct ParserOptions {
parse_float_as_decimal: bool,
}

#[derive(Debug, Clone)]
/// Struct to store Common Table Expression provided with WITH clause and Parameter Data Types provided with PREPARE statement
pub struct Ctes {
/// Data type provided with prepare statement
pub prepare_param_data_types: Vec<DataType>,
/// Map of CTE name to logical plan of the WITH clause
pub ctes: HashMap<String, LogicalPlan>,
}

impl Default for Ctes {
fn default() -> Self {
Self::new()
}
}

impl Ctes {
/// Create a new Ctes
pub fn new() -> Self {
Self {
prepare_param_data_types: vec![],
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is always empty in this PR. After this is merged, I will use and set it in PR #4490

ctes: HashMap::new(),
}
}
}

/// SQL query planner
pub struct SqlToRel<'a, S: ContextProvider> {
schema_provider: &'a S,
Expand Down Expand Up @@ -181,7 +206,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
describe_alias: _,
..
} => self.explain_statement_to_plan(verbose, analyze, *statement),
Statement::Query(query) => self.query_to_plan(*query, &mut HashMap::new()),
Statement::Query(query) => self.query_to_plan(*query, &mut Ctes::new()),
Statement::ShowVariable { variable } => self.show_variable_to_plan(&variable),
Statement::SetVariable {
local,
Expand All @@ -204,7 +229,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
&& table_properties.is_empty()
&& with_options.is_empty() =>
{
let plan = self.query_to_plan(*query, &mut HashMap::new())?;
let plan = self.query_to_plan(*query, &mut Ctes::new())?;
let input_schema = plan.schema();

let plan = if !columns.is_empty() {
Expand Down Expand Up @@ -248,7 +273,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
with_options,
..
} if with_options.is_empty() => {
let mut plan = self.query_to_plan(*query, &mut HashMap::new())?;
let mut plan = self.query_to_plan(*query, &mut Ctes::new())?;
plan = Self::apply_expr_alias(plan, &columns)?;

Ok(LogicalPlan::CreateView(CreateView {
Expand Down Expand Up @@ -367,19 +392,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
}

/// Generate a logical plan from an SQL query
pub fn query_to_plan(
&self,
query: Query,
ctes: &mut HashMap<String, LogicalPlan>,
) -> Result<LogicalPlan> {
pub fn query_to_plan(&self, query: Query, ctes: &mut Ctes) -> Result<LogicalPlan> {
self.query_to_plan_with_alias(query, None, ctes, None)
}

/// Generate a logical plan from a SQL subquery
pub fn subquery_to_plan(
&self,
query: Query,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
outer_query_schema: &DFSchema,
) -> Result<LogicalPlan> {
self.query_to_plan_with_alias(query, None, ctes, Some(outer_query_schema))
Expand All @@ -390,7 +411,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
&self,
query: Query,
alias: Option<String>,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
outer_query_schema: Option<&DFSchema>,
) -> Result<LogicalPlan> {
let set_expr = query.body;
Expand All @@ -406,7 +427,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
for cte in with.cte_tables {
// A `WITH` block can't use the same name more than once
let cte_name = normalize_ident(&cte.alias.name);
if ctes.contains_key(&cte_name) {
if ctes.ctes.contains_key(&cte_name) {
return Err(DataFusionError::SQL(ParserError(format!(
"WITH query name {:?} specified more than once",
cte_name
Expand All @@ -424,7 +445,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
// projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2").
let logical_plan = self.apply_table_alias(logical_plan, cte.alias)?;

ctes.insert(cte_name, logical_plan);
ctes.ctes.insert(cte_name, logical_plan);
}
}
let plan = self.set_expr_to_plan(*set_expr, alias, ctes, outer_query_schema)?;
Expand All @@ -436,7 +457,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
&self,
set_expr: SetExpr,
alias: Option<String>,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
outer_query_schema: Option<&DFSchema>,
) -> Result<LogicalPlan> {
match set_expr {
Expand Down Expand Up @@ -618,7 +639,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
fn plan_from_tables(
&self,
mut from: Vec<TableWithJoins>,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
outer_query_schema: Option<&DFSchema>,
) -> Result<LogicalPlan> {
match from.len() {
Expand All @@ -644,7 +665,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
fn plan_table_with_joins(
&self,
t: TableWithJoins,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
outer_query_schema: Option<&DFSchema>,
) -> Result<LogicalPlan> {
// From clause may exist CTEs, we should separate them from global CTEs.
Expand Down Expand Up @@ -683,7 +704,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
&self,
left: LogicalPlan,
join: Join,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
outer_query_schema: Option<&DFSchema>,
) -> Result<LogicalPlan> {
let right = self.create_relation(join.relation, ctes, outer_query_schema)?;
Expand Down Expand Up @@ -722,7 +743,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
right: LogicalPlan,
constraint: JoinConstraint,
join_type: JoinType,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
) -> Result<LogicalPlan> {
match constraint {
JoinConstraint::On(sql_expr) => {
Expand Down Expand Up @@ -837,7 +858,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
fn create_relation(
&self,
relation: TableFactor,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
outer_query_schema: Option<&DFSchema>,
) -> Result<LogicalPlan> {
let (plan, alias) = match relation {
Expand All @@ -850,7 +871,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let table_name = normalize_sql_object_name(sql_object_name);
let table_ref: TableReference = table_name.as_str().into();
let table_alias = alias.as_ref().map(|a| normalize_ident(&a.name));
let cte = ctes.get(&table_name);
let cte = ctes.ctes.get(&table_name);
(
match (cte, self.schema_provider.get_table_provider(table_ref)) {
(Some(cte_plan), _) => match table_alias {
Expand Down Expand Up @@ -953,7 +974,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
selection: Option<SQLExpr>,
plan: LogicalPlan,
outer_query_schema: Option<&DFSchema>,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
) -> Result<LogicalPlan> {
match selection {
Some(predicate_expr) => {
Expand Down Expand Up @@ -1017,7 +1038,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
fn select_to_plan(
&self,
select: Select,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
alias: Option<String>,
outer_query_schema: Option<&DFSchema>,
) -> Result<LogicalPlan> {
Expand Down Expand Up @@ -1218,7 +1239,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
projection: Vec<SelectItem>,
empty_from: bool,
outer_query_schema: Option<&DFSchema>,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
from_schema: &DFSchema,
) -> Result<Vec<Expr>> {
projection
Expand Down Expand Up @@ -1371,7 +1392,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Some(skip_expr) => match self.sql_to_rex(
skip_expr.value,
input.schema(),
&mut HashMap::new(),
&mut Ctes::new(),
)? {
Expr::Literal(ScalarValue::Int64(Some(s))) => {
if s < 0 {
Expand All @@ -1394,7 +1415,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let n = match self.sql_to_rex(
limit_expr,
input.schema(),
&mut HashMap::new(),
&mut Ctes::new(),
)? {
Expr::Literal(ScalarValue::Int64(Some(n))) => Ok(n as usize),
_ => Err(DataFusionError::Plan(
Expand Down Expand Up @@ -1456,7 +1477,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
let field = schema.field(field_index - 1);
Expr::Column(field.qualified_column())
}
e => self.sql_expr_to_logical_expr(e, schema, &mut HashMap::new())?,
e => self.sql_expr_to_logical_expr(e, schema, &mut Ctes::new())?,
};
Ok({
let asc = asc.unwrap_or(true);
Expand Down Expand Up @@ -1537,7 +1558,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
plan: &LogicalPlan,
empty_from: bool,
outer_query_schema: Option<&DFSchema>,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
from_schema: &DFSchema,
) -> Result<Vec<Expr>> {
let input_schema = match outer_query_schema {
Expand Down Expand Up @@ -1586,7 +1607,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
&self,
sql: SQLExpr,
schema: &DFSchema,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
) -> Result<Expr> {
let mut expr = self.sql_expr_to_logical_expr(sql, schema, ctes)?;
expr = self.rewrite_partial_qualifier(expr, schema);
Expand Down Expand Up @@ -1626,7 +1647,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
&self,
sql: FunctionArg,
schema: &DFSchema,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
) -> Result<Expr> {
match sql {
FunctionArg::Named {
Expand Down Expand Up @@ -1654,7 +1675,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
op: BinaryOperator,
right: SQLExpr,
schema: &DFSchema,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
) -> Result<Expr> {
let operator = match op {
BinaryOperator::Gt => Ok(Operator::Gt),
Expand Down Expand Up @@ -1698,7 +1719,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
op: UnaryOperator,
expr: SQLExpr,
schema: &DFSchema,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
) -> Result<Expr> {
match op {
UnaryOperator::Not => Ok(Expr::Not(Box::new(
Expand Down Expand Up @@ -1747,19 +1768,16 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Ok(Expr::Literal(ScalarValue::Null))
}
SQLExpr::Value(Value::Boolean(n)) => Ok(lit(n)),
SQLExpr::UnaryOp { op, expr } => self.parse_sql_unary_op(
op,
*expr,
&schema,
&mut HashMap::new(),
),
SQLExpr::UnaryOp { op, expr } => {
self.parse_sql_unary_op(op, *expr, &schema, &mut Ctes::new())
}
SQLExpr::BinaryOp { left, op, right } => self
.parse_sql_binary_op(
*left,
op,
*right,
&schema,
&mut HashMap::new(),
&mut Ctes::new(),
),
SQLExpr::TypedString { data_type, value } => {
Ok(Expr::Cast(Cast::new(
Expand All @@ -1771,7 +1789,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
Box::new(self.sql_expr_to_logical_expr(
*expr,
&schema,
&mut HashMap::new(),
&mut Ctes::new(),
)?),
self.convert_data_type(&data_type)?,
))),
Expand All @@ -1790,7 +1808,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
&self,
sql: SQLExpr,
schema: &DFSchema,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
) -> Result<Expr> {
match sql {
SQLExpr::Value(Value::Number(n, _)) => self.parse_sql_number(&n),
Expand Down Expand Up @@ -2338,7 +2356,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
subquery: Query,
negated: bool,
input_schema: &DFSchema,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
) -> Result<Expr> {
Ok(Expr::Exists {
subquery: Subquery {
Expand All @@ -2358,7 +2376,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
subquery: Query,
negated: bool,
input_schema: &DFSchema,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
) -> Result<Expr> {
Ok(Expr::InSubquery {
expr: Box::new(self.sql_to_rex(expr, input_schema, ctes)?),
Expand All @@ -2377,7 +2395,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
&self,
subquery: Query,
input_schema: &DFSchema,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
) -> Result<Expr> {
Ok(Expr::ScalarSubquery(Subquery {
subquery: Arc::new(self.subquery_to_plan(subquery, ctes, input_schema)?),
Expand All @@ -2388,7 +2406,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
&self,
array_agg: ArrayAgg,
input_schema: &DFSchema,
ctes: &mut HashMap<String, LogicalPlan>,
ctes: &mut Ctes,
) -> Result<Expr> {
// Some dialects have special syntax for array_agg. DataFusion only supports it like a function.
let ArrayAgg {
Expand Down Expand Up @@ -2437,7 +2455,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
schema: &DFSchema,
) -> Result<Vec<Expr>> {
args.into_iter()
.map(|a| self.sql_fn_arg_to_logical_expr(a, schema, &mut HashMap::new()))
.map(|a| self.sql_fn_arg_to_logical_expr(a, schema, &mut Ctes::new()))
.collect::<Result<Vec<Expr>>>()
}

Expand All @@ -2455,7 +2473,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => {
Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone()))
}
_ => self.sql_fn_arg_to_logical_expr(a, schema, &mut HashMap::new()),
_ => self.sql_fn_arg_to_logical_expr(a, schema, &mut Ctes::new()),
})
.collect::<Result<Vec<Expr>>>()?,
_ => self.function_args_to_expr(args, schema)?,
Expand Down Expand Up @@ -2708,7 +2726,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {

for element in elements {
let value =
self.sql_expr_to_logical_expr(element, schema, &mut HashMap::new())?;
self.sql_expr_to_logical_expr(element, schema, &mut Ctes::new())?;
match value {
Expr::Literal(scalar) => {
values.push(scalar);
Expand Down