Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support various reductions in pyspark #1870

Merged
merged 9 commits into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
from typing import Literal
from typing import Sequence

from pyspark.sql import Window
from pyspark.sql import functions as F # noqa: N812

from narwhals._spark_like.utils import native_to_narwhals_dtype
from narwhals._spark_like.utils import parse_exprs_and_named_exprs
from narwhals.typing import CompliantLazyFrame
from narwhals.utils import Implementation
from narwhals.utils import check_column_exists
from narwhals.utils import parse_columns_to_drop
Expand All @@ -26,7 +30,6 @@
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals.dtypes import DType
from narwhals.utils import Version
from narwhals.typing import CompliantLazyFrame


class SparkLikeLazyFrame(CompliantLazyFrame):
Expand Down Expand Up @@ -94,7 +97,9 @@ def select(
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs)
new_columns, returns_scalar = parse_exprs_and_named_exprs(self)(
*exprs, **named_exprs
)

if not new_columns:
# return empty dataframe, like Polars does
Expand All @@ -105,8 +110,38 @@ def select(

return self._from_native_frame(spark_df)

new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()]
return self._from_native_frame(self._native_frame.select(*new_columns_list))
if all(returns_scalar):
new_columns_list = [
col.alias(col_name) for col_name, col in new_columns.items()
]
return self._from_native_frame(self._native_frame.agg(*new_columns_list))
else:
new_columns_list = [
col.over(Window.partitionBy(F.lit(1))).alias(col_name)
if _returns_scalar
else col.alias(col_name)
for (col_name, col), _returns_scalar in zip(
new_columns.items(), returns_scalar
)
]
Comment on lines +119 to +126
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this is may be too late to set over

for example, in nw.col('a') - nw.col('a').mean() - I think it's in the binary operation __sub__ that nw.col('a').mean() needs to become nw.col('a').mean().over(lit(1))

as in, we want to translate nw.col('a') - nw.col('a').mean() to F.col('a') - F.col('a').mean().over(F.lit(1)). the code, however, as far as I can tell, translates it to (F.col('a') - F.col('a').mean()).over(F.lit(1))

Copy link
Member Author

@FBruzzesi FBruzzesi Jan 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That happens in maybe_evaluate, and you can see that now the reduction_test are passing.
I know it's not ideal to have the logic for setting over in two places, but I couldn't figure out a unique place in which to handle this as maybe_evaluate is called only to evaluate other arguments

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see! yes this might be fine then, thanks!

return self._from_native_frame(self._native_frame.select(*new_columns_list))

