Skip to content

Commit

Permalink
feat: add "interchange"-level support for libraries which implement t…
Browse files Browse the repository at this point in the history
…he interchange protocol (#517)

* wip: support interchange protocol

* raise on invalid attributes

* raise on invalid attributes

* typing

* typing

* typing

* no default

* fixup

* fixup types

* fixup types

* fixup types

* wip

* fixup

* coverage

* coverage

* coverage

* coverage

* match error

* change signature

* cov

* cov

* rename
  • Loading branch information
MarcoGorelli authored Jul 14, 2024
1 parent 6dbfe6c commit 7644991
Show file tree
Hide file tree
Showing 23 changed files with 545 additions and 57 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/narwhals.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Here are the top-level functions available in Narwhals.
- col
- concat
- from_native
- get_level
- get_native_namespace
- is_ordered_categorical
- len
Expand Down
2 changes: 2 additions & 0 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from narwhals.expression import sum
from narwhals.expression import sum_horizontal
from narwhals.functions import concat
from narwhals.functions import get_level
from narwhals.functions import show_versions
from narwhals.series import Series
from narwhals.translate import from_native
Expand All @@ -49,6 +50,7 @@
__all__ = [
"selectors",
"concat",
"get_level",
"to_native",
"from_native",
"is_ordered_categorical",
Expand Down
4 changes: 4 additions & 0 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def rows(
def get_column(self, name: str) -> ArrowSeries:
from narwhals._arrow.series import ArrowSeries

if not isinstance(name, str):
msg = f"Expected str, got: {type(name)}"
raise TypeError(msg)

return ArrowSeries(
self._native_dataframe[name],
name=name,
Expand Down
Empty file.
96 changes: 96 additions & 0 deletions narwhals/_interchange/dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

import enum
from typing import TYPE_CHECKING
from typing import Any
from typing import NoReturn

from narwhals import dtypes

if TYPE_CHECKING:
from narwhals._interchange.series import InterchangeSeries


class DtypeKind(enum.IntEnum):
# https://data-apis.org/dataframe-protocol/latest/API.html
INT = 0
UINT = 1
FLOAT = 2
BOOL = 20
STRING = 21 # UTF-8
DATETIME = 22
CATEGORICAL = 23


def map_interchange_dtype_to_narwhals_dtype(
interchange_dtype: tuple[DtypeKind, int, Any, Any],
) -> dtypes.DType:
if interchange_dtype[0] == DtypeKind.INT:
if interchange_dtype[1] == 64:
return dtypes.Int64()
if interchange_dtype[1] == 32:
return dtypes.Int32()
if interchange_dtype[1] == 16:
return dtypes.Int16()
if interchange_dtype[1] == 8:
return dtypes.Int8()
raise AssertionError("Invalid bit width for INT")
if interchange_dtype[0] == DtypeKind.UINT:
if interchange_dtype[1] == 64:
return dtypes.UInt64()
if interchange_dtype[1] == 32:
return dtypes.UInt32()
if interchange_dtype[1] == 16:
return dtypes.UInt16()
if interchange_dtype[1] == 8:
return dtypes.UInt8()
raise AssertionError("Invalid bit width for UINT")
if interchange_dtype[0] == DtypeKind.FLOAT:
if interchange_dtype[1] == 64:
return dtypes.Float64()
if interchange_dtype[1] == 32:
return dtypes.Float32()
raise AssertionError("Invalid bit width for FLOAT")
if interchange_dtype[0] == DtypeKind.BOOL:
return dtypes.Boolean()
if interchange_dtype[0] == DtypeKind.STRING:
return dtypes.String()
if interchange_dtype[0] == DtypeKind.DATETIME:
return dtypes.Datetime()
if interchange_dtype[0] == DtypeKind.CATEGORICAL: # pragma: no cover
# upstream issue: https://github.com/ibis-project/ibis/issues/9570
return dtypes.Categorical()
msg = f"Invalid dtype, got: {interchange_dtype}" # pragma: no cover
raise AssertionError(msg)


class InterchangeFrame:
def __init__(self, df: Any) -> None:
self._native_dataframe = df

def __narwhals_dataframe__(self) -> Any:
return self

def __getitem__(self, item: str) -> InterchangeSeries:
from narwhals._interchange.series import InterchangeSeries

return InterchangeSeries(self._native_dataframe.get_column_by_name(item))

@property
def schema(self) -> dict[str, dtypes.DType]:
return {
column_name: map_interchange_dtype_to_narwhals_dtype(
self._native_dataframe.get_column_by_name(column_name).dtype
)
for column_name in self._native_dataframe.column_names()
}

def __getattr__(self, attr: str) -> NoReturn:
msg = (
f"Attribute {attr} is not supported for metadata-only dataframes.\n\n"
"Hint: you probably called `nw.from_native` on an object which isn't fully "
"supported by Narwhals, yet implements `__dataframe__`. If you would like to "
"see this kind of object supported in Narwhals, please open a feature request "
"at https://github.com/narwhals-dev/narwhals/issues."
)
raise NotImplementedError(msg)
32 changes: 32 additions & 0 deletions narwhals/_interchange/series.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any
from typing import NoReturn

from narwhals._interchange.dataframe import map_interchange_dtype_to_narwhals_dtype

if TYPE_CHECKING:
from narwhals import dtypes


class InterchangeSeries:
def __init__(self, df: Any) -> None:
self._native_series = df

def __narwhals_series__(self) -> Any:
return self

@property
def dtype(self) -> dtypes.DType:
return map_interchange_dtype_to_narwhals_dtype(self._native_series.dtype)

def __getattr__(self, attr: str) -> NoReturn:
msg = (
f"Attribute {attr} is not supported for metadata-only dataframes.\n\n"
"Hint: you probably called `nw.from_native` on an object which isn't fully "
"supported by Narwhals, yet implements `__dataframe__`. If you would like to "
"see this kind of object supported in Narwhals, please open a feature request "
"at https://github.com/narwhals-dev/narwhals/issues."
)
raise NotImplementedError(msg)
13 changes: 13 additions & 0 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class BaseFrame(Generic[FrameT]):
_compliant_frame: Any
_is_polars: bool
_backend_version: tuple[int, ...]
_level: Literal["full", "interchange"]

def __len__(self) -> Any:
return self._compliant_frame.__len__()
Expand All @@ -58,6 +59,7 @@ def _from_compliant_dataframe(self, df: Any) -> Self:
df,
is_polars=self._is_polars,
backend_version=self._backend_version,
level=self._level,
)

def _flatten_and_extract(self, *args: Any, **kwargs: Any) -> Any:
Expand Down Expand Up @@ -119,6 +121,7 @@ def lazy(self) -> LazyFrame[Any]:
self._compliant_frame.lazy(),
is_polars=self._is_polars,
backend_version=self._backend_version,
level=self._level,
)

def with_columns(
Expand Down Expand Up @@ -218,9 +221,11 @@ def __init__(
*,
backend_version: tuple[int, ...],
is_polars: bool,
level: Literal["full", "interchange"],
) -> None:
self._is_polars = is_polars
self._backend_version = backend_version
self._level: Literal["full", "interchange"] = level
if hasattr(df, "__narwhals_dataframe__"):
self._compliant_frame: Any = df.__narwhals_dataframe__()
elif is_polars and isinstance(df, get_polars().DataFrame):
Expand Down Expand Up @@ -453,6 +458,7 @@ def get_column(self, name: str) -> Series:
self._compliant_frame.get_column(name),
backend_version=self._backend_version,
is_polars=self._is_polars,
level=self._level,
)

@overload
Expand Down Expand Up @@ -522,6 +528,7 @@ def __getitem__(
self._compliant_frame[item],
backend_version=self._backend_version,
is_polars=self._is_polars,
level=self._level,
)

elif isinstance(item, (Sequence, slice)) or (
Expand Down Expand Up @@ -587,6 +594,7 @@ def to_dict(
value,
backend_version=self._backend_version,
is_polars=self._is_polars,
level=self._level,
)
for key, value in self._compliant_frame.to_dict(
as_series=as_series
Expand Down Expand Up @@ -1700,6 +1708,7 @@ def is_duplicated(self: Self) -> Series:
self._compliant_frame.is_duplicated(),
backend_version=self._backend_version,
is_polars=self._is_polars,
level=self._level,
)

def is_empty(self: Self) -> bool:
Expand Down Expand Up @@ -1786,6 +1795,7 @@ def is_unique(self: Self) -> Series:
self._compliant_frame.is_unique(),
backend_version=self._backend_version,
is_polars=self._is_polars,
level=self._level,
)

def null_count(self: Self) -> Self:
Expand Down Expand Up @@ -1927,9 +1937,11 @@ def __init__(
*,
is_polars: bool,
backend_version: tuple[int, ...],
level: Literal["full", "interchange"],
) -> None:
self._is_polars = is_polars
self._backend_version = backend_version
self._level = level
if hasattr(df, "__narwhals_lazyframe__"):
self._compliant_frame: Any = df.__narwhals_lazyframe__()
elif is_polars and (
Expand Down Expand Up @@ -1998,6 +2010,7 @@ def collect(self) -> DataFrame[Any]:
self._compliant_frame.collect(),
is_polars=self._is_polars,
backend_version=self._backend_version,
level=self._level,
)

# inherited
Expand Down
19 changes: 19 additions & 0 deletions narwhals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import platform
import sys
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Literal
from typing import TypeVar
Expand All @@ -17,6 +19,9 @@
# The rest of the annotations seem to work fine with this anyway
FrameT = TypeVar("FrameT", bound=Union[DataFrame, LazyFrame]) # type: ignore[type-arg]

if TYPE_CHECKING:
from narwhals.series import Series


def concat(
items: Iterable[FrameT],
Expand Down Expand Up @@ -116,3 +121,17 @@ def show_versions() -> None:
print("\nPython dependencies:") # noqa: T201
for k, stat in deps_info.items():
print(f"{k:>13}: {stat}") # noqa: T201


def get_level(
obj: DataFrame[Any] | LazyFrame[Any] | Series,
) -> Literal["full", "interchange"]:
"""
Level of support Narwhals has for current object.
This can be one of:
- 'full': full Narwhals API support
- 'metadata': only metadata operations are supported (`df.schema`)
"""
return obj._level
9 changes: 8 additions & 1 deletion narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,11 @@ def __init__(
*,
backend_version: tuple[int, ...],
is_polars: bool,
level: Literal["full", "interchange"],
) -> None:
self._is_polars = is_polars
self._backend_version = backend_version
self._level = level
if hasattr(series, "__narwhals_series__"):
self._compliant_series = series.__narwhals_series__()
elif is_polars and (
Expand Down Expand Up @@ -102,7 +104,10 @@ def _extract_native(self, arg: Any) -> Any:

def _from_compliant_series(self, series: Any) -> Self:
return self.__class__(
series, is_polars=self._is_polars, backend_version=self._backend_version
series,
is_polars=self._is_polars,
backend_version=self._backend_version,
level=self._level,
)

def __repr__(self) -> str: # pragma: no cover
Expand Down Expand Up @@ -296,6 +301,7 @@ def to_frame(self) -> DataFrame[Any]:
self._compliant_series.to_frame(),
is_polars=self._is_polars,
backend_version=self._backend_version,
level=self._level,
)

def to_list(self) -> list[Any]:
Expand Down Expand Up @@ -1674,6 +1680,7 @@ def value_counts(
self._compliant_series.value_counts(sort=sort, parallel=parallel),
is_polars=self._is_polars,
backend_version=self._backend_version,
level=self._level,
)

def quantile(
Expand Down
Loading

0 comments on commit 7644991

Please sign in to comment.