Skip to content

Commit

Permalink
wip: support interchange protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Jul 13, 2024
1 parent 85a2b4c commit 6af274c
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 0 deletions.
Empty file.
75 changes: 75 additions & 0 deletions narwhals/_interchange/dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

import enum
from typing import Any

from narwhals import dtypes


class DtypeKind(enum.IntEnum):
# https://data-apis.org/dataframe-protocol/latest/API.html
INT = 0
UINT = 1
FLOAT = 2
BOOL = 20
STRING = 21 # UTF-8
DATETIME = 22
CATEGORICAL = 23


def map_interchange_dtype_to_narwhals_dtype(
interchange_dtype: tuple[DtypeKind, int, Any, Any],
) -> dtypes.DType:
if interchange_dtype[0] == DtypeKind.INT:
if interchange_dtype[1] == 64:
return dtypes.Int64()
if interchange_dtype[1] == 32:
return dtypes.Int32()
if interchange_dtype[1] == 16:
return dtypes.Int16()
if interchange_dtype[1] == 8:
return dtypes.Int8()
raise AssertionError("Invalid bit width for INT")
if interchange_dtype[0] == DtypeKind.UINT:
if interchange_dtype[1] == 64:
return dtypes.UInt64()
if interchange_dtype[1] == 32:
return dtypes.UInt32()
if interchange_dtype[1] == 16:
return dtypes.UInt16()
if interchange_dtype[1] == 8:
return dtypes.UInt8()
raise AssertionError("Invalid bit width for UINT")
if interchange_dtype[0] == DtypeKind.FLOAT:
if interchange_dtype[1] == 64:
return dtypes.Float64()
if interchange_dtype[1] == 32:
return dtypes.Float32()
raise AssertionError("Invalid bit width for FLOAT")
if interchange_dtype[0] == DtypeKind.BOOL:
return dtypes.Boolean()
if interchange_dtype[0] == DtypeKind.STRING:
return dtypes.String()
if interchange_dtype[0] == DtypeKind.DATETIME:
return dtypes.Datetime()
if interchange_dtype[0] == DtypeKind.CATEGORICAL:
return dtypes.Categorical()
msg = f"Invalid dtype, got: {interchange_dtype}"
raise AssertionError(msg)


class InterchangeFrame:
def __init__(self, df: Any) -> None:
self._df = df

def __narwhals_dataframe__(self) -> Any:
return self

@property
def schema(self) -> dict[str, dtypes.DType]:
return {
column_name: map_interchange_dtype_to_narwhals_dtype(
self._df.get_column_by_name(column_name).dtype
)
for column_name in self._df.column_names()
}
11 changes: 11 additions & 0 deletions narwhals/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def from_native( # noqa: PLR0915
"""
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.series import ArrowSeries
from narwhals._interchange.dataframe import InterchangeFrame
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals._pandas_like.series import PandasLikeSeries
from narwhals._pandas_like.utils import Implementation
Expand Down Expand Up @@ -332,6 +333,16 @@ def from_native( # noqa: PLR0915
is_polars=False,
backend_version=parse_version(pa.__version__),
)
elif hasattr(native_object, "__dataframe__"):
if series_only:
msg = "Cannot only use `series_only` with object which only implements __dataframe__"
raise TypeError(msg)
# placeholder (0,) version here, as we wouldn't use it in this case anyway.
return DataFrame(
InterchangeFrame(native_object.__dataframe__()),
is_polars=False,
backend_version=(0,),
)
elif hasattr(native_object, "__narwhals_dataframe__"):
if series_only:
msg = "Cannot only use `series_only` with dataframe"
Expand Down

0 comments on commit 6af274c

Please sign in to comment.