-
-
Notifications
You must be signed in to change notification settings - Fork 18.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Support ExtensionArray operators via a mixin #21261
Changes from all commits
5b0ebc7
d7596c6
7f2b0a1
ec96841
a07bb49
1d7b2b3
7bad559
dfcda3b
aaaa8fd
4bcf978
f958d7b
ef83c3a
41dc5ca
be6656b
a0f503c
700d75b
87e8f55
97bd291
8fc93e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .base import ExtensionArray # noqa | ||
from .base import (ExtensionArray, # noqa | ||
ExtensionScalarOpsMixin) | ||
from .categorical import Categorical # noqa |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,8 +7,13 @@ | |
""" | ||
import numpy as np | ||
|
||
import operator | ||
|
||
from pandas.errors import AbstractMethodError | ||
from pandas.compat.numpy import function as nv | ||
from pandas.compat import set_function_name, PY3 | ||
from pandas.core.dtypes.common import is_list_like | ||
from pandas.core import ops | ||
|
||
_not_implemented_message = "{} does not implement {}." | ||
|
||
|
@@ -610,3 +615,125 @@ def _ndarray_values(self): | |
used for interacting with our indexers. | ||
""" | ||
return np.array(self) | ||
|
||
|
||
class ExtensionOpsMixin(object): | ||
""" | ||
A base class for linking the operators to their dunder names | ||
""" | ||
@classmethod | ||
def _add_arithmetic_ops(cls): | ||
cls.__add__ = cls._create_arithmetic_method(operator.add) | ||
cls.__radd__ = cls._create_arithmetic_method(ops.radd) | ||
cls.__sub__ = cls._create_arithmetic_method(operator.sub) | ||
cls.__rsub__ = cls._create_arithmetic_method(ops.rsub) | ||
cls.__mul__ = cls._create_arithmetic_method(operator.mul) | ||
cls.__rmul__ = cls._create_arithmetic_method(ops.rmul) | ||
cls.__pow__ = cls._create_arithmetic_method(operator.pow) | ||
cls.__rpow__ = cls._create_arithmetic_method(ops.rpow) | ||
cls.__mod__ = cls._create_arithmetic_method(operator.mod) | ||
cls.__rmod__ = cls._create_arithmetic_method(ops.rmod) | ||
cls.__floordiv__ = cls._create_arithmetic_method(operator.floordiv) | ||
cls.__rfloordiv__ = cls._create_arithmetic_method(ops.rfloordiv) | ||
cls.__truediv__ = cls._create_arithmetic_method(operator.truediv) | ||
cls.__rtruediv__ = cls._create_arithmetic_method(ops.rtruediv) | ||
if not PY3: | ||
cls.__div__ = cls._create_arithmetic_method(operator.div) | ||
cls.__rdiv__ = cls._create_arithmetic_method(ops.rdiv) | ||
|
||
cls.__divmod__ = cls._create_arithmetic_method(divmod) | ||
cls.__rdivmod__ = cls._create_arithmetic_method(ops.rdivmod) | ||
|
||
@classmethod | ||
def _add_comparison_ops(cls): | ||
cls.__eq__ = cls._create_comparison_method(operator.eq) | ||
cls.__ne__ = cls._create_comparison_method(operator.ne) | ||
cls.__lt__ = cls._create_comparison_method(operator.lt) | ||
cls.__gt__ = cls._create_comparison_method(operator.gt) | ||
cls.__le__ = cls._create_comparison_method(operator.le) | ||
cls.__ge__ = cls._create_comparison_method(operator.ge) | ||
|
||
|
||
class ExtensionScalarOpsMixin(ExtensionOpsMixin): | ||
"""A mixin for defining the arithmetic and logical operations on | ||
an ExtensionArray class, where it is assumed that the underlying objects | ||
have the operators already defined. | ||
|
||
Usage | ||
------ | ||
If you have defined a subclass MyExtensionArray(ExtensionArray), then | ||
use MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin) to | ||
get the arithmetic operators. After the definition of MyExtensionArray, | ||
insert the lines | ||
|
||
MyExtensionArray._add_arithmetic_ops() | ||
MyExtensionArray._add_comparison_ops() | ||
|
||
to link the operators to your class. | ||
""" | ||
|
||
@classmethod | ||
def _create_method(cls, op, coerce_to_dtype=True): | ||
""" | ||
A class method that returns a method that will correspond to an | ||
operator for an ExtensionArray subclass, by dispatching to the | ||
relevant operator defined on the individual elements of the | ||
ExtensionArray. | ||
|
||
Parameters | ||
---------- | ||
op : function | ||
An operator that takes arguments op(a, b) | ||
coerce_to_dtype : bool | ||
boolean indicating whether to attempt to convert | ||
the result to the underlying ExtensionArray dtype | ||
(default True) | ||
|
||
Returns | ||
------- | ||
A method that can be bound to a method of a class | ||
|
||
Example | ||
------- | ||
Given an ExtensionArray subclass called MyExtensionArray, use | ||
|
||
>>> __add__ = cls._create_method(operator.add) | ||
|
||
in the class definition of MyExtensionArray to create the operator | ||
for addition, that will be based on the operator implementation | ||
of the underlying elements of the ExtensionArray | ||
|
||
""" | ||
|
||
def _binop(self, other): | ||
def convert_values(param): | ||
if isinstance(param, ExtensionArray) or is_list_like(param): | ||
ovalues = param | ||
else: # Assume its an object | ||
ovalues = [param] * len(self) | ||
return ovalues | ||
lvalues = self | ||
rvalues = convert_values(other) | ||
|
||
# If the operator is not defined for the underlying objects, | ||
# a TypeError should be raised | ||
res = [op(a, b) for (a, b) in zip(lvalues, rvalues)] | ||
|
||
if coerce_to_dtype: | ||
try: | ||
res = self._from_sequence(res) | ||
except TypeError: | ||
pass | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if this fails, I think we should still convert it to an array instead of keeping it as a list? Or does that happen on another level? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we convert to an array, then we could have a dtype problem. This allows the result to be of any type. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
It will be converted to an array anyhow, if not here, then at the level above when the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jorisvandenbossche But why repeat that logic? If we leave it as a list, then the Series constructor will do the inference on the dtype. |
||
|
||
return res | ||
|
||
op_name = ops._get_op_name(op, True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you use a parameter name instead of the positional argument ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I could do that, but not specifying the parameter is consistent with all the other usages of |
||
return set_function_name(_binop, op_name, cls) | ||
|
||
@classmethod | ||
def _create_arithmetic_method(cls, op): | ||
return cls._create_method(op) | ||
|
||
@classmethod | ||
def _create_comparison_method(cls, op): | ||
return cls._create_method(op, coerce_to_dtype=False) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,6 +33,7 @@ | |
is_bool_dtype, | ||
is_list_like, | ||
is_scalar, | ||
is_extension_array_dtype, | ||
_ensure_object) | ||
from pandas.core.dtypes.cast import ( | ||
maybe_upcast_putmask, find_common_type, | ||
|
@@ -993,6 +994,26 @@ def _construct_divmod_result(left, result, index, name, dtype): | |
) | ||
|
||
|
||
def dispatch_to_extension_op(op, left, right): | ||
""" | ||
Assume that left or right is a Series backed by an ExtensionArray, | ||
apply the operator defined by op. | ||
""" | ||
|
||
# The op calls will raise TypeError if the op is not defined | ||
# on the ExtensionArray | ||
if is_extension_array_dtype(left): | ||
res_values = op(left.values, right) | ||
else: | ||
# We know that left is not ExtensionArray and is Series and right is | ||
# ExtensionArray. Want to force ExtensionArray op to get called | ||
res_values = op(list(left.values), right.values) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. was this needed to fix failing tests? (this was not here in a previous version I think? and eg the dispatch to index is only dealing with the left case) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was added a few commits ago. I discovered that the tests were not properly testing the reverse operators. So there were a couple of changes related to that:
Without these changes, an operator such as As best as I can tell, Another possible implementation might be to add a test in the wrappers that says that if is_extension_dtype(right) return NotImplemented, which will then make python call the reverse operator. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is much closer to what we've had in mind with the recent refactoring in Why is the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jbrockmendel Wrote:
Yes, I think it would be something like:
But I would have to test the concept. I don't think the upcast is necessary. We know that left is a Series. So python would then see that I'm not going to try making that change without feedback from others.
Because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think this is the better way to go. For me it is fine to not yet do this in this PR, but then I would also not include the above change (and live with that the first iteration in this PR is not yet ideal for cases where the ExtensionArray is the right variable) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I would at least try this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jorisvandenbossche Yes, I will try this in a new PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jorisvandenbossche (and @jbrockmendel) I tried the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not entirely sure I follow. Are you saying that both The dispatch logic should do something like
It is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @jbrockmendel I've tried two different alternatives (regarding replacing this code else:
# We know that left is not ExtensionArray and is Series and right is
# ExtensionArray. Want to force ExtensionArray op to get called
res_values = op(list(left.values), right.values) that this discussion refers to). Note the section of code is hit when
else:
return NotImplemented
else:
res_values = op(left.values, right.values) In the first case, python sees that In the second case, |
||
|
||
res_name = get_op_result_name(left, right) | ||
return left._constructor(res_values, index=left.index, | ||
name=res_name) | ||
|
||
|
||
def _arith_method_SERIES(cls, op, special): | ||
""" | ||
Wrapper function for Series arithmetic operations, to avoid | ||
|
@@ -1061,6 +1082,11 @@ def wrapper(left, right): | |
raise TypeError("{typ} cannot perform the operation " | ||
"{op}".format(typ=type(left).__name__, op=str_rep)) | ||
|
||
elif (is_extension_array_dtype(left) or | ||
(is_extension_array_dtype(right) and | ||
not is_categorical_dtype(right))): | ||
return dispatch_to_extension_op(op, left, right) | ||
|
||
lvalues = left.values | ||
rvalues = right | ||
if isinstance(rvalues, ABCSeries): | ||
|
@@ -1238,6 +1264,11 @@ def wrapper(self, other, axis=None): | |
return self._constructor(res_values, index=self.index, | ||
name=res_name) | ||
|
||
elif (is_extension_array_dtype(self) or | ||
(is_extension_array_dtype(other) and | ||
not is_categorical_dtype(other))): | ||
return dispatch_to_extension_op(op, self, other) | ||
|
||
elif isinstance(other, ABCSeries): | ||
# By this point we have checked that self._indexed_same(other) | ||
res_values = na_op(self.values, other.values) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make this
is_list_like(param)
more strict asis_array_like
?For example, if you create ExtensionArray of sets, and do an operation where the right value is a single set, the current code will not work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or similar with dicts (for which we already have a dummy implementation in the tests)
(I am also fine with leaving this for later, as there are other places where we have problems with iterable scalar elements)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jorisvandenbossche I think I want to leave it for now. Because you'd like to be able to do an operation such as
EABackedSeries + list(objects)
and usingis_array_type
means you have to have a dtype.