Skip to content

Commit

Permalink
ENH: Support ExtensionArray operators via a mixin (#21261)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dr-Irv authored and jreback committed Jun 29, 2018
1 parent 5c761f1 commit 0b63e81
Show file tree
Hide file tree
Showing 15 changed files with 460 additions and 33 deletions.
64 changes: 56 additions & 8 deletions doc/source/extending.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Extension Types

.. warning::

The :class:`pandas.api.extension.ExtensionDtype` and :class:`pandas.api.extension.ExtensionArray` APIs are new and
The :class:`pandas.api.extensions.ExtensionDtype` and :class:`pandas.api.extensions.ExtensionArray` APIs are new and
experimental. They may change between versions without warning.

Pandas defines an interface for implementing data types and arrays that *extend*
Expand All @@ -79,10 +79,10 @@ on :ref:`ecosystem.extensions`.

The interface consists of two classes.

:class:`~pandas.api.extension.ExtensionDtype`
:class:`~pandas.api.extensions.ExtensionDtype`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

A :class:`pandas.api.extension.ExtensionDtype` is similar to a ``numpy.dtype`` object. It describes the
A :class:`pandas.api.extensions.ExtensionDtype` is similar to a ``numpy.dtype`` object. It describes the
data type. Implementors are responsible for a few unique items like the name.

One particularly important item is the ``type`` property. This should be the
Expand All @@ -91,7 +91,7 @@ extension array for IP Address data, this might be ``ipaddress.IPv4Address``.

See the `extension dtype source`_ for interface definition.

:class:`~pandas.api.extension.ExtensionArray`
:class:`~pandas.api.extensions.ExtensionArray`
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

This class provides all the array-like functionality. ExtensionArrays are
Expand All @@ -113,6 +113,54 @@ by some other storage type, like Python lists.
See the `extension array source`_ for the interface definition. The docstrings
and comments contain guidance for properly implementing the interface.

.. _extending.extension.operator:

:class:`~pandas.api.extensions.ExtensionArray` Operator Support
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

.. versionadded:: 0.24.0

By default, there are no operators defined for the class :class:`~pandas.api.extensions.ExtensionArray`.
There are two approaches for providing operator support for your ExtensionArray:

1. Define each of the operators on your ``ExtensionArray`` subclass.
2. Use an operator implementation from pandas that depends on operators that are already defined
on the underlying elements (scalars) of the ExtensionArray.

For the first approach, you define selected operators, e.g., ``__add__``, ``__le__``, etc. that
you want your ``ExtensionArray`` subclass to support.

The second approach assumes that the underlying elements (i.e., scalar type) of the ``ExtensionArray``
have the individual operators already defined. In other words, if your ``ExtensionArray``
named ``MyExtensionArray`` is implemented so that each element is an instance
of the class ``MyExtensionElement``, then if the operators are defined
for ``MyExtensionElement``, the second approach will automatically
define the operators for ``MyExtensionArray``.

A mixin class, :class:`~pandas.api.extensions.ExtensionScalarOpsMixin` supports this second
approach. If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``,
can simply include ``ExtensionScalarOpsMixin`` as a parent class of ``MyExtensionArray``,
and then call the methods :meth:`~MyExtensionArray._add_arithmetic_ops` and/or
:meth:`~MyExtensionArray._add_comparison_ops` to hook the operators into
your ``MyExtensionArray`` class, as follows:

.. code-block:: python
class MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin):
pass
MyExtensionArray._add_arithmetic_ops()
MyExtensionArray._add_comparison_ops()
Note that since ``pandas`` automatically calls the underlying operator on each
element one-by-one, this might not be as performant as implementing your own
version of the associated operators directly on the ``ExtensionArray``.

.. _extending.extension.testing:

Testing Extension Arrays
^^^^^^^^^^^^^^^^^^^^^^^^

We provide a test suite for ensuring that your extension arrays satisfy the expected
behavior. To use the test suite, you must provide several pytest fixtures and inherit
from the base test class. The required fixtures are found in
Expand Down Expand Up @@ -174,11 +222,11 @@ There are 3 constructor properties to be defined:
Following table shows how ``pandas`` data structures define constructor properties by default.

=========================== ======================= =============
Property Attributes ``Series`` ``DataFrame``
Property Attributes ``Series`` ``DataFrame``
=========================== ======================= =============
``_constructor`` ``Series`` ``DataFrame``
``_constructor_sliced`` ``NotImplementedError`` ``Series``
``_constructor_expanddim`` ``DataFrame`` ``Panel``
``_constructor`` ``Series`` ``DataFrame``
``_constructor_sliced`` ``NotImplementedError`` ``Series``
``_constructor_expanddim`` ``DataFrame`` ``Panel``
=========================== ======================= =============

Below example shows how to define ``SubclassedSeries`` and ``SubclassedDataFrame`` overriding constructor properties.
Expand Down
16 changes: 16 additions & 0 deletions doc/source/whatsnew/v0.24.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,22 @@ New features

