Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support various reductions in pyspark #1870

Merged
merged 9 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
from typing import Literal
from typing import Sequence

from pyspark.sql import Window
from pyspark.sql import functions as F # noqa: N812

from narwhals._spark_like.utils import native_to_narwhals_dtype
from narwhals._spark_like.utils import parse_exprs_and_named_exprs
from narwhals.typing import CompliantLazyFrame
from narwhals.utils import Implementation
from narwhals.utils import check_column_exists
from narwhals.utils import parse_columns_to_drop
Expand All @@ -26,7 +30,6 @@
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals.dtypes import DType
from narwhals.utils import Version
from narwhals.typing import CompliantLazyFrame


class SparkLikeLazyFrame(CompliantLazyFrame):
Expand Down Expand Up @@ -94,7 +97,9 @@ def select(
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs)
new_columns, returns_scalar = parse_exprs_and_named_exprs(self)(
*exprs, **named_exprs
)

if not new_columns:
# return empty dataframe, like Polars does
Expand All @@ -105,8 +110,38 @@ def select(

return self._from_native_frame(spark_df)

new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()]
return self._from_native_frame(self._native_frame.select(*new_columns_list))
if all(returns_scalar):
new_columns_list = [
col.alias(col_name) for col_name, col in new_columns.items()
]
return self._from_native_frame(self._native_frame.agg(*new_columns_list))
else:
new_columns_list = [
col.over(Window.partitionBy(F.lit(1))).alias(col_name)
if _returns_scalar
else col.alias(col_name)
for (col_name, col), _returns_scalar in zip(
new_columns.items(), returns_scalar
)
]
Comment on lines +119 to +126
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this is may be too late to set over

for example, in nw.col('a') - nw.col('a').mean() - I think it's in the binary operation __sub__ that nw.col('a').mean() needs to become nw.col('a').mean().over(lit(1))

as in, we want to translate nw.col('a') - nw.col('a').mean() to F.col('a') - F.col('a').mean().over(F.lit(1)). the code, however, as far as I can tell, translates it to (F.col('a') - F.col('a').mean()).over(F.lit(1))

Copy link
Member Author

@FBruzzesi FBruzzesi Jan 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That happens in maybe_evaluate, and you can see that now the reduction_test are passing.
I know it's not ideal to have the logic for setting over in two places, but I couldn't figure out a unique place in which to handle this as maybe_evaluate is called only to evaluate other arguments

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see! yes this might be fine then, thanks!

return self._from_native_frame(self._native_frame.select(*new_columns_list))

