From d0225b3359c302d56420c042655b4aee7390d283 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Sun, 15 Dec 2024 09:02:59 +0000 Subject: [PATCH] chore: Add some Compliant Protocols (#1522) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- narwhals/_arrow/dataframe.py | 5 +- narwhals/_arrow/expr.py | 11 +- narwhals/_arrow/group_by.py | 6 +- narwhals/_arrow/namespace.py | 35 ++-- narwhals/_arrow/selectors.py | 9 +- narwhals/_arrow/series.py | 3 +- narwhals/_arrow/typing.py | 2 +- narwhals/_arrow/utils.py | 2 +- narwhals/_dask/dataframe.py | 3 +- narwhals/_dask/expr.py | 8 +- narwhals/_dask/group_by.py | 6 +- narwhals/_dask/namespace.py | 67 ++----- narwhals/_dask/selectors.py | 2 +- narwhals/_expression_parsing.py | 270 +++++++++++------------------ narwhals/_pandas_like/expr.py | 8 +- narwhals/_pandas_like/group_by.py | 8 +- narwhals/_pandas_like/namespace.py | 37 ++-- narwhals/_pandas_like/selectors.py | 2 +- narwhals/_pandas_like/series.py | 3 +- narwhals/_pandas_like/utils.py | 2 +- narwhals/_polars/namespace.py | 5 +- narwhals/_spark_like/expr.py | 7 +- narwhals/_spark_like/group_by.py | 5 +- narwhals/_spark_like/namespace.py | 38 +--- narwhals/typing.py | 51 ++++++ 25 files changed, 277 insertions(+), 318 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index ea6ed4697..34758bd82 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -39,8 +39,11 @@ from narwhals.typing import SizeUnit from narwhals.utils import Version +from narwhals.typing import CompliantDataFrame +from narwhals.typing import CompliantLazyFrame -class ArrowDataFrame: + +class ArrowDataFrame(CompliantDataFrame, CompliantLazyFrame): # --- not in the spec --- def __init__( self: Self, diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index cde4dc2aa..c875b7b56 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -18,18 +18,20 @@ from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.namespace import ArrowNamespace - from narwhals._arrow.series import ArrowSeries from narwhals._arrow.typing import IntoArrowExpr from narwhals.dtypes import DType from narwhals.utils import Version +from narwhals._arrow.series import ArrowSeries +from narwhals.typing import CompliantExpr -class ArrowExpr: + +class ArrowExpr(CompliantExpr[ArrowSeries]): _implementation: Implementation = Implementation.PYARROW def __init__( self: Self, - call: Callable[[ArrowDataFrame], list[ArrowSeries]], + call: Callable[[ArrowDataFrame], Sequence[ArrowSeries]], *, depth: int, function_name: str, @@ -57,6 +59,9 @@ def __repr__(self: Self) -> str: # pragma: no cover f"output_names={self._output_names}" ) + def __call__(self, df: ArrowDataFrame) -> Sequence[ArrowSeries]: + return self._call(df) + @classmethod def from_column_names( cls: type[Self], diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 11f9afd08..7e2422a17 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -5,6 +5,7 @@ from typing import Any from typing import Callable from typing import Iterator +from typing import Sequence from narwhals._expression_parsing import is_simple_aggregation from narwhals._expression_parsing import parse_into_exprs @@ -17,8 +18,9 @@ from typing_extensions import Self from narwhals._arrow.dataframe import ArrowDataFrame - from narwhals._arrow.expr import ArrowExpr + from narwhals._arrow.series import ArrowSeries from narwhals._arrow.typing import IntoArrowExpr + from narwhals.typing import CompliantExpr def polars_to_arrow_aggregations() -> ( @@ -122,7 +124,7 @@ def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]: def agg_arrow( grouped: pa.TableGroupBy, - exprs: list[ArrowExpr], + exprs: Sequence[CompliantExpr[ArrowSeries]], keys: list[str], output_names: list[str], from_dataframe: Callable[[Any], ArrowDataFrame], diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 033b69da8..c4da2e824 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -5,6 +5,7 @@ from typing import Any from typing import Iterable from typing import Literal +from typing import Sequence from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.expr import ArrowExpr @@ -17,6 +18,7 @@ from narwhals._expression_parsing import combine_root_names from narwhals._expression_parsing import parse_into_exprs from narwhals._expression_parsing import reduce_output_names +from narwhals.typing import CompliantNamespace from narwhals.utils import Implementation from narwhals.utils import import_dtypes_module @@ -30,10 +32,10 @@ from narwhals.utils import Version -class ArrowNamespace: +class ArrowNamespace(CompliantNamespace[ArrowSeries]): def _create_expr_from_callable( self: Self, - func: Callable[[ArrowDataFrame], list[ArrowSeries]], + func: Callable[[ArrowDataFrame], Sequence[ArrowSeries]], *, depth: int, function_name: str, @@ -181,7 +183,7 @@ def all_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: ArrowDataFrame) -> list[ArrowSeries]: - series = (s for _expr in parsed_exprs for s in _expr._call(df)) + series = (s for _expr in parsed_exprs for s in _expr(df)) return [reduce(lambda x, y: x & y, series)] return self._create_expr_from_callable( @@ -196,7 +198,7 @@ def any_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: ArrowDataFrame) -> list[ArrowSeries]: - series = (s for _expr in parsed_exprs for s in _expr._call(df)) + series = (s for _expr in parsed_exprs for s in _expr(df)) return [reduce(lambda x, y: x | y, series)] return self._create_expr_from_callable( @@ -214,7 +216,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: series = ( s.fill_null(0, strategy=None, limit=None) for _expr in parsed_exprs - for s in _expr._call(df) + for s in _expr(df) ) return [reduce(lambda x, y: x + y, series)] @@ -234,12 +236,12 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: series = ( s.fill_null(0, strategy=None, limit=None) for _expr in parsed_exprs - for s in _expr._call(df) + for s in _expr(df) ) non_na = ( 1 - s.is_null().cast(dtypes.Int64()) for _expr in parsed_exprs - for s in _expr._call(df) + for s in _expr(df) ) return [ reduce(lambda x, y: x + y, series) / reduce(lambda x, y: x + y, non_na) @@ -259,7 +261,7 @@ def min_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: ArrowDataFrame) -> list[ArrowSeries]: - init_series, *series = [s for _expr in parsed_exprs for s in _expr._call(df)] + init_series, *series = [s for _expr in parsed_exprs for s in _expr(df)] return [ ArrowSeries( native_series=reduce( @@ -287,7 +289,7 @@ def max_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: ArrowDataFrame) -> list[ArrowSeries]: - init_series, *series = [s for _expr in parsed_exprs for s in _expr._call(df)] + init_series, *series = [s for _expr in parsed_exprs for s in _expr(df)] return [ ArrowSeries( native_series=reduce( @@ -387,7 +389,7 @@ def concat_str( ) -> ArrowExpr: import pyarrow.compute as pc - parsed_exprs: list[ArrowExpr] = [ + parsed_exprs = [ *parse_into_exprs(*exprs, namespace=self), *parse_into_exprs(*more_exprs, namespace=self), ] @@ -397,7 +399,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: series = ( s._native_series for _expr in parsed_exprs - for s in _expr.cast(dtypes.String())._call(df) + for s in _expr.cast(dtypes.String())(df) ) null_handling = "skip" if ignore_nulls else "emit_null" result_series = pc.binary_join_element_wise( @@ -437,7 +439,7 @@ def __init__( self._otherwise_value = otherwise_value self._version = version - def __call__(self: Self, df: ArrowDataFrame) -> list[ArrowSeries]: + def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]: import pyarrow as pa import pyarrow.compute as pc @@ -446,9 +448,9 @@ def __call__(self: Self, df: ArrowDataFrame) -> list[ArrowSeries]: plx = ArrowNamespace(backend_version=self._backend_version, version=self._version) - condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] + condition = parse_into_expr(self._condition, namespace=plx)(df)[0] try: - value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0] + value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0] except TypeError: # `self._otherwise_value` is a scalar and can't be converted to an expression value_series = condition.__class__._from_iterable( @@ -471,9 +473,7 @@ def __call__(self: Self, df: ArrowDataFrame) -> list[ArrowSeries]: ) ] try: - otherwise_series = parse_into_expr( - self._otherwise_value, namespace=plx - )._call(df)[0] + otherwise_expr = parse_into_expr(self._otherwise_value, namespace=plx) except TypeError: # `self._otherwise_value` is a scalar and can't be converted to an expression. # Remark that string values _are_ converted into expressions! @@ -485,6 +485,7 @@ def __call__(self: Self, df: ArrowDataFrame) -> list[ArrowSeries]: ) ] else: + otherwise_series = otherwise_expr(df)[0] condition_native, otherwise_native = broadcast_series( [condition, otherwise_series] ) diff --git a/narwhals/_arrow/selectors.py b/narwhals/_arrow/selectors.py index 80f4eb4fa..1d0180c4f 100644 --- a/narwhals/_arrow/selectors.py +++ b/narwhals/_arrow/selectors.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import NoReturn +from typing import Sequence from narwhals._arrow.expr import ArrowExpr from narwhals.utils import Implementation @@ -127,10 +128,10 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]: def __or__(self: Self, other: Self | Any) -> ArrowSelector | Any: if isinstance(other, ArrowSelector): - def call(df: ArrowDataFrame) -> list[ArrowSeries]: - lhs = self._call(df) - rhs = other._call(df) - return [x for x in lhs if x.name not in {x.name for x in rhs}] + rhs + def call(df: ArrowDataFrame) -> Sequence[ArrowSeries]: + lhs = self(df) + rhs = other(df) + return [*(x for x in lhs if x.name not in {x.name for x in rhs}), *rhs] return ArrowSelector( call, diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index ae964cb84..8e3dacd66 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -30,6 +30,7 @@ from narwhals._arrow.namespace import ArrowNamespace from narwhals.dtypes import DType from narwhals.utils import Version +from narwhals.typing import CompliantSeries def maybe_extract_py_scalar(value: Any, return_py_scalar: bool) -> Any: # noqa: FBT001 @@ -38,7 +39,7 @@ def maybe_extract_py_scalar(value: Any, return_py_scalar: bool) -> Any: # noqa: return value -class ArrowSeries: +class ArrowSeries(CompliantSeries): def __init__( self: Self, native_series: pa.ChunkedArray, diff --git a/narwhals/_arrow/typing.py b/narwhals/_arrow/typing.py index ab68e044e..9d7130a60 100644 --- a/narwhals/_arrow/typing.py +++ b/narwhals/_arrow/typing.py @@ -14,4 +14,4 @@ from narwhals._arrow.expr import ArrowExpr from narwhals._arrow.series import ArrowSeries - IntoArrowExpr: TypeAlias = Union[ArrowExpr, str, int, float, ArrowSeries] + IntoArrowExpr: TypeAlias = Union[ArrowExpr, str, ArrowSeries] diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index 1742fd199..f9a8fa890 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -338,7 +338,7 @@ def cast_for_truediv( return arrow_array, pa_object -def broadcast_series(series: list[ArrowSeries]) -> list[Any]: +def broadcast_series(series: Sequence[ArrowSeries]) -> list[Any]: lengths = [len(s) for s in series] max_length = max(lengths) fast_path = all(_len == max_length for _len in lengths) diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 4184ae409..7a79a2d36 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -29,9 +29,10 @@ from narwhals._dask.typing import IntoDaskExpr from narwhals.dtypes import DType from narwhals.utils import Version +from narwhals.typing import CompliantLazyFrame -class DaskLazyFrame: +class DaskLazyFrame(CompliantLazyFrame): def __init__( self, native_dataframe: dd.DataFrame, diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 8990f0a8f..748d332b2 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -15,6 +15,7 @@ from narwhals._pandas_like.utils import calculate_timestamp_datetime from narwhals._pandas_like.utils import native_to_narwhals_dtype from narwhals.exceptions import ColumnNotFoundError +from narwhals.typing import CompliantExpr from narwhals.utils import Implementation from narwhals.utils import generate_temporary_column_name from narwhals.utils import import_dtypes_module @@ -29,12 +30,12 @@ from narwhals.utils import Version -class DaskExpr: +class DaskExpr(CompliantExpr["dask_expr.Series"]): _implementation: Implementation = Implementation.DASK def __init__( self, - call: Callable[[DaskLazyFrame], list[dask_expr.Series]], + call: Callable[[DaskLazyFrame], Sequence[dask_expr.Series]], *, depth: int, function_name: str, @@ -55,6 +56,9 @@ def __init__( self._backend_version = backend_version self._version = version + def __call__(self, df: DaskLazyFrame) -> Sequence[dask_expr.Series]: + return self._call(df) + def __narwhals_expr__(self) -> None: ... def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 7165b8c7e..af269abd7 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Sequence from narwhals._expression_parsing import is_simple_aggregation from narwhals._expression_parsing import parse_into_exprs @@ -11,11 +12,12 @@ if TYPE_CHECKING: import dask.dataframe as dd + import dask_expr import pandas as pd from narwhals._dask.dataframe import DaskLazyFrame - from narwhals._dask.expr import DaskExpr from narwhals._dask.typing import IntoDaskExpr + from narwhals.typing import CompliantExpr def n_unique() -> dd.Aggregation: @@ -101,7 +103,7 @@ def _from_native_frame(self, df: DaskLazyFrame) -> DaskLazyFrame: def agg_dask( df: DaskLazyFrame, grouped: Any, - exprs: list[DaskExpr], + exprs: Sequence[CompliantExpr[dask_expr.Series]], keys: list[str], from_dataframe: Callable[[Any], DaskLazyFrame], ) -> DaskLazyFrame: diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index b3e2814ca..a64734bae 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -3,10 +3,9 @@ from functools import reduce from typing import TYPE_CHECKING from typing import Any -from typing import Callable from typing import Iterable from typing import Literal -from typing import NoReturn +from typing import Sequence from typing import cast from narwhals._dask.dataframe import DaskLazyFrame @@ -19,6 +18,7 @@ from narwhals._expression_parsing import combine_root_names from narwhals._expression_parsing import parse_into_exprs from narwhals._expression_parsing import reduce_output_names +from narwhals.typing import CompliantNamespace if TYPE_CHECKING: import dask_expr @@ -28,7 +28,7 @@ from narwhals.utils import Version -class DaskNamespace: +class DaskNamespace(CompliantNamespace["dask_expr.Series"]): @property def selectors(self) -> DaskSelectorNamespace: return DaskSelectorNamespace( @@ -144,7 +144,7 @@ def all_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: DaskLazyFrame) -> list[dask_expr.Series]: - series = [s for _expr in parsed_exprs for s in _expr._call(df)] + series = [s for _expr in parsed_exprs for s in _expr(df)] return [reduce(lambda x, y: x & y, series).rename(series[0].name)] return DaskExpr( @@ -162,7 +162,7 @@ def any_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: DaskLazyFrame) -> list[dask_expr.Series]: - series = [s for _expr in parsed_exprs for s in _expr._call(df)] + series = [s for _expr in parsed_exprs for s in _expr(df)] return [reduce(lambda x, y: x | y, series).rename(series[0].name)] return DaskExpr( @@ -180,7 +180,7 @@ def sum_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: DaskLazyFrame) -> list[dask_expr.Series]: - series = [s.fillna(0) for _expr in parsed_exprs for s in _expr._call(df)] + series = [s.fillna(0) for _expr in parsed_exprs for s in _expr(df)] return [reduce(lambda x, y: x + y, series).rename(series[0].name)] return DaskExpr( @@ -255,8 +255,8 @@ def mean_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: DaskLazyFrame) -> list[dask_expr.Series]: - series = (s.fillna(0) for _expr in parsed_exprs for s in _expr._call(df)) - non_na = (1 - s.isna() for _expr in parsed_exprs for s in _expr._call(df)) + series = (s.fillna(0) for _expr in parsed_exprs for s in _expr(df)) + non_na = (1 - s.isna() for _expr in parsed_exprs for s in _expr(df)) return [ name_preserving_div( reduce(name_preserving_sum, series), @@ -281,7 +281,7 @@ def min_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: DaskLazyFrame) -> list[dask_expr.Series]: - series = [s for _expr in parsed_exprs for s in _expr._call(df)] + series = [s for _expr in parsed_exprs for s in _expr(df)] return [dd.concat(series, axis=1).min(axis=1).rename(series[0].name)] @@ -302,7 +302,7 @@ def max_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: DaskLazyFrame) -> list[dask_expr.Series]: - series = [s for _expr in parsed_exprs for s in _expr._call(df)] + series = [s for _expr in parsed_exprs for s in _expr(df)] return [dd.concat(series, axis=1).max(axis=1).rename(series[0].name)] @@ -317,36 +317,6 @@ def func(df: DaskLazyFrame) -> list[dask_expr.Series]: version=self._version, ) - def _create_expr_from_series(self, _: Any) -> NoReturn: - msg = "`_create_expr_from_series` for DaskNamespace exists only for compatibility" - raise NotImplementedError(msg) - - def _create_compliant_series(self, _: Any) -> NoReturn: - msg = "`_create_compliant_series` for DaskNamespace exists only for compatibility" - raise NotImplementedError(msg) - - def _create_series_from_scalar( - self, value: Any, *, reference_series: DaskExpr - ) -> NoReturn: - msg = ( - "`_create_series_from_scalar` for DaskNamespace exists only for compatibility" - ) - raise NotImplementedError(msg) - - def _create_expr_from_callable( # pragma: no cover - self, - func: Callable[[DaskLazyFrame], list[DaskExpr]], - *, - depth: int, - function_name: str, - root_names: list[str] | None, - output_names: list[str] | None, - ) -> DaskExpr: - msg = ( - "`_create_expr_from_callable` for DaskNamespace exists only for compatibility" - ) - raise NotImplementedError(msg) - def when( self, *predicates: IntoDaskExpr, @@ -369,14 +339,14 @@ def concat_str( separator: str = "", ignore_nulls: bool = False, ) -> DaskExpr: - parsed_exprs: list[DaskExpr] = [ + parsed_exprs = [ *parse_into_exprs(*exprs, namespace=self), *parse_into_exprs(*more_exprs, namespace=self), ] def func(df: DaskLazyFrame) -> list[dask_expr.Series]: - series = (s.astype(str) for _expr in parsed_exprs for s in _expr._call(df)) - null_mask = [s for _expr in parsed_exprs for s in _expr.is_null()._call(df)] + series = (s.astype(str) for _expr in parsed_exprs for s in _expr(df)) + null_mask = [s for _expr in parsed_exprs for s in _expr.is_null()(df)] if not ignore_nulls: null_mask_result = reduce(lambda x, y: x | y, null_mask) @@ -430,16 +400,16 @@ def __init__( self._returns_scalar = returns_scalar self._version = version - def __call__(self, df: DaskLazyFrame) -> list[dask_expr.Series]: + def __call__(self, df: DaskLazyFrame) -> Sequence[dask_expr.Series]: from narwhals._dask.namespace import DaskNamespace from narwhals._expression_parsing import parse_into_expr plx = DaskNamespace(backend_version=self._backend_version, version=self._version) - condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] + condition = parse_into_expr(self._condition, namespace=plx)(df)[0] condition = cast("dask_expr.Series", condition) try: - value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0] + value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0] except TypeError: # `self._otherwise_value` is a scalar and can't be converted to an expression _df = condition.to_frame("a") @@ -451,12 +421,11 @@ def __call__(self, df: DaskLazyFrame) -> list[dask_expr.Series]: if self._otherwise_value is None: return [value_series.where(condition)] try: - otherwise_series = parse_into_expr( - self._otherwise_value, namespace=plx - )._call(df)[0] + otherwise_expr = parse_into_expr(self._otherwise_value, namespace=plx) except TypeError: # `self._otherwise_value` is a scalar and can't be converted to an expression return [value_series.where(condition, self._otherwise_value)] + otherwise_series = otherwise_expr(df)[0] validate_comparand(condition, otherwise_series) return [value_series.where(condition, otherwise_series)] diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index c1704bdbd..d4064353d 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -134,7 +134,7 @@ def __or__(self: Self, other: DaskSelector | Any) -> DaskSelector | Any: def call(df: DaskLazyFrame) -> list[dask_expr.Series]: lhs = self._call(df) rhs = other._call(df) - return [x for x in lhs if x.name not in {x.name for x in rhs}] + rhs + return [*(x for x in lhs if x.name not in {x.name for x in rhs}), *rhs] return DaskSelector( call, diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 9321641bf..e15746f0b 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -17,99 +17,47 @@ from narwhals.utils import Implementation if TYPE_CHECKING: - from narwhals._arrow.dataframe import ArrowDataFrame + from typing_extensions import TypeAlias + from narwhals._arrow.expr import ArrowExpr - from narwhals._arrow.namespace import ArrowNamespace - from narwhals._arrow.series import ArrowSeries - from narwhals._arrow.typing import IntoArrowExpr - from narwhals._dask.dataframe import DaskLazyFrame - from narwhals._dask.expr import DaskExpr - from narwhals._dask.namespace import DaskNamespace - from narwhals._dask.typing import IntoDaskExpr - from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.expr import PandasLikeExpr - from narwhals._pandas_like.namespace import PandasLikeNamespace - from narwhals._pandas_like.series import PandasLikeSeries - from narwhals._pandas_like.typing import IntoPandasLikeExpr - from narwhals._polars.expr import PolarsExpr - from narwhals._polars.namespace import PolarsNamespace - from narwhals._polars.series import PolarsSeries - from narwhals._polars.typing import IntoPolarsExpr - from narwhals._spark_like.dataframe import SparkLikeLazyFrame - from narwhals._spark_like.expr import SparkLikeExpr - from narwhals._spark_like.namespace import SparkLikeNamespace - from narwhals._spark_like.typing import IntoSparkLikeExpr - - CompliantNamespace = Union[ - PandasLikeNamespace, - ArrowNamespace, - DaskNamespace, - PolarsNamespace, - SparkLikeNamespace, - ] - CompliantExpr = Union[PandasLikeExpr, ArrowExpr, DaskExpr, PolarsExpr, SparkLikeExpr] - IntoCompliantExpr = Union[ - IntoPandasLikeExpr, IntoArrowExpr, IntoDaskExpr, IntoPolarsExpr, IntoSparkLikeExpr - ] - IntoCompliantExprT = TypeVar("IntoCompliantExprT", bound=IntoCompliantExpr) - CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExpr) - CompliantSeries = Union[PandasLikeSeries, ArrowSeries, PolarsSeries] - ListOfCompliantSeries = Union[ - list[PandasLikeSeries], list[ArrowSeries], list[DaskExpr], list[PolarsSeries] - ] - ListOfCompliantExpr = Union[ - list[PandasLikeExpr], - list[ArrowExpr], - list[DaskExpr], - list[PolarsExpr], - list[SparkLikeExpr], - ] - CompliantDataFrame = Union[ - PandasLikeDataFrame, ArrowDataFrame, DaskLazyFrame, SparkLikeLazyFrame - ] + from narwhals.typing import CompliantDataFrame + from narwhals.typing import CompliantExpr + from narwhals.typing import CompliantLazyFrame + from narwhals.typing import CompliantNamespace + from narwhals.typing import CompliantSeries + from narwhals.typing import CompliantSeriesT_co + + IntoCompliantExpr: TypeAlias = ( + CompliantExpr[CompliantSeriesT_co] | str | CompliantSeriesT_co + ) + CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExpr[Any]) + + ArrowOrPandasLikeExpr = TypeVar( + "ArrowOrPandasLikeExpr", bound=Union[ArrowExpr, PandasLikeExpr] + ) + PandasLikeExprT = TypeVar("PandasLikeExprT", bound=PandasLikeExpr) + ArrowExprT = TypeVar("ArrowExprT", bound=ArrowExpr) T = TypeVar("T") def evaluate_into_expr( - df: CompliantDataFrame, into_expr: IntoCompliantExpr -) -> ListOfCompliantSeries: + df: CompliantDataFrame | CompliantLazyFrame, + into_expr: IntoCompliantExpr[CompliantSeriesT_co], +) -> Sequence[CompliantSeriesT_co]: """Return list of raw columns.""" - expr = parse_into_expr(into_expr, namespace=df.__narwhals_namespace__()) # type: ignore[arg-type] - return expr._call(df) # type: ignore[arg-type] - - -@overload -def evaluate_into_exprs( - df: PandasLikeDataFrame, - *exprs: IntoPandasLikeExpr, - **named_exprs: IntoPandasLikeExpr, -) -> list[PandasLikeSeries]: ... - - -@overload -def evaluate_into_exprs( - df: ArrowDataFrame, - *exprs: IntoArrowExpr, - **named_exprs: IntoArrowExpr, -) -> list[ArrowSeries]: ... - - -@overload -def evaluate_into_exprs( - df: DaskLazyFrame, - *exprs: IntoDaskExpr, - **named_exprs: IntoDaskExpr, -) -> list[DaskExpr]: ... + expr = parse_into_expr(into_expr, namespace=df.__narwhals_namespace__()) + return expr(df) def evaluate_into_exprs( df: CompliantDataFrame, - *exprs: IntoCompliantExprT, - **named_exprs: IntoCompliantExprT, -) -> ListOfCompliantSeries: + *exprs: IntoCompliantExpr[CompliantSeriesT_co], + **named_exprs: IntoCompliantExpr[CompliantSeriesT_co], +) -> Sequence[CompliantSeriesT_co]: """Evaluate each expr into Series.""" - series: ListOfCompliantSeries = [ # type: ignore[assignment] + series = [ item for sublist in (evaluate_into_expr(df, into_expr) for into_expr in exprs) for item in sublist @@ -120,98 +68,40 @@ def evaluate_into_exprs( msg = "Named expressions must return a single column" # pragma: no cover raise AssertionError(msg) to_append = evaluated_expr[0].alias(name) - series.append(to_append) # type: ignore[arg-type] + series.append(to_append) return series def maybe_evaluate_expr( - df: CompliantDataFrame, expr: CompliantExpr | T -) -> ListOfCompliantSeries | T: + df: CompliantDataFrame, expr: CompliantExpr[CompliantSeriesT_co] | T +) -> Sequence[CompliantSeriesT_co] | T: """Evaluate `expr` if it's an expression, otherwise return it as is.""" if hasattr(expr, "__narwhals_expr__"): - expr = cast("CompliantExpr", expr) - return expr._call(df) # type: ignore[arg-type] + compliant_expr = cast("CompliantExpr[Any]", expr) + return compliant_expr(df) return expr -@overload -def parse_into_exprs( - *exprs: IntoPandasLikeExpr, - namespace: PandasLikeNamespace, - **named_exprs: IntoPandasLikeExpr, -) -> list[PandasLikeExpr]: ... - - -@overload -def parse_into_exprs( - *exprs: IntoArrowExpr, - namespace: ArrowNamespace, - **named_exprs: IntoArrowExpr, -) -> list[ArrowExpr]: ... - - -@overload -def parse_into_exprs( - *exprs: IntoDaskExpr, - namespace: DaskNamespace, - **named_exprs: IntoDaskExpr, -) -> list[DaskExpr]: ... - - -@overload -def parse_into_exprs( - *exprs: IntoPolarsExpr, - namespace: PolarsNamespace, - **named_exprs: IntoPolarsExpr, -) -> list[PolarsExpr]: ... - - -@overload -def parse_into_exprs( - *exprs: IntoSparkLikeExpr, - namespace: SparkLikeNamespace, - **named_exprs: IntoSparkLikeExpr, -) -> list[SparkLikeExpr]: ... - - def parse_into_exprs( - *exprs: IntoCompliantExpr, - namespace: CompliantNamespace, - **named_exprs: IntoCompliantExpr, -) -> ListOfCompliantExpr: + *exprs: IntoCompliantExpr[CompliantSeriesT_co], + namespace: CompliantNamespace[CompliantSeriesT_co], + **named_exprs: IntoCompliantExpr[CompliantSeriesT_co], +) -> Sequence[CompliantExpr[CompliantSeriesT_co]]: """Parse each input as an expression (if it's not already one). See `parse_into_expr` for more details. """ - return [parse_into_expr(into_expr, namespace=namespace) for into_expr in exprs] + [ # type: ignore[arg-type] - parse_into_expr(expr, namespace=namespace).alias(name) # type: ignore[arg-type] + return [parse_into_expr(into_expr, namespace=namespace) for into_expr in exprs] + [ + parse_into_expr(expr, namespace=namespace).alias(name) for name, expr in named_exprs.items() ] -@overload -def parse_into_expr(into_expr: IntoArrowExpr, namespace: ArrowNamespace) -> ArrowExpr: ... -@overload -def parse_into_expr( - into_expr: IntoPandasLikeExpr, namespace: PandasLikeNamespace -) -> PandasLikeExpr: ... -@overload -def parse_into_expr( - into_expr: IntoPolarsExpr, namespace: PolarsNamespace -) -> PolarsExpr: ... -@overload def parse_into_expr( - into_expr: IntoSparkLikeExpr, namespace: SparkLikeNamespace -) -> SparkLikeExpr: ... -@overload -def parse_into_expr(into_expr: IntoDaskExpr, namespace: DaskNamespace) -> DaskExpr: ... - - -def parse_into_expr( # type: ignore[misc] - into_expr: IntoCompliantExpr, + into_expr: IntoCompliantExpr[CompliantSeriesT_co], *, - namespace: CompliantNamespace, -) -> CompliantExpr: + namespace: CompliantNamespace[CompliantSeriesT_co], +) -> CompliantExpr[CompliantSeriesT_co]: """Parse `into_expr` as an expression. For example, in Polars, we can do both `df.select('a')` and `df.select(pl.col('a'))`. @@ -226,22 +116,42 @@ def parse_into_expr( # type: ignore[misc] if hasattr(into_expr, "__narwhals_expr__"): return into_expr # type: ignore[return-value] if hasattr(into_expr, "__narwhals_series__"): - return namespace._create_expr_from_series(into_expr) # type: ignore[arg-type] + return namespace._create_expr_from_series(into_expr) # type: ignore[no-any-return, attr-defined] if isinstance(into_expr, str): return namespace.col(into_expr) if is_numpy_array(into_expr): series = namespace._create_compliant_series(into_expr) - return namespace._create_expr_from_series(series) # type: ignore[arg-type] + return namespace._create_expr_from_series(series) raise InvalidIntoExprError.from_invalid_type(type(into_expr)) +@overload def reuse_series_implementation( - expr: CompliantExprT, + expr: PandasLikeExprT, attr: str, *args: Any, returns_scalar: bool = False, **kwargs: Any, -) -> CompliantExprT: +) -> PandasLikeExprT: ... + + +@overload +def reuse_series_implementation( + expr: ArrowExprT, + attr: str, + *args: Any, + returns_scalar: bool = False, + **kwargs: Any, +) -> ArrowExprT: ... + + +def reuse_series_implementation( + expr: ArrowExprT | PandasLikeExprT, + attr: str, + *args: Any, + returns_scalar: bool = False, + **kwargs: Any, +) -> ArrowExprT | PandasLikeExprT: """Reuse Series implementation for expression. If Series.foo is already defined, and we'd like Expr.foo to be the same, we can @@ -257,9 +167,9 @@ def reuse_series_implementation( """ plx = expr.__narwhals_namespace__() - def func(df: CompliantDataFrame) -> list[CompliantSeries]: - _args = [maybe_evaluate_expr(df, arg) for arg in args] - _kwargs = { + def func(df: CompliantDataFrame) -> Sequence[CompliantSeries]: + _args = [maybe_evaluate_expr(df, arg) for arg in args] # type: ignore[var-annotated] + _kwargs = { # type: ignore[var-annotated] arg_name: maybe_evaluate_expr(df, arg_value) for arg_name, arg_value in kwargs.items() } @@ -279,7 +189,7 @@ def func(df: CompliantDataFrame) -> list[CompliantSeries]: ) if returns_scalar else getattr(series, attr)(*_args, **_kwargs) - for series in expr._call(df) # type: ignore[arg-type] + for series in expr(df) # type: ignore[arg-type] ] if expr._output_names is not None and ( [s.name for s in out] != expr._output_names @@ -326,16 +236,38 @@ def func(df: CompliantDataFrame) -> list[CompliantSeries]: ) +@overload def reuse_series_namespace_implementation( - expr: CompliantExprT, series_namespace: str, attr: str, *args: Any, **kwargs: Any -) -> CompliantExprT: - # Just like `reuse_series_implementation`, but for e.g. `Expr.dt.foo` instead - # of `Expr.foo`. + expr: ArrowExprT, series_namespace: str, attr: str, *args: Any, **kwargs: Any +) -> ArrowExprT: ... +@overload +def reuse_series_namespace_implementation( + expr: PandasLikeExprT, series_namespace: str, attr: str, *args: Any, **kwargs: Any +) -> PandasLikeExprT: ... +def reuse_series_namespace_implementation( + expr: ArrowExprT | PandasLikeExprT, + series_namespace: str, + attr: str, + *args: Any, + **kwargs: Any, +) -> ArrowExprT | PandasLikeExprT: + """Reuse Series implementation for expression. + + Just like `reuse_series_implementation`, but for e.g. `Expr.dt.foo` instead + of `Expr.foo`. + + Arguments: + expr: expression object. + series_namespace: The Series namespace (e.g. `dt`, `cat`, `str`, `list`, `name`) + attr: name of method. + args: arguments to pass to function. + kwargs: keyword arguments to pass to function. + """ plx = expr.__narwhals_namespace__() return plx._create_expr_from_callable( # type: ignore[return-value] lambda df: [ getattr(getattr(series, series_namespace), attr)(*args, **kwargs) - for series in expr._call(df) # type: ignore[arg-type] + for series in expr(df) # type: ignore[arg-type] ], depth=expr._depth + 1, function_name=f"{expr._function_name}->{series_namespace}.{attr}", @@ -344,7 +276,7 @@ def reuse_series_namespace_implementation( ) -def is_simple_aggregation(expr: CompliantExpr) -> bool: +def is_simple_aggregation(expr: CompliantExpr[Any]) -> bool: """Check if expr is a very simple one. Examples: @@ -361,10 +293,10 @@ def is_simple_aggregation(expr: CompliantExpr) -> bool: return expr._depth < 2 -def combine_root_names(parsed_exprs: Sequence[CompliantExpr]) -> list[str] | None: +def combine_root_names(parsed_exprs: Sequence[CompliantExpr[Any]]) -> list[str] | None: root_names = copy(parsed_exprs[0]._root_names) for arg in parsed_exprs[1:]: - if root_names is not None and hasattr(arg, "__narwhals_expr__"): + if root_names is not None: if arg._root_names is not None: root_names.extend(arg._root_names) else: @@ -373,7 +305,7 @@ def combine_root_names(parsed_exprs: Sequence[CompliantExpr]) -> list[str] | Non return root_names -def reduce_output_names(parsed_exprs: Sequence[CompliantExpr]) -> list[str] | None: +def reduce_output_names(parsed_exprs: Sequence[CompliantExpr[Any]]) -> list[str] | None: """Returns the left-most output name.""" return ( parsed_exprs[0]._output_names[:1] diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index ce90db0c5..2f6ec16a9 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -12,6 +12,7 @@ from narwhals.dependencies import get_numpy from narwhals.dependencies import is_numpy_array from narwhals.exceptions import ColumnNotFoundError +from narwhals.typing import CompliantExpr if TYPE_CHECKING: from typing_extensions import Self @@ -23,10 +24,10 @@ from narwhals.utils import Version -class PandasLikeExpr: +class PandasLikeExpr(CompliantExpr[PandasLikeSeries]): def __init__( self: Self, - call: Callable[[PandasLikeDataFrame], list[PandasLikeSeries]], + call: Callable[[PandasLikeDataFrame], Sequence[PandasLikeSeries]], *, depth: int, function_name: str, @@ -45,6 +46,9 @@ def __init__( self._backend_version = backend_version self._version = version + def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: + return self._call(df) + def __repr__(self) -> str: # pragma: no cover return ( f"PandasLikeExpr(" diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index 5735953bd..a1abe2ebc 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -7,6 +7,7 @@ from typing import Any from typing import Callable from typing import Iterator +from typing import Sequence from narwhals._expression_parsing import is_simple_aggregation from narwhals._expression_parsing import parse_into_exprs @@ -20,8 +21,9 @@ if TYPE_CHECKING: from narwhals._pandas_like.dataframe import PandasLikeDataFrame - from narwhals._pandas_like.expr import PandasLikeExpr + from narwhals._pandas_like.series import PandasLikeSeries from narwhals._pandas_like.typing import IntoPandasLikeExpr + from narwhals.typing import CompliantExpr POLARS_TO_PANDAS_AGGREGATIONS = { "sum": "sum", @@ -136,7 +138,7 @@ def __iter__(self) -> Iterator[tuple[Any, PandasLikeDataFrame]]: def agg_pandas( # noqa: PLR0915 grouped: Any, - exprs: list[PandasLikeExpr], + exprs: Sequence[CompliantExpr[PandasLikeSeries]], keys: list[str], output_names: list[str], from_dataframe: Callable[[Any], PandasLikeDataFrame], @@ -280,7 +282,7 @@ def func(df: Any) -> Any: out_group = [] out_names = [] for expr in exprs: - results_keys = expr._call(from_dataframe(df)) + results_keys = expr(from_dataframe(df)) if not all(len(x) == 1 for x in results_keys): msg = f"Aggregation '{expr._function_name}' failed to aggregate - does your aggregation function return a scalar?" raise ValueError(msg) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 2163e19f6..f6918d01b 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -6,6 +6,7 @@ from typing import Callable from typing import Iterable from typing import Literal +from typing import Sequence from narwhals._expression_parsing import combine_root_names from narwhals._expression_parsing import parse_into_exprs @@ -19,6 +20,7 @@ from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import rename from narwhals._pandas_like.utils import vertical_concat +from narwhals.typing import CompliantNamespace from narwhals.utils import import_dtypes_module if TYPE_CHECKING: @@ -28,7 +30,7 @@ from narwhals.utils import Version -class PandasLikeNamespace: +class PandasLikeNamespace(CompliantNamespace[PandasLikeSeries]): @property def selectors(self) -> PandasSelectorNamespace: return PandasSelectorNamespace( @@ -50,7 +52,7 @@ def __init__( def _create_expr_from_callable( self, - func: Callable[[PandasLikeDataFrame], list[PandasLikeSeries]], + func: Callable[[PandasLikeDataFrame], Sequence[PandasLikeSeries]], *, depth: int, function_name: str, @@ -229,7 +231,7 @@ def sum_horizontal(self, *exprs: IntoPandasLikeExpr) -> PandasLikeExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - series = (s.fill_null(0) for _expr in parsed_exprs for s in _expr._call(df)) + series = (s.fill_null(0) for _expr in parsed_exprs for s in _expr(df)) return [reduce(lambda x, y: x + y, series)] return self._create_expr_from_callable( @@ -244,7 +246,7 @@ def all_horizontal(self, *exprs: IntoPandasLikeExpr) -> PandasLikeExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - series = (s for _expr in parsed_exprs for s in _expr._call(df)) + series = (s for _expr in parsed_exprs for s in _expr(df)) return [reduce(lambda x, y: x & y, series)] return self._create_expr_from_callable( @@ -259,7 +261,7 @@ def any_horizontal(self, *exprs: IntoPandasLikeExpr) -> PandasLikeExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - series = (s for _expr in parsed_exprs for s in _expr._call(df)) + series = (s for _expr in parsed_exprs for s in _expr(df)) return [reduce(lambda x, y: x | y, series)] return self._create_expr_from_callable( @@ -274,8 +276,8 @@ def mean_horizontal(self, *exprs: IntoPandasLikeExpr) -> PandasLikeExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - series = (s.fill_null(0) for _expr in parsed_exprs for s in _expr._call(df)) - non_na = (1 - s.is_null() for _expr in parsed_exprs for s in _expr._call(df)) + series = (s.fill_null(0) for _expr in parsed_exprs for s in _expr(df)) + non_na = (1 - s.is_null() for _expr in parsed_exprs for s in _expr(df)) return [ reduce(lambda x, y: x + y, series) / reduce(lambda x, y: x + y, non_na) ] @@ -292,7 +294,7 @@ def min_horizontal(self, *exprs: IntoPandasLikeExpr) -> PandasLikeExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - series = [s for _expr in parsed_exprs for s in _expr._call(df)] + series = [s for _expr in parsed_exprs for s in _expr(df)] return [ PandasLikeSeries( @@ -322,7 +324,7 @@ def max_horizontal(self, *exprs: IntoPandasLikeExpr) -> PandasLikeExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - series = [s for _expr in parsed_exprs for s in _expr._call(df)] + series = [s for _expr in parsed_exprs for s in _expr(df)] return [ PandasLikeSeries( @@ -415,7 +417,7 @@ def concat_str( separator: str = "", ignore_nulls: bool = False, ) -> PandasLikeExpr: - parsed_exprs: list[PandasLikeExpr] = [ + parsed_exprs = [ *parse_into_exprs(*exprs, namespace=self), *parse_into_exprs(*more_exprs, namespace=self), ] @@ -423,9 +425,9 @@ def concat_str( def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: series = ( - s for _expr in parsed_exprs for s in _expr.cast(dtypes.String())._call(df) + s for _expr in parsed_exprs for s in _expr.cast(dtypes.String())(df) ) - null_mask = [s for _expr in parsed_exprs for s in _expr.is_null()._call(df)] + null_mask = [s for _expr in parsed_exprs for s in _expr.is_null()(df)] if not ignore_nulls: null_mask_result = reduce(lambda x, y: x | y, null_mask) @@ -481,7 +483,7 @@ def __init__( self._otherwise_value = otherwise_value self._version = version - def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: + def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: from narwhals._expression_parsing import parse_into_expr from narwhals._pandas_like.namespace import PandasLikeNamespace from narwhals._pandas_like.utils import broadcast_align_and_extract_native @@ -492,9 +494,9 @@ def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: version=self._version, ) - condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] + condition = parse_into_expr(self._condition, namespace=plx)(df)[0] try: - value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0] + value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0] except TypeError: # `self._otherwise_value` is a scalar and can't be converted to an expression value_series = condition.__class__._from_iterable( @@ -516,9 +518,7 @@ def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: ) ] try: - otherwise_series = parse_into_expr( - self._otherwise_value, namespace=plx - )._call(df)[0] + otherwise_expr = parse_into_expr(self._otherwise_value, namespace=plx) except TypeError: # `self._otherwise_value` is a scalar and can't be converted to an expression return [ @@ -527,6 +527,7 @@ def __call__(self, df: PandasLikeDataFrame) -> list[PandasLikeSeries]: ) ] else: + otherwise_series = otherwise_expr(df)[0] return [value_series.zip_with(condition, otherwise_series)] def then(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasThen: diff --git a/narwhals/_pandas_like/selectors.py b/narwhals/_pandas_like/selectors.py index 9a2c8b7be..2f775aa6c 100644 --- a/narwhals/_pandas_like/selectors.py +++ b/narwhals/_pandas_like/selectors.py @@ -136,7 +136,7 @@ def __or__(self, other: PandasSelector | Any) -> PandasSelector | Any: def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: lhs = self._call(df) rhs = other._call(df) - return [x for x in lhs if x.name not in {x.name for x in rhs}] + rhs + return [*(x for x in lhs if x.name not in {x.name for x in rhs}), *rhs] return PandasSelector( call, diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 145d15678..9f5b562d0 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -19,6 +19,7 @@ from narwhals._pandas_like.utils import select_columns_by_name from narwhals._pandas_like.utils import set_axis from narwhals._pandas_like.utils import to_datetime +from narwhals.typing import CompliantSeries from narwhals.utils import Implementation from narwhals.utils import import_dtypes_module @@ -77,7 +78,7 @@ } -class PandasLikeSeries: +class PandasLikeSeries(CompliantSeries): def __init__( self, native_series: Any, diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 4076a6b88..30c69254c 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -630,7 +630,7 @@ def narwhals_to_native_dtype( # noqa: PLR0915 raise AssertionError(msg) -def broadcast_series(series: list[PandasLikeSeries]) -> list[Any]: +def broadcast_series(series: Sequence[PandasLikeSeries]) -> list[Any]: native_namespace = series[0].__native_namespace__() lengths = [len(s) for s in series] diff --git a/narwhals/_polars/namespace.py b/narwhals/_polars/namespace.py index 7903de15b..3e1ea1761 100644 --- a/narwhals/_polars/namespace.py +++ b/narwhals/_polars/namespace.py @@ -5,6 +5,7 @@ from typing import Iterable from typing import Literal from typing import Sequence +from typing import cast from typing import overload from narwhals._expression_parsing import parse_into_exprs @@ -143,7 +144,7 @@ def mean_horizontal(self: Self, *exprs: IntoPolarsExpr) -> PolarsExpr: from narwhals._polars.expr import PolarsExpr - polars_exprs = parse_into_exprs(*exprs, namespace=self) + polars_exprs = cast("list[PolarsExpr]", parse_into_exprs(*exprs, namespace=self)) if self._backend_version < (0, 20, 8): return PolarsExpr( @@ -182,7 +183,7 @@ def concat_str( from narwhals._polars.expr import PolarsExpr pl_exprs: list[pl.Expr] = [ - expr._native_expr + expr._native_expr # type: ignore[attr-defined] for expr in ( *parse_into_exprs(*exprs, namespace=self), *parse_into_exprs(*more_exprs, namespace=self), diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index c98c79a5a..5695faf51 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -4,9 +4,11 @@ from copy import copy from typing import TYPE_CHECKING from typing import Callable +from typing import Sequence from narwhals._spark_like.utils import get_column_name from narwhals._spark_like.utils import maybe_evaluate +from narwhals.typing import CompliantExpr from narwhals.utils import Implementation from narwhals.utils import parse_version @@ -19,7 +21,7 @@ from narwhals.utils import Version -class SparkLikeExpr: +class SparkLikeExpr(CompliantExpr["Column"]): _implementation = Implementation.PYSPARK def __init__( @@ -45,6 +47,9 @@ def __init__( self._backend_version = backend_version self._version = version + def __call__(self, df: SparkLikeLazyFrame) -> Sequence[Column]: + return self._call(df) + def __narwhals_expr__(self) -> None: ... def __narwhals_namespace__(self) -> SparkLikeNamespace: # pragma: no cover diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py index a3d557c02..ecd9f235d 100644 --- a/narwhals/_spark_like/group_by.py +++ b/narwhals/_spark_like/group_by.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import Callable +from typing import Sequence from narwhals._expression_parsing import is_simple_aggregation from narwhals._expression_parsing import parse_into_exprs @@ -14,8 +15,8 @@ from pyspark.sql import GroupedData from narwhals._spark_like.dataframe import SparkLikeLazyFrame - from narwhals._spark_like.expr import SparkLikeExpr from narwhals._spark_like.typing import IntoSparkLikeExpr + from narwhals.typing import CompliantExpr POLARS_TO_PYSPARK_AGGREGATIONS = { "len": "count", @@ -84,7 +85,7 @@ def get_spark_function(function_name: str) -> Column: def agg_pyspark( grouped: GroupedData, - exprs: list[SparkLikeExpr], + exprs: Sequence[CompliantExpr[Column]], keys: list[str], from_dataframe: Callable[[Any], SparkLikeLazyFrame], ) -> SparkLikeLazyFrame: diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index a762c26c8..44053e7c6 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -3,15 +3,13 @@ import operator from functools import reduce from typing import TYPE_CHECKING -from typing import Any -from typing import Callable -from typing import NoReturn from narwhals._expression_parsing import combine_root_names from narwhals._expression_parsing import parse_into_exprs from narwhals._expression_parsing import reduce_output_names from narwhals._spark_like.expr import SparkLikeExpr from narwhals._spark_like.utils import get_column_name +from narwhals.typing import CompliantNamespace if TYPE_CHECKING: from pyspark.sql import Column @@ -21,44 +19,18 @@ from narwhals.utils import Version -class SparkLikeNamespace: +class SparkLikeNamespace(CompliantNamespace["Column"]): def __init__(self, *, backend_version: tuple[int, ...], version: Version) -> None: self._backend_version = backend_version self._version = version - def _create_expr_from_series(self, _: Any) -> NoReturn: - msg = "`_create_expr_from_series` for PySparkNamespace exists only for compatibility" - raise NotImplementedError(msg) - - def _create_compliant_series(self, _: Any) -> NoReturn: - msg = "`_create_compliant_series` for PySparkNamespace exists only for compatibility" - raise NotImplementedError(msg) - - def _create_series_from_scalar( - self, value: Any, *, reference_series: SparkLikeExpr - ) -> NoReturn: - msg = "`_create_series_from_scalar` for PySparkNamespace exists only for compatibility" - raise NotImplementedError(msg) - - def _create_expr_from_callable( # pragma: no cover - self, - func: Callable[[SparkLikeLazyFrame], list[SparkLikeExpr]], - *, - depth: int, - function_name: str, - root_names: list[str] | None, - output_names: list[str] | None, - ) -> SparkLikeExpr: - msg = "`_create_expr_from_callable` for PySparkNamespace exists only for compatibility" - raise NotImplementedError(msg) - def all(self) -> SparkLikeExpr: def _all(df: SparkLikeLazyFrame) -> list[Column]: import pyspark.sql.functions as F # noqa: N812 return [F.col(col_name) for col_name in df.columns] - return SparkLikeExpr( + return SparkLikeExpr( # type: ignore[abstract] call=_all, depth=0, function_name="all", @@ -73,11 +45,11 @@ def all_horizontal(self, *exprs: IntoSparkLikeExpr) -> SparkLikeExpr: parsed_exprs = parse_into_exprs(*exprs, namespace=self) def func(df: SparkLikeLazyFrame) -> list[Column]: - cols = [c for _expr in parsed_exprs for c in _expr._call(df)] + cols = [c for _expr in parsed_exprs for c in _expr(df)] col_name = get_column_name(df, cols[0]) return [reduce(operator.and_, cols).alias(col_name)] - return SparkLikeExpr( + return SparkLikeExpr( # type: ignore[abstract] call=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="all_horizontal", diff --git a/narwhals/typing.py b/narwhals/typing.py index c2fd7fd1f..e6503e1ae 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -2,19 +2,26 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Generic from typing import Literal from typing import Protocol +from typing import Sequence from typing import TypeVar from typing import Union if TYPE_CHECKING: import sys + from narwhals.dtypes import DType + from narwhals.utils import Implementation + if sys.version_info >= (3, 10): from typing import TypeAlias else: from typing_extensions import TypeAlias + from typing_extensions import Self + from narwhals import dtypes from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame @@ -37,6 +44,47 @@ class DataFrameLike(Protocol): def __dataframe__(self, *args: Any, **kwargs: Any) -> Any: ... +class CompliantSeries(Protocol): + @property + def name(self) -> str: ... + def __narwhals_series__(self) -> CompliantSeries: ... + def alias(self, name: str) -> Self: ... + + +class CompliantDataFrame(Protocol): + def __narwhals_dataframe__(self) -> CompliantDataFrame: ... + def __narwhals_namespace__(self) -> Any: ... + + +class CompliantLazyFrame(Protocol): + def __narwhals_lazyframe__(self) -> CompliantLazyFrame: ... + def __narwhals_namespace__(self) -> Any: ... + + +CompliantSeriesT_co = TypeVar( + "CompliantSeriesT_co", bound=CompliantSeries, covariant=True +) + + +class CompliantExpr(Protocol, Generic[CompliantSeriesT_co]): + _implementation: Implementation + _output_names: list[str] | None + _root_names: list[str] | None + _depth: int + _function_name: str + + def __call__(self, df: Any) -> Sequence[CompliantSeriesT_co]: ... + def __narwhals_expr__(self) -> None: ... + def __narwhals_namespace__(self) -> CompliantNamespace[CompliantSeriesT_co]: ... + def is_null(self) -> Self: ... + def alias(self, name: str) -> Self: ... + def cast(self, dtype: DType) -> Self: ... + + +class CompliantNamespace(Protocol, Generic[CompliantSeriesT_co]): + def col(self, *column_names: str) -> CompliantExpr[CompliantSeriesT_co]: ... + + IntoExpr: TypeAlias = Union["Expr", str, "Series[Any]"] """Anything which can be converted to an expression. @@ -218,6 +266,9 @@ class DTypes: __all__ = [ + "CompliantDataFrame", + "CompliantLazyFrame", + "CompliantSeries", "DataFrameT", "Frame", "FrameT",