Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking β€œSign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: completely refactor alias tracking and support nw.all, nw.nth, and selectors across the API #1866

Merged
merged 43 commits into from
Jan 26, 2025
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
e24da74
chore: skip ibis tests if duckdb not installed
MarcoGorelli Jan 24, 2025
0ff19a6
__array__ defaults
MarcoGorelli Jan 24, 2025
14e9eb4
simpler fix
MarcoGorelli Jan 24, 2025
eea221a
wip
MarcoGorelli Jan 24, 2025
9a958be
wip
MarcoGorelli Jan 24, 2025
6d2e95d
wip
MarcoGorelli Jan 24, 2025
2ce8f68
wip
MarcoGorelli Jan 24, 2025
ef3be7e
rename
MarcoGorelli Jan 24, 2025
6793e3a
rev
MarcoGorelli Jan 24, 2025
26bdb90
fix horizontal functions
MarcoGorelli Jan 25, 2025
6279653
fixup dask selectors
MarcoGorelli Jan 25, 2025
f789e38
this is amazing
MarcoGorelli Jan 25, 2025
91d9987
start fixing pyspark
MarcoGorelli Jan 25, 2025
ae92462
wip
MarcoGorelli Jan 25, 2025
6fbd537
hey everythingn passed!
MarcoGorelli Jan 25, 2025
907e450
clean up
MarcoGorelli Jan 25, 2025
936adea
Merge remote-tracking branch 'upstream/main' into the-big-dask-refactor
MarcoGorelli Jan 25, 2025
0a794bd
wait, duckdb is passing now?
MarcoGorelli Jan 25, 2025
7b0fda2
wait, duckdb is passing now?
MarcoGorelli Jan 25, 2025
945a5b9
wait, duckdb is passing now?
MarcoGorelli Jan 25, 2025
38dc3bc
pyarrow, getting there!
MarcoGorelli Jan 25, 2025
02c9ab2
fixing pandas
MarcoGorelli Jan 25, 2025
e65796a
come on getting there
MarcoGorelli Jan 25, 2025
a37d088
come one, almost there
MarcoGorelli Jan 25, 2025
83351d5
come one, almost there
MarcoGorelli Jan 25, 2025
6e40a89
one more
MarcoGorelli Jan 25, 2025
e86792b
restore dask
MarcoGorelli Jan 25, 2025
5ae54a8
clean up
MarcoGorelli Jan 25, 2025
265308b
remove more stuff
MarcoGorelli Jan 25, 2025
5e43f7a
raise NotImplementedError in `over` in some cases for now
MarcoGorelli Jan 26, 2025
31c84e6
Merge remote-tracking branch 'upstream/main' into the-big-dask-refactor
MarcoGorelli Jan 26, 2025
f3c5166
duckdb fixup
MarcoGorelli Jan 26, 2025
9f7aa3b
outdated comment
MarcoGorelli Jan 26, 2025
37f6315
loudly raise for over with key overlap
MarcoGorelli Jan 26, 2025
c3dc751
better types in selectors.
MarcoGorelli Jan 26, 2025
c5a22a2
fixup
MarcoGorelli Jan 26, 2025
7753490
Merge remote-tracking branch 'upstream/main' into the-big-dask-refactor
MarcoGorelli Jan 26, 2025
7bca332
versions compat
MarcoGorelli Jan 26, 2025
6ea2745
fixup
MarcoGorelli Jan 26, 2025
e2d9448
coverage
MarcoGorelli Jan 26, 2025
02680c1
fixup
MarcoGorelli Jan 26, 2025
ac4681c
comprehensions
MarcoGorelli Jan 26, 2025
613966f
redundat comment
MarcoGorelli Jan 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 31 additions & 22 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@
from narwhals._arrow.expr_name import ArrowExprNameNamespace
from narwhals._arrow.expr_str import ArrowExprStringNamespace
from narwhals._arrow.series import ArrowSeries
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._expression_parsing import reuse_series_implementation
from narwhals.dependencies import get_numpy
from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import AnonymousExprError
from narwhals.exceptions import ColumnNotFoundError
from narwhals.typing import CompliantExpr
from narwhals.utils import Implementation
Expand All @@ -39,19 +39,18 @@ def __init__(
*,
depth: int,
function_name: str,
root_names: list[str] | None,
output_names: list[str] | None,
evaluate_output_names: Callable[[ArrowDataFrame], Sequence[str]],
alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None,
backend_version: tuple[int, ...],
version: Version,
kwargs: dict[str, Any],
) -> None:
self._call = call
self._depth = depth
self._function_name = function_name
self._root_names = root_names
self._depth = depth
self._output_names = output_names
self._implementation = Implementation.PYARROW
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._backend_version = backend_version
self._version = version
self._kwargs = kwargs
Expand All @@ -61,8 +60,6 @@ def __repr__(self: Self) -> str: # pragma: no cover
f"ArrowExpr("
f"depth={self._depth}, "
f"function_name={self._function_name}, "
f"root_names={self._root_names}, "
f"output_names={self._output_names}"
)

