Skip to content

Commit

Permalink
feat: completely refactor alias tracking and support nw.all, `nw.nt…
Browse files Browse the repository at this point in the history
…h`, and selectors across the API (#1866)
  • Loading branch information
MarcoGorelli authored Jan 26, 2025
1 parent 26c5cbc commit a585548
Show file tree
Hide file tree
Showing 42 changed files with 1,359 additions and 1,679 deletions.
53 changes: 31 additions & 22 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from narwhals._arrow.expr_name import ArrowExprNameNamespace
from narwhals._arrow.expr_str import ArrowExprStringNamespace
from narwhals._arrow.series import ArrowSeries
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._expression_parsing import reuse_series_implementation
from narwhals.dependencies import get_numpy
from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import AnonymousExprError
from narwhals.exceptions import ColumnNotFoundError
from narwhals.typing import CompliantExpr
from narwhals.utils import Implementation
Expand All @@ -39,19 +39,18 @@ def __init__(
*,
depth: int,
function_name: str,
root_names: list[str] | None,
output_names: list[str] | None,
evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]],
alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None,
backend_version: tuple[int, ...],
version: Version,
kwargs: dict[str, Any],
) -> None:
self._call = call
self._depth = depth
self._function_name = function_name
self._root_names = root_names
self._depth = depth
self._output_names = output_names
self._implementation = Implementation.PYARROW
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._backend_version = backend_version
self._version = version
self._kwargs = kwargs
Expand All @@ -61,8 +60,6 @@ def __repr__(self: Self) -> str: # pragma: no cover
f"ArrowExpr("
f"depth={self._depth}, "
f"function_name={self._function_name}, "
f"root_names={self._root_names}, "
f"output_names={self._output_names}"
)

def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
Expand Down Expand Up @@ -98,8 +95,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
func,
depth=0,
function_name="col",
root_names=list(column_names),
output_names=list(column_names),
evaluate_output_names=lambda _df: list(column_names),
alias_output_names=None,
backend_version=backend_version,
version=version,
kwargs={},
Expand Down Expand Up @@ -129,8 +126,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
func,
depth=0,
function_name="nth",
root_names=None,
output_names=None,
evaluate_output_names=lambda df: [df.columns[i] for i in column_indices],
alias_output_names=None,
backend_version=backend_version,
version=version,
kwargs={},
Expand Down Expand Up @@ -265,14 +262,20 @@ def shift(self: Self, n: int) -> Self:
return reuse_series_implementation(self, "shift", n=n)

def alias(self: Self, name: str) -> Self:
def alias_output_names(names: Sequence[str]) -> Sequence[str]:
if len(names) != 1:
msg = f"Expected function with single output, found output names: {names}"
raise ValueError(msg)
return [name]

