-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: support various reductions in pyspark #1870
Changes from 1 commit
c250928
a2c3679
6e07d05
e27e38b
1c91326
a1c8a2b
3caf2b0
598df7d
a93c0fa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,9 @@ | |
from typing import Literal | ||
from typing import Sequence | ||
|
||
from pyspark.sql import Window | ||
from pyspark.sql import functions as F # noqa: N812 | ||
|
||
from narwhals._spark_like.utils import native_to_narwhals_dtype | ||
from narwhals._spark_like.utils import parse_exprs_and_named_exprs | ||
from narwhals.typing import CompliantLazyFrame | ||
|
@@ -94,8 +97,8 @@ def select( | |
*exprs: SparkLikeExpr, | ||
**named_exprs: SparkLikeExpr, | ||
) -> Self: | ||
new_columns = parse_exprs_and_named_exprs( | ||
self, *exprs, with_columns_context=False, **named_exprs | ||
new_columns, returns_scalar = parse_exprs_and_named_exprs(self)( | ||
*exprs, **named_exprs | ||
) | ||
|
||
if not new_columns: | ||
|
@@ -107,8 +110,38 @@ def select( | |
|
||
return self._from_native_frame(spark_df) | ||
|
||
new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()] | ||
return self._from_native_frame(self._native_frame.select(*new_columns_list)) | ||
if all(returns_scalar): | ||
new_columns_list = [ | ||
col.alias(col_name) for col_name, col in new_columns.items() | ||
] | ||
return self._from_native_frame(self._native_frame.agg(*new_columns_list)) | ||
else: | ||
new_columns_list = [ | ||
col.over(Window.partitionBy(F.lit(1))).alias(col_name) | ||
if _returns_scalar | ||
else col.alias(col_name) | ||
for (col_name, col), _returns_scalar in zip( | ||
new_columns.items(), returns_scalar | ||
) | ||
] | ||
return self._from_native_frame(self._native_frame.select(*new_columns_list)) | ||
|
||
def with_columns( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just moved this closer to |
||
self: Self, | ||
*exprs: SparkLikeExpr, | ||
**named_exprs: SparkLikeExpr, | ||
) -> Self: | ||
new_columns, returns_scalar = parse_exprs_and_named_exprs(self)( | ||
*exprs, **named_exprs | ||
) | ||
|
||
new_columns_map = { | ||
col_name: col.over(Window.partitionBy(F.lit(1))) if _returns_scalar else col | ||
for (col_name, col), _returns_scalar in zip( | ||
new_columns.items(), returns_scalar | ||
) | ||
} | ||
return self._from_native_frame(self._native_frame.withColumns(new_columns_map)) | ||
|
||
def filter(self: Self, *predicates: SparkLikeExpr, **constraints: Any) -> Self: | ||
plx = self.__narwhals_namespace__() | ||
|
@@ -132,16 +165,6 @@ def schema(self: Self) -> dict[str, DType]: | |
def collect_schema(self: Self) -> dict[str, DType]: | ||
return self.schema | ||
|
||
def with_columns( | ||
self: Self, | ||
*exprs: SparkLikeExpr, | ||
**named_exprs: SparkLikeExpr, | ||
) -> Self: | ||
new_columns_map = parse_exprs_and_named_exprs( | ||
self, *exprs, with_columns_context=True, **named_exprs | ||
) | ||
return self._from_native_frame(self._native_frame.withColumns(new_columns_map)) | ||
|
||
def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001 | ||
columns_to_drop = parse_columns_to_drop( | ||
compliant_frame=self, columns=columns, strict=strict | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,7 +140,7 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: | |
function_name=f"{self._function_name}->{expr_name}", | ||
evaluate_output_names=self._evaluate_output_names, | ||
alias_output_names=self._alias_output_names, | ||
returns_scalar=self._returns_scalar or returns_scalar, | ||
returns_scalar=returns_scalar, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This gave me so much headache before spotting it |
||
backend_version=self._backend_version, | ||
version=self._version, | ||
kwargs=expressifiable_args, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
from typing import Any | ||
from typing import Callable | ||
|
||
from pyspark.sql import Column | ||
from pyspark.sql import Window | ||
from pyspark.sql import functions as F # noqa: N812 | ||
|
||
|
@@ -13,7 +14,6 @@ | |
from narwhals.utils import isinstance_or_issubclass | ||
|
||
if TYPE_CHECKING: | ||
from pyspark.sql import Column | ||
from pyspark.sql import types as pyspark_types | ||
|
||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame | ||
|
@@ -113,49 +113,42 @@ def narwhals_to_native_dtype( | |
|
||
def parse_exprs_and_named_exprs( | ||
df: SparkLikeLazyFrame, | ||
*exprs: SparkLikeExpr, | ||
with_columns_context: bool, | ||
**named_exprs: SparkLikeExpr, | ||
) -> dict[str, Column]: | ||
native_results: dict[str, Column] = {} | ||
|
||
# `returns_scalar` keeps track if an expression returns a scalar and is not lit. | ||
# Notice that lit is quite special case, since it gets broadcasted by pyspark | ||
# without the need of adding `.over(Window.partitionBy(F.lit(1)))` | ||
returns_scalar: list[bool] = [] | ||
|
||
for expr in exprs: | ||
native_series_list = expr._call(df) | ||
output_names = expr._evaluate_output_names(df) | ||
if expr._alias_output_names is not None: | ||
output_names = expr._alias_output_names(output_names) | ||
if len(output_names) != len(native_series_list): # pragma: no cover | ||
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" | ||
raise AssertionError(msg) | ||
native_results.update(zip(output_names, native_series_list)) | ||
returns_scalar.extend( | ||
[expr._returns_scalar and expr._function_name != "lit"] * len(output_names) | ||
) | ||
for col_alias, expr in named_exprs.items(): | ||
native_series_list = expr._call(df) | ||
if len(native_series_list) != 1: # pragma: no cover | ||
msg = "Named expressions must return a single column" | ||
raise ValueError(msg) | ||
native_results[col_alias] = native_series_list[0] | ||
returns_scalar.append(expr._returns_scalar and expr._function_name != "lit") | ||
|
||
if all(returns_scalar) and not with_columns_context: | ||
return native_results | ||
else: | ||
return { | ||
col_name: col.over(Window.partitionBy(F.lit(1))) if _returns_scalar else col | ||
for (col_name, col), _returns_scalar in zip( | ||
native_results.items(), returns_scalar | ||
) -> Callable[..., tuple[dict[str, Column], list[bool]]]: | ||
def func( | ||
*exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr | ||
) -> tuple[dict[str, Column], list[bool]]: | ||
native_results: dict[str, list[Column]] = {} | ||
|
||
# `returns_scalar` keeps track if an expression returns a scalar and is not lit. | ||
# Notice that lit is quite special case, since it gets broadcasted by pyspark | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of interest, do we run into any issues if we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We end up with
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. isn't this still going to break for, say
? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes correct There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added a test case and now it works! |
||
# without the need of adding `.over(Window.partitionBy(F.lit(1)))` | ||
returns_scalar: list[bool] = [] | ||
for expr in exprs: | ||
native_series_list = expr._call(df) | ||
output_names = expr._evaluate_output_names(df) | ||
if expr._alias_output_names is not None: | ||
output_names = expr._alias_output_names(output_names) | ||
if len(output_names) != len(native_series_list): # pragma: no cover | ||
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" | ||
raise AssertionError(msg) | ||
native_results.update(zip(output_names, native_series_list)) | ||
returns_scalar.extend( | ||
[expr._returns_scalar and expr._function_name != "lit"] | ||
* len(output_names) | ||
) | ||
} | ||
for col_alias, expr in named_exprs.items(): | ||
native_series_list = expr._call(df) | ||
if len(native_series_list) != 1: # pragma: no cover | ||
msg = "Named expressions must return a single column" | ||
raise ValueError(msg) | ||
native_results[col_alias] = native_series_list[0] | ||
returns_scalar.append(expr._returns_scalar and expr._function_name != "lit") | ||
return native_results, returns_scalar | ||
|
||
return func | ||
|
||
def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, returns_scalar: bool) -> Any: | ||
|
||
def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, returns_scalar: bool) -> Column: | ||
from narwhals._spark_like.expr import SparkLikeExpr | ||
|
||
if isinstance(obj, SparkLikeExpr): | ||
|
@@ -164,7 +157,7 @@ def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, returns_scalar: bool) -> | |
msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) not supported in this context" | ||
raise NotImplementedError(msg) | ||
column_result = column_results[0] | ||
if obj._returns_scalar and not returns_scalar and obj._function_name != "lit": | ||
if obj._returns_scalar and obj._function_name != "lit" and not returns_scalar: | ||
FBruzzesi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Returns scalar, but overall expression doesn't. | ||
# Let PySpark do its broadcasting | ||
return column_result.over(Window.partitionBy(F.lit(1))) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think this is may be too late to set
over
for example, in
nw.col('a') - nw.col('a').mean()
- I think it's in the binary operation__sub__
thatnw.col('a').mean()
needs to becomenw.col('a').mean().over(lit(1))
as in, we want to translate
nw.col('a') - nw.col('a').mean()
toF.col('a') - F.col('a').mean().over(F.lit(1))
. the code, however, as far as I can tell, translates it to(F.col('a') - F.col('a').mean()).over(F.lit(1))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That happens in
maybe_evaluate
, and you can see that now thereduction_test
are passing.I know it's not ideal to have the logic for setting
over
in two places, but I couldn't figure out a unique place in which to handle this asmaybe_evaluate
is called only to evaluate other argumentsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah I see! yes this might be fine then, thanks!