From 985c5fac7c5ce94c694b6ab69e50b8946b280690 Mon Sep 17 00:00:00 2001 From: jbrockmendel Date: Mon, 19 Apr 2021 11:54:32 -0700 Subject: [PATCH] TYP: get_indexer (#40612) * TYP: get_indexer * update per discussion in #40612 * one more overload * pre-commit fixup --- pandas/core/indexes/base.py | 33 ++++++++++++++++++++++++++++++--- pandas/core/indexes/category.py | 11 ++++++++--- 2 files changed, 38 insertions(+), 6 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 59550e675b846..58f5ca3de5dce 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -13,6 +13,7 @@ Sequence, TypeVar, cast, + overload, ) import warnings @@ -159,6 +160,8 @@ ) if TYPE_CHECKING: + from typing import Literal + from pandas import ( CategoricalIndex, DataFrame, @@ -5212,7 +5215,8 @@ def set_value(self, arr, key, value): """ @Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs) - def get_indexer_non_unique(self, target): + def get_indexer_non_unique(self, target) -> tuple[np.ndarray, np.ndarray]: + # both returned ndarrays are np.intp target = ensure_index(target) if not self._should_compare(target) and not is_interval_dtype(self.dtype): @@ -5236,7 +5240,7 @@ def get_indexer_non_unique(self, target): tgt_values = target._get_engine_target() indexer, missing = self._engine.get_indexer_non_unique(tgt_values) - return ensure_platform_int(indexer), missing + return ensure_platform_int(indexer), ensure_platform_int(missing) @final def get_indexer_for(self, target, **kwargs) -> np.ndarray: @@ -5256,8 +5260,31 @@ def get_indexer_for(self, target, **kwargs) -> np.ndarray: indexer, _ = self.get_indexer_non_unique(target) return indexer + @overload + def _get_indexer_non_comparable( + self, target: Index, method, unique: Literal[True] = ... + ) -> np.ndarray: + # returned ndarray is np.intp + ... + + @overload + def _get_indexer_non_comparable( + self, target: Index, method, unique: Literal[False] + ) -> tuple[np.ndarray, np.ndarray]: + # both returned ndarrays are np.intp + ... + + @overload + def _get_indexer_non_comparable( + self, target: Index, method, unique: bool = True + ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: + # any returned ndarrays are np.intp + ... + @final - def _get_indexer_non_comparable(self, target: Index, method, unique: bool = True): + def _get_indexer_non_comparable( + self, target: Index, method, unique: bool = True + ) -> np.ndarray | tuple[np.ndarray, np.ndarray]: """ Called from get_indexer or get_indexer_non_unique when the target is of a non-comparable dtype. diff --git a/pandas/core/indexes/category.py b/pandas/core/indexes/category.py index 8d15b460a79df..0624a1a64c9f8 100644 --- a/pandas/core/indexes/category.py +++ b/pandas/core/indexes/category.py @@ -491,6 +491,7 @@ def _get_indexer( limit: int | None = None, tolerance=None, ) -> np.ndarray: + # returned ndarray is np.intp if self.equals(target): return np.arange(len(self), dtype="intp") @@ -498,11 +499,15 @@ def _get_indexer( return self._get_indexer_non_unique(target._values)[0] @Appender(_index_shared_docs["get_indexer_non_unique"] % _index_doc_kwargs) - def get_indexer_non_unique(self, target): + def get_indexer_non_unique(self, target) -> tuple[np.ndarray, np.ndarray]: + # both returned ndarrays are np.intp target = ibase.ensure_index(target) return self._get_indexer_non_unique(target._values) - def _get_indexer_non_unique(self, values: ArrayLike): + def _get_indexer_non_unique( + self, values: ArrayLike + ) -> tuple[np.ndarray, np.ndarray]: + # both returned ndarrays are np.intp """ get_indexer_non_unique but after unrapping the target Index object. """ @@ -521,7 +526,7 @@ def _get_indexer_non_unique(self, values: ArrayLike): codes = self.categories.get_indexer(values) indexer, missing = self._engine.get_indexer_non_unique(codes) - return ensure_platform_int(indexer), missing + return ensure_platform_int(indexer), ensure_platform_int(missing) @doc(Index._convert_list_indexer) def _convert_list_indexer(self, keyarr):