Skip to content

Commit

Permalink
fix: support various reductions in pyspark (#1870)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Jan 27, 2025
1 parent 267eb53 commit 702eea5
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 174 deletions.
53 changes: 40 additions & 13 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@
from typing import Literal
from typing import Sequence

from pyspark.sql import Window
from pyspark.sql import functions as F # noqa: N812

from narwhals._spark_like.utils import native_to_narwhals_dtype
from narwhals._spark_like.utils import parse_exprs_and_named_exprs
from narwhals.typing import CompliantLazyFrame
from narwhals.utils import Implementation
from narwhals.utils import check_column_exists
from narwhals.utils import parse_columns_to_drop
Expand All @@ -26,7 +30,6 @@
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals.dtypes import DType
from narwhals.utils import Version
from narwhals.typing import CompliantLazyFrame


class SparkLikeLazyFrame(CompliantLazyFrame):
Expand Down Expand Up @@ -94,7 +97,9 @@ def select(
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs)
new_columns, returns_scalar = parse_exprs_and_named_exprs(self)(
*exprs, **named_exprs
)

if not new_columns:
# return empty dataframe, like Polars does
Expand All @@ -105,8 +110,38 @@ def select(

return self._from_native_frame(spark_df)

new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()]
return self._from_native_frame(self._native_frame.select(*new_columns_list))
if all(returns_scalar):
new_columns_list = [
col.alias(col_name) for col_name, col in new_columns.items()
]
return self._from_native_frame(self._native_frame.agg(*new_columns_list))
else:
new_columns_list = [
col.over(Window.partitionBy(F.lit(1))).alias(col_name)
if _returns_scalar
else col.alias(col_name)
for (col_name, col), _returns_scalar in zip(
new_columns.items(), returns_scalar
)
]
return self._from_native_frame(self._native_frame.select(*new_columns_list))

def with_columns(
self: Self,
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns, returns_scalar = parse_exprs_and_named_exprs(self)(
*exprs, **named_exprs
)

new_columns_map = {
col_name: col.over(Window.partitionBy(F.lit(1))) if _returns_scalar else col
for (col_name, col), _returns_scalar in zip(
new_columns.items(), returns_scalar
)
}
return self._from_native_frame(self._native_frame.withColumns(new_columns_map))

def filter(self: Self, *predicates: SparkLikeExpr, **constraints: Any) -> Self:
plx = self.__narwhals_namespace__()
Expand All @@ -130,14 +165,6 @@ def schema(self: Self) -> dict[str, DType]:
def collect_schema(self: Self) -> dict[str, DType]:
return self.schema

def with_columns(
self: Self,
*exprs: SparkLikeExpr,
**named_exprs: SparkLikeExpr,
) -> Self:
new_columns_map = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs)
return self._from_native_frame(self._native_frame.withColumns(new_columns_map))

def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
columns_to_drop = parse_columns_to_drop(
compliant_frame=self, columns=columns, strict=strict
Expand All @@ -155,7 +182,7 @@ def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroup
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy

return SparkLikeLazyGroupBy(
df=self, keys=list(keys), drop_null_keys=drop_null_keys
compliant_frame=self, keys=list(keys), drop_null_keys=drop_null_keys
)

def sort(
Expand Down
8 changes: 4 additions & 4 deletions narwhals/_spark_like/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def _from_call(
def func(df: SparkLikeLazyFrame) -> list[Column]:
native_series_list = self._call(df)
other_native_series = {
key: maybe_evaluate(df, value)
key: maybe_evaluate(df, value, returns_scalar=returns_scalar)
for key, value in expressifiable_args.items()
}
return [
Expand All @@ -136,7 +136,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]:
function_name=f"{self._function_name}->{expr_name}",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
returns_scalar=self._returns_scalar or returns_scalar,
returns_scalar=returns_scalar,
backend_version=self._backend_version,
version=self._version,
)
Expand Down Expand Up @@ -349,7 +349,7 @@ def std(self: Self, ddof: int) -> Self:

func = partial(_std, ddof=ddof, np_version=parse_version(np.__version__))

return self._from_call(func, f"std[{ddof}]", returns_scalar=True)
return self._from_call(func, "std", returns_scalar=True)

def var(self: Self, ddof: int) -> Self:
from functools import partial
Expand All @@ -360,7 +360,7 @@ def var(self: Self, ddof: int) -> Self:

func = partial(_var, ddof=ddof, np_version=parse_version(np.__version__))

return self._from_call(func, f"var[{ddof}]", returns_scalar=True)
return self._from_call(func, "var", returns_scalar=True)

def clip(
self: Self,
Expand Down
160 changes: 34 additions & 126 deletions narwhals/_spark_like/group_by.py
Original file line number Diff line number Diff line change
@@ -1,150 +1,58 @@
from __future__ import annotations

import re
from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Sequence

from pyspark.sql import functions as F # noqa: N812

from narwhals._expression_parsing import is_simple_aggregation
from narwhals._spark_like.utils import _std
from narwhals._spark_like.utils import _var
from narwhals.utils import parse_version

if TYPE_CHECKING:
from pyspark.sql import Column
from pyspark.sql import GroupedData
from typing_extensions import Self

from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.typing import SparkLikeExpr
from narwhals.typing import CompliantExpr
from narwhals._spark_like.expr import SparkLikeExpr


class SparkLikeLazyGroupBy:
def __init__(
self: Self,
df: SparkLikeLazyFrame,
compliant_frame: SparkLikeLazyFrame,
keys: list[str],
drop_null_keys: bool, # noqa: FBT001
) -> None:
self._df = df
self._keys = keys
if drop_null_keys:
self._grouped = self._df._native_frame.dropna(subset=self._keys).groupBy(
*self._keys
)
self._compliant_frame = compliant_frame.drop_nulls(subset=None)
else:
self._grouped = self._df._native_frame.groupBy(*self._keys)

def agg(
self: Self,
*exprs: SparkLikeExpr,
) -> SparkLikeLazyFrame:
return agg_pyspark(
self._df,
self._grouped,
exprs,
self._keys,
self._from_native_frame,
)

def _from_native_frame(self: Self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame:
from narwhals._spark_like.dataframe import SparkLikeLazyFrame

return SparkLikeLazyFrame(
df, backend_version=self._df._backend_version, version=self._df._version
)


def get_spark_function(function_name: str) -> Column:
if (stem := function_name.split("[", maxsplit=1)[0]) in ("std", "var"):
import numpy as np # ignore-banned-import

return partial(
_std if stem == "std" else _var,
ddof=int(function_name.split("[", maxsplit=1)[1].rstrip("]")),
np_version=parse_version(np.__version__),
)

elif function_name == "len":
# Use count(*) to count all rows including nulls
def _count(*_args: Any, **_kwargs: Any) -> Column:
return F.count("*")

return _count

elif function_name == "n_unique":
from pyspark.sql.types import IntegerType

def _n_unique(_input: Column) -> Column:
return F.count_distinct(_input) + F.max(F.isnull(_input).cast(IntegerType()))

return _n_unique

else:
return getattr(F, function_name)


def agg_pyspark(
df: SparkLikeLazyFrame,
grouped: GroupedData,
exprs: Sequence[CompliantExpr[Column]],
keys: list[str],
from_dataframe: Callable[[Any], SparkLikeLazyFrame],
) -> SparkLikeLazyFrame:
if not exprs:
# No aggregation provided
return from_dataframe(df._native_frame.select(*keys).dropDuplicates(subset=keys))
self._compliant_frame = compliant_frame
self._keys = keys

for expr in exprs:
if not is_simple_aggregation(expr): # pragma: no cover
msg = (
"Non-trivial complex aggregation found.\n\n"
"Hint: you were probably trying to apply a non-elementary aggregation with a "
"dask dataframe.\n"
"Please rewrite your query such that group-by aggregations "
"are elementary. For example, instead of:\n\n"
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
"use:\n\n"
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
def agg(self: Self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame:
agg_columns = []
df = self._compliant_frame
for expr in exprs:
output_names = expr._evaluate_output_names(df)
aliases = (
output_names
if expr._alias_output_names is None
else expr._alias_output_names(output_names)
)
raise ValueError(msg)

simple_aggregations: dict[str, Column] = {}
for expr in exprs:
output_names = expr._evaluate_output_names(df)
aliases = (
output_names
if expr._alias_output_names is None
else expr._alias_output_names(output_names)
)
if len(output_names) > 1:
# For multi-output aggregations, e.g. `df.group_by('a').agg(nw.all().mean())`, we skip
# the keys, else they would appear duplicated in the output.
output_names, aliases = zip(
*[(x, alias) for x, alias in zip(output_names, aliases) if x not in keys]
native_expressions = expr(df)
exclude = (
self._keys
if expr._function_name.split("->", maxsplit=1)[0] in ("all", "selector")
else []
)
agg_columns.extend(
[
native_expression.alias(alias)
for native_expression, output_name, alias in zip(
native_expressions, output_names, aliases
)
if output_name not in exclude
]
)
if expr._depth == 0: # pragma: no cover
# e.g. agg(nw.len()) # noqa: ERA001
agg_func = get_spark_function(expr._function_name)
simple_aggregations.update({alias: agg_func(keys[0]) for alias in aliases})
continue

# e.g. agg(nw.mean('a')) # noqa: ERA001
function_name = re.sub(r"(\w+->)", "", expr._function_name)
agg_func = get_spark_function(function_name)
if not agg_columns:
return self._compliant_frame._from_native_frame(
self._compliant_frame._native_frame.select(*self._keys).dropDuplicates()
)

simple_aggregations.update(
{
alias: agg_func(output_name)
for alias, output_name in zip(aliases, output_names)
}
return self._compliant_frame._from_native_frame(
self._compliant_frame._native_frame.groupBy(self._keys).agg(*agg_columns)
)

agg_columns = [col_.alias(name) for name, col_ in simple_aggregations.items()]
result_simple = grouped.agg(*agg_columns)
return from_dataframe(result_simple)
39 changes: 32 additions & 7 deletions narwhals/_spark_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from typing import Any
from typing import Callable

from pyspark.sql import Column
from pyspark.sql import Window
from pyspark.sql import functions as F # noqa: N812
from pyspark.sql import types as pyspark_types
from pyspark.sql.window import Window

from narwhals.exceptions import UnsupportedDTypeError
from narwhals.utils import import_dtypes_module
Expand Down Expand Up @@ -109,9 +110,16 @@ def narwhals_to_native_dtype(

def parse_exprs_and_named_exprs(
df: SparkLikeLazyFrame,
) -> Callable[..., dict[str, Column]]:
def func(*exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr) -> dict[str, Column]:
) -> Callable[..., tuple[dict[str, Column], list[bool]]]:
def func(
*exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr
) -> tuple[dict[str, Column], list[bool]]:
native_results: dict[str, list[Column]] = {}

# `returns_scalar` keeps track if an expression returns a scalar and is not lit.
# Notice that lit is quite special case, since it gets broadcasted by pyspark
# without the need of adding `.over(Window.partitionBy(F.lit(1)))`
returns_scalar: list[bool] = []
for expr in exprs:
native_series_list = expr._call(df)
output_names = expr._evaluate_output_names(df)
Expand All @@ -121,18 +129,30 @@ def func(*exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr) -> dict[str, Colum
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
raise AssertionError(msg)
native_results.update(zip(output_names, native_series_list))
returns_scalar.extend(
[
expr._returns_scalar
and expr._function_name.split("->", maxsplit=1)[0] != "lit"
]
* len(output_names)
)
for col_alias, expr in named_exprs.items():
native_series_list = expr._call(df)
if len(native_series_list) != 1: # pragma: no cover
msg = "Named expressions must return a single column"
raise ValueError(msg)
native_results[col_alias] = native_series_list[0]
return native_results
returns_scalar.append(
expr._returns_scalar
and expr._function_name.split("->", maxsplit=1)[0] != "lit"
)

return native_results, returns_scalar

return func


def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any:
def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, returns_scalar: bool) -> Column:
from narwhals._spark_like.expr import SparkLikeExpr

if isinstance(obj, SparkLikeExpr):
Expand All @@ -141,8 +161,13 @@ def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any:
msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) not supported in this context"
raise NotImplementedError(msg)
column_result = column_results[0]
if obj._returns_scalar:
# Return scalar, let PySpark do its broadcasting
if (
obj._returns_scalar
and obj._function_name.split("->", maxsplit=1)[0] != "lit"
and not returns_scalar
):
# Returns scalar, but overall expression doesn't.
# Let PySpark do its broadcasting
return column_result.over(Window.partitionBy(F.lit(1)))
return column_result
return F.lit(obj)
Expand Down
Loading

0 comments on commit 702eea5

Please sign in to comment.