- ``ExcelWriter`` now accepts ``mode`` as a keyword argument, enabling append to existing workbooks when using the ``openpyxl`` engine (:issue:`3441`)

.. _whatsnew_0240.enhancements.extension_array_operators

``ExtensionArray`` operator support
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

A ``Series`` based on an ``ExtensionArray`` now supports arithmetic and comparison
operators. (:issue:`19577`). There are two approaches for providing operator support for an ``ExtensionArray``:

1. Define each of the operators on your ``ExtensionArray`` subclass.
2. Use an operator implementation from pandas that depends on operators that are already defined
on the underlying elements (scalars) of the ``ExtensionArray``.

See the :ref:`ExtensionArray Operator Support
<extending.extension.operator>` documentation section for details on both
ways of adding operator support.

.. _whatsnew_0240.enhancements.other:

Other Enhancements
Expand Down
3 changes: 2 additions & 1 deletion pandas/api/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
register_index_accessor,
register_series_accessor)
from pandas.core.algorithms import take # noqa
from pandas.core.arrays.base import ExtensionArray # noqa
from pandas.core.arrays.base import (ExtensionArray, # noqa
ExtensionScalarOpsMixin)
from pandas.core.dtypes.dtypes import ExtensionDtype # noqa
36 changes: 18 additions & 18 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def observed(request):
'__mul__', '__rmul__',
'__floordiv__', '__rfloordiv__',
'__truediv__', '__rtruediv__',
'__pow__', '__rpow__']
'__pow__', '__rpow__',
'__mod__', '__rmod__']
if not PY3:
_all_arithmetic_operators.extend(['__div__', '__rdiv__'])

Expand All @@ -102,6 +103,22 @@ def all_arithmetic_operators(request):
return request.param


@pytest.fixture(params=['__eq__', '__ne__', '__le__',
'__lt__', '__ge__', '__gt__'])
def all_compare_operators(request):
"""
Fixture for dunder names for common compare operations
* >=
* >
* ==
* !=
* <
* <=
"""
return request.param


@pytest.fixture(params=[None, 'gzip', 'bz2', 'zip',
pytest.param('xz', marks=td.skip_if_no_lzma)])
def compression(request):
Expand Down Expand Up @@ -320,20 +337,3 @@ def mock():
return importlib.import_module("unittest.mock")
else:
return pytest.importorskip("mock")


@pytest.fixture(params=['__eq__', '__ne__', '__le__',
'__lt__', '__ge__', '__gt__'])
def all_compare_operators(request):
"""
Fixture for dunder names for common compare operations
* >=
* >
* ==
* !=
* <
* <=
"""

return request.param
3 changes: 2 additions & 1 deletion pandas/core/arrays/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .base import ExtensionArray # noqa
from .base import (ExtensionArray, # noqa
ExtensionScalarOpsMixin)
from .categorical import Categorical # noqa
127 changes: 127 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@
"""
import numpy as np

import operator

from pandas.errors import AbstractMethodError
from pandas.compat.numpy import function as nv
from pandas.compat import set_function_name, PY3
from pandas.core.dtypes.common import is_list_like
from pandas.core import ops

_not_implemented_message = "{} does not implement {}."

Expand Down Expand Up @@ -610,3 +615,125 @@ def _ndarray_values(self):
used for interacting with our indexers.
"""
return np.array(self)


class ExtensionOpsMixin(object):
"""
A base class for linking the operators to their dunder names
"""
@classmethod
def _add_arithmetic_ops(cls):
cls.__add__ = cls._create_arithmetic_method(operator.add)
cls.__radd__ = cls._create_arithmetic_method(ops.radd)
cls.__sub__ = cls._create_arithmetic_method(operator.sub)
cls.__rsub__ = cls._create_arithmetic_method(ops.rsub)
cls.__mul__ = cls._create_arithmetic_method(operator.mul)
cls.__rmul__ = cls._create_arithmetic_method(ops.rmul)
cls.__pow__ = cls._create_arithmetic_method(operator.pow)
cls.__rpow__ = cls._create_arithmetic_method(ops.rpow)
cls.__mod__ = cls._create_arithmetic_method(operator.mod)
cls.__rmod__ = cls._create_arithmetic_method(ops.rmod)
cls.__floordiv__ = cls._create_arithmetic_method(operator.floordiv)
cls.__rfloordiv__ = cls._create_arithmetic_method(ops.rfloordiv)
cls.__truediv__ = cls._create_arithmetic_method(operator.truediv)
cls.__rtruediv__ = cls._create_arithmetic_method(ops.rtruediv)
if not PY3:
cls.__div__ = cls._create_arithmetic_method(operator.div)
cls.__rdiv__ = cls._create_arithmetic_method(ops.rdiv)

cls.__divmod__ = cls._create_arithmetic_method(divmod)
cls.__rdivmod__ = cls._create_arithmetic_method(ops.rdivmod)

