Skip to content

Commit

Permalink
Fix unparser derived table with columns include calculations, limit/o…
Browse files Browse the repository at this point in the history
…rder/distinct (#24)

* compare format output to make sure the two level of projects match

* add method to find inner projection that could be nested under limit/order/distinct

* use format! for matching in unparser sort optimization too

* refactor

* use to_string and also put comments in
  • Loading branch information
y-f-u authored Jul 29, 2024
1 parent ffe792d commit 1c65aff
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 65 deletions.
66 changes: 4 additions & 62 deletions datafusion/sql/src/unparser/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ use super::{
BuilderError, DerivedRelationBuilder, QueryBuilder, RelationBuilder,
SelectBuilder, TableRelationBuilder, TableWithJoinsBuilder,
},
rewrite::{normalize_union_schema, rewrite_plan_for_sort_on_non_projected_fields},
rewrite::{
normalize_union_schema, rewrite_plan_for_sort_on_non_projected_fields,
subquery_alias_inner_query_and_columns,
},
utils::{find_agg_node_within_select, unproject_window_exprs, AggVariant},
Unparser,
};
Expand Down Expand Up @@ -604,67 +607,6 @@ impl Unparser<'_> {
}
}

// This logic is to work out the columns and inner query for SubqueryAlias plan for both types of
// subquery
// - `(SELECT column_a as a from table) AS A`
// - `(SELECT column_a from table) AS A (a)`
//
// A roundtrip example for table alias with columns
//
// query: SELECT id FROM (SELECT j1_id from j1) AS c (id)
//
// LogicPlan:
// Projection: c.id
// SubqueryAlias: c
// Projection: j1.j1_id AS id
// Projection: j1.j1_id
// TableScan: j1
//
// Before introducing this logic, the unparsed query would be `SELECT c.id FROM (SELECT j1.j1_id AS
// id FROM (SELECT j1.j1_id FROM j1)) AS c`.
// The query is invalid as `j1.j1_id` is not a valid identifier in the derived table
// `(SELECT j1.j1_id FROM j1)`
//
// With this logic, the unparsed query will be:
// `SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)`
//
// Caveat: this won't handle the case like `select * from (select 1, 2) AS a (b, c)`
// as the parser gives a wrong plan which has mismatch `Int(1)` types: Literal and
// Column in the Projections. Once the parser side is fixed, this logic should work
fn subquery_alias_inner_query_and_columns(
subquery_alias: &datafusion_expr::SubqueryAlias,
) -> (&LogicalPlan, Vec<Ident>) {
let plan: &LogicalPlan = subquery_alias.input.as_ref();

let LogicalPlan::Projection(outer_projections) = plan else {
return (plan, vec![]);
};

// check if it's projection inside projection
let LogicalPlan::Projection(inner_projection) = outer_projections.input.as_ref()
else {
return (plan, vec![]);
};

let mut columns: Vec<Ident> = vec![];
// check if the inner projection and outer projection have a matching pattern like
// Projection: j1.j1_id AS id
// Projection: j1.j1_id
for (i, inner_expr) in inner_projection.expr.iter().enumerate() {
let Expr::Alias(ref outer_alias) = &outer_projections.expr[i] else {
return (plan, vec![]);
};

if outer_alias.expr.as_ref() != inner_expr {
return (plan, vec![]);
};

columns.push(outer_alias.name.as_str().into());
}

(outer_projections.input.as_ref(), columns)
}

impl From<BuilderError> for DataFusionError {
fn from(e: BuilderError) -> Self {
DataFusionError::External(Box::new(e))
Expand Down
93 changes: 90 additions & 3 deletions datafusion/sql/src/unparser/rewrite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use datafusion_common::{
Result,
};
use datafusion_expr::{Expr, LogicalPlan, Projection, Sort};
use sqlparser::ast::Ident;

/// Normalize the schema of a union plan to remove qualifiers from the schema fields and sort expressions.
///
Expand Down Expand Up @@ -143,6 +144,9 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields(
map.insert(a.clone(), f.clone());
a
} else {
// inner expr may have different type to outer expr: e.g. a + 1 is a column of
// string in outer, but a expr of math in inner
map.insert(Expr::Column(f.to_string().into()), f.clone());
f.clone()
}
})
Expand All @@ -155,9 +159,15 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields(
}
}

