Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support various reductions in pyspark #1870

Merged
merged 9 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 39 additions & 12 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
from typing import Literal
from typing import Sequence

from pyspark.sql import Window
from pyspark.sql import functions as F # noqa: N812

from narwhals._spark_like.utils import native_to_narwhals_dtype
from narwhals._spark_like.utils import parse_exprs_and_named_exprs
from narwhals.typing import CompliantLazyFrame
from narwhals.utils import Implementation
from narwhals.utils import check_column_exists
from narwhals.utils import parse_columns_to_drop
Expand All @@ -26,7 +30,6 @@
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals.dtypes import DType
from narwhals.utils import Version
from narwhals.typing import CompliantLazyFrame


class SparkLikeLazyFrame(CompliantLazyFrame):
Expand Down Expand Up @@ -94,7 +97,9 @@ def select(
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs)
new_columns, returns_scalar = parse_exprs_and_named_exprs(self)(
*exprs, **named_exprs
)

if not new_columns:
# return empty dataframe, like Polars does
Expand All @@ -105,8 +110,38 @@ def select(

return self._from_native_frame(spark_df)

new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()]
return self._from_native_frame(self._native_frame.select(*new_columns_list))
if all(returns_scalar):
new_columns_list = [
col.alias(col_name) for col_name, col in new_columns.items()
]
return self._from_native_frame(self._native_frame.agg(*new_columns_list))
else:
new_columns_list = [
col.over(Window.partitionBy(F.lit(1))).alias(col_name)
if _returns_scalar
else col.alias(col_name)
for (col_name, col), _returns_scalar in zip(
new_columns.items(), returns_scalar
)
]
Comment on lines +119 to +126
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this is may be too late to set over

for example, in nw.col('a') - nw.col('a').mean() - I think it's in the binary operation __sub__ that nw.col('a').mean() needs to become nw.col('a').mean().over(lit(1))

as in, we want to translate nw.col('a') - nw.col('a').mean() to F.col('a') - F.col('a').mean().over(F.lit(1)). the code, however, as far as I can tell, translates it to (F.col('a') - F.col('a').mean()).over(F.lit(1))

Copy link
Member Author

@FBruzzesi FBruzzesi Jan 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That happens in maybe_evaluate, and you can see that now the reduction_test are passing.
I know it's not ideal to have the logic for setting over in two places, but I couldn't figure out a unique place in which to handle this as maybe_evaluate is called only to evaluate other arguments

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see! yes this might be fine then, thanks!

return self._from_native_frame(self._native_frame.select(*new_columns_list))

