-
Notifications
You must be signed in to change notification settings - Fork 123
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add "interchange"-level support for libraries which implement t…
…he interchange protocol (#517) * wip: support interchange protocol * raise on invalid attributes * raise on invalid attributes * typing * typing * typing * no default * fixup * fixup types * fixup types * fixup types * wip * fixup * coverage * coverage * coverage * coverage * match error * change signature * cov * cov * rename
- Loading branch information
1 parent
6dbfe6c
commit 7644991
Showing
23 changed files
with
545 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
from __future__ import annotations | ||
|
||
import enum | ||
from typing import TYPE_CHECKING | ||
from typing import Any | ||
from typing import NoReturn | ||
|
||
from narwhals import dtypes | ||
|
||
if TYPE_CHECKING: | ||
from narwhals._interchange.series import InterchangeSeries | ||
|
||
|
||
class DtypeKind(enum.IntEnum): | ||
# https://data-apis.org/dataframe-protocol/latest/API.html | ||
INT = 0 | ||
UINT = 1 | ||
FLOAT = 2 | ||
BOOL = 20 | ||
STRING = 21 # UTF-8 | ||
DATETIME = 22 | ||
CATEGORICAL = 23 | ||
|
||
|
||
def map_interchange_dtype_to_narwhals_dtype( | ||
interchange_dtype: tuple[DtypeKind, int, Any, Any], | ||
) -> dtypes.DType: | ||
if interchange_dtype[0] == DtypeKind.INT: | ||
if interchange_dtype[1] == 64: | ||
return dtypes.Int64() | ||
if interchange_dtype[1] == 32: | ||
return dtypes.Int32() | ||
if interchange_dtype[1] == 16: | ||
return dtypes.Int16() | ||
if interchange_dtype[1] == 8: | ||
return dtypes.Int8() | ||
raise AssertionError("Invalid bit width for INT") | ||
if interchange_dtype[0] == DtypeKind.UINT: | ||
if interchange_dtype[1] == 64: | ||
return dtypes.UInt64() | ||
if interchange_dtype[1] == 32: | ||
return dtypes.UInt32() | ||
if interchange_dtype[1] == 16: | ||
return dtypes.UInt16() | ||
if interchange_dtype[1] == 8: | ||
return dtypes.UInt8() | ||
raise AssertionError("Invalid bit width for UINT") | ||
if interchange_dtype[0] == DtypeKind.FLOAT: | ||
if interchange_dtype[1] == 64: | ||
return dtypes.Float64() | ||
if interchange_dtype[1] == 32: | ||
return dtypes.Float32() | ||
raise AssertionError("Invalid bit width for FLOAT") | ||
if interchange_dtype[0] == DtypeKind.BOOL: | ||
return dtypes.Boolean() | ||
if interchange_dtype[0] == DtypeKind.STRING: | ||
return dtypes.String() | ||
if interchange_dtype[0] == DtypeKind.DATETIME: | ||
return dtypes.Datetime() | ||
if interchange_dtype[0] == DtypeKind.CATEGORICAL: # pragma: no cover | ||
# upstream issue: https://github.com/ibis-project/ibis/issues/9570 | ||
return dtypes.Categorical() | ||
msg = f"Invalid dtype, got: {interchange_dtype}" # pragma: no cover | ||
raise AssertionError(msg) | ||
|
||
|
||
class InterchangeFrame: | ||
def __init__(self, df: Any) -> None: | ||
self._native_dataframe = df | ||
|
||
def __narwhals_dataframe__(self) -> Any: | ||
return self | ||
|
||
def __getitem__(self, item: str) -> InterchangeSeries: | ||
from narwhals._interchange.series import InterchangeSeries | ||
|
||
return InterchangeSeries(self._native_dataframe.get_column_by_name(item)) | ||
|
||
@property | ||
def schema(self) -> dict[str, dtypes.DType]: | ||
return { | ||
column_name: map_interchange_dtype_to_narwhals_dtype( | ||
self._native_dataframe.get_column_by_name(column_name).dtype | ||
) | ||
for column_name in self._native_dataframe.column_names() | ||
} | ||
|
||
def __getattr__(self, attr: str) -> NoReturn: | ||
msg = ( | ||
f"Attribute {attr} is not supported for metadata-only dataframes.\n\n" | ||
"Hint: you probably called `nw.from_native` on an object which isn't fully " | ||
"supported by Narwhals, yet implements `__dataframe__`. If you would like to " | ||
"see this kind of object supported in Narwhals, please open a feature request " | ||
"at https://github.com/narwhals-dev/narwhals/issues." | ||
) | ||
raise NotImplementedError(msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
from typing import Any | ||
from typing import NoReturn | ||
|
||
from narwhals._interchange.dataframe import map_interchange_dtype_to_narwhals_dtype | ||
|
||
if TYPE_CHECKING: | ||
from narwhals import dtypes | ||
|
||
|
||
class InterchangeSeries: | ||
def __init__(self, df: Any) -> None: | ||
self._native_series = df | ||
|
||
def __narwhals_series__(self) -> Any: | ||
return self | ||
|
||
@property | ||
def dtype(self) -> dtypes.DType: | ||
return map_interchange_dtype_to_narwhals_dtype(self._native_series.dtype) | ||
|
||
def __getattr__(self, attr: str) -> NoReturn: | ||
msg = ( | ||
f"Attribute {attr} is not supported for metadata-only dataframes.\n\n" | ||
"Hint: you probably called `nw.from_native` on an object which isn't fully " | ||
"supported by Narwhals, yet implements `__dataframe__`. If you would like to " | ||
"see this kind of object supported in Narwhals, please open a feature request " | ||
"at https://github.com/narwhals-dev/narwhals/issues." | ||
) | ||
raise NotImplementedError(msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.