Skip to content

Commit

Permalink
fix: pyarrow unpivot upcast numeric (#1140)
Browse files Browse the repository at this point in the history
* fix: pyarrow unpivot upcast numeric

* pin min pyarrow version
  • Loading branch information
FBruzzesi authored Oct 6, 2024
1 parent 9e1dbb6 commit efc6a52
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 1 deletion.
10 changes: 9 additions & 1 deletion narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,11 @@ def unpivot(

n_rows = len(self)

promote_kwargs = (
{"promote_options": "permissive"}
if self._backend_version >= (14, 0, 0)
else {}
)
return self._from_native_frame(
pa.concat_tables(
[
Expand All @@ -694,6 +699,9 @@ def unpivot(
names=[*index_, variable_name, value_name],
)
for on_col in on_
]
],
**promote_kwargs,
)
)
# TODO(Unassigned): Even with promote_options="permissive", pyarrow does not
# upcast numeric to non-numeric (e.g. string) datatypes
34 changes: 34 additions & 0 deletions tests/frame/unpivot_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any

import pyarrow as pa
import pytest

import narwhals.stable.v1 as nw
from narwhals.utils import parse_version
from tests.utils import Constructor
from tests.utils import compare_dicts

if TYPE_CHECKING:
from narwhals.stable.v1.dtypes import DType

data = {
"a": ["x", "y", "z"],
"b": [1, 3, 5],
Expand Down Expand Up @@ -70,3 +78,29 @@ def test_unpivot_default_var_value_names(constructor: Constructor) -> None:
result = df.unpivot(on=["b", "c"], index=["a"])

assert result.collect_schema().names()[-2:] == ["variable", "value"]


@pytest.mark.parametrize(
("data", "expected_dtypes"),
[
(
{"idx": [0, 1], "a": [1, 2], "b": [1.5, 2.5]},
[nw.Int64(), nw.String(), nw.Float64()],
),
],
)
def test_unpivot_mixed_types(
request: pytest.FixtureRequest,
constructor: Constructor,
data: dict[str, Any],
expected_dtypes: list[DType],
) -> None:
if "dask" in str(constructor) or (
"pyarrow_table" in str(constructor)
and parse_version(pa.__version__) < parse_version("14.0.0")
):
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.unpivot(on=["a", "b"], index="idx")

assert result.collect_schema().dtypes() == expected_dtypes

0 comments on commit efc6a52

Please sign in to comment.