def with_columns(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just moved this closer to select method - it was easier to debug them while in the same screen πŸ™ˆ

self: Self,
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns, returns_scalar = parse_exprs_and_named_exprs(self)(
*exprs, **named_exprs
)

new_columns_map = {
col_name: col.over(Window.partitionBy(F.lit(1))) if _returns_scalar else col
for (col_name, col), _returns_scalar in zip(
new_columns.items(), returns_scalar
)
}
return self._from_native_frame(self._native_frame.withColumns(new_columns_map))

def filter(self: Self, *predicates: SparkLikeExpr, **constraints: Any) -> Self:
plx = self.__narwhals_namespace__()
Expand All @@ -130,14 +165,6 @@ def schema(self: Self) -> dict[str, DType]:
def collect_schema(self: Self) -> dict[str, DType]:
return self.schema

def with_columns(
self: Self,
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns_map = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs)
return self._from_native_frame(self._native_frame.withColumns(new_columns_map))

def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
columns_to_drop = parse_columns_to_drop(
compliant_frame=self, columns=columns, strict=strict
Expand All @@ -155,7 +182,7 @@ def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroup
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy

return SparkLikeLazyGroupBy(
df=self, keys=list(keys), drop_null_keys=drop_null_keys
compliant_frame=self, keys=list(keys), drop_null_keys=drop_null_keys
)

def sort(
Expand Down
8 changes: 4 additions & 4 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _from_call(
def func(df: SparkLikeLazyFrame) -> list[Column]:
native_series_list = self._call(df)
other_native_series = {
key: maybe_evaluate(df, value)
key: maybe_evaluate(df, value, returns_scalar=returns_scalar)
for key, value in expressifiable_args.items()
}
return [
Expand All @@ -136,7 +136,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
function_name=f"{self._function_name}->{expr_name}",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
returns_scalar=self._returns_scalar or returns_scalar,
returns_scalar=returns_scalar,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gave me so much headache before spotting it

backend_version=self._backend_version,
version=self._version,
)
Expand Down Expand Up @@ -349,7 +349,7 @@ def std(self: Self, ddof: int) -> Self:

func = partial(_std, ddof=ddof, np_version=parse_version(np.__version__))

return self._from_call(func, f"std[{ddof}]", returns_scalar=True)
return self._from_call(func, "std", returns_scalar=True)

def var(self: Self, ddof: int) -> Self:
from functools import partial
Expand All @@ -360,7 +360,7 @@ def var(self: Self, ddof: int) -> Self:

func = partial(_var, ddof=ddof, np_version=parse_version(np.__version__))

return self._from_call(func, f"var[{ddof}]", returns_scalar=True)
return self._from_call(func, "var", returns_scalar=True)

def clip(
self: Self,
Expand Down
160 changes: 34 additions & 126 deletions narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
@@ -1,150 +1,58 @@
from __future__ import annotations

import re
from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Sequence

from pyspark.sql import functions as F # noqa: N812

from narwhals._expression_parsing import is_simple_aggregation
from narwhals._spark_like.utils import _std
from narwhals._spark_like.utils import _var
from narwhals.utils import parse_version

if TYPE_CHECKING:
from pyspark.sql import Column
from pyspark.sql import GroupedData
from typing_extensions import Self

from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.typing import SparkLikeExpr
from narwhals.typing import CompliantExpr
from narwhals._spark_like.expr import SparkLikeExpr


class SparkLikeLazyGroupBy:
def __init__(
self: Self,
df: SparkLikeLazyFrame,
compliant_frame: SparkLikeLazyFrame,
keys: list[str],
drop_null_keys: bool, # noqa: FBT001
) -> None:
self._df = df
self._keys = keys
if drop_null_keys:
self._grouped = self._df._native_frame.dropna(subset=self._keys).groupBy(
*self._keys
)
self._compliant_frame = compliant_frame.drop_nulls(subset=None)
else:
self._grouped = self._df._native_frame.groupBy(*self._keys)

def agg(
self: Self,
*exprs: SparkLikeExpr,
) -> SparkLikeLazyFrame:
return agg_pyspark(
self._df,
self._grouped,
exprs,
self._keys,
self._from_native_frame,
)

def _from_native_frame(self: Self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame:
from narwhals._spark_like.dataframe import SparkLikeLazyFrame

return SparkLikeLazyFrame(
df, backend_version=self._df._backend_version, version=self._df._version
)


def get_spark_function(function_name: str) -> Column:
if (stem := function_name.split("[", maxsplit=1)[0]) in ("std", "var"):
import numpy as np # ignore-banned-import

return partial(
_std if stem == "std" else _var,
ddof=int(function_name.split("[", maxsplit=1)[1].rstrip("]")),
np_version=parse_version(np.__version__),
)

elif function_name == "len":
# Use count(*) to count all rows including nulls
def _count(*_args: Any, **_kwargs: Any) -> Column:
return F.count("*")

return _count

elif function_name == "n_unique":
from pyspark.sql.types import IntegerType

def _n_unique(_input: Column) -> Column:
return F.count_distinct(_input) + F.max(F.isnull(_input).cast(IntegerType()))

return _n_unique

else:
return getattr(F, function_name)


def agg_pyspark(
df: SparkLikeLazyFrame,
grouped: GroupedData,
exprs: Sequence[CompliantExpr[Column]],
keys: list[str],
from_dataframe: Callable[[Any], SparkLikeLazyFrame],
) -> SparkLikeLazyFrame:
if not exprs:
# No aggregation provided
return from_dataframe(df._native_frame.select(*keys).dropDuplicates(subset=keys))
self._compliant_frame = compliant_frame
self._keys = keys

for expr in exprs:
if not is_simple_aggregation(expr): # pragma: no cover
msg = (
"Non-trivial complex aggregation found.\n\n"
"Hint: you were probably trying to apply a non-elementary aggregation with a "
"dask dataframe.\n"
"Please rewrite your query such that group-by aggregations "
"are elementary. For example, instead of:\n\n"
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
"use:\n\n"
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
def agg(self: Self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame:
agg_columns = []
df = self._compliant_frame
for expr in exprs:
output_names = expr._evaluate_output_names(df)
aliases = (
output_names
if expr._alias_output_names is None
else expr._alias_output_names(output_names)
)
raise ValueError(msg)

simple_aggregations: dict[str, Column] = {}
for expr in exprs:
output_names = expr._evaluate_output_names(df)
aliases = (
output_names
if expr._alias_output_names is None
else expr._alias_output_names(output_names)
)
if len(output_names) > 1:
# For multi-output aggregations, e.g. `df.group_by('a').agg(nw.all().mean())`, we skip
# the keys, else they would appear duplicated in the output.
output_names, aliases = zip(
*[(x, alias) for x, alias in zip(output_names, aliases) if x not in keys]
native_expressions = expr(df)
exclude = (
self._keys
if expr._function_name.split("->", maxsplit=1)[0] in ("all", "selector")
else []
)
agg_columns.extend(
[
native_expression.alias(alias)
for native_expression, output_name, alias in zip(
native_expressions, output_names, aliases
)
if output_name not in exclude
]
)
if expr._depth == 0: # pragma: no cover
# e.g. agg(nw.len()) # noqa: ERA001
agg_func = get_spark_function(expr._function_name)
simple_aggregations.update({alias: agg_func(keys[0]) for alias in aliases})
continue

# e.g. agg(nw.mean('a')) # noqa: ERA001
function_name = re.sub(r"(\w+->)", "", expr._function_name)
agg_func = get_spark_function(function_name)
if not agg_columns:
return self._compliant_frame._from_native_frame(
self._compliant_frame._native_frame.select(*self._keys).dropDuplicates()
)

simple_aggregations.update(
{
alias: agg_func(output_name)
for alias, output_name in zip(aliases, output_names)
}
return self._compliant_frame._from_native_frame(
self._compliant_frame._native_frame.groupBy(self._keys).agg(*agg_columns)
)

agg_columns = [col_.alias(name) for name, col_ in simple_aggregations.items()]
result_simple = grouped.agg(*agg_columns)
return from_dataframe(result_simple)
39 changes: 32 additions & 7 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from typing import Any
from typing import Callable

from pyspark.sql import Column
from pyspark.sql import Window
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql import types as pyspark_types
from pyspark.sql.window import Window

from narwhals.exceptions import UnsupportedDTypeError
from narwhals.utils import import_dtypes_module
Expand Down Expand Up @@ -109,9 +110,16 @@ def narwhals_to_native_dtype(

def parse_exprs_and_named_exprs(
df: SparkLikeLazyFrame,
) -> Callable[..., dict[str, Column]]:
def func(*exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr) -> dict[str, Column]:
) -> Callable[..., tuple[dict[str, Column], list[bool]]]:
def func(
*exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr
) -> tuple[dict[str, Column], list[bool]]:
native_results: dict[str, list[Column]] = {}

# `returns_scalar` keeps track if an expression returns a scalar and is not lit.
# Notice that lit is quite special case, since it gets broadcasted by pyspark
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of interest, do we run into any issues if we use over anyway with lit? No objections to special casing it, just curious

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We end up with tests/expr_and_series/lit_test.py failing 3 tests due to:

pyspark.errors.exceptions.captured.AnalysisException: [UNSUPPORTED_EXPR_FOR_WINDOW] Expression "1" not supported within a window function.;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this still going to break for, say

df.with_columns(nw.lit(2)+1)

?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes correct

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test case and now it works!

# without the need of adding `.over(Window.partitionBy(F.lit(1)))`
returns_scalar: list[bool] = []
for expr in exprs:
native_series_list = expr._call(df)
output_names = expr._evaluate_output_names(df)
Expand All @@ -121,18 +129,30 @@ def func(*exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr) -> dict[str, Colum
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
raise AssertionError(msg)
native_results.update(zip(output_names, native_series_list))
returns_scalar.extend(
[
expr._returns_scalar
and expr._function_name.split("->", maxsplit=1)[0] != "lit"
]
* len(output_names)
)
for col_alias, expr in named_exprs.items():
native_series_list = expr._call(df)
if len(native_series_list) != 1: # pragma: no cover
msg = "Named expressions must return a single column"
raise ValueError(msg)
native_results[col_alias] = native_series_list[0]
return native_results
returns_scalar.append(
expr._returns_scalar
and expr._function_name.split("->", maxsplit=1)[0] != "lit"
)

return native_results, returns_scalar

return func


def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any:
def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, returns_scalar: bool) -> Column:
from narwhals._spark_like.expr import SparkLikeExpr

if isinstance(obj, SparkLikeExpr):
Expand All @@ -141,8 +161,13 @@ def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any:
msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) not supported in this context"
raise NotImplementedError(msg)
column_result = column_results[0]
if obj._returns_scalar:
# Return scalar, let PySpark do its broadcasting
if (
obj._returns_scalar
and obj._function_name.split("->", maxsplit=1)[0] != "lit"
and not returns_scalar
):
# Returns scalar, but overall expression doesn't.
# Let PySpark do its broadcasting
return column_result.over(Window.partitionBy(F.lit(1)))
return column_result
return F.lit(obj)
Expand Down
Loading
Loading