Skip to content

Commit

Permalink
Add optimizer test for simplifying predicates on timestamps (apache#3939
Browse files Browse the repository at this point in the history
)
  • Loading branch information
alamb authored and Dandandan committed Nov 5, 2022
1 parent 22ff7a8 commit fd19d30
Showing 1 changed file with 80 additions and 30 deletions.
110 changes: 80 additions & 30 deletions datafusion/optimizer/tests/integration-test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef, TimeUnit};
use chrono::{DateTime, NaiveDateTime, Utc};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource};
use datafusion_optimizer::optimizer::Optimizer;
Expand Down Expand Up @@ -87,11 +88,11 @@ fn case_when_aggregate() -> Result<()> {

#[test]
fn unsigned_target_type() -> Result<()> {
let sql = "SELECT * FROM test WHERE col_uint32 > 0";
let sql = "SELECT col_utf8 FROM test WHERE col_uint32 > 0";
let plan = test_sql(sql)?;
let expected = "Projection: test.col_int32, test.col_uint32, test.col_utf8, test.col_date32, test.col_date64\
\n Filter: CAST(test.col_uint32 AS Int64) > Int64(0)\
\n TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]";
let expected = "Projection: test.col_utf8\
\n Filter: CAST(test.col_uint32 AS Int64) > Int64(0)\
\n TableScan: test projection=[col_uint32, col_utf8]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}
Expand All @@ -111,46 +112,46 @@ fn distribute_by() -> Result<()> {
#[test]
fn semi_join_with_join_filter() -> Result<()> {
// regression test for https://github.com/apache/arrow-datafusion/issues/2888
let sql = "SELECT * FROM test WHERE EXISTS (\
SELECT * FROM test t2 WHERE test.col_int32 = t2.col_int32 \
AND test.col_uint32 != t2.col_uint32)";
let sql = "SELECT col_utf8 FROM test WHERE EXISTS (\
SELECT col_utf8 FROM test t2 WHERE test.col_int32 = t2.col_int32 \
AND test.col_uint32 != t2.col_uint32)";
let plan = test_sql(sql)?;
let expected = r#"Projection: test.col_int32, test.col_uint32, test.col_utf8, test.col_date32, test.col_date64
Semi Join: test.col_int32 = t2.col_int32 Filter: test.col_uint32 != t2.col_uint32
TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]
SubqueryAlias: t2
TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]"#;
let expected = "Projection: test.col_utf8\
\n Semi Join: test.col_int32 = t2.col_int32 Filter: test.col_uint32 != t2.col_uint32\
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]\
\n SubqueryAlias: t2\
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}

#[test]
fn anti_join_with_join_filter() -> Result<()> {
// regression test for https://github.com/apache/arrow-datafusion/issues/2888
let sql = "SELECT * FROM test WHERE NOT EXISTS (\
SELECT * FROM test t2 WHERE test.col_int32 = t2.col_int32 \
AND test.col_uint32 != t2.col_uint32)";
let sql = "SELECT col_utf8 FROM test WHERE NOT EXISTS (\
SELECT col_utf8 FROM test t2 WHERE test.col_int32 = t2.col_int32 \
AND test.col_uint32 != t2.col_uint32)";
let plan = test_sql(sql)?;
let expected = r#"Projection: test.col_int32, test.col_uint32, test.col_utf8, test.col_date32, test.col_date64
Anti Join: test.col_int32 = t2.col_int32 Filter: test.col_uint32 != t2.col_uint32
TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]
SubqueryAlias: t2
TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]"#;
let expected = "Projection: test.col_utf8\
\n Anti Join: test.col_int32 = t2.col_int32 Filter: test.col_uint32 != t2.col_uint32\
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]\
\n SubqueryAlias: t2\
\n TableScan: test projection=[col_int32, col_uint32, col_utf8]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}

#[test]
fn where_exists_distinct() -> Result<()> {
// regression test for https://github.com/apache/arrow-datafusion/issues/3724
let sql = "SELECT * FROM test WHERE EXISTS (\
SELECT DISTINCT col_int32 FROM test t2 WHERE test.col_int32 = t2.col_int32)";
let sql = "SELECT col_int32 FROM test WHERE EXISTS (\
SELECT DISTINCT col_int32 FROM test t2 WHERE test.col_int32 = t2.col_int32)";
let plan = test_sql(sql)?;
let expected = r#"Projection: test.col_int32, test.col_uint32, test.col_utf8, test.col_date32, test.col_date64
Semi Join: test.col_int32 = t2.col_int32
TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]
SubqueryAlias: t2
TableScan: test projection=[col_int32, col_uint32, col_utf8, col_date32, col_date64]"#;
let expected = "Projection: test.col_int32\
\n Semi Join: test.col_int32 = t2.col_int32\
\n TableScan: test projection=[col_int32]\
\n SubqueryAlias: t2\
\n TableScan: test projection=[col_int32]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}
Expand Down Expand Up @@ -225,6 +226,38 @@ fn concat_ws_literals() -> Result<()> {
Ok(())
}

#[test]
#[ignore]
// https://github.com/apache/arrow-datafusion/issues/3938
fn timestamp_nano_ts_none_predicates() -> Result<()> {
let sql = "SELECT col_int32
FROM test
WHERE col_ts_nano_none < (now() - interval '1 hour')";
let plan = test_sql(sql)?;
// a scan should have the now()... predicate folded to a single
// constant and compared to the column without a cast so it can be
// pushed down / pruned
let expected = "Projection: test.col_int32\n Filter: test.col_ts_nano_utc < TimestampNanosecond(1666612093000000000, Some(\"UTC\"))\
\n TableScan: test projection=[col_int32, col_ts_nano_none]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}

#[test]
fn timestamp_nano_ts_utc_predicates() -> Result<()> {
let sql = "SELECT col_int32
FROM test
WHERE col_ts_nano_utc < (now() - interval '1 hour')";
let plan = test_sql(sql)?;
// a scan should have the now()... predicate folded to a single
// constant and compared to the column without a cast so it can be
// pushed down / pruned
let expected = "Projection: test.col_int32\n Filter: test.col_ts_nano_utc < TimestampNanosecond(1666612093000000000, Some(\"UTC\"))\
\n TableScan: test projection=[col_int32, col_ts_nano_utc]";
assert_eq!(expected, format!("{:?}", plan));
Ok(())
}

fn test_sql(sql: &str) -> Result<LogicalPlan> {
// parse the SQL
let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ...
Expand All @@ -236,9 +269,14 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
let sql_to_rel = SqlToRel::new(&schema_provider);
let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap();

// optimize the logical plan
let mut config = OptimizerConfig::new().with_skip_failing_rules(false);
// hard code the return value of now()
let now_time =
DateTime::<Utc>::from_utc(NaiveDateTime::from_timestamp(1666615693, 0), Utc);
let mut config = OptimizerConfig::new()
.with_skip_failing_rules(false)
.with_query_execution_start_time(now_time);
let optimizer = Optimizer::new(&config);
// optimize the logical plan
optimizer.optimize(&plan, &mut config, &observe)
}

Expand All @@ -258,6 +296,18 @@ impl ContextProvider for MySchemaProvider {
Field::new("col_utf8", DataType::Utf8, true),
Field::new("col_date32", DataType::Date32, true),
Field::new("col_date64", DataType::Date64, true),
// timestamp with no timezone
Field::new(
"col_ts_nano_none",
DataType::Timestamp(TimeUnit::Nanosecond, None),
true,
),
// timestamp with UTC timezone
Field::new(
"col_ts_nano_utc",
DataType::Timestamp(TimeUnit::Nanosecond, Some("UTC".into())),
true,
),
],
HashMap::new(),
);
Expand Down

0 comments on commit fd19d30

Please sign in to comment.