Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jayzhan211 committed Feb 26, 2025
1 parent 610c9a3 commit cb6c975
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 68 deletions.
11 changes: 5 additions & 6 deletions datafusion/optimizer/tests/optimizer_integration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ fn join_keys_in_subquery_alias_1() {
fn push_down_filter_groupby_expr_contains_alias() {
let sql = "SELECT * FROM (SELECT (col_int32 + col_uint32) AS c, count(*) FROM test GROUP BY 1) where c > 3";
let plan = test_sql(sql).unwrap();
let expected = "Projection: test.col_int32 + test.col_uint32 AS c, count(*)\
\n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[count(Int64(1)) AS count(*)]]\
let expected = "Projection: test.col_int32 + test.col_uint32 AS c, count(Int64(1)) AS count(*)\
\n Aggregate: groupBy=[[test.col_int32 + CAST(test.col_uint32 AS Int32)]], aggr=[[count(Int64(1))]]\
\n Filter: test.col_int32 + CAST(test.col_uint32 AS Int32) > Int32(3)\
\n TableScan: test projection=[col_int32, col_uint32]";
assert_eq!(expected, format!("{plan}"));
Expand Down Expand Up @@ -312,10 +312,9 @@ fn eliminate_redundant_null_check_on_count() {
GROUP BY col_int32
HAVING c IS NOT NULL";
let plan = test_sql(sql).unwrap();
let expected = "\
Projection: test.col_int32, count(*) AS c\
\n Aggregate: groupBy=[[test.col_int32]], aggr=[[count(Int64(1)) AS count(*)]]\
\n TableScan: test projection=[col_int32]";
let expected = "Projection: test.col_int32, count(Int64(1)) AS count(*) AS c\
\n Aggregate: groupBy=[[test.col_int32]], aggr=[[count(Int64(1))]]\
\n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{plan}"));
}

Expand Down
145 changes: 83 additions & 62 deletions datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -684,40 +684,50 @@ async fn roundtrip_union_all() -> Result<()> {

#[tokio::test]
async fn simple_intersect() -> Result<()> {
// Substrait treats both count(*) and count(1) the same
assert_expected_plan(
"SELECT count(*) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);",
"Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]\
\n Projection: \
\n LeftSemi Join: data.a = data2.a\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n TableScan: data2 projection=[a]",
true
).await?;

assert_expected_plan(
"SELECT count() FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);",
"Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count()]]\
\n Projection: \
\n LeftSemi Join: data.a = data2.a\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n TableScan: data2 projection=[a]",
true
).await?;
async fn check_wildcard(syntax: &str) -> Result<()> {
let expected_plan_str = format!(
"Projection: count(Int64(1)) AS {syntax}\
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\
\n Projection: \
\n LeftSemi Join: data.a = data2.a\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n TableScan: data2 projection=[a]"
);

assert_expected_plan(
&format!("SELECT {syntax} FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);"),
&expected_plan_str,
true
).await
}

assert_expected_plan(
"SELECT count(1) FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);",
"Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\
\n Projection: \
\n LeftSemi Join: data.a = data2.a\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n TableScan: data2 projection=[a]",
true
).await?;
async fn check_constant(sql_syntax: &str, plan_expr: &str) -> Result<()> {
let expected_plan_str = format!(
"Aggregate: groupBy=[[]], aggr=[[{plan_expr}]]\
\n Projection: \
\n LeftSemi Join: data.a = data2.a\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n TableScan: data2 projection=[a]"
);

assert_expected_plan(
&format!("SELECT {sql_syntax} FROM (SELECT data.a FROM data INTERSECT SELECT data2.a FROM data2);"),
&expected_plan_str,
true
).await
}

check_wildcard("count(*)").await?;
check_wildcard("count()").await?;
check_constant("count(1)", "count(Int64(1))").await?;
check_constant("count(2)", "count(Int64(2))").await?;
check_constant(
"count(1 + 2)",
"count(Int64(3)) AS count(Int64(1) + Int64(2))",
)
.await?;
Ok(())
}

Expand Down Expand Up @@ -843,44 +853,55 @@ async fn simple_intersect_table_reuse() -> Result<()> {
// Instead, when we consume Substrait, we add aliases before a join that'd otherwise collide.
// In this case the aliasing happens at a different point in the plan, so we cannot use roundtrip.
// Schema check works because we set aliases to what the Substrait consumer will generate.
assert_expected_plan(
"SELECT count(*) FROM (SELECT left.a FROM data AS left INTERSECT SELECT right.a FROM data AS right);",
"Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]\
\n Projection: \
\n LeftSemi Join: left.a = right.a\
\n SubqueryAlias: left\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n SubqueryAlias: right\
\n TableScan: data projection=[a]",
true
).await?;

assert_expected_plan(
"SELECT count() FROM (SELECT left.a FROM data AS left INTERSECT SELECT right.a FROM data AS right);",
"Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count()]]\
\n Projection: \
\n LeftSemi Join: left.a = right.a\
\n SubqueryAlias: left\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n SubqueryAlias: right\
\n TableScan: data projection=[a]",
true
).await?;
async fn check_wildcard(syntax: &str) -> Result<()> {
let expected_plan_str = format!(
"Projection: count(Int64(1)) AS {syntax}\
\n Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\
\n Projection: \
\n LeftSemi Join: left.a = right.a\
\n SubqueryAlias: left\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n SubqueryAlias: right\
\n TableScan: data projection=[a]"
);

assert_expected_plan(
&format!("SELECT {syntax} FROM (SELECT left.a FROM data AS left INTERSECT SELECT right.a FROM data AS right);"),
&expected_plan_str,
true
).await
}

assert_expected_plan(
"SELECT count(1) FROM (SELECT left.a FROM data AS left INTERSECT SELECT right.a FROM data AS right);",
"Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]]\
async fn check_constant(sql_syntax: &str, plan_expr: &str) -> Result<()> {
let expected_plan_str = format!(
"Aggregate: groupBy=[[]], aggr=[[{plan_expr}]]\
\n Projection: \
\n LeftSemi Join: left.a = right.a\
\n SubqueryAlias: left\
\n Aggregate: groupBy=[[data.a]], aggr=[[]]\
\n TableScan: data projection=[a]\
\n SubqueryAlias: right\
\n TableScan: data projection=[a]",
true
).await?;
\n TableScan: data projection=[a]"
);

assert_expected_plan(
&format!("SELECT {sql_syntax} FROM (SELECT left.a FROM data AS left INTERSECT SELECT right.a FROM data AS right);"),
&expected_plan_str,
true
).await
}

check_wildcard("count(*)").await?;
check_wildcard("count()").await?;
check_constant("count(1)", "count(Int64(1))").await?;
check_constant("count(2)", "count(Int64(2))").await?;
check_constant(
"count(1 + 2)",
"count(Int64(3)) AS count(Int64(1) + Int64(2))",
)
.await?;

Ok(())
}
Expand Down

0 comments on commit cb6c975

Please sign in to comment.