Skip to content

Commit

Permalink
TYP: get_indexer (pandas-dev#40612)
Browse files Browse the repository at this point in the history
* TYP: get_indexer

* update per discussion in pandas-dev#40612

* one more overload

* pre-commit fixup
  • Loading branch information
jbrockmendel authored and yeshsurya committed Apr 21, 2021
1 parent 331fef4 commit 8fa4f11
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
33 changes: 30 additions & 3 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Sequence,
TypeVar,
cast,
overload,
)
import warnings

Expand Down Expand Up @@ -159,6 +160,8 @@
)

if TYPE_CHECKING:
from typing import Literal

from pandas import (
CategoricalIndex,
DataFrame,
Expand Down Expand Up @@ -5212,7 +5215,8 @@ def set_value(self, arr, key, value):
"""

@Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs)
def get_indexer_non_unique(self, target):
def get_indexer_non_unique(self, target) -> tuple[np.ndarray, np.ndarray]:
# both returned ndarrays are np.intp
target = ensure_index(target)

if not self._should_compare(target) and not is_interval_dtype(self.dtype):
Expand All @@ -5236,7 +5240,7 @@ def get_indexer_non_unique(self, target):
tgt_values = target._get_engine_target()

indexer, missing = self._engine.get_indexer_non_unique(tgt_values)
return ensure_platform_int(indexer), missing
return ensure_platform_int(indexer), ensure_platform_int(missing)

@final
def get_indexer_for(self, target, **kwargs) -> np.ndarray:
Expand All @@ -5256,8 +5260,31 @@ def get_indexer_for(self, target, **kwargs) -> np.ndarray:
indexer, _ = self.get_indexer_non_unique(target)
return indexer

@overload
def _get_indexer_non_comparable(
self, target: Index, method, unique: Literal[True] = ...
) -> np.ndarray:
# returned ndarray is np.intp
...

@overload
def _get_indexer_non_comparable(
self, target: Index, method, unique: Literal[False]
) -> tuple[np.ndarray, np.ndarray]:
# both returned ndarrays are np.intp
...

@overload
def _get_indexer_non_comparable(
self, target: Index, method, unique: bool = True
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
# any returned ndarrays are np.intp
...

@final
def _get_indexer_non_comparable(self, target: Index, method, unique: bool = True):
def _get_indexer_non_comparable(
self, target: Index, method, unique: bool = True
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""
Called from get_indexer or get_indexer_non_unique when the target
is of a non-comparable dtype.
Expand Down
11 changes: 8 additions & 3 deletions pandas/core/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,18 +482,23 @@ def _get_indexer(
limit: int | None = None,
tolerance=None,
) -> np.ndarray:
# returned ndarray is np.intp

if self.equals(target):
return np.arange(len(self), dtype="intp")

return self._get_indexer_non_unique(target._values)[0]

@Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs)
def get_indexer_non_unique(self, target):
def get_indexer_non_unique(self, target) -> tuple[np.ndarray, np.ndarray]:
# both returned ndarrays are np.intp
target = ibase.ensure_index(target)
return self._get_indexer_non_unique(target._values)

def _get_indexer_non_unique(self, values: ArrayLike):
def _get_indexer_non_unique(
self, values: ArrayLike
) -> tuple[np.ndarray, np.ndarray]:
# both returned ndarrays are np.intp
"""
get_indexer_non_unique but after unrapping the target Index object.
"""
Expand All @@ -512,7 +517,7 @@ def _get_indexer_non_unique(self, values: ArrayLike):
codes = self.categories.get_indexer(values)

indexer, missing = self._engine.get_indexer_non_unique(codes)
return ensure_platform_int(indexer), missing
return ensure_platform_int(indexer), ensure_platform_int(missing)

@doc(Index._convert_list_indexer)
def _convert_list_indexer(self, keyarr):
Expand Down
3 changes: 3 additions & 0 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ def _get_indexer(
limit: int | None = None,
tolerance: Any | None = None,
) -> np.ndarray:
# returned ndarray is np.intp

if isinstance(target, IntervalIndex):
# equal indexes -> 1:1 positional match
Expand Down Expand Up @@ -753,6 +754,7 @@ def _get_indexer(

@Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs)
def get_indexer_non_unique(self, target: Index) -> tuple[np.ndarray, np.ndarray]:
# both returned ndarrays are np.intp
target = ensure_index(target)

if isinstance(target, IntervalIndex) and not self._should_compare(target):
Expand All @@ -772,6 +774,7 @@ def get_indexer_non_unique(self, target: Index) -> tuple[np.ndarray, np.ndarray]
return ensure_platform_int(indexer), ensure_platform_int(missing)

def _get_indexer_pointwise(self, target: Index) -> tuple[np.ndarray, np.ndarray]:
# both returned ndarrays are np.intp
"""
pointwise implementation for get_indexer and get_indexer_non_unique.
"""
Expand Down

0 comments on commit 8fa4f11

Please sign in to comment.