From 6af274c6fa72c40fac3b91cb4270cc39df7a6363 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 13 Jul 2024 21:38:01 +0100 Subject: [PATCH] wip: support interchange protocol --- narwhals/_interchange/__init__.py | 0 narwhals/_interchange/dataframe.py | 75 ++++++++++++++++++++++++++++++ narwhals/translate.py | 11 +++++ 3 files changed, 86 insertions(+) create mode 100644 narwhals/_interchange/__init__.py create mode 100644 narwhals/_interchange/dataframe.py diff --git a/narwhals/_interchange/__init__.py b/narwhals/_interchange/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/narwhals/_interchange/dataframe.py b/narwhals/_interchange/dataframe.py new file mode 100644 index 0000000000..1459dfd40b --- /dev/null +++ b/narwhals/_interchange/dataframe.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import enum +from typing import Any + +from narwhals import dtypes + + +class DtypeKind(enum.IntEnum): + # https://data-apis.org/dataframe-protocol/latest/API.html + INT = 0 + UINT = 1 + FLOAT = 2 + BOOL = 20 + STRING = 21 # UTF-8 + DATETIME = 22 + CATEGORICAL = 23 + + +def map_interchange_dtype_to_narwhals_dtype( + interchange_dtype: tuple[DtypeKind, int, Any, Any], +) -> dtypes.DType: + if interchange_dtype[0] == DtypeKind.INT: + if interchange_dtype[1] == 64: + return dtypes.Int64() + if interchange_dtype[1] == 32: + return dtypes.Int32() + if interchange_dtype[1] == 16: + return dtypes.Int16() + if interchange_dtype[1] == 8: + return dtypes.Int8() + raise AssertionError("Invalid bit width for INT") + if interchange_dtype[0] == DtypeKind.UINT: + if interchange_dtype[1] == 64: + return dtypes.UInt64() + if interchange_dtype[1] == 32: + return dtypes.UInt32() + if interchange_dtype[1] == 16: + return dtypes.UInt16() + if interchange_dtype[1] == 8: + return dtypes.UInt8() + raise AssertionError("Invalid bit width for UINT") + if interchange_dtype[0] == DtypeKind.FLOAT: + if interchange_dtype[1] == 64: + return dtypes.Float64() + if interchange_dtype[1] == 32: + return dtypes.Float32() + raise AssertionError("Invalid bit width for FLOAT") + if interchange_dtype[0] == DtypeKind.BOOL: + return dtypes.Boolean() + if interchange_dtype[0] == DtypeKind.STRING: + return dtypes.String() + if interchange_dtype[0] == DtypeKind.DATETIME: + return dtypes.Datetime() + if interchange_dtype[0] == DtypeKind.CATEGORICAL: + return dtypes.Categorical() + msg = f"Invalid dtype, got: {interchange_dtype}" + raise AssertionError(msg) + + +class InterchangeFrame: + def __init__(self, df: Any) -> None: + self._df = df + + def __narwhals_dataframe__(self) -> Any: + return self + + @property + def schema(self) -> dict[str, dtypes.DType]: + return { + column_name: map_interchange_dtype_to_narwhals_dtype( + self._df.get_column_by_name(column_name).dtype + ) + for column_name in self._df.column_names() + } diff --git a/narwhals/translate.py b/narwhals/translate.py index 0c05fa664c..92c93cfd7a 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -249,6 +249,7 @@ def from_native( # noqa: PLR0915 """ from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.series import ArrowSeries + from narwhals._interchange.dataframe import InterchangeFrame from narwhals._pandas_like.dataframe import PandasLikeDataFrame from narwhals._pandas_like.series import PandasLikeSeries from narwhals._pandas_like.utils import Implementation @@ -332,6 +333,16 @@ def from_native( # noqa: PLR0915 is_polars=False, backend_version=parse_version(pa.__version__), ) + elif hasattr(native_object, "__dataframe__"): + if series_only: + msg = "Cannot only use `series_only` with object which only implements __dataframe__" + raise TypeError(msg) + # placeholder (0,) version here, as we wouldn't use it in this case anyway. + return DataFrame( + InterchangeFrame(native_object.__dataframe__()), + is_polars=False, + backend_version=(0,), + ) elif hasattr(native_object, "__narwhals_dataframe__"): if series_only: msg = "Cannot only use `series_only` with dataframe"