@classmethod
def _add_comparison_ops(cls):
cls.__eq__ = cls._create_comparison_method(operator.eq)
cls.__ne__ = cls._create_comparison_method(operator.ne)
cls.__lt__ = cls._create_comparison_method(operator.lt)
cls.__gt__ = cls._create_comparison_method(operator.gt)
cls.__le__ = cls._create_comparison_method(operator.le)
cls.__ge__ = cls._create_comparison_method(operator.ge)


class ExtensionScalarOpsMixin(ExtensionOpsMixin):
"""A mixin for defining the arithmetic and logical operations on
an ExtensionArray class, where it is assumed that the underlying objects
have the operators already defined.
Usage
------
If you have defined a subclass MyExtensionArray(ExtensionArray), then
use MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin) to
get the arithmetic operators. After the definition of MyExtensionArray,
insert the lines
MyExtensionArray._add_arithmetic_ops()
MyExtensionArray._add_comparison_ops()
to link the operators to your class.
"""

@classmethod
def _create_method(cls, op, coerce_to_dtype=True):
"""
A class method that returns a method that will correspond to an
operator for an ExtensionArray subclass, by dispatching to the
relevant operator defined on the individual elements of the
ExtensionArray.
Parameters
----------
op : function
An operator that takes arguments op(a, b)
coerce_to_dtype : bool
boolean indicating whether to attempt to convert
the result to the underlying ExtensionArray dtype
(default True)
Returns
-------
A method that can be bound to a method of a class
Example
-------
Given an ExtensionArray subclass called MyExtensionArray, use
>>> __add__ = cls._create_method(operator.add)
in the class definition of MyExtensionArray to create the operator
for addition, that will be based on the operator implementation
of the underlying elements of the ExtensionArray
"""

def _binop(self, other):
def convert_values(param):
if isinstance(param, ExtensionArray) or is_list_like(param):
ovalues = param
else: # Assume its an object
ovalues = [param] * len(self)
return ovalues
lvalues = self
rvalues = convert_values(other)

# If the operator is not defined for the underlying objects,
# a TypeError should be raised
res = [op(a, b) for (a, b) in zip(lvalues, rvalues)]

if coerce_to_dtype:
try:
res = self._from_sequence(res)
except TypeError:
pass

return res

op_name = ops._get_op_name(op, True)
return set_function_name(_binop, op_name, cls)

@classmethod
def _create_arithmetic_method(cls, op):
return cls._create_method(op)

@classmethod
def _create_comparison_method(cls, op):
return cls._create_method(op, coerce_to_dtype=False)
31 changes: 31 additions & 0 deletions pandas/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
is_bool_dtype,
is_list_like,
is_scalar,
is_extension_array_dtype,
_ensure_object)
from pandas.core.dtypes.cast import (
maybe_upcast_putmask, find_common_type,
Expand Down Expand Up @@ -993,6 +994,26 @@ def _construct_divmod_result(left, result, index, name, dtype):
)


def dispatch_to_extension_op(op, left, right):
"""
Assume that left or right is a Series backed by an ExtensionArray,
apply the operator defined by op.
"""

# The op calls will raise TypeError if the op is not defined
# on the ExtensionArray
if is_extension_array_dtype(left):
res_values = op(left.values, right)
else:
# We know that left is not ExtensionArray and is Series and right is
# ExtensionArray. Want to force ExtensionArray op to get called
res_values = op(list(left.values), right.values)

res_name = get_op_result_name(left, right)
return left._constructor(res_values, index=left.index,
name=res_name)


def _arith_method_SERIES(cls, op, special):
"""
Wrapper function for Series arithmetic operations, to avoid
Expand Down Expand Up @@ -1061,6 +1082,11 @@ def wrapper(left, right):
raise TypeError("{typ} cannot perform the operation "
"{op}".format(typ=type(left).__name__, op=str_rep))

elif (is_extension_array_dtype(left) or
(is_extension_array_dtype(right) and
not is_categorical_dtype(right))):
return dispatch_to_extension_op(op, left, right)

lvalues = left.values
rvalues = right
if isinstance(rvalues, ABCSeries):
Expand Down Expand Up @@ -1238,6 +1264,11 @@ def wrapper(self, other, axis=None):
return self._constructor(res_values, index=self.index,
name=res_name)

elif (is_extension_array_dtype(self) or
(is_extension_array_dtype(other) and
not is_categorical_dtype(other))):
return dispatch_to_extension_op(op, self, other)

elif isinstance(other, ABCSeries):
# By this point we have checked that self._indexed_same(other)
res_values = na_op(self.values, other.values)
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/extension/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class TestMyDtype(BaseDtypeTests):
from .groupby import BaseGroupbyTests # noqa
from .interface import BaseInterfaceTests # noqa
from .methods import BaseMethodsTests # noqa
from .ops import BaseArithmeticOpsTests, BaseComparisonOpsTests # noqa
from .missing import BaseMissingTests # noqa
from .reshaping import BaseReshapingTests # noqa
from .setitem import BaseSetitemTests # noqa
Loading

0 comments on commit 0b63e81

Please sign in to comment.