diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 176a79259..8953bd6f9 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -294,7 +294,7 @@ def simple_select(self, *column_names: str) -> Self: return self._from_native_frame(self._native_frame.select(list(column_names))) def select(self: Self, *exprs: IntoArrowExpr, **named_exprs: IntoArrowExpr) -> Self: - new_series = evaluate_into_exprs(self, *exprs, **named_exprs) + new_series: list[ArrowSeries] = evaluate_into_exprs(self)(*exprs, **named_exprs) if not new_series: # return empty dataframe, like Polars does return self._from_native_frame(self._native_frame.__class__.from_arrays([])) @@ -306,7 +306,7 @@ def with_columns( self: Self, *exprs: IntoArrowExpr, **named_exprs: IntoArrowExpr ) -> Self: native_frame = self._native_frame - new_columns = evaluate_into_exprs(self, *exprs, **named_exprs) + new_columns: list[ArrowSeries] = evaluate_into_exprs(self)(*exprs, **named_exprs) length = len(self) columns = self.columns diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index 3213a0eaa..cb3c136be 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -75,7 +75,7 @@ def _from_native_frame(self: Self, df: Any) -> Self: def with_columns(self: Self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self: df = self._native_frame - new_series = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) + new_series = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs) df = df.assign(**new_series) return self._from_native_frame(df) @@ -115,7 +115,7 @@ def simple_select(self: Self, *column_names: str) -> Self: ) def select(self: Self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self: - new_series = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) + new_series = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs) if not new_series: # return empty dataframe, like Polars does diff --git a/narwhals/_dask/utils.py b/narwhals/_dask/utils.py index 33584efdb..6a2222f95 100644 --- a/narwhals/_dask/utils.py +++ b/narwhals/_dask/utils.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Callable from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._pandas_like.utils import select_columns_by_name @@ -44,29 +45,32 @@ def maybe_evaluate(df: DaskLazyFrame, obj: Any) -> Any: return obj -def parse_exprs_and_named_exprs( - df: DaskLazyFrame, *exprs: DaskExpr, **named_exprs: DaskExpr -) -> dict[str, dx.Series]: - native_results: dict[str, dx.Series] = {} - for expr in exprs: - native_series_list = expr._call(df) - return_scalar = getattr(expr, "_returns_scalar", False) - _, aliases = evaluate_output_names_and_aliases(expr, df, []) - if len(aliases) != len(native_series_list): # pragma: no cover - msg = f"Internal error: got aliases {aliases}, but only got {len(native_series_list)} results" - raise AssertionError(msg) - for native_series, alias in zip(native_series_list, aliases): - native_results[alias] = native_series[0] if return_scalar else native_series - for name, value in named_exprs.items(): - native_series_list = value._call(df) - if len(native_series_list) != 1: # pragma: no cover - msg = "Named expressions must return a single column" - raise AssertionError(msg) - return_scalar = getattr(value, "_returns_scalar", False) - native_results[name] = ( - native_series_list[0][0] if return_scalar else native_series_list[0] - ) - return native_results +def parse_exprs_and_named_exprs(df: DaskLazyFrame) -> Callable[..., dict[str, dx.Series]]: + def func(*exprs: DaskExpr, **named_exprs: DaskExpr) -> dict[str, dx.Series]: + native_results: dict[str, dx.Series] = {} + for expr in exprs: + native_series_list = expr._call(df) + return_scalar = getattr(expr, "_returns_scalar", False) + _, aliases = evaluate_output_names_and_aliases(expr, df, []) + if len(aliases) != len(native_series_list): # pragma: no cover + msg = f"Internal error: got aliases {aliases}, but only got {len(native_series_list)} results" + raise AssertionError(msg) + for native_series, alias in zip(native_series_list, aliases): + native_results[alias] = ( + native_series[0] if return_scalar else native_series + ) + for name, value in named_exprs.items(): + native_series_list = value._call(df) + if len(native_series_list) != 1: # pragma: no cover + msg = "Named expressions must return a single column" + raise AssertionError(msg) + return_scalar = getattr(value, "_returns_scalar", False) + native_results[name] = ( + native_series_list[0][0] if return_scalar else native_series_list[0] + ) + return native_results + + return func def add_row_index( diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 7ee11a12a..c34028e84 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -104,7 +104,7 @@ def select( *exprs: DuckDBExpr, **named_exprs: DuckDBExpr, ) -> Self: - new_columns_map = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) + new_columns_map = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs) if not new_columns_map: # TODO(marco): return empty relation with 0 columns? return self._from_native_frame(self._native_frame.limit(0)) @@ -139,7 +139,7 @@ def with_columns( *exprs: DuckDBExpr, **named_exprs: DuckDBExpr, ) -> Self: - new_columns_map = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) + new_columns_map = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs) result = [] for col in self._native_frame.columns: if col in new_columns_map: diff --git a/narwhals/_duckdb/utils.py b/narwhals/_duckdb/utils.py index 8457843e4..16cb7b92c 100644 --- a/narwhals/_duckdb/utils.py +++ b/narwhals/_duckdb/utils.py @@ -4,6 +4,7 @@ from functools import lru_cache from typing import TYPE_CHECKING from typing import Any +from typing import Callable import duckdb @@ -36,25 +37,30 @@ def maybe_evaluate(df: DuckDBLazyFrame, obj: Any) -> Any: def parse_exprs_and_named_exprs( - df: DuckDBLazyFrame, *exprs: DuckDBExpr, **named_exprs: DuckDBExpr -) -> dict[str, duckdb.Expression]: - native_results: dict[str, list[duckdb.Expression]] = {} - for expr in exprs: - native_series_list = expr._call(df) - output_names = expr._evaluate_output_names(df) - if expr._alias_output_names is not None: - output_names = expr._alias_output_names(output_names) - if len(output_names) != len(native_series_list): # pragma: no cover - msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" - raise AssertionError(msg) - native_results.update(zip(output_names, native_series_list)) - for col_alias, expr in named_exprs.items(): - native_series_list = expr._call(df) - if len(native_series_list) != 1: # pragma: no cover - msg = "Named expressions must return a single column" - raise ValueError(msg) - native_results[col_alias] = native_series_list[0] - return native_results + df: DuckDBLazyFrame, +) -> Callable[..., dict[str, duckdb.Expression]]: + def func( + *exprs: DuckDBExpr, **named_exprs: DuckDBExpr + ) -> dict[str, duckdb.Expression]: + native_results: dict[str, list[duckdb.Expression]] = {} + for expr in exprs: + native_series_list = expr._call(df) + output_names = expr._evaluate_output_names(df) + if expr._alias_output_names is not None: + output_names = expr._alias_output_names(output_names) + if len(output_names) != len(native_series_list): # pragma: no cover + msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" + raise AssertionError(msg) + native_results.update(zip(output_names, native_series_list)) + for col_alias, expr in named_exprs.items(): + native_series_list = expr._call(df) + if len(native_series_list) != 1: # pragma: no cover + msg = "Named expressions must return a single column" + raise ValueError(msg) + native_results[col_alias] = native_series_list[0] + return native_results + + return func @lru_cache(maxsize=16) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index af4742f35..5fa186dbb 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -62,23 +62,28 @@ def evaluate_into_expr( def evaluate_into_exprs( df: CompliantDataFrame, - *exprs: IntoCompliantExpr[CompliantSeriesT_co], - **named_exprs: IntoCompliantExpr[CompliantSeriesT_co], -) -> Sequence[CompliantSeriesT_co]: +) -> Callable[..., list[CompliantSeriesT_co]]: """Evaluate each expr into Series.""" - series = [ - item - for sublist in (evaluate_into_expr(df, into_expr) for into_expr in exprs) - for item in sublist - ] - for name, expr in named_exprs.items(): - evaluated_expr = evaluate_into_expr(df, expr) - if len(evaluated_expr) > 1: - msg = "Named expressions must return a single column" # pragma: no cover - raise AssertionError(msg) - to_append = evaluated_expr[0].alias(name) - series.append(to_append) - return series + + def func( + *exprs: IntoCompliantExpr[CompliantSeriesT_co], + **named_exprs: IntoCompliantExpr[CompliantSeriesT_co], + ) -> list[CompliantSeriesT_co]: + series = [ + item + for sublist in (evaluate_into_expr(df, into_expr) for into_expr in exprs) + for item in sublist + ] + for name, expr in named_exprs.items(): + evaluated_expr = evaluate_into_expr(df, expr) + if len(evaluated_expr) > 1: + msg = "Named expressions must return a single column" # pragma: no cover + raise AssertionError(msg) + to_append = evaluated_expr[0].alias(name) + series.append(to_append) + return series + + return func def maybe_evaluate_expr( diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 809b3de05..1b4c29b8e 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -363,7 +363,9 @@ def select( *exprs: IntoPandasLikeExpr, **named_exprs: IntoPandasLikeExpr, ) -> Self: - new_series = evaluate_into_exprs(self, *exprs, **named_exprs) + new_series: list[PandasLikeSeries] = evaluate_into_exprs(self)( + *exprs, **named_exprs + ) if not new_series: # return empty dataframe, like Polars does return self._from_native_frame(self._native_frame.__class__()) @@ -433,7 +435,9 @@ def with_columns( **named_exprs: IntoPandasLikeExpr, ) -> Self: index = self._native_frame.index - new_columns = evaluate_into_exprs(self, *exprs, **named_exprs) + new_columns: list[PandasLikeSeries] = evaluate_into_exprs(self)( + *exprs, **named_exprs + ) if not new_columns and len(self) == 0: return self diff --git a/narwhals/_spark_like/dataframe.py b/narwhals/_spark_like/dataframe.py index 3c4ac2fe7..fed410f89 100644 --- a/narwhals/_spark_like/dataframe.py +++ b/narwhals/_spark_like/dataframe.py @@ -94,7 +94,7 @@ def select( *exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr, ) -> Self: - new_columns = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) + new_columns = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs) if not new_columns: # return empty dataframe, like Polars does @@ -135,7 +135,7 @@ def with_columns( *exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr, ) -> Self: - new_columns_map = parse_exprs_and_named_exprs(self, *exprs, **named_exprs) + new_columns_map = parse_exprs_and_named_exprs(self)(*exprs, **named_exprs) return self._from_native_frame(self._native_frame.withColumns(new_columns_map)) def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001 diff --git a/narwhals/_spark_like/utils.py b/narwhals/_spark_like/utils.py index f2f7abfde..259e124b9 100644 --- a/narwhals/_spark_like/utils.py +++ b/narwhals/_spark_like/utils.py @@ -3,6 +3,7 @@ from functools import lru_cache from typing import TYPE_CHECKING from typing import Any +from typing import Callable from pyspark.sql import functions as F # noqa: N812 @@ -110,25 +111,28 @@ def narwhals_to_native_dtype( def parse_exprs_and_named_exprs( - df: SparkLikeLazyFrame, *exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr -) -> dict[str, Column]: - native_results: dict[str, list[Column]] = {} - for expr in exprs: - native_series_list = expr._call(df) - output_names = expr._evaluate_output_names(df) - if expr._alias_output_names is not None: - output_names = expr._alias_output_names(output_names) - if len(output_names) != len(native_series_list): # pragma: no cover - msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" - raise AssertionError(msg) - native_results.update(zip(output_names, native_series_list)) - for col_alias, expr in named_exprs.items(): - native_series_list = expr._call(df) - if len(native_series_list) != 1: # pragma: no cover - msg = "Named expressions must return a single column" - raise ValueError(msg) - native_results[col_alias] = native_series_list[0] - return native_results + df: SparkLikeLazyFrame, +) -> Callable[..., dict[str, Column]]: + def func(*exprs: SparkLikeExpr, **named_exprs: SparkLikeExpr) -> dict[str, Column]: + native_results: dict[str, list[Column]] = {} + for expr in exprs: + native_series_list = expr._call(df) + output_names = expr._evaluate_output_names(df) + if expr._alias_output_names is not None: + output_names = expr._alias_output_names(output_names) + if len(output_names) != len(native_series_list): # pragma: no cover + msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results" + raise AssertionError(msg) + native_results.update(zip(output_names, native_series_list)) + for col_alias, expr in named_exprs.items(): + native_series_list = expr._call(df) + if len(native_series_list) != 1: # pragma: no cover + msg = "Named expressions must return a single column" + raise ValueError(msg) + native_results[col_alias] = native_series_list[0] + return native_results + + return func def maybe_evaluate(df: SparkLikeLazyFrame, obj: Any) -> Any: