diff --git a/narwhals/_arrow/expr.py b/narwhals/_arrow/expr.py index a79d68ac4..3f53fe661 100644 --- a/narwhals/_arrow/expr.py +++ b/narwhals/_arrow/expr.py @@ -12,10 +12,10 @@ from narwhals._arrow.expr_name import ArrowExprNameNamespace from narwhals._arrow.expr_str import ArrowExprStringNamespace from narwhals._arrow.series import ArrowSeries +from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._expression_parsing import reuse_series_implementation from narwhals.dependencies import get_numpy from narwhals.dependencies import is_numpy_array -from narwhals.exceptions import AnonymousExprError from narwhals.exceptions import ColumnNotFoundError from narwhals.typing import CompliantExpr from narwhals.utils import Implementation @@ -39,8 +39,8 @@ def __init__( *, depth: int, function_name: str, - root_names: list[str] | None, - output_names: list[str] | None, + evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]], + alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, backend_version: tuple[int, ...], version: Version, kwargs: dict[str, Any], @@ -48,10 +48,9 @@ def __init__( self._call = call self._depth = depth self._function_name = function_name - self._root_names = root_names self._depth = depth - self._output_names = output_names - self._implementation = Implementation.PYARROW + self._evaluate_output_names = evaluate_output_names + self._alias_output_names = alias_output_names self._backend_version = backend_version self._version = version self._kwargs = kwargs @@ -61,8 +60,6 @@ def __repr__(self: Self) -> str: # pragma: no cover f"ArrowExpr(" f"depth={self._depth}, " f"function_name={self._function_name}, " - f"root_names={self._root_names}, " - f"output_names={self._output_names}" ) def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]: @@ -98,8 +95,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: func, depth=0, function_name="col", - root_names=list(column_names), - output_names=list(column_names), + evaluate_output_names=lambda _df: list(column_names), + alias_output_names=None, backend_version=backend_version, version=version, kwargs={}, @@ -129,8 +126,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: func, depth=0, function_name="nth", - root_names=None, - output_names=None, + evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], + alias_output_names=None, backend_version=backend_version, version=version, kwargs={}, @@ -265,14 +262,20 @@ def shift(self: Self, n: int) -> Self: return reuse_series_implementation(self, "shift", n=n) def alias(self: Self, name: str) -> Self: + def alias_output_names(names: Sequence[str]) -> Sequence[str]: + if len(names) != 1: + msg = f"Expected function with single output, found output names: {names}" + raise ValueError(msg) + return [name] + # Define this one manually, so that we can # override `output_names` and not increase depth return self.__class__( lambda df: [series.alias(name) for series in self._call(df)], depth=self._depth, function_name=self._function_name, - root_names=self._root_names, - output_names=[name], + evaluate_output_names=self._evaluate_output_names, + alias_output_names=alias_output_names, backend_version=self._backend_version, version=self._version, kwargs={**self._kwargs, "name": name}, @@ -390,22 +393,28 @@ def clip(self: Self, lower_bound: Any | None, upper_bound: Any | None) -> Self: def over(self: Self, keys: list[str]) -> Self: def func(df: ArrowDataFrame) -> list[ArrowSeries]: - if self._output_names is None: - msg = ".over" - raise AnonymousExprError.from_expr_name(msg) + output_names, aliases = evaluate_output_names_and_aliases(self, df, []) + if overlap := set(output_names).intersection(keys): + # E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined, + # we just don't support it yet. + msg = ( + f"Column names {overlap} appear in both expression output names and in `over` keys.\n" + "This is not yet supported." + ) + raise NotImplementedError(msg) tmp = df.group_by(*keys, drop_null_keys=False).agg(self) tmp = df.simple_select(*keys).join( tmp, how="left", left_on=keys, right_on=keys, suffix="_right" ) - return [tmp[name] for name in self._output_names] + return [tmp[alias] for alias in aliases] return self.__class__( func, depth=self._depth + 1, function_name=self._function_name + "->over", - root_names=self._root_names, - output_names=self._output_names, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, kwargs={**self._kwargs, "keys": keys}, @@ -446,8 +455,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: func, depth=self._depth + 1, function_name=self._function_name + "->map_batches", - root_names=self._root_names, - output_names=self._output_names, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, kwargs={**self._kwargs, "function": function, "return_dtype": return_dtype}, diff --git a/narwhals/_arrow/expr_name.py b/narwhals/_arrow/expr_name.py index 4c4991bfe..a8741bf61 100644 --- a/narwhals/_arrow/expr_name.py +++ b/narwhals/_arrow/expr_name.py @@ -3,8 +3,6 @@ from typing import TYPE_CHECKING from typing import Callable -from narwhals.exceptions import AnonymousExprError - if TYPE_CHECKING: from typing_extensions import Self @@ -16,131 +14,118 @@ def __init__(self: Self, expr: ArrowExpr) -> None: self._compliant_expr = expr def keep(self: Self) -> ArrowExpr: - root_names = self._compliant_expr._root_names - - if root_names is None: - msg = ".name.keep" - raise AnonymousExprError.from_expr_name(msg) - return self._compliant_expr.__class__( lambda df: [ series.alias(name) - for series, name in zip(self._compliant_expr._call(df), root_names) + for series, name in zip( + self._compliant_expr._call(df), + self._compliant_expr._evaluate_output_names(df), + ) ], depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=root_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=None, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, kwargs=self._compliant_expr._kwargs, ) def map(self: Self, function: Callable[[str], str]) -> ArrowExpr: - root_names = self._compliant_expr._root_names - - if root_names is None: - msg = ".name.map" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [function(str(name)) for name in root_names] - return self._compliant_expr.__class__( lambda df: [ - series.alias(name) - for series, name in zip(self._compliant_expr._call(df), output_names) + series.alias(function(str(name))) + for series, name in zip( + self._compliant_expr._call(df), + self._compliant_expr._evaluate_output_names(df), + ) ], depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + function(str(name)) for name in output_names + ], backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, kwargs={**self._compliant_expr._kwargs, "function": function}, ) def prefix(self: Self, prefix: str) -> ArrowExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.prefix" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [prefix + str(name) for name in root_names] return self._compliant_expr.__class__( lambda df: [ - series.alias(name) - for series, name in zip(self._compliant_expr._call(df), output_names) + series.alias(f"{prefix}{name}") + for series, name in zip( + self._compliant_expr._call(df), + self._compliant_expr._evaluate_output_names(df), + ) ], depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + f"{prefix}{output_name}" for output_name in output_names + ], backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, kwargs={**self._compliant_expr._kwargs, "prefix": prefix}, ) def suffix(self: Self, suffix: str) -> ArrowExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.suffix" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [str(name) + suffix for name in root_names] - return self._compliant_expr.__class__( lambda df: [ - series.alias(name) - for series, name in zip(self._compliant_expr._call(df), output_names) + series.alias(f"{name}{suffix}") + for series, name in zip( + self._compliant_expr._call(df), + self._compliant_expr._evaluate_output_names(df), + ) ], depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + f"{output_name}{suffix}" for output_name in output_names + ], backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, kwargs={**self._compliant_expr._kwargs, "suffix": suffix}, ) def to_lowercase(self: Self) -> ArrowExpr: - root_names = self._compliant_expr._root_names - - if root_names is None: - msg = ".name.to_lowercase" - raise AnonymousExprError.from_expr_name(msg) - output_names = [str(name).lower() for name in root_names] - return self._compliant_expr.__class__( lambda df: [ - series.alias(name) - for series, name in zip(self._compliant_expr._call(df), output_names) + series.alias(str(name).lower()) + for series, name in zip( + self._compliant_expr._call(df), + self._compliant_expr._evaluate_output_names(df), + ) ], depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + str(name).lower() for name in output_names + ], backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, kwargs=self._compliant_expr._kwargs, ) def to_uppercase(self: Self) -> ArrowExpr: - root_names = self._compliant_expr._root_names - - if root_names is None: - msg = ".name.to_uppercase" - raise AnonymousExprError.from_expr_name(msg) - output_names = [str(name).upper() for name in root_names] - return self._compliant_expr.__class__( lambda df: [ - series.alias(name) - for series, name in zip(self._compliant_expr._call(df), output_names) + series.alias(str(name).upper()) + for series, name in zip( + self._compliant_expr._call(df), + self._compliant_expr._evaluate_output_names(df), + ) ], depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + str(name).upper() for name in output_names + ], backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, kwargs=self._compliant_expr._kwargs, diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 86ce73cb0..6e1f61b93 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -1,27 +1,23 @@ from __future__ import annotations import collections +import re from typing import TYPE_CHECKING from typing import Any -from typing import Callable from typing import Iterator -from typing import Sequence import pyarrow as pa import pyarrow.compute as pc +from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._expression_parsing import is_simple_aggregation -from narwhals.exceptions import AnonymousExprError from narwhals.utils import generate_temporary_column_name -from narwhals.utils import remove_prefix if TYPE_CHECKING: from typing_extensions import Self from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.expr import ArrowExpr - from narwhals._arrow.series import ArrowSeries - from narwhals.typing import CompliantExpr POLARS_TO_ARROW_AGGREGATIONS = { "sum": "sum", @@ -49,18 +45,97 @@ def __init__( self._grouped = pa.TableGroupBy(self._df._native_frame, list(self._keys)) def agg(self: Self, *exprs: ArrowExpr) -> ArrowDataFrame: + all_simple_aggs = True for expr in exprs: - if expr._output_names is None: - msg = "group_by.agg" - raise AnonymousExprError.from_expr_name(msg) - - return agg_arrow( - self._grouped, - exprs, - self._keys, - self._df._from_native_frame, - backend_version=self._df._backend_version, - ) + if not ( + is_simple_aggregation(expr) + and re.sub(r"(\w+->)", "", expr._function_name) + in POLARS_TO_ARROW_AGGREGATIONS + ): + all_simple_aggs = False + break + + if not all_simple_aggs: + msg = ( + "Non-trivial complex aggregation found.\n\n" + "Hint: you were probably trying to apply a non-elementary aggregation with a " + "pyarrow table.\n" + "Please rewrite your query such that group-by aggregations " + "are elementary. For example, instead of:\n\n" + " df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" + "use:\n\n" + " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" + ) + raise ValueError(msg) + + aggs: list[tuple[str, str, pc.FunctionOptions | None]] = [] + expected_pyarrow_column_names: list[str] = self._keys.copy() + new_column_names: list[str] = self._keys.copy() + + for expr in exprs: + output_names, aliases = evaluate_output_names_and_aliases( + expr, self._df, self._keys + ) + + if expr._depth == 0: + # e.g. agg(nw.len()) # noqa: ERA001 + if expr._function_name != "len": # pragma: no cover + msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" + raise AssertionError(msg) + + new_column_names.append(aliases[0]) + expected_pyarrow_column_names.append(f"{self._keys[0]}_count") + aggs.append((self._keys[0], "count", pc.CountOptions(mode="all"))) + + continue + + function_name = re.sub(r"(\w+->)", "", expr._function_name) + if function_name in {"std", "var"}: + option = pc.VarianceOptions(ddof=expr._kwargs["ddof"]) + elif function_name in {"len", "n_unique"}: + option = pc.CountOptions(mode="all") + elif function_name == "count": + option = pc.CountOptions(mode="only_valid") + else: + option = None + + function_name = POLARS_TO_ARROW_AGGREGATIONS[function_name] + + new_column_names.extend(aliases) + expected_pyarrow_column_names.extend( + [f"{output_name}_{function_name}" for output_name in output_names] + ) + aggs.extend( + [(output_name, function_name, option) for output_name in output_names] + ) + + result_simple = self._grouped.aggregate(aggs) + + # Rename columns, being very careful + expected_old_names_indices: dict[str, list[int]] = collections.defaultdict(list) + for idx, item in enumerate(expected_pyarrow_column_names): + expected_old_names_indices[item].append(idx) + if not ( + set(result_simple.column_names) == set(expected_pyarrow_column_names) + and len(result_simple.column_names) == len(expected_pyarrow_column_names) + ): # pragma: no cover + msg = ( + f"Safety assertion failed, expected {expected_pyarrow_column_names} " + f"got {result_simple.column_names}, " + "please report a bug at https://github.com/narwhals-dev/narwhals/issues" + ) + raise AssertionError(msg) + index_map: list[int] = [ + expected_old_names_indices[item].pop(0) for item in result_simple.column_names + ] + new_column_names = [new_column_names[i] for i in index_map] + result_simple = result_simple.rename_columns(new_column_names) + if self._df._backend_version < (12, 0, 0): + columns = result_simple.column_names + result_simple = result_simple.select( + [*self._keys, *[col for col in columns if col not in self._keys]] + ) + return self._df._from_native_frame(result_simple) def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]: col_token = generate_temporary_column_name(n_bytes=8, columns=self._df.columns) @@ -91,109 +166,3 @@ def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]: ) for v in pc.unique(key_values) ) - - -def agg_arrow( - grouped: pa.TableGroupBy, - exprs: Sequence[CompliantExpr[ArrowSeries]], - keys: list[str], - from_dataframe: Callable[[Any], ArrowDataFrame], - backend_version: tuple[int, ...], -) -> ArrowDataFrame: - all_simple_aggs = True - for expr in exprs: - if not ( - is_simple_aggregation(expr) - and remove_prefix(expr._function_name, "col->") - in POLARS_TO_ARROW_AGGREGATIONS - ): - all_simple_aggs = False - break - - if not all_simple_aggs: - msg = ( - "Non-trivial complex aggregation found.\n\n" - "Hint: you were probably trying to apply a non-elementary aggregation with a " - "pyarrow table.\n" - "Please rewrite your query such that group-by aggregations " - "are elementary. For example, instead of:\n\n" - " df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" - "use:\n\n" - " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" - ) - raise ValueError(msg) - - aggs: list[tuple[str, str, pc.FunctionOptions | None]] = [] - expected_pyarrow_column_names: list[str] = keys.copy() - new_column_names: list[str] = keys.copy() - - for expr in exprs: - if expr._depth == 0: - # e.g. agg(nw.len()) # noqa: ERA001 - if ( - expr._output_names is None or expr._function_name != "len" - ): # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" - raise AssertionError(msg) - - new_column_names.append(expr._output_names[0]) - expected_pyarrow_column_names.append(f"{keys[0]}_count") - aggs.append((keys[0], "count", pc.CountOptions(mode="all"))) - - continue - - # e.g. agg(nw.mean('a')) # noqa: ERA001 - if ( - expr._depth != 1 or expr._root_names is None or expr._output_names is None - ): # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" - raise AssertionError(msg) - - function_name = remove_prefix(expr._function_name, "col->") - - if function_name in {"std", "var"}: - option = pc.VarianceOptions(ddof=expr._kwargs["ddof"]) - elif function_name in {"len", "n_unique"}: - option = pc.CountOptions(mode="all") - elif function_name == "count": - option = pc.CountOptions(mode="only_valid") - else: - option = None - - function_name = POLARS_TO_ARROW_AGGREGATIONS[function_name] - - new_column_names.extend(expr._output_names) - expected_pyarrow_column_names.extend( - [f"{root_name}_{function_name}" for root_name in expr._root_names] - ) - aggs.extend( - [(root_name, function_name, option) for root_name in expr._root_names] - ) - - result_simple = grouped.aggregate(aggs) - - # Rename columns, being very careful - expected_old_names_indices: dict[str, list[int]] = collections.defaultdict(list) - for idx, item in enumerate(expected_pyarrow_column_names): - expected_old_names_indices[item].append(idx) - if not ( - set(result_simple.column_names) == set(expected_pyarrow_column_names) - and len(result_simple.column_names) == len(expected_pyarrow_column_names) - ): # pragma: no cover - msg = ( - f"Safety assertion failed, expected {expected_pyarrow_column_names} " - f"got {result_simple.column_names}, " - "please report a bug at https://github.com/narwhals-dev/narwhals/issues" - ) - raise AssertionError(msg) - index_map: list[int] = [ - expected_old_names_indices[item].pop(0) for item in result_simple.column_names - ] - new_column_names = [new_column_names[i] for i in index_map] - result_simple = result_simple.rename_columns(new_column_names) - if backend_version < (12, 0, 0): - columns = result_simple.column_names - result_simple = result_simple.select( - [*keys, *[col for col in columns if col not in keys]] - ) - return from_dataframe(result_simple) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index f02ee44c8..8274e5ffc 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -3,6 +3,7 @@ from functools import reduce from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Iterable from typing import Literal from typing import Sequence @@ -18,9 +19,9 @@ from narwhals._arrow.utils import diagonal_concat from narwhals._arrow.utils import horizontal_concat from narwhals._arrow.utils import vertical_concat -from narwhals._expression_parsing import combine_root_names +from narwhals._expression_parsing import combine_alias_output_names +from narwhals._expression_parsing import combine_evaluate_output_names from narwhals._expression_parsing import parse_into_exprs -from narwhals._expression_parsing import reduce_output_names from narwhals.typing import CompliantNamespace from narwhals.utils import Implementation from narwhals.utils import import_dtypes_module @@ -42,8 +43,8 @@ def _create_expr_from_callable( *, depth: int, function_name: str, - root_names: list[str] | None, - output_names: list[str] | None, + evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]], + alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, kwargs: dict[str, Any], ) -> ArrowExpr: from narwhals._arrow.expr import ArrowExpr @@ -52,8 +53,8 @@ def _create_expr_from_callable( func, depth=depth, function_name=function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=evaluate_output_names, + alias_output_names=alias_output_names, backend_version=self._backend_version, version=self._version, kwargs=kwargs, @@ -66,8 +67,8 @@ def _create_expr_from_series(self: Self, series: ArrowSeries) -> ArrowExpr: lambda _df: [series], depth=0, function_name="series", - root_names=None, - output_names=None, + evaluate_output_names=lambda _df: [series.name], + alias_output_names=None, backend_version=self._backend_version, version=self._version, kwargs={}, @@ -133,8 +134,8 @@ def len(self: Self) -> ArrowExpr: ], depth=0, function_name="len", - root_names=None, - output_names=["len"], + evaluate_output_names=lambda _df: ["len"], + alias_output_names=None, backend_version=self._backend_version, version=self._version, kwargs={}, @@ -156,8 +157,8 @@ def all(self: Self) -> ArrowExpr: ], depth=0, function_name="all", - root_names=None, - output_names=None, + evaluate_output_names=lambda df: df.columns, + alias_output_names=None, backend_version=self._backend_version, version=self._version, kwargs={}, @@ -179,8 +180,8 @@ def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries: lambda df: [_lit_arrow_series(df)], depth=0, function_name="lit", - root_names=None, - output_names=["literal"], + evaluate_output_names=lambda _df: ["literal"], + alias_output_names=None, backend_version=self._backend_version, version=self._version, kwargs={}, @@ -197,8 +198,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: func=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="all_horizontal", - root_names=combine_root_names(parsed_exprs), - output_names=reduce_output_names(parsed_exprs), + evaluate_output_names=combine_evaluate_output_names(*parsed_exprs), + alias_output_names=combine_alias_output_names(*parsed_exprs), kwargs={"exprs": exprs}, ) @@ -213,8 +214,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: func=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="any_horizontal", - root_names=combine_root_names(parsed_exprs), - output_names=reduce_output_names(parsed_exprs), + evaluate_output_names=combine_evaluate_output_names(*parsed_exprs), + alias_output_names=combine_alias_output_names(*parsed_exprs), kwargs={"exprs": exprs}, ) @@ -233,8 +234,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: func=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="sum_horizontal", - root_names=combine_root_names(parsed_exprs), - output_names=reduce_output_names(parsed_exprs), + evaluate_output_names=combine_evaluate_output_names(*parsed_exprs), + alias_output_names=combine_alias_output_names(*parsed_exprs), kwargs={"exprs": exprs}, ) @@ -261,8 +262,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: func=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="mean_horizontal", - root_names=combine_root_names(parsed_exprs), - output_names=reduce_output_names(parsed_exprs), + evaluate_output_names=combine_evaluate_output_names(*parsed_exprs), + alias_output_names=combine_alias_output_names(*parsed_exprs), kwargs={"exprs": exprs}, ) @@ -288,8 +289,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: func=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="min_horizontal", - root_names=combine_root_names(parsed_exprs), - output_names=reduce_output_names(parsed_exprs), + evaluate_output_names=combine_evaluate_output_names(*parsed_exprs), + alias_output_names=combine_alias_output_names(*parsed_exprs), kwargs={"exprs": exprs}, ) @@ -315,8 +316,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: func=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="max_horizontal", - root_names=combine_root_names(parsed_exprs), - output_names=reduce_output_names(parsed_exprs), + evaluate_output_names=combine_evaluate_output_names(*parsed_exprs), + alias_output_names=combine_alias_output_names(*parsed_exprs), kwargs={"exprs": exprs}, ) @@ -391,8 +392,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: func=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="concat_str", - root_names=combine_root_names(parsed_exprs), - output_names=reduce_output_names(parsed_exprs), + evaluate_output_names=combine_evaluate_output_names(*parsed_exprs), + alias_output_names=combine_alias_output_names(*parsed_exprs), kwargs={ "exprs": exprs, "separator": separator, @@ -472,8 +473,10 @@ def then(self: Self, value: ArrowExpr | ArrowSeries | Any) -> ArrowThen: self, depth=0, function_name="whenthen", - root_names=None, - output_names=None, + evaluate_output_names=getattr( + value, "_evaluate_output_names", lambda _df: ["literal"] + ), + alias_output_names=getattr(value, "_alias_output_names", None), backend_version=self._backend_version, version=self._version, kwargs={"value": value}, @@ -487,8 +490,8 @@ def __init__( *, depth: int, function_name: str, - root_names: list[str] | None, - output_names: list[str] | None, + evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]], + alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, backend_version: tuple[int, ...], version: Version, kwargs: dict[str, Any], @@ -498,8 +501,8 @@ def __init__( self._call = call self._depth = depth self._function_name = function_name - self._root_names = root_names - self._output_names = output_names + self._evaluate_output_names = evaluate_output_names + self._alias_output_names = alias_output_names self._kwargs = kwargs def otherwise(self: Self, value: ArrowExpr | ArrowSeries | Any) -> ArrowExpr: diff --git a/narwhals/_arrow/selectors.py b/narwhals/_arrow/selectors.py index 36feb5d56..3fee25472 100644 --- a/narwhals/_arrow/selectors.py +++ b/narwhals/_arrow/selectors.py @@ -29,12 +29,15 @@ def by_dtype(self: Self, dtypes: list[DType | type[DType]]) -> ArrowSelector: def func(df: ArrowDataFrame) -> list[ArrowSeries]: return [df[col] for col in df.columns if df.schema[col] in dtypes] + def evalute_output_names(df: ArrowDataFrame) -> Sequence[str]: + return [col for col in df.columns if df.schema[col] in dtypes] + return ArrowSelector( func, depth=0, - function_name="type_selector", - root_names=None, - output_names=None, + function_name="selector", + evaluate_output_names=evalute_output_names, + alias_output_names=None, backend_version=self._backend_version, version=self._version, kwargs={"dtypes": dtypes}, @@ -76,9 +79,9 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]: return ArrowSelector( func, depth=0, - function_name="type_selector", - root_names=None, - output_names=None, + function_name="selector", + evaluate_output_names=lambda df: df.columns, + alias_output_names=None, backend_version=self._backend_version, version=self._version, kwargs={}, @@ -90,9 +93,7 @@ def __repr__(self: Self) -> str: # pragma: no cover return ( f"ArrowSelector(" f"depth={self._depth}, " - f"function_name={self._function_name}, " - f"root_names={self._root_names}, " - f"output_names={self._output_names}" + f"function_name={self._function_name})" ) def _to_expr(self: Self) -> ArrowExpr: @@ -100,8 +101,8 @@ def _to_expr(self: Self) -> ArrowExpr: self._call, depth=self._depth, function_name=self._function_name, - root_names=self._root_names, - output_names=self._output_names, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, kwargs=self._kwargs, @@ -111,16 +112,22 @@ def __sub__(self: Self, other: Self | Any) -> ArrowSelector | Any: if isinstance(other, ArrowSelector): def call(df: ArrowDataFrame) -> list[ArrowSeries]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) lhs = self._call(df) - rhs = other._call(df) - return [x for x in lhs if x.name not in {x.name for x in rhs}] + return [x for x, name in zip(lhs, lhs_names) if name not in rhs_names] + + def evaluate_output_names(df: ArrowDataFrame) -> list[str]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) + return [x for x in lhs_names if x not in rhs_names] return ArrowSelector( call, depth=0, - function_name="type_selector", - root_names=None, - output_names=None, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, backend_version=self._backend_version, version=self._version, kwargs={**self._kwargs, "other": other}, @@ -131,17 +138,27 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]: def __or__(self: Self, other: Self | Any) -> ArrowSelector | Any: if isinstance(other, ArrowSelector): - def call(df: ArrowDataFrame) -> Sequence[ArrowSeries]: - lhs = self(df) - rhs = other(df) - return [*(x for x in lhs if x.name not in {x.name for x in rhs}), *rhs] + def call(df: ArrowDataFrame) -> list[ArrowSeries]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) + lhs = self._call(df) + rhs = other._call(df) + return [ + *(x for x, name in zip(lhs, lhs_names) if name not in rhs_names), + *rhs, + ] + + def evaluate_output_names(df: ArrowDataFrame) -> list[str]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) + return [*(x for x in lhs_names if x not in rhs_names), *rhs_names] return ArrowSelector( call, depth=0, - function_name="type_selector", - root_names=None, - output_names=None, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, backend_version=self._backend_version, version=self._version, kwargs={**self._kwargs, "other": other}, @@ -153,16 +170,22 @@ def __and__(self: Self, other: Self | Any) -> ArrowSelector | Any: if isinstance(other, ArrowSelector): def call(df: ArrowDataFrame) -> list[ArrowSeries]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) lhs = self._call(df) - rhs = other._call(df) - return [x for x in lhs if x.name in {x.name for x in rhs}] + return [x for x, name in zip(lhs, lhs_names) if name in rhs_names] + + def evaluate_output_names(df: ArrowDataFrame) -> list[str]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) + return [x for x in lhs_names if x in rhs_names] return ArrowSelector( call, depth=0, - function_name="type_selector", - root_names=None, - output_names=None, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, backend_version=self._backend_version, version=self._version, kwargs={**self._kwargs, "other": other}, diff --git a/narwhals/_dask/expr.py b/narwhals/_dask/expr.py index 777916282..25b1cf893 100644 --- a/narwhals/_dask/expr.py +++ b/narwhals/_dask/expr.py @@ -13,9 +13,8 @@ from narwhals._dask.utils import binary_operation_returns_scalar from narwhals._dask.utils import maybe_evaluate from narwhals._dask.utils import narwhals_to_native_dtype -from narwhals._expression_parsing import infer_new_root_output_names +from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._pandas_like.utils import native_to_narwhals_dtype -from narwhals.exceptions import AnonymousExprError from narwhals.exceptions import ColumnNotFoundError from narwhals.exceptions import InvalidOperationError from narwhals.typing import CompliantExpr @@ -45,8 +44,8 @@ def __init__( *, depth: int, function_name: str, - root_names: list[str] | None, - output_names: list[str] | None, + evaluate_output_names: Callable[[DaskLazyFrame], Sequence[str]], + alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, # Whether the expression is a length-1 Series resulting from # a reduction, such as `nw.col('a').sum()` returns_scalar: bool, @@ -57,8 +56,8 @@ def __init__( self._call = call self._depth = depth self._function_name = function_name - self._root_names = root_names - self._output_names = output_names + self._evaluate_output_names = evaluate_output_names + self._alias_output_names = alias_output_names self._returns_scalar = returns_scalar self._backend_version = backend_version self._version = version @@ -96,8 +95,8 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: func, depth=0, function_name="col", - root_names=list(column_names), - output_names=list(column_names), + evaluate_output_names=lambda _df: list(column_names), + alias_output_names=None, returns_scalar=False, backend_version=backend_version, version=version, @@ -120,8 +119,8 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: func, depth=0, function_name="nth", - root_names=None, - output_names=None, + evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], + alias_output_names=None, returns_scalar=False, backend_version=backend_version, version=version, @@ -135,48 +134,51 @@ def _from_call( expr_name: str, *, returns_scalar: bool, - **kwargs: Any, + **expressifiable_args: Self | Any, ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: - results = [] - inputs = self._call(df) - _kwargs = {key: maybe_evaluate(df, value) for key, value in kwargs.items()} - for _input in inputs: - name = _input.name + native_results: list[dx.Series] = [] + native_series_list = self._call(df) + other_native_series = { + key: maybe_evaluate(df, value) + for key, value in expressifiable_args.items() + } + for native_series in native_series_list: if self._returns_scalar: - _input = _input[0] - result = call(_input, **_kwargs) + result_native = call(native_series[0], **other_native_series) + else: + result_native = call(native_series, **other_native_series) if returns_scalar: - result = result.to_series() - result = result.rename(name) - results.append(result) - return results - - root_names, output_names = infer_new_root_output_names(self, **kwargs) + native_results.append(result_native.to_series()) + else: + native_results.append(result_native) + return native_results return self.__class__( func, depth=self._depth + 1, function_name=f"{self._function_name}->{expr_name}", - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, returns_scalar=returns_scalar, backend_version=self._backend_version, version=self._version, - kwargs={**self._kwargs, **kwargs}, + kwargs={**self._kwargs, **expressifiable_args}, ) def alias(self: Self, name: str) -> Self: - def func(df: DaskLazyFrame) -> list[dx.Series]: - inputs = self._call(df) - return [_input.rename(name) for _input in inputs] + def alias_output_names(names: Sequence[str]) -> Sequence[str]: + if len(names) != 1: + msg = f"Expected function with single output, found output names: {names}" + raise ValueError(msg) + return [name] return self.__class__( - func, + self._call, depth=self._depth, function_name=self._function_name, - root_names=self._root_names, - output_names=[name], + evaluate_output_names=self._evaluate_output_names, + alias_output_names=alias_output_names, returns_scalar=self._returns_scalar, backend_version=self._backend_version, version=self._version, @@ -659,10 +661,15 @@ def null_count(self: Self) -> Self: def over(self: Self, keys: list[str]) -> Self: def func(df: DaskLazyFrame) -> list[Any]: - if self._output_names is None: - msg = "over" - raise AnonymousExprError.from_expr_name(msg) - + output_names, aliases = evaluate_output_names_and_aliases(self, df, []) + if overlap := set(output_names).intersection(keys): + # E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined, + # we just don't support it yet. + msg = ( + f"Column names {overlap} appear in both expression output names and in `over` keys.\n" + "This is not yet supported." + ) + raise NotImplementedError(msg) if df._native_frame.npartitions == 1: # pragma: no cover tmp = df.group_by(*keys, drop_null_keys=False).agg(self) tmp_native = ( @@ -670,7 +677,8 @@ def func(df: DaskLazyFrame) -> list[Any]: .join(tmp, how="left", left_on=keys, right_on=keys, suffix="_right") ._native_frame ) - return [tmp_native[name] for name in self._output_names] + return [tmp_native[name] for name in aliases] + # https://github.com/dask/dask/issues/6659 msg = ( "`Expr.over` is not supported for Dask backend with multiple partitions." ) @@ -680,8 +688,8 @@ def func(df: DaskLazyFrame) -> list[Any]: func, depth=self._depth + 1, function_name=self._function_name + "->over", - root_names=self._root_names, - output_names=self._output_names, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, returns_scalar=False, backend_version=self._backend_version, version=self._version, diff --git a/narwhals/_dask/expr_name.py b/narwhals/_dask/expr_name.py index bbed6addc..a261dcdf1 100644 --- a/narwhals/_dask/expr_name.py +++ b/narwhals/_dask/expr_name.py @@ -3,8 +3,6 @@ from typing import TYPE_CHECKING from typing import Callable -from narwhals.exceptions import AnonymousExprError - if TYPE_CHECKING: from typing_extensions import Self @@ -16,21 +14,12 @@ def __init__(self: Self, expr: DaskExpr) -> None: self._compliant_expr = expr def keep(self: Self) -> DaskExpr: - root_names = self._compliant_expr._root_names - - if root_names is None: - msg = ".name.keep" - raise AnonymousExprError.from_expr_name(msg) - return self._compliant_expr.__class__( - lambda df: [ - series.rename(name) - for series, name in zip(self._compliant_expr._call(df), root_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=root_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=None, returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -38,23 +27,14 @@ def keep(self: Self) -> DaskExpr: ) def map(self: Self, function: Callable[[str], str]) -> DaskExpr: - root_names = self._compliant_expr._root_names - - if root_names is None: - msg = ".name.map" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [function(str(name)) for name in root_names] - return self._compliant_expr.__class__( - lambda df: [ - series.rename(name) - for series, name in zip(self._compliant_expr._call(df), output_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + function(str(name)) for name in output_names + ], returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -62,21 +42,14 @@ def map(self: Self, function: Callable[[str], str]) -> DaskExpr: ) def prefix(self: Self, prefix: str) -> DaskExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.prefix" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [prefix + str(name) for name in root_names] return self._compliant_expr.__class__( - lambda df: [ - series.rename(name) - for series, name in zip(self._compliant_expr._call(df), output_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + f"{prefix}{output_name}" for output_name in output_names + ], returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -84,22 +57,14 @@ def prefix(self: Self, prefix: str) -> DaskExpr: ) def suffix(self: Self, suffix: str) -> DaskExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.suffix" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [str(name) + suffix for name in root_names] - return self._compliant_expr.__class__( - lambda df: [ - series.rename(name) - for series, name in zip(self._compliant_expr._call(df), output_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + f"{output_name}{suffix}" for output_name in output_names + ], returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -107,22 +72,14 @@ def suffix(self: Self, suffix: str) -> DaskExpr: ) def to_lowercase(self: Self) -> DaskExpr: - root_names = self._compliant_expr._root_names - - if root_names is None: - msg = ".name.to_lowercase" - raise AnonymousExprError.from_expr_name(msg) - output_names = [str(name).lower() for name in root_names] - return self._compliant_expr.__class__( - lambda df: [ - series.rename(name) - for series, name in zip(self._compliant_expr._call(df), output_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + str(name).lower() for name in output_names + ], returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -130,22 +87,14 @@ def to_lowercase(self: Self) -> DaskExpr: ) def to_uppercase(self: Self) -> DaskExpr: - root_names = self._compliant_expr._root_names - - if root_names is None: - msg = ".name.to_uppercase" - raise AnonymousExprError.from_expr_name(msg) - output_names = [str(name).upper() for name in root_names] - return self._compliant_expr.__class__( - lambda df: [ - series.rename(name) - for series, name in zip(self._compliant_expr._call(df), output_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + str(name).upper() for name in output_names + ], returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, diff --git a/narwhals/_dask/group_by.py b/narwhals/_dask/group_by.py index 7fd434af1..e4ef92dbe 100644 --- a/narwhals/_dask/group_by.py +++ b/narwhals/_dask/group_by.py @@ -1,6 +1,6 @@ from __future__ import annotations -from copy import copy +import re from typing import TYPE_CHECKING from typing import Any from typing import Callable @@ -8,9 +8,8 @@ import dask.dataframe as dd +from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._expression_parsing import is_simple_aggregation -from narwhals.exceptions import AnonymousExprError -from narwhals.utils import remove_prefix try: import dask.dataframe.dask_expr as dx @@ -90,14 +89,6 @@ def agg( self: Self, *exprs: DaskExpr, ) -> DaskLazyFrame: - output_names: list[str] = copy(self._keys) - for expr in exprs: - if expr._output_names is None: - msg = "group_by.agg" - raise AnonymousExprError.from_expr_name(msg) - - output_names.extend(expr._output_names) - return agg_dask( self._df, self._grouped, @@ -134,7 +125,7 @@ def agg_dask( for expr in exprs: if not ( is_simple_aggregation(expr) - and remove_prefix(expr._function_name, "col->") in POLARS_TO_DASK_AGGREGATIONS + and re.sub(r"(\w+->)", "", expr._function_name) in POLARS_TO_DASK_AGGREGATIONS ): all_simple_aggs = False break @@ -142,31 +133,19 @@ def agg_dask( if all_simple_aggs: simple_aggregations: dict[str, tuple[str, str | dd.Aggregation]] = {} for expr in exprs: + output_names, aliases = evaluate_output_names_and_aliases(expr, df, keys) if expr._depth == 0: # e.g. agg(nw.len()) # noqa: ERA001 - if expr._output_names is None: # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" - raise AssertionError(msg) - function_name = POLARS_TO_DASK_AGGREGATIONS.get( expr._function_name, expr._function_name ) simple_aggregations.update( - { - output_name: (keys[0], function_name) - for output_name in expr._output_names - } + {alias: (keys[0], function_name) for alias in aliases} ) continue # e.g. agg(nw.mean('a')) # noqa: ERA001 - if ( - expr._depth != 1 or expr._root_names is None or expr._output_names is None - ): # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" - raise AssertionError(msg) - - function_name = remove_prefix(expr._function_name, "col->") + function_name = re.sub(r"(\w+->)", "", expr._function_name) kwargs = ( {"ddof": expr._kwargs["ddof"]} if function_name in {"std", "var"} else {} ) @@ -179,10 +158,8 @@ def agg_dask( simple_aggregations.update( { - output_name: (root_name, agg_function) - for root_name, output_name in zip( - expr._root_names, expr._output_names - ) + alias: (output_name, agg_function) + for alias, output_name in zip(aliases, output_names) } ) result_simple = grouped.agg(**simple_aggregations) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 9c8a48e85..8c06365a7 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -3,6 +3,7 @@ from functools import reduce from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Iterable from typing import Literal from typing import Sequence @@ -18,8 +19,8 @@ from narwhals._dask.utils import name_preserving_sum from narwhals._dask.utils import narwhals_to_native_dtype from narwhals._dask.utils import validate_comparand -from narwhals._expression_parsing import combine_root_names -from narwhals._expression_parsing import reduce_output_names +from narwhals._expression_parsing import combine_alias_output_names +from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.typing import CompliantNamespace if TYPE_CHECKING: @@ -55,8 +56,8 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: func, depth=0, function_name="all", - root_names=None, - output_names=None, + evaluate_output_names=lambda df: df.columns, + alias_output_names=None, returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -92,8 +93,8 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: func, depth=0, function_name="lit", - root_names=None, - output_names=["literal"], + evaluate_output_names=lambda _df: ["literal"], + alias_output_names=None, returns_scalar=True, backend_version=self._backend_version, version=self._version, @@ -109,15 +110,15 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: npartitions=df._native_frame.npartitions, ) ] - return [df._native_frame[df.columns[0]].size.to_series().rename("len")] + return [df._native_frame[df.columns[0]].size.to_series()] # coverage bug? this is definitely hit return DaskExpr( # pragma: no cover func, depth=0, function_name="len", - root_names=None, - output_names=["len"], + evaluate_output_names=lambda _df: ["len"], + alias_output_names=None, returns_scalar=True, backend_version=self._backend_version, version=self._version, @@ -127,14 +128,14 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: def all_horizontal(self: Self, *exprs: DaskExpr) -> DaskExpr: def func(df: DaskLazyFrame) -> list[dx.Series]: series = [s for _expr in exprs for s in _expr(df)] - return [reduce(lambda x, y: x & y, series).rename(series[0].name)] + return [reduce(lambda x, y: x & y, series)] return DaskExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="all_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -144,14 +145,14 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: def any_horizontal(self: Self, *exprs: DaskExpr) -> DaskExpr: def func(df: DaskLazyFrame) -> list[dx.Series]: series = [s for _expr in exprs for s in _expr(df)] - return [reduce(lambda x, y: x | y, series).rename(series[0].name)] + return [reduce(lambda x, y: x | y, series)] return DaskExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="any_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -161,14 +162,14 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: def sum_horizontal(self: Self, *exprs: DaskExpr) -> DaskExpr: def func(df: DaskLazyFrame) -> list[dx.Series]: series = [s.fillna(0) for _expr in exprs for s in _expr(df)] - return [reduce(lambda x, y: x + y, series).rename(series[0].name)] + return [reduce(lambda x, y: x + y, series)] return DaskExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="sum_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -245,8 +246,8 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: call=func, depth=max(x._depth for x in exprs) + 1, function_name="mean_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -257,14 +258,14 @@ def min_horizontal(self: Self, *exprs: DaskExpr) -> DaskExpr: def func(df: DaskLazyFrame) -> list[dx.Series]: series = [s for _expr in exprs for s in _expr(df)] - return [dd.concat(series, axis=1).min(axis=1).rename(series[0].name)] + return [dd.concat(series, axis=1).min(axis=1)] return DaskExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="min_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -275,14 +276,14 @@ def max_horizontal(self: Self, *exprs: DaskExpr) -> DaskExpr: def func(df: DaskLazyFrame) -> list[dx.Series]: series = [s for _expr in exprs for s in _expr(df)] - return [dd.concat(series, axis=1).max(axis=1).rename(series[0].name)] + return [dd.concat(series, axis=1).max(axis=1)] return DaskExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="max_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -329,14 +330,16 @@ def func(df: DaskLazyFrame) -> list[dx.Series]: init_value, ) - return [result.rename(null_mask[0].name)] + return [result] return DaskExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="concat_str", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=getattr( + exprs[0], "_evaluate_output_names", lambda _df: ["literal"] + ), + alias_output_names=getattr(exprs[0], "_alias_output_names", None), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -410,8 +413,10 @@ def then(self: Self, value: DaskExpr | Any) -> DaskThen: self, depth=0, function_name="whenthen", - root_names=None, - output_names=None, + evaluate_output_names=getattr( + value, "_evaluate_output_names", lambda _df: ["literal"] + ), + alias_output_names=getattr(value, "_alias_output_names", None), returns_scalar=self._returns_scalar, backend_version=self._backend_version, version=self._version, @@ -426,8 +431,8 @@ def __init__( *, depth: int, function_name: str, - root_names: list[str] | None, - output_names: list[str] | None, + evaluate_output_names: Callable[[DaskLazyFrame], Sequence[str]], + alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, returns_scalar: bool, backend_version: tuple[int, ...], version: Version, @@ -438,8 +443,8 @@ def __init__( self._call = call self._depth = depth self._function_name = function_name - self._root_names = root_names - self._output_names = output_names + self._evaluate_output_names = evaluate_output_names + self._alias_output_names = alias_output_names self._returns_scalar = returns_scalar self._kwargs = kwargs diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index 9e6cc6302..62083c676 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -2,21 +2,23 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Sequence from narwhals._dask.expr import DaskExpr from narwhals.utils import import_dtypes_module if TYPE_CHECKING: - try: - import dask.dataframe.dask_expr as dx - except ModuleNotFoundError: - import dask_expr as dx from typing_extensions import Self from narwhals._dask.dataframe import DaskLazyFrame from narwhals.dtypes import DType from narwhals.utils import Version + try: + import dask.dataframe.dask_expr as dx + except ModuleNotFoundError: + import dask_expr as dx + class DaskSelectorNamespace: def __init__( @@ -26,17 +28,20 @@ def __init__( self._version = version def by_dtype(self: Self, dtypes: list[DType | type[DType]]) -> DaskSelector: - def func(df: DaskLazyFrame) -> list[Any]: + def func(df: DaskLazyFrame) -> list[dx.Series]: return [ df._native_frame[col] for col in df.columns if df.schema[col] in dtypes ] + def evalute_output_names(df: DaskLazyFrame) -> Sequence[str]: + return [col for col in df.columns if df.schema[col] in dtypes] + return DaskSelector( func, depth=0, - function_name="type_selector", - root_names=None, - output_names=None, + function_name="selector", + evaluate_output_names=evalute_output_names, + alias_output_names=None, backend_version=self._backend_version, returns_scalar=False, version=self._version, @@ -73,15 +78,15 @@ def boolean(self: Self) -> DaskSelector: return self.by_dtype([dtypes.Boolean]) def all(self: Self) -> DaskSelector: - def func(df: DaskLazyFrame) -> list[Any]: + def func(df: DaskLazyFrame) -> list[dx.Series]: return [df._native_frame[col] for col in df.columns] return DaskSelector( func, depth=0, - function_name="type_selector", - root_names=None, - output_names=None, + function_name="selector", + evaluate_output_names=lambda df: df.columns, + alias_output_names=None, backend_version=self._backend_version, returns_scalar=False, version=self._version, @@ -94,9 +99,7 @@ def __repr__(self: Self) -> str: # pragma: no cover return ( f"DaskSelector(" f"depth={self._depth}, " - f"function_name={self._function_name}, " - f"root_names={self._root_names}, " - f"output_names={self._output_names}" + f"function_name={self._function_name})" ) def _to_expr(self: Self) -> DaskExpr: @@ -104,8 +107,8 @@ def _to_expr(self: Self) -> DaskExpr: self._call, depth=self._depth, function_name=self._function_name, - root_names=self._root_names, - output_names=self._output_names, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, backend_version=self._backend_version, returns_scalar=self._returns_scalar, version=self._version, @@ -115,17 +118,23 @@ def _to_expr(self: Self) -> DaskExpr: def __sub__(self: Self, other: DaskSelector | Any) -> DaskSelector | Any: if isinstance(other, DaskSelector): - def call(df: DaskLazyFrame) -> list[Any]: + def call(df: DaskLazyFrame) -> list[dx.Series]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) lhs = self._call(df) - rhs = other._call(df) - return [x for x in lhs if x.name not in {x.name for x in rhs}] + return [x for x, name in zip(lhs, lhs_names) if name not in rhs_names] + + def evaluate_output_names(df: DaskLazyFrame) -> list[str]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) + return [x for x in lhs_names if x not in rhs_names] return DaskSelector( call, depth=0, - function_name="type_selector", - root_names=None, - output_names=None, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, backend_version=self._backend_version, returns_scalar=self._returns_scalar, version=self._version, @@ -138,16 +147,26 @@ def __or__(self: Self, other: DaskSelector | Any) -> DaskSelector | Any: if isinstance(other, DaskSelector): def call(df: DaskLazyFrame) -> list[dx.Series]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) lhs = self._call(df) rhs = other._call(df) - return [*(x for x in lhs if x.name not in {x.name for x in rhs}), *rhs] + return [ + *(x for x, name in zip(lhs, lhs_names) if name not in rhs_names), + *rhs, + ] + + def evaluate_output_names(df: DaskLazyFrame) -> list[str]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) + return [*(x for x in lhs_names if x not in rhs_names), *rhs_names] return DaskSelector( call, depth=0, - function_name="type_selector", - root_names=None, - output_names=None, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, backend_version=self._backend_version, returns_scalar=self._returns_scalar, version=self._version, @@ -159,17 +178,23 @@ def call(df: DaskLazyFrame) -> list[dx.Series]: def __and__(self: Self, other: DaskSelector | Any) -> DaskSelector | Any: if isinstance(other, DaskSelector): - def call(df: DaskLazyFrame) -> list[Any]: + def call(df: DaskLazyFrame) -> list[dx.Series]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) lhs = self._call(df) - rhs = other._call(df) - return [x for x in lhs if x.name in {x.name for x in rhs}] + return [x for x, name in zip(lhs, lhs_names) if name in rhs_names] + + def evaluate_output_names(df: DaskLazyFrame) -> list[str]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) + return [x for x in lhs_names if x in rhs_names] return DaskSelector( call, depth=0, - function_name="type_selector", - root_names=None, - output_names=None, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, backend_version=self._backend_version, returns_scalar=self._returns_scalar, version=self._version, diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index c7140e96f..33584efdb 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING from typing import Any +from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._pandas_like.utils import select_columns_by_name from narwhals.dependencies import get_pandas from narwhals.dependencies import get_pyarrow @@ -46,22 +47,26 @@ def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any: def parse_exprs_and_named_exprs( df: DaskLazyFrame, *exprs: DaskExpr, **named_exprs: DaskExpr ) -> dict[str, dx.Series]: - results = {} + native_results: dict[str, dx.Series] = {} for expr in exprs: - _results = expr._call(df) + native_series_list = expr._call(df) return_scalar = getattr(expr, "_returns_scalar", False) - for _result in _results: - results[_result.name] = _result[0] if return_scalar else _result - + _, aliases = evaluate_output_names_and_aliases(expr, df, []) + if len(aliases) != len(native_series_list): # pragma: no cover + msg = f"Internal error: got aliases {aliases}, but only got {len(native_series_list)} results" + raise AssertionError(msg) + for native_series, alias in zip(native_series_list, aliases): + native_results[alias] = native_series[0] if return_scalar else native_series for name, value in named_exprs.items(): - _results = value._call(df) - if len(_results) != 1: # pragma: no cover + native_series_list = value._call(df) + if len(native_series_list) != 1: # pragma: no cover msg = "Named expressions must return a single column" raise AssertionError(msg) return_scalar = getattr(value, "_returns_scalar", False) - for _result in _results: - results[name] = _result[0] if return_scalar else _result - return results + native_results[name] = ( + native_series_list[0][0] if return_scalar else native_series_list[0] + ) + return native_results def add_row_index( diff --git a/narwhals/_duckdb/expr.py b/narwhals/_duckdb/expr.py index 031e48e67..04eadacfd 100644 --- a/narwhals/_duckdb/expr.py +++ b/narwhals/_duckdb/expr.py @@ -17,10 +17,8 @@ from narwhals._duckdb.expr_name import DuckDBExprNameNamespace from narwhals._duckdb.expr_str import DuckDBExprStringNamespace from narwhals._duckdb.utils import binary_operation_returns_scalar -from narwhals._duckdb.utils import get_column_name from narwhals._duckdb.utils import maybe_evaluate from narwhals._duckdb.utils import narwhals_to_native_dtype -from narwhals._expression_parsing import infer_new_root_output_names from narwhals.typing import CompliantExpr from narwhals.utils import Implementation @@ -43,8 +41,8 @@ def __init__( *, depth: int, function_name: str, - root_names: list[str] | None, - output_names: list[str] | None, + evaluate_output_names: Callable[[DuckDBLazyFrame], Sequence[str]], + alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, # Whether the expression is a length-1 Column resulting from # a reduction, such as `nw.col('a').sum()` returns_scalar: bool, @@ -55,8 +53,8 @@ def __init__( self._call = call self._depth = depth self._function_name = function_name - self._root_names = root_names - self._output_names = output_names + self._evaluate_output_names = evaluate_output_names + self._alias_output_names = alias_output_names self._returns_scalar = returns_scalar self._backend_version = backend_version self._version = version @@ -89,8 +87,8 @@ def func(_: DuckDBLazyFrame) -> list[duckdb.Expression]: func, depth=0, function_name="col", - root_names=list(column_names), - output_names=list(column_names), + evaluate_output_names=lambda _df: list(column_names), + alias_output_names=None, returns_scalar=False, backend_version=backend_version, version=version, @@ -113,8 +111,8 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: func, depth=0, function_name="nth", - root_names=None, - output_names=None, + evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], + alias_output_names=None, returns_scalar=False, backend_version=backend_version, version=version, @@ -127,44 +125,29 @@ def _from_call( expr_name: str, *, returns_scalar: bool, - **kwargs: Any, + **expressifiable_args: Self | Any, ) -> Self: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: - results = [] - inputs = self._call(df) - _kwargs = {key: maybe_evaluate(df, value) for key, value in kwargs.items()} - for _input in inputs: - input_col_name = get_column_name( - df, _input, returns_scalar=self._returns_scalar - ) - if self._returns_scalar: - # TODO(marco): once WindowExpression is supported, then - # we may need to call it with `over(1)` here, - # depending on the context? - pass - - column_result = call(_input, **_kwargs) - column_result = column_result.alias(input_col_name) - if returns_scalar: - # TODO(marco): once WindowExpression is supported, then - # we may need to call it with `over(1)` here, - # depending on the context? - pass - results.append(column_result) - return results - - root_names, output_names = infer_new_root_output_names(self, **kwargs) + native_series_list = self._call(df) + other_native_series = { + key: maybe_evaluate(df, value) + for key, value in expressifiable_args.items() + } + return [ + call(native_series, **other_native_series) + for native_series in native_series_list + ] return self.__class__( func, depth=self._depth + 1, function_name=f"{self._function_name}->{expr_name}", - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, returns_scalar=returns_scalar, backend_version=self._backend_version, version=self._version, - kwargs=kwargs, + kwargs=expressifiable_args, ) def __and__(self: Self, other: DuckDBExpr) -> Self: @@ -295,17 +278,20 @@ def __invert__(self: Self) -> Self: ) def alias(self: Self, name: str) -> Self: - def _alias(df: DuckDBLazyFrame) -> list[duckdb.Expression]: - return [col.alias(name) for col in self._call(df)] + def alias_output_names(names: Sequence[str]) -> Sequence[str]: + if len(names) != 1: + msg = f"Expected function with single output, found output names: {names}" + raise ValueError(msg) + return [name] # Define this one manually, so that we can # override `output_names` and not increase depth return self.__class__( - _alias, + self._call, depth=self._depth, function_name=self._function_name, - root_names=self._root_names, - output_names=[name], + evaluate_output_names=self._evaluate_output_names, + alias_output_names=alias_output_names, returns_scalar=self._returns_scalar, backend_version=self._backend_version, version=self._version, diff --git a/narwhals/_duckdb/expr_name.py b/narwhals/_duckdb/expr_name.py index 2ed2b2ea8..a2cc890de 100644 --- a/narwhals/_duckdb/expr_name.py +++ b/narwhals/_duckdb/expr_name.py @@ -3,8 +3,6 @@ from typing import TYPE_CHECKING from typing import Callable -from narwhals.exceptions import AnonymousExprError - if TYPE_CHECKING: from typing_extensions import Self @@ -16,19 +14,12 @@ def __init__(self: Self, expr: DuckDBExpr) -> None: self._compliant_expr = expr def keep(self: Self) -> DuckDBExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.keep" - raise AnonymousExprError.from_expr_name(msg) return self._compliant_expr.__class__( - lambda df: [ - expr.alias(name) - for expr, name in zip(self._compliant_expr._call(df), root_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=root_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=None, returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -36,22 +27,14 @@ def keep(self: Self) -> DuckDBExpr: ) def map(self: Self, function: Callable[[str], str]) -> DuckDBExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.map" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [function(str(name)) for name in root_names] - return self._compliant_expr.__class__( - lambda df: [ - expr.alias(name) - for expr, name in zip(self._compliant_expr._call(df), output_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + function(str(name)) for name in output_names + ], returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -59,21 +42,14 @@ def map(self: Self, function: Callable[[str], str]) -> DuckDBExpr: ) def prefix(self: Self, prefix: str) -> DuckDBExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.prefix" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [prefix + str(name) for name in root_names] return self._compliant_expr.__class__( - lambda df: [ - expr.alias(name) - for expr, name in zip(self._compliant_expr._call(df), output_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + f"{prefix}{output_name}" for output_name in output_names + ], returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -81,22 +57,14 @@ def prefix(self: Self, prefix: str) -> DuckDBExpr: ) def suffix(self: Self, suffix: str) -> DuckDBExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.suffix" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [str(name) + suffix for name in root_names] - return self._compliant_expr.__class__( - lambda df: [ - expr.alias(name) - for expr, name in zip(self._compliant_expr._call(df), output_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + f"{output_name}{suffix}" for output_name in output_names + ], returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -104,22 +72,14 @@ def suffix(self: Self, suffix: str) -> DuckDBExpr: ) def to_lowercase(self: Self) -> DuckDBExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.to_lowercase" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [str(name).lower() for name in root_names] - return self._compliant_expr.__class__( - lambda df: [ - expr.alias(name) - for expr, name in zip(self._compliant_expr._call(df), output_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + str(name).lower() for name in output_names + ], returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -127,21 +87,14 @@ def to_lowercase(self: Self) -> DuckDBExpr: ) def to_uppercase(self: Self) -> DuckDBExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.to_uppercase" - raise AnonymousExprError.from_expr_name(msg) - output_names = [str(name).upper() for name in root_names] - return self._compliant_expr.__class__( - lambda df: [ - expr.alias(name) - for expr, name in zip(self._compliant_expr._call(df), output_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + str(name).upper() for name in output_names + ], returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, diff --git a/narwhals/_duckdb/group_by.py b/narwhals/_duckdb/group_by.py index 11f629bec..f005ea0e2 100644 --- a/narwhals/_duckdb/group_by.py +++ b/narwhals/_duckdb/group_by.py @@ -1,10 +1,7 @@ from __future__ import annotations -from copy import copy from typing import TYPE_CHECKING -from narwhals.exceptions import AnonymousExprError - if TYPE_CHECKING: from typing_extensions import Self @@ -25,22 +22,32 @@ def __init__( self._compliant_frame = compliant_frame self._keys = keys - def agg( - self: Self, - *exprs: DuckDBExpr, - ) -> DuckDBLazyFrame: - output_names: list[str] = copy(self._keys) + def agg(self: Self, *exprs: DuckDBExpr) -> DuckDBLazyFrame: + agg_columns = self._keys.copy() + df = self._compliant_frame for expr in exprs: - if expr._output_names is None: # pragma: no cover - msg = "group_by.agg" - raise AnonymousExprError.from_expr_name(msg) - - output_names.extend(expr._output_names) + output_names = expr._evaluate_output_names(df) + aliases = ( + output_names + if expr._alias_output_names is None + else expr._alias_output_names(output_names) + ) + native_expressions = expr(df) + exclude = ( + self._keys + if expr._function_name.split("->", maxsplit=1)[0] in ("all", "selector") + else [] + ) + agg_columns.extend( + [ + native_expression.alias(alias) + for native_expression, output_name, alias in zip( + native_expressions, output_names, aliases + ) + if output_name not in exclude + ] + ) - agg_columns = [ - *self._keys, - *(x for expr in exprs for x in expr(self._compliant_frame)), - ] return self._compliant_frame._from_native_frame( self._compliant_frame._native_frame.aggregate( agg_columns, group_expr=",".join(f'"{key}"' for key in self._keys) diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index 2ceffc4eb..d3b70826e 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -5,6 +5,7 @@ from functools import reduce from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Literal from typing import Sequence from typing import cast @@ -17,8 +18,8 @@ from narwhals._duckdb.expr import DuckDBExpr from narwhals._duckdb.utils import narwhals_to_native_dtype -from narwhals._expression_parsing import combine_root_names -from narwhals._expression_parsing import reduce_output_names +from narwhals._expression_parsing import combine_alias_output_names +from narwhals._expression_parsing import combine_evaluate_output_names from narwhals.typing import CompliantNamespace if TYPE_CHECKING: @@ -30,10 +31,6 @@ from narwhals.utils import Version -def get_column_name(df: DuckDBLazyFrame, column: duckdb.Expression) -> str: - return str(df._native_frame.select(column).columns[0]) - - class DuckDBNamespace(CompliantNamespace["duckdb.Expression"]): def __init__( self: Self, *, backend_version: tuple[int, ...], version: Version @@ -49,8 +46,8 @@ def _all(df: DuckDBLazyFrame) -> list[duckdb.Expression]: call=_all, depth=0, function_name="all", - root_names=None, - output_names=None, + evaluate_output_names=lambda df: df.columns, + alias_output_names=None, returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -88,7 +85,6 @@ def concat_str( def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = [s for _expr in exprs for s in _expr(df)] null_mask = [s.isnull() for _expr in exprs for s in _expr(df)] - first_column_name = get_column_name(df, cols[0]) if not ignore_nulls: null_mask_result = reduce(lambda x, y: x | y, null_mask) @@ -128,14 +124,14 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: init_value, ) - return [result.alias(first_column_name)] + return [result] return DuckDBExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="concat_str", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -149,15 +145,14 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: def all_horizontal(self: Self, *exprs: DuckDBExpr) -> DuckDBExpr: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = [c for _expr in exprs for c in _expr(df)] - col_name = get_column_name(df, cols[0]) - return [reduce(operator.and_, cols).alias(col_name)] + return [reduce(operator.and_, cols)] return DuckDBExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="all_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -167,15 +162,14 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: def any_horizontal(self: Self, *exprs: DuckDBExpr) -> DuckDBExpr: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = [c for _expr in exprs for c in _expr(df)] - col_name = get_column_name(df, cols[0]) - return [reduce(operator.or_, cols).alias(col_name)] + return [reduce(operator.or_, cols)] return DuckDBExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="or_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -185,15 +179,14 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: def max_horizontal(self: Self, *exprs: DuckDBExpr) -> DuckDBExpr: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = [c for _expr in exprs for c in _expr(df)] - col_name = get_column_name(df, cols[0]) - return [FunctionExpression("greatest", *cols).alias(col_name)] + return [FunctionExpression("greatest", *cols)] return DuckDBExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="max_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -203,15 +196,14 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: def min_horizontal(self: Self, *exprs: DuckDBExpr) -> DuckDBExpr: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = [c for _expr in exprs for c in _expr(df)] - col_name = get_column_name(df, cols[0]) - return [FunctionExpression("least", *cols).alias(col_name)] + return [FunctionExpression("least", *cols)] return DuckDBExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="min_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -221,20 +213,19 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: def sum_horizontal(self: Self, *exprs: DuckDBExpr) -> DuckDBExpr: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = [c for _expr in exprs for c in _expr(df)] - col_name = get_column_name(df, cols[0]) return [ reduce( operator.add, (CoalesceOperator(col, ConstantExpression(0)) for col in cols), - ).alias(col_name) + ) ] return DuckDBExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="sum_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -244,7 +235,6 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: def mean_horizontal(self: Self, *exprs: DuckDBExpr) -> DuckDBExpr: def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: cols = [c for _expr in exprs for c in _expr(df)] - col_name = get_column_name(df, cols[0]) return [ ( reduce( @@ -252,15 +242,15 @@ def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]: (CoalesceOperator(col, ConstantExpression(0)) for col in cols), ) / reduce(operator.add, (col.isnotnull().cast("int") for col in cols)) - ).alias(col_name) + ) ] return DuckDBExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="mean_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -291,18 +281,18 @@ def lit(self: Self, value: Any, dtype: DType | None) -> DuckDBExpr: def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: if dtype is not None: return [ - ConstantExpression(value) - .cast(narwhals_to_native_dtype(dtype, version=self._version)) - .alias("literal") + ConstantExpression(value).cast( + narwhals_to_native_dtype(dtype, version=self._version) + ) ] - return [ConstantExpression(value).alias("literal")] + return [ConstantExpression(value)] return DuckDBExpr( func, depth=0, function_name="lit", - root_names=None, - output_names=["literal"], + evaluate_output_names=lambda _df: ["literal"], + alias_output_names=None, returns_scalar=True, backend_version=self._backend_version, version=self._version, @@ -311,14 +301,14 @@ def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: def len(self: Self) -> DuckDBExpr: def func(_df: DuckDBLazyFrame) -> list[duckdb.Expression]: - return [FunctionExpression("count").alias("len")] + return [FunctionExpression("count")] return DuckDBExpr( call=func, depth=0, function_name="len", - root_names=None, - output_names=["len"], + evaluate_output_names=lambda _df: ["len"], + alias_output_names=None, returns_scalar=True, backend_version=self._backend_version, version=self._version, @@ -352,25 +342,20 @@ def __call__(self: Self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: value = self._then_value(df)[0] else: # `self._otherwise_value` is a scalar - value = ConstantExpression(self._then_value).alias("literal") + value = ConstantExpression(self._then_value) value = cast("duckdb.Expression", value) - value_name = get_column_name(df, value) if self._otherwise_value is None: - return [CaseExpression(condition=condition, value=value).alias(value_name)] + return [CaseExpression(condition=condition, value=value)] if not isinstance(self._otherwise_value, DuckDBExpr): # `self._otherwise_value` is a scalar return [ - CaseExpression(condition=condition, value=value) - .otherwise(ConstantExpression(self._otherwise_value)) - .alias(value_name) + CaseExpression(condition=condition, value=value).otherwise( + ConstantExpression(self._otherwise_value) + ) ] otherwise = self._otherwise_value(df)[0] - return [ - CaseExpression(condition=condition, value=value) - .otherwise(otherwise) - .alias(value_name) - ] + return [CaseExpression(condition=condition, value=value).otherwise(otherwise)] def then(self: Self, value: DuckDBExpr | Any) -> DuckDBThen: self._then_value = value @@ -379,8 +364,10 @@ def then(self: Self, value: DuckDBExpr | Any) -> DuckDBThen: self, depth=0, function_name="whenthen", - root_names=None, - output_names=None, + evaluate_output_names=getattr( + value, "_evaluate_output_names", lambda _df: ["literal"] + ), + alias_output_names=getattr(value, "_alias_output_names", None), returns_scalar=self._returns_scalar, backend_version=self._backend_version, version=self._version, @@ -395,8 +382,8 @@ def __init__( *, depth: int, function_name: str, - root_names: list[str] | None, - output_names: list[str] | None, + evaluate_output_names: Callable[[DuckDBLazyFrame], Sequence[str]], + alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, returns_scalar: bool, backend_version: tuple[int, ...], version: Version, @@ -407,8 +394,8 @@ def __init__( self._call = call self._depth = depth self._function_name = function_name - self._root_names = root_names - self._output_names = output_names + self._evaluate_output_names = evaluate_output_names + self._alias_output_names = alias_output_names self._returns_scalar = returns_scalar self._kwargs = kwargs diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index d45123267..8457843e4 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -4,31 +4,20 @@ from functools import lru_cache from typing import TYPE_CHECKING from typing import Any -from typing import Sequence + +import duckdb from narwhals.dtypes import DType from narwhals.utils import import_dtypes_module from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: - import duckdb - from narwhals._duckdb.dataframe import DuckDBLazyFrame from narwhals._duckdb.expr import DuckDBExpr from narwhals.utils import Version -def get_column_name( - df: DuckDBLazyFrame, column: duckdb.Expression, *, returns_scalar: bool -) -> str: - if returns_scalar: - return str(df._native_frame.aggregate([column]).columns[0]) - return str(df._native_frame.select(column).columns[0]) - - def maybe_evaluate(df: DuckDBLazyFrame, obj: Any) -> Any: - import duckdb - from narwhals._duckdb.expr import DuckDBExpr if isinstance(obj, DuckDBExpr): @@ -47,40 +36,25 @@ def maybe_evaluate(df: DuckDBLazyFrame, obj: Any) -> Any: def parse_exprs_and_named_exprs( - df: DuckDBLazyFrame, - *exprs: DuckDBExpr, - **named_exprs: DuckDBExpr, + df: DuckDBLazyFrame, *exprs: DuckDBExpr, **named_exprs: DuckDBExpr ) -> dict[str, duckdb.Expression]: - result_columns: dict[str, list[duckdb.Expression]] = {} + native_results: dict[str, list[duckdb.Expression]] = {} for expr in exprs: - column_list = _columns_from_expr(df, expr) - if expr._output_names is None: - output_names = [ - get_column_name(df, col, returns_scalar=expr._returns_scalar) - for col in column_list - ] - else: - output_names = expr._output_names - result_columns.update(zip(output_names, column_list)) + native_series_list = expr._call(df) + output_names = expr._evaluate_output_names(df) + if expr._alias_output_names is not None: + output_names = expr._alias_output_names(output_names) + if len(output_names) != len(native_series_list): # pragma: no cover + msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" + raise AssertionError(msg) + native_results.update(zip(output_names, native_series_list)) for col_alias, expr in named_exprs.items(): - columns_list = _columns_from_expr(df, expr) - if len(columns_list) != 1: # pragma: no cover + native_series_list = expr._call(df) + if len(native_series_list) != 1: # pragma: no cover msg = "Named expressions must return a single column" - raise AssertionError(msg) - result_columns[col_alias] = columns_list[0] - return result_columns - - -def _columns_from_expr( - df: DuckDBLazyFrame, expr: DuckDBExpr -) -> Sequence[duckdb.Expression]: - col_output_list = expr._call(df) - if expr._output_names is not None and ( - len(col_output_list) != len(expr._output_names) - ): # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" - raise AssertionError(msg) - return col_output_list + raise ValueError(msg) + native_results[col_alias] = native_series_list[0] + return native_results @lru_cache(maxsize=16) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index 96f1b00c6..af4742f35 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -3,9 +3,9 @@ # and pandas or PyArrow. from __future__ import annotations -from copy import copy from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Sequence from typing import TypeVar from typing import Union @@ -42,9 +42,22 @@ def evaluate_into_expr( df: CompliantDataFrame | CompliantLazyFrame, into_expr: IntoCompliantExpr[CompliantSeriesT_co], ) -> Sequence[CompliantSeriesT_co]: - """Return list of raw columns.""" + """Return list of raw columns. + + This is only use for eager backends (pandas, PyArrow), where we + alias operations at each step. As a safety precaution, here we + can check that the expected result names match those we were + expecting from the various `evaluate_output_names` / `alias_output_names` + calls. Note that for PySpark / DuckDB, we are less free to liberally + set aliases whenever we want. + """ expr = parse_into_expr(into_expr, namespace=df.__narwhals_namespace__()) - return expr(df) + _, aliases = evaluate_output_names_and_aliases(expr, df, []) + result = expr(df) + if list(aliases) != [s.name for s in result]: # pragma: no cover + msg = f"Safety assertion failed, expected {aliases}, got {result}" + raise AssertionError(msg) + return result def evaluate_into_exprs( @@ -119,38 +132,6 @@ def parse_into_expr( raise InvalidIntoExprError.from_invalid_type(type(into_expr)) -def infer_new_root_output_names( - expr: CompliantExpr[Any], **kwargs: Any -) -> tuple[list[str] | None, list[str] | None]: - """Return new root and output names after chaining expressions. - - Try tracking root and output names by combining them from all expressions appearing in kwargs. - If any anonymous expression appears (e.g. nw.all()), then give up on tracking root names - and just set it to None. - """ - root_names = copy(expr._root_names) - output_names = expr._output_names - for arg in list(kwargs.values()): - if root_names is not None and isinstance(arg, expr.__class__): - if arg._root_names is not None: - root_names.extend(arg._root_names) - else: - root_names = None - output_names = None - break - elif root_names is None: - output_names = None - break - - if not ( - (output_names is None and root_names is None) - or (output_names is not None and root_names is not None) - ): # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" - raise AssertionError(msg) - return root_names, output_names - - @overload def reuse_series_implementation( expr: PandasLikeExprT, @@ -176,7 +157,7 @@ def reuse_series_implementation( attr: str, *, returns_scalar: bool = False, - **kwargs: Any, + **expressifiable_args: Any, ) -> ArrowExprT | PandasLikeExprT: """Reuse Series implementation for expression. @@ -189,14 +170,15 @@ def reuse_series_implementation( returns_scalar: whether the Series version returns a scalar. In this case, the expression version should return a 1-row Series. args: arguments to pass to function. - kwargs: keyword arguments to pass to function. + expressifiable_args: keyword arguments to pass to function, which may + be expressifiable (e.g. `nw.col('a').is_between(3, nw.col('b')))`). """ plx = expr.__narwhals_namespace__() def func(df: CompliantDataFrame) -> Sequence[CompliantSeries]: _kwargs = { # type: ignore[var-annotated] arg_name: maybe_evaluate_expr(df, arg_value) - for arg_name, arg_value in kwargs.items() + for arg_name, arg_value in expressifiable_args.items() } # For PyArrow.Series, we return Python Scalars (like Polars does) instead of PyArrow Scalars. @@ -216,26 +198,23 @@ def func(df: CompliantDataFrame) -> Sequence[CompliantSeries]: else getattr(series, attr)(**_kwargs) for series in expr(df) # type: ignore[arg-type] ] - if expr._output_names is not None and ( - [s.name for s in out] != expr._output_names - ): # pragma: no cover + _, aliases = evaluate_output_names_and_aliases(expr, df, []) + if [s.name for s in out] != list(aliases): # pragma: no cover msg = ( f"Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues\n" - f"Expression output names: {expr._output_names}\n" + f"Expression aliases: {aliases}\n" f"Series names: {[s.name for s in out]}" ) raise AssertionError(msg) return out - root_names, output_names = infer_new_root_output_names(expr, **kwargs) - return plx._create_expr_from_callable( # type: ignore[return-value] func, # type: ignore[arg-type] depth=expr._depth + 1, function_name=f"{expr._function_name}->{attr}", - root_names=root_names, - output_names=output_names, - kwargs={**expr._kwargs, **kwargs}, + evaluate_output_names=expr._evaluate_output_names, # type: ignore[arg-type] + alias_output_names=expr._alias_output_names, + kwargs={**expr._kwargs, **expressifiable_args}, ) @@ -273,8 +252,8 @@ def reuse_series_namespace_implementation( ], depth=expr._depth + 1, function_name=f"{expr._function_name}->{series_namespace}.{attr}", - root_names=expr._root_names, - output_names=expr._output_names, + evaluate_output_names=expr._evaluate_output_names, # type: ignore[arg-type] + alias_output_names=expr._alias_output_names, kwargs={**expr._kwargs, **kwargs}, ) @@ -296,25 +275,34 @@ def is_simple_aggregation(expr: CompliantExpr[Any]) -> bool: return expr._depth < 2 -def combine_root_names(parsed_exprs: Sequence[CompliantExpr[Any]]) -> list[str] | None: - root_names = copy(parsed_exprs[0]._root_names) - for arg in parsed_exprs[1:]: - if root_names is not None: - if arg._root_names is not None: - root_names.extend(arg._root_names) - else: - root_names = None - break - return root_names - - -def reduce_output_names(parsed_exprs: Sequence[CompliantExpr[Any]]) -> list[str] | None: - """Returns the left-most output name.""" - return ( - parsed_exprs[0]._output_names[:1] - if parsed_exprs[0]._output_names is not None - else None - ) +def combine_evaluate_output_names( + *exprs: CompliantExpr[Any], +) -> Callable[[CompliantDataFrame | CompliantLazyFrame], Sequence[str]]: + # Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the + # first name of `expr1`. + def evaluate_output_names( + df: CompliantDataFrame | CompliantLazyFrame, + ) -> Sequence[str]: + if not hasattr(exprs[0], "__narwhals_expr__"): # pragma: no cover + msg = f"Safety assertion failed, expected expression, got: {type(exprs[0])}. Please report a bug." + raise AssertionError(msg) + return exprs[0]._evaluate_output_names(df)[:1] + + return evaluate_output_names + + +def combine_alias_output_names( + *exprs: CompliantExpr[Any], +) -> Callable[[Sequence[str]], Sequence[str]] | None: + # Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1.alias(alias), expr2)` takes the + # aliasing function of `expr1` and apply it to the first output name of `expr1`. + if exprs[0]._alias_output_names is None: + return None + + def alias_output_names(names: Sequence[str]) -> Sequence[str]: + return exprs[0]._alias_output_names(names)[:1] # type: ignore[misc] + + return alias_output_names def extract_compliant( @@ -380,3 +368,25 @@ def operation_aggregates(*args: IntoExpr | Any) -> bool: # expression does not aggregate, then broadcasting will take # place and the result will not be an aggregate. return all(getattr(x, "_aggregates", True) for x in args) + + +def evaluate_output_names_and_aliases( + expr: CompliantExpr[Any], + df: CompliantDataFrame | CompliantLazyFrame, + exclude: Sequence[str], +) -> tuple[Sequence[str], Sequence[str]]: + output_names = expr._evaluate_output_names(df) + if not output_names: + return [], [] + aliases = ( + output_names + if expr._alias_output_names is None + else expr._alias_output_names(output_names) + ) + if expr._function_name.split("->", maxsplit=1)[0] in {"all", "selector"}: + # For multi-output aggregations, e.g. `df.group_by('a').agg(nw.all().mean())`, we skip + # the keys, else they would appear duplicated in the output. + output_names, aliases = zip( + *[(x, alias) for x, alias in zip(output_names, aliases) if x not in exclude] + ) + return output_names, aliases diff --git a/narwhals/_pandas_like/expr.py b/narwhals/_pandas_like/expr.py index 23021317e..2fb135dab 100644 --- a/narwhals/_pandas_like/expr.py +++ b/narwhals/_pandas_like/expr.py @@ -1,11 +1,14 @@ from __future__ import annotations +import re from typing import TYPE_CHECKING from typing import Any from typing import Callable from typing import Literal from typing import Sequence +from narwhals._expression_parsing import evaluate_output_names_and_aliases +from narwhals._expression_parsing import is_simple_aggregation from narwhals._expression_parsing import reuse_series_implementation from narwhals._pandas_like.expr_cat import PandasLikeExprCatNamespace from narwhals._pandas_like.expr_dt import PandasLikeExprDateTimeNamespace @@ -16,7 +19,6 @@ from narwhals._pandas_like.utils import rename from narwhals.dependencies import get_numpy from narwhals.dependencies import is_numpy_array -from narwhals.exceptions import AnonymousExprError from narwhals.exceptions import ColumnNotFoundError from narwhals.typing import CompliantExpr @@ -31,16 +33,16 @@ from narwhals.utils import Version MANY_TO_MANY_AGG_FUNCTIONS_TO_PANDAS_EQUIVALENT = { - "col->cum_sum": "cumsum", - "col->cum_min": "cummin", - "col->cum_max": "cummax", - "col->cum_prod": "cumprod", + "cum_sum": "cumsum", + "cum_min": "cummin", + "cum_max": "cummax", + "cum_prod": "cumprod", # Pandas cumcount starts counting from 0 while Polars starts from 1 # Pandas cumcount counts nulls while Polars does not # So, instead of using "cumcount" we use "cumsum" on notna() to get the same result - "col->cum_count": "cumsum", - "col->shift": "shift", - "col->rank": "rank", + "cum_count": "cumsum", + "shift": "shift", + "rank": "rank", } @@ -51,8 +53,8 @@ def __init__( *, depth: int, function_name: str, - root_names: list[str] | None, - output_names: list[str] | None, + evaluate_output_names: Callable[[PandasLikeDataFrame], Sequence[str]], + alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, implementation: Implementation, backend_version: tuple[int, ...], version: Version, @@ -61,8 +63,8 @@ def __init__( self._call = call self._depth = depth self._function_name = function_name - self._root_names = root_names - self._output_names = output_names + self._evaluate_output_names = evaluate_output_names + self._alias_output_names = alias_output_names self._implementation = implementation self._backend_version = backend_version self._version = version @@ -76,8 +78,6 @@ def __repr__(self) -> str: # pragma: no cover f"PandasLikeExpr(" f"depth={self._depth}, " f"function_name={self._function_name}, " - f"root_names={self._root_names}, " - f"output_names={self._output_names}" ")" ) @@ -120,8 +120,8 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: func, depth=0, function_name="col", - root_names=list(column_names), - output_names=list(column_names), + evaluate_output_names=lambda _df: list(column_names), + alias_output_names=None, implementation=implementation, backend_version=backend_version, version=version, @@ -151,8 +151,8 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: func, depth=0, function_name="nth", - root_names=None, - output_names=None, + evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], + alias_output_names=None, implementation=implementation, backend_version=backend_version, version=version, @@ -376,14 +376,20 @@ def sample( ) def alias(self: Self, name: str) -> Self: + def alias_output_names(names: Sequence[str]) -> Sequence[str]: + if len(names) != 1: + msg = f"Expected function with single output, found output names: {names}" + raise ValueError(msg) + return [name] + # Define this one manually, so that we can # override `output_names` and not increase depth return self.__class__( lambda df: [series.alias(name) for series in self._call(df)], depth=self._depth, function_name=self._function_name, - root_names=self._root_names, - output_names=[name], + evaluate_output_names=self._evaluate_output_names, + alias_output_names=alias_output_names, implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -391,15 +397,14 @@ def alias(self: Self, name: str) -> Self: ) def over(self: Self, keys: list[str]) -> Self: - if self._function_name in MANY_TO_MANY_AGG_FUNCTIONS_TO_PANDAS_EQUIVALENT: + if ( + is_simple_aggregation(self) + and (function_name := re.sub(r"(\w+->)", "", self._function_name)) + in MANY_TO_MANY_AGG_FUNCTIONS_TO_PANDAS_EQUIVALENT + ): def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - if ( - self._output_names is None or self._root_names is None - ): # pragma: no cover - # Technically unreachable, but keep this for safety - msg = "over" - raise AnonymousExprError.from_expr_name(msg) + output_names, aliases = evaluate_output_names_and_aliases(self, df, []) reverse = self._kwargs.get("reverse", False) if reverse: @@ -409,13 +414,13 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: ) raise NotImplementedError(msg) - if self._function_name == "col->cum_count": + if function_name == "cum_count": plx = self.__narwhals_namespace__() - df = df.with_columns(~plx.col(*self._root_names).is_null()) + df = df.with_columns(~plx.col(*output_names).is_null()) - if self._function_name == "col->shift": + if function_name == "shift": kwargs = {"periods": self._kwargs["n"]} - elif self._function_name == "col->rank": + elif function_name == "rank": _method = self._kwargs.get("method", "average") kwargs = { "method": "first" if _method == "ordinal" else _method, @@ -427,40 +432,46 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: kwargs = {"skipna": True} res_native = getattr( - df._native_frame.groupby(list(keys), as_index=False)[ - self._root_names + df._native_frame.groupby([df._native_frame[key] for key in keys])[ + list(output_names) ], - MANY_TO_MANY_AGG_FUNCTIONS_TO_PANDAS_EQUIVALENT[self._function_name], + MANY_TO_MANY_AGG_FUNCTIONS_TO_PANDAS_EQUIVALENT[function_name], )(**kwargs) - result_frame = df._from_native_frame( rename( res_native, - columns=dict(zip(self._root_names, self._output_names)), + columns=dict(zip(output_names, aliases)), implementation=self._implementation, backend_version=self._backend_version, ) ) - return [result_frame[name] for name in self._output_names] + return [result_frame[name] for name in aliases] else: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: - if self._output_names is None: - msg = "over" - raise AnonymousExprError.from_expr_name(msg) + output_names, aliases = evaluate_output_names_and_aliases(self, df, []) + if overlap := set(output_names).intersection(keys): + # E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined, + # we just don't support it yet. + msg = ( + f"Column names {overlap} appear in both expression output names and in `over` keys.\n" + "This is not yet supported." + ) + raise NotImplementedError(msg) + tmp = df.group_by(*keys, drop_null_keys=False).agg(self) tmp = df.simple_select(*keys).join( tmp, how="left", left_on=keys, right_on=keys, suffix="_right" ) - return [tmp[name] for name in self._output_names] + return [tmp[name] for name in aliases] return self.__class__( func, depth=self._depth + 1, function_name=self._function_name + "->over", - root_names=self._root_names, - output_names=self._output_names, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -536,8 +547,8 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: func, depth=self._depth + 1, function_name=self._function_name + "->map_batches", - root_names=self._root_names, - output_names=self._output_names, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, implementation=self._implementation, backend_version=self._backend_version, version=self._version, diff --git a/narwhals/_pandas_like/expr_name.py b/narwhals/_pandas_like/expr_name.py index fb56c944a..065e4c34f 100644 --- a/narwhals/_pandas_like/expr_name.py +++ b/narwhals/_pandas_like/expr_name.py @@ -3,8 +3,6 @@ from typing import TYPE_CHECKING from typing import Callable -from narwhals.exceptions import AnonymousExprError - if TYPE_CHECKING: from typing_extensions import Self @@ -16,21 +14,18 @@ def __init__(self: Self, expr: PandasLikeExpr) -> None: self._compliant_expr = expr def keep(self: Self) -> PandasLikeExpr: - root_names = self._compliant_expr._root_names - - if root_names is None: - msg = ".name.keep" - raise AnonymousExprError.from_expr_name(msg) - return self._compliant_expr.__class__( lambda df: [ series.alias(name) - for series, name in zip(self._compliant_expr._call(df), root_names) + for series, name in zip( + self._compliant_expr._call(df), + self._compliant_expr._evaluate_output_names(df), + ) ], depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=root_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=None, implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -38,23 +33,20 @@ def keep(self: Self) -> PandasLikeExpr: ) def map(self: Self, function: Callable[[str], str]) -> PandasLikeExpr: - root_names = self._compliant_expr._root_names - - if root_names is None: - msg = ".name.map" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [function(str(name)) for name in root_names] - return self._compliant_expr.__class__( lambda df: [ - series.alias(name) - for series, name in zip(self._compliant_expr._call(df), output_names) + series.alias(function(str(name))) + for series, name in zip( + self._compliant_expr._call(df), + self._compliant_expr._evaluate_output_names(df), + ) ], depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + function(str(name)) for name in output_names + ], implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -62,21 +54,20 @@ def map(self: Self, function: Callable[[str], str]) -> PandasLikeExpr: ) def prefix(self: Self, prefix: str) -> PandasLikeExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.prefix" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [prefix + str(name) for name in root_names] return self._compliant_expr.__class__( lambda df: [ - series.alias(name) - for series, name in zip(self._compliant_expr._call(df), output_names) + series.alias(f"{prefix}{name}") + for series, name in zip( + self._compliant_expr._call(df), + self._compliant_expr._evaluate_output_names(df), + ) ], depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + f"{prefix}{output_name}" for output_name in output_names + ], implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -84,22 +75,20 @@ def prefix(self: Self, prefix: str) -> PandasLikeExpr: ) def suffix(self: Self, suffix: str) -> PandasLikeExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.suffix" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [str(name) + suffix for name in root_names] - return self._compliant_expr.__class__( lambda df: [ - series.alias(name) - for series, name in zip(self._compliant_expr._call(df), output_names) + series.alias(f"{name}{suffix}") + for series, name in zip( + self._compliant_expr._call(df), + self._compliant_expr._evaluate_output_names(df), + ) ], depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + f"{output_name}{suffix}" for output_name in output_names + ], implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -107,23 +96,20 @@ def suffix(self: Self, suffix: str) -> PandasLikeExpr: ) def to_lowercase(self: Self) -> PandasLikeExpr: - root_names = self._compliant_expr._root_names - - if root_names is None: - msg = ".name.to_lowercase" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [str(name).lower() for name in root_names] - return self._compliant_expr.__class__( lambda df: [ - series.alias(name) - for series, name in zip(self._compliant_expr._call(df), output_names) + series.alias(str(name).lower()) + for series, name in zip( + self._compliant_expr._call(df), + self._compliant_expr._evaluate_output_names(df), + ) ], depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + str(name).lower() for name in output_names + ], implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -131,23 +117,20 @@ def to_lowercase(self: Self) -> PandasLikeExpr: ) def to_uppercase(self: Self) -> PandasLikeExpr: - root_names = self._compliant_expr._root_names - - if root_names is None: - msg = ".name.to_uppercase" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [str(name).upper() for name in root_names] - return self._compliant_expr.__class__( lambda df: [ - series.alias(name) - for series, name in zip(self._compliant_expr._call(df), output_names) + series.alias(str(name).upper()) + for series, name in zip( + self._compliant_expr._call(df), + self._compliant_expr._evaluate_output_names(df), + ) ], depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + str(name).upper() for name in output_names + ], implementation=self._compliant_expr._implementation, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, diff --git a/narwhals/_pandas_like/group_by.py b/narwhals/_pandas_like/group_by.py index b817f3dd6..95ad7730d 100644 --- a/narwhals/_pandas_like/group_by.py +++ b/narwhals/_pandas_like/group_by.py @@ -1,31 +1,27 @@ from __future__ import annotations import collections +import re import warnings from copy import copy from typing import TYPE_CHECKING from typing import Any -from typing import Callable from typing import Iterator -from typing import Sequence +from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._expression_parsing import is_simple_aggregation from narwhals._pandas_like.utils import horizontal_concat from narwhals._pandas_like.utils import native_series_from_iterable from narwhals._pandas_like.utils import select_columns_by_name from narwhals._pandas_like.utils import set_columns -from narwhals.exceptions import AnonymousExprError from narwhals.utils import Implementation from narwhals.utils import find_stacklevel -from narwhals.utils import remove_prefix if TYPE_CHECKING: from typing_extensions import Self from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.expr import PandasLikeExpr - from narwhals._pandas_like.series import PandasLikeSeries - from narwhals.typing import CompliantExpr POLARS_TO_PANDAS_AGGREGATIONS = { "sum": "sum", @@ -80,286 +76,234 @@ def __init__( observed=True, ) - def agg( - self: Self, - *exprs: PandasLikeExpr, - ) -> PandasLikeDataFrame: - implementation: Implementation = self._df._implementation - output_names: list[str] = copy(self._keys) - for expr in exprs: - if expr._output_names is None: - msg = "group_by.agg" - raise AnonymousExprError.from_expr_name(msg) + def agg(self: Self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame: # noqa: PLR0915 + implementation = self._df._implementation + backend_version = self._df._backend_version - output_names.extend(expr._output_names) + new_names: list[str] = copy(self._keys) + for expr in exprs: + _, aliases = evaluate_output_names_and_aliases(expr, self._df, self._keys) + new_names.extend(aliases) - return agg_pandas( - self._grouped, - exprs, - self._keys, - output_names, - self._from_native_frame, - dataframe_is_empty=self._df._native_frame.empty, - implementation=implementation, - backend_version=self._df._backend_version, - native_namespace=self._df.__native_namespace__(), - ) + all_aggs_are_simple = True + for expr in exprs: + if not ( + is_simple_aggregation(expr) + and re.sub(r"(\w+->)", "", expr._function_name) + in POLARS_TO_PANDAS_AGGREGATIONS + ): + all_aggs_are_simple = False - def _from_native_frame(self: Self, df: PandasLikeDataFrame) -> PandasLikeDataFrame: - from narwhals._pandas_like.dataframe import PandasLikeDataFrame + # dict of {output_name: root_name} that we count n_unique on + # We need to do this separately from the rest so that we + # can pass the `dropna` kwargs. + nunique_aggs: dict[str, str] = {} + simple_aggs: dict[str, list[str]] = collections.defaultdict(list) - return PandasLikeDataFrame( - df, - implementation=self._df._implementation, - backend_version=self._df._backend_version, - version=self._df._version, + # ddof to (output_names, aliases) mapping + std_aggs: dict[int, tuple[list[str], list[str]]] = collections.defaultdict( + lambda: ([], []) + ) + var_aggs: dict[int, tuple[list[str], list[str]]] = collections.defaultdict( + lambda: ([], []) ) - def __iter__(self: Self) -> Iterator[tuple[Any, PandasLikeDataFrame]]: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", - message=".*a length 1 tuple will be returned", - category=FutureWarning, - ) - for key, group in self._grouped: - yield (key, self._from_native_frame(group)) - - -def agg_pandas( # noqa: PLR0915 - grouped: Any, - exprs: Sequence[CompliantExpr[PandasLikeSeries]], - keys: list[str], - output_names: list[str], - from_dataframe: Callable[[Any], PandasLikeDataFrame], - *, - implementation: Any, - backend_version: tuple[int, ...], - dataframe_is_empty: bool, - native_namespace: Any, -) -> PandasLikeDataFrame: - """This should be the fastpath, but cuDF is too far behind to use it. - - - https://github.com/rapidsai/cudf/issues/15118 - - https://github.com/rapidsai/cudf/issues/15084 - """ - all_aggs_are_simple = True - for expr in exprs: - if not ( - is_simple_aggregation(expr) - and remove_prefix(expr._function_name, "col->") - in POLARS_TO_PANDAS_AGGREGATIONS - ): - all_aggs_are_simple = False - break - - # dict of {output_name: root_name} that we count n_unique on - # We need to do this separately from the rest so that we - # can pass the `dropna` kwargs. - nunique_aggs: dict[str, str] = {} - simple_aggs: dict[str, list[str]] = collections.defaultdict(list) - - # ddof to (root_names, output_names) mapping - std_aggs: dict[int, tuple[list[str], list[str]]] = collections.defaultdict( - lambda: ([], []) - ) - var_aggs: dict[int, tuple[list[str], list[str]]] = collections.defaultdict( - lambda: ([], []) - ) - - expected_old_names: list[str] = [] - new_names: list[str] = [] - - if all_aggs_are_simple: - for expr in exprs: - if expr._depth == 0: - # e.g. agg(nw.len()) # noqa: ERA001 - if expr._output_names is None: # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" - raise AssertionError(msg) + expected_old_names: list[str] = [] + simple_agg_new_names: list[str] = [] + if all_aggs_are_simple: + for expr in exprs: + output_names, aliases = evaluate_output_names_and_aliases( + expr, self._df, self._keys + ) + if expr._depth == 0: + # e.g. agg(nw.len()) # noqa: ERA001 + function_name = POLARS_TO_PANDAS_AGGREGATIONS.get( + expr._function_name, expr._function_name + ) + for alias in aliases: + expected_old_names.append(f"{self._keys[0]}_{function_name}") + simple_aggs[self._keys[0]].append(function_name) + simple_agg_new_names.append(alias) + continue + + # e.g. agg(nw.mean('a')) # noqa: ERA001 + function_name = re.sub(r"(\w+->)", "", expr._function_name) function_name = POLARS_TO_PANDAS_AGGREGATIONS.get( - expr._function_name, expr._function_name + function_name, function_name ) - for output_name in expr._output_names: - new_names.append(output_name) - expected_old_names.append(f"{keys[0]}_{function_name}") - simple_aggs[keys[0]].append(function_name) - continue - - # e.g. agg(nw.mean('a')) # noqa: ERA001 - if ( - expr._depth != 1 or expr._root_names is None or expr._output_names is None - ): # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" - raise AssertionError(msg) - - function_name = remove_prefix(expr._function_name, "col->") - function_name = POLARS_TO_PANDAS_AGGREGATIONS.get( - function_name, function_name - ) - is_n_unique = function_name == "nunique" - is_std = function_name == "std" - is_var = function_name == "var" - for root_name, output_name in zip(expr._root_names, expr._output_names): - if is_n_unique: - nunique_aggs[output_name] = root_name - elif is_std and (ddof := expr._kwargs["ddof"]) != 1: - std_aggs[ddof][0].append(root_name) - std_aggs[ddof][1].append(output_name) - elif is_var and (ddof := expr._kwargs["ddof"]) != 1: - var_aggs[ddof][0].append(root_name) - var_aggs[ddof][1].append(output_name) - else: - new_names.append(output_name) - expected_old_names.append(f"{root_name}_{function_name}") - simple_aggs[root_name].append(function_name) - - result_aggs = [] + is_n_unique = function_name == "nunique" + is_std = function_name == "std" + is_var = function_name == "var" + for output_name, alias in zip(output_names, aliases): + if is_n_unique: + nunique_aggs[alias] = output_name + elif is_std and (ddof := expr._kwargs["ddof"]) != 1: + std_aggs[ddof][0].append(output_name) + std_aggs[ddof][1].append(alias) + elif is_var and (ddof := expr._kwargs["ddof"]) != 1: + var_aggs[ddof][0].append(output_name) + var_aggs[ddof][1].append(alias) + else: + expected_old_names.append(f"{output_name}_{function_name}") + simple_aggs[output_name].append(function_name) + simple_agg_new_names.append(alias) + + result_aggs = [] + + if simple_aggs: + result_simple_aggs = self._grouped.agg(simple_aggs) + result_simple_aggs.columns = [ + f"{a}_{b}" for a, b in result_simple_aggs.columns + ] + if not ( + set(result_simple_aggs.columns) == set(expected_old_names) + and len(result_simple_aggs.columns) == len(expected_old_names) + ): # pragma: no cover + msg = ( + f"Safety assertion failed, expected {expected_old_names} " + f"got {result_simple_aggs.columns}, " + "please report a bug at https://github.com/narwhals-dev/narwhals/issues" + ) + raise AssertionError(msg) - if simple_aggs: - result_simple_aggs = grouped.agg(simple_aggs) - result_simple_aggs.columns = [ - f"{a}_{b}" for a, b in result_simple_aggs.columns - ] - if not ( - set(result_simple_aggs.columns) == set(expected_old_names) - and len(result_simple_aggs.columns) == len(expected_old_names) - ): # pragma: no cover - msg = ( - f"Safety assertion failed, expected {expected_old_names} " - f"got {result_simple_aggs.columns}, " - "please report a bug at https://github.com/narwhals-dev/narwhals/issues" + # Rename columns, being very careful + expected_old_names_indices: dict[str, list[int]] = ( + collections.defaultdict(list) ) - raise AssertionError(msg) - - # Rename columns, being very careful - expected_old_names_indices: dict[str, list[int]] = collections.defaultdict( - list - ) - for idx, item in enumerate(expected_old_names): - expected_old_names_indices[item].append(idx) - index_map: list[int] = [ - expected_old_names_indices[item].pop(0) - for item in result_simple_aggs.columns - ] - new_names = [new_names[i] for i in index_map] - result_simple_aggs.columns = new_names + for idx, item in enumerate(expected_old_names): + expected_old_names_indices[item].append(idx) + index_map: list[int] = [ + expected_old_names_indices[item].pop(0) + for item in result_simple_aggs.columns + ] + result_simple_aggs.columns = [simple_agg_new_names[i] for i in index_map] + result_aggs.append(result_simple_aggs) - result_aggs.append(result_simple_aggs) + if nunique_aggs: + result_nunique_aggs = self._grouped[list(nunique_aggs.values())].nunique( + dropna=False + ) + result_nunique_aggs.columns = list(nunique_aggs.keys()) + + result_aggs.append(result_nunique_aggs) + + if std_aggs: + result_aggs.extend( + [ + set_columns( + self._grouped[std_output_names].std(ddof=ddof), + columns=std_aliases, + implementation=implementation, + backend_version=backend_version, + ) + for ddof, (std_output_names, std_aliases) in std_aggs.items() + ] + ) + if var_aggs: + result_aggs.extend( + [ + set_columns( + self._grouped[var_output_names].var(ddof=ddof), + columns=var_aliases, + implementation=implementation, + backend_version=backend_version, + ) + for ddof, (var_output_names, var_aliases) in var_aggs.items() + ] + ) - if nunique_aggs: - result_nunique_aggs = grouped[list(nunique_aggs.values())].nunique( - dropna=False + if result_aggs: + output_names_counter = collections.Counter( + [c for frame in result_aggs for c in frame] + ) + if any(v > 1 for v in output_names_counter.values()): + msg = "" + for key, value in output_names_counter.items(): + if value > 1: + msg += f"\n- '{key}' {value} times" + else: # pragma: no cover + pass + msg = f"Expected unique output names, got:{msg}" + raise ValueError(msg) + result = horizontal_concat( + dfs=result_aggs, + implementation=implementation, + backend_version=backend_version, + ) + else: + # No aggregation provided + result = self._df.__native_namespace__().DataFrame( + list(self._grouped.groups.keys()), columns=self._keys + ) + # Keep inplace=True to avoid making a redundant copy. + # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files + result.reset_index(inplace=True) # noqa: PD002 + return self._df._from_native_frame( + select_columns_by_name(result, new_names, backend_version, implementation) ) - result_nunique_aggs.columns = list(nunique_aggs.keys()) - - result_aggs.append(result_nunique_aggs) - if std_aggs: - result_aggs.extend( - [ - set_columns( - grouped[std_root_names].std(ddof=ddof), - columns=std_output_names, - implementation=implementation, - backend_version=backend_version, - ) - for ddof, (std_root_names, std_output_names) in std_aggs.items() - ] - ) - if var_aggs: - result_aggs.extend( - [ - set_columns( - grouped[var_root_names].var(ddof=ddof), - columns=var_output_names, - implementation=implementation, - backend_version=backend_version, - ) - for ddof, (var_root_names, var_output_names) in var_aggs.items() - ] + if self._df._native_frame.empty: + # Don't even attempt this, it's way too inconsistent across pandas versions. + msg = ( + "No results for group-by aggregation.\n\n" + "Hint: you were probably trying to apply a non-elementary aggregation with a " + "pandas-like API.\n" + "Please rewrite your query such that group-by aggregations " + "are elementary. For example, instead of:\n\n" + " df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" + "use:\n\n" + " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" ) + raise ValueError(msg) + + warnings.warn( + "Found complex group-by expression, which can't be expressed efficiently with the " + "pandas API. If you can, please rewrite your query such that group-by aggregations " + "are simple (e.g. mean, std, min, max, ...). \n\n" + "Please see: " + "https://narwhals-dev.github.io/narwhals/pandas_like_concepts/improve_group_by_operation/", + UserWarning, + stacklevel=find_stacklevel(), + ) - if result_aggs: - output_names_counter = collections.Counter( - [c for frame in result_aggs for c in frame] - ) - if any(v > 1 for v in output_names_counter.values()): - msg = "" - for key, value in output_names_counter.items(): - if value > 1: - msg += f"\n- '{key}' {value} times" - else: # pragma: no cover - pass - msg = f"Expected unique output names, got:{msg}" - raise ValueError(msg) - result = horizontal_concat( - dfs=result_aggs, + def func(df: Any) -> Any: + out_group = [] + out_names = [] + for expr in exprs: + results_keys = expr(self._df._from_native_frame(df)) + for result_keys in results_keys: + out_group.append(result_keys._native_series.iloc[0]) + out_names.append(result_keys.name) + return native_series_from_iterable( + out_group, + index=out_names, + name="", implementation=implementation, - backend_version=backend_version, ) - else: - # No aggregation provided - result = native_namespace.DataFrame(list(grouped.groups.keys()), columns=keys) - # Keep inplace=True to avoid making a redundant copy. - # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files - result.reset_index(inplace=True) # noqa: PD002 - return from_dataframe( - select_columns_by_name(result, output_names, backend_version, implementation) - ) - if dataframe_is_empty: - # Don't even attempt this, it's way too inconsistent across pandas versions. - msg = ( - "No results for group-by aggregation.\n\n" - "Hint: you were probably trying to apply a non-elementary aggregation with a " - "pandas-like API.\n" - "Please rewrite your query such that group-by aggregations " - "are elementary. For example, instead of:\n\n" - " df.group_by('a').agg(nw.col('b').round(2).mean())\n\n" - "use:\n\n" - " df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n" - ) - raise ValueError(msg) + if implementation is Implementation.PANDAS and backend_version >= (2, 2): + result_complex = self._grouped.apply(func, include_groups=False) + else: # pragma: no cover + result_complex = self._grouped.apply(func) - warnings.warn( - "Found complex group-by expression, which can't be expressed efficiently with the " - "pandas API. If you can, please rewrite your query such that group-by aggregations " - "are simple (e.g. mean, std, min, max, ...). \n\n" - "Please see: " - "https://narwhals-dev.github.io/narwhals/pandas_like_concepts/improve_group_by_operation/", - UserWarning, - stacklevel=find_stacklevel(), - ) + # Keep inplace=True to avoid making a redundant copy. + # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files + result_complex.reset_index(inplace=True) # noqa: PD002 - def func(df: Any) -> Any: - out_group = [] - out_names = [] - for expr in exprs: - results_keys = expr(from_dataframe(df)) - for result_keys in results_keys: - out_group.append(result_keys._native_series.iloc[0]) - out_names.append(result_keys.name) - return native_series_from_iterable( - out_group, - index=out_names, - name="", - implementation=implementation, + return self._df._from_native_frame( + select_columns_by_name( + result_complex, new_names, backend_version, implementation + ) ) - if implementation is Implementation.PANDAS and backend_version >= (2, 2): - result_complex = grouped.apply(func, include_groups=False) - else: # pragma: no cover - result_complex = grouped.apply(func) - - # Keep inplace=True to avoid making a redundant copy. - # This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files - result_complex.reset_index(inplace=True) # noqa: PD002 - - return from_dataframe( - select_columns_by_name( - result_complex, output_names, backend_version, implementation - ) - ) + def __iter__(self: Self) -> Iterator[tuple[Any, PandasLikeDataFrame]]: + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message=".*a length 1 tuple will be returned", + category=FutureWarning, + ) + for key, group in self._grouped: + yield (key, self._df._from_native_frame(group)) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 8ace4351b..8ba7879f6 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -8,9 +8,9 @@ from typing import Literal from typing import Sequence -from narwhals._expression_parsing import combine_root_names +from narwhals._expression_parsing import combine_alias_output_names +from narwhals._expression_parsing import combine_evaluate_output_names from narwhals._expression_parsing import parse_into_exprs -from narwhals._expression_parsing import reduce_output_names from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.expr import PandasLikeExpr from narwhals._pandas_like.selectors import PandasSelectorNamespace @@ -57,16 +57,16 @@ def _create_expr_from_callable( *, depth: int, function_name: str, - root_names: list[str] | None, - output_names: list[str] | None, + evaluate_output_names: Callable[[PandasLikeDataFrame], Sequence[str]], + alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, kwargs: dict[str, Any], ) -> PandasLikeExpr: return PandasLikeExpr( func, depth=depth, function_name=function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=evaluate_output_names, + alias_output_names=alias_output_names, implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -90,8 +90,8 @@ def _create_expr_from_series(self: Self, series: PandasLikeSeries) -> PandasLike lambda _df: [series], depth=0, function_name="series", - root_names=None, - output_names=None, + evaluate_output_names=lambda _df: [series.name], + alias_output_names=None, implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -136,8 +136,8 @@ def all(self: Self) -> PandasLikeExpr: ], depth=0, function_name="all", - root_names=None, - output_names=None, + evaluate_output_names=lambda df: df.columns, + alias_output_names=None, implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -162,8 +162,8 @@ def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries: lambda df: [_lit_pandas_series(df)], depth=0, function_name="lit", - root_names=None, - output_names=["literal"], + evaluate_output_names=lambda _df: ["literal"], + alias_output_names=None, implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -184,8 +184,8 @@ def len(self: Self) -> PandasLikeExpr: ], depth=0, function_name="len", - root_names=None, - output_names=["len"], + evaluate_output_names=lambda _df: ["len"], + alias_output_names=None, implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -208,8 +208,8 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: func=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="sum_horizontal", - root_names=combine_root_names(parsed_exprs), - output_names=reduce_output_names(parsed_exprs), + evaluate_output_names=combine_evaluate_output_names(*parsed_exprs), + alias_output_names=combine_alias_output_names(*parsed_exprs), kwargs={"exprs": exprs}, ) @@ -224,8 +224,8 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: func=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="all_horizontal", - root_names=combine_root_names(parsed_exprs), - output_names=reduce_output_names(parsed_exprs), + evaluate_output_names=combine_evaluate_output_names(*parsed_exprs), + alias_output_names=combine_alias_output_names(*parsed_exprs), kwargs={"exprs": exprs}, ) @@ -240,8 +240,8 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: func=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="any_horizontal", - root_names=combine_root_names(parsed_exprs), - output_names=reduce_output_names(parsed_exprs), + evaluate_output_names=combine_evaluate_output_names(*parsed_exprs), + alias_output_names=combine_alias_output_names(*parsed_exprs), kwargs={"exprs": exprs}, ) @@ -263,8 +263,8 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: func=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="mean_horizontal", - root_names=combine_root_names(parsed_exprs), - output_names=reduce_output_names(parsed_exprs), + evaluate_output_names=combine_evaluate_output_names(*parsed_exprs), + alias_output_names=combine_alias_output_names(*parsed_exprs), kwargs={"exprs": exprs}, ) @@ -289,8 +289,8 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: func=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="min_horizontal", - root_names=combine_root_names(parsed_exprs), - output_names=reduce_output_names(parsed_exprs), + evaluate_output_names=combine_evaluate_output_names(*parsed_exprs), + alias_output_names=combine_alias_output_names(*parsed_exprs), kwargs={"exprs": exprs}, ) @@ -315,8 +315,8 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: func=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="max_horizontal", - root_names=combine_root_names(parsed_exprs), - output_names=reduce_output_names(parsed_exprs), + evaluate_output_names=combine_evaluate_output_names(*parsed_exprs), + alias_output_names=combine_alias_output_names(*parsed_exprs), kwargs={"exprs": exprs}, ) @@ -421,8 +421,8 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: func=func, depth=max(x._depth for x in parsed_exprs) + 1, function_name="concat_str", - root_names=combine_root_names(parsed_exprs), - output_names=reduce_output_names(parsed_exprs), + evaluate_output_names=combine_evaluate_output_names(*parsed_exprs), + alias_output_names=combine_alias_output_names(*parsed_exprs), kwargs={ "exprs": exprs, "separator": separator, @@ -500,8 +500,10 @@ def then(self: Self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasTh self, depth=0, function_name="whenthen", - root_names=None, - output_names=None, + evaluate_output_names=getattr( + value, "_evaluate_output_names", lambda _df: ["literal"] + ), + alias_output_names=getattr(value, "_alias_output_names", None), implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -516,8 +518,8 @@ def __init__( *, depth: int, function_name: str, - root_names: list[str] | None, - output_names: list[str] | None, + evaluate_output_names: Callable[[PandasLikeDataFrame], Sequence[str]], + alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, implementation: Implementation, backend_version: tuple[int, ...], version: Version, @@ -529,8 +531,8 @@ def __init__( self._call = call self._depth = depth self._function_name = function_name - self._root_names = root_names - self._output_names = output_names + self._evaluate_output_names = evaluate_output_names + self._alias_output_names = alias_output_names self._kwargs = kwargs def otherwise( diff --git a/narwhals/_pandas_like/selectors.py b/narwhals/_pandas_like/selectors.py index b3518283f..17abc541c 100644 --- a/narwhals/_pandas_like/selectors.py +++ b/narwhals/_pandas_like/selectors.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Sequence from narwhals._pandas_like.expr import PandasLikeExpr from narwhals.utils import import_dtypes_module @@ -32,12 +33,15 @@ def by_dtype(self: Self, dtypes: list[DType | type[DType]]) -> PandasSelector: def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return [df[col] for col in df.columns if df.schema[col] in dtypes] + def evalute_output_names(df: PandasLikeDataFrame) -> Sequence[str]: + return [col for col in df.columns if df.schema[col] in dtypes] + return PandasSelector( func, depth=0, - function_name="type_selector", - root_names=None, - output_names=None, + function_name="selector", + evaluate_output_names=evalute_output_names, + alias_output_names=None, implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -80,9 +84,9 @@ def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: return PandasSelector( func, depth=0, - function_name="type_selector", - root_names=None, - output_names=None, + function_name="selector", + evaluate_output_names=lambda df: df.columns, + alias_output_names=None, implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -96,8 +100,6 @@ def __repr__(self) -> str: # pragma: no cover f"PandasSelector(" f"depth={self._depth}, " f"function_name={self._function_name}, " - f"root_names={self._root_names}, " - f"output_names={self._output_names}" ) def _to_expr(self: Self) -> PandasLikeExpr: @@ -105,8 +107,8 @@ def _to_expr(self: Self) -> PandasLikeExpr: self._call, depth=self._depth, function_name=self._function_name, - root_names=self._root_names, - output_names=self._output_names, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -117,16 +119,22 @@ def __sub__(self: Self, other: PandasSelector | Any) -> PandasSelector | Any: if isinstance(other, PandasSelector): def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) lhs = self._call(df) - rhs = other._call(df) - return [x for x in lhs if x.name not in {x.name for x in rhs}] + return [x for x, name in zip(lhs, lhs_names) if name not in rhs_names] + + def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) + return [x for x in lhs_names if x not in rhs_names] return PandasSelector( call, depth=0, - function_name="type_selector", - root_names=None, - output_names=None, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -139,16 +147,26 @@ def __or__(self: Self, other: PandasSelector | Any) -> PandasSelector | Any: if isinstance(other, PandasSelector): def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) lhs = self._call(df) rhs = other._call(df) - return [*(x for x in lhs if x.name not in {x.name for x in rhs}), *rhs] + return [ + *(x for x, name in zip(lhs, lhs_names) if name not in rhs_names), + *rhs, + ] + + def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) + return [*(x for x in lhs_names if x not in rhs_names), *rhs_names] return PandasSelector( call, depth=0, - function_name="type_selector", - root_names=None, - output_names=None, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, implementation=self._implementation, backend_version=self._backend_version, version=self._version, @@ -161,16 +179,22 @@ def __and__(self: Self, other: PandasSelector | Any) -> PandasSelector | Any: if isinstance(other, PandasSelector): def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) lhs = self._call(df) - rhs = other._call(df) - return [x for x in lhs if x.name in {x.name for x in rhs}] + return [x for x, name in zip(lhs, lhs_names) if name in rhs_names] + + def evaluate_output_names(df: PandasLikeDataFrame) -> list[str]: + lhs_names = self._evaluate_output_names(df) + rhs_names = other._evaluate_output_names(df) + return [x for x in lhs_names if x in rhs_names] return PandasSelector( call, depth=0, - function_name="type_selector", - root_names=None, - output_names=None, + function_name="selector", + evaluate_output_names=evaluate_output_names, + alias_output_names=None, implementation=self._implementation, backend_version=self._backend_version, version=self._version, diff --git a/narwhals/_spark_like/expr.py b/narwhals/_spark_like/expr.py index 22ca07a7b..df9e8d42a 100644 --- a/narwhals/_spark_like/expr.py +++ b/narwhals/_spark_like/expr.py @@ -9,11 +9,10 @@ from pyspark.sql import Window from pyspark.sql import functions as F # noqa: N812 -from narwhals._expression_parsing import infer_new_root_output_names from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace from narwhals._spark_like.expr_name import SparkLikeExprNameNamespace from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace -from narwhals._spark_like.utils import get_column_name +from narwhals._spark_like.utils import binary_operation_returns_scalar from narwhals._spark_like.utils import maybe_evaluate from narwhals._spark_like.utils import narwhals_to_native_dtype from narwhals.typing import CompliantExpr @@ -39,8 +38,8 @@ def __init__( *, depth: int, function_name: str, - root_names: list[str] | None, - output_names: list[str] | None, + evaluate_output_names: Callable[[SparkLikeLazyFrame], Sequence[str]], + alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, # Whether the expression is a length-1 Column resulting from # a reduction, such as `nw.col('a').sum()` returns_scalar: bool, @@ -51,8 +50,8 @@ def __init__( self._call = call self._depth = depth self._function_name = function_name - self._root_names = root_names - self._output_names = output_names + self._evaluate_output_names = evaluate_output_names + self._alias_output_names = alias_output_names self._returns_scalar = returns_scalar self._backend_version = backend_version self._version = version @@ -85,8 +84,8 @@ def func(_: SparkLikeLazyFrame) -> list[Column]: func, depth=0, function_name="col", - root_names=list(column_names), - output_names=list(column_names), + evaluate_output_names=lambda _df: list(column_names), + alias_output_names=None, returns_scalar=False, backend_version=backend_version, version=version, @@ -108,8 +107,8 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: func, depth=0, function_name="nth", - root_names=None, - output_names=None, + evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], + alias_output_names=None, returns_scalar=False, backend_version=backend_version, version=version, @@ -122,32 +121,29 @@ def _from_call( expr_name: str, *, returns_scalar: bool, - **kwargs: Any, + **expressifiable_args: Self | Any, ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: - results = [] - inputs = self._call(df) - _kwargs = {key: maybe_evaluate(df, value) for key, value in kwargs.items()} - for _input in inputs: - input_col_name = get_column_name(df, _input) - column_result = call(_input, **_kwargs) - if not returns_scalar: - column_result = column_result.alias(input_col_name) - results.append(column_result) - return results - - root_names, output_names = infer_new_root_output_names(self, **kwargs) + native_series_list = self._call(df) + other_native_series = { + key: maybe_evaluate(df, value) + for key, value in expressifiable_args.items() + } + return [ + call(native_series, **other_native_series) + for native_series in native_series_list + ] return self.__class__( func, depth=self._depth + 1, function_name=f"{self._function_name}->{expr_name}", - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, returns_scalar=self._returns_scalar or returns_scalar, backend_version=self._backend_version, version=self._version, - kwargs=kwargs, + kwargs=expressifiable_args, ) def __eq__(self: Self, other: SparkLikeExpr) -> Self: # type: ignore[override] @@ -155,7 +151,7 @@ def __eq__(self: Self, other: SparkLikeExpr) -> Self: # type: ignore[override] lambda _input, other: _input.__eq__(other), "__eq__", other=other, - returns_scalar=False, + returns_scalar=binary_operation_returns_scalar(self, other), ) def __ne__(self: Self, other: SparkLikeExpr) -> Self: # type: ignore[override] @@ -163,7 +159,7 @@ def __ne__(self: Self, other: SparkLikeExpr) -> Self: # type: ignore[override] lambda _input, other: _input.__ne__(other), "__ne__", other=other, - returns_scalar=False, + returns_scalar=binary_operation_returns_scalar(self, other), ) def __add__(self: Self, other: SparkLikeExpr) -> Self: @@ -171,7 +167,7 @@ def __add__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__add__(other), "__add__", other=other, - returns_scalar=False, + returns_scalar=binary_operation_returns_scalar(self, other), ) def __sub__(self: Self, other: SparkLikeExpr) -> Self: @@ -179,7 +175,7 @@ def __sub__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__sub__(other), "__sub__", other=other, - returns_scalar=False, + returns_scalar=binary_operation_returns_scalar(self, other), ) def __mul__(self: Self, other: SparkLikeExpr) -> Self: @@ -187,7 +183,7 @@ def __mul__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__mul__(other), "__mul__", other=other, - returns_scalar=False, + returns_scalar=binary_operation_returns_scalar(self, other), ) def __truediv__(self: Self, other: SparkLikeExpr) -> Self: @@ -195,7 +191,7 @@ def __truediv__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__truediv__(other), "__truediv__", other=other, - returns_scalar=False, + returns_scalar=binary_operation_returns_scalar(self, other), ) def __floordiv__(self: Self, other: SparkLikeExpr) -> Self: @@ -203,7 +199,10 @@ def _floordiv(_input: Column, other: Column) -> Column: return F.floor(_input / other) return self._from_call( - _floordiv, "__floordiv__", other=other, returns_scalar=False + _floordiv, + "__floordiv__", + other=other, + returns_scalar=binary_operation_returns_scalar(self, other), ) def __pow__(self: Self, other: SparkLikeExpr) -> Self: @@ -211,7 +210,7 @@ def __pow__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__pow__(other), "__pow__", other=other, - returns_scalar=False, + returns_scalar=binary_operation_returns_scalar(self, other), ) def __mod__(self: Self, other: SparkLikeExpr) -> Self: @@ -219,7 +218,7 @@ def __mod__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__mod__(other), "__mod__", other=other, - returns_scalar=False, + returns_scalar=binary_operation_returns_scalar(self, other), ) def __ge__(self: Self, other: SparkLikeExpr) -> Self: @@ -227,7 +226,7 @@ def __ge__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__ge__(other), "__ge__", other=other, - returns_scalar=False, + returns_scalar=binary_operation_returns_scalar(self, other), ) def __gt__(self: Self, other: SparkLikeExpr) -> Self: @@ -235,7 +234,7 @@ def __gt__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input > other, "__gt__", other=other, - returns_scalar=False, + returns_scalar=binary_operation_returns_scalar(self, other), ) def __le__(self: Self, other: SparkLikeExpr) -> Self: @@ -243,7 +242,7 @@ def __le__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__le__(other), "__le__", other=other, - returns_scalar=False, + returns_scalar=binary_operation_returns_scalar(self, other), ) def __lt__(self: Self, other: SparkLikeExpr) -> Self: @@ -251,7 +250,7 @@ def __lt__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__lt__(other), "__lt__", other=other, - returns_scalar=False, + returns_scalar=binary_operation_returns_scalar(self, other), ) def __and__(self: Self, other: SparkLikeExpr) -> Self: @@ -259,7 +258,7 @@ def __and__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__and__(other), "__and__", other=other, - returns_scalar=False, + returns_scalar=binary_operation_returns_scalar(self, other), ) def __or__(self: Self, other: SparkLikeExpr) -> Self: @@ -267,7 +266,7 @@ def __or__(self: Self, other: SparkLikeExpr) -> Self: lambda _input, other: _input.__or__(other), "__or__", other=other, - returns_scalar=False, + returns_scalar=binary_operation_returns_scalar(self, other), ) def __invert__(self: Self) -> Self: @@ -281,17 +280,20 @@ def abs(self: Self) -> Self: return self._from_call(F.abs, "abs", returns_scalar=self._returns_scalar) def alias(self: Self, name: str) -> Self: - def _alias(df: SparkLikeLazyFrame) -> list[Column]: - return [col.alias(name) for col in self._call(df)] + def alias_output_names(names: Sequence[str]) -> Sequence[str]: + if len(names) != 1: + msg = f"Expected function with single output, found output names: {names}" + raise ValueError(msg) + return [name] # Define this one manually, so that we can # override `output_names` and not increase depth return self.__class__( - _alias, + self._call, depth=self._depth, function_name=self._function_name, - root_names=self._root_names, - output_names=[name], + evaluate_output_names=self._evaluate_output_names, + alias_output_names=alias_output_names, returns_scalar=self._returns_scalar, backend_version=self._backend_version, version=self._version, @@ -490,14 +492,6 @@ def _n_unique(_input: Column) -> Column: return self._from_call(_n_unique, "n_unique", returns_scalar=True) def over(self: Self, keys: list[str]) -> Self: - if self._output_names is None: - msg = ( - "Anonymous expressions are not supported in over.\n" - "Instead of `nw.all()`, try using a named expression, such as " - "`nw.col('a', 'b')`\n" - ) - raise ValueError(msg) - def func(df: SparkLikeLazyFrame) -> list[Column]: return [expr.over(Window.partitionBy(*keys)) for expr in self._call(df)] @@ -505,8 +499,8 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: func, depth=self._depth + 1, function_name=self._function_name + "->over", - root_names=self._root_names, - output_names=self._output_names, + evaluate_output_names=self._evaluate_output_names, + alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, returns_scalar=False, diff --git a/narwhals/_spark_like/expr_name.py b/narwhals/_spark_like/expr_name.py index c32305270..8682d3dd1 100644 --- a/narwhals/_spark_like/expr_name.py +++ b/narwhals/_spark_like/expr_name.py @@ -3,8 +3,6 @@ from typing import TYPE_CHECKING from typing import Callable -from narwhals.exceptions import AnonymousExprError - if TYPE_CHECKING: from typing_extensions import Self @@ -16,20 +14,12 @@ def __init__(self: Self, expr: SparkLikeExpr) -> None: self._compliant_expr = expr def keep(self: Self) -> SparkLikeExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.keep" - raise AnonymousExprError.from_expr_name(msg) - return self._compliant_expr.__class__( - lambda df: [ - expr.alias(name) - for expr, name in zip(self._compliant_expr._call(df), root_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=root_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=None, returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -37,22 +27,14 @@ def keep(self: Self) -> SparkLikeExpr: ) def map(self: Self, function: Callable[[str], str]) -> SparkLikeExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.map" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [function(str(name)) for name in root_names] - return self._compliant_expr.__class__( - lambda df: [ - expr.alias(name) - for expr, name in zip(self._compliant_expr._call(df), output_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + function(str(name)) for name in output_names + ], returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -60,21 +42,14 @@ def map(self: Self, function: Callable[[str], str]) -> SparkLikeExpr: ) def prefix(self: Self, prefix: str) -> SparkLikeExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.prefix" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [prefix + str(name) for name in root_names] return self._compliant_expr.__class__( - lambda df: [ - expr.alias(name) - for expr, name in zip(self._compliant_expr._call(df), output_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + f"{prefix}{output_name}" for output_name in output_names + ], returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -82,22 +57,14 @@ def prefix(self: Self, prefix: str) -> SparkLikeExpr: ) def suffix(self: Self, suffix: str) -> SparkLikeExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.suffix" - raise AnonymousExprError.from_expr_name(msg) - - output_names = [str(name) + suffix for name in root_names] - return self._compliant_expr.__class__( - lambda df: [ - expr.alias(name) - for expr, name in zip(self._compliant_expr._call(df), output_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + f"{output_name}{suffix}" for output_name in output_names + ], returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -105,21 +72,14 @@ def suffix(self: Self, suffix: str) -> SparkLikeExpr: ) def to_lowercase(self: Self) -> SparkLikeExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.to_lowercase" - raise AnonymousExprError.from_expr_name(msg) - output_names = [str(name).lower() for name in root_names] - return self._compliant_expr.__class__( - lambda df: [ - expr.alias(name) - for expr, name in zip(self._compliant_expr._call(df), output_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + str(name).lower() for name in output_names + ], returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, @@ -127,21 +87,14 @@ def to_lowercase(self: Self) -> SparkLikeExpr: ) def to_uppercase(self: Self) -> SparkLikeExpr: - root_names = self._compliant_expr._root_names - if root_names is None: - msg = ".name.to_uppercase" - raise AnonymousExprError.from_expr_name(msg) - output_names = [str(name).upper() for name in root_names] - return self._compliant_expr.__class__( - lambda df: [ - expr.alias(name) - for expr, name in zip(self._compliant_expr._call(df), output_names) - ], + self._compliant_expr._call, depth=self._compliant_expr._depth, function_name=self._compliant_expr._function_name, - root_names=root_names, - output_names=output_names, + evaluate_output_names=self._compliant_expr._evaluate_output_names, + alias_output_names=lambda output_names: [ + str(name).upper() for name in output_names + ], returns_scalar=self._compliant_expr._returns_scalar, backend_version=self._compliant_expr._backend_version, version=self._compliant_expr._version, diff --git a/narwhals/_spark_like/group_by.py b/narwhals/_spark_like/group_by.py index 319a95dc2..dca7dbf16 100644 --- a/narwhals/_spark_like/group_by.py +++ b/narwhals/_spark_like/group_by.py @@ -1,6 +1,6 @@ from __future__ import annotations -from copy import copy +import re from functools import partial from typing import TYPE_CHECKING from typing import Any @@ -12,9 +12,7 @@ from narwhals._expression_parsing import is_simple_aggregation from narwhals._spark_like.utils import _std from narwhals._spark_like.utils import _var -from narwhals.exceptions import AnonymousExprError from narwhals.utils import parse_version -from narwhals.utils import remove_prefix if TYPE_CHECKING: from pyspark.sql import Column @@ -46,14 +44,6 @@ def agg( self: Self, *exprs: SparkLikeExpr, ) -> SparkLikeLazyFrame: - output_names: list[str] = copy(self._keys) - for expr in exprs: - if expr._output_names is None: # pragma: no cover - msg = "group_by.agg" - raise AnonymousExprError.from_expr_name(msg) - - output_names.extend(expr._output_names) - return agg_pyspark( self._df, self._grouped, @@ -126,31 +116,32 @@ def agg_pyspark( simple_aggregations: dict[str, Column] = {} for expr in exprs: + output_names = expr._evaluate_output_names(df) + aliases = ( + output_names + if expr._alias_output_names is None + else expr._alias_output_names(output_names) + ) + if len(output_names) > 1: + # For multi-output aggregations, e.g. `df.group_by('a').agg(nw.all().mean())`, we skip + # the keys, else they would appear duplicated in the output. + output_names, aliases = zip( + *[(x, alias) for x, alias in zip(output_names, aliases) if x not in keys] + ) if expr._depth == 0: # pragma: no cover # e.g. agg(nw.len()) # noqa: ERA001 - if expr._output_names is None: # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" - raise AssertionError(msg) agg_func = get_spark_function(expr._function_name, **expr._kwargs) - simple_aggregations.update( - {output_name: agg_func(keys[0]) for output_name in expr._output_names} - ) + simple_aggregations.update({alias: agg_func(keys[0]) for alias in aliases}) continue # e.g. agg(nw.mean('a')) # noqa: ERA001 - if ( - expr._depth != 1 or expr._root_names is None or expr._output_names is None - ): # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" - raise AssertionError(msg) - - function_name = remove_prefix(expr._function_name, "col->") + function_name = re.sub(r"(\w+->)", "", expr._function_name) agg_func = get_spark_function(function_name, **expr._kwargs) simple_aggregations.update( { - output_name: agg_func(root_name) - for root_name, output_name in zip(expr._root_names, expr._output_names) + alias: agg_func(output_name) + for alias, output_name in zip(aliases, output_names) } ) diff --git a/narwhals/_spark_like/namespace.py b/narwhals/_spark_like/namespace.py index 59a15dde1..c17500e62 100644 --- a/narwhals/_spark_like/namespace.py +++ b/narwhals/_spark_like/namespace.py @@ -4,16 +4,18 @@ from functools import reduce from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Iterable from typing import Literal +from typing import Sequence from pyspark.sql import functions as F # noqa: N812 +from pyspark.sql.types import IntegerType -from narwhals._expression_parsing import combine_root_names -from narwhals._expression_parsing import reduce_output_names +from narwhals._expression_parsing import combine_alias_output_names +from narwhals._expression_parsing import combine_evaluate_output_names from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.expr import SparkLikeExpr -from narwhals._spark_like.utils import get_column_name from narwhals.typing import CompliantNamespace if TYPE_CHECKING: @@ -34,16 +36,14 @@ def __init__( def all(self: Self) -> SparkLikeExpr: def _all(df: SparkLikeLazyFrame) -> list[Column]: - import pyspark.sql.functions as F # noqa: N812 - return [F.col(col_name) for col_name in df.columns] return SparkLikeExpr( call=_all, depth=0, function_name="all", - root_names=None, - output_names=None, + evaluate_output_names=lambda df: df.columns, + alias_output_names=None, returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -68,14 +68,14 @@ def lit(self: Self, value: object, dtype: DType | None) -> SparkLikeExpr: def _lit(_: SparkLikeLazyFrame) -> list[Column]: import pyspark.sql.functions as F # noqa: N812 - return [F.lit(value).alias("literal")] + return [F.lit(value)] return SparkLikeExpr( call=_lit, depth=0, function_name="lit", - root_names=None, - output_names=["literal"], + evaluate_output_names=lambda _df: ["literal"], + alias_output_names=None, returns_scalar=True, backend_version=self._backend_version, version=self._version, @@ -84,16 +84,14 @@ def _lit(_: SparkLikeLazyFrame) -> list[Column]: def len(self: Self) -> SparkLikeExpr: def func(_: SparkLikeLazyFrame) -> list[Column]: - import pyspark.sql.functions as F # noqa: N812 - - return [F.count("*").alias("len")] + return [F.count("*")] return SparkLikeExpr( func, depth=0, function_name="len", - root_names=None, - output_names=["len"], + evaluate_output_names=lambda _df: ["len"], + alias_output_names=None, returns_scalar=True, backend_version=self._backend_version, version=self._version, @@ -103,15 +101,14 @@ def func(_: SparkLikeLazyFrame) -> list[Column]: def all_horizontal(self: Self, *exprs: SparkLikeExpr) -> SparkLikeExpr: def func(df: SparkLikeLazyFrame) -> list[Column]: cols = [c for _expr in exprs for c in _expr(df)] - col_name = get_column_name(df, cols[0]) - return [reduce(operator.and_, cols).alias(col_name)] + return [reduce(operator.and_, cols)] return SparkLikeExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="all_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -121,15 +118,14 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: def any_horizontal(self: Self, *exprs: SparkLikeExpr) -> SparkLikeExpr: def func(df: SparkLikeLazyFrame) -> list[Column]: cols = [c for _expr in exprs for c in _expr(df)] - col_name = get_column_name(df, cols[0]) - return [reduce(operator.or_, cols).alias(col_name)] + return [reduce(operator.or_, cols)] return SparkLikeExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="any_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -138,23 +134,20 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: def sum_horizontal(self: Self, *exprs: SparkLikeExpr) -> SparkLikeExpr: def func(df: SparkLikeLazyFrame) -> list[Column]: - import pyspark.sql.functions as F # noqa: N812 - cols = [c for _expr in exprs for c in _expr(df)] - col_name = get_column_name(df, cols[0]) return [ reduce( operator.add, (F.coalesce(col, F.lit(0)) for col in cols), - ).alias(col_name) + ) ] return SparkLikeExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="sum_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -162,11 +155,8 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: ) def mean_horizontal(self: Self, *exprs: SparkLikeExpr) -> SparkLikeExpr: - from pyspark.sql.types import IntegerType - def func(df: SparkLikeLazyFrame) -> list[Column]: cols = [c for _expr in exprs for c in _expr(df)] - col_name = get_column_name(df, cols[0]) return [ ( reduce(operator.add, (F.coalesce(col, F.lit(0)) for col in cols)) @@ -174,15 +164,15 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: operator.add, (col.isNotNull().cast(IntegerType()) for col in cols), ) - ).alias(col_name) + ) ] return SparkLikeExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="mean_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -192,15 +182,14 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: def max_horizontal(self: Self, *exprs: SparkLikeExpr) -> SparkLikeExpr: def func(df: SparkLikeLazyFrame) -> list[Column]: cols = [c for _expr in exprs for c in _expr(df)] - col_name = get_column_name(df, cols[0]) - return [F.greatest(*cols).alias(col_name)] + return [F.greatest(*cols)] return SparkLikeExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="max_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -210,15 +199,14 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: def min_horizontal(self: Self, *exprs: SparkLikeExpr) -> SparkLikeExpr: def func(df: SparkLikeLazyFrame) -> list[Column]: cols = [c for _expr in exprs for c in _expr(df)] - col_name = get_column_name(df, cols[0]) - return [F.least(*cols).alias(col_name)] + return [F.least(*cols)] return SparkLikeExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="min_horizontal", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -279,7 +267,6 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: cols = [s for _expr in exprs for s in _expr(df)] cols_casted = [s.cast(StringType()) for s in cols] null_mask = [F.isnull(s) for _expr in exprs for s in _expr(df)] - first_column_name = get_column_name(df, cols[0]) if not ignore_nulls: null_mask_result = reduce(lambda x, y: x | y, null_mask) @@ -306,14 +293,14 @@ def func(df: SparkLikeLazyFrame) -> list[Column]: init_value, ) - return [result.alias(first_column_name)] + return [result] return SparkLikeExpr( call=func, depth=max(x._depth for x in exprs) + 1, function_name="concat_str", - root_names=combine_root_names(exprs), - output_names=reduce_output_names(exprs), + evaluate_output_names=combine_evaluate_output_names(*exprs), + alias_output_names=combine_alias_output_names(*exprs), returns_scalar=False, backend_version=self._backend_version, version=self._version, @@ -355,11 +342,9 @@ def __call__(self: Self, df: SparkLikeLazyFrame) -> list[Column]: if isinstance(self._then_value, SparkLikeExpr): value_ = self._then_value(df)[0] - col_name = get_column_name(df, value_) else: # `self._then_value` is a scalar value_ = F.lit(self._then_value) - col_name = "literal" if isinstance(self._otherwise_value, SparkLikeExpr): other_ = self._otherwise_value(df)[0] @@ -367,11 +352,7 @@ def __call__(self: Self, df: SparkLikeLazyFrame) -> list[Column]: # `self._otherwise_value` is a scalar other_ = F.lit(self._otherwise_value) - return [ - F.when(condition=condition, value=value_) - .otherwise(value=other_) - .alias(col_name) - ] + return [F.when(condition=condition, value=value_).otherwise(value=other_)] def then(self: Self, value: SparkLikeExpr | Any) -> SparkLikeThen: self._then_value = value @@ -380,8 +361,10 @@ def then(self: Self, value: SparkLikeExpr | Any) -> SparkLikeThen: self, depth=0, function_name="whenthen", - root_names=None, - output_names=None, + evaluate_output_names=getattr( + value, "_evaluate_output_names", lambda _df: ["literal"] + ), + alias_output_names=getattr(value, "_alias_output_names", None), returns_scalar=self._returns_scalar, backend_version=self._backend_version, version=self._version, @@ -396,8 +379,8 @@ def __init__( *, depth: int, function_name: str, - root_names: list[str] | None, - output_names: list[str] | None, + evaluate_output_names: Callable[[SparkLikeLazyFrame], Sequence[str]], + alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, returns_scalar: bool, backend_version: tuple[int, ...], version: Version, @@ -408,8 +391,8 @@ def __init__( self._call = call self._depth = depth self._function_name = function_name - self._root_names = root_names - self._output_names = output_names + self._evaluate_output_names = evaluate_output_names + self._alias_output_names = alias_output_names self._returns_scalar = returns_scalar self._kwargs = kwargs diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index 37a2426d4..f2f7abfde 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -109,38 +109,26 @@ def narwhals_to_native_dtype( raise AssertionError(msg) -def get_column_name(df: SparkLikeLazyFrame, column: Column) -> str: - return str(df._native_frame.select(column).columns[0]) - - -def _columns_from_expr(df: SparkLikeLazyFrame, expr: SparkLikeExpr) -> list[Column]: - col_output_list = expr._call(df) - if expr._output_names is not None and ( - len(col_output_list) != len(expr._output_names) - ): # pragma: no cover - msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues" - raise AssertionError(msg) - return col_output_list - - def parse_exprs_and_named_exprs( df: SparkLikeLazyFrame, *exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr ) -> dict[str, Column]: - result_columns: dict[str, list[Column]] = {} + native_results: dict[str, list[Column]] = {} for expr in exprs: - column_list = _columns_from_expr(df, expr) - if expr._output_names is None: - output_names = [get_column_name(df, col) for col in column_list] - else: - output_names = expr._output_names - result_columns.update(zip(output_names, column_list)) + native_series_list = expr._call(df) + output_names = expr._evaluate_output_names(df) + if expr._alias_output_names is not None: + output_names = expr._alias_output_names(output_names) + if len(output_names) != len(native_series_list): # pragma: no cover + msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" + raise AssertionError(msg) + native_results.update(zip(output_names, native_series_list)) for col_alias, expr in named_exprs.items(): - columns_list = _columns_from_expr(df, expr) - if len(columns_list) != 1: # pragma: no cover + native_series_list = expr._call(df) + if len(native_series_list) != 1: # pragma: no cover msg = "Named expressions must return a single column" - raise AssertionError(msg) - result_columns[col_alias] = columns_list[0] - return result_columns + raise ValueError(msg) + native_results[col_alias] = native_series_list[0] + return native_results def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any: @@ -187,3 +175,10 @@ def _var(_input: Column | str, ddof: int, np_version: tuple[int, ...]) -> Column input_col = F.col(_input) if isinstance(_input, str) else _input return var(input_col, ddof=ddof) + + +def binary_operation_returns_scalar(lhs: SparkLikeExpr, rhs: SparkLikeExpr | Any) -> bool: + # If `rhs` is a SparkLikeExpr, we look at `_returns_scalar`. If it isn't, + # it means that it was a scalar (e.g. nw.col('a') + 1), and so we default + # to `True`. + return lhs._returns_scalar and getattr(rhs, "_returns_scalar", True) diff --git a/narwhals/exceptions.py b/narwhals/exceptions.py index 817cca4e1..94bd8ebcd 100644 --- a/narwhals/exceptions.py +++ b/narwhals/exceptions.py @@ -71,7 +71,7 @@ def from_invalid_type(cls: type, invalid_type: type) -> InvalidIntoExprError: return InvalidIntoExprError(message) -class AnonymousExprError(ValueError): +class AnonymousExprError(ValueError): # pragma: no cover """Exception raised when trying to perform operations on anonymous expressions.""" def __init__(self: Self, message: str) -> None: diff --git a/narwhals/typing.py b/narwhals/typing.py index c37b24351..42bfc1409 100644 --- a/narwhals/typing.py +++ b/narwhals/typing.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Callable from typing import Generic from typing import Literal from typing import Protocol @@ -80,8 +81,10 @@ def __narwhals_namespace__(self) -> Any: ... class CompliantExpr(Protocol, Generic[CompliantSeriesT_co]): _implementation: Implementation _backend_version: tuple[int, ...] - _output_names: list[str] | None - _root_names: list[str] | None + _evaluate_output_names: Callable[ + [CompliantDataFrame | CompliantLazyFrame], Sequence[str] + ] + _alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None _depth: int _function_name: str _kwargs: dict[str, Any] diff --git a/narwhals/utils.py b/narwhals/utils.py index 752f7150f..4192b820d 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -326,10 +326,10 @@ def import_dtypes_module(version: Version) -> DTypes: return dtypes # type: ignore[return-value] -def remove_prefix(text: str, prefix: str) -> str: +def remove_prefix(text: str, prefix: str) -> str: # pragma: no cover if text.startswith(prefix): return text[len(prefix) :] - return text # pragma: no cover + return text def remove_suffix(text: str, suffix: str) -> str: # pragma: no cover diff --git a/tests/expr_and_series/n_unique_test.py b/tests/expr_and_series/n_unique_test.py index c761b20fd..c29b1abb3 100644 --- a/tests/expr_and_series/n_unique_test.py +++ b/tests/expr_and_series/n_unique_test.py @@ -13,7 +13,7 @@ def test_n_unique(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) - result = df.select(nw.col("a", "b").n_unique()) + result = df.select(nw.all().n_unique()) expected = {"a": [3], "b": [4]} assert_equal_data(result, expected) diff --git a/tests/expr_and_series/name/keep_test.py b/tests/expr_and_series/name/keep_test.py index 9c172dba6..9456f80d0 100644 --- a/tests/expr_and_series/name/keep_test.py +++ b/tests/expr_and_series/name/keep_test.py @@ -1,12 +1,6 @@ from __future__ import annotations -from contextlib import nullcontext as does_not_raise - -import polars as pl -import pytest - import narwhals.stable.v1 as nw -from narwhals.exceptions import AnonymousExprError from tests.utils import Constructor from tests.utils import assert_equal_data @@ -27,18 +21,9 @@ def test_keep_after_alias(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_keep_raise_anonymous(constructor: Constructor) -> None: +def test_keep_anonymous(constructor: Constructor) -> None: df_raw = constructor(data) df = nw.from_native(df_raw) - - context = ( - does_not_raise() - if isinstance(df_raw, (pl.LazyFrame, pl.DataFrame)) - else pytest.raises( - AnonymousExprError, - match="Anonymous expressions are not supported in `.name.keep`.", - ) - ) - - with context: - df.select(nw.all().name.keep()) + result = df.select("foo").select(nw.all().alias("fdfsad").name.keep()) + expected = {"foo": [1, 2, 3]} + assert_equal_data(result, expected) diff --git a/tests/expr_and_series/name/map_test.py b/tests/expr_and_series/name/map_test.py index 5b93de213..c1c370303 100644 --- a/tests/expr_and_series/name/map_test.py +++ b/tests/expr_and_series/name/map_test.py @@ -1,12 +1,6 @@ from __future__ import annotations -from contextlib import nullcontext as does_not_raise - -import polars as pl -import pytest - import narwhals.stable.v1 as nw -from narwhals.exceptions import AnonymousExprError from tests.utils import Constructor from tests.utils import assert_equal_data @@ -20,29 +14,20 @@ def map_func(s: str | None) -> str: def test_map(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.map(function=map_func)) - expected = {map_func(k): [e * 2 for e in v] for k, v in data.items()} + expected = {"oof": [2, 4, 6], "rab": [8, 10, 12]} assert_equal_data(result, expected) def test_map_after_alias(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo")).alias("alias_for_foo").name.map(function=map_func)) - expected = {map_func("foo"): data["foo"]} + expected = {"oof": data["foo"]} assert_equal_data(result, expected) -def test_map_raise_anonymous(constructor: Constructor) -> None: +def test_map_anonymous(constructor: Constructor) -> None: df_raw = constructor(data) df = nw.from_native(df_raw) - - context = ( - does_not_raise() - if isinstance(df_raw, (pl.LazyFrame, pl.DataFrame)) - else pytest.raises( - AnonymousExprError, - match="Anonymous expressions are not supported in `.name.map`.", - ) - ) - - with context: - df.select(nw.all().name.map(function=map_func)) + result = df.select(nw.all().name.map(function=map_func)) + expected = {"oof": [1, 2, 3], "rab": [4, 5, 6]} + assert_equal_data(result, expected) diff --git a/tests/expr_and_series/name/prefix_test.py b/tests/expr_and_series/name/prefix_test.py index 5894153be..953bbe4d9 100644 --- a/tests/expr_and_series/name/prefix_test.py +++ b/tests/expr_and_series/name/prefix_test.py @@ -1,12 +1,6 @@ from __future__ import annotations -from contextlib import nullcontext as does_not_raise - -import polars as pl -import pytest - import narwhals.stable.v1 as nw -from narwhals.exceptions import AnonymousExprError from tests.utils import Constructor from tests.utils import assert_equal_data @@ -17,29 +11,20 @@ def test_prefix(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.prefix(prefix)) - expected = {prefix + str(k): [e * 2 for e in v] for k, v in data.items()} + expected = {"with_prefix_foo": [2, 4, 6], "with_prefix_BAR": [8, 10, 12]} assert_equal_data(result, expected) def test_suffix_after_alias(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo")).alias("alias_for_foo").name.prefix(prefix)) - expected = {prefix + "foo": data["foo"]} + expected = {"with_prefix_foo": [1, 2, 3]} assert_equal_data(result, expected) -def test_prefix_raise_anonymous(constructor: Constructor) -> None: +def test_prefix_anonymous(constructor: Constructor) -> None: df_raw = constructor(data) df = nw.from_native(df_raw) - - context = ( - does_not_raise() - if isinstance(df_raw, (pl.LazyFrame, pl.DataFrame)) - else pytest.raises( - AnonymousExprError, - match="Anonymous expressions are not supported in `.name.prefix`.", - ) - ) - - with context: - df.select(nw.all().name.prefix(prefix)) + result = df.select(nw.all().name.prefix(prefix)) + expected = {"with_prefix_foo": [1, 2, 3], "with_prefix_BAR": [4, 5, 6]} + assert_equal_data(result, expected) diff --git a/tests/expr_and_series/name/suffix_test.py b/tests/expr_and_series/name/suffix_test.py index 1c5816154..d58e5b0e6 100644 --- a/tests/expr_and_series/name/suffix_test.py +++ b/tests/expr_and_series/name/suffix_test.py @@ -1,10 +1,5 @@ from __future__ import annotations -from contextlib import nullcontext as does_not_raise - -import polars as pl -import pytest - import narwhals.stable.v1 as nw from tests.utils import Constructor from tests.utils import assert_equal_data @@ -16,29 +11,20 @@ def test_suffix(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.suffix(suffix)) - expected = {str(k) + suffix: [e * 2 for e in v] for k, v in data.items()} + expected = {"foo_with_suffix": [2, 4, 6], "BAR_with_suffix": [8, 10, 12]} assert_equal_data(result, expected) def test_suffix_after_alias(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo")).alias("alias_for_foo").name.suffix(suffix)) - expected = {"foo" + suffix: data["foo"]} + expected = {"foo_with_suffix": [1, 2, 3]} assert_equal_data(result, expected) -def test_suffix_raise_anonymous(constructor: Constructor) -> None: +def test_suffix_anonymous(constructor: Constructor) -> None: df_raw = constructor(data) df = nw.from_native(df_raw) - - context = ( - does_not_raise() - if isinstance(df_raw, (pl.LazyFrame, pl.DataFrame)) - else pytest.raises( - ValueError, - match="Anonymous expressions are not supported in `.name.suffix`.", - ) - ) - - with context: - df.select(nw.all().name.suffix(suffix)) + result = df.select(nw.all().name.suffix(suffix)) + expected = {"foo_with_suffix": [1, 2, 3], "BAR_with_suffix": [4, 5, 6]} + assert_equal_data(result, expected) diff --git a/tests/expr_and_series/name/to_lowercase_test.py b/tests/expr_and_series/name/to_lowercase_test.py index 7acf9af59..e8606d088 100644 --- a/tests/expr_and_series/name/to_lowercase_test.py +++ b/tests/expr_and_series/name/to_lowercase_test.py @@ -1,12 +1,6 @@ from __future__ import annotations -from contextlib import nullcontext as does_not_raise - -import polars as pl -import pytest - import narwhals.stable.v1 as nw -from narwhals.exceptions import AnonymousExprError from tests.utils import Constructor from tests.utils import assert_equal_data @@ -16,29 +10,20 @@ def test_to_lowercase(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.to_lowercase()) - expected = {k.lower(): [e * 2 for e in v] for k, v in data.items()} + expected = {"foo": [2, 4, 6], "bar": [8, 10, 12]} assert_equal_data(result, expected) def test_to_lowercase_after_alias(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("BAR")).alias("ALIAS_FOR_BAR").name.to_lowercase()) - expected = {"bar": data["BAR"]} + expected = {"bar": [4, 5, 6]} assert_equal_data(result, expected) def test_to_lowercase_raise_anonymous(constructor: Constructor) -> None: df_raw = constructor(data) df = nw.from_native(df_raw) - - context = ( - does_not_raise() - if isinstance(df_raw, (pl.LazyFrame, pl.DataFrame)) - else pytest.raises( - AnonymousExprError, - match="Anonymous expressions are not supported in `.name.to_lowercase`.", - ) - ) - - with context: - df.select(nw.all().name.to_lowercase()) + result = df.select(nw.all().name.to_lowercase()) + expected = {"foo": [1, 2, 3], "bar": [4, 5, 6]} + assert_equal_data(result, expected) diff --git a/tests/expr_and_series/name/to_uppercase_test.py b/tests/expr_and_series/name/to_uppercase_test.py index 7d0bd7e57..c38b867df 100644 --- a/tests/expr_and_series/name/to_uppercase_test.py +++ b/tests/expr_and_series/name/to_uppercase_test.py @@ -1,12 +1,6 @@ from __future__ import annotations -from contextlib import nullcontext as does_not_raise - -import polars as pl -import pytest - import narwhals.stable.v1 as nw -from narwhals.exceptions import AnonymousExprError from tests.utils import Constructor from tests.utils import assert_equal_data @@ -16,29 +10,20 @@ def test_to_uppercase(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo", "BAR") * 2).name.to_uppercase()) - expected = {k.upper(): [e * 2 for e in v] for k, v in data.items()} + expected = {"FOO": [2, 4, 6], "BAR": [8, 10, 12]} assert_equal_data(result, expected) def test_to_uppercase_after_alias(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select((nw.col("foo")).alias("alias_for_foo").name.to_uppercase()) - expected = {"FOO": data["foo"]} + expected = {"FOO": [1, 2, 3]} assert_equal_data(result, expected) -def test_to_uppercase_raise_anonymous(constructor: Constructor) -> None: +def test_to_uppercase_anonymous(constructor: Constructor) -> None: df_raw = constructor(data) df = nw.from_native(df_raw) - - context = ( - does_not_raise() - if isinstance(df_raw, (pl.LazyFrame, pl.DataFrame)) - else pytest.raises( - AnonymousExprError, - match="Anonymous expressions are not supported in `.name.to_uppercase`.", - ) - ) - - with context: - df.select(nw.all().name.to_uppercase()) + result = df.select(nw.all().name.to_uppercase()) + expected = {"FOO": [1, 2, 3], "BAR": [4, 5, 6]} + assert_equal_data(result, expected) diff --git a/tests/expr_and_series/over_test.py b/tests/expr_and_series/over_test.py index ee8a14806..600bde179 100644 --- a/tests/expr_and_series/over_test.py +++ b/tests/expr_and_series/over_test.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +from contextlib import nullcontext as does_not_raise import pandas as pd import pytest @@ -61,17 +62,6 @@ def test_over_multiple(request: pytest.FixtureRequest, constructor: Constructor) assert_equal_data(result, expected) -def test_over_invalid(request: pytest.FixtureRequest, constructor: Constructor) -> None: - if "polars" in str(constructor): - request.applymarker(pytest.mark.xfail) - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) - - df = nw.from_native(constructor(data)) - with pytest.raises(ValueError, match="Anonymous expressions"): - df.with_columns(c_min=nw.all().min().over("a", "b")) - - def test_over_cumsum( request: pytest.FixtureRequest, constructor_eager: ConstructorEager ) -> None: @@ -177,10 +167,59 @@ def test_over_cumprod( assert_equal_data(result, expected) -def test_over_anonymous() -> None: - df = pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) - with pytest.raises(ValueError, match="Anonymous expressions"): - nw.from_native(df).select(nw.all().cum_max().over("a")) +def test_over_anonymous_cumulative(constructor_eager: ConstructorEager) -> None: + df = nw.from_native(constructor_eager({"a": [1, 1, 2], "b": [4, 5, 6]})) + context = ( + pytest.raises(NotImplementedError) + if df.implementation.is_pyarrow() + else pytest.raises(KeyError) # type: ignore[arg-type] + if df.implementation.is_modin() + or (df.implementation.is_pandas() and PANDAS_VERSION < (1, 3)) + # TODO(unassigned): bug in old pandas + modin. + # df.groupby('a')[['a', 'b']].cum_sum() excludes `'a'` from result + else does_not_raise() + ) + with context: + result = df.with_columns( + nw.all().cum_sum().over("a").name.suffix("_cum_sum") + ).sort("a", "b") + expected = { + "a": [1, 1, 2], + "b": [4, 5, 6], + "a_cum_sum": [1, 2, 2], + "b_cum_sum": [4, 9, 6], + } + assert_equal_data(result, expected) + + +def test_over_anonymous_reduction( + constructor: Constructor, request: pytest.FixtureRequest +) -> None: + if "duckdb" in str(constructor): + # TODO(unassigned): we should be able to support these + request.applymarker(pytest.mark.xfail) + + df = nw.from_native(constructor({"a": [1, 1, 2], "b": [4, 5, 6]})) + context = ( + pytest.raises(NotImplementedError) + if df.implementation.is_pyarrow() + or df.implementation.is_pandas_like() + or df.implementation.is_dask() + else does_not_raise() + ) + with context: + result = ( + nw.from_native(df) + .with_columns(nw.all().sum().over("a").name.suffix("_sum")) + .sort("a", "b") + ) + expected = { + "a": [1, 1, 2], + "b": [4, 5, 6], + "a_sum": [2, 2, 2], + "b_sum": [9, 9, 6], + } + assert_equal_data(result, expected) def test_over_shift( diff --git a/tests/expr_and_series/sum_horizontal_test.py b/tests/expr_and_series/sum_horizontal_test.py index 7404eefeb..25ef6af81 100644 --- a/tests/expr_and_series/sum_horizontal_test.py +++ b/tests/expr_and_series/sum_horizontal_test.py @@ -35,9 +35,9 @@ def test_sumh_nullable(constructor: Constructor) -> None: def test_sumh_all(constructor: Constructor) -> None: data = {"a": [1, 2, 3], "b": [10, 20, 30]} df = nw.from_native(constructor(data)) - result = df.select(nw.sum_horizontal(nw.all())) + result = df.select(nw.sum_horizontal(nw.all().name.suffix("_foo"))) expected = { - "a": [11, 22, 33], + "a_foo": [11, 22, 33], } assert_equal_data(result, expected) result = df.select(c=nw.sum_horizontal(nw.all())) diff --git a/tests/frame/select_test.py b/tests/frame/select_test.py index 36b8fc881..8237b7b0d 100644 --- a/tests/frame/select_test.py +++ b/tests/frame/select_test.py @@ -5,6 +5,7 @@ import pandas as pd import pyarrow as pa import pytest +from polars.exceptions import DuplicateError import narwhals.stable.v1 as nw from narwhals.exceptions import ColumnNotFoundError @@ -133,3 +134,9 @@ def test_left_to_right_broadcasting( result = df.select(nw.col("b").sum() + nw.col("a").sum()) expected = {"b": [19]} assert_equal_data(result, expected) + + +def test_alias_invalid(constructor: Constructor) -> None: + df = nw.from_native(constructor({"a": [1, 2, 3], "b": [4, 5, 6]})) + with pytest.raises((DuplicateError, ValueError)): + df.lazy().select(nw.all().alias("c")).collect() diff --git a/tests/group_by_test.py b/tests/group_by_test.py index 9929d36cd..d5f25554d 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -4,11 +4,9 @@ import pandas as pd import polars as pl -import pyarrow as pa import pytest import narwhals.stable.v1 as nw -from narwhals.exceptions import AnonymousExprError from narwhals.exceptions import InvalidOperationError from tests.utils import PANDAS_VERSION from tests.utils import PYARROW_VERSION @@ -46,36 +44,6 @@ def test_invalid_group_by_dask() -> None: with pytest.raises(ValueError, match=r"Non-trivial complex aggregation found"): nw.from_native(df_dask).group_by("a").agg(nw.col("b").mean().min()) - with pytest.raises(InvalidOperationError, match="does not aggregate"): - nw.from_native(df_dask).group_by("a").agg(nw.col("b")) - - with pytest.raises( - AnonymousExprError, - match=r"Anonymous expressions are not supported in `group_by\.agg`", - ): - nw.from_native(df_dask).group_by("a").agg(nw.all().mean()) - - -@pytest.mark.filterwarnings("ignore:Found complex group-by expression:UserWarning") -def test_invalid_group_by() -> None: - df = nw.from_native(df_pandas) - with pytest.raises(InvalidOperationError, match="does not aggregate"): - df.group_by("a").agg(nw.col("b")) - with pytest.raises( - AnonymousExprError, - match=r"Anonymous expressions are not supported in `group_by\.agg`", - ): - df.group_by("a").agg(nw.all().mean()) - with pytest.raises( - AnonymousExprError, - match=r"Anonymous expressions are not supported in `group_by\.agg`", - ): - nw.from_native(pa.table({"a": [1, 2, 3]})).group_by("a").agg(nw.all().mean()) - with pytest.raises(ValueError, match=r"Non-trivial complex aggregation found"): - nw.from_native(pa.table({"a": [1, 2, 3]})).group_by("a").agg( - nw.col("b").mean().min() - ) - def test_group_by_iter(constructor_eager: ConstructorEager) -> None: df = nw.from_native(constructor_eager(data), eager_only=True) @@ -99,6 +67,16 @@ def test_group_by_iter(constructor_eager: ConstructorEager) -> None: assert sorted(keys) == sorted(expected_keys) +def test_group_by_nw_all(constructor: Constructor) -> None: + df = nw.from_native(constructor({"a": [1, 1, 2], "b": [4, 5, 6], "c": [7, 8, 9]})) + result = df.group_by("a").agg(nw.all().sum()).sort("a") + expected = {"a": [1, 2], "b": [9, 6], "c": [15, 9]} + assert_equal_data(result, expected) + result = df.group_by("a").agg(nw.all().sum().name.suffix("_sum")).sort("a") + expected = {"a": [1, 2], "b_sum": [9, 6], "c_sum": [15, 9]} + assert_equal_data(result, expected) + + @pytest.mark.parametrize( ("attr", "expected"), [ @@ -344,8 +322,6 @@ def test_group_by_categorical( request.applymarker(pytest.mark.xfail) if "pyarrow_table" in str(constructor) and PYARROW_VERSION < ( 15, - 0, - 0, ): # pragma: no cover request.applymarker(pytest.mark.xfail) diff --git a/tests/stable_api_test.py b/tests/stable_api_test.py index c1c3f210f..8d5348834 100644 --- a/tests/stable_api_test.py +++ b/tests/stable_api_test.py @@ -52,15 +52,10 @@ def test_renamed_taxicab_norm( assert_equal_data(result_v1, expected) -def test_renamed_taxicab_norm_dataframe( - request: pytest.FixtureRequest, constructor: Constructor -) -> None: +def test_renamed_taxicab_norm_dataframe(constructor: Constructor) -> None: # Suppose we have `DataFrame._l1_norm` in `stable.v1`, but remove it # in the main namespace. Here, we check that it's still usable from # the stable api. - if "pyspark" in str(constructor): - request.applymarker(pytest.mark.xfail) - def func(df_any: Any) -> Any: df = nw_v1.from_native(df_any) df = df._l1_norm() @@ -71,16 +66,10 @@ def func(df_any: Any) -> Any: assert_equal_data(result, expected) -def test_renamed_taxicab_norm_dataframe_narwhalify( - request: pytest.FixtureRequest, constructor: Constructor -) -> None: +def test_renamed_taxicab_norm_dataframe_narwhalify(constructor: Constructor) -> None: # Suppose we have `DataFrame._l1_norm` in `stable.v1`, but remove it # in the main namespace. Here, we check that it's still usable from # the stable api when using `narwhalify`. - - if "pyspark" in str(constructor): - request.applymarker(pytest.mark.xfail) - @nw_v1.narwhalify def func(df: Any) -> Any: return df._l1_norm()