Skip to content

Commit

Permalink
chore: Add some Compliant Protocols (#1522)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
MarcoGorelli and pre-commit-ci[bot] authored Dec 15, 2024
1 parent 9276ad3 commit d0225b3
Show file tree
Hide file tree
Showing 25 changed files with 277 additions and 318 deletions.
5 changes: 4 additions & 1 deletion narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@
from narwhals.typing import SizeUnit
from narwhals.utils import Version

from narwhals.typing import CompliantDataFrame
from narwhals.typing import CompliantLazyFrame

class ArrowDataFrame:

class ArrowDataFrame(CompliantDataFrame, CompliantLazyFrame):
# --- not in the spec ---
def __init__(
self: Self,
Expand Down
11 changes: 8 additions & 3 deletions narwhals/_arrow/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,20 @@

from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import IntoArrowExpr
from narwhals.dtypes import DType
from narwhals.utils import Version

from narwhals._arrow.series import ArrowSeries
from narwhals.typing import CompliantExpr

class ArrowExpr:

class ArrowExpr(CompliantExpr[ArrowSeries]):
_implementation: Implementation = Implementation.PYARROW

def __init__(
self: Self,
call: Callable[[ArrowDataFrame], list[ArrowSeries]],
call: Callable[[ArrowDataFrame], Sequence[ArrowSeries]],
*,
depth: int,
function_name: str,
Expand Down Expand Up @@ -57,6 +59,9 @@ def __repr__(self: Self) -> str: # pragma: no cover
f"output_names={self._output_names}"
)

def __call__(self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
return self._call(df)

@classmethod
def from_column_names(
cls: type[Self],
Expand Down
6 changes: 4 additions & 2 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any
from typing import Callable
from typing import Iterator
from typing import Sequence

from narwhals._expression_parsing import is_simple_aggregation
from narwhals._expression_parsing import parse_into_exprs
Expand All @@ -17,8 +18,9 @@
from typing_extensions import Self

from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import IntoArrowExpr
from narwhals.typing import CompliantExpr


def polars_to_arrow_aggregations() -> (
Expand Down Expand Up @@ -122,7 +124,7 @@ def __iter__(self: Self) -> Iterator[tuple[Any, ArrowDataFrame]]:

def agg_arrow(
grouped: pa.TableGroupBy,
exprs: list[ArrowExpr],
exprs: Sequence[CompliantExpr[ArrowSeries]],
keys: list[str],
output_names: list[str],
from_dataframe: Callable[[Any], ArrowDataFrame],
Expand Down
35 changes: 18 additions & 17 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any
from typing import Iterable
from typing import Literal
from typing import Sequence

from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
Expand All @@ -17,6 +18,7 @@
from narwhals._expression_parsing import combine_root_names
from narwhals._expression_parsing import parse_into_exprs
from narwhals._expression_parsing import reduce_output_names
from narwhals.typing import CompliantNamespace
from narwhals.utils import Implementation
from narwhals.utils import import_dtypes_module

Expand All @@ -30,10 +32,10 @@
from narwhals.utils import Version


class ArrowNamespace:
class ArrowNamespace(CompliantNamespace[ArrowSeries]):
def _create_expr_from_callable(
self: Self,
func: Callable[[ArrowDataFrame], list[ArrowSeries]],
func: Callable[[ArrowDataFrame], Sequence[ArrowSeries]],
*,
depth: int,
function_name: str,
Expand Down Expand Up @@ -181,7 +183,7 @@ def all_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
series = (s for _expr in parsed_exprs for s in _expr._call(df))
series = (s for _expr in parsed_exprs for s in _expr(df))
return [reduce(lambda x, y: x & y, series)]

return self._create_expr_from_callable(
Expand All @@ -196,7 +198,7 @@ def any_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
series = (s for _expr in parsed_exprs for s in _expr._call(df))
series = (s for _expr in parsed_exprs for s in _expr(df))
return [reduce(lambda x, y: x | y, series)]

return self._create_expr_from_callable(
Expand All @@ -214,7 +216,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
series = (
s.fill_null(0, strategy=None, limit=None)
for _expr in parsed_exprs
for s in _expr._call(df)
for s in _expr(df)
)
return [reduce(lambda x, y: x + y, series)]

Expand All @@ -234,12 +236,12 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
series = (
s.fill_null(0, strategy=None, limit=None)
for _expr in parsed_exprs
for s in _expr._call(df)
for s in _expr(df)
)
non_na = (
1 - s.is_null().cast(dtypes.Int64())
for _expr in parsed_exprs
for s in _expr._call(df)
for s in _expr(df)
)
return [
reduce(lambda x, y: x + y, series) / reduce(lambda x, y: x + y, non_na)
Expand All @@ -259,7 +261,7 @@ def min_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
init_series, *series = [s for _expr in parsed_exprs for s in _expr._call(df)]
init_series, *series = [s for _expr in parsed_exprs for s in _expr(df)]
return [
ArrowSeries(
native_series=reduce(
Expand Down Expand Up @@ -287,7 +289,7 @@ def max_horizontal(self: Self, *exprs: IntoArrowExpr) -> ArrowExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)

def func(df: ArrowDataFrame) -> list[ArrowSeries]:
init_series, *series = [s for _expr in parsed_exprs for s in _expr._call(df)]
init_series, *series = [s for _expr in parsed_exprs for s in _expr(df)]
return [
ArrowSeries(
native_series=reduce(
Expand Down Expand Up @@ -387,7 +389,7 @@ def concat_str(
) -> ArrowExpr:
import pyarrow.compute as pc

parsed_exprs: list[ArrowExpr] = [
parsed_exprs = [
*parse_into_exprs(*exprs, namespace=self),
*parse_into_exprs(*more_exprs, namespace=self),
]
Expand All @@ -397,7 +399,7 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
series = (
s._native_series
for _expr in parsed_exprs
for s in _expr.cast(dtypes.String())._call(df)
for s in _expr.cast(dtypes.String())(df)
)
null_handling = "skip" if ignore_nulls else "emit_null"
result_series = pc.binary_join_element_wise(
Expand Down Expand Up @@ -437,7 +439,7 @@ def __init__(
self._otherwise_value = otherwise_value
self._version = version

def __call__(self: Self, df: ArrowDataFrame) -> list[ArrowSeries]:
def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]:
import pyarrow as pa
import pyarrow.compute as pc

Expand All @@ -446,9 +448,9 @@ def __call__(self: Self, df: ArrowDataFrame) -> list[ArrowSeries]:

plx = ArrowNamespace(backend_version=self._backend_version, version=self._version)

condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0]
condition = parse_into_expr(self._condition, namespace=plx)(df)[0]
try:
value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0]
value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
value_series = condition.__class__._from_iterable(
Expand All @@ -471,9 +473,7 @@ def __call__(self: Self, df: ArrowDataFrame) -> list[ArrowSeries]:
)
]
try:
otherwise_series = parse_into_expr(
self._otherwise_value, namespace=plx
)._call(df)[0]
otherwise_expr = parse_into_expr(self._otherwise_value, namespace=plx)
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression.
# Remark that string values _are_ converted into expressions!
Expand All @@ -485,6 +485,7 @@ def __call__(self: Self, df: ArrowDataFrame) -> list[ArrowSeries]:
)
]
else:
otherwise_series = otherwise_expr(df)[0]
condition_native, otherwise_native = broadcast_series(
[condition, otherwise_series]
)
Expand Down
9 changes: 5 additions & 4 deletions narwhals/_arrow/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import NoReturn
from typing import Sequence

from narwhals._arrow.expr import ArrowExpr
from narwhals.utils import Implementation
Expand Down Expand Up @@ -127,10 +128,10 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]:
def __or__(self: Self, other: Self | Any) -> ArrowSelector | Any:
if isinstance(other, ArrowSelector):

def call(df: ArrowDataFrame) -> list[ArrowSeries]:
lhs = self._call(df)
rhs = other._call(df)
return [x for x in lhs if x.name not in {x.name for x in rhs}] + rhs
def call(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
lhs = self(df)
rhs = other(df)
return [*(x for x in lhs if x.name not in {x.name for x in rhs}), *rhs]

return ArrowSelector(
call,
Expand Down
3 changes: 2 additions & 1 deletion narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from narwhals._arrow.namespace import ArrowNamespace
from narwhals.dtypes import DType
from narwhals.utils import Version
from narwhals.typing import CompliantSeries


def maybe_extract_py_scalar(value: Any, return_py_scalar: bool) -> Any: # noqa: FBT001
Expand All @@ -38,7 +39,7 @@ def maybe_extract_py_scalar(value: Any, return_py_scalar: bool) -> Any: # noqa:
return value


class ArrowSeries:
class ArrowSeries(CompliantSeries):
def __init__(
self: Self,
native_series: pa.ChunkedArray,
Expand Down
2 changes: 1 addition & 1 deletion narwhals/_arrow/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.series import ArrowSeries

IntoArrowExpr: TypeAlias = Union[ArrowExpr, str, int, float, ArrowSeries]
IntoArrowExpr: TypeAlias = Union[ArrowExpr, str, ArrowSeries]
2 changes: 1 addition & 1 deletion narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def cast_for_truediv(
return arrow_array, pa_object


def broadcast_series(series: list[ArrowSeries]) -> list[Any]:
def broadcast_series(series: Sequence[ArrowSeries]) -> list[Any]:
lengths = [len(s) for s in series]
max_length = max(lengths)
fast_path = all(_len == max_length for _len in lengths)
Expand Down
3 changes: 2 additions & 1 deletion narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
from narwhals._dask.typing import IntoDaskExpr
from narwhals.dtypes import DType
from narwhals.utils import Version
from narwhals.typing import CompliantLazyFrame


class DaskLazyFrame:
class DaskLazyFrame(CompliantLazyFrame):
def __init__(
self,
native_dataframe: dd.DataFrame,
Expand Down
8 changes: 6 additions & 2 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from narwhals._pandas_like.utils import calculate_timestamp_datetime
from narwhals._pandas_like.utils import native_to_narwhals_dtype
from narwhals.exceptions import ColumnNotFoundError
from narwhals.typing import CompliantExpr
from narwhals.utils import Implementation
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import import_dtypes_module
Expand All @@ -29,12 +30,12 @@
from narwhals.utils import Version


class DaskExpr:
class DaskExpr(CompliantExpr["dask_expr.Series"]):
_implementation: Implementation = Implementation.DASK

def __init__(
self,
call: Callable[[DaskLazyFrame], list[dask_expr.Series]],
call: Callable[[DaskLazyFrame], Sequence[dask_expr.Series]],
*,
depth: int,
function_name: str,
Expand All @@ -55,6 +56,9 @@ def __init__(
self._backend_version = backend_version
self._version = version

def __call__(self, df: DaskLazyFrame) -> Sequence[dask_expr.Series]:
return self._call(df)

def __narwhals_expr__(self) -> None: ...

def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover
Expand Down
6 changes: 4 additions & 2 deletions narwhals/_dask/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Sequence

from narwhals._expression_parsing import is_simple_aggregation
from narwhals._expression_parsing import parse_into_exprs
from narwhals.utils import remove_prefix

if TYPE_CHECKING:
import dask.dataframe as dd
import dask_expr
import pandas as pd

from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals._dask.typing import IntoDaskExpr
from narwhals.typing import CompliantExpr


def n_unique() -> dd.Aggregation:
Expand Down Expand Up @@ -101,7 +103,7 @@ def _from_native_frame(self, df: DaskLazyFrame) -> DaskLazyFrame:
def agg_dask(
df: DaskLazyFrame,
grouped: Any,
exprs: list[DaskExpr],
exprs: Sequence[CompliantExpr[dask_expr.Series]],
keys: list[str],
from_dataframe: Callable[[Any], DaskLazyFrame],
) -> DaskLazyFrame:
Expand Down
Loading

0 comments on commit d0225b3

Please sign in to comment.