# Define this one manually, so that we can
# override `output_names` and not increase depth
return self.__class__(
lambda df: [series.alias(name) for series in self._call(df)],
depth=self._depth,
function_name=self._function_name,
root_names=self._root_names,
output_names=[name],
evaluate_output_names=self._evaluate_output_names,
alias_output_names=alias_output_names,
backend_version=self._backend_version,
version=self._version,
kwargs={**self._kwargs, "name": name},
Expand Down Expand Up @@ -390,22 +393,28 @@ def clip(self: Self, lower_bound: Any | None, upper_bound: Any | None) -> Self:

def over(self: Self, keys: list[str]) -> Self:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
if self._output_names is None:
msg = ".over"
raise AnonymousExprError.from_expr_name(msg)
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
if overlap := set(output_names).intersection(keys):
# E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined,
# we just don't support it yet.
msg = (
f"Column names {overlap} appear in both expression output names and in `over` keys.\n"
"This is not yet supported."
)
raise NotImplementedError(msg)

tmp = df.group_by(*keys, drop_null_keys=False).agg(self)
tmp = df.simple_select(*keys).join(
tmp, how="left", left_on=keys, right_on=keys, suffix="_right"
)
return [tmp[name] for name in self._output_names]
return [tmp[alias] for alias in aliases]

return self.__class__(
func,
depth=self._depth + 1,
function_name=self._function_name + "->over",
root_names=self._root_names,
output_names=self._output_names,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
backend_version=self._backend_version,
version=self._version,
kwargs={**self._kwargs, "keys": keys},
Expand Down Expand Up @@ -446,8 +455,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
func,
depth=self._depth + 1,
function_name=self._function_name + "->map_batches",
root_names=self._root_names,
output_names=self._output_names,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
backend_version=self._backend_version,
version=self._version,
kwargs={**self._kwargs, "function": function, "return_dtype": return_dtype},
Expand Down
117 changes: 51 additions & 66 deletions narwhals/_arrow/expr_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from typing import TYPE_CHECKING
from typing import Callable

from narwhals.exceptions import AnonymousExprError

if TYPE_CHECKING:
from typing_extensions import Self

Expand All @@ -16,131 +14,118 @@ def __init__(self: Self, expr: ArrowExpr) -> None:
self._compliant_expr = expr

def keep(self: Self) -> ArrowExpr:
root_names = self._compliant_expr._root_names

if root_names is None:
msg = ".name.keep"
raise AnonymousExprError.from_expr_name(msg)

return self._compliant_expr.__class__(
lambda df: [
series.alias(name)
for series, name in zip(self._compliant_expr._call(df), root_names)
for series, name in zip(
self._compliant_expr._call(df),
self._compliant_expr._evaluate_output_names(df),
)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=root_names,
evaluate_output_names=self._compliant_expr._evaluate_output_names,
alias_output_names=None,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs=self._compliant_expr._kwargs,
)

def map(self: Self, function: Callable[[str], str]) -> ArrowExpr:
root_names = self._compliant_expr._root_names

if root_names is None:
msg = ".name.map"
raise AnonymousExprError.from_expr_name(msg)

output_names = [function(str(name)) for name in root_names]

return self._compliant_expr.__class__(
lambda df: [
series.alias(name)
for series, name in zip(self._compliant_expr._call(df), output_names)
series.alias(function(str(name)))
for series, name in zip(
self._compliant_expr._call(df),
self._compliant_expr._evaluate_output_names(df),
)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
evaluate_output_names=self._compliant_expr._evaluate_output_names,
alias_output_names=lambda output_names: [
function(str(name)) for name in output_names
],
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={**self._compliant_expr._kwargs, "function": function},
)

def prefix(self: Self, prefix: str) -> ArrowExpr:
root_names = self._compliant_expr._root_names
if root_names is None:
msg = ".name.prefix"
raise AnonymousExprError.from_expr_name(msg)

output_names = [prefix + str(name) for name in root_names]
return self._compliant_expr.__class__(
lambda df: [
series.alias(name)
for series, name in zip(self._compliant_expr._call(df), output_names)
series.alias(f"{prefix}{name}")
for series, name in zip(
self._compliant_expr._call(df),
self._compliant_expr._evaluate_output_names(df),
)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
evaluate_output_names=self._compliant_expr._evaluate_output_names,
alias_output_names=lambda output_names: [
f"{prefix}{output_name}" for output_name in output_names
],
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={**self._compliant_expr._kwargs, "prefix": prefix},
)

def suffix(self: Self, suffix: str) -> ArrowExpr:
root_names = self._compliant_expr._root_names
if root_names is None:
msg = ".name.suffix"
raise AnonymousExprError.from_expr_name(msg)

output_names = [str(name) + suffix for name in root_names]

return self._compliant_expr.__class__(
lambda df: [
series.alias(name)
for series, name in zip(self._compliant_expr._call(df), output_names)
series.alias(f"{name}{suffix}")
for series, name in zip(
self._compliant_expr._call(df),
self._compliant_expr._evaluate_output_names(df),
)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
evaluate_output_names=self._compliant_expr._evaluate_output_names,
alias_output_names=lambda output_names: [
f"{output_name}{suffix}" for output_name in output_names
],
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={**self._compliant_expr._kwargs, "suffix": suffix},
)

def to_lowercase(self: Self) -> ArrowExpr:
root_names = self._compliant_expr._root_names

if root_names is None:
msg = ".name.to_lowercase"
raise AnonymousExprError.from_expr_name(msg)
output_names = [str(name).lower() for name in root_names]

return self._compliant_expr.__class__(
lambda df: [
series.alias(name)
for series, name in zip(self._compliant_expr._call(df), output_names)
series.alias(str(name).lower())
for series, name in zip(
self._compliant_expr._call(df),
self._compliant_expr._evaluate_output_names(df),
)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
evaluate_output_names=self._compliant_expr._evaluate_output_names,
alias_output_names=lambda output_names: [
str(name).lower() for name in output_names
],
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs=self._compliant_expr._kwargs,
)

def to_uppercase(self: Self) -> ArrowExpr:
root_names = self._compliant_expr._root_names

if root_names is None:
msg = ".name.to_uppercase"
raise AnonymousExprError.from_expr_name(msg)
output_names = [str(name).upper() for name in root_names]

return self._compliant_expr.__class__(
lambda df: [
series.alias(name)
for series, name in zip(self._compliant_expr._call(df), output_names)
series.alias(str(name).upper())
for series, name in zip(
self._compliant_expr._call(df),
self._compliant_expr._evaluate_output_names(df),
)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
evaluate_output_names=self._compliant_expr._evaluate_output_names,
alias_output_names=lambda output_names: [
str(name).upper() for name in output_names
],
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs=self._compliant_expr._kwargs,
Expand Down
Loading

0 comments on commit a585548

Please sign in to comment.