diff --git a/doc/source/extending.rst b/doc/source/extending.rst index 8018d35770924..38b3b19031a0e 100644 --- a/doc/source/extending.rst +++ b/doc/source/extending.rst @@ -61,7 +61,7 @@ Extension Types .. warning:: - The :class:`pandas.api.extension.ExtensionDtype` and :class:`pandas.api.extension.ExtensionArray` APIs are new and + The :class:`pandas.api.extensions.ExtensionDtype` and :class:`pandas.api.extensions.ExtensionArray` APIs are new and experimental. They may change between versions without warning. Pandas defines an interface for implementing data types and arrays that *extend* @@ -79,10 +79,10 @@ on :ref:`ecosystem.extensions`. The interface consists of two classes. -:class:`~pandas.api.extension.ExtensionDtype` +:class:`~pandas.api.extensions.ExtensionDtype` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -A :class:`pandas.api.extension.ExtensionDtype` is similar to a ``numpy.dtype`` object. It describes the +A :class:`pandas.api.extensions.ExtensionDtype` is similar to a ``numpy.dtype`` object. It describes the data type. Implementors are responsible for a few unique items like the name. One particularly important item is the ``type`` property. This should be the @@ -91,7 +91,7 @@ extension array for IP Address data, this might be ``ipaddress.IPv4Address``. See the `extension dtype source`_ for interface definition. -:class:`~pandas.api.extension.ExtensionArray` +:class:`~pandas.api.extensions.ExtensionArray` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ This class provides all the array-like functionality. ExtensionArrays are @@ -113,6 +113,54 @@ by some other storage type, like Python lists. See the `extension array source`_ for the interface definition. The docstrings and comments contain guidance for properly implementing the interface. +.. _extending.extension.operator: + +:class:`~pandas.api.extensions.ExtensionArray` Operator Support +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. versionadded:: 0.24.0 + +By default, there are no operators defined for the class :class:`~pandas.api.extensions.ExtensionArray`. +There are two approaches for providing operator support for your ExtensionArray: + +1. Define each of the operators on your ``ExtensionArray`` subclass. +2. Use an operator implementation from pandas that depends on operators that are already defined + on the underlying elements (scalars) of the ExtensionArray. + +For the first approach, you define selected operators, e.g., ``__add__``, ``__le__``, etc. that +you want your ``ExtensionArray`` subclass to support. + +The second approach assumes that the underlying elements (i.e., scalar type) of the ``ExtensionArray`` +have the individual operators already defined. In other words, if your ``ExtensionArray`` +named ``MyExtensionArray`` is implemented so that each element is an instance +of the class ``MyExtensionElement``, then if the operators are defined +for ``MyExtensionElement``, the second approach will automatically +define the operators for ``MyExtensionArray``. + +A mixin class, :class:`~pandas.api.extensions.ExtensionScalarOpsMixin` supports this second +approach. If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``, +can simply include ``ExtensionScalarOpsMixin`` as a parent class of ``MyExtensionArray``, +and then call the methods :meth:`~MyExtensionArray._add_arithmetic_ops` and/or +:meth:`~MyExtensionArray._add_comparison_ops` to hook the operators into +your ``MyExtensionArray`` class, as follows: + +.. code-block:: python + + class MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin): + pass + + MyExtensionArray._add_arithmetic_ops() + MyExtensionArray._add_comparison_ops() + +Note that since ``pandas`` automatically calls the underlying operator on each +element one-by-one, this might not be as performant as implementing your own +version of the associated operators directly on the ``ExtensionArray``. + +.. _extending.extension.testing: + +Testing Extension Arrays +^^^^^^^^^^^^^^^^^^^^^^^^ + We provide a test suite for ensuring that your extension arrays satisfy the expected behavior. To use the test suite, you must provide several pytest fixtures and inherit from the base test class. The required fixtures are found in @@ -174,11 +222,11 @@ There are 3 constructor properties to be defined: Following table shows how ``pandas`` data structures define constructor properties by default. =========================== ======================= ============= -Property Attributes ``Series`` ``DataFrame`` +Property Attributes ``Series`` ``DataFrame`` =========================== ======================= ============= -``_constructor`` ``Series`` ``DataFrame`` -``_constructor_sliced`` ``NotImplementedError`` ``Series`` -``_constructor_expanddim`` ``DataFrame`` ``Panel`` +``_constructor`` ``Series`` ``DataFrame`` +``_constructor_sliced`` ``NotImplementedError`` ``Series`` +``_constructor_expanddim`` ``DataFrame`` ``Panel`` =========================== ======================= ============= Below example shows how to define ``SubclassedSeries`` and ``SubclassedDataFrame`` overriding constructor properties. diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index 1ab67bd80a5e8..2b38e7b1d5cc3 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -10,6 +10,22 @@ New features - ``ExcelWriter`` now accepts ``mode`` as a keyword argument, enabling append to existing workbooks when using the ``openpyxl`` engine (:issue:`3441`) +.. _whatsnew_0240.enhancements.extension_array_operators + +``ExtensionArray`` operator support +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +A ``Series`` based on an ``ExtensionArray`` now supports arithmetic and comparison +operators. (:issue:`19577`). There are two approaches for providing operator support for an ``ExtensionArray``: + +1. Define each of the operators on your ``ExtensionArray`` subclass. +2. Use an operator implementation from pandas that depends on operators that are already defined + on the underlying elements (scalars) of the ``ExtensionArray``. + +See the :ref:`ExtensionArray Operator Support +` documentation section for details on both +ways of adding operator support. + .. _whatsnew_0240.enhancements.other: Other Enhancements diff --git a/pandas/api/extensions/__init__.py b/pandas/api/extensions/__init__.py index 3e6e192a3502c..851a63725952a 100644 --- a/pandas/api/extensions/__init__.py +++ b/pandas/api/extensions/__init__.py @@ -3,5 +3,6 @@ register_index_accessor, register_series_accessor) from pandas.core.algorithms import take # noqa -from pandas.core.arrays.base import ExtensionArray # noqa +from pandas.core.arrays.base import (ExtensionArray, # noqa + ExtensionScalarOpsMixin) from pandas.core.dtypes.dtypes import ExtensionDtype # noqa diff --git a/pandas/conftest.py b/pandas/conftest.py index ae08e0817de29..8ca90722d17f7 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -89,7 +89,8 @@ def observed(request): '__mul__', '__rmul__', '__floordiv__', '__rfloordiv__', '__truediv__', '__rtruediv__', - '__pow__', '__rpow__'] + '__pow__', '__rpow__', + '__mod__', '__rmod__'] if not PY3: _all_arithmetic_operators.extend(['__div__', '__rdiv__']) @@ -102,6 +103,22 @@ def all_arithmetic_operators(request): return request.param +@pytest.fixture(params=['__eq__', '__ne__', '__le__', + '__lt__', '__ge__', '__gt__']) +def all_compare_operators(request): + """ + Fixture for dunder names for common compare operations + + * >= + * > + * == + * != + * < + * <= + """ + return request.param + + @pytest.fixture(params=[None, 'gzip', 'bz2', 'zip', pytest.param('xz', marks=td.skip_if_no_lzma)]) def compression(request): @@ -320,20 +337,3 @@ def mock(): return importlib.import_module("unittest.mock") else: return pytest.importorskip("mock") - - -@pytest.fixture(params=['__eq__', '__ne__', '__le__', - '__lt__', '__ge__', '__gt__']) -def all_compare_operators(request): - """ - Fixture for dunder names for common compare operations - - * >= - * > - * == - * != - * < - * <= - """ - - return request.param diff --git a/pandas/core/arrays/__init__.py b/pandas/core/arrays/__init__.py index f8adcf520c15b..f57348116c195 100644 --- a/pandas/core/arrays/__init__.py +++ b/pandas/core/arrays/__init__.py @@ -1,2 +1,3 @@ -from .base import ExtensionArray # noqa +from .base import (ExtensionArray, # noqa + ExtensionScalarOpsMixin) from .categorical import Categorical # noqa diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 30949ca6d1d6b..a572fff1c44d7 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -7,8 +7,13 @@ """ import numpy as np +import operator + from pandas.errors import AbstractMethodError from pandas.compat.numpy import function as nv +from pandas.compat import set_function_name, PY3 +from pandas.core.dtypes.common import is_list_like +from pandas.core import ops _not_implemented_message = "{} does not implement {}." @@ -610,3 +615,125 @@ def _ndarray_values(self): used for interacting with our indexers. """ return np.array(self) + + +class ExtensionOpsMixin(object): + """ + A base class for linking the operators to their dunder names + """ + @classmethod + def _add_arithmetic_ops(cls): + cls.__add__ = cls._create_arithmetic_method(operator.add) + cls.__radd__ = cls._create_arithmetic_method(ops.radd) + cls.__sub__ = cls._create_arithmetic_method(operator.sub) + cls.__rsub__ = cls._create_arithmetic_method(ops.rsub) + cls.__mul__ = cls._create_arithmetic_method(operator.mul) + cls.__rmul__ = cls._create_arithmetic_method(ops.rmul) + cls.__pow__ = cls._create_arithmetic_method(operator.pow) + cls.__rpow__ = cls._create_arithmetic_method(ops.rpow) + cls.__mod__ = cls._create_arithmetic_method(operator.mod) + cls.__rmod__ = cls._create_arithmetic_method(ops.rmod) + cls.__floordiv__ = cls._create_arithmetic_method(operator.floordiv) + cls.__rfloordiv__ = cls._create_arithmetic_method(ops.rfloordiv) + cls.__truediv__ = cls._create_arithmetic_method(operator.truediv) + cls.__rtruediv__ = cls._create_arithmetic_method(ops.rtruediv) + if not PY3: + cls.__div__ = cls._create_arithmetic_method(operator.div) + cls.__rdiv__ = cls._create_arithmetic_method(ops.rdiv) + + cls.__divmod__ = cls._create_arithmetic_method(divmod) + cls.__rdivmod__ = cls._create_arithmetic_method(ops.rdivmod) + + @classmethod + def _add_comparison_ops(cls): + cls.__eq__ = cls._create_comparison_method(operator.eq) + cls.__ne__ = cls._create_comparison_method(operator.ne) + cls.__lt__ = cls._create_comparison_method(operator.lt) + cls.__gt__ = cls._create_comparison_method(operator.gt) + cls.__le__ = cls._create_comparison_method(operator.le) + cls.__ge__ = cls._create_comparison_method(operator.ge) + + +class ExtensionScalarOpsMixin(ExtensionOpsMixin): + """A mixin for defining the arithmetic and logical operations on + an ExtensionArray class, where it is assumed that the underlying objects + have the operators already defined. + + Usage + ------ + If you have defined a subclass MyExtensionArray(ExtensionArray), then + use MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin) to + get the arithmetic operators. After the definition of MyExtensionArray, + insert the lines + + MyExtensionArray._add_arithmetic_ops() + MyExtensionArray._add_comparison_ops() + + to link the operators to your class. + """ + + @classmethod + def _create_method(cls, op, coerce_to_dtype=True): + """ + A class method that returns a method that will correspond to an + operator for an ExtensionArray subclass, by dispatching to the + relevant operator defined on the individual elements of the + ExtensionArray. + + Parameters + ---------- + op : function + An operator that takes arguments op(a, b) + coerce_to_dtype : bool + boolean indicating whether to attempt to convert + the result to the underlying ExtensionArray dtype + (default True) + + Returns + ------- + A method that can be bound to a method of a class + + Example + ------- + Given an ExtensionArray subclass called MyExtensionArray, use + + >>> __add__ = cls._create_method(operator.add) + + in the class definition of MyExtensionArray to create the operator + for addition, that will be based on the operator implementation + of the underlying elements of the ExtensionArray + + """ + + def _binop(self, other): + def convert_values(param): + if isinstance(param, ExtensionArray) or is_list_like(param): + ovalues = param + else: # Assume its an object + ovalues = [param] * len(self) + return ovalues + lvalues = self + rvalues = convert_values(other) + + # If the operator is not defined for the underlying objects, + # a TypeError should be raised + res = [op(a, b) for (a, b) in zip(lvalues, rvalues)] + + if coerce_to_dtype: + try: + res = self._from_sequence(res) + except TypeError: + pass + + return res + + op_name = ops._get_op_name(op, True) + return set_function_name(_binop, op_name, cls) + + @classmethod + def _create_arithmetic_method(cls, op): + return cls._create_method(op) + + @classmethod + def _create_comparison_method(cls, op): + return cls._create_method(op, coerce_to_dtype=False) diff --git a/pandas/core/ops.py b/pandas/core/ops.py index 540ebeee438f6..fa6d88648cc63 100644 --- a/pandas/core/ops.py +++ b/pandas/core/ops.py @@ -33,6 +33,7 @@ is_bool_dtype, is_list_like, is_scalar, + is_extension_array_dtype, _ensure_object) from pandas.core.dtypes.cast import ( maybe_upcast_putmask, find_common_type, @@ -993,6 +994,26 @@ def _construct_divmod_result(left, result, index, name, dtype): ) +def dispatch_to_extension_op(op, left, right): + """ + Assume that left or right is a Series backed by an ExtensionArray, + apply the operator defined by op. + """ + + # The op calls will raise TypeError if the op is not defined + # on the ExtensionArray + if is_extension_array_dtype(left): + res_values = op(left.values, right) + else: + # We know that left is not ExtensionArray and is Series and right is + # ExtensionArray. Want to force ExtensionArray op to get called + res_values = op(list(left.values), right.values) + + res_name = get_op_result_name(left, right) + return left._constructor(res_values, index=left.index, + name=res_name) + + def _arith_method_SERIES(cls, op, special): """ Wrapper function for Series arithmetic operations, to avoid @@ -1061,6 +1082,11 @@ def wrapper(left, right): raise TypeError("{typ} cannot perform the operation " "{op}".format(typ=type(left).__name__, op=str_rep)) + elif (is_extension_array_dtype(left) or + (is_extension_array_dtype(right) and + not is_categorical_dtype(right))): + return dispatch_to_extension_op(op, left, right) + lvalues = left.values rvalues = right if isinstance(rvalues, ABCSeries): @@ -1238,6 +1264,11 @@ def wrapper(self, other, axis=None): return self._constructor(res_values, index=self.index, name=res_name) + elif (is_extension_array_dtype(self) or + (is_extension_array_dtype(other) and + not is_categorical_dtype(other))): + return dispatch_to_extension_op(op, self, other) + elif isinstance(other, ABCSeries): # By this point we have checked that self._indexed_same(other) res_values = na_op(self.values, other.values) diff --git a/pandas/tests/extension/base/__init__.py b/pandas/tests/extension/base/__init__.py index 9da985625c4ee..640b894e2245f 100644 --- a/pandas/tests/extension/base/__init__.py +++ b/pandas/tests/extension/base/__init__.py @@ -47,6 +47,7 @@ class TestMyDtype(BaseDtypeTests): from .groupby import BaseGroupbyTests # noqa from .interface import BaseInterfaceTests # noqa from .methods import BaseMethodsTests # noqa +from .ops import BaseArithmeticOpsTests, BaseComparisonOpsTests # noqa from .missing import BaseMissingTests # noqa from .reshaping import BaseReshapingTests # noqa from .setitem import BaseSetitemTests # noqa diff --git a/pandas/tests/extension/base/ops.py b/pandas/tests/extension/base/ops.py new file mode 100644 index 0000000000000..659b9757ac1e3 --- /dev/null +++ b/pandas/tests/extension/base/ops.py @@ -0,0 +1,94 @@ +import pytest + +import operator + +import pandas as pd +from .base import BaseExtensionTests + + +class BaseOpsUtil(BaseExtensionTests): + def get_op_from_name(self, op_name): + short_opname = op_name.strip('_') + try: + op = getattr(operator, short_opname) + except AttributeError: + # Assume it is the reverse operator + rop = getattr(operator, short_opname[1:]) + op = lambda x, y: rop(y, x) + + return op + + def check_opname(self, s, op_name, other, exc=NotImplementedError): + op = self.get_op_from_name(op_name) + + self._check_op(s, op, other, exc) + + def _check_op(self, s, op, other, exc=NotImplementedError): + if exc is None: + result = op(s, other) + expected = s.combine(other, op) + self.assert_series_equal(result, expected) + else: + with pytest.raises(exc): + op(s, other) + + +class BaseArithmeticOpsTests(BaseOpsUtil): + """Various Series and DataFrame arithmetic ops methods.""" + + def test_arith_scalar(self, data, all_arithmetic_operators): + # scalar + op_name = all_arithmetic_operators + s = pd.Series(data) + self.check_opname(s, op_name, s.iloc[0], exc=TypeError) + + def test_arith_array(self, data, all_arithmetic_operators): + # ndarray & other series + op_name = all_arithmetic_operators + s = pd.Series(data) + self.check_opname(s, op_name, [s.iloc[0]] * len(s), exc=TypeError) + + def test_divmod(self, data): + s = pd.Series(data) + self._check_op(s, divmod, 1, exc=TypeError) + self._check_op(1, divmod, s, exc=TypeError) + + def test_error(self, data, all_arithmetic_operators): + # invalid ops + op_name = all_arithmetic_operators + with pytest.raises(AttributeError): + getattr(data, op_name) + + +class BaseComparisonOpsTests(BaseOpsUtil): + """Various Series and DataFrame comparison ops methods.""" + + def _compare_other(self, s, data, op_name, other): + op = self.get_op_from_name(op_name) + if op_name == '__eq__': + assert getattr(data, op_name)(other) is NotImplemented + assert not op(s, other).all() + elif op_name == '__ne__': + assert getattr(data, op_name)(other) is NotImplemented + assert op(s, other).all() + + else: + + # array + assert getattr(data, op_name)(other) is NotImplemented + + # series + s = pd.Series(data) + with pytest.raises(TypeError): + op(s, other) + + def test_compare_scalar(self, data, all_compare_operators): + op_name = all_compare_operators + s = pd.Series(data) + self._compare_other(s, data, op_name, 0) + + def test_compare_array(self, data, all_compare_operators): + op_name = all_compare_operators + s = pd.Series(data) + other = [0] * len(data) + self._compare_other(s, data, op_name, other) diff --git a/pandas/tests/extension/category/test_categorical.py b/pandas/tests/extension/category/test_categorical.py index 61fdb8454b542..ae0d72c204d13 100644 --- a/pandas/tests/extension/category/test_categorical.py +++ b/pandas/tests/extension/category/test_categorical.py @@ -183,3 +183,29 @@ def test_combine_add(self, data_repeated): class TestCasting(base.BaseCastingTests): pass + + +class TestArithmeticOps(base.BaseArithmeticOpsTests): + + def test_arith_scalar(self, data, all_arithmetic_operators): + + op_name = all_arithmetic_operators + if op_name != '__rmod__': + super(TestArithmeticOps, self).test_arith_scalar(data, op_name) + else: + pytest.skip('rmod never called when string is first argument') + + +class TestComparisonOps(base.BaseComparisonOpsTests): + + def _compare_other(self, s, data, op_name, other): + op = self.get_op_from_name(op_name) + if op_name == '__eq__': + assert not op(data, other).all() + + elif op_name == '__ne__': + assert op(data, other).all() + + else: + with pytest.raises(TypeError): + op(data, other) diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index cc6fadc483d5e..3f2f24cd26af0 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -6,7 +6,8 @@ import numpy as np import pandas as pd -from pandas.core.arrays import ExtensionArray +from pandas.core.arrays import (ExtensionArray, + ExtensionScalarOpsMixin) from pandas.core.dtypes.base import ExtensionDtype @@ -24,13 +25,14 @@ def construct_from_string(cls, string): "'{}'".format(cls, string)) -class DecimalArray(ExtensionArray): +class DecimalArray(ExtensionArray, ExtensionScalarOpsMixin): dtype = DecimalDtype() def __init__(self, values): for val in values: if not isinstance(val, self.dtype.type): - raise TypeError + raise TypeError("All values must be of type " + + str(self.dtype.type)) values = np.asarray(values, dtype=object) self._data = values @@ -103,5 +105,9 @@ def _concat_same_type(cls, to_concat): return cls(np.concatenate([x._data for x in to_concat])) +DecimalArray._add_arithmetic_ops() +DecimalArray._add_comparison_ops() + + def make_data(): return [decimal.Decimal(random.random()) for _ in range(100)] diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index f74b4d7e94f11..45ee7f227c4f0 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -191,3 +191,64 @@ def test_dataframe_constructor_with_different_dtype_raises(): xpr = "Cannot coerce extension array to dtype 'int64'. " with tm.assert_raises_regex(ValueError, xpr): pd.DataFrame({"A": arr}, dtype='int64') + + +class TestArithmeticOps(BaseDecimal, base.BaseArithmeticOpsTests): + + def check_opname(self, s, op_name, other, exc=None): + super(TestArithmeticOps, self).check_opname(s, op_name, + other, exc=None) + + def test_arith_array(self, data, all_arithmetic_operators): + op_name = all_arithmetic_operators + s = pd.Series(data) + + context = decimal.getcontext() + divbyzerotrap = context.traps[decimal.DivisionByZero] + invalidoptrap = context.traps[decimal.InvalidOperation] + context.traps[decimal.DivisionByZero] = 0 + context.traps[decimal.InvalidOperation] = 0 + + # Decimal supports ops with int, but not float + other = pd.Series([int(d * 100) for d in data]) + self.check_opname(s, op_name, other) + + if "mod" not in op_name: + self.check_opname(s, op_name, s * 2) + + self.check_opname(s, op_name, 0) + self.check_opname(s, op_name, 5) + context.traps[decimal.DivisionByZero] = divbyzerotrap + context.traps[decimal.InvalidOperation] = invalidoptrap + + @pytest.mark.skip(reason="divmod not appropriate for decimal") + def test_divmod(self, data): + pass + + def test_error(self): + pass + + +class TestComparisonOps(BaseDecimal, base.BaseComparisonOpsTests): + + def check_opname(self, s, op_name, other, exc=None): + super(TestComparisonOps, self).check_opname(s, op_name, + other, exc=None) + + def _compare_other(self, s, data, op_name, other): + self.check_opname(s, op_name, other) + + def test_compare_scalar(self, data, all_compare_operators): + op_name = all_compare_operators + s = pd.Series(data) + self._compare_other(s, data, op_name, 0.5) + + def test_compare_array(self, data, all_compare_operators): + op_name = all_compare_operators + s = pd.Series(data) + + alter = np.random.choice([-1, 0, 1], len(data)) + # Randomly double, halve or keep same value + other = pd.Series(data) * [decimal.Decimal(pow(2.0, i)) + for i in alter] + self._compare_other(s, data, op_name, other) diff --git a/pandas/tests/extension/json/array.py b/pandas/tests/extension/json/array.py index 10be7836cb8d7..d3043bf0852d2 100644 --- a/pandas/tests/extension/json/array.py +++ b/pandas/tests/extension/json/array.py @@ -47,7 +47,8 @@ class JSONArray(ExtensionArray): def __init__(self, values): for val in values: if not isinstance(val, self.dtype.type): - raise TypeError + raise TypeError("All values must be of type " + + str(self.dtype.type)) self.data = values # Some aliases for common attribute names to ensure pandas supports diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index 85a282ae4007f..268134dc8c333 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -238,3 +238,12 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping): super(TestGroupby, self).test_groupby_extension_agg( as_index, data_for_grouping ) + + +class TestArithmeticOps(BaseJSON, base.BaseArithmeticOpsTests): + def test_error(self, data, all_arithmetic_operators): + pass + + +class TestComparisonOps(BaseJSON, base.BaseComparisonOpsTests): + pass diff --git a/pandas/util/testing.py b/pandas/util/testing.py index a5afcb6915034..11e9942079aad 100644 --- a/pandas/util/testing.py +++ b/pandas/util/testing.py @@ -29,7 +29,8 @@ is_categorical_dtype, is_interval_dtype, is_sequence, - is_list_like) + is_list_like, + is_extension_array_dtype) from pandas.io.formats.printing import pprint_thing from pandas.core.algorithms import take_1d import pandas.core.common as com @@ -1243,6 +1244,10 @@ def assert_series_equal(left, right, check_dtype=True, right = pd.IntervalIndex(right) assert_index_equal(left, right, obj='{obj}.index'.format(obj=obj)) + elif (is_extension_array_dtype(left) and not is_categorical_dtype(left) and + is_extension_array_dtype(right) and not is_categorical_dtype(right)): + return assert_extension_array_equal(left.values, right.values) + else: _testing.assert_almost_equal(left.get_values(), right.get_values(), check_less_precise=check_less_precise,