-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: support various reductions in pyspark #1870
Changes from 5 commits
c250928
a2c3679
6e07d05
e27e38b
1c91326
a1c8a2b
3caf2b0
598df7d
a93c0fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,8 +6,12 @@ | |
from typing import Literal | ||
from typing import Sequence | ||
|
||
from pyspark.sql import Window | ||
from pyspark.sql import functions as F # noqa: N812 | ||
|
||
from narwhals._spark_like.utils import native_to_narwhals_dtype | ||
from narwhals._spark_like.utils import parse_exprs_and_named_exprs | ||
from narwhals.typing import CompliantLazyFrame | ||
from narwhals.utils import Implementation | ||
from narwhals.utils import check_column_exists | ||
from narwhals.utils import parse_columns_to_drop | ||
|
@@ -26,7 +30,6 @@ | |
from narwhals._spark_like.namespace import SparkLikeNamespace | ||
from narwhals.dtypes import DType | ||
from narwhals.utils import Version | ||
from narwhals.typing import CompliantLazyFrame | ||
|
||
|
||
class SparkLikeLazyFrame(CompliantLazyFrame): | ||
|
@@ -94,7 +97,9 @@ def select( | |
*exprs: SparkLikeExpr, | ||
**named_exprs: SparkLikeExpr, | ||
) -> Self: | ||
new_columns = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs) | ||
new_columns, returns_scalar = parse_exprs_and_named_exprs(self)( | ||
*exprs, **named_exprs | ||
) | ||
|
||
if not new_columns: | ||
# return empty dataframe, like Polars does | ||
|
@@ -105,8 +110,38 @@ def select( | |
|
||
return self._from_native_frame(spark_df) | ||
|
||
new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()] | ||
return self._from_native_frame(self._native_frame.select(*new_columns_list)) | ||
if all(returns_scalar): | ||
new_columns_list = [ | ||
col.alias(col_name) for col_name, col in new_columns.items() | ||
] | ||
return self._from_native_frame(self._native_frame.agg(*new_columns_list)) | ||
else: | ||
new_columns_list = [ | ||
col.over(Window.partitionBy(F.lit(1))).alias(col_name) | ||
if _returns_scalar | ||
else col.alias(col_name) | ||
for (col_name, col), _returns_scalar in zip( | ||
new_columns.items(), returns_scalar | ||
) | ||
] | ||
return self._from_native_frame(self._native_frame.select(*new_columns_list)) | ||
|
||
def with_columns( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just moved this closer to |
||
self: Self, | ||
*exprs: SparkLikeExpr, | ||
**named_exprs: SparkLikeExpr, | ||
) -> Self: | ||
new_columns, returns_scalar = parse_exprs_and_named_exprs(self)( | ||
*exprs, **named_exprs | ||
) | ||
|
||
new_columns_map = { | ||
col_name: col.over(Window.partitionBy(F.lit(1))) if _returns_scalar else col | ||
for (col_name, col), _returns_scalar in zip( | ||
new_columns.items(), returns_scalar | ||
) | ||
} | ||
return self._from_native_frame(self._native_frame.withColumns(new_columns_map)) | ||
|
||
def filter(self: Self, *predicates: SparkLikeExpr, **constraints: Any) -> Self: | ||
plx = self.__narwhals_namespace__() | ||
|
@@ -130,14 +165,6 @@ def schema(self: Self) -> dict[str, DType]: | |
def collect_schema(self: Self) -> dict[str, DType]: | ||
return self.schema | ||
|
||
def with_columns( | ||
self: Self, | ||
*exprs: SparkLikeExpr, | ||
**named_exprs: SparkLikeExpr, | ||
) -> Self: | ||
new_columns_map = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs) | ||
return self._from_native_frame(self._native_frame.withColumns(new_columns_map)) | ||
|
||
def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001 | ||
columns_to_drop = parse_columns_to_drop( | ||
compliant_frame=self, columns=columns, strict=strict | ||
|
@@ -155,7 +182,7 @@ def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroup | |
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy | ||
|
||
return SparkLikeLazyGroupBy( | ||
df=self, keys=list(keys), drop_null_keys=drop_null_keys | ||
compliant_frame=self, keys=list(keys), drop_null_keys=drop_null_keys | ||
) | ||
|
||
def sort( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -126,7 +126,7 @@ def _from_call( | |
def func(df: SparkLikeLazyFrame) -> list[Column]: | ||
native_series_list = self._call(df) | ||
other_native_series = { | ||
key: maybe_evaluate(df, value) | ||
key: maybe_evaluate(df, value, returns_scalar=returns_scalar) | ||
for key, value in expressifiable_args.items() | ||
} | ||
return [ | ||
|
@@ -140,7 +140,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: | |
function_name=f"{self._function_name}->{expr_name}", | ||
evaluate_output_names=self._evaluate_output_names, | ||
alias_output_names=self._alias_output_names, | ||
returns_scalar=self._returns_scalar or returns_scalar, | ||
returns_scalar=returns_scalar, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This gave me so much headache before spotting it |
||
backend_version=self._backend_version, | ||
version=self._version, | ||
kwargs=expressifiable_args, | ||
|
@@ -349,26 +349,30 @@ def sum(self: Self) -> Self: | |
return self._from_call(F.sum, "sum", returns_scalar=True) | ||
|
||
def std(self: Self, ddof: int) -> Self: | ||
from functools import partial | ||
|
||
import numpy as np # ignore-banned-import | ||
|
||
from narwhals._spark_like.utils import _std | ||
|
||
func = partial(_std, ddof=ddof, np_version=parse_version(np.__version__)) | ||
|
||
return self._from_call(func, "std", returns_scalar=True, ddof=ddof) | ||
return self._from_call( | ||
_std, | ||
"std", | ||
returns_scalar=True, | ||
ddof=ddof, | ||
np_version=parse_version(np.__version__), | ||
) | ||
|
||
def var(self: Self, ddof: int) -> Self: | ||
from functools import partial | ||
|
||
import numpy as np # ignore-banned-import | ||
|
||
from narwhals._spark_like.utils import _var | ||
|
||
func = partial(_var, ddof=ddof, np_version=parse_version(np.__version__)) | ||
|
||
return self._from_call(func, "var", returns_scalar=True, ddof=ddof) | ||
return self._from_call( | ||
_var, | ||
"var", | ||
returns_scalar=True, | ||
ddof=ddof, | ||
np_version=parse_version(np.__version__), | ||
) | ||
|
||
def clip( | ||
self: Self, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,150 +1,58 @@ | ||
from __future__ import annotations | ||
|
||
import re | ||
from functools import partial | ||
from typing import TYPE_CHECKING | ||
from typing import Any | ||
from typing import Callable | ||
from typing import Sequence | ||
|
||
from pyspark.sql import functions as F # noqa: N812 | ||
|
||
from narwhals._expression_parsing import is_simple_aggregation | ||
from narwhals._spark_like.utils import _std | ||
from narwhals._spark_like.utils import _var | ||
from narwhals.utils import parse_version | ||
|
||
if TYPE_CHECKING: | ||
from pyspark.sql import Column | ||
from pyspark.sql import GroupedData | ||
from typing_extensions import Self | ||
|
||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame | ||
from narwhals._spark_like.typing import SparkLikeExpr | ||
from narwhals.typing import CompliantExpr | ||
from narwhals._spark_like.expr import SparkLikeExpr | ||
|
||
|
||
class SparkLikeLazyGroupBy: | ||
def __init__( | ||
self: Self, | ||
df: SparkLikeLazyFrame, | ||
compliant_frame: SparkLikeLazyFrame, | ||
keys: list[str], | ||
drop_null_keys: bool, # noqa: FBT001 | ||
) -> None: | ||
self._df = df | ||
self._keys = keys | ||
if drop_null_keys: | ||
self._grouped = self._df._native_frame.dropna(subset=self._keys).groupBy( | ||
*self._keys | ||
) | ||
self._compliant_frame = compliant_frame.drop_nulls(subset=None) | ||
else: | ||
self._grouped = self._df._native_frame.groupBy(*self._keys) | ||
|
||
def agg( | ||
self: Self, | ||
*exprs: SparkLikeExpr, | ||
) -> SparkLikeLazyFrame: | ||
return agg_pyspark( | ||
self._df, | ||
self._grouped, | ||
exprs, | ||
self._keys, | ||
self._from_native_frame, | ||
) | ||
|
||
def _from_native_frame(self: Self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame: | ||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame | ||
|
||
return SparkLikeLazyFrame( | ||
df, backend_version=self._df._backend_version, version=self._df._version | ||
) | ||
|
||
|
||
def get_spark_function(function_name: str, **kwargs: Any) -> Column: | ||
if function_name in {"std", "var"}: | ||
import numpy as np # ignore-banned-import | ||
|
||
return partial( | ||
_std if function_name == "std" else _var, | ||
ddof=kwargs["ddof"], | ||
np_version=parse_version(np.__version__), | ||
) | ||
|
||
elif function_name == "len": | ||
# Use count(*) to count all rows including nulls | ||
def _count(*_args: Any, **_kwargs: Any) -> Column: | ||
return F.count("*") | ||
|
||
return _count | ||
|
||
elif function_name == "n_unique": | ||
from pyspark.sql.types import IntegerType | ||
|
||
def _n_unique(_input: Column) -> Column: | ||
return F.count_distinct(_input) + F.max(F.isnull(_input).cast(IntegerType())) | ||
|
||
return _n_unique | ||
|
||
else: | ||
return getattr(F, function_name) | ||
|
||
|
||
def agg_pyspark( | ||
df: SparkLikeLazyFrame, | ||
grouped: GroupedData, | ||
exprs: Sequence[CompliantExpr[Column]], | ||
keys: list[str], | ||
from_dataframe: Callable[[Any], SparkLikeLazyFrame], | ||
) -> SparkLikeLazyFrame: | ||
if not exprs: | ||
# No aggregation provided | ||
return from_dataframe(df._native_frame.select(*keys).dropDuplicates(subset=keys)) | ||
self._compliant_frame = compliant_frame | ||
self._keys = keys | ||
|
||
for expr in exprs: | ||
if not is_simple_aggregation(expr): # pragma: no cover | ||
msg = ( | ||
"Non-trivial complex aggregation found.\n\n" | ||
"Hint: you were probably trying to apply a non-elementary aggregation with a " | ||
"dask dataframe.\n" | ||
"Please rewrite your query such that group-by aggregations " | ||
"are elementary. For example, instead of:\n\n" | ||
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" | ||
"use:\n\n" | ||
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" | ||
def agg(self: Self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame: | ||
agg_columns = [] | ||
df = self._compliant_frame | ||
for expr in exprs: | ||
output_names = expr._evaluate_output_names(df) | ||
aliases = ( | ||
output_names | ||
if expr._alias_output_names is None | ||
else expr._alias_output_names(output_names) | ||
) | ||
raise ValueError(msg) | ||
|
||
simple_aggregations: dict[str, Column] = {} | ||
for expr in exprs: | ||
output_names = expr._evaluate_output_names(df) | ||
aliases = ( | ||
output_names | ||
if expr._alias_output_names is None | ||
else expr._alias_output_names(output_names) | ||
) | ||
if len(output_names) > 1: | ||
# For multi-output aggregations, e.g. `df.group_by('a').agg(nw.all().mean())`, we skip | ||
# the keys, else they would appear duplicated in the output. | ||
output_names, aliases = zip( | ||
*[(x, alias) for x, alias in zip(output_names, aliases) if x not in keys] | ||
native_expressions = expr(df) | ||
exclude = ( | ||
self._keys | ||
if expr._function_name.split("->", maxsplit=1)[0] in ("all", "selector") | ||
else [] | ||
) | ||
agg_columns.extend( | ||
[ | ||
native_expression.alias(alias) | ||
for native_expression, output_name, alias in zip( | ||
native_expressions, output_names, aliases | ||
) | ||
if output_name not in exclude | ||
] | ||
) | ||
if expr._depth == 0: # pragma: no cover | ||
# e.g. agg(nw.len()) # noqa: ERA001 | ||
agg_func = get_spark_function(expr._function_name, **expr._kwargs) | ||
simple_aggregations.update({alias: agg_func(keys[0]) for alias in aliases}) | ||
continue | ||
|
||
# e.g. agg(nw.mean('a')) # noqa: ERA001 | ||
function_name = re.sub(r"(\w+->)", "", expr._function_name) | ||
agg_func = get_spark_function(function_name, **expr._kwargs) | ||
if not agg_columns: | ||
return self._compliant_frame._from_native_frame( | ||
self._compliant_frame._native_frame.select(*self._keys).dropDuplicates() | ||
) | ||
|
||
simple_aggregations.update( | ||
{ | ||
alias: agg_func(output_name) | ||
for alias, output_name in zip(aliases, output_names) | ||
} | ||
return self._compliant_frame._from_native_frame( | ||
self._compliant_frame._native_frame.groupBy(self._keys).agg(*agg_columns) | ||
) | ||
|
||
agg_columns = [col_.alias(name) for name, col_ in simple_aggregations.items()] | ||
result_simple = grouped.agg(*agg_columns) | ||
return from_dataframe(result_simple) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this is may be too late to set
over
for example, in
nw.col('a') - nw.col('a').mean()
- I think it's in the binary operation__sub__
thatnw.col('a').mean()
needs to becomenw.col('a').mean().over(lit(1))
as in, we want to translate
nw.col('a') - nw.col('a').mean()
toF.col('a') - F.col('a').mean().over(F.lit(1))
. the code, however, as far as I can tell, translates it to(F.col('a') - F.col('a').mean()).over(F.lit(1))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That happens in
maybe_evaluate
, and you can see that now thereduction_test
are passing.I know it's not ideal to have the logic for setting
over
in two places, but I couldn't figure out a unique place in which to handle this asmaybe_evaluate
is called only to evaluate other argumentsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah I see! yes this might be fine then, thanks!