Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix type __get_item__ #1958

Merged
merged 3 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 18 additions & 35 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -857,50 +857,33 @@ def estimated_size(self: Self, unit: SizeUnit = "b") -> int | float:
return self._compliant_frame.estimated_size(unit=unit) # type: ignore[no-any-return]

@overload
def __getitem__(self: Self, item: tuple[Sequence[int], slice]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[slice, Sequence[int]]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[Sequence[int], str]) -> Series[Any]: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self: Self, item: tuple[slice, str]) -> Series[Any]: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self: Self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[slice, Sequence[str]]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[Sequence[int], int]) -> Series[Any]: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self: Self, item: tuple[slice, int]) -> Series[Any]: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self: Self, item: Sequence[int]) -> Self: ...

@overload
def __getitem__(self: Self, item: str) -> Series[Any]: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self: Self, item: Sequence[str]) -> Self: ...
def __getitem__( # type: ignore[overload-overlap]
self: Self, key: str | tuple[slice | Sequence[int] | np.ndarray, int | str]
) -> Series[Any]: ...

@overload
def __getitem__(self: Self, item: slice) -> Self: ...

@overload
def __getitem__(self: Self, item: tuple[slice, slice]) -> Self: ...

def __getitem__(
self: Self,
key: (
slice
| Sequence[int]
| Sequence[str]
| tuple[
slice | Sequence[int] | np.ndarray, slice | Sequence[int] | Sequence[str]
]
),
) -> Self: ...
def __getitem__(
self: Self,
item: (
str
| slice
| Sequence[int]
| Sequence[str]
| tuple[Sequence[int], str | int]
| tuple[slice, str | int]
| tuple[slice | Sequence[int], Sequence[int] | Sequence[str] | slice]
| tuple[slice, slice]
| tuple[slice | Sequence[int] | np.ndarray, int | str]
| tuple[
slice | Sequence[int] | np.ndarray, slice | Sequence[int] | Sequence[str]
]
),
) -> Series[Any] | Self:
"""Extract column or slice of DataFrame.
Expand Down
45 changes: 14 additions & 31 deletions narwhals/stable/v1/__init__.py
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if changing this should be allowed, but in theory we are just expanding the allowed types (adding np.array that is suppoerted anyway)
πŸ€”

Copy link
Member

@dangotbanned dangotbanned Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, totally allowed πŸ˜„ adding types should still be backwards-compatible

Original file line number Diff line number Diff line change
Expand Up @@ -132,38 +132,21 @@ def _lazyframe(self: Self) -> type[LazyFrame[Any]]:
return LazyFrame

@overload
def __getitem__(self: Self, item: tuple[Sequence[int], slice]) -> Self: ...
def __getitem__( # type: ignore[overload-overlap]
self: Self, key: str | tuple[slice | Sequence[int] | np.ndarray, int | str]
) -> Series: ...
@overload
def __getitem__(self: Self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[slice, Sequence[int]]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self: Self, item: tuple[slice, str]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self: Self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[slice, Sequence[str]]) -> Self: ...
@overload
def __getitem__(self: Self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap]
@overload
def __getitem__(self: Self, item: tuple[slice, int]) -> Series: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self: Self, item: Sequence[int]) -> Self: ...

@overload
def __getitem__(self: Self, item: str) -> Series: ... # type: ignore[overload-overlap]

@overload
def __getitem__(self: Self, item: Sequence[str]) -> Self: ...

@overload
def __getitem__(self: Self, item: slice) -> Self: ...

@overload
def __getitem__(self: Self, item: tuple[slice, slice]) -> Self: ...
def __getitem__(
self: Self,
key: (
slice
| Sequence[int]
| Sequence[str]
| tuple[
slice | Sequence[int] | np.ndarray, slice | Sequence[int] | Sequence[str]
]
),
) -> Self: ...

def __getitem__(self: Self, item: Any) -> Any:
return super().__getitem__(item)
Expand Down
104 changes: 39 additions & 65 deletions tests/frame/getitem_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from typing import Any

import numpy as np
import pandas as pd
import polars as pl
Expand Down Expand Up @@ -117,73 +119,45 @@ def test_slice_int_rows_str_columns(constructor_eager: ConstructorEager) -> None
assert_equal_data(result, expected)


def test_slice_slice_columns(constructor_eager: ConstructorEager) -> None: # noqa: PLR0915
@pytest.mark.parametrize(
("row_selector", "col_selector", "expected"),
[
([0, 1], slice("b", "c"), {"b": [4, 5], "c": [7, 8]}),
([0, 1], slice(None, "c"), {"a": [1, 2], "b": [4, 5], "c": [7, 8]}),
([0, 1], slice("a", "d", 2), {"a": [1, 2], "c": [7, 8]}),
([0, 1], slice("b", None), {"b": [4, 5], "c": [7, 8], "d": [1, 4]}),
([0, 1], slice(1, 3), {"b": [4, 5], "c": [7, 8]}),
([0, 1], slice(None, 3), {"a": [1, 2], "b": [4, 5], "c": [7, 8]}),
([0, 1], slice(0, 4, 2), {"a": [1, 2], "c": [7, 8]}),
([0, 1], slice(1, None), {"b": [4, 5], "c": [7, 8], "d": [1, 4]}),
(slice(None), ["b", "d"], {"b": [4, 5, 6], "d": [1, 4, 2]}),
(slice(None), [0, 2], {"a": [1, 2, 3], "c": [7, 8, 9]}),
(slice(None, 2), [0, 2], {"a": [1, 2], "c": [7, 8]}),
(slice(None, 2), ["a", "c"], {"a": [1, 2], "c": [7, 8]}),
(slice(1, None), [0, 2], {"a": [2, 3], "c": [8, 9]}),
(slice(1, None), ["a", "c"], {"a": [2, 3], "c": [8, 9]}),
(["b", "c"], None, {"b": [4, 5, 6], "c": [7, 8, 9]}),
(slice(None, 2), None, {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}),
(slice(2, None), None, {"a": [3], "b": [6], "c": [9], "d": [2]}),
(slice("a", "b"), None, {"a": [1, 2, 3], "b": [4, 5, 6]}),
((0, 1), slice(None), {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}),
([0, 1], slice(None), {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}),
(
[0, 1],
["a", "b", "c", "d"],
{"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]},
),
],
)
def test_slice_slice_columns(
constructor_eager: ConstructorEager,
row_selector: Any,
col_selector: Any,
expected: dict[str, list[Any]],
) -> None:
data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9], "d": [1, 4, 2]}
df = nw.from_native(constructor_eager(data), eager_only=True)
result = df[[0, 1], "b":"c"] # type: ignore[misc]
expected = {"b": [4, 5], "c": [7, 8]}
assert_equal_data(result, expected)
result = df[[0, 1], :"c"] # type: ignore[misc]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8]}
assert_equal_data(result, expected)
result = df[[0, 1], "a":"d":2] # type: ignore[misc]
expected = {"a": [1, 2], "c": [7, 8]}
assert_equal_data(result, expected)
result = df[[0, 1], "b":] # type: ignore[misc]
expected = {"b": [4, 5], "c": [7, 8], "d": [1, 4]}
assert_equal_data(result, expected)
result = df[[0, 1], 1:3]
expected = {"b": [4, 5], "c": [7, 8]}
assert_equal_data(result, expected)
result = df[[0, 1], :3]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8]}
assert_equal_data(result, expected)
result = df[[0, 1], 0:4:2]
expected = {"a": [1, 2], "c": [7, 8]}
assert_equal_data(result, expected)
result = df[[0, 1], 1:]
expected = {"b": [4, 5], "c": [7, 8], "d": [1, 4]}
assert_equal_data(result, expected)
result = df[:, ["b", "d"]]
expected = {"b": [4, 5, 6], "d": [1, 4, 2]}
assert_equal_data(result, expected)
result = df[:, [0, 2]]
expected = {"a": [1, 2, 3], "c": [7, 8, 9]}
assert_equal_data(result, expected)
result = df[:2, [0, 2]]
expected = {"a": [1, 2], "c": [7, 8]}
assert_equal_data(result, expected)
result = df[:2, ["a", "c"]]
expected = {"a": [1, 2], "c": [7, 8]}
assert_equal_data(result, expected)
result = df[1:, [0, 2]]
expected = {"a": [2, 3], "c": [8, 9]}
assert_equal_data(result, expected)
result = df[1:, ["a", "c"]]
expected = {"a": [2, 3], "c": [8, 9]}
assert_equal_data(result, expected)
result = df[["b", "c"]]
expected = {"b": [4, 5, 6], "c": [7, 8, 9]}
assert_equal_data(result, expected)
result = df[:2]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}
assert_equal_data(result, expected)
result = df[2:]
expected = {"a": [3], "b": [6], "c": [9], "d": [2]}
assert_equal_data(result, expected)
# mypy says "Slice index must be an integer", but we do in fact support
# using string slices
result = df["a":"b"] # type: ignore[misc]
expected = {"a": [1, 2, 3], "b": [4, 5, 6]}
assert_equal_data(result, expected)
result = df[(0, 1), :]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}
assert_equal_data(result, expected)
result = df[[0, 1], :]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}
assert_equal_data(result, expected)
result = df[[0, 1], df.columns]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}
result = df[row_selector] if col_selector is None else df[row_selector, col_selector]
assert_equal_data(result, expected)


Expand Down
Loading