def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
Expand Down Expand Up @@ -98,8 +95,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
func,
depth=0,
function_name="col",
root_names=list(column_names),
output_names=list(column_names),
evaluate_output_names=lambda _df: list(column_names),
alias_output_names=None,
backend_version=backend_version,
version=version,
kwargs={},
Expand Down Expand Up @@ -129,8 +126,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
func,
depth=0,
function_name="nth",
root_names=None,
output_names=None,
evaluate_output_names=lambda df: [df.columns[i] for i in column_indices],
alias_output_names=None,
backend_version=backend_version,
version=version,
kwargs={},
Expand Down Expand Up @@ -265,14 +262,20 @@ def shift(self: Self, n: int) -> Self:
return reuse_series_implementation(self, "shift", n=n)

def alias(self: Self, name: str) -> Self:
def alias_output_names(names: Sequence[str]) -> Sequence[str]:
if len(names) != 1:
msg = f"Expected function with single output, found output names: {names}"
raise ValueError(msg)
return [name]

# Define this one manually, so that we can
# override `output_names` and not increase depth
return self.__class__(
lambda df: [series.alias(name) for series in self._call(df)],
depth=self._depth,
function_name=self._function_name,
root_names=self._root_names,
output_names=[name],
evaluate_output_names=self._evaluate_output_names,
alias_output_names=alias_output_names,
backend_version=self._backend_version,
version=self._version,
kwargs={**self._kwargs, "name": name},
Expand Down Expand Up @@ -390,22 +393,28 @@ def clip(self: Self, lower_bound: Any | None, upper_bound: Any | None) -> Self:

def over(self: Self, keys: list[str]) -> Self:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
if self._output_names is None:
msg = ".over"
raise AnonymousExprError.from_expr_name(msg)
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
if overlap := set(output_names).intersection(keys):
# E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined,
# we just don't support it yet.
msg = (
f"Column names {overlap} appear in both expression output names and in `over` keys.\n"
"This is not yet supported."
)
raise NotImplementedError(msg)

tmp = df.group_by(*keys, drop_null_keys=False).agg(self)
tmp = df.simple_select(*keys).join(
tmp, how="left", left_on=keys, right_on=keys, suffix="_right"
)
return [tmp[name] for name in self._output_names]
return [tmp[alias] for alias in aliases]

return self.__class__(
func,
depth=self._depth + 1,
function_name=self._function_name + "->over",
root_names=self._root_names,
output_names=self._output_names,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
backend_version=self._backend_version,
version=self._version,
kwargs={**self._kwargs, "keys": keys},
Expand Down Expand Up @@ -446,8 +455,8 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
func,
depth=self._depth + 1,
function_name=self._function_name + "->map_batches",
root_names=self._root_names,
output_names=self._output_names,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
backend_version=self._backend_version,
version=self._version,
kwargs={**self._kwargs, "function": function, "return_dtype": return_dtype},
Expand Down
117 changes: 51 additions & 66 deletions narwhals/_arrow/expr_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
from typing import TYPE_CHECKING
from typing import Callable

from narwhals.exceptions import AnonymousExprError

if TYPE_CHECKING:
from typing_extensions import Self

Expand All @@ -16,131 +14,118 @@ def __init__(self: Self, expr: ArrowExpr) -> None:
self._compliant_expr = expr

def keep(self: Self) -> ArrowExpr:
root_names = self._compliant_expr._root_names

if root_names is None:
msg = ".name.keep"
raise AnonymousExprError.from_expr_name(msg)

return self._compliant_expr.__class__(
lambda df: [
series.alias(name)
for series, name in zip(self._compliant_expr._call(df), root_names)
for series, name in zip(
self._compliant_expr._call(df),
self._compliant_expr._evaluate_output_names(df),
)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=root_names,
evaluate_output_names=self._compliant_expr._evaluate_output_names,
alias_output_names=None,
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs=self._compliant_expr._kwargs,
)

def map(self: Self, function: Callable[[str], str]) -> ArrowExpr:
root_names = self._compliant_expr._root_names

if root_names is None:
msg = ".name.map"
raise AnonymousExprError.from_expr_name(msg)

output_names = [function(str(name)) for name in root_names]

return self._compliant_expr.__class__(
lambda df: [
series.alias(name)
for series, name in zip(self._compliant_expr._call(df), output_names)
series.alias(function(str(name)))
for series, name in zip(
self._compliant_expr._call(df),
self._compliant_expr._evaluate_output_names(df),
)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
evaluate_output_names=self._compliant_expr._evaluate_output_names,
alias_output_names=lambda output_names: [
function(str(name)) for name in output_names
],
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={**self._compliant_expr._kwargs, "function": function},
)

def prefix(self: Self, prefix: str) -> ArrowExpr:
root_names = self._compliant_expr._root_names
if root_names is None:
msg = ".name.prefix"
raise AnonymousExprError.from_expr_name(msg)

output_names = [prefix + str(name) for name in root_names]
return self._compliant_expr.__class__(
lambda df: [
series.alias(name)
for series, name in zip(self._compliant_expr._call(df), output_names)
series.alias(f"{prefix}{name}")
for series, name in zip(
self._compliant_expr._call(df),
self._compliant_expr._evaluate_output_names(df),
)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
evaluate_output_names=self._compliant_expr._evaluate_output_names,
alias_output_names=lambda output_names: [
f"{prefix}{output_name}" for output_name in output_names
],
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={**self._compliant_expr._kwargs, "prefix": prefix},
)

def suffix(self: Self, suffix: str) -> ArrowExpr:
root_names = self._compliant_expr._root_names
if root_names is None:
msg = ".name.suffix"
raise AnonymousExprError.from_expr_name(msg)

output_names = [str(name) + suffix for name in root_names]

return self._compliant_expr.__class__(
lambda df: [
series.alias(name)
for series, name in zip(self._compliant_expr._call(df), output_names)
series.alias(f"{name}{suffix}")
for series, name in zip(
self._compliant_expr._call(df),
self._compliant_expr._evaluate_output_names(df),
)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
evaluate_output_names=self._compliant_expr._evaluate_output_names,
alias_output_names=lambda output_names: [
f"{output_name}{suffix}" for output_name in output_names
],
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs={**self._compliant_expr._kwargs, "suffix": suffix},
)

def to_lowercase(self: Self) -> ArrowExpr:
root_names = self._compliant_expr._root_names

if root_names is None:
msg = ".name.to_lowercase"
raise AnonymousExprError.from_expr_name(msg)
output_names = [str(name).lower() for name in root_names]

return self._compliant_expr.__class__(
lambda df: [
series.alias(name)
for series, name in zip(self._compliant_expr._call(df), output_names)
series.alias(str(name).lower())
for series, name in zip(
self._compliant_expr._call(df),
self._compliant_expr._evaluate_output_names(df),
)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
evaluate_output_names=self._compliant_expr._evaluate_output_names,
alias_output_names=lambda output_names: [
str(name).lower() for name in output_names
],
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs=self._compliant_expr._kwargs,
)

def to_uppercase(self: Self) -> ArrowExpr:
root_names = self._compliant_expr._root_names

if root_names is None:
msg = ".name.to_uppercase"
raise AnonymousExprError.from_expr_name(msg)
output_names = [str(name).upper() for name in root_names]

return self._compliant_expr.__class__(
lambda df: [
series.alias(name)
for series, name in zip(self._compliant_expr._call(df), output_names)
series.alias(str(name).upper())
for series, name in zip(
self._compliant_expr._call(df),
self._compliant_expr._evaluate_output_names(df),
)
],
depth=self._compliant_expr._depth,
function_name=self._compliant_expr._function_name,
root_names=root_names,
output_names=output_names,
evaluate_output_names=self._compliant_expr._evaluate_output_names,
alias_output_names=lambda output_names: [
str(name).upper() for name in output_names
],
backend_version=self._compliant_expr._backend_version,
version=self._compliant_expr._version,
kwargs=self._compliant_expr._kwargs,
Expand Down
Loading
Loading