Skip to content

Commit

Permalink
TYP: Fix overlapping overloads issue in 2->1 ufuncs
Browse files Browse the repository at this point in the history
  • Loading branch information
jorenham authored and charris committed Jan 17, 2025
1 parent 32b58cd commit f782790
Showing 1 changed file with 111 additions and 49 deletions.
160 changes: 111 additions & 49 deletions numpy/_typing/_ufunc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -4,32 +4,32 @@ The signatures of the ufuncs are too varied to reasonably type
with a single class. So instead, `ufunc` has been expanded into
four private subclasses, one for each combination of
`~ufunc.nin` and `~ufunc.nout`.
"""

from typing import (
Any,
Generic,
Literal,
NoReturn,
TypedDict,
overload,
Protocol,
SupportsIndex,
TypeAlias,
TypedDict,
TypeVar,
Literal,
SupportsIndex,
Protocol,
overload,
type_check_only,
)

from typing_extensions import LiteralString, Unpack

import numpy as np
from numpy import ufunc, _CastingKind, _OrderKACF
from numpy import _CastingKind, _OrderKACF, ufunc
from numpy.typing import NDArray

from ._shape import _ShapeLike
from ._scalars import _ScalarLike_co
from ._array_like import ArrayLike, _ArrayLikeBool_co, _ArrayLikeInt_co
from ._dtype_like import DTypeLike
from ._scalars import _ScalarLike_co
from ._shape import _ShapeLike

_T = TypeVar("_T")
_2Tuple: TypeAlias = tuple[_T, _T]
Expand Down Expand Up @@ -61,6 +61,13 @@ class _SupportsArrayUFunc(Protocol):
**kwargs: Any,
) -> Any: ...

@type_check_only
class _UFunc3Kwargs(TypedDict, total=False):
where: _ArrayLikeBool_co | None
casting: _CastingKind
order: _OrderKACF
subok: bool
signature: _3Tuple[str | None] | str | None

# NOTE: `reduce`, `accumulate`, `reduceat` and `outer` raise a ValueError for
# ufuncs that don't accept two input arguments and return one output argument.
Expand All @@ -72,6 +79,8 @@ class _SupportsArrayUFunc(Protocol):
# NOTE: If 2 output types are returned then `out` must be a
# 2-tuple of arrays. Otherwise `None` or a plain array are also acceptable

# pyright: reportIncompatibleMethodOverride=false

@type_check_only
class _UFunc_Nin1_Nout1(ufunc, Generic[_NameType, _NTypes, _IDType]): # type: ignore[misc]
@property
Expand Down Expand Up @@ -162,34 +171,61 @@ class _UFunc_Nin2_Nout1(ufunc, Generic[_NameType, _NTypes, _IDType]): # type: i
@property
def signature(self) -> None: ...

@overload
@overload # (scalar, scalar) -> scalar
def __call__(
self,
__x1: _ScalarLike_co,
__x2: _ScalarLike_co,
out: None = ...,
x1: _ScalarLike_co,
x2: _ScalarLike_co,
/,
out: None = None,
*,
where: None | _ArrayLikeBool_co = ...,
casting: _CastingKind = ...,
order: _OrderKACF = ...,
dtype: DTypeLike = ...,
subok: bool = ...,
signature: str | _3Tuple[None | str] = ...,
dtype: DTypeLike | None = None,
**kwds: Unpack[_UFunc3Kwargs],
) -> Any: ...
@overload
@overload # (array-like, array) -> array
def __call__(
self,
__x1: ArrayLike,
__x2: ArrayLike,
out: None | NDArray[Any] | tuple[NDArray[Any]] = ...,
x1: ArrayLike,
x2: NDArray[np.generic],
/,
out: NDArray[np.generic] | tuple[NDArray[np.generic]] | None = None,
*,
where: None | _ArrayLikeBool_co = ...,
casting: _CastingKind = ...,
order: _OrderKACF = ...,
dtype: DTypeLike = ...,
subok: bool = ...,
signature: str | _3Tuple[None | str] = ...,
dtype: DTypeLike | None = None,
**kwds: Unpack[_UFunc3Kwargs],
) -> NDArray[Any]: ...
@overload # (array, array-like) -> array
def __call__(
self,
x1: NDArray[np.generic],
x2: ArrayLike,
/,
out: NDArray[np.generic] | tuple[NDArray[np.generic]] | None = None,
*,
dtype: DTypeLike | None = None,
**kwds: Unpack[_UFunc3Kwargs],
) -> NDArray[Any]: ...
@overload # (array-like, array-like, out=array) -> array
def __call__(
self,
x1: ArrayLike,
x2: ArrayLike,
/,
out: NDArray[np.generic] | tuple[NDArray[np.generic]],
*,
dtype: DTypeLike | None = None,
**kwds: Unpack[_UFunc3Kwargs],
) -> NDArray[Any]: ...
@overload # (array-like, array-like) -> array | scalar
def __call__(
self,
x1: ArrayLike,
x2: ArrayLike,
/,
out: NDArray[np.generic] | tuple[NDArray[np.generic]] | None = None,
*,
dtype: DTypeLike | None = None,
**kwds: Unpack[_UFunc3Kwargs],
) -> NDArray[Any] | Any: ...

def at(
self,
Expand Down Expand Up @@ -227,35 +263,61 @@ class _UFunc_Nin2_Nout1(ufunc, Generic[_NameType, _NTypes, _IDType]): # type: i
out: None | NDArray[Any] = ...,
) -> NDArray[Any]: ...

# Expand `**kwargs` into explicit keyword-only arguments
@overload
@overload # (scalar, scalar) -> scalar
def outer(
self,
A: _ScalarLike_co,
B: _ScalarLike_co,
/, *,
out: None = ...,
where: None | _ArrayLikeBool_co = ...,
casting: _CastingKind = ...,
order: _OrderKACF = ...,
dtype: DTypeLike = ...,
subok: bool = ...,
signature: str | _3Tuple[None | str] = ...,
/,
*,
out: None = None,
dtype: DTypeLike | None = None,
**kwds: Unpack[_UFunc3Kwargs],
) -> Any: ...
@overload
def outer( # type: ignore[misc]
@overload # (array-like, array) -> array
def outer(
self,
A: ArrayLike,
B: NDArray[np.generic],
/,
*,
out: NDArray[np.generic] | tuple[NDArray[np.generic]] | None = None,
dtype: DTypeLike | None = None,
**kwds: Unpack[_UFunc3Kwargs],
) -> NDArray[Any]: ...
@overload # (array, array-like) -> array
def outer(
self,
A: NDArray[np.generic],
B: ArrayLike,
/, *,
out: None | NDArray[Any] | tuple[NDArray[Any]] = ...,
where: None | _ArrayLikeBool_co = ...,
casting: _CastingKind = ...,
order: _OrderKACF = ...,
dtype: DTypeLike = ...,
subok: bool = ...,
signature: str | _3Tuple[None | str] = ...,
/,
*,
out: NDArray[np.generic] | tuple[NDArray[np.generic]] | None = None,
dtype: DTypeLike | None = None,
**kwds: Unpack[_UFunc3Kwargs],
) -> NDArray[Any]: ...
@overload # (array-like, array-like, out=array) -> array
def outer(
self,
A: ArrayLike,
B: ArrayLike,
/,
*,
out: NDArray[np.generic] | tuple[NDArray[np.generic]],
dtype: DTypeLike | None = None,
**kwds: Unpack[_UFunc3Kwargs],
) -> NDArray[Any]: ...
@overload # (array-like, array-like) -> array | scalar
def outer(
self,
A: ArrayLike,
B: ArrayLike,
/,
*,
out: NDArray[np.generic] | tuple[NDArray[np.generic]] | None = None,
dtype: DTypeLike | None = None,
**kwds: Unpack[_UFunc3Kwargs],
) -> NDArray[Any] | Any: ...

@type_check_only
class _UFunc_Nin1_Nout2(ufunc, Generic[_NameType, _NTypes, _IDType]): # type: ignore[misc]
Expand Down

0 comments on commit f782790

Please sign in to comment.