Skip to content

Commit

Permalink
c
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Nov 21, 2024
1 parent 5f61d70 commit 37b57a4
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ fn should_block_join_specific(
// the join can produce null values
// TODO! check if we can be less conservative here
BinaryExpr { op, left, right } => match op {
Operator::NotEq => LeftRight(false, false),
Operator::Eq => {
Operator::Eq | Operator::NotEq => {
let LeftRight(bleft, bright) = join_produces_null(how);

let l_name = aexpr_output_name(*left, expr_arena).unwrap();
Expand All @@ -62,7 +61,7 @@ fn should_block_join_specific(
},
_ => join_produces_null(how),
},
_ => LeftRight(false, false),
_ => join_produces_null(how),
}
}

Expand Down
35 changes: 30 additions & 5 deletions py-polars/tests/unit/test_predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,16 +555,41 @@ def test_predicate_pushdown_struct_unnest_19632() -> None:
)


def test_predicate_pushdown_right_join_19772() -> None:
left = pl.LazyFrame({"k": [1], "v": [7]})
right = pl.LazyFrame({"k": [1, 2]})
@pytest.mark.parametrize(
"predicate",
[
pl.col("v") == 7,
pl.col("v") != 99,
pl.col("v") > 0,
pl.col("v") < 999,
pl.col("v").is_in([7]),
pl.col("v").cast(pl.Boolean),
pl.col("b"),
],
)
@pytest.mark.parametrize("alias", [True, False])
@pytest.mark.parametrize("join_type", ["left", "right"])
def test_predicate_pushdown_join_19772(
predicate: pl.Expr, join_type: str, alias: bool
) -> None:
left = pl.LazyFrame({"k": [1, 2]})
right = pl.LazyFrame({"k": [1], "v": [7], "b": True})

if join_type == "right":
[left, right] = [right, left]

q = left.join(right, on="k", how="right").filter(pl.col("v") == 7)
if alias:
predicate = predicate.alias(":V")

q = left.join(right, on="k", how=join_type).filter(predicate) # type: ignore[arg-type]

plan = q.explain()
assert plan.startswith("FILTER")

expect = pl.DataFrame({"v": 7, "k": 1})
expect = pl.DataFrame({"k": 1, "v": 7, "b": True})

if join_type == "right":
expect = expect.select("v", "b", "k")

assert_frame_equal(q.collect(no_optimization=True), expect)
assert_frame_equal(q.collect(), expect)

0 comments on commit 37b57a4

Please sign in to comment.