Skip to content

Commit

Permalink
chore: get_<dependency> cleanup (#1074)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Sep 27, 2024
1 parent 00ab2c3 commit bbabe44
Show file tree
Hide file tree
Showing 18 changed files with 193 additions and 150 deletions.
11 changes: 8 additions & 3 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from narwhals._arrow.utils import translate_dtype
from narwhals._arrow.utils import validate_dataframe_comparand
from narwhals._expression_parsing import evaluate_into_exprs
from narwhals.dependencies import get_pyarrow
from narwhals.dependencies import is_numpy_array
from narwhals.utils import Implementation
from narwhals.utils import flatten
Expand All @@ -23,6 +22,8 @@
from narwhals.utils import parse_columns_to_drop

if TYPE_CHECKING:
from types import ModuleType

import numpy as np
import pyarrow as pa
from typing_extensions import Self
Expand All @@ -48,8 +49,12 @@ def __narwhals_namespace__(self) -> ArrowNamespace:

return ArrowNamespace(backend_version=self._backend_version)

def __native_namespace__(self) -> Any:
return get_pyarrow()
def __native_namespace__(self: Self) -> ModuleType:
if self._implementation is Implementation.PYARROW:
return self._implementation.to_native_namespace()

msg = f"Expected pyarrow, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)

def __narwhals_dataframe__(self) -> Self:
return self
Expand Down
3 changes: 2 additions & 1 deletion narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Literal
from typing import cast

from narwhals import dtypes
Expand Down Expand Up @@ -239,7 +240,7 @@ def concat(
self,
items: Iterable[ArrowDataFrame],
*,
how: str = "vertical",
how: Literal["horizontal", "vertical"],
) -> ArrowDataFrame:
dfs: list[Any] = [item._native_frame for item in items]

Expand Down
15 changes: 10 additions & 5 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from narwhals._arrow.utils import narwhals_to_native_dtype
from narwhals._arrow.utils import translate_dtype
from narwhals._arrow.utils import validate_column_comparand
from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_pyarrow
from narwhals.utils import Implementation
from narwhals.utils import generate_unique_token

if TYPE_CHECKING:
from types import ModuleType

import pyarrow as pa
from typing_extensions import Self

Expand Down Expand Up @@ -303,8 +303,12 @@ def n_unique(self) -> int:
unique_values = pc.unique(self._native_series)
return pc.count(unique_values, mode="all") # type: ignore[no-any-return]

def __native_namespace__(self) -> Any: # pragma: no cover
return get_pyarrow()
def __native_namespace__(self: Self) -> ModuleType:
if self._implementation is Implementation.PYARROW:
return self._implementation.to_native_namespace()

msg = f"Expected pyarrow, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)

@property
def name(self) -> str:
Expand Down Expand Up @@ -573,7 +577,8 @@ def to_frame(self: Self) -> ArrowDataFrame:
return ArrowDataFrame(df, backend_version=self._backend_version)

def to_pandas(self: Self) -> Any:
pd = get_pandas()
import pandas as pd # ignore-banned-import()

return pd.Series(self._native_series, name=self.name)

def is_duplicated(self: Self) -> ArrowSeries:
Expand Down
16 changes: 11 additions & 5 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
from narwhals._dask.utils import add_row_index
from narwhals._dask.utils import parse_exprs_and_named_exprs
from narwhals._pandas_like.utils import translate_dtype
from narwhals.dependencies import get_dask_dataframe
from narwhals.dependencies import get_pandas
from narwhals.utils import Implementation
from narwhals.utils import flatten
from narwhals.utils import generate_unique_token
from narwhals.utils import parse_columns_to_drop
from narwhals.utils import parse_version

if TYPE_CHECKING:
from types import ModuleType

import dask.dataframe as dd
from typing_extensions import Self

Expand All @@ -36,8 +36,12 @@ def __init__(
self._backend_version = backend_version
self._implementation = Implementation.DASK

def __native_namespace__(self) -> Any: # pragma: no cover
return get_dask_dataframe()
def __native_namespace__(self: Self) -> ModuleType:
if self._implementation is Implementation.DASK:
return self._implementation.to_native_namespace()

msg = f"Expected dask, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)

def __narwhals_namespace__(self) -> DaskNamespace:
from narwhals._dask.namespace import DaskNamespace
Expand All @@ -57,13 +61,15 @@ def with_columns(self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self:
return self._from_native_frame(df)

def collect(self) -> Any:
import pandas as pd # ignore-banned-import()

from narwhals._pandas_like.dataframe import PandasLikeDataFrame

result = self._native_frame.compute()
return PandasLikeDataFrame(
result,
implementation=Implementation.PANDAS,
backend_version=parse_version(get_pandas().__version__),
backend_version=parse_version(pd.__version__),
)

@property
Expand Down
5 changes: 3 additions & 2 deletions narwhals/_dask/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from narwhals._dask.utils import add_row_index
from narwhals._dask.utils import maybe_evaluate
from narwhals._dask.utils import reverse_translate_dtype
from narwhals.dependencies import get_dask
from narwhals.utils import generate_unique_token

if TYPE_CHECKING:
Expand Down Expand Up @@ -803,8 +802,10 @@ def slice(self, offset: int, length: int | None = None) -> DaskExpr:
)

def to_datetime(self, format: str | None = None) -> DaskExpr: # noqa: A002
import dask.dataframe as dd # ignore-banned-import()

return self._expr._from_call(
lambda _input, fmt: get_dask().dataframe.to_datetime(_input, format=fmt),
lambda _input, fmt: dd.to_datetime(_input, format=fmt),
"to_datetime",
format,
returns_scalar=False,
Expand Down
3 changes: 2 additions & 1 deletion narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any
from typing import Callable
from typing import Iterable
from typing import Literal
from typing import NoReturn
from typing import cast

Expand Down Expand Up @@ -208,7 +209,7 @@ def concat(
self,
items: Iterable[DaskLazyFrame],
*,
how: str = "vertical",
how: Literal["horizontal", "vertical"],
) -> DaskLazyFrame:
import dask.dataframe as dd # ignore-banned-import

Expand Down
20 changes: 10 additions & 10 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@
from narwhals._pandas_like.utils import horizontal_concat
from narwhals._pandas_like.utils import translate_dtype
from narwhals._pandas_like.utils import validate_dataframe_comparand
from narwhals.dependencies import get_cudf
from narwhals.dependencies import get_modin
from narwhals.dependencies import get_pandas
from narwhals.dependencies import is_numpy_array
from narwhals.utils import Implementation
from narwhals.utils import flatten
Expand All @@ -27,6 +24,8 @@
from narwhals.utils import parse_columns_to_drop

if TYPE_CHECKING:
from types import ModuleType

import numpy as np
import pandas as pd
from typing_extensions import Self
Expand Down Expand Up @@ -63,13 +62,14 @@ def __narwhals_namespace__(self) -> PandasLikeNamespace:

return PandasLikeNamespace(self._implementation, self._backend_version)

def __native_namespace__(self) -> Any:
if self._implementation is Implementation.PANDAS:
return get_pandas()
if self._implementation is Implementation.MODIN: # pragma: no cover
return get_modin()
if self._implementation is Implementation.CUDF: # pragma: no cover
return get_cudf()
def __native_namespace__(self: Self) -> ModuleType:
if self._implementation in {
Implementation.PANDAS,
Implementation.MODIN,
Implementation.CUDF,
}:
return self._implementation.to_native_namespace()

msg = f"Expected pandas/modin/cudf, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)

Expand Down
3 changes: 2 additions & 1 deletion narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any
from typing import Callable
from typing import Iterable
from typing import Literal
from typing import cast

from narwhals import dtypes
Expand Down Expand Up @@ -273,7 +274,7 @@ def concat(
self,
items: Iterable[PandasLikeDataFrame],
*,
how: str = "vertical",
how: Literal["horizontal", "vertical"],
) -> PandasLikeDataFrame:
dfs: list[Any] = [item._native_frame for item in items]
if how == "horizontal":
Expand Down
20 changes: 10 additions & 10 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
from narwhals._pandas_like.utils import to_datetime
from narwhals._pandas_like.utils import translate_dtype
from narwhals._pandas_like.utils import validate_column_comparand
from narwhals.dependencies import get_cudf
from narwhals.dependencies import get_modin
from narwhals.dependencies import get_pandas
from narwhals.utils import Implementation

if TYPE_CHECKING:
from types import ModuleType

from typing_extensions import Self

from narwhals._pandas_like.dataframe import PandasLikeDataFrame
Expand Down Expand Up @@ -97,13 +96,14 @@ def __init__(
else:
self._use_copy_false = False

def __native_namespace__(self) -> Any:
if self._implementation is Implementation.PANDAS:
return get_pandas()
if self._implementation is Implementation.MODIN: # pragma: no cover
return get_modin()
if self._implementation is Implementation.CUDF: # pragma: no cover
return get_cudf()
def __native_namespace__(self: Self) -> ModuleType:
if self._implementation in {
Implementation.PANDAS,
Implementation.MODIN,
Implementation.CUDF,
}:
return self._implementation.to_native_namespace()

msg = f"Expected pandas/modin/cudf, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)

Expand Down
Loading

0 comments on commit bbabe44

Please sign in to comment.