Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support various reductions in pyspark #1870

Merged
merged 9 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from narwhals._spark_like.utils import native_to_narwhals_dtype
from narwhals._spark_like.utils import parse_exprs_and_named_exprs
from narwhals.typing import CompliantLazyFrame
from narwhals.utils import Implementation
from narwhals.utils import check_column_exists
from narwhals.utils import parse_columns_to_drop
Expand All @@ -26,7 +27,6 @@
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals.dtypes import DType
from narwhals.utils import Version
from narwhals.typing import CompliantLazyFrame


class SparkLikeLazyFrame(CompliantLazyFrame):
Expand Down Expand Up @@ -94,7 +94,9 @@ def select(
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns = parse_exprs_and_named_exprs(self, *exprs, **named_exprs)
new_columns = parse_exprs_and_named_exprs(
self, *exprs, with_columns_context=False, **named_exprs
)

if not new_columns:
# return empty dataframe, like Polars does
Expand Down Expand Up @@ -135,7 +137,9 @@ def with_columns(
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns_map = parse_exprs_and_named_exprs(self, *exprs, **named_exprs)
new_columns_map = parse_exprs_and_named_exprs(
self, *exprs, with_columns_context=True, **named_exprs
)
return self._from_native_frame(self._native_frame.withColumns(new_columns_map))

def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
Expand Down
26 changes: 15 additions & 11 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _from_call(
def func(df: SparkLikeLazyFrame) -> list[Column]:
native_series_list = self._call(df)
other_native_series = {
key: maybe_evaluate(df, value)
key: maybe_evaluate(df, value, returns_scalar=returns_scalar)
for key, value in expressifiable_args.items()
}
return [
Expand Down Expand Up @@ -349,26 +349,30 @@ def sum(self: Self) -> Self:
return self._from_call(F.sum, "sum", returns_scalar=True)

def std(self: Self, ddof: int) -> Self:
from functools import partial

import numpy as np # ignore-banned-import

from narwhals._spark_like.utils import _std

func = partial(_std, ddof=ddof, np_version=parse_version(np.__version__))

return self._from_call(func, "std", returns_scalar=True, ddof=ddof)
return self._from_call(
_std,
"std",
returns_scalar=True,
ddof=ddof,
np_version=parse_version(np.__version__),
)

def var(self: Self, ddof: int) -> Self:
from functools import partial

import numpy as np # ignore-banned-import

from narwhals._spark_like.utils import _var

func = partial(_var, ddof=ddof, np_version=parse_version(np.__version__))

return self._from_call(func, "var", returns_scalar=True, ddof=ddof)
return self._from_call(
_var,
"var",
returns_scalar=True,
ddof=ddof,
np_version=parse_version(np.__version__),
)

def clip(
self: Self,
Expand Down
38 changes: 30 additions & 8 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING
from typing import Any

from pyspark.sql import Window
from pyspark.sql import functions as F # noqa: N812

from narwhals.exceptions import UnsupportedDTypeError
Expand Down Expand Up @@ -110,9 +111,18 @@ def narwhals_to_native_dtype(


def parse_exprs_and_named_exprs(
df: SparkLikeLazyFrame, *exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr
df: SparkLikeLazyFrame,
*exprs: SparkLikeExpr,
with_columns_context: bool,
MarcoGorelli marked this conversation as resolved.
Show resolved Hide resolved
**named_exprs: SparkLikeExpr,
) -> dict[str, Column]:
native_results: dict[str, list[Column]] = {}
native_results: dict[str, Column] = {}

# `returns_scalar` keeps track if an expression returns a scalar and is not lit.
# Notice that lit is quite special case, since it gets broadcasted by pyspark
# without the need of adding `.over(Window.partitionBy(F.lit(1)))`
returns_scalar: list[bool] = []

for expr in exprs:
native_series_list = expr._call(df)
output_names = expr._evaluate_output_names(df)
Expand All @@ -122,16 +132,29 @@ def parse_exprs_and_named_exprs(
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
raise AssertionError(msg)
native_results.update(zip(output_names, native_series_list))
returns_scalar.extend(
[expr._returns_scalar and expr._function_name != "lit"] * len(output_names)
)
for col_alias, expr in named_exprs.items():
native_series_list = expr._call(df)
if len(native_series_list) != 1: # pragma: no cover
msg = "Named expressions must return a single column"
raise ValueError(msg)
native_results[col_alias] = native_series_list[0]
return native_results
returns_scalar.append(expr._returns_scalar and expr._function_name != "lit")

if all(returns_scalar) and not with_columns_context:
return native_results
else:
return {
col_name: col.over(Window.partitionBy(F.lit(1))) if _returns_scalar else col
for (col_name, col), _returns_scalar in zip(
native_results.items(), returns_scalar
)
}
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved

def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any:

def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, returns_scalar: bool) -> Any:
from narwhals._spark_like.expr import SparkLikeExpr

if isinstance(obj, SparkLikeExpr):
Expand All @@ -140,10 +163,9 @@ def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any:
msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) not supported in this context"
raise NotImplementedError(msg)
column_result = column_results[0]
if obj._returns_scalar:
# Return scalar, let PySpark do its broadcasting
from pyspark.sql.window import Window

if obj._returns_scalar and not returns_scalar and obj._function_name != "lit":
# Returns scalar, but overall expression doesn't.
# Let PySpark do its broadcasting
return column_result.over(Window.partitionBy(F.lit(1)))
return column_result
return obj
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ xfail_strict = true
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
env = [
"MODIN_ENGINE=python",
"PYARROW_IGNORE_TIMEZONE=1"
"PYARROW_IGNORE_TIMEZONE=1",
"TZ=UTC",
]

[tool.coverage.run]
Expand Down
7 changes: 0 additions & 7 deletions tests/expr_and_series/lit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,6 @@ def test_lit_operation(
and DASK_VERSION < (2024, 10)
):
request.applymarker(pytest.mark.xfail)
if "pyspark" in str(constructor) and col_name in {
"left_lit_with_agg",
"left_scalar_with_agg",
"right_lit_with_agg",
"right_lit",
}:
request.applymarker(pytest.mark.xfail)

data = {"a": [1, 3, 2]}
df_raw = constructor(data)
Expand Down
14 changes: 3 additions & 11 deletions tests/expr_and_series/reduction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,6 @@ def test_scalar_reduction_select(
expected: dict[str, list[Any]],
request: pytest.FixtureRequest,
) -> None:
if "pyspark" in str(constructor) and request.node.callspec.id in {
"pyspark-2",
"pyspark-3",
"pyspark-4",
}:
request.applymarker(pytest.mark.xfail)

if "duckdb" in str(constructor) and request.node.callspec.id not in {"duckdb-0"}:
request.applymarker(pytest.mark.xfail)

Expand Down Expand Up @@ -72,9 +65,7 @@ def test_scalar_reduction_with_columns(
expected: dict[str, list[Any]],
request: pytest.FixtureRequest,
) -> None:
if "duckdb" in str(constructor) or (
"pyspark" in str(constructor) and request.node.callspec.id != "pyspark-1"
):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 2, 3], "b": [4, 5, 6]}
df = nw.from_native(constructor(data))
Expand All @@ -85,6 +76,7 @@ def test_scalar_reduction_with_columns(
def test_empty_scalar_reduction_select(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
# pyspark doesn't necessarely fails, but returns all None's
if "pyspark" in str(constructor) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {
Expand Down Expand Up @@ -118,7 +110,7 @@ def test_empty_scalar_reduction_select(
def test_empty_scalar_reduction_with_columns(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "pyspark" in str(constructor) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
from itertools import chain

Expand Down
2 changes: 1 addition & 1 deletion tests/stable_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def remove_docstring_examples(doc: str) -> str:
def test_renamed_taxicab_norm(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
# Suppose we need to rename `_l1_norm` to `_taxicab_norm`.
# We need `narwhals.stable.v1` to stay stable. So, we
Expand Down
Loading