diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs index 7e4710e709a1..3437bbc8b3e1 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs @@ -44,25 +44,25 @@ fn should_block_join_specific( // any operation that checks for equality or ordering can be wrong because // the join can produce null values // TODO! check if we can be less conservative here - BinaryExpr { op, left, right } => match op { - Operator::NotEq => LeftRight(false, false), - Operator::Eq => { - let LeftRight(bleft, bright) = join_produces_null(how); + BinaryExpr { + op: Operator::Eq | Operator::NotEq, + left, + right, + } => { + let LeftRight(bleft, bright) = join_produces_null(how); - let l_name = aexpr_output_name(*left, expr_arena).unwrap(); - let r_name = aexpr_output_name(*right, expr_arena).unwrap(); + let l_name = aexpr_output_name(*left, expr_arena).unwrap(); + let r_name = aexpr_output_name(*right, expr_arena).unwrap(); - let is_in_on = on_names.contains(&l_name) || on_names.contains(&r_name); + let is_in_on = on_names.contains(&l_name) || on_names.contains(&r_name); - let block_left = - is_in_on && (schema_left.contains(&l_name) || schema_left.contains(&r_name)); - let block_right = - is_in_on && (schema_right.contains(&l_name) || schema_right.contains(&r_name)); - LeftRight(block_left | bleft, block_right | bright) - }, - _ => join_produces_null(how), + let block_left = + is_in_on && (schema_left.contains(&l_name) || schema_left.contains(&r_name)); + let block_right = + is_in_on && (schema_right.contains(&l_name) || schema_right.contains(&r_name)); + LeftRight(block_left | bleft, block_right | bright) }, - _ => LeftRight(false, false), + _ => join_produces_null(how), } } diff --git a/py-polars/tests/unit/test_predicates.py b/py-polars/tests/unit/test_predicates.py index e8f0be927cb9..7a458d65e7fe 100644 --- a/py-polars/tests/unit/test_predicates.py +++ b/py-polars/tests/unit/test_predicates.py @@ -555,16 +555,41 @@ def test_predicate_pushdown_struct_unnest_19632() -> None: ) -def test_predicate_pushdown_right_join_19772() -> None: - left = pl.LazyFrame({"k": [1], "v": [7]}) - right = pl.LazyFrame({"k": [1, 2]}) +@pytest.mark.parametrize( + "predicate", + [ + pl.col("v") == 7, + pl.col("v") != 99, + pl.col("v") > 0, + pl.col("v") < 999, + pl.col("v").is_in([7]), + pl.col("v").cast(pl.Boolean), + pl.col("b"), + ], +) +@pytest.mark.parametrize("alias", [True, False]) +@pytest.mark.parametrize("join_type", ["left", "right"]) +def test_predicate_pushdown_join_19772( + predicate: pl.Expr, join_type: str, alias: bool +) -> None: + left = pl.LazyFrame({"k": [1, 2]}) + right = pl.LazyFrame({"k": [1], "v": [7], "b": True}) + + if join_type == "right": + [left, right] = [right, left] - q = left.join(right, on="k", how="right").filter(pl.col("v") == 7) + if alias: + predicate = predicate.alias(":V") + + q = left.join(right, on="k", how=join_type).filter(predicate) # type: ignore[arg-type] plan = q.explain() assert plan.startswith("FILTER") - expect = pl.DataFrame({"v": 7, "k": 1}) + expect = pl.DataFrame({"k": 1, "v": 7, "b": True}) + + if join_type == "right": + expect = expect.select("v", "b", "k") assert_frame_equal(q.collect(no_optimization=True), expect) assert_frame_equal(q.collect(), expect)