def with_columns(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just moved this closer to select method - it was easier to debug them while in the same screen πŸ™ˆ

self: Self,
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns, returns_scalar = parse_exprs_and_named_exprs(self)(
*exprs, **named_exprs
)

new_columns_map = {
col_name: col.over(Window.partitionBy(F.lit(1))) if _returns_scalar else col
for (col_name, col), _returns_scalar in zip(
new_columns.items(), returns_scalar
)
}
return self._from_native_frame(self._native_frame.withColumns(new_columns_map))

def filter(self: Self, *predicates: SparkLikeExpr, **constraints: Any) -> Self:
plx = self.__narwhals_namespace__()
Expand All @@ -130,14 +165,6 @@ def schema(self: Self) -> dict[str, DType]:
def collect_schema(self: Self) -> dict[str, DType]:
return self.schema

def with_columns(
self: Self,
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns_map = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs)
return self._from_native_frame(self._native_frame.withColumns(new_columns_map))

def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
columns_to_drop = parse_columns_to_drop(
compliant_frame=self, columns=columns, strict=strict
Expand Down
28 changes: 16 additions & 12 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _from_call(
def func(df: SparkLikeLazyFrame) -> list[Column]:
native_series_list = self._call(df)
other_native_series = {
key: maybe_evaluate(df, value)
key: maybe_evaluate(df, value, returns_scalar=returns_scalar)
for key, value in expressifiable_args.items()
}
return [
Expand All @@ -140,7 +140,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
function_name=f"{self._function_name}->{expr_name}",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
returns_scalar=self._returns_scalar or returns_scalar,
returns_scalar=returns_scalar,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gave me so much headache before spotting it

backend_version=self._backend_version,
version=self._version,
kwargs=expressifiable_args,
Expand Down Expand Up @@ -349,26 +349,30 @@ def sum(self: Self) -> Self:
return self._from_call(F.sum, "sum", returns_scalar=True)

def std(self: Self, ddof: int) -> Self:
from functools import partial

import numpy as np # ignore-banned-import

from narwhals._spark_like.utils import _std

func = partial(_std, ddof=ddof, np_version=parse_version(np.__version__))

return self._from_call(func, "std", returns_scalar=True, ddof=ddof)
return self._from_call(
_std,
"std",
returns_scalar=True,
ddof=ddof,
np_version=parse_version(np.__version__),
)

def var(self: Self, ddof: int) -> Self:
from functools import partial

import numpy as np # ignore-banned-import

from narwhals._spark_like.utils import _var

func = partial(_var, ddof=ddof, np_version=parse_version(np.__version__))

return self._from_call(func, "var", returns_scalar=True, ddof=ddof)
return self._from_call(
_var,
"var",
returns_scalar=True,
ddof=ddof,
np_version=parse_version(np.__version__),
)

def clip(
self: Self,
Expand Down
30 changes: 21 additions & 9 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,15 @@
from typing import Any
from typing import Callable

from pyspark.sql import Column
from pyspark.sql import Window
from pyspark.sql import functions as F # noqa: N812

from narwhals.exceptions import UnsupportedDTypeError
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass

if TYPE_CHECKING:
from pyspark.sql import Column
from pyspark.sql import types as pyspark_types

from narwhals._spark_like.dataframe import SparkLikeLazyFrame
Expand Down Expand Up @@ -112,9 +113,16 @@ def narwhals_to_native_dtype(

def parse_exprs_and_named_exprs(
df: SparkLikeLazyFrame,
) -> Callable[..., dict[str, Column]]:
def func(*exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr) -> dict[str, Column]:
) -> Callable[..., tuple[dict[str, Column], list[bool]]]:
def func(
*exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr
) -> tuple[dict[str, Column], list[bool]]:
native_results: dict[str, list[Column]] = {}

# `returns_scalar` keeps track if an expression returns a scalar and is not lit.
# Notice that lit is quite special case, since it gets broadcasted by pyspark
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of interest, do we run into any issues if we use over anyway with lit? No objections to special casing it, just curious

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We end up with tests/expr_and_series/lit_test.py failing 3 tests due to:

pyspark.errors.exceptions.captured.AnalysisException: [UNSUPPORTED_EXPR_FOR_WINDOW] Expression "1" not supported within a window function.;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this still going to break for, say

df.with_columns(nw.lit(2)+1)

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes correct

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test case and now it works!

# without the need of adding `.over(Window.partitionBy(F.lit(1)))`
returns_scalar: list[bool] = []
for expr in exprs:
native_series_list = expr._call(df)
output_names = expr._evaluate_output_names(df)
Expand All @@ -124,18 +132,23 @@ def func(*exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr) -> dict[str, Colum
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
raise AssertionError(msg)
native_results.update(zip(output_names, native_series_list))
returns_scalar.extend(
[expr._returns_scalar and expr._function_name != "lit"]
* len(output_names)
)
for col_alias, expr in named_exprs.items():
native_series_list = expr._call(df)
if len(native_series_list) != 1: # pragma: no cover
msg = "Named expressions must return a single column"
raise ValueError(msg)
native_results[col_alias] = native_series_list[0]
return native_results
returns_scalar.append(expr._returns_scalar and expr._function_name != "lit")
return native_results, returns_scalar

return func


def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any:
def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, returns_scalar: bool) -> Column:
from narwhals._spark_like.expr import SparkLikeExpr

if isinstance(obj, SparkLikeExpr):
Expand All @@ -144,10 +157,9 @@ def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any:
msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) not supported in this context"
raise NotImplementedError(msg)
column_result = column_results[0]
if obj._returns_scalar:
# Return scalar, let PySpark do its broadcasting
from pyspark.sql.window import Window

if obj._returns_scalar and obj._function_name != "lit" and not returns_scalar:
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
# Returns scalar, but overall expression doesn't.
# Let PySpark do its broadcasting
return column_result.over(Window.partitionBy(F.lit(1)))
return column_result
return obj
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ xfail_strict = true
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
env = [
"MODIN_ENGINE=python",
"PYARROW_IGNORE_TIMEZONE=1"
"PYARROW_IGNORE_TIMEZONE=1",
"TZ=UTC",
]

[tool.coverage.run]
Expand Down
4 changes: 1 addition & 3 deletions tests/expr_and_series/binary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@


def test_expr_binary(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if ("dask" in str(constructor) and DASK_VERSION < (2024, 10)) or "pyspark" in str(
constructor
):
if "dask" in str(constructor) and DASK_VERSION < (2024, 10):
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
df_raw = constructor(data)
Expand Down
7 changes: 0 additions & 7 deletions tests/expr_and_series/lit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,13 +105,6 @@ def test_lit_operation(
and DASK_VERSION < (2024, 10)
):
request.applymarker(pytest.mark.xfail)
if "pyspark" in str(constructor) and col_name in {
"left_lit_with_agg",
"left_scalar_with_agg",
"right_lit_with_agg",
"right_lit",
}:
request.applymarker(pytest.mark.xfail)

data = {"a": [1, 3, 2]}
df_raw = constructor(data)
Expand Down
14 changes: 3 additions & 11 deletions tests/expr_and_series/reduction_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,6 @@ def test_scalar_reduction_select(
expected: dict[str, list[Any]],
request: pytest.FixtureRequest,
) -> None:
if "pyspark" in str(constructor) and request.node.callspec.id in {
"pyspark-2",
"pyspark-3",
"pyspark-4",
}:
request.applymarker(pytest.mark.xfail)

if "duckdb" in str(constructor) and request.node.callspec.id not in {"duckdb-0"}:
request.applymarker(pytest.mark.xfail)

Expand Down Expand Up @@ -72,9 +65,7 @@ def test_scalar_reduction_with_columns(
expected: dict[str, list[Any]],
request: pytest.FixtureRequest,
) -> None:
if "duckdb" in str(constructor) or (
"pyspark" in str(constructor) and request.node.callspec.id != "pyspark-1"
):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 2, 3], "b": [4, 5, 6]}
df = nw.from_native(constructor(data))
Expand All @@ -85,6 +76,7 @@ def test_scalar_reduction_with_columns(
def test_empty_scalar_reduction_select(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
# pyspark doesn't necessarely fails, but returns all None's
if "pyspark" in str(constructor) or "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {
Expand Down Expand Up @@ -118,7 +110,7 @@ def test_empty_scalar_reduction_select(
def test_empty_scalar_reduction_with_columns(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "pyspark" in str(constructor) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
from itertools import chain

Expand Down
2 changes: 1 addition & 1 deletion tests/stable_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def remove_docstring_examples(doc: str) -> str:
def test_renamed_taxicab_norm(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if ("pyspark" in str(constructor)) or "duckdb" in str(constructor):
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
# Suppose we need to rename `_l1_norm` to `_taxicab_norm`.
# We need `narwhals.stable.v1` to stay stable. So, we
Expand Down