-
-
Notifications
You must be signed in to change notification settings - Fork 18.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement NA.__array_ufunc__ (#30245)
- Loading branch information
1 parent
ea73e0b
commit 97153bf
Showing
8 changed files
with
256 additions
and
113 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
DISPATCHED_UFUNCS = { | ||
"add", | ||
"sub", | ||
"mul", | ||
"pow", | ||
"mod", | ||
"floordiv", | ||
"truediv", | ||
"divmod", | ||
"eq", | ||
"ne", | ||
"lt", | ||
"gt", | ||
"le", | ||
"ge", | ||
"remainder", | ||
"matmul", | ||
"or", | ||
"xor", | ||
"and", | ||
} | ||
UFUNC_ALIASES = { | ||
"subtract": "sub", | ||
"multiply": "mul", | ||
"floor_divide": "floordiv", | ||
"true_divide": "truediv", | ||
"power": "pow", | ||
"remainder": "mod", | ||
"divide": "div", | ||
"equal": "eq", | ||
"not_equal": "ne", | ||
"less": "lt", | ||
"less_equal": "le", | ||
"greater": "gt", | ||
"greater_equal": "ge", | ||
"bitwise_or": "or", | ||
"bitwise_and": "and", | ||
"bitwise_xor": "xor", | ||
} | ||
|
||
# For op(., Array) -> Array.__r{op}__ | ||
REVERSED_NAMES = { | ||
"lt": "__gt__", | ||
"le": "__ge__", | ||
"gt": "__lt__", | ||
"ge": "__le__", | ||
"eq": "__eq__", | ||
"ne": "__ne__", | ||
} | ||
|
||
|
||
def maybe_dispatch_ufunc_to_dunder_op( | ||
object self, object ufunc, str method, *inputs, **kwargs | ||
): | ||
""" | ||
Dispatch a ufunc to the equivalent dunder method. | ||
Parameters | ||
---------- | ||
self : ArrayLike | ||
The array whose dunder method we dispatch to | ||
ufunc : Callable | ||
A NumPy ufunc | ||
method : {'reduce', 'accumulate', 'reduceat', 'outer', 'at', '__call__'} | ||
inputs : ArrayLike | ||
The input arrays. | ||
kwargs : Any | ||
The additional keyword arguments, e.g. ``out``. | ||
Returns | ||
------- | ||
result : Any | ||
The result of applying the ufunc | ||
""" | ||
# special has the ufuncs we dispatch to the dunder op on | ||
|
||
op_name = ufunc.__name__ | ||
op_name = UFUNC_ALIASES.get(op_name, op_name) | ||
|
||
def not_implemented(*args, **kwargs): | ||
return NotImplemented | ||
|
||
if (method == "__call__" | ||
and op_name in DISPATCHED_UFUNCS | ||
and kwargs.get("out") is None): | ||
if isinstance(inputs[0], type(self)): | ||
name = f"__{op_name}__" | ||
return getattr(self, name, not_implemented)(inputs[1]) | ||
else: | ||
name = REVERSED_NAMES.get(op_name, f"__r{op_name}__") | ||
result = getattr(self, name, not_implemented)(inputs[0]) | ||
return result | ||
else: | ||
return NotImplemented |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.