Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

after type coercion CommonSubexprEliminate will produce invalid projection #3635

Closed
liukun4515 opened this issue Sep 28, 2022 · 9 comments · Fixed by #3726 or #4487
Closed

after type coercion CommonSubexprEliminate will produce invalid projection #3635

liukun4515 opened this issue Sep 28, 2022 · 9 comments · Fixed by #3726 or #4487
Labels
bug Something isn't working

Comments

@liukun4515
Copy link
Contributor

Describe the bug

after do this pr with moving the type coercion to the beginning for the optimizer, the CommonSubexprEliminate will generate the invalid projection

maybe like #2907

I think many of the optimizer rule didn't take care of the cast/try_cast for the type coercion.

To Reproduce
Steps to reproduce the behavior:

Expected behavior
A clear and concise description of what you expected to happen.

Additional context
Add any other context about the problem here.

@liukun4515
Copy link
Contributor Author

cc @andygrove

@liukun4515
Copy link
Contributor Author

Do you have time to take a look this issue?

@alex-spies
Copy link
Contributor

I was able to hunt down the cause of the bug, I think:

In arrow-datafusion/datafusion/optimizer/src/common_subexpr_eliminate.rs, the ExpressionvVsitor::post_visit assigns the same datatype to every sub-expression in the expr_set; more specifically, it assigns the datatype of the overall expression result.

Accordingly, I don't think the issue is limited to cast expressions.

This can be seen in the following test (place this in the same file to reproduce):

    #[test]
    fn filter_schema_after_optimization() {
        use datafusion_expr::cast;

        let schema = Schema::new(vec![
            Field::new("a", DataType::UInt64, false),
            Field::new("b", DataType::UInt64, false),
            Field::new("c", DataType::UInt64, false),
        ]);

        let plan = table_scan(Some("table"), &schema, None)
            .unwrap()
            .filter(cast(col("a"), DataType::Int64).lt(lit(1_i64)).and(cast(col("a"), DataType::Int64).not_eq(lit(1_i64))))
            .unwrap()
            .build()
            .unwrap();
        let rule = CommonSubexprEliminate {};
        let optimized_plan = rule.optimize(&plan, &mut OptimizerConfig::new()).unwrap();

        /* The optimized_plan has an additional column for the cast; the datatype should be i64,
         * but it's boolean instead.
         *
         * DFField {
         *    qualifier: None,
         *    field: Field {
         *        name: "CAST(#table.a AS Int64)#table.a",
         *        data_type: Boolean,
         *        nullable: true,
         *        dict_id: 0,
         *        dict_is_ordered: false,
         *        metadata: None,
         *    },
         *},
         *
         */

        println!("{:#?}", optimized_plan);
        println!("{:#?}", optimized_plan.schema());

        panic!();
    }

In ExpressionVisitor::post_visit, one can see that we're assigning the same datatype to all sub-expressions:

    fn post_visit(mut self, expr: &Expr) -> Result<Self> {
        self.series_number += 1;

        let (idx, sub_expr_desc) = self.pop_enter_mark();
        // skip exprs should not be recognize.
        if matches!(
            expr,
            Expr::Literal(..)
                | Expr::Column(..)
                | Expr::ScalarVariable(..)
                | Expr::Alias(..)
                | Expr::Sort { .. }
                | Expr::Wildcard
        ) {
            self.id_array[idx].0 = self.series_number;
            let desc = Self::desc_expr(expr);
            self.visit_stack.push(VisitRecord::ExprItem(desc));
            return Ok(self);
        }                                                                                                                                                                                                                                                                                                  let mut desc = Self::desc_expr(expr);
        desc.push_str(&sub_expr_desc);

        self.id_array[idx] = (self.series_number, desc.clone());
        self.visit_stack.push(VisitRecord::ExprItem(desc.clone()));
        // Error: data type of a sub-expression can be different from the final type.
        // This leads to a wrong schema of the resulting logical plan.
        let data_type = self.data_type.clone();
        self.expr_set
            .entry(desc)
            .or_insert_with(|| (expr.clone(), 0, data_type))
            .1 += 1;
        Ok(self)
    }

@liukun4515
Copy link
Contributor Author

@alex-natzka Could you please take a look the test multiple_or_predicates, I think it is failed now.
https://github.com/apache/arrow-datafusion/blob/d72eb9a1c4c18bcabbf941541a9c1defa83a592c/datafusion/core/tests/sql/predicates.rs#L433

I ignore this test, because it will produce unexpected result. cc @alamb @andygrove

expected:

[
    "Explain [plan_type:Utf8, plan:Utf8]",
    "  Projection: lineitem.l_partkey [l_partkey:Int64]",
    "    Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND CAST(part.p_size AS Int64) BETWEEN Int64(1) AND Int64(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND CAST(part.p_size AS Int64) BETWEEN Int64(1) AND Int64(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND CAST(part.p_size AS Int64) BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
    "      Inner Join: lineitem.l_partkey = part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
    "        TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
    "        TableScan: part projection=[p_partkey, p_brand, p_size] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
]
actual:

[
    "Explain [plan_type:Utf8, plan:Utf8]",
    "  Projection: lineitem.l_partkey [l_partkey:Int64]",
    "    Projection: part.p_size >= Int32(1) AS part.p_size >= Int32(1)Int32(1)part.p_size, lineitem.l_partkey, lineitem.l_quantity, part.p_brand, part.p_size [part.p_size >= Int32(1)Int32(1)part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size:Int32]",
    "      Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
    "        Inner Join: lineitem.l_partkey = part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
    "          TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
    "          Filter: part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
    "            TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
]

