Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support various reductions in pyspark #1870

Merged
merged 9 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from typing import Literal
from typing import Sequence

from pyspark.sql import Window
from pyspark.sql import functions as F # noqa: N812

from narwhals._spark_like.utils import native_to_narwhals_dtype
from narwhals._spark_like.utils import parse_exprs_and_named_exprs
from narwhals.typing import CompliantLazyFrame
Expand Down Expand Up @@ -94,8 +97,8 @@ def select(
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns = parse_exprs_and_named_exprs(
self, *exprs, with_columns_context=False, **named_exprs
new_columns, returns_scalar = parse_exprs_and_named_exprs(self)(
*exprs, **named_exprs
)

if not new_columns:
Expand All @@ -107,8 +110,38 @@ def select(

return self._from_native_frame(spark_df)

new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()]
return self._from_native_frame(self._native_frame.select(*new_columns_list))
if all(returns_scalar):
new_columns_list = [
col.alias(col_name) for col_name, col in new_columns.items()
]
return self._from_native_frame(self._native_frame.agg(*new_columns_list))
else:
new_columns_list = [
col.over(Window.partitionBy(F.lit(1))).alias(col_name)
if _returns_scalar
else col.alias(col_name)
for (col_name, col), _returns_scalar in zip(
new_columns.items(), returns_scalar
)
]
Comment on lines +119 to +126
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this is may be too late to set over

for example, in nw.col('a') - nw.col('a').mean() - I think it's in the binary operation __sub__ that nw.col('a').mean() needs to become nw.col('a').mean().over(lit(1))

as in, we want to translate nw.col('a') - nw.col('a').mean() to F.col('a') - F.col('a').mean().over(F.lit(1)). the code, however, as far as I can tell, translates it to (F.col('a') - F.col('a').mean()).over(F.lit(1))

Copy link
Member Author

@FBruzzesi FBruzzesi Jan 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That happens in maybe_evaluate, and you can see that now the reduction_test are passing.
I know it's not ideal to have the logic for setting over in two places, but I couldn't figure out a unique place in which to handle this as maybe_evaluate is called only to evaluate other arguments

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see! yes this might be fine then, thanks!

return self._from_native_frame(self._native_frame.select(*new_columns_list))

def with_columns(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just moved this closer to select method - it was easier to debug them while in the same screen πŸ™ˆ

self: Self,
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns, returns_scalar = parse_exprs_and_named_exprs(self)(
*exprs, **named_exprs
)

new_columns_map = {
col_name: col.over(Window.partitionBy(F.lit(1))) if _returns_scalar else col
for (col_name, col), _returns_scalar in zip(
new_columns.items(), returns_scalar
)
}
return self._from_native_frame(self._native_frame.withColumns(new_columns_map))

def filter(self: Self, *predicates: SparkLikeExpr, **constraints: Any) -> Self:
plx = self.__narwhals_namespace__()
Expand All @@ -132,16 +165,6 @@ def schema(self: Self) -> dict[str, DType]:
def collect_schema(self: Self) -> dict[str, DType]:
return self.schema

def with_columns(
self: Self,
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns_map = parse_exprs_and_named_exprs(
self, *exprs, with_columns_context=True, **named_exprs
)
return self._from_native_frame(self._native_frame.withColumns(new_columns_map))

def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
columns_to_drop = parse_columns_to_drop(
compliant_frame=self, columns=columns, strict=strict
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
function_name=f"{self._function_name}->{expr_name}",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
returns_scalar=self._returns_scalar or returns_scalar,
returns_scalar=returns_scalar,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gave me so much headache before spotting it

backend_version=self._backend_version,
version=self._version,
kwargs=expressifiable_args,
Expand Down
77 changes: 35 additions & 42 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any
from typing import Callable

from pyspark.sql import Column
from pyspark.sql import Window
from pyspark.sql import functions as F # noqa: N812

Expand All @@ -13,7 +14,6 @@
from narwhals.utils import isinstance_or_issubclass

if TYPE_CHECKING:
from pyspark.sql import Column
from pyspark.sql import types as pyspark_types

from narwhals._spark_like.dataframe import SparkLikeLazyFrame
Expand Down Expand Up @@ -113,49 +113,42 @@ def narwhals_to_native_dtype(

def parse_exprs_and_named_exprs(
df: SparkLikeLazyFrame,
*exprs: SparkLikeExpr,
with_columns_context: bool,
**named_exprs: SparkLikeExpr,
) -> dict[str, Column]:
native_results: dict[str, Column] = {}

# `returns_scalar` keeps track if an expression returns a scalar and is not lit.
# Notice that lit is quite special case, since it gets broadcasted by pyspark
# without the need of adding `.over(Window.partitionBy(F.lit(1)))`
returns_scalar: list[bool] = []

for expr in exprs:
native_series_list = expr._call(df)
output_names = expr._evaluate_output_names(df)
if expr._alias_output_names is not None:
output_names = expr._alias_output_names(output_names)
if len(output_names) != len(native_series_list): # pragma: no cover
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
raise AssertionError(msg)
native_results.update(zip(output_names, native_series_list))
returns_scalar.extend(
[expr._returns_scalar and expr._function_name != "lit"] * len(output_names)
)
for col_alias, expr in named_exprs.items():
native_series_list = expr._call(df)
if len(native_series_list) != 1: # pragma: no cover
msg = "Named expressions must return a single column"
raise ValueError(msg)
native_results[col_alias] = native_series_list[0]
returns_scalar.append(expr._returns_scalar and expr._function_name != "lit")

if all(returns_scalar) and not with_columns_context:
return native_results
else:
return {
col_name: col.over(Window.partitionBy(F.lit(1))) if _returns_scalar else col
for (col_name, col), _returns_scalar in zip(
native_results.items(), returns_scalar
) -> Callable[..., tuple[dict[str, Column], list[bool]]]:
def func(
*exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr
) -> tuple[dict[str, Column], list[bool]]:
native_results: dict[str, list[Column]] = {}

# `returns_scalar` keeps track if an expression returns a scalar and is not lit.
# Notice that lit is quite special case, since it gets broadcasted by pyspark
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of interest, do we run into any issues if we use over anyway with lit? No objections to special casing it, just curious

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We end up with tests/expr_and_series/lit_test.py failing 3 tests due to:

pyspark.errors.exceptions.captured.AnalysisException: [UNSUPPORTED_EXPR_FOR_WINDOW] Expression "1" not supported within a window function.;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this still going to break for, say

df.with_columns(nw.lit(2)+1)

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes correct

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test case and now it works!

# without the need of adding `.over(Window.partitionBy(F.lit(1)))`
returns_scalar: list[bool] = []
for expr in exprs:
native_series_list = expr._call(df)
output_names = expr._evaluate_output_names(df)
if expr._alias_output_names is not None:
output_names = expr._alias_output_names(output_names)
if len(output_names) != len(native_series_list): # pragma: no cover
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
raise AssertionError(msg)
native_results.update(zip(output_names, native_series_list))
returns_scalar.extend(
[expr._returns_scalar and expr._function_name != "lit"]
* len(output_names)
)
}
for col_alias, expr in named_exprs.items():
native_series_list = expr._call(df)
if len(native_series_list) != 1: # pragma: no cover
msg = "Named expressions must return a single column"
raise ValueError(msg)
native_results[col_alias] = native_series_list[0]
returns_scalar.append(expr._returns_scalar and expr._function_name != "lit")
return native_results, returns_scalar

return func

def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, returns_scalar: bool) -> Any:

def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, returns_scalar: bool) -> Column:
from narwhals._spark_like.expr import SparkLikeExpr

if isinstance(obj, SparkLikeExpr):
Expand All @@ -164,7 +157,7 @@ def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, returns_scalar: bool) ->
msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) not supported in this context"
raise NotImplementedError(msg)
column_result = column_results[0]
if obj._returns_scalar and not returns_scalar and obj._function_name != "lit":
if obj._returns_scalar and obj._function_name != "lit" and not returns_scalar:
FBruzzesi marked this conversation as resolved.
Show resolved Hide resolved
# Returns scalar, but overall expression doesn't.
# Let PySpark do its broadcasting
return column_result.over(Window.partitionBy(F.lit(1)))
Expand Down
4 changes: 1 addition & 3 deletions tests/expr_and_series/binary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@


def test_expr_binary(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if ("dask" in str(constructor) and DASK_VERSION < (2024, 10)) or "pyspark" in str(
constructor
):
if "dask" in str(constructor) and DASK_VERSION < (2024, 10):
request.applymarker(pytest.mark.xfail)
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]}
df_raw = constructor(data)
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.