Skip to content

Commit

Permalink
fix: fix type __get_item__ (#1958)
Browse files Browse the repository at this point in the history
  • Loading branch information
EdAbati authored Feb 7, 2025
1 parent f33e82c commit 20d52d7
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 131 deletions.
53 changes: 18 additions & 35 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,50 +857,33 @@ def estimated_size(self: Self, unit: SizeUnit = "b") -> int | float:
return self._compliant_frame.estimated_size(unit=unit) # type: ignore[no-any-return]

@overload
def __getitem__(self: Self, item: tuple[Sequence[int], slice]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[slice, Sequence[int]]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[Sequence[int], str]) -> Series[Any]: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self: Self, item: tuple[slice, str]) -> Series[Any]: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self: Self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[slice, Sequence[str]]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[Sequence[int], int]) -> Series[Any]: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self: Self, item: tuple[slice, int]) -> Series[Any]: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self: Self, item: Sequence[int]) -> Self: ...

@overload
def __getitem__(self: Self, item: str) -> Series[Any]: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self: Self, item: Sequence[str]) -> Self: ...
def __getitem__( # type: ignore[overload-overlap]
self: Self, key: str | tuple[slice | Sequence[int] | np.ndarray, int | str]
) -> Series[Any]: ...

@overload
def __getitem__(self: Self, item: slice) -> Self: ...

@overload
def __getitem__(self: Self, item: tuple[slice, slice]) -> Self: ...