def with_columns(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just moved this closer to select method - it was easier to debug them while in the same screen πŸ™ˆ

self: Self,
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns, returns_scalar = parse_exprs_and_named_exprs(self)(
*exprs, **named_exprs
)

new_columns_map = {
col_name: col.over(Window.partitionBy(F.lit(1))) if _returns_scalar else col
for (col_name, col), _returns_scalar in zip(
new_columns.items(), returns_scalar
)
}
return self._from_native_frame(self._native_frame.withColumns(new_columns_map))

def filter(self: Self, *predicates: SparkLikeExpr, **constraints: Any) -> Self:
plx = self.__narwhals_namespace__()
Expand All @@ -130,14 +165,6 @@ def schema(self: Self) -> dict[str, DType]:
def collect_schema(self: Self) -> dict[str, DType]:
return self.schema

def with_columns(
self: Self,
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns_map = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs)
return self._from_native_frame(self._native_frame.withColumns(new_columns_map))

def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
columns_to_drop = parse_columns_to_drop(
compliant_frame=self, columns=columns, strict=strict
Expand All @@ -155,7 +182,7 @@ def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroup
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy

return SparkLikeLazyGroupBy(
df=self, keys=list(keys), drop_null_keys=drop_null_keys
compliant_frame=self, keys=list(keys), drop_null_keys=drop_null_keys
)

def sort(
Expand Down
28 changes: 16 additions & 12 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _from_call(
def func(df: SparkLikeLazyFrame) -> list[Column]:
native_series_list = self._call(df)
other_native_series = {
key: maybe_evaluate(df, value)
key: maybe_evaluate(df, value, returns_scalar=returns_scalar)
for key, value in expressifiable_args.items()
}
return [
Expand All @@ -140,7 +140,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
function_name=f"{self._function_name}->{expr_name}",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
returns_scalar=self._returns_scalar or returns_scalar,
returns_scalar=returns_scalar,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This gave me so much headache before spotting it

backend_version=self._backend_version,
version=self._version,
kwargs=expressifiable_args,
Expand Down Expand Up @@ -349,26 +349,30 @@ def sum(self: Self) -> Self:
return self._from_call(F.sum, "sum", returns_scalar=True)

def std(self: Self, ddof: int) -> Self:
from functools import partial

import numpy as np # ignore-banned-import

from narwhals._spark_like.utils import _std

func = partial(_std, ddof=ddof, np_version=parse_version(np.__version__))

return self._from_call(func, "std", returns_scalar=True, ddof=ddof)
return self._from_call(
_std,
"std",
returns_scalar=True,
ddof=ddof,
np_version=parse_version(np.__version__),
)

def var(self: Self, ddof: int) -> Self:
from functools import partial

import numpy as np # ignore-banned-import

from narwhals._spark_like.utils import _var

func = partial(_var, ddof=ddof, np_version=parse_version(np.__version__))

return self._from_call(func, "var", returns_scalar=True, ddof=ddof)
return self._from_call(
_var,
"var",
returns_scalar=True,
ddof=ddof,
np_version=parse_version(np.__version__),
)

def clip(
self: Self,
Expand Down
160 changes: 34 additions & 126 deletions narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
@@ -1,150 +1,58 @@
from __future__ import annotations

import re
from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Sequence

from pyspark.sql import functions as F # noqa: N812

from narwhals._expression_parsing import is_simple_aggregation
from narwhals._spark_like.utils import _std
from narwhals._spark_like.utils import _var
from narwhals.utils import parse_version

if TYPE_CHECKING:
from pyspark.sql import Column
from pyspark.sql import GroupedData
from typing_extensions import Self

from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.typing import SparkLikeExpr
from narwhals.typing import CompliantExpr
from narwhals._spark_like.expr import SparkLikeExpr


class SparkLikeLazyGroupBy:
def __init__(
self: Self,
df: SparkLikeLazyFrame,
compliant_frame: SparkLikeLazyFrame,
keys: list[str],
drop_null_keys: bool, # noqa: FBT001
) -> None:
self._df = df
self._keys = keys
if drop_null_keys:
self._grouped = self._df._native_frame.dropna(subset=self._keys).groupBy(
*self._keys
)
self._compliant_frame = compliant_frame.drop_nulls(subset=None)
else:
self._grouped = self._df._native_frame.groupBy(*self._keys)

def agg(
self: Self,
*exprs: SparkLikeExpr,
) -> SparkLikeLazyFrame:
return agg_pyspark(
self._df,
self._grouped,
exprs,
self._keys,
self._from_native_frame,
)

def _from_native_frame(self: Self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame:
from narwhals._spark_like.dataframe import SparkLikeLazyFrame

return SparkLikeLazyFrame(
df, backend_version=self._df._backend_version, version=self._df._version
)


def get_spark_function(function_name: str, **kwargs: Any) -> Column:
if function_name in {"std", "var"}:
import numpy as np # ignore-banned-import

return partial(
_std if function_name == "std" else _var,
ddof=kwargs["ddof"],
np_version=parse_version(np.__version__),
)

elif function_name == "len":
# Use count(*) to count all rows including nulls
def _count(*_args: Any, **_kwargs: Any) -> Column:
return F.count("*")

return _count

elif function_name == "n_unique":
from pyspark.sql.types import IntegerType

def _n_unique(_input: Column) -> Column:
return F.count_distinct(_input) + F.max(F.isnull(_input).cast(IntegerType()))

return _n_unique

else:
return getattr(F, function_name)


def agg_pyspark(
df: SparkLikeLazyFrame,
grouped: GroupedData,
exprs: Sequence[CompliantExpr[Column]],
keys: list[str],
from_dataframe: Callable[[Any], SparkLikeLazyFrame],
) -> SparkLikeLazyFrame:
if not exprs:
# No aggregation provided
return from_dataframe(df._native_frame.select(*keys).dropDuplicates(subset=keys))
self._compliant_frame = compliant_frame
self._keys = keys

for expr in exprs:
if not is_simple_aggregation(expr): # pragma: no cover
msg = (
"Non-trivial complex aggregation found.\n\n"
"Hint: you were probably trying to apply a non-elementary aggregation with a "
"dask dataframe.\n"
"Please rewrite your query such that group-by aggregations "
"are elementary. For example, instead of:\n\n"
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
"use:\n\n"
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
def agg(self: Self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame:
agg_columns = []
df = self._compliant_frame
for expr in exprs:
output_names = expr._evaluate_output_names(df)
aliases = (
output_names
if expr._alias_output_names is None
else expr._alias_output_names(output_names)
)
raise ValueError(msg)

simple_aggregations: dict[str, Column] = {}
for expr in exprs:
output_names = expr._evaluate_output_names(df)
aliases = (
output_names
if expr._alias_output_names is None
else expr._alias_output_names(output_names)
)
if len(output_names) > 1:
# For multi-output aggregations, e.g. `df.group_by('a').agg(nw.all().mean())`, we skip
# the keys, else they would appear duplicated in the output.
output_names, aliases = zip(
*[(x, alias) for x, alias in zip(output_names, aliases) if x not in keys]
native_expressions = expr(df)
exclude = (
self._keys
if expr._function_name.split("->", maxsplit=1)[0] in ("all", "selector")
else []
)
agg_columns.extend(
[
native_expression.alias(alias)
for native_expression, output_name, alias in zip(
native_expressions, output_names, aliases
)
if output_name not in exclude
]
)
if expr._depth == 0: # pragma: no cover
# e.g. agg(nw.len()) # noqa: ERA001
agg_func = get_spark_function(expr._function_name, **expr._kwargs)
simple_aggregations.update({alias: agg_func(keys[0]) for alias in aliases})
continue

# e.g. agg(nw.mean('a')) # noqa: ERA001
function_name = re.sub(r"(\w+->)", "", expr._function_name)
agg_func = get_spark_function(function_name, **expr._kwargs)
if not agg_columns:
return self._compliant_frame._from_native_frame(
self._compliant_frame._native_frame.select(*self._keys).dropDuplicates()
)

simple_aggregations.update(
{
alias: agg_func(output_name)
for alias, output_name in zip(aliases, output_names)
}
return self._compliant_frame._from_native_frame(
self._compliant_frame._native_frame.groupBy(self._keys).agg(*agg_columns)
)

agg_columns = [col_.alias(name) for name, col_ in simple_aggregations.items()]
result_simple = grouped.agg(*agg_columns)
return from_dataframe(result_simple)
Loading
Loading