Skip to content

Commit

Permalink
TYP: get_indexer (pandas-dev#40612)
Browse files Browse the repository at this point in the history
* TYP: get_indexer

* update per discussion in pandas-dev#40612

* one more overload

* pre-commit fixup
  • Loading branch information
jbrockmendel authored and yeshsurya committed May 6, 2021
1 parent 6882acf commit 985c5fa
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 6 deletions.
33 changes: 30 additions & 3 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Sequence,
TypeVar,
cast,
overload,
)
import warnings

Expand Down Expand Up @@ -159,6 +160,8 @@
)

if TYPE_CHECKING:
from typing import Literal

from pandas import (
CategoricalIndex,
DataFrame,
Expand Down Expand Up @@ -5212,7 +5215,8 @@ def set_value(self, arr, key, value):
"""

@Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs)
def get_indexer_non_unique(self, target):
def get_indexer_non_unique(self, target) -> tuple[np.ndarray, np.ndarray]:
# both returned ndarrays are np.intp
target = ensure_index(target)

if not self._should_compare(target) and not is_interval_dtype(self.dtype):
Expand All @@ -5236,7 +5240,7 @@ def get_indexer_non_unique(self, target):
tgt_values = target._get_engine_target()

indexer, missing = self._engine.get_indexer_non_unique(tgt_values)
return ensure_platform_int(indexer), missing
return ensure_platform_int(indexer), ensure_platform_int(missing)

@final
def get_indexer_for(self, target, **kwargs) -> np.ndarray:
Expand All @@ -5256,8 +5260,31 @@ def get_indexer_for(self, target, **kwargs) -> np.ndarray:
indexer, _ = self.get_indexer_non_unique(target)
return indexer

@overload
def _get_indexer_non_comparable(
self, target: Index, method, unique: Literal[True] = ...
) -> np.ndarray:
# returned ndarray is np.intp
...

@overload
def _get_indexer_non_comparable(
self, target: Index, method, unique: Literal[False]
) -> tuple[np.ndarray, np.ndarray]:
# both returned ndarrays are np.intp
...

@overload
def _get_indexer_non_comparable(
self, target: Index, method, unique: bool = True
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
# any returned ndarrays are np.intp
...

@final
def _get_indexer_non_comparable(self, target: Index, method, unique: bool = True):
def _get_indexer_non_comparable(
self, target: Index, method, unique: bool = True
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
"""
Called from get_indexer or get_indexer_non_unique when the target
is of a non-comparable dtype.
Expand Down
11 changes: 8 additions & 3 deletions pandas/core/indexes/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,18 +491,23 @@ def _get_indexer(
limit: int | None = None,
tolerance=None,
) -> np.ndarray:
# returned ndarray is np.intp

if self.equals(target):
return np.arange(len(self), dtype="intp")

return self._get_indexer_non_unique(target._values)[0]

@Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs)
def get_indexer_non_unique(self, target):
def get_indexer_non_unique(self, target) -> tuple[np.ndarray, np.ndarray]:
# both returned ndarrays are np.intp
target = ibase.ensure_index(target)
return self._get_indexer_non_unique(target._values)

def _get_indexer_non_unique(self, values: ArrayLike):
def _get_indexer_non_unique(
self, values: ArrayLike
) -> tuple[np.ndarray, np.ndarray]:
# both returned ndarrays are np.intp
"""
get_indexer_non_unique but after unrapping the target Index object.
"""
Expand All @@ -521,7 +526,7 @@ def _get_indexer_non_unique(self, values: ArrayLike):
codes = self.categories.get_indexer(values)

indexer, missing = self._engine.get_indexer_non_unique(codes)
return ensure_platform_int(indexer), missing
return ensure_platform_int(indexer), ensure_platform_int(missing)

@doc(Index._convert_list_indexer)
def _convert_list_indexer(self, keyarr):
Expand Down

0 comments on commit 985c5fa

Please sign in to comment.