Skip to content

Commit

Permalink
fix: when-then-otherwise lit string for arrow backend (#1137)
Browse files Browse the repository at this point in the history
* fix: when-then-otherwise lit string for arrow backend

* rename test
  • Loading branch information
FBruzzesi authored Oct 6, 2024
1 parent efc6a52 commit 5aa4e12
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
13 changes: 11 additions & 2 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.selectors import ArrowSelectorNamespace
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.utils import broadcast_series
from narwhals._arrow.utils import horizontal_concat
from narwhals._arrow.utils import vertical_concat
from narwhals._expression_parsing import combine_root_names
Expand Down Expand Up @@ -353,7 +354,8 @@ def __call__(self, df: ArrowDataFrame) -> list[ArrowSeries]:
self._otherwise_value, namespace=plx
)._call(df)[0] # type: ignore[arg-type]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
# `self._otherwise_value` is a scalar and can't be converted to an expression.
# Remark that string values _are_ converted into expressions!
return [
value_series._from_native_series(
pc.if_else(
Expand All @@ -364,7 +366,14 @@ def __call__(self, df: ArrowDataFrame) -> list[ArrowSeries]:
else:
otherwise_series = cast(ArrowSeries, otherwise_series)
condition = cast(ArrowSeries, condition)
return [value_series.zip_with(condition, otherwise_series)]
condition_native, otherwise_native = broadcast_series(
[condition, otherwise_series]
)
return [
value_series._from_native_series(
pc.if_else(condition_native, value_series_native, otherwise_native)
)
]

def then(self, value: ArrowExpr | ArrowSeries | Any) -> ArrowThen:
self._then_value = value
Expand Down
7 changes: 7 additions & 0 deletions tests/expr_and_series/when_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,10 @@ def test_when_then_otherwise_into_expr(constructor: Constructor) -> None:
result = df.select(nw.when(nw.col("a") > 1).then("c").otherwise("e"))
expected = {"c": [7, 5, 6]}
compare_dicts(result, expected)


def test_when_then_otherwise_lit_str(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") > 1).then(nw.col("b")).otherwise(nw.lit("z")))
expected = {"b": ["z", "b", "c"]}
compare_dicts(result, expected)

0 comments on commit 5aa4e12

Please sign in to comment.