diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 8b6bbaae3..c3c9ca4b9 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -857,39 +857,22 @@ def estimated_size(self: Self, unit: SizeUnit = "b") -> int | float: return self._compliant_frame.estimated_size(unit=unit) # type: ignore[no-any-return] @overload - def __getitem__(self: Self, item: tuple[Sequence[int], slice]) -> Self: ... - @overload - def __getitem__(self: Self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ... - @overload - def __getitem__(self: Self, item: tuple[slice, Sequence[int]]) -> Self: ... - @overload - def __getitem__(self: Self, item: tuple[Sequence[int], str]) -> Series[Any]: ... # type: ignore[overload-overlap] - @overload - def __getitem__(self: Self, item: tuple[slice, str]) -> Series[Any]: ... # type: ignore[overload-overlap] - @overload - def __getitem__(self: Self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ... - @overload - def __getitem__(self: Self, item: tuple[slice, Sequence[str]]) -> Self: ... - @overload - def __getitem__(self: Self, item: tuple[Sequence[int], int]) -> Series[Any]: ... # type: ignore[overload-overlap] - @overload - def __getitem__(self: Self, item: tuple[slice, int]) -> Series[Any]: ... # type: ignore[overload-overlap] - - @overload - def __getitem__(self: Self, item: Sequence[int]) -> Self: ... - - @overload - def __getitem__(self: Self, item: str) -> Series[Any]: ... # type: ignore[overload-overlap] - - @overload - def __getitem__(self: Self, item: Sequence[str]) -> Self: ... + def __getitem__( # type: ignore[overload-overlap] + self: Self, key: str | tuple[slice | Sequence[int] | np.ndarray, int | str] + ) -> Series[Any]: ... @overload - def __getitem__(self: Self, item: slice) -> Self: ... - - @overload - def __getitem__(self: Self, item: tuple[slice, slice]) -> Self: ... - + def __getitem__( + self: Self, + key: ( + slice + | Sequence[int] + | Sequence[str] + | tuple[ + slice | Sequence[int] | np.ndarray, slice | Sequence[int] | Sequence[str] + ] + ), + ) -> Self: ... def __getitem__( self: Self, item: ( @@ -897,10 +880,10 @@ def __getitem__( | slice | Sequence[int] | Sequence[str] - | tuple[Sequence[int], str | int] - | tuple[slice, str | int] - | tuple[slice | Sequence[int], Sequence[int] | Sequence[str] | slice] - | tuple[slice, slice] + | tuple[slice | Sequence[int] | np.ndarray, int | str] + | tuple[ + slice | Sequence[int] | np.ndarray, slice | Sequence[int] | Sequence[str] + ] ), ) -> Series[Any] | Self: """Extract column or slice of DataFrame. diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index af8e11922..77124ecb5 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -132,38 +132,21 @@ def _lazyframe(self: Self) -> type[LazyFrame[Any]]: return LazyFrame @overload - def __getitem__(self: Self, item: tuple[Sequence[int], slice]) -> Self: ... + def __getitem__( # type: ignore[overload-overlap] + self: Self, key: str | tuple[slice | Sequence[int] | np.ndarray, int | str] + ) -> Series: ... @overload - def __getitem__(self: Self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ... - @overload - def __getitem__(self: Self, item: tuple[slice, Sequence[int]]) -> Self: ... - @overload - def __getitem__(self: Self, item: tuple[Sequence[int], str]) -> Series: ... # type: ignore[overload-overlap] - @overload - def __getitem__(self: Self, item: tuple[slice, str]) -> Series: ... # type: ignore[overload-overlap] - @overload - def __getitem__(self: Self, item: tuple[Sequence[int], Sequence[str]]) -> Self: ... - @overload - def __getitem__(self: Self, item: tuple[slice, Sequence[str]]) -> Self: ... - @overload - def __getitem__(self: Self, item: tuple[Sequence[int], int]) -> Series: ... # type: ignore[overload-overlap] - @overload - def __getitem__(self: Self, item: tuple[slice, int]) -> Series: ... # type: ignore[overload-overlap] - - @overload - def __getitem__(self: Self, item: Sequence[int]) -> Self: ... - - @overload - def __getitem__(self: Self, item: str) -> Series: ... # type: ignore[overload-overlap] - - @overload - def __getitem__(self: Self, item: Sequence[str]) -> Self: ... - - @overload - def __getitem__(self: Self, item: slice) -> Self: ... - - @overload - def __getitem__(self: Self, item: tuple[slice, slice]) -> Self: ... + def __getitem__( + self: Self, + key: ( + slice + | Sequence[int] + | Sequence[str] + | tuple[ + slice | Sequence[int] | np.ndarray, slice | Sequence[int] | Sequence[str] + ] + ), + ) -> Self: ... def __getitem__(self: Self, item: Any) -> Any: return super().__getitem__(item) diff --git a/tests/frame/getitem_test.py b/tests/frame/getitem_test.py index 9f5a9b52d..379f99b46 100644 --- a/tests/frame/getitem_test.py +++ b/tests/frame/getitem_test.py @@ -1,5 +1,7 @@ from __future__ import annotations +from typing import Any + import numpy as np import pandas as pd import polars as pl @@ -117,73 +119,45 @@ def test_slice_int_rows_str_columns(constructor_eager: ConstructorEager) -> None assert_equal_data(result, expected) -def test_slice_slice_columns(constructor_eager: ConstructorEager) -> None: # noqa: PLR0915 +@pytest.mark.parametrize( + ("row_selector", "col_selector", "expected"), + [ + ([0, 1], slice("b", "c"), {"b": [4, 5], "c": [7, 8]}), + ([0, 1], slice(None, "c"), {"a": [1, 2], "b": [4, 5], "c": [7, 8]}), + ([0, 1], slice("a", "d", 2), {"a": [1, 2], "c": [7, 8]}), + ([0, 1], slice("b", None), {"b": [4, 5], "c": [7, 8], "d": [1, 4]}), + ([0, 1], slice(1, 3), {"b": [4, 5], "c": [7, 8]}), + ([0, 1], slice(None, 3), {"a": [1, 2], "b": [4, 5], "c": [7, 8]}), + ([0, 1], slice(0, 4, 2), {"a": [1, 2], "c": [7, 8]}), + ([0, 1], slice(1, None), {"b": [4, 5], "c": [7, 8], "d": [1, 4]}), + (slice(None), ["b", "d"], {"b": [4, 5, 6], "d": [1, 4, 2]}), + (slice(None), [0, 2], {"a": [1, 2, 3], "c": [7, 8, 9]}), + (slice(None, 2), [0, 2], {"a": [1, 2], "c": [7, 8]}), + (slice(None, 2), ["a", "c"], {"a": [1, 2], "c": [7, 8]}), + (slice(1, None), [0, 2], {"a": [2, 3], "c": [8, 9]}), + (slice(1, None), ["a", "c"], {"a": [2, 3], "c": [8, 9]}), + (["b", "c"], None, {"b": [4, 5, 6], "c": [7, 8, 9]}), + (slice(None, 2), None, {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}), + (slice(2, None), None, {"a": [3], "b": [6], "c": [9], "d": [2]}), + (slice("a", "b"), None, {"a": [1, 2, 3], "b": [4, 5, 6]}), + ((0, 1), slice(None), {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}), + ([0, 1], slice(None), {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}), + ( + [0, 1], + ["a", "b", "c", "d"], + {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]}, + ), + ], +) +def test_slice_slice_columns( + constructor_eager: ConstructorEager, + row_selector: Any, + col_selector: Any, + expected: dict[str, list[Any]], +) -> None: data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9], "d": [1, 4, 2]} df = nw.from_native(constructor_eager(data), eager_only=True) - result = df[[0, 1], "b":"c"] # type: ignore[misc] - expected = {"b": [4, 5], "c": [7, 8]} - assert_equal_data(result, expected) - result = df[[0, 1], :"c"] # type: ignore[misc] - expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8]} - assert_equal_data(result, expected) - result = df[[0, 1], "a":"d":2] # type: ignore[misc] - expected = {"a": [1, 2], "c": [7, 8]} - assert_equal_data(result, expected) - result = df[[0, 1], "b":] # type: ignore[misc] - expected = {"b": [4, 5], "c": [7, 8], "d": [1, 4]} - assert_equal_data(result, expected) - result = df[[0, 1], 1:3] - expected = {"b": [4, 5], "c": [7, 8]} - assert_equal_data(result, expected) - result = df[[0, 1], :3] - expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8]} - assert_equal_data(result, expected) - result = df[[0, 1], 0:4:2] - expected = {"a": [1, 2], "c": [7, 8]} - assert_equal_data(result, expected) - result = df[[0, 1], 1:] - expected = {"b": [4, 5], "c": [7, 8], "d": [1, 4]} - assert_equal_data(result, expected) - result = df[:, ["b", "d"]] - expected = {"b": [4, 5, 6], "d": [1, 4, 2]} - assert_equal_data(result, expected) - result = df[:, [0, 2]] - expected = {"a": [1, 2, 3], "c": [7, 8, 9]} - assert_equal_data(result, expected) - result = df[:2, [0, 2]] - expected = {"a": [1, 2], "c": [7, 8]} - assert_equal_data(result, expected) - result = df[:2, ["a", "c"]] - expected = {"a": [1, 2], "c": [7, 8]} - assert_equal_data(result, expected) - result = df[1:, [0, 2]] - expected = {"a": [2, 3], "c": [8, 9]} - assert_equal_data(result, expected) - result = df[1:, ["a", "c"]] - expected = {"a": [2, 3], "c": [8, 9]} - assert_equal_data(result, expected) - result = df[["b", "c"]] - expected = {"b": [4, 5, 6], "c": [7, 8, 9]} - assert_equal_data(result, expected) - result = df[:2] - expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]} - assert_equal_data(result, expected) - result = df[2:] - expected = {"a": [3], "b": [6], "c": [9], "d": [2]} - assert_equal_data(result, expected) - # mypy says "Slice index must be an integer", but we do in fact support - # using string slices - result = df["a":"b"] # type: ignore[misc] - expected = {"a": [1, 2, 3], "b": [4, 5, 6]} - assert_equal_data(result, expected) - result = df[(0, 1), :] - expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]} - assert_equal_data(result, expected) - result = df[[0, 1], :] - expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]} - assert_equal_data(result, expected) - result = df[[0, 1], df.columns] - expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8], "d": [1, 4]} + result = df[row_selector] if col_selector is None else df[row_selector, col_selector] assert_equal_data(result, expected)