@liukun4515 liukun4515 reopened this Oct 11, 2022
@alex-spies
Copy link
Contributor

@alex-natzka Could you please take a look the test multiple_or_predicates, I think it is failed now.

https://github.com/apache/arrow-datafusion/blob/d72eb9a1c4c18bcabbf941541a9c1defa83a592c/datafusion/core/tests/sql/predicates.rs#L433

I ignore this test, because it will produce unexpected result. cc @alamb @andygrove

expected:

[
    "Explain [plan_type:Utf8, plan:Utf8]",
    "  Projection: lineitem.l_partkey [l_partkey:Int64]",
    "    Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND CAST(part.p_size AS Int64) BETWEEN Int64(1) AND Int64(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND CAST(part.p_size AS Int64) BETWEEN Int64(1) AND Int64(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND CAST(part.p_size AS Int64) BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
    "      Inner Join: lineitem.l_partkey = part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
    "        TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
    "        TableScan: part projection=[p_partkey, p_brand, p_size] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
]
actual:

[
    "Explain [plan_type:Utf8, plan:Utf8]",
    "  Projection: lineitem.l_partkey [l_partkey:Int64]",
    "    Projection: part.p_size >= Int32(1) AS part.p_size >= Int32(1)Int32(1)part.p_size, lineitem.l_partkey, lineitem.l_quantity, part.p_brand, part.p_size [part.p_size >= Int32(1)Int32(1)part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size:Int32]",
    "      Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
    "        Inner Join: lineitem.l_partkey = part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
    "          TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
    "          Filter: part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
    "            TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
]

@liukun4515 looks like you needed to ignore the test for your previous PR #3636 . I'm a bit confused because, at least at first glance, it looks like the actual output is not wrong. (I didn't look very thoroughly at this complicated expression, though.) Also, the unexpected test output seems unrelated to the wrong extra columns that the CommonSubexprEliminate optimizer rule created, which is now fixed.

It may be safe to create a PR that un-ignores the test and updates the expected value with what we see here, though maybe @alamb and @andygrove can weigh in on this.

@alex-spies
Copy link
Contributor

@liukun4515 , after staring at the actual output for a longer while, I realize that there is a double projection above the filter, which is indeed unexpected:

    "  Projection: lineitem.l_partkey [l_partkey:Int64]",
    "    Projection: part.p_size >= Int32(1) AS part.p_size >= Int32(1)Int32(1)part.p_size, lineitem.l_partkey, lineitem.l_quantity, part.p_brand, part.p_size [part.p_size >= Int32(1)Int32(1)part.p_size:Boolean;N, l_partkey:Int64, l_quantity:Decimal128(15, 2), p_brand:Utf8, p_size:Int32]",

The projection on the second line seems completely unnecessary.

Sorry, seems like I solved a different problem than what you meant :/ Unfortunately, I'm not familiar enough with this optimization rule to fix this. At least it seems to me like the optimized logical plan is not wrong, just that the unnecessary projection is inefficient.

@alex-spies
Copy link
Contributor

Looks like CommonSubexprEliminate creates the second projection which is initially before the filter node, but then the filter gets pushed down, c.f. the ordering of optimization rules here: https://github.com/apache/arrow-datafusion/blob/61c38b7114e802f9f289bf5364a031395f5799a6/datafusion/optimizer/src/optimizer.rs#L153-L162

I just tried and I get a more sensible execution plan in the test multiple_or_predicates in https://github.com/apache/arrow-datafusion/blob/d72eb9a1c4c18bcabbf941541a9c1defa83a592c/datafusion/core/tests/sql/predicates.rs#L433, if I re-order the optimization rules and place CommonSubexprEliminate after FilterPushdown.

Changing the order of optimization rules here https://github.com/apache/arrow-datafusion/blob/61c38b7114e802f9f289bf5364a031395f5799a6/datafusion/optimizer/src/optimizer.rs#L138 like this
image

yields the following test output:

[
    "Explain [plan_type:Utf8, plan:Utf8]",
    "  Projection: lineitem.l_partkey [l_partkey:Int64]",
    "    Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND CAST(part.p_size AS Int64) BETWEEN Int64(1) AND Int64(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND CAST(part.p_size AS Int64) BETWEEN Int64(1) AND Int64(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND CAST(part.p_size AS Int64) BETWEEN Int64(1) AND Int64(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
    "      Inner Join: lineitem.l_partkey = part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
    "        TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
    "        TableScan: part projection=[p_partkey, p_brand, p_size] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
]
actual:

[
    "Explain [plan_type:Utf8, plan:Utf8]",
    "  Projection: lineitem.l_partkey [l_partkey:Int64]",
    "    Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
    "      Inner Join: lineitem.l_partkey = part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
    "        TableScan: lineitem projection=[l_partkey, l_quantity] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]",
    "        Filter: part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
    "          TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
]

which looks better to me.

@alamb
Copy link
Contributor

alamb commented Oct 12, 2022

which looks better to me.

I agree -- this is a very nice writeup @alex-natzka -- thank you (love the diagram).

Shall we make a PR that proposes switching the order of the passes (and add comments explaining why we do filter pushdown first)?

@alex-spies
Copy link
Contributor

@alamb sorry for the late reply - I was away for a couple of days.

Created a draft PR #3861 .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
3 participants