Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for series[other_series] #2013

Merged
merged 9 commits into from
Feb 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,14 +327,16 @@ def __narwhals_series__(self: Self) -> Self:
def __getitem__(self: Self, idx: int) -> Any: ...

@overload
def __getitem__(self: Self, idx: slice | Sequence[int]) -> Self: ...
def __getitem__(self: Self, idx: slice | Sequence[int] | pa.ChunkedArray) -> Self: ...

def __getitem__(self: Self, idx: int | slice | Sequence[int]) -> Any | Self:
def __getitem__(
self: Self, idx: int | slice | Sequence[int] | pa.ChunkedArray
) -> Any | Self:
if isinstance(idx, int):
return maybe_extract_py_scalar(
self._native_series[idx], return_py_scalar=True
)
if isinstance(idx, Sequence):
if isinstance(idx, (Sequence, pa.ChunkedArray)):
return self._from_native_series(self._native_series.take(idx))
return self._from_native_series(self._native_series[idx])

Expand Down
6 changes: 4 additions & 2 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,11 @@ def dtype(self: Self) -> DType:
def __getitem__(self: Self, item: int) -> Any: ...

@overload
def __getitem__(self: Self, item: slice | Sequence[int]) -> Self: ...
def __getitem__(self: Self, item: slice | Sequence[int] | pl.Series) -> Self: ...

def __getitem__(self: Self, item: int | slice | Sequence[int]) -> Any | Self:
def __getitem__(
self: Self, item: int | slice | Sequence[int] | pl.Series
) -> Any | Self:
return self._from_native_object(self._native_series.__getitem__(item))

def cast(self: Self, dtype: DType) -> Self:
Expand Down
13 changes: 8 additions & 5 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from narwhals.series_dt import SeriesDateTimeNamespace
from narwhals.series_list import SeriesListNamespace
from narwhals.series_str import SeriesStringNamespace
from narwhals.translate import to_native
from narwhals.typing import IntoSeriesT
from narwhals.utils import _validate_rolling_arguments
from narwhals.utils import generate_repr
Expand Down Expand Up @@ -118,17 +119,17 @@ def __array__(self: Self, dtype: Any = None, copy: bool | None = None) -> _1DArr
def __getitem__(self: Self, idx: int) -> Any: ...

@overload
def __getitem__(self: Self, idx: slice | Sequence[int]) -> Self: ...
def __getitem__(self: Self, idx: slice | Sequence[int] | Self) -> Self: ...

def __getitem__(self: Self, idx: int | slice | Sequence[int]) -> Any | Self:
def __getitem__(self: Self, idx: int | slice | Sequence[int] | Self) -> Any | Self:
"""Retrieve elements from the object using integer indexing or slicing.

Arguments:
idx: The index, slice, or sequence of indices to retrieve.

- If `idx` is an integer, a single element is returned.
- If `idx` is a slice or a sequence of integers,
a subset of the Series is returned.
- If `idx` is a slice, a sequence of integers, or another Series
(with integer values) a subset of the Series is returned.

Returns:
A single element if `idx` is an integer, else a subset of the Series.
Expand Down Expand Up @@ -156,7 +157,9 @@ def __getitem__(self: Self, idx: int | slice | Sequence[int]) -> Any | Self:
is_numpy_scalar(idx) and idx.dtype.kind in ("i", "u")
):
return self._compliant_series[idx]
return self._from_compliant_series(self._compliant_series[idx])
return self._from_compliant_series(
self._compliant_series[to_native(idx, pass_through=True)]
Copy link
Member Author

@FBruzzesi FBruzzesi Feb 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kind of unrelated, but we have a method called _extract_native which actually (maybe) extracts compliant πŸ™ˆ

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's an awfully good point

)

def __native_namespace__(self: Self) -> ModuleType:
return self._compliant_series.__native_namespace__() # type: ignore[no-any-return]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pyarrow as pa
import pytest

import narwhals.stable.v1 as nw
from tests.utils import ConstructorEager
from tests.utils import assert_equal_data

if TYPE_CHECKING:
from tests.utils import ConstructorEager


def test_slice(constructor_eager: ConstructorEager) -> None:
def test_by_slice(constructor_eager: ConstructorEager) -> None:
data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9], "d": [1, 4, 2]}
df = nw.from_native(constructor_eager(data), eager_only=True)
result = {"a": df["a"][[0, 1]]}
Expand Down Expand Up @@ -45,3 +50,14 @@ def test_index(constructor_eager: ConstructorEager) -> None:
df = constructor_eager({"a": [0, 1, 2]})
snw = nw.from_native(df, eager_only=True)["a"]
assert snw[snw[0]] == 0


@pytest.mark.filterwarnings(
"ignore:.*_array__ implementation doesn't accept a copy keyword.*:DeprecationWarning:modin"
)
Comment on lines +55 to +57
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this just something that Modin calls (on itself?) internally? not anything we need to concern ourselves with?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it traces back to modin/pandas/indexing.py

def test_getitem_other_series(constructor_eager: ConstructorEager) -> None:
series = nw.from_native(constructor_eager({"a": [1, None, 2, 3]}), eager_only=True)[
"a"
]
other = nw.from_native(constructor_eager({"b": [1, 3]}), eager_only=True)["b"]
assert_equal_data(series[other].to_frame(), {"a": [None, 3]})
Loading