From c250928e00502a74cf3f5aaf431b27dc88c8c804 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 26 Jan 2025 18:47:20 +0100 Subject: [PATCH 1/6] test and test env --- pyproject.toml | 3 ++- tests/expr_and_series/lit_test.py | 7 ------- tests/expr_and_series/reduction_test.py | 14 +++----------- tests/stable_api_test.py | 2 +- 4 files changed, 6 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d283f4d59..195ef1027 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,7 +174,8 @@ xfail_strict = true markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] env = [ "MODIN_ENGINE=python", - "PYARROW_IGNORE_TIMEZONE=1" + "PYARROW_IGNORE_TIMEZONE=1", + "TZ=UTC", ] [tool.coverage.run] diff --git a/tests/expr_and_series/lit_test.py b/tests/expr_and_series/lit_test.py index b29ab89ee..ddc1267e0 100644 --- a/tests/expr_and_series/lit_test.py +++ b/tests/expr_and_series/lit_test.py @@ -105,13 +105,6 @@ def test_lit_operation( and DASK_VERSION < (2024, 10) ): request.applymarker(pytest.mark.xfail) - if "pyspark" in str(constructor) and col_name in { - "left_lit_with_agg", - "left_scalar_with_agg", - "right_lit_with_agg", - "right_lit", - }: - request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2]} df_raw = constructor(data) diff --git a/tests/expr_and_series/reduction_test.py b/tests/expr_and_series/reduction_test.py index 49a3fddba..b1d84e85e 100644 --- a/tests/expr_and_series/reduction_test.py +++ b/tests/expr_and_series/reduction_test.py @@ -33,13 +33,6 @@ def test_scalar_reduction_select( expected: dict[str, list[Any]], request: pytest.FixtureRequest, ) -> None: - if "pyspark" in str(constructor) and request.node.callspec.id in { - "pyspark-2", - "pyspark-3", - "pyspark-4", - }: - request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor) and request.node.callspec.id not in {"duckdb-0"}: request.applymarker(pytest.mark.xfail) @@ -72,9 +65,7 @@ def test_scalar_reduction_with_columns( expected: dict[str, list[Any]], request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor) or ( - "pyspark" in str(constructor) and request.node.callspec.id != "pyspark-1" - ): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = {"a": [1, 2, 3], "b": [4, 5, 6]} df = nw.from_native(constructor(data)) @@ -85,6 +76,7 @@ def test_scalar_reduction_with_columns( def test_empty_scalar_reduction_select( constructor: Constructor, request: pytest.FixtureRequest ) -> None: + # pyspark doesn't necessarely fails, but returns all None's if "pyspark" in str(constructor) or "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) data = { @@ -118,7 +110,7 @@ def test_empty_scalar_reduction_select( def test_empty_scalar_reduction_with_columns( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if "pyspark" in str(constructor) or "duckdb" in str(constructor): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) from itertools import chain diff --git a/tests/stable_api_test.py b/tests/stable_api_test.py index 8d5348834..ad3416892 100644 --- a/tests/stable_api_test.py +++ b/tests/stable_api_test.py @@ -22,7 +22,7 @@ def remove_docstring_examples(doc: str) -> str: def test_renamed_taxicab_norm( constructor: Constructor, request: pytest.FixtureRequest ) -> None: - if ("pyspark" in str(constructor)) or "duckdb" in str(constructor): + if "duckdb" in str(constructor): request.applymarker(pytest.mark.xfail) # Suppose we need to rename `_l1_norm` to `_taxicab_norm`. # We need `narwhals.stable.v1` to stay stable. So, we From a2c3679c477e7089dd0afb350e87840b5a4a9d71 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 26 Jan 2025 18:57:23 +0100 Subject: [PATCH 2/6] working solution --- narwhals/_spark_like/dataframe.py | 10 +++++--- narwhals/_spark_like/expr.py | 26 ++++++++++++--------- narwhals/_spark_like/utils.py | 38 ++++++++++++++++++++++++------- 3 files changed, 52 insertions(+), 22 deletions(-) diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 3c4ac2fe7..665e09e8a 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -8,6 +8,7 @@ from narwhals._spark_like.utils import native_to_narwhals_dtype from narwhals._spark_like.utils import parse_exprs_and_named_exprs +from narwhals.typing import CompliantLazyFrame from narwhals.utils import Implementation from narwhals.utils import check_column_exists from narwhals.utils import parse_columns_to_drop @@ -26,7 +27,6 @@ from narwhals._spark_like.namespace import SparkLikeNamespace from narwhals.dtypes import DType from narwhals.utils import Version -from narwhals.typing import CompliantLazyFrame class SparkLikeLazyFrame(CompliantLazyFrame): @@ -94,7 +94,9 @@ def select( *exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr, ) -> Self: - new_columns = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) + new_columns = parse_exprs_and_named_exprs( + self, *exprs, with_columns_context=False, **named_exprs + ) if not new_columns: # return empty dataframe, like Polars does @@ -135,7 +137,9 @@ def with_columns( *exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr, ) -> Self: - new_columns_map = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) + new_columns_map = parse_exprs_and_named_exprs( + self, *exprs, with_columns_context=True, **named_exprs + ) return self._from_native_frame(self._native_frame.withColumns(new_columns_map)) def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001 diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index df9e8d42a..feed13473 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -126,7 +126,7 @@ def _from_call( def func(df: SparkLikeLazyFrame) -> list[Column]: native_series_list = self._call(df) other_native_series = { - key: maybe_evaluate(df, value) + key: maybe_evaluate(df, value, returns_scalar=returns_scalar) for key, value in expressifiable_args.items() } return [ @@ -349,26 +349,30 @@ def sum(self: Self) -> Self: return self._from_call(F.sum, "sum", returns_scalar=True) def std(self: Self, ddof: int) -> Self: - from functools import partial - import numpy as np # ignore-banned-import from narwhals._spark_like.utils import _std - func = partial(_std, ddof=ddof, np_version=parse_version(np.__version__)) - - return self._from_call(func, "std", returns_scalar=True, ddof=ddof) + return self._from_call( + _std, + "std", + returns_scalar=True, + ddof=ddof, + np_version=parse_version(np.__version__), + ) def var(self: Self, ddof: int) -> Self: - from functools import partial - import numpy as np # ignore-banned-import from narwhals._spark_like.utils import _var - func = partial(_var, ddof=ddof, np_version=parse_version(np.__version__)) - - return self._from_call(func, "var", returns_scalar=True, ddof=ddof) + return self._from_call( + _var, + "var", + returns_scalar=True, + ddof=ddof, + np_version=parse_version(np.__version__), + ) def clip( self: Self, diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index f2f7abfde..b568d0816 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING from typing import Any +from pyspark.sql import Window from pyspark.sql import functions as F # noqa: N812 from narwhals.exceptions import UnsupportedDTypeError @@ -110,9 +111,18 @@ def narwhals_to_native_dtype( def parse_exprs_and_named_exprs( - df: SparkLikeLazyFrame, *exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr + df: SparkLikeLazyFrame, + *exprs: SparkLikeExpr, + with_columns_context: bool, + **named_exprs: SparkLikeExpr, ) -> dict[str, Column]: - native_results: dict[str, list[Column]] = {} + native_results: dict[str, Column] = {} + + # `returns_scalar` keeps track if an expression returns a scalar and is not lit. + # Notice that lit is quite special case, since it gets broadcasted by pyspark + # without the need of adding `.over(Window.partitionBy(F.lit(1)))` + returns_scalar: list[bool] = [] + for expr in exprs: native_series_list = expr._call(df) output_names = expr._evaluate_output_names(df) @@ -122,16 +132,29 @@ def parse_exprs_and_named_exprs( msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" raise AssertionError(msg) native_results.update(zip(output_names, native_series_list)) + returns_scalar.extend( + [expr._returns_scalar and expr._function_name != "lit"] * len(output_names) + ) for col_alias, expr in named_exprs.items(): native_series_list = expr._call(df) if len(native_series_list) != 1: # pragma: no cover msg = "Named expressions must return a single column" raise ValueError(msg) native_results[col_alias] = native_series_list[0] - return native_results + returns_scalar.append(expr._returns_scalar and expr._function_name != "lit") + if all(returns_scalar) and not with_columns_context: + return native_results + else: + return { + col_name: col.over(Window.partitionBy(F.lit(1))) if _returns_scalar else col + for (col_name, col), _returns_scalar in zip( + native_results.items(), returns_scalar + ) + } -def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any: + +def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, returns_scalar: bool) -> Any: from narwhals._spark_like.expr import SparkLikeExpr if isinstance(obj, SparkLikeExpr): @@ -140,10 +163,9 @@ def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any: msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) not supported in this context" raise NotImplementedError(msg) column_result = column_results[0] - if obj._returns_scalar: - # Return scalar, let PySpark do its broadcasting - from pyspark.sql.window import Window - + if obj._returns_scalar and not returns_scalar and obj._function_name != "lit": + # Returns scalar, but overall expression doesn't. + # Let PySpark do its broadcasting return column_result.over(Window.partitionBy(F.lit(1))) return column_result return obj From 1c91326b038c9d112292109ebd1f655507cf9bab Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Sun, 26 Jan 2025 21:53:31 +0100 Subject: [PATCH 3/6] also the great group by refactor --- narwhals/_spark_like/dataframe.py | 2 +- narwhals/_spark_like/group_by.py | 160 +++++++----------------------- 2 files changed, 35 insertions(+), 127 deletions(-) diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 21d6d5302..62865159f 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -182,7 +182,7 @@ def group_by(self: Self, *keys: str, drop_null_keys: bool) -> SparkLikeLazyGroup from narwhals._spark_like.group_by import SparkLikeLazyGroupBy return SparkLikeLazyGroupBy( - df=self, keys=list(keys), drop_null_keys=drop_null_keys + compliant_frame=self, keys=list(keys), drop_null_keys=drop_null_keys ) def sort( diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py index dca7dbf16..7d87b11dc 100644 --- a/narwhals/_spark_like/group_by.py +++ b/narwhals/_spark_like/group_by.py @@ -1,150 +1,58 @@ from __future__ import annotations -import re -from functools import partial from typing import TYPE_CHECKING -from typing import Any -from typing import Callable -from typing import Sequence - -from pyspark.sql import functions as F # noqa: N812 - -from narwhals._expression_parsing import is_simple_aggregation -from narwhals._spark_like.utils import _std -from narwhals._spark_like.utils import _var -from narwhals.utils import parse_version if TYPE_CHECKING: - from pyspark.sql import Column - from pyspark.sql import GroupedData from typing_extensions import Self from narwhals._spark_like.dataframe import SparkLikeLazyFrame - from narwhals._spark_like.typing import SparkLikeExpr - from narwhals.typing import CompliantExpr + from narwhals._spark_like.expr import SparkLikeExpr class SparkLikeLazyGroupBy: def __init__( self: Self, - df: SparkLikeLazyFrame, + compliant_frame: SparkLikeLazyFrame, keys: list[str], drop_null_keys: bool, # noqa: FBT001 ) -> None: - self._df = df - self._keys = keys if drop_null_keys: - self._grouped = self._df._native_frame.dropna(subset=self._keys).groupBy( - *self._keys - ) + self._compliant_frame = compliant_frame.drop_nulls(subset=None) else: - self._grouped = self._df._native_frame.groupBy(*self._keys) - - def agg( - self: Self, - *exprs: SparkLikeExpr, - ) -> SparkLikeLazyFrame: - return agg_pyspark( - self._df, - self._grouped, - exprs, - self._keys, - self._from_native_frame, - ) - - def _from_native_frame(self: Self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame: - from narwhals._spark_like.dataframe import SparkLikeLazyFrame - - return SparkLikeLazyFrame( - df, backend_version=self._df._backend_version, version=self._df._version - ) - - -def get_spark_function(function_name: str, **kwargs: Any) -> Column: - if function_name in {"std", "var"}: - import numpy as np # ignore-banned-import - - return partial( - _std if function_name == "std" else _var, - ddof=kwargs["ddof"], - np_version=parse_version(np.__version__), - ) - - elif function_name == "len": - # Use count(*) to count all rows including nulls - def _count(*_args: Any, **_kwargs: Any) -> Column: - return F.count("*") - - return _count - - elif function_name == "n_unique": - from pyspark.sql.types import IntegerType - - def _n_unique(_input: Column) -> Column: - return F.count_distinct(_input) + F.max(F.isnull(_input).cast(IntegerType())) - - return _n_unique - - else: - return getattr(F, function_name) - - -def agg_pyspark( - df: SparkLikeLazyFrame, - grouped: GroupedData, - exprs: Sequence[CompliantExpr[Column]], - keys: list[str], - from_dataframe: Callable[[Any], SparkLikeLazyFrame], -) -> SparkLikeLazyFrame: - if not exprs: - # No aggregation provided - return from_dataframe(df._native_frame.select(*keys).dropDuplicates(subset=keys)) + self._compliant_frame = compliant_frame + self._keys = keys - for expr in exprs: - if not is_simple_aggregation(expr): # pragma: no cover - msg = ( - "Non-trivial complex aggregation found.\n\n" - "Hint: you were probably trying to apply a non-elementary aggregation with a " - "dask dataframe.\n" - "Please rewrite your query such that group-by aggregations " - "are elementary. For example, instead of:\n\n" - " df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" - "use:\n\n" - " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" + def agg(self: Self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame: + agg_columns = [] + df = self._compliant_frame + for expr in exprs: + output_names = expr._evaluate_output_names(df) + aliases = ( + output_names + if expr._alias_output_names is None + else expr._alias_output_names(output_names) ) - raise ValueError(msg) - - simple_aggregations: dict[str, Column] = {} - for expr in exprs: - output_names = expr._evaluate_output_names(df) - aliases = ( - output_names - if expr._alias_output_names is None - else expr._alias_output_names(output_names) - ) - if len(output_names) > 1: - # For multi-output aggregations, e.g. `df.group_by('a').agg(nw.all().mean())`, we skip - # the keys, else they would appear duplicated in the output. - output_names, aliases = zip( - *[(x, alias) for x, alias in zip(output_names, aliases) if x not in keys] + native_expressions = expr(df) + exclude = ( + self._keys + if expr._function_name.split("->", maxsplit=1)[0] in ("all", "selector") + else [] + ) + agg_columns.extend( + [ + native_expression.alias(alias) + for native_expression, output_name, alias in zip( + native_expressions, output_names, aliases + ) + if output_name not in exclude + ] ) - if expr._depth == 0: # pragma: no cover - # e.g. agg(nw.len()) # noqa: ERA001 - agg_func = get_spark_function(expr._function_name, **expr._kwargs) - simple_aggregations.update({alias: agg_func(keys[0]) for alias in aliases}) - continue - # e.g. agg(nw.mean('a')) # noqa: ERA001 - function_name = re.sub(r"(\w+->)", "", expr._function_name) - agg_func = get_spark_function(function_name, **expr._kwargs) + if not agg_columns: + return self._compliant_frame._from_native_frame( + self._compliant_frame._native_frame.select(*self._keys).dropDuplicates() + ) - simple_aggregations.update( - { - alias: agg_func(output_name) - for alias, output_name in zip(aliases, output_names) - } + return self._compliant_frame._from_native_frame( + self._compliant_frame._native_frame.groupBy(self._keys).agg(*agg_columns) ) - - agg_columns = [col_.alias(name) for name, col_ in simple_aggregations.items()] - result_simple = grouped.agg(*agg_columns) - return from_dataframe(result_simple) From a1c8a2b457ddd53807b2805d351a6e2e71c0da3e Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 27 Jan 2025 10:12:05 +0100 Subject: [PATCH 4/6] additional test --- tests/expr_and_series/lit_test.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/expr_and_series/lit_test.py b/tests/expr_and_series/lit_test.py index ddc1267e0..b4a85a0d0 100644 --- a/tests/expr_and_series/lit_test.py +++ b/tests/expr_and_series/lit_test.py @@ -85,7 +85,7 @@ def test_lit_out_name(constructor: Constructor) -> None: ("right_scalar_with_agg", nw.col("a").mean() - 1, [1]), ], ) -def test_lit_operation( +def test_lit_operation_in_select( constructor: Constructor, col_name: str, expr: nw.Expr, @@ -114,6 +114,30 @@ def test_lit_operation( assert_equal_data(result, expected) +@pytest.mark.parametrize( + ("col_name", "expr", "expected_result"), + [ + ("lit_and_scalar", (nw.lit(2) + 1), [3, 3, 3]), + ("scalar_and_lit", (1 + nw.lit(2)), [3, 3, 3]), + ], +) +def test_lit_operation_in_with_columns( + constructor: Constructor, + col_name: str, + expr: nw.Expr, + expected_result: list[int], +) -> None: + data = {"a": [1, 3, 2]} + df_raw = constructor(data) + df = nw.from_native(df_raw).lazy() + result = df.with_columns(expr.alias(col_name)) + expected = { + "a": data["a"], + col_name: expected_result, + } + assert_equal_data(result, expected) + + @pytest.mark.skipif(PANDAS_VERSION < (1, 5), reason="too old for pyarrow") def test_date_lit(constructor: Constructor, request: pytest.FixtureRequest) -> None: if "dask" in str(constructor) or "pyspark" in str(constructor): From 598df7d329c9b8e153ba8ef45be8b182e555dbb7 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 27 Jan 2025 10:35:45 +0100 Subject: [PATCH 5/6] one duckdb xfail --- tests/expr_and_series/lit_test.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/expr_and_series/lit_test.py b/tests/expr_and_series/lit_test.py index 1d0fe2e72..525584f15 100644 --- a/tests/expr_and_series/lit_test.py +++ b/tests/expr_and_series/lit_test.py @@ -126,7 +126,10 @@ def test_lit_operation_in_with_columns( col_name: str, expr: nw.Expr, expected_result: list[int], + request: pytest.FixtureRequest, ) -> None: + if "duckdb" in str(constructor) and col_name == "scalar_and_lit": + request.applymarker(pytest.mark.xfail) data = {"a": [1, 3, 2]} df_raw = constructor(data) df = nw.from_native(df_raw).lazy() From a93c0fa714844d9b65c1cd3b1d037c5f98d69c85 Mon Sep 17 00:00:00 2001 From: FBruzzesi Date: Mon, 27 Jan 2025 10:43:37 +0100 Subject: [PATCH 6/6] split func name in maybe_evaluate as well --- narwhals/_spark_like/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index 09d102bb0..ab7298e52 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -161,7 +161,11 @@ def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any, *, returns_scalar: bool) -> msg = "Multi-output expressions (e.g. `nw.all()` or `nw.col('a', 'b')`) not supported in this context" raise NotImplementedError(msg) column_result = column_results[0] - if obj._returns_scalar and obj._function_name != "lit" and not returns_scalar: + if ( + obj._returns_scalar + and obj._function_name.split("->", maxsplit=1)[0] != "lit" + and not returns_scalar + ): # Returns scalar, but overall expression doesn't. # Let PySpark do its broadcasting return column_result.over(Window.partitionBy(F.lit(1)))