def __getitem__(
self: Self,
key: (
slice
| Sequence[int]
| Sequence[str]
| tuple[
slice | Sequence[int] | np.ndarray, slice | Sequence[int] | Sequence[str]
]
),
) -> Self: ...
def __getitem__(
self: Self,
item: (
str
| slice
| Sequence[int]
| Sequence[str]
| tuple[Sequence[int], str | int]
| tuple[slice, str | int]
| tuple[slice | Sequence[int], Sequence[int] | Sequence[str] | slice]
| tuple[slice, slice]
| tuple[slice | Sequence[int] | np.ndarray, int | str]
| tuple[
slice | Sequence[int] | np.ndarray, slice | Sequence[int] | Sequence[str]
]
),
) -> Series[Any] | Self:
"""Extract column or slice of DataFrame.
Expand Down
45 changes: 14 additions & 31 deletions narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,38 +132,21 @@ def _lazyframe(self: Self) -> type[LazyFrame[Any]]:
return LazyFrame

@overload
def __getitem__(self: Self, item: tuple[Sequence[int], slice]) -> Self: ...
def __getitem__( # type: ignore[overload-overlap]
self: Self, key: str | tuple[slice | Sequence[int] | np.ndarray, int | str]
) -> Series: ...
@overload
def __getitem__(self: Self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[slice, Sequence[int]]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self: Self, item: tuple[slice, str]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self: Self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[slice, Sequence[str]]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self: Self, item: tuple[slice, int]) -> Series: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self: Self, item: Sequence[int]) -> Self: ...

@overload
def __getitem__(self: Self, item: str) -> Series: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self: Self, item: Sequence[str]) -> Self: ...

@overload
def __getitem__(self: Self, item: slice) -> Self: ...

@overload
def __getitem__(self: Self, item: tuple[slice, slice]) -> Self: ...
def __getitem__(
self: Self,
key: (
slice
| Sequence[int]
| Sequence[str]
| tuple[
slice | Sequence[int] | np.ndarray, slice | Sequence[int] | Sequence[str]
]
),
) -> Self: ...

def __getitem__(self: Self, item: Any) -> Any:
return super().__getitem__(item)
Expand Down
104 changes: 39 additions & 65 deletions tests/frame/getitem_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any

import numpy as np
import pandas as pd
import polars as pl
Expand Down Expand Up @@ -117,73 +119,45 @@ def test_slice_int_rows_str_columns(constructor_eager: ConstructorEager) -> None
assert_equal_data(result, expected)


def test_slice_slice_columns(constructor_eager: ConstructorEager) -> None: # noqa: PLR0915
@pytest.mark.parametrize(
("row_selector", "col_selector", "expected"),
[
([0, 1], slice("b", "c"), {"b": [4, 5], "c": [7, 8]}),
([0, 1], slice(None, "c"), {"a": [1, 2], "b": [4, 5], "c": [7, 8]}),
([0, 1], slice("a", "d", 2), {"a": [1, 2], "c": [7, 8]}),
([0, 1], slice("b", None), {"b": [4, 5], "c": [7, 8], "d": [1, 4]}),
([0, 1], slice(1, 3), {"b": [4, 5], "c": [7, 8]}),
([0, 1], slice(None, 3), {"a": [1, 2], "b": [4, 5], "c": [7, 8]}),
([0, 1], slice(0, 4, 2), {"a": [1, 2], "c": [7, 8]}),
([0, 1], slice(1, None), {"b": [4, 5], "c": [7, 8], "d": [1, 4]}),
(slice(None), ["b", "d"], {"b": [4, 5, 6], "d": [1, 4, 2]}),
(slice(None), [0, 2], {"a": [1, 2, 3], "c": [7, 8, 9]}),
(slice(None, 2), [0, 2], {"a": [1, 2], "c": [7, 8]}),
(slice(None, 2), ["a", "c"], {"a": [1, 2], "c": [7, 8]}),
(slice(1, None), [0, 2], {"a": [2, 3], "c": [8, 9]}),
(slice(1, None), ["a", "c"], {"a": [2, 3], "c": [8, 9]}),
(["b", "c"], None, {"b": [4, 5, 6], "c": [7, 8, 9]}),
(slice(None, 2), None, {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}),
(slice(2, None), None, {"a": [3], "b": [6], "c": [9], "d": [2]}),
(slice("a", "b"), None, {"a": [1, 2, 3], "b": [4, 5, 6]}),
((0, 1), slice(None), {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}),
([0, 1], slice(None), {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}),
(
[0, 1],
["a", "b", "c", "d"],
{"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]},
),
],
)
def test_slice_slice_columns(
constructor_eager: ConstructorEager,
row_selector: Any,
col_selector: Any,
expected: dict[str, list[Any]],
) -> None:
data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9], "d": [1, 4, 2]}
df = nw.from_native(constructor_eager(data), eager_only=True)
result = df[[0, 1], "b":"c"] # type: ignore[misc]
expected = {"b": [4, 5], "c": [7, 8]}
assert_equal_data(result, expected)
result = df[[0, 1], :"c"] # type: ignore[misc]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8]}
assert_equal_data(result, expected)
result = df[[0, 1], "a":"d":2] # type: ignore[misc]
expected = {"a": [1, 2], "c": [7, 8]}
assert_equal_data(result, expected)
result = df[[0, 1], "b":] # type: ignore[misc]
expected = {"b": [4, 5], "c": [7, 8], "d": [1, 4]}
assert_equal_data(result, expected)
result = df[[0, 1], 1:3]
expected = {"b": [4, 5], "c": [7, 8]}
assert_equal_data(result, expected)
result = df[[0, 1], :3]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8]}
assert_equal_data(result, expected)
result = df[[0, 1], 0:4:2]
expected = {"a": [1, 2], "c": [7, 8]}
assert_equal_data(result, expected)
result = df[[0, 1], 1:]
expected = {"b": [4, 5], "c": [7, 8], "d": [1, 4]}
assert_equal_data(result, expected)
result = df[:, ["b", "d"]]
expected = {"b": [4, 5, 6], "d": [1, 4, 2]}
assert_equal_data(result, expected)
result = df[:, [0, 2]]
expected = {"a": [1, 2, 3], "c": [7, 8, 9]}
assert_equal_data(result, expected)
result = df[:2, [0, 2]]
expected = {"a": [1, 2], "c": [7, 8]}
assert_equal_data(result, expected)
result = df[:2, ["a", "c"]]
expected = {"a": [1, 2], "c": [7, 8]}
assert_equal_data(result, expected)
result = df[1:, [0, 2]]
expected = {"a": [2, 3], "c": [8, 9]}
assert_equal_data(result, expected)
result = df[1:, ["a", "c"]]
expected = {"a": [2, 3], "c": [8, 9]}
assert_equal_data(result, expected)
result = df[["b", "c"]]
expected = {"b": [4, 5, 6], "c": [7, 8, 9]}
assert_equal_data(result, expected)
result = df[:2]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}
assert_equal_data(result, expected)
result = df[2:]
expected = {"a": [3], "b": [6], "c": [9], "d": [2]}
assert_equal_data(result, expected)
# mypy says "Slice index must be an integer", but we do in fact support
# using string slices
result = df["a":"b"] # type: ignore[misc]
expected = {"a": [1, 2, 3], "b": [4, 5, 6]}
assert_equal_data(result, expected)
result = df[(0, 1), :]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}
assert_equal_data(result, expected)
result = df[[0, 1], :]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}
assert_equal_data(result, expected)
result = df[[0, 1], df.columns]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}
result = df[row_selector] if col_selector is None else df[row_selector, col_selector]
assert_equal_data(result, expected)


Expand Down

0 comments on commit 20d52d7

Please sign in to comment.