Skip to content

Commit

Permalink
fix: Incorrect explode schema for LazyFrame.explode() (#19860)
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Nov 19, 2024
1 parent 9f914b7 commit 545b9bc
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 12 deletions.
15 changes: 11 additions & 4 deletions crates/polars-plan/src/plans/functions/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,17 @@ fn explode_schema<'a>(

// columns to string
columns.iter().try_for_each(|name| {
if let DataType::List(inner) = schema.try_get(name)? {
let inner = *inner.clone();
schema.with_column(name.clone(), inner);
};
match schema.try_get(name)? {
DataType::List(inner) => {
schema.with_column(name.clone(), inner.as_ref().clone());
},
#[cfg(feature = "dtype-array")]
DataType::Array(inner, _) => {
schema.with_column(name.clone(), inner.as_ref().clone());
},
_ => {},
}

PolarsResult::Ok(())
})?;
let schema = Arc::new(schema);
Expand Down
28 changes: 20 additions & 8 deletions py-polars/tests/unit/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def test_schema_functions_in_agg_with_literal_arg_19011() -> None:
)


def test_lf_explode_in_agg_schema_19562() -> None:
def test_lazy_explode_in_agg_schema_19562() -> None:
def new_df_check_schema(
value: dict[str, Any], schema: dict[str, Any]
) -> pl.DataFrame:
Expand Down Expand Up @@ -192,7 +192,7 @@ def new_df_check_schema(
assert_frame_equal(q.collect(), new_df_check_schema({"a": [1], "b": [[1]]}, schema))


def test_lf_nested_function_expr_agg_schema() -> None:
def test_lazy_nested_function_expr_agg_schema() -> None:
q = (
pl.LazyFrame({"k": [1, 1, 2]})
.group_by(pl.first(), maintain_order=True)
Expand All @@ -205,15 +205,15 @@ def test_lf_nested_function_expr_agg_schema() -> None:
)


def test_lf_agg_scalar_return_schema() -> None:
def test_lazy_agg_scalar_return_schema() -> None:
q = pl.LazyFrame({"k": [1]}).group_by("k").agg(pl.col("k").null_count().alias("o"))

schema = {"k": pl.Int64, "o": pl.UInt32}
assert q.collect_schema() == schema
assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": 0}, schema=schema))


def test_lf_agg_nested_expr_schema() -> None:
def test_lazy_agg_nested_expr_schema() -> None:
q = (
pl.LazyFrame({"k": [1]})
.group_by("k")
Expand All @@ -236,7 +236,7 @@ def test_lf_agg_nested_expr_schema() -> None:
assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": 0}, schema=schema))


def test_lf_agg_lit_explode() -> None:
def test_lazy_agg_lit_explode() -> None:
q = (
pl.LazyFrame({"k": [1]})
.group_by("k")
Expand All @@ -255,7 +255,7 @@ def test_lf_agg_lit_explode() -> None:
"nan_min", "null_count", "product", "sample", "skew", "std", "sum", "upper_bound",
"var"
]) # fmt: skip
def test_lf_agg_auto_agg_list_19752(expr_op: str) -> None:
def test_lazy_agg_auto_agg_list_19752(expr_op: str) -> None:
op = getattr(pl.Expr, expr_op)

lf = pl.LazyFrame({"a": 1, "b": 1})
Expand All @@ -272,15 +272,15 @@ def test_lf_agg_auto_agg_list_19752(expr_op: str) -> None:
"expr", [pl.col("b"), pl.col("b").sum(), pl.col("b").reverse()]
)
@pytest.mark.parametrize("mapping_strategy", ["explode", "join", "group_to_rows"])
def test_lf_window_schema(expr: pl.Expr, mapping_strategy: str) -> None:
def test_lazy_window_schema(expr: pl.Expr, mapping_strategy: str) -> None:
q = pl.LazyFrame({"a": 1, "b": 1}).select(
expr.over("a", mapping_strategy=mapping_strategy) # type: ignore[arg-type]
)

assert q.collect_schema() == q.collect().collect_schema()


def test_lf_explode_schema() -> None:
def test_lazy_explode_schema() -> None:
lf = pl.LazyFrame({"k": [1], "x": pl.Series([[1]], dtype=pl.Array(pl.Int64, 1))})

q = lf.select(pl.col("x").explode())
Expand All @@ -297,6 +297,18 @@ def test_lf_explode_schema() -> None:
q = lf.select(pl.col("x").list.explode())
assert q.collect_schema() == {"x": pl.Int64}

# `LazyFrame.explode()` goes through a different codepath than `Expr.expode`
lf = pl.LazyFrame().with_columns(
pl.Series([[1]], dtype=pl.List(pl.Int64)).alias("list"),
pl.Series([[1]], dtype=pl.Array(pl.Int64, 1)).alias("array"),
)

q = lf.explode("*")
assert q.collect_schema() == {"list": pl.Int64, "array": pl.Int64}

q = lf.explode("list")
assert q.collect_schema() == {"list": pl.Int64, "array": pl.Array(pl.Int64, 1)}


def test_raise_subnodes_18787() -> None:
df = pl.DataFrame({"a": [1], "b": [2]})
Expand Down

0 comments on commit 545b9bc

Please sign in to comment.