From 7bcd9f02ae3ecb4ec02ab1670b086eea2507f0d5 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 7 Jan 2025 07:58:49 +0100 Subject: [PATCH 1/2] increase pyspark min version --- .github/workflows/extremes.yml | 4 ++-- narwhals/_spark_like/expr.py | 14 ++------------ narwhals/_spark_like/group_by.py | 13 +++---------- narwhals/_spark_like/utils.py | 18 ++++-------------- pyproject.toml | 2 +- 5 files changed, 12 insertions(+), 39 deletions(-) diff --git a/.github/workflows/extremes.yml b/.github/workflows/extremes.yml index 9e7e997b2..ac0d6a163 100644 --- a/.github/workflows/extremes.yml +++ b/.github/workflows/extremes.yml @@ -61,7 +61,7 @@ jobs: cache-suffix: ${{ matrix.python-version }} cache-dependency-glob: "pyproject.toml" - name: install-pretty-old-versions - run: uv pip install pipdeptree tox virtualenv setuptools pandas==1.1.5 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 "pyarrow-stubs<17" pyspark==3.3.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system + run: uv pip install pipdeptree tox virtualenv setuptools pandas==1.1.5 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 "pyarrow-stubs<17" pyspark==3.5.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system - name: install-reqs run: uv pip install -e ".[dev]" --system - name: show-deps @@ -99,7 +99,7 @@ jobs: cache-suffix: ${{ matrix.python-version }} cache-dependency-glob: "pyproject.toml" - name: install-not-so-old-versions - run: uv pip install tox virtualenv setuptools pandas==2.0.3 polars==0.20.8 numpy==1.24.4 pyarrow==15.0.0 "pyarrow-stubs<17" pyspark==3.4.0 scipy==1.8.0 scikit-learn==1.3.0 dask[dataframe]==2024.10 tzdata --system + run: uv pip install tox virtualenv setuptools pandas==2.0.3 polars==0.20.8 numpy==1.24.4 pyarrow==15.0.0 "pyarrow-stubs<17" pyspark==3.5.0 scipy==1.8.0 scikit-learn==1.3.0 dask[dataframe]==2024.10 tzdata --system - name: install-reqs run: uv pip install -e ".[dev]" --system - name: show-deps diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index b74aea678..d190b5667 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -225,12 +225,7 @@ def std(self: Self, ddof: int) -> Self: from narwhals._spark_like.utils import _std - func = partial( - _std, - ddof=ddof, - backend_version=self._backend_version, - np_version=parse_version(np.__version__), - ) + func = partial(_std, ddof=ddof, np_version=parse_version(np.__version__)) return self._from_call(func, "std", returns_scalar=True, ddof=ddof) @@ -241,11 +236,6 @@ def var(self: Self, ddof: int) -> Self: from narwhals._spark_like.utils import _var - func = partial( - _var, - ddof=ddof, - backend_version=self._backend_version, - np_version=parse_version(np.__version__), - ) + func = partial(_var, ddof=ddof, np_version=parse_version(np.__version__)) return self._from_call(func, "var", returns_scalar=True, ddof=ddof) diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py index c7cc52bf1..7f3dc077d 100644 --- a/narwhals/_spark_like/group_by.py +++ b/narwhals/_spark_like/group_by.py @@ -79,16 +79,13 @@ def _from_native_frame(self, df: SparkLikeLazyFrame) -> SparkLikeLazyFrame: ) -def get_spark_function( - function_name: str, backend_version: tuple[int, ...], **kwargs: Any -) -> Column: +def get_spark_function(function_name: str, **kwargs: Any) -> Column: if function_name in {"std", "var"}: import numpy as np # ignore-banned-import return partial( _std if function_name == "std" else _var, ddof=kwargs.get("ddof", 1), - backend_version=backend_version, np_version=parse_version(np.__version__), ) from pyspark.sql import functions as F # noqa: N812 @@ -127,9 +124,7 @@ def agg_pyspark( function_name = POLARS_TO_PYSPARK_AGGREGATIONS.get( expr._function_name, expr._function_name ) - agg_func = get_spark_function( - function_name, backend_version=expr._backend_version, **expr._kwargs - ) + agg_func = get_spark_function(function_name, **expr._kwargs) simple_aggregations.update( {output_name: agg_func(keys[0]) for output_name in expr._output_names} ) @@ -146,9 +141,7 @@ def agg_pyspark( pyspark_function = POLARS_TO_PYSPARK_AGGREGATIONS.get( function_name, function_name ) - agg_func = get_spark_function( - pyspark_function, backend_version=expr._backend_version, **expr._kwargs - ) + agg_func = get_spark_function(pyspark_function, **expr._kwargs) simple_aggregations.update( { diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index a3c77033c..fb3a3f3c4 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -120,13 +120,8 @@ def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any: return obj -def _std( - _input: Column | str, - ddof: int, - backend_version: tuple[int, ...], - np_version: tuple[int, ...], -) -> Column: - if backend_version < (3, 5) or np_version > (2, 0): +def _std(_input: Column | str, ddof: int, np_version: tuple[int, ...]) -> Column: + if np_version > (2, 0): from pyspark.sql import functions as F # noqa: N812 if ddof == 1: @@ -142,13 +137,8 @@ def _std( return stddev(input_col, ddof=ddof) -def _var( - _input: Column | str, - ddof: int, - backend_version: tuple[int, ...], - np_version: tuple[int, ...], -) -> Column: - if backend_version < (3, 5) or np_version > (2, 0): +def _var(_input: Column | str, ddof: int, np_version: tuple[int, ...]) -> Column: + if np_version > (2, 0): from pyspark.sql import functions as F # noqa: N812 if ddof == 1: diff --git a/pyproject.toml b/pyproject.toml index c16407d80..bb89564b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ pandas = ["pandas>=0.25.3"] modin = ["modin"] cudf = ["cudf>=24.10.0"] pyarrow = ["pyarrow>=11.0.0"] -pyspark = ["pyspark>=3.3.0"] +pyspark = ["pyspark>=3.5.0"] polars = ["polars>=0.20.3"] dask = ["dask[dataframe]>=2024.8"] duckdb = ["duckdb>=1.0"] From b5d72f9744b2097c19134ba6df2354cda46a4770 Mon Sep 17 00:00:00 2001 From: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Date: Tue, 7 Jan 2025 08:12:20 +0100 Subject: [PATCH 2/2] update extremes --- .github/workflows/extremes.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/extremes.yml b/.github/workflows/extremes.yml index ac0d6a163..47ebc85ea 100644 --- a/.github/workflows/extremes.yml +++ b/.github/workflows/extremes.yml @@ -75,7 +75,7 @@ jobs: echo "$DEPS" | grep 'polars==0.20.3' echo "$DEPS" | grep 'numpy==1.17.5' echo "$DEPS" | grep 'pyarrow==11.0.0' - echo "$DEPS" | grep 'pyspark==3.3.0' + echo "$DEPS" | grep 'pyspark==3.5.0' echo "$DEPS" | grep 'scipy==1.5.0' echo "$DEPS" | grep 'scikit-learn==1.1.0' - name: Run pytest @@ -111,7 +111,7 @@ jobs: echo "$DEPS" | grep 'polars==0.20.8' echo "$DEPS" | grep 'numpy==1.24.4' echo "$DEPS" | grep 'pyarrow==15.0.0' - echo "$DEPS" | grep 'pyspark==3.4.0' + echo "$DEPS" | grep 'pyspark==3.5.0' echo "$DEPS" | grep 'scipy==1.8.0' echo "$DEPS" | grep 'scikit-learn==1.3.0' echo "$DEPS" | grep 'dask==2024.10'