diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 7f0bcbec8..c0e86de61 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -11,6 +11,7 @@ from narwhals._arrow.expr import ArrowExpr from narwhals._arrow.selectors import ArrowSelectorNamespace from narwhals._arrow.series import ArrowSeries +from narwhals._arrow.utils import broadcast_series from narwhals._arrow.utils import horizontal_concat from narwhals._arrow.utils import vertical_concat from narwhals._expression_parsing import combine_root_names @@ -353,7 +354,8 @@ def __call__(self, df: ArrowDataFrame) -> list[ArrowSeries]: self._otherwise_value, namespace=plx )._call(df)[0] # type: ignore[arg-type] except TypeError: - # `self._otherwise_value` is a scalar and can't be converted to an expression + # `self._otherwise_value` is a scalar and can't be converted to an expression. + # Remark that string values _are_ converted into expressions! return [ value_series._from_native_series( pc.if_else( @@ -364,7 +366,14 @@ def __call__(self, df: ArrowDataFrame) -> list[ArrowSeries]: else: otherwise_series = cast(ArrowSeries, otherwise_series) condition = cast(ArrowSeries, condition) - return [value_series.zip_with(condition, otherwise_series)] + condition_native, otherwise_native = broadcast_series( + [condition, otherwise_series] + ) + return [ + value_series._from_native_series( + pc.if_else(condition_native, value_series_native, otherwise_native) + ) + ] def then(self, value: ArrowExpr | ArrowSeries | Any) -> ArrowThen: self._then_value = value diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 993988744..6fabaa68b 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -138,3 +138,10 @@ def test_when_then_otherwise_into_expr(constructor: Constructor) -> None: result = df.select(nw.when(nw.col("a") > 1).then("c").otherwise("e")) expected = {"c": [7, 5, 6]} compare_dicts(result, expected) + + +def test_when_then_otherwise_lit_str(constructor: Constructor) -> None: + df = nw.from_native(constructor(data)) + result = df.select(nw.when(nw.col("a") > 1).then(nw.col("b")).otherwise(nw.lit("z"))) + expected = {"b": ["z", "b", "c"]} + compare_dicts(result, expected)