Skip to content

Commit

Permalink
fixup dask
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Jan 4, 2025
1 parent 4d4e5b3 commit 89597a0
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 20 deletions.
35 changes: 21 additions & 14 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,26 +66,30 @@ def nth(self, *column_indices: int) -> DaskExpr:
)

def lit(self, value: Any, dtype: DType | None) -> DaskExpr:
def convert_if_dtype(
series: dask_expr.Series, dtype: DType | type[DType]
) -> dask_expr.Series:
return (
series.astype(narwhals_to_native_dtype(dtype, self._version))
if dtype
else series
)
import dask.dataframe as dd
import pandas as pd

return DaskExpr(
lambda df: [
df._native_frame.assign(literal=value)["literal"].pipe(
convert_if_dtype, dtype
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
return [
dd.from_pandas(
pd.Series(
[value],
dtype=narwhals_to_native_dtype(dtype, self._version)
if dtype is not None
else None,
name="literal",
),
npartitions=df._native_frame.npartitions,
)
],
]

return DaskExpr(
func,
depth=0,
function_name="lit",
root_names=None,
output_names=["literal"],
returns_scalar=False,
returns_scalar=True,
backend_version=self._backend_version,
version=self._version,
kwargs={},
Expand Down Expand Up @@ -415,6 +419,9 @@ def __call__(self, df: DaskLazyFrame) -> Sequence[dask_expr.Series]:
return [value_series.where(condition, self._otherwise_value)]
otherwise_series = otherwise_expr(df)[0]
validate_comparand(condition, otherwise_series)

if otherwise_expr._returns_scalar: # type: ignore[attr-defined]
return [value_series.where(condition, otherwise_series[0])]
return [value_series.where(condition, otherwise_series)]

def then(self, value: DaskExpr | Any) -> DaskThen:
Expand Down
3 changes: 2 additions & 1 deletion narwhals/_dask/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any:
msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) not supported in this context"
raise NotImplementedError(msg)
result = results[0]
validate_comparand(df._native_frame, result)
if not obj._returns_scalar:
validate_comparand(df._native_frame, result)
if obj._returns_scalar:
# Return scalar, let Dask do its broadcasting
return result[0]
Expand Down
2 changes: 1 addition & 1 deletion tests/expr_and_series/arithmetic_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_right_arithmetic_expr(

data = {"a": [1, 2, 3]}
df = nw.from_native(constructor(data))
result = df.select(getattr(nw.col("a"), attr)(rhs))
result = df.with_columns(literal=getattr(nw.col("a"), attr)(rhs)).select("literal")
assert_equal_data(result, {"literal": expected})


Expand Down
7 changes: 6 additions & 1 deletion tests/expr_and_series/when_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,12 @@ def test_when_then_otherwise_into_expr(constructor: Constructor) -> None:
assert_equal_data(result, expected)


def test_when_then_otherwise_lit_str(constructor: Constructor) -> None:
def test_when_then_otherwise_lit_str(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "dask" in str(constructor):
# TODO(marco): bug in dask? exprs are not co-aligned
request.applymarker(pytest.mark.xfail)
df = nw.from_native(constructor(data))
result = df.select(nw.when(nw.col("a") > 1).then(nw.col("b")).otherwise(nw.lit("z")))
expected = {"b": ["z", "b", "c"]}
Expand Down
3 changes: 0 additions & 3 deletions tests/frame/lit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,7 @@ def test_lit_operation(
col_name: str,
expr: nw.Expr,
expected_result: list[int],
request: pytest.FixtureRequest,
) -> None:
if "dask_lazy_p2" in str(constructor) and "lit_with_agg" in col_name:
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 3, 2]}
df_raw = constructor(data)
df = nw.from_native(df_raw).lazy()
Expand Down

0 comments on commit 89597a0

Please sign in to comment.