if collects.iter().collect::<HashSet<_>>()
== inner_exprs.iter().collect::<HashSet<_>>()
{
// inner expr may have different type to outer expr: e.g. a + 1 is a column of
// string in outer, but a expr of math in inner
let outer_collects = collects.iter().map(Expr::to_string).collect::<HashSet<_>>();
let inner_collects = inner_exprs
.iter()
.map(Expr::to_string)
.collect::<HashSet<_>>();

if outer_collects == inner_collects {
let mut sort = sort.clone();
let mut inner_p = inner_p.clone();

Expand All @@ -175,3 +185,80 @@ pub(super) fn rewrite_plan_for_sort_on_non_projected_fields(
None
}
}

// This logic is to work out the columns and inner query for SubqueryAlias plan for both types of
// subquery
// - `(SELECT column_a as a from table) AS A`
// - `(SELECT column_a from table) AS A (a)`
//
// A roundtrip example for table alias with columns
//
// query: SELECT id FROM (SELECT j1_id from j1) AS c (id)
//
// LogicPlan:
// Projection: c.id
// SubqueryAlias: c
// Projection: j1.j1_id AS id
// Projection: j1.j1_id
// TableScan: j1
//
// Before introducing this logic, the unparsed query would be `SELECT c.id FROM (SELECT j1.j1_id AS
// id FROM (SELECT j1.j1_id FROM j1)) AS c`.
// The query is invalid as `j1.j1_id` is not a valid identifier in the derived table
// `(SELECT j1.j1_id FROM j1)`
//
// With this logic, the unparsed query will be:
// `SELECT c.id FROM (SELECT j1.j1_id FROM j1) AS c (id)`
//
// Caveat: this won't handle the case like `select * from (select 1, 2) AS a (b, c)`
// as the parser gives a wrong plan which has mismatch `Int(1)` types: Literal and
// Column in the Projections. Once the parser side is fixed, this logic should work
pub(super) fn subquery_alias_inner_query_and_columns(
subquery_alias: &datafusion_expr::SubqueryAlias,
) -> (&LogicalPlan, Vec<Ident>) {
let plan: &LogicalPlan = subquery_alias.input.as_ref();

let LogicalPlan::Projection(outer_projections) = plan else {
return (plan, vec![]);
};

// check if it's projection inside projection
let Some(inner_projection) = find_projection(outer_projections.input.as_ref()) else {
return (plan, vec![]);
};

let mut columns: Vec<Ident> = vec![];
// check if the inner projection and outer projection have a matching pattern like
// Projection: j1.j1_id AS id
// Projection: j1.j1_id
for (i, inner_expr) in inner_projection.expr.iter().enumerate() {
let Expr::Alias(ref outer_alias) = &outer_projections.expr[i] else {
return (plan, vec![]);
};

let expr = outer_alias.expr.clone();

// inner expr may have different type to outer expr: e.g. a + 1 is a column of
// string in outer, but a expr of math in inner
if expr.to_string() != inner_expr.to_string() {
return (plan, vec![]);
};

columns.push(outer_alias.name.as_str().into());
}

(outer_projections.input.as_ref(), columns)
}

fn find_projection(logical_plan: &LogicalPlan) -> Option<&Projection> {
match logical_plan {
LogicalPlan::Projection(p) => {
return Some(p);
}
LogicalPlan::Limit(p) => find_projection(p.input.as_ref()),
LogicalPlan::Distinct(p) => find_projection(p.input().as_ref()),
LogicalPlan::Sort(p) => find_projection(p.input.as_ref()),

_ => None,
}
}
20 changes: 20 additions & 0 deletions datafusion/sql/tests/cases/plan_to_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,26 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
// Test query that has calculation in derived table with columns
TestStatementWithDialect {
sql: "SELECT id FROM (SELECT j1_id + 1 * 3 from j1) AS c (id)",
expected: r#"SELECT c.id FROM (SELECT (j1.j1_id + (1 * 3)) FROM j1) AS c (id)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
// Test query that has limit/distinct/order in derived table with columns
TestStatementWithDialect {
sql: "SELECT id FROM (SELECT distinct (j1_id + 1 * 3) FROM j1 LIMIT 1) AS c (id)",
expected: r#"SELECT c.id FROM (SELECT DISTINCT (j1.j1_id + (1 * 3)) FROM j1 LIMIT 1) AS c (id)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
TestStatementWithDialect {
sql: "SELECT id FROM (SELECT j1_id + 1 FROM j1 ORDER BY j1_id DESC LIMIT 1) AS c (id)",
expected: r#"SELECT c.id FROM (SELECT (j1.j1_id + 1) FROM j1 ORDER BY j1.j1_id DESC NULLS FIRST LIMIT 1) AS c (id)"#,
parser_dialect: Box::new(GenericDialect {}),
unparser_dialect: Box::new(UnparserDefaultDialect {}),
},
];

for query in tests {
Expand Down

0 comments on commit 1c65aff

Please sign in to comment.