From 5b0ebc7eba0cd7ec632f266ff7af3202226d16ad Mon Sep 17 00:00:00 2001 From: Dr-Irv Date: Wed, 30 May 2018 14:25:39 -0400 Subject: [PATCH 01/12] ENH: Support ExtensionArray operators via a mixin --- doc/source/whatsnew/v0.24.0.txt | 2 +- pandas/core/arrays/__init__.py | 2 +- pandas/core/arrays/base.py | 93 +++++++++++++++++++ pandas/core/indexes/base.py | 10 +- pandas/core/ops.py | 21 +++++ pandas/core/series.py | 27 ++++-- pandas/tests/extension/base/getitem.py | 2 +- pandas/tests/extension/decimal/array.py | 8 +- .../tests/extension/decimal/test_decimal.py | 36 +++++++ pandas/tests/series/test_operators.py | 6 +- pandas/util/testing.py | 7 +- 11 files changed, 192 insertions(+), 22 deletions(-) diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index e931450cb5c01..b0f7de6354a3a 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -13,7 +13,7 @@ New features Other Enhancements ^^^^^^^^^^^^^^^^^^ - :func:`to_datetime` now supports the ``%Z`` and ``%z`` directive when passed into ``format`` (:issue:`13486`) -- +- ``ExtensionArray`` has a ``ExtensionOpsMixin`` factory that allows default operators to be defined (:issue:`20659`, :issue:`19577`) - .. _whatsnew_0240.api_breaking: diff --git a/pandas/core/arrays/__init__.py b/pandas/core/arrays/__init__.py index f8adcf520c15b..693fe5d4cc0d1 100644 --- a/pandas/core/arrays/__init__.py +++ b/pandas/core/arrays/__init__.py @@ -1,2 +1,2 @@ -from .base import ExtensionArray # noqa +from .base import ExtensionArray, ExtensionOpsMixin # noqa from .categorical import Categorical # noqa diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 1922801c30719..c57819c186661 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -9,6 +9,11 @@ from pandas.errors import AbstractMethodError from pandas.compat.numpy import function as nv +from pandas.compat import set_function_name, PY3 +import pandas.core.common as com +from pandas.core.dtypes.common import ( + is_extension_array_dtype, + is_list_like) _not_implemented_message = "{} does not implement {}." @@ -610,3 +615,91 @@ def _ndarray_values(self): used for interacting with our indexers. """ return np.array(self) + + +def ExtensionOpsMixin(include_arith_ops, include_logic_ops): + """A mixin factory for creating default arithmetic and logical operators, + which are based on the underlying dtype backing the ExtensionArray + + Parameters + ---------- + include_arith_ops : boolean indicating whether arithmetic ops should be + created + include_logic_ops : boolean indicating whether logical ops should be + created + + Returns + ------- + A mixin class that has the associated operators defined. + + Usage + ------ + If you have defined a subclass MyClass(ExtensionArray), then + use MyClass(ExtensionArray, ExtensionOpsMixin(True, True)) to + get both the arithmetic and logical operators + """ + class _ExtensionOpsMixin: + pass + + def create_method(op_name): + def _binop(self, other): + def convert_values(parm): + if isinstance(parm, ExtensionArray): + ovalues = list(parm) + elif is_extension_array_dtype(parm): + ovalues = parm.values + elif is_list_like(parm): + ovalues = parm + else: # Assume its an object + ovalues = [parm] * len(self) + return ovalues + lvalues = convert_values(self) + rvalues = convert_values(other) + + # Get the method for each object. + def callfunc(a, b): + f = getattr(a, op_name, None) + if f is not None: + return f(b) + else: + return NotImplemented + res = [callfunc(a, b) for (a, b) in zip(lvalues, rvalues)] + + # We can't use (NotImplemented in res) because the + # results might be objects that have overridden __eq__ + if any(isinstance(r, type(NotImplemented)) for r in res): + msg = "invalid operation {opn} between {one} and {two}" + raise TypeError(msg.format(opn=op_name, + one=type(lvalues), + two=type(rvalues))) + + res_values = com._values_from_object(res) + + try: + res_values = self._from_sequence(res_values) + except TypeError: + pass + + return res_values + + name = '__{name}__'.format(name=op_name) + return set_function_name(_binop, name, _ExtensionOpsMixin) + + if include_arith_ops: + arithops = ['__add__', '__radd__', '__sub__', '__rsub__', '__mul__', + '__rmul__', '__pow__', '__rpow__', '__mod__', '__rmod__', + '__floordiv__', '__rfloordiv__', '__truediv__', + '__rtruediv__', '__divmod__', '__rdivmod__'] + if not PY3: + arithops.extend(['__div__', '__rdiv__']) + + for op_name in arithops: + setattr(_ExtensionOpsMixin, op_name, create_method(op_name)) + + if include_logic_ops: + logicops = ['__eq__', '__ne__', '__lt__', '__gt__', + '__le__', '__ge__'] + for op_name in logicops: + setattr(_ExtensionOpsMixin, op_name, create_method(op_name)) + + return _ExtensionOpsMixin diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 229624c7e6645..1d54017c2357c 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2972,16 +2972,20 @@ def get_value(self, series, key): # use this, e.g. DatetimeIndex s = getattr(series, '_values', None) if isinstance(s, (ExtensionArray, Index)) and is_scalar(key): - # GH 20825 + # GH 20882, 21257 # Unify Index and ExtensionArray treatment # First try to convert the key to a location - # If that fails, see if key is an integer, and + # If that fails, raise a KeyError if an integer + # index, otherwise, see if key is an integer, and # try that try: iloc = self.get_loc(key) return s[iloc] except KeyError: - if is_integer(key): + if (len(self) > 0 and + self.inferred_type in ['integer', 'boolean']): + raise + elif is_integer(key): return s[key] s = com._values_from_object(series) diff --git a/pandas/core/ops.py b/pandas/core/ops.py index e14f82906cd06..a384d31f62246 100644 --- a/pandas/core/ops.py +++ b/pandas/core/ops.py @@ -30,6 +30,7 @@ is_bool_dtype, is_list_like, is_scalar, + is_extension_array_dtype, _ensure_object) from pandas.core.dtypes.cast import ( maybe_upcast_putmask, find_common_type, @@ -990,6 +991,20 @@ def _construct_divmod_result(left, result, index, name, dtype): ) +def dispatch_to_extension_op(left, right, op_name): + """ + Assume that left is a Series backed by an ExtensionArray, + apply the operator defined by op_name. + """ + + method = getattr(left.values, op_name, None) + res_values = method(right) + + res_name = get_op_result_name(left, right) + return left._constructor(res_values, index=left.index, + name=res_name) + + def _arith_method_SERIES(cls, op, special): """ Wrapper function for Series arithmetic operations, to avoid @@ -1058,6 +1073,9 @@ def wrapper(left, right): raise TypeError("{typ} cannot perform the operation " "{op}".format(typ=type(left).__name__, op=str_rep)) + elif is_extension_array_dtype(left): + return dispatch_to_extension_op(left, right, op_name) + lvalues = left.values rvalues = right if isinstance(rvalues, ABCSeries): @@ -1208,6 +1226,9 @@ def wrapper(self, other, axis=None): return self._constructor(res_values, index=self.index, name=res_name) + elif is_extension_array_dtype(self): + return dispatch_to_extension_op(self, other, op_name) + elif isinstance(other, ABCSeries): # By this point we have checked that self._indexed_same(other) res_values = na_op(self.values, other.values) diff --git a/pandas/core/series.py b/pandas/core/series.py index c9329e8b9e572..8e1e3640639f0 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -2196,23 +2196,22 @@ def _binop(self, other, func, level=None, fill_value=None): result.name = None return result - def combine(self, other, func, fill_value=np.nan): + def combine(self, other, func, fill_value=None): """ Perform elementwise binary operation on two Series using given function with optional fill value when an index is missing from one Series or the other - Parameters ---------- other : Series or scalar value func : function Function that takes two scalars as inputs and return a scalar fill_value : scalar value - + The default specifies to use the appropriate NaN value for + the underlying dtype of the Series Returns ------- result : Series - Examples -------- >>> s1 = Series([1, 2]) @@ -2221,26 +2220,36 @@ def combine(self, other, func, fill_value=np.nan): 0 0 1 2 dtype: int64 - See Also -------- Series.combine_first : Combine Series values, choosing the calling Series's values first """ + self_is_ext = is_extension_array_dtype(self.values) + if fill_value is None: + fill_value = na_value_for_dtype(self.dtype, False) + if isinstance(other, Series): new_index = self.index.union(other.index) new_name = ops.get_op_result_name(self, other) - new_values = np.empty(len(new_index), dtype=self.dtype) - for i, idx in enumerate(new_index): + new_values = [] + for idx in new_index: lv = self.get(idx, fill_value) rv = other.get(idx, fill_value) with np.errstate(all='ignore'): - new_values[i] = func(lv, rv) + new_values.append(func(lv, rv)) else: new_index = self.index with np.errstate(all='ignore'): - new_values = func(self._values, other) + new_values = [func(lv, other) for lv in self._values] new_name = self.name + + if self_is_ext and not is_categorical_dtype(self.values): + try: + new_values = self._values._from_sequence(new_values) + except TypeError: + pass + return self._constructor(new_values, index=new_index, name=new_name) def combine_first(self, other): diff --git a/pandas/tests/extension/base/getitem.py b/pandas/tests/extension/base/getitem.py index 883b3f5588aef..390971c134642 100644 --- a/pandas/tests/extension/base/getitem.py +++ b/pandas/tests/extension/base/getitem.py @@ -130,7 +130,7 @@ def test_get(self, data): expected = s.iloc[[0, 1]] self.assert_series_equal(result, expected) - assert s.get(-1) == s.iloc[-1] + assert s.get(-1) is None assert s.get(s.index.max() + 1) is None s = pd.Series(data[:6], index=list('abcdef')) diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 90f0181beab0d..83143f0dde455 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -6,7 +6,7 @@ import numpy as np import pandas as pd -from pandas.core.arrays import ExtensionArray +from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin from pandas.core.dtypes.base import ExtensionDtype @@ -24,11 +24,13 @@ def construct_from_string(cls, string): "'{}'".format(cls, string)) -class DecimalArray(ExtensionArray): +class DecimalArray(ExtensionArray, ExtensionOpsMixin(True, True)): dtype = DecimalDtype() def __init__(self, values): - assert all(isinstance(v, decimal.Decimal) for v in values) + for val in values: + if not isinstance(val, self.dtype.type): + raise TypeError values = np.asarray(values, dtype=object) self._data = values diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 1f8cf0264f62f..6f6ba5ff6bab3 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -7,6 +7,9 @@ from pandas.tests.extension import base +from pandas.tests.series.test_operators import TestSeriesOperators +from pandas.util._decorators import cache_readonly + from .array import DecimalDtype, DecimalArray, make_data @@ -183,3 +186,36 @@ def test_dataframe_constructor_with_different_dtype_raises(): xpr = "Cannot coerce extension array to dtype 'int64'. " with tm.assert_raises_regex(ValueError, xpr): pd.DataFrame({"A": arr}, dtype='int64') + + +_ts = pd.Series(DecimalArray(make_data())) + + +class TestOperator(BaseDecimal, TestSeriesOperators): + @cache_readonly + def ts(self): + ts = _ts.copy() + ts.name = 'ts' + return ts + + def test_operators(self): + def absfunc(v): + if isinstance(v, pd.Series): + vals = v.values + return pd.Series(vals._from_sequence([abs(i) for i in vals])) + else: + return abs(v) + context = decimal.getcontext() + divbyzerotrap = context.traps[decimal.DivisionByZero] + invalidoptrap = context.traps[decimal.InvalidOperation] + context.traps[decimal.DivisionByZero] = 0 + context.traps[decimal.InvalidOperation] = 0 + super(TestOperator, self).test_operators(absfunc) + context.traps[decimal.DivisionByZero] = divbyzerotrap + context.traps[decimal.InvalidOperation] = invalidoptrap + + def test_operators_corner(self): + pytest.skip("Cannot add empty Series of float64 to DecimalArray") + + def test_divmod(self): + pytest.skip("divmod not appropriate for Decimal type") diff --git a/pandas/tests/series/test_operators.py b/pandas/tests/series/test_operators.py index ecb74622edf10..ed0ae64be53ad 100644 --- a/pandas/tests/series/test_operators.py +++ b/pandas/tests/series/test_operators.py @@ -1216,11 +1216,11 @@ def test_neg(self): def test_invert(self): assert_series_equal(-(self.series < 0), ~(self.series < 0)) - def test_operators(self): + def test_operators(self, absfunc=np.abs): def _check_op(series, other, op, pos_only=False, check_dtype=True): - left = np.abs(series) if pos_only else series - right = np.abs(other) if pos_only else other + left = absfunc(series) if pos_only else series + right = absfunc(other) if pos_only else other cython_or_numpy = op(left, right) python = left.combine(right, op) diff --git a/pandas/util/testing.py b/pandas/util/testing.py index 233eba6490937..cd6c54f084594 100644 --- a/pandas/util/testing.py +++ b/pandas/util/testing.py @@ -30,7 +30,8 @@ is_categorical_dtype, is_interval_dtype, is_sequence, - is_list_like) + is_list_like, + is_extension_array_dtype) from pandas.io.formats.printing import pprint_thing from pandas.core.algorithms import take_1d import pandas.core.common as com @@ -1225,6 +1226,10 @@ def assert_series_equal(left, right, check_dtype=True, right = pd.IntervalIndex(right) assert_index_equal(left, right, obj='{obj}.index'.format(obj=obj)) + elif (is_extension_array_dtype(left) and not is_categorical_dtype(left) and + is_extension_array_dtype(right) and not is_categorical_dtype(right)): + return assert_extension_array_equal(left.values, right.values) + else: _testing.assert_almost_equal(left.get_values(), right.get_values(), check_less_precise=check_less_precise, From 7f2b0a1df4160dfc7ad269cf0de79289ea0aae10 Mon Sep 17 00:00:00 2001 From: Dr-Irv Date: Thu, 31 May 2018 18:26:42 -0400 Subject: [PATCH 02/12] No longer use factory. Use op instead of op_name. Catch exceptions --- pandas/api/extensions/__init__.py | 5 +- pandas/core/arrays/__init__.py | 5 +- pandas/core/arrays/base.py | 141 ++++++++++++++--------- pandas/core/ops.py | 8 +- pandas/tests/extension/decimal/array.py | 7 +- pandas/tests/extension/json/array.py | 4 +- pandas/tests/extension/json/test_json.py | 18 +++ 7 files changed, 128 insertions(+), 60 deletions(-) diff --git a/pandas/api/extensions/__init__.py b/pandas/api/extensions/__init__.py index 97c33d4d75ef6..c47ec527fb407 100644 --- a/pandas/api/extensions/__init__.py +++ b/pandas/api/extensions/__init__.py @@ -3,5 +3,8 @@ register_index_accessor, register_series_accessor) from pandas.core.algorithms import take # noqa -from pandas.core.arrays.base import ExtensionArray, ExtensionOpsMixin # noqa +from pandas.core.arrays.base import (ExtensionArray, # noqa + ExtensionArithmeticMixin, + ExtensionComparisonMixin, + ExtensionOpsBase) from pandas.core.dtypes.dtypes import ExtensionDtype # noqa diff --git a/pandas/core/arrays/__init__.py b/pandas/core/arrays/__init__.py index 693fe5d4cc0d1..edd14776a6797 100644 --- a/pandas/core/arrays/__init__.py +++ b/pandas/core/arrays/__init__.py @@ -1,2 +1,5 @@ -from .base import ExtensionArray, ExtensionOpsMixin # noqa +from .base import (ExtensionArray, # noqa + ExtensionArithmeticMixin, + ExtensionComparisonMixin, + ExtensionOpsBase) from .categorical import Categorical # noqa diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 5d99bbebb393a..599e6d22c6e78 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -7,6 +7,8 @@ """ import numpy as np +import operator + from pandas.errors import AbstractMethodError from pandas.compat.numpy import function as nv from pandas.compat import set_function_name, PY3 @@ -14,6 +16,7 @@ from pandas.core.dtypes.common import ( is_extension_array_dtype, is_list_like) +from pandas.core import ops _not_implemented_message = "{} does not implement {}." @@ -617,31 +620,36 @@ def _ndarray_values(self): return np.array(self) -def ExtensionOpsMixin(include_arith_ops, include_logic_ops): - """A mixin factory for creating default arithmetic and logical operators, - which are based on the underlying dtype backing the ExtensionArray +class ExtensionOpsBase(object): + """ + A base class for the mixins for different operators. + Can also be used to define an individual method for a specific + operator using the class method create_method() + """ + @classmethod + def create_method(cls, op): + """ + A class method that returns a method that will correspond to an + operator for an ExtensionArray subclass. + + Parameters + ---------- + op: An operator that takes arguments op(a, b) + + Returns + ------- + A method that can be bound to a method of a class - Parameters - ---------- - include_arith_ops : boolean indicating whether arithmetic ops should be - created - include_logic_ops : boolean indicating whether logical ops should be - created + Usage + ----- + Given an ExtensionArray subclass called MyClass, use - Returns - ------- - A mixin class that has the associated operators defined. + mymethod = create_method(my_operator) + in the class definition of MyClass to create the operator - Usage - ------ - If you have defined a subclass MyClass(ExtensionArray), then - use MyClass(ExtensionArray, ExtensionOpsMixin(True, True)) to - get both the arithmetic and logical operators - """ - class _ExtensionOpsMixin(object): - pass + """ + op_name = ops._get_op_name(op, False) - def create_method(op_name): def _binop(self, other): def convert_values(parm): if isinstance(parm, ExtensionArray): @@ -656,19 +664,11 @@ def convert_values(parm): lvalues = convert_values(self) rvalues = convert_values(other) - # Get the method for each object. - def callfunc(a, b): - f = getattr(a, op_name, None) - if f is not None: - return f(b) - else: - return NotImplemented - res = [callfunc(a, b) for (a, b) in zip(lvalues, rvalues)] - - # We can't use (NotImplemented in res) because the - # results might be objects that have overridden __eq__ - if any(isinstance(r, type(NotImplemented)) for r in res): - msg = "invalid operation {opn} between {one} and {two}" + try: + res = [op(a, b) for (a, b) in zip(lvalues, rvalues)] + except TypeError: + msg = ("ExtensionDtype invalid operation " + + "{opn} between {one} and {two}") raise TypeError(msg.format(opn=op_name, one=type(lvalues), two=type(rvalues))) @@ -683,23 +683,56 @@ def callfunc(a, b): return res_values name = '__{name}__'.format(name=op_name) - return set_function_name(_binop, name, _ExtensionOpsMixin) - - if include_arith_ops: - arithops = ['__add__', '__radd__', '__sub__', '__rsub__', '__mul__', - '__rmul__', '__pow__', '__rpow__', '__mod__', '__rmod__', - '__floordiv__', '__rfloordiv__', '__truediv__', - '__rtruediv__', '__divmod__', '__rdivmod__'] - if not PY3: - arithops.extend(['__div__', '__rdiv__']) - - for op_name in arithops: - setattr(_ExtensionOpsMixin, op_name, create_method(op_name)) - - if include_logic_ops: - logicops = ['__eq__', '__ne__', '__lt__', '__gt__', - '__le__', '__ge__'] - for op_name in logicops: - setattr(_ExtensionOpsMixin, op_name, create_method(op_name)) - - return _ExtensionOpsMixin + return set_function_name(_binop, name, cls) + + +class ExtensionArithmeticMixin(ExtensionOpsBase): + """A mixin for defining the arithmetic operations on an ExtensionArray + class, where it assumed that the underlying objects have the operators + already defined. + + Usage + ------ + If you have defined a subclass MyClass(ExtensionArray), then + use MyClass(ExtensionArray, ExtensionArithmeticMixin) to + get the arithmetic operators + """ + + __add__ = ExtensionOpsBase.create_method(operator.add) + __radd__ = ExtensionOpsBase.create_method(ops.radd) + __sub__ = ExtensionOpsBase.create_method(operator.sub) + __rsub__ = ExtensionOpsBase.create_method(ops.rsub) + __mul__ = ExtensionOpsBase.create_method(operator.mul) + __rmul__ = ExtensionOpsBase.create_method(ops.rmul) + __pow__ = ExtensionOpsBase.create_method(operator.pow) + __rpow__ = ExtensionOpsBase.create_method(ops.rpow) + __mod__ = ExtensionOpsBase.create_method(operator.mod) + __rmod__ = ExtensionOpsBase.create_method(ops.rmod) + __floordiv__ = ExtensionOpsBase.create_method(operator.floordiv) + __rfloordiv__ = ExtensionOpsBase.create_method(ops.rfloordiv) + __truediv__ = ExtensionOpsBase.create_method(operator.truediv) + __rtruediv__ = ExtensionOpsBase.create_method(ops.rtruediv) + if not PY3: + __div__ = ExtensionOpsBase.create_method(operator.div) + __rdiv__ = ExtensionOpsBase.create_method(ops.rdiv) + + __divmod__ = ExtensionOpsBase.create_method(divmod) + + +class ExtensionComparisonMixin(ExtensionOpsBase): + """A mixin for defining the comparison operations on an ExtensionArray + class, where it assumed that the underlying objects have the operators + already defined. + + Usage + ------ + If you have defined a subclass MyClass(ExtensionArray), then + use MyClass(ExtensionArray, ExtensionComparisonMixin) to + get the arithmetic operators + """ + __eq__ = ExtensionOpsBase.create_method(operator.eq) + __ne__ = ExtensionOpsBase.create_method(operator.ne) + __lt__ = ExtensionOpsBase.create_method(operator.lt) + __gt__ = ExtensionOpsBase.create_method(operator.gt) + __le__ = ExtensionOpsBase.create_method(operator.le) + __ge__ = ExtensionOpsBase.create_method(operator.ge) diff --git a/pandas/core/ops.py b/pandas/core/ops.py index a384d31f62246..898a4f129db39 100644 --- a/pandas/core/ops.py +++ b/pandas/core/ops.py @@ -998,7 +998,13 @@ def dispatch_to_extension_op(left, right, op_name): """ method = getattr(left.values, op_name, None) - res_values = method(right) + if method is not None: + res_values = method(right) + if method is None or res_values is NotImplemented: + msg = "ExtensionArray invalid operation {opn} between {one} and {two}" + raise TypeError(msg.format(opn=op_name, + one=type(left.values), + two=type(right))) res_name = get_op_result_name(left, right) return left._constructor(res_values, index=left.index, diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 83143f0dde455..59db8439b8747 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -6,7 +6,9 @@ import numpy as np import pandas as pd -from pandas.core.arrays import ExtensionArray, ExtensionOpsMixin +from pandas.core.arrays import (ExtensionArray, + ExtensionArithmeticMixin, + ExtensionComparisonMixin) from pandas.core.dtypes.base import ExtensionDtype @@ -24,7 +26,8 @@ def construct_from_string(cls, string): "'{}'".format(cls, string)) -class DecimalArray(ExtensionArray, ExtensionOpsMixin(True, True)): +class DecimalArray(ExtensionArray, ExtensionArithmeticMixin, + ExtensionComparisonMixin): dtype = DecimalDtype() def __init__(self, values): diff --git a/pandas/tests/extension/json/array.py b/pandas/tests/extension/json/array.py index 10be7836cb8d7..06139d5c21e8c 100644 --- a/pandas/tests/extension/json/array.py +++ b/pandas/tests/extension/json/array.py @@ -16,11 +16,12 @@ import random import string import sys +import operator import numpy as np from pandas.core.dtypes.base import ExtensionDtype -from pandas.core.arrays import ExtensionArray +from pandas.core.arrays import ExtensionArray, ExtensionOpsBase class JSONDtype(ExtensionDtype): @@ -43,6 +44,7 @@ def construct_from_string(cls, string): class JSONArray(ExtensionArray): dtype = JSONDtype() + __le__ = ExtensionOpsBase.create_method(operator.le) def __init__(self, values): for val in values: diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index b7ac8033f3f6d..b61004b705d66 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -230,3 +230,21 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping): super(TestGroupby, self).test_groupby_extension_agg( as_index, data_for_grouping ) + + +def test_ops(data): + s1 = pd.Series(data) + s2 = pd.Series(data) + # Here we test if the mixin method was defined but the underlying Dtype + # did not have the method defined + with tm.assert_raises_regex(TypeError, "ExtensionDtype invalid operation"): + (s1 <= s2) + + # An object will always have __lt__ defined, so test if we catch that it + # was not implemented + with tm.assert_raises_regex(TypeError, "ExtensionArray invalid operation"): + (s1 < s2) + + # Test that if method is not defined at all, we catch that as well + with tm.assert_raises_regex(TypeError, "ExtensionArray invalid operation"): + s1 + s2 From ec96841d0c4b9a2bb8a2c52c96de10d3696a429c Mon Sep 17 00:00:00 2001 From: Dr-Irv Date: Fri, 1 Jun 2018 17:56:24 -0400 Subject: [PATCH 03/12] Numerous changes based on PR feedback --- doc/source/whatsnew/v0.24.0.txt | 24 +++- pandas/api/extensions/__init__.py | 6 +- pandas/conftest.py | 12 +- pandas/core/arrays/__init__.py | 6 +- pandas/core/arrays/base.py | 110 +++++++++--------- pandas/core/ops.py | 18 +-- pandas/tests/extension/base/__init__.py | 1 + pandas/tests/extension/base/ops.py | 66 +++++++++++ .../extension/category/test_categorical.py | 15 +++ pandas/tests/extension/decimal/array.py | 11 +- .../tests/extension/decimal/test_decimal.py | 70 +++++++---- pandas/tests/extension/json/array.py | 7 +- pandas/tests/extension/json/test_json.py | 18 +-- 13 files changed, 242 insertions(+), 122 deletions(-) create mode 100644 pandas/tests/extension/base/ops.py diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index abfd3a5c42375..8f5aa7605074e 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -10,11 +10,31 @@ New features .. _whatsnew_0240.enhancements.other: +``ExtensionArray`` operator support +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +An ``ExtensionArray`` subclass consisting of objects that have arithmetic and +comparison operators defined on the underlying objects can easily support +those operators on the ``ExtensionArray``, and therefore the operators +on ``Series`` built on those ``ExtensionArray`` classes will work as expected. + +Two new mixin classes, :class:`ExtensionScalarArithmeticMixin` and +:class:`ExtensionScalarComparisonMixin`, support this capability. +If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``, +simply include ``ExtensionScalarArithmeticMixin`` and/or +``ExtensionScalarComparisonMixin`` as parent classes of ``MyExtensionArray`` +as follows: + +.. code-block:: python + + class MyExtensionArray(ExtensionArray, ExtensionScalarArithmeticMixin, + ExtensionScalarComparisonMixin): + + Other Enhancements ^^^^^^^^^^^^^^^^^^ - :func:`to_datetime` now supports the ``%Z`` and ``%z`` directive when passed into ``format`` (:issue:`13486`) -- :func:`to_csv` now supports ``compression`` keyword when a file handle is passed. (:issue:`21227`) -- ``ExtensionArray`` has a ``ExtensionOpsMixin`` factory that allows default operators to be defined (:issue:`20659`, :issue:`19577`) +- :func:`to_csv` now supports ``compression`` keyword when a file handle is passed. (:issue:`21227`)- ``ExtensionArray`` has a ``ExtensionOpsMixin`` factory that allows default operators to be defined (:issue:`20659`, :issue:`19577`) - .. _whatsnew_0240.api_breaking: diff --git a/pandas/api/extensions/__init__.py b/pandas/api/extensions/__init__.py index c47ec527fb407..5057c707e91d2 100644 --- a/pandas/api/extensions/__init__.py +++ b/pandas/api/extensions/__init__.py @@ -4,7 +4,7 @@ register_series_accessor) from pandas.core.algorithms import take # noqa from pandas.core.arrays.base import (ExtensionArray, # noqa - ExtensionArithmeticMixin, - ExtensionComparisonMixin, - ExtensionOpsBase) + ExtensionScalarArithmeticMixin, + ExtensionScalarComparisonMixin, + ExtensionScalarOpsMixin) from pandas.core.dtypes.dtypes import ExtensionDtype # noqa diff --git a/pandas/conftest.py b/pandas/conftest.py index a463f573c82e0..f83b796fee2cc 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -83,7 +83,8 @@ def observed(request): '__mul__', '__rmul__', '__floordiv__', '__rfloordiv__', '__truediv__', '__rtruediv__', - '__pow__', '__rpow__'] + '__pow__', '__rpow__', + '__mod__', '__rmod__'] if not PY3: _all_arithmetic_operators.extend(['__div__', '__rdiv__']) @@ -96,6 +97,15 @@ def all_arithmetic_operators(request): return request.param +@pytest.fixture(params=['__eq__', '__ne__', '__le__', + '__lt__', '__ge__', '__gt__']) +def all_compare_operators(request): + """ + Fixture for dunder names for common compare operations + """ + return request.param + + @pytest.fixture(params=[None, 'gzip', 'bz2', 'zip', pytest.param('xz', marks=td.skip_if_no_lzma)]) def compression(request): diff --git a/pandas/core/arrays/__init__.py b/pandas/core/arrays/__init__.py index edd14776a6797..b37f02bea9ea7 100644 --- a/pandas/core/arrays/__init__.py +++ b/pandas/core/arrays/__init__.py @@ -1,5 +1,5 @@ from .base import (ExtensionArray, # noqa - ExtensionArithmeticMixin, - ExtensionComparisonMixin, - ExtensionOpsBase) + ExtensionScalarArithmeticMixin, + ExtensionScalarComparisonMixin, + ExtensionScalarOpsMixin) from .categorical import Categorical # noqa diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 599e6d22c6e78..c404f87db421f 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -12,7 +12,6 @@ from pandas.errors import AbstractMethodError from pandas.compat.numpy import function as nv from pandas.compat import set_function_name, PY3 -import pandas.core.common as com from pandas.core.dtypes.common import ( is_extension_array_dtype, is_list_like) @@ -620,73 +619,74 @@ def _ndarray_values(self): return np.array(self) -class ExtensionOpsBase(object): +class ExtensionScalarOpsMixin(object): """ A base class for the mixins for different operators. Can also be used to define an individual method for a specific operator using the class method create_method() """ + @classmethod - def create_method(cls, op): + def _create_method(cls, op, coerce_to_dtype=True): """ A class method that returns a method that will correspond to an - operator for an ExtensionArray subclass. + operator for an ExtensionArray subclass, by dispatching to the + relevant operator defined on the individual elements of the + ExtensionArray. Parameters ---------- - op: An operator that takes arguments op(a, b) + op: An operator that takes arguments op(a, b) + coerce_to_dtype: boolean indicating whether to attempt to convert + the result to the underlying ExtensionArray dtype + (default True) Returns ------- A method that can be bound to a method of a class - Usage - ----- + Example + ------- Given an ExtensionArray subclass called MyClass, use - mymethod = create_method(my_operator) + >>> __add__ = ExtensionScalarOpsMixin.create_method(operator.add) + in the class definition of MyClass to create the operator + for addition. """ + op_name = ops._get_op_name(op, False) def _binop(self, other): def convert_values(parm): - if isinstance(parm, ExtensionArray): - ovalues = list(parm) + if isinstance(parm, ExtensionArray) or is_list_like(parm): + ovalues = parm elif is_extension_array_dtype(parm): ovalues = parm.values - elif is_list_like(parm): - ovalues = parm else: # Assume its an object ovalues = [parm] * len(self) return ovalues lvalues = convert_values(self) rvalues = convert_values(other) - try: - res = [op(a, b) for (a, b) in zip(lvalues, rvalues)] - except TypeError: - msg = ("ExtensionDtype invalid operation " + - "{opn} between {one} and {two}") - raise TypeError(msg.format(opn=op_name, - one=type(lvalues), - two=type(rvalues))) - - res_values = com._values_from_object(res) + # If the operator is not defined for the underlying objects, + # a TypeError should be raised + res = [op(a, b) for (a, b) in zip(lvalues, rvalues)] - try: - res_values = self._from_sequence(res_values) - except TypeError: - pass + if coerce_to_dtype: + try: + res = self._from_sequence(res) + except TypeError: + pass - return res_values + return res name = '__{name}__'.format(name=op_name) return set_function_name(_binop, name, cls) -class ExtensionArithmeticMixin(ExtensionOpsBase): +class ExtensionScalarArithmeticMixin(ExtensionScalarOpsMixin): """A mixin for defining the arithmetic operations on an ExtensionArray class, where it assumed that the underlying objects have the operators already defined. @@ -694,32 +694,33 @@ class ExtensionArithmeticMixin(ExtensionOpsBase): Usage ------ If you have defined a subclass MyClass(ExtensionArray), then - use MyClass(ExtensionArray, ExtensionArithmeticMixin) to + use MyClass(ExtensionArray, ExtensionScalarArithmeticMixin) to get the arithmetic operators """ - __add__ = ExtensionOpsBase.create_method(operator.add) - __radd__ = ExtensionOpsBase.create_method(ops.radd) - __sub__ = ExtensionOpsBase.create_method(operator.sub) - __rsub__ = ExtensionOpsBase.create_method(ops.rsub) - __mul__ = ExtensionOpsBase.create_method(operator.mul) - __rmul__ = ExtensionOpsBase.create_method(ops.rmul) - __pow__ = ExtensionOpsBase.create_method(operator.pow) - __rpow__ = ExtensionOpsBase.create_method(ops.rpow) - __mod__ = ExtensionOpsBase.create_method(operator.mod) - __rmod__ = ExtensionOpsBase.create_method(ops.rmod) - __floordiv__ = ExtensionOpsBase.create_method(operator.floordiv) - __rfloordiv__ = ExtensionOpsBase.create_method(ops.rfloordiv) - __truediv__ = ExtensionOpsBase.create_method(operator.truediv) - __rtruediv__ = ExtensionOpsBase.create_method(ops.rtruediv) + __add__ = ExtensionScalarOpsMixin._create_method(operator.add) + __radd__ = ExtensionScalarOpsMixin._create_method(ops.radd) + __sub__ = ExtensionScalarOpsMixin._create_method(operator.sub) + __rsub__ = ExtensionScalarOpsMixin._create_method(ops.rsub) + __mul__ = ExtensionScalarOpsMixin._create_method(operator.mul) + __rmul__ = ExtensionScalarOpsMixin._create_method(ops.rmul) + __pow__ = ExtensionScalarOpsMixin._create_method(operator.pow) + __rpow__ = ExtensionScalarOpsMixin._create_method(ops.rpow) + __mod__ = ExtensionScalarOpsMixin._create_method(operator.mod) + __rmod__ = ExtensionScalarOpsMixin._create_method(ops.rmod) + __floordiv__ = ExtensionScalarOpsMixin._create_method(operator.floordiv) + __rfloordiv__ = ExtensionScalarOpsMixin._create_method(ops.rfloordiv) + __truediv__ = ExtensionScalarOpsMixin._create_method(operator.truediv) + __rtruediv__ = ExtensionScalarOpsMixin._create_method(ops.rtruediv) if not PY3: - __div__ = ExtensionOpsBase.create_method(operator.div) - __rdiv__ = ExtensionOpsBase.create_method(ops.rdiv) + __div__ = ExtensionScalarOpsMixin._create_method(operator.div) + __rdiv__ = ExtensionScalarOpsMixin._create_method(ops.rdiv) - __divmod__ = ExtensionOpsBase.create_method(divmod) + __divmod__ = ExtensionScalarOpsMixin._create_method(divmod) + __rdivmod__ = ExtensionScalarOpsMixin._create_method(ops.rdivmod) -class ExtensionComparisonMixin(ExtensionOpsBase): +class ExtensionScalarComparisonMixin(ExtensionScalarOpsMixin): """A mixin for defining the comparison operations on an ExtensionArray class, where it assumed that the underlying objects have the operators already defined. @@ -728,11 +729,12 @@ class ExtensionComparisonMixin(ExtensionOpsBase): ------ If you have defined a subclass MyClass(ExtensionArray), then use MyClass(ExtensionArray, ExtensionComparisonMixin) to - get the arithmetic operators + get the comparison operators """ - __eq__ = ExtensionOpsBase.create_method(operator.eq) - __ne__ = ExtensionOpsBase.create_method(operator.ne) - __lt__ = ExtensionOpsBase.create_method(operator.lt) - __gt__ = ExtensionOpsBase.create_method(operator.gt) - __le__ = ExtensionOpsBase.create_method(operator.le) - __ge__ = ExtensionOpsBase.create_method(operator.ge) + + __eq__ = ExtensionScalarOpsMixin._create_method(operator.eq, False) + __ne__ = ExtensionScalarOpsMixin._create_method(operator.ne, False) + __lt__ = ExtensionScalarOpsMixin._create_method(operator.lt, False) + __gt__ = ExtensionScalarOpsMixin._create_method(operator.gt, False) + __le__ = ExtensionScalarOpsMixin._create_method(operator.le, False) + __ge__ = ExtensionScalarOpsMixin._create_method(operator.ge, False) diff --git a/pandas/core/ops.py b/pandas/core/ops.py index 898a4f129db39..13f3ce0083d26 100644 --- a/pandas/core/ops.py +++ b/pandas/core/ops.py @@ -991,20 +991,14 @@ def _construct_divmod_result(left, result, index, name, dtype): ) -def dispatch_to_extension_op(left, right, op_name): +def dispatch_to_extension_op(op, left, right): """ Assume that left is a Series backed by an ExtensionArray, - apply the operator defined by op_name. + apply the operator defined by op. """ - method = getattr(left.values, op_name, None) - if method is not None: - res_values = method(right) - if method is None or res_values is NotImplemented: - msg = "ExtensionArray invalid operation {opn} between {one} and {two}" - raise TypeError(msg.format(opn=op_name, - one=type(left.values), - two=type(right))) + # This will raise TypeError if the op is not defined on the ExtensionArray + res_values = op(left.values, right) res_name = get_op_result_name(left, right) return left._constructor(res_values, index=left.index, @@ -1080,7 +1074,7 @@ def wrapper(left, right): "{op}".format(typ=type(left).__name__, op=str_rep)) elif is_extension_array_dtype(left): - return dispatch_to_extension_op(left, right, op_name) + return dispatch_to_extension_op(op, left, right) lvalues = left.values rvalues = right @@ -1233,7 +1227,7 @@ def wrapper(self, other, axis=None): name=res_name) elif is_extension_array_dtype(self): - return dispatch_to_extension_op(self, other, op_name) + return dispatch_to_extension_op(op, self, other) elif isinstance(other, ABCSeries): # By this point we have checked that self._indexed_same(other) diff --git a/pandas/tests/extension/base/__init__.py b/pandas/tests/extension/base/__init__.py index 9da985625c4ee..f2e41bfe6a21d 100644 --- a/pandas/tests/extension/base/__init__.py +++ b/pandas/tests/extension/base/__init__.py @@ -47,6 +47,7 @@ class TestMyDtype(BaseDtypeTests): from .groupby import BaseGroupbyTests # noqa from .interface import BaseInterfaceTests # noqa from .methods import BaseMethodsTests # noqa +from .ops import BaseOpsTests # noqa from .missing import BaseMissingTests # noqa from .reshaping import BaseReshapingTests # noqa from .setitem import BaseSetitemTests # noqa diff --git a/pandas/tests/extension/base/ops.py b/pandas/tests/extension/base/ops.py new file mode 100644 index 0000000000000..9f51f709a18c2 --- /dev/null +++ b/pandas/tests/extension/base/ops.py @@ -0,0 +1,66 @@ +import pytest +import numpy as np +import pandas as pd +from .base import BaseExtensionTests + + +class BaseOpsTests(BaseExtensionTests): + """Various Series and DataFrame ops methos.""" + + def check_op(self, s, op, other, exc=NotImplementedError): + + with pytest.raises(exc): + getattr(s, op)(other) + + def test_arith_scalar(self, data, all_arithmetic_operators): + # scalar + op = all_arithmetic_operators + s = pd.Series(data) + self.check_op(s, op, 1, exc=TypeError) + + def test_arith_array(self, data, all_arithmetic_operators): + # ndarray & other series + op = all_arithmetic_operators + s = pd.Series(data) + self.check_op(s, op, np.ones(len(s), dtype=s.dtype.type), + exc=TypeError) + + def test_divmod(self, data): + s = pd.Series(data) + self.check_op(s, divmod, 1, exc=TypeError) + + def _compare_other(self, data, op, other): + s = pd.Series(data) + + if op in '__eq__': + assert getattr(data, op)(other) is NotImplemented + assert not getattr(s, op)(other).all() + elif op in '__ne__': + assert getattr(data, op)(other) is NotImplemented + assert getattr(s, op)(other).all() + + else: + + # array + getattr(data, op)(other) is NotImplementedError + + # series + s = pd.Series(data) + with pytest.raises(TypeError): + getattr(s, op)(other) + + def test_compare_scalar(self, data, all_compare_operators): + op = all_compare_operators + self._compare_other(data, op, 0) + + def test_compare_array(self, data, all_compare_operators): + op = all_compare_operators + other = [0] * len(data) + self._compare_other(data, op, other) + + def test_error(self, data, all_arithmetic_operators): + + # invalid ops + op = all_arithmetic_operators + with pytest.raises(AttributeError): + getattr(data, op) diff --git a/pandas/tests/extension/category/test_categorical.py b/pandas/tests/extension/category/test_categorical.py index 530a4e7a22a7a..155810b54ba68 100644 --- a/pandas/tests/extension/category/test_categorical.py +++ b/pandas/tests/extension/category/test_categorical.py @@ -157,3 +157,18 @@ def test_value_counts(self, all_data, dropna): class TestCasting(base.BaseCastingTests): pass + + +class TestOps(base.BaseOpsTests): + + def _compare_other(self, data, op, other): + + if op == '__eq__': + assert not getattr(data, op)(other).all() + + elif op == '__ne__': + assert getattr(data, op)(other).all() + + else: + with pytest.raises(TypeError): + getattr(data, op)(other) diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 59db8439b8747..e4010023805ca 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -7,8 +7,8 @@ import pandas as pd from pandas.core.arrays import (ExtensionArray, - ExtensionArithmeticMixin, - ExtensionComparisonMixin) + ExtensionScalarArithmeticMixin, + ExtensionScalarComparisonMixin) from pandas.core.dtypes.base import ExtensionDtype @@ -26,14 +26,15 @@ def construct_from_string(cls, string): "'{}'".format(cls, string)) -class DecimalArray(ExtensionArray, ExtensionArithmeticMixin, - ExtensionComparisonMixin): +class DecimalArray(ExtensionArray, ExtensionScalarArithmeticMixin, + ExtensionScalarComparisonMixin): dtype = DecimalDtype() def __init__(self, values): for val in values: if not isinstance(val, self.dtype.type): - raise TypeError + raise TypeError("All values must be of type " + + str(self.dtype.type)) values = np.asarray(values, dtype=object) self._data = values diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 6f6ba5ff6bab3..a42fe3320ae68 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -1,4 +1,5 @@ import decimal +import operator import numpy as np import pandas as pd @@ -7,9 +8,6 @@ from pandas.tests.extension import base -from pandas.tests.series.test_operators import TestSeriesOperators -from pandas.util._decorators import cache_readonly - from .array import DecimalDtype, DecimalArray, make_data @@ -188,34 +186,62 @@ def test_dataframe_constructor_with_different_dtype_raises(): pd.DataFrame({"A": arr}, dtype='int64') -_ts = pd.Series(DecimalArray(make_data())) +class TestOps(BaseDecimal, base.BaseOpsTests): + def check_op(self, s, op_name, other): + + short_opname = op_name.strip('_') + if short_opname[0] == 'r': + short_opname = short_opname[1:] + op = getattr(operator, short_opname) + result = op(s, other) + expected = s.combine(other, op) + self.assert_series_equal(result, expected) + def test_arith_scalar(self, data, all_arithmetic_operators): + # scalar + op_name = all_arithmetic_operators + s = pd.Series(data) + self.check_op(s, op_name, decimal.Decimal(1.5)) -class TestOperator(BaseDecimal, TestSeriesOperators): - @cache_readonly - def ts(self): - ts = _ts.copy() - ts.name = 'ts' - return ts + def test_arith_array(self, data, all_arithmetic_operators): + op_name = all_arithmetic_operators + s = pd.Series(data) - def test_operators(self): - def absfunc(v): - if isinstance(v, pd.Series): - vals = v.values - return pd.Series(vals._from_sequence([abs(i) for i in vals])) - else: - return abs(v) context = decimal.getcontext() divbyzerotrap = context.traps[decimal.DivisionByZero] invalidoptrap = context.traps[decimal.InvalidOperation] context.traps[decimal.DivisionByZero] = 0 context.traps[decimal.InvalidOperation] = 0 - super(TestOperator, self).test_operators(absfunc) + + if "mod" not in op_name: + self.check_op(s, op_name, s * 2) + else: + self.check_op(s, op_name, pd.Series([int(d * 10) for d in data])) + + self.check_op(s, op_name, 0) context.traps[decimal.DivisionByZero] = divbyzerotrap context.traps[decimal.InvalidOperation] = invalidoptrap - def test_operators_corner(self): - pytest.skip("Cannot add empty Series of float64 to DecimalArray") + @pytest.mark.skip(reason="divmod not appropriate for decimal") + def test_divmod(self, data): + pass + + def _compare_other(self, data, op_name, other): + s = pd.Series(data) + self.check_op(s, op_name, other) + + def test_compare_scalar(self, data, all_compare_operators): + op_name = all_compare_operators + self._compare_other(data, op_name, 0.5) + + def test_compare_array(self, data, all_compare_operators): + op_name = all_compare_operators + + alter = np.random.choice([-1, 0, 1], len(data)) + # Randomly double, halve or keep same value + other = pd.Series(data) * [decimal.Decimal(pow(2.0, i)) + for i in alter] + self._compare_other(data, op_name, other) - def test_divmod(self): - pytest.skip("divmod not appropriate for Decimal type") + def test_error(self): + pass diff --git a/pandas/tests/extension/json/array.py b/pandas/tests/extension/json/array.py index 06139d5c21e8c..d3043bf0852d2 100644 --- a/pandas/tests/extension/json/array.py +++ b/pandas/tests/extension/json/array.py @@ -16,12 +16,11 @@ import random import string import sys -import operator import numpy as np from pandas.core.dtypes.base import ExtensionDtype -from pandas.core.arrays import ExtensionArray, ExtensionOpsBase +from pandas.core.arrays import ExtensionArray class JSONDtype(ExtensionDtype): @@ -44,12 +43,12 @@ def construct_from_string(cls, string): class JSONArray(ExtensionArray): dtype = JSONDtype() - __le__ = ExtensionOpsBase.create_method(operator.le) def __init__(self, values): for val in values: if not isinstance(val, self.dtype.type): - raise TypeError + raise TypeError("All values must be of type " + + str(self.dtype.type)) self.data = values # Some aliases for common attribute names to ensure pandas supports diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index b61004b705d66..07b92e57e3dc3 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -232,19 +232,5 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping): ) -def test_ops(data): - s1 = pd.Series(data) - s2 = pd.Series(data) - # Here we test if the mixin method was defined but the underlying Dtype - # did not have the method defined - with tm.assert_raises_regex(TypeError, "ExtensionDtype invalid operation"): - (s1 <= s2) - - # An object will always have __lt__ defined, so test if we catch that it - # was not implemented - with tm.assert_raises_regex(TypeError, "ExtensionArray invalid operation"): - (s1 < s2) - - # Test that if method is not defined at all, we catch that as well - with tm.assert_raises_regex(TypeError, "ExtensionArray invalid operation"): - s1 + s2 +class TestOps(BaseJSON, base.BaseOpsTests): + pass From 7bad5591bc9be95ece4f2336bd62ac05e278d266 Mon Sep 17 00:00:00 2001 From: Dr-Irv Date: Mon, 4 Jun 2018 12:06:16 -0400 Subject: [PATCH 04/12] Update docs and reorganize tests --- doc/source/extending.rst | 39 ++++++++- pandas/core/arrays/base.py | 33 ++++--- pandas/core/ops.py | 20 +++-- pandas/tests/extension/base/__init__.py | 2 +- pandas/tests/extension/base/ops.py | 87 ++++++++++++------- .../extension/category/test_categorical.py | 13 ++- .../tests/extension/decimal/test_decimal.py | 46 +++++----- pandas/tests/extension/json/test_json.py | 6 +- pandas/tests/series/test_operators.py | 6 +- 9 files changed, 163 insertions(+), 89 deletions(-) diff --git a/doc/source/extending.rst b/doc/source/extending.rst index f665b219a7bd1..197dbb75ebe64 100644 --- a/doc/source/extending.rst +++ b/doc/source/extending.rst @@ -61,7 +61,7 @@ Extension Types .. warning:: - The :class:`pandas.api.extension.ExtensionDtype` and :class:`pandas.api.extension.ExtensionArray` APIs are new and + The :class:`pandas.api.extensions.ExtensionDtype` and :class:`pandas.api.extensions.ExtensionArray` APIs are new and experimental. They may change between versions without warning. Pandas defines an interface for implementing data types and arrays that *extend* @@ -79,10 +79,10 @@ on :ref:`ecosystem.extensions`. The interface consists of two classes. -:class:`~pandas.api.extension.ExtensionDtype` +:class:`~pandas.api.extensions.ExtensionDtype` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -A :class:`pandas.api.extension.ExtensionDtype` is similar to a ``numpy.dtype`` object. It describes the +A :class:`pandas.api.extensions.ExtensionDtype` is similar to a ``numpy.dtype`` object. It describes the data type. Implementors are responsible for a few unique items like the name. One particularly important item is the ``type`` property. This should be the @@ -91,7 +91,7 @@ extension array for IP Address data, this might be ``ipaddress.IPv4Address``. See the `extension dtype source`_ for interface definition. -:class:`~pandas.api.extension.ExtensionArray` +:class:`~pandas.api.extensions.ExtensionArray` ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ This class provides all the array-like functionality. ExtensionArrays are @@ -113,6 +113,37 @@ by some other storage type, like Python lists. See the `extension array source`_ for the interface definition. The docstrings and comments contain guidance for properly implementing the interface. +:class:`~pandas.api.extensions.ExtensionArray` Operator Support +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +By default, there are no operators defined for the class :class:`~pandas.api.extensions.ExtensionArray`. +There are two ways that you can provide operator support for your ExtensionArray. One +is to define each of the operators on your ExtensionArray subclass. The second method +assumes that the underlying elements of the ExtensionArray have the individual operators +already defined. An ``ExtensionArray`` subclass consisting of objects that have arithmetic and +comparison operators defined on the underlying objects can easily support +those operators on the ``ExtensionArray``, and therefore the operators +on ``Series`` built on those ``ExtensionArray`` classes will work as expected. + +Two mixin classes, :class:`~pandas.api.extensions.ExtensionScalarArithmeticMixin` and +:class:`~pandas.api.extensions.ExtensionScalarComparisonMixin`, support this capability. +If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``, +simply include ``ExtensionScalarArithmeticMixin`` and/or +``ExtensionScalarComparisonMixin`` as parent classes of ``MyExtensionArray`` +as follows: + +.. code-block:: python + + class MyExtensionArray(ExtensionArray, ExtensionScalarArithmeticMixin, + ExtensionScalarComparisonMixin): + +Note that since ``pandas`` automatically calls the underlying operator on each +element one-by-one, this might not be as performant as implementing your own +version of the associated operators directly on the ExtensionArray. + +Testing Extension Arrays +^^^^^^^^^^^^^^^^^^^^^^^^ + We provide a test suite for ensuring that your extension arrays satisfy the expected behavior. To use the test suite, you must provide several pytest fixtures and inherit from the base test class. The required fixtures are found in diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index c404f87db421f..4a19a7ca66cdd 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -12,10 +12,9 @@ from pandas.errors import AbstractMethodError from pandas.compat.numpy import function as nv from pandas.compat import set_function_name, PY3 -from pandas.core.dtypes.common import ( - is_extension_array_dtype, - is_list_like) +from pandas.core.dtypes.common import is_list_like from pandas.core import ops +from pandas.core.ops import _get_op_name _not_implemented_message = "{} does not implement {}." @@ -636,10 +635,12 @@ def _create_method(cls, op, coerce_to_dtype=True): Parameters ---------- - op: An operator that takes arguments op(a, b) - coerce_to_dtype: boolean indicating whether to attempt to convert - the result to the underlying ExtensionArray dtype - (default True) + op: function + An operator that takes arguments op(a, b) + coerce_to_dtype: bool + boolean indicating whether to attempt to convert + the result to the underlying ExtensionArray dtype + (default True) Returns ------- @@ -656,18 +657,14 @@ def _create_method(cls, op, coerce_to_dtype=True): """ - op_name = ops._get_op_name(op, False) - def _binop(self, other): - def convert_values(parm): - if isinstance(parm, ExtensionArray) or is_list_like(parm): - ovalues = parm - elif is_extension_array_dtype(parm): - ovalues = parm.values + def convert_values(param): + if isinstance(param, ExtensionArray) or is_list_like(param): + ovalues = param else: # Assume its an object - ovalues = [parm] * len(self) + ovalues = [param] * len(self) return ovalues - lvalues = convert_values(self) + lvalues = self rvalues = convert_values(other) # If the operator is not defined for the underlying objects, @@ -682,8 +679,8 @@ def convert_values(parm): return res - name = '__{name}__'.format(name=op_name) - return set_function_name(_binop, name, cls) + op_name = _get_op_name(op, True) + return set_function_name(_binop, op_name, cls) class ExtensionScalarArithmeticMixin(ExtensionScalarOpsMixin): diff --git a/pandas/core/ops.py b/pandas/core/ops.py index 13f3ce0083d26..e6d381231793e 100644 --- a/pandas/core/ops.py +++ b/pandas/core/ops.py @@ -993,12 +993,18 @@ def _construct_divmod_result(left, result, index, name, dtype): def dispatch_to_extension_op(op, left, right): """ - Assume that left is a Series backed by an ExtensionArray, + Assume that left or right is a Series backed by an ExtensionArray, apply the operator defined by op. """ - # This will raise TypeError if the op is not defined on the ExtensionArray - res_values = op(left.values, right) + # The op calls will raise TypeError if the op is not defined + # on the ExtensionArray + if is_extension_array_dtype(left): + res_values = op(left.values, right) + else: + # We know that left is not ExtensionArray and is Series and right is + # ExtensionArray. Want to force ExtensionArray op to get called + res_values = op(list(left.values), right.values) res_name = get_op_result_name(left, right) return left._constructor(res_values, index=left.index, @@ -1073,7 +1079,9 @@ def wrapper(left, right): raise TypeError("{typ} cannot perform the operation " "{op}".format(typ=type(left).__name__, op=str_rep)) - elif is_extension_array_dtype(left): + elif (is_extension_array_dtype(left) or + (is_extension_array_dtype(right) and + not is_categorical_dtype(right))): return dispatch_to_extension_op(op, left, right) lvalues = left.values @@ -1226,7 +1234,9 @@ def wrapper(self, other, axis=None): return self._constructor(res_values, index=self.index, name=res_name) - elif is_extension_array_dtype(self): + elif (is_extension_array_dtype(self) or + (is_extension_array_dtype(other) and + not is_categorical_dtype(other))): return dispatch_to_extension_op(op, self, other) elif isinstance(other, ABCSeries): diff --git a/pandas/tests/extension/base/__init__.py b/pandas/tests/extension/base/__init__.py index 4f79d04e63441..640b894e2245f 100644 --- a/pandas/tests/extension/base/__init__.py +++ b/pandas/tests/extension/base/__init__.py @@ -47,7 +47,7 @@ class TestMyDtype(BaseDtypeTests): from .groupby import BaseGroupbyTests # noqa from .interface import BaseInterfaceTests # noqa from .methods import BaseMethodsTests # noqa -from .ops import BaseOpsTests # noqa +from .ops import BaseArithmeticOpsTests, BaseComparisonOpsTests # noqa from .missing import BaseMissingTests # noqa from .reshaping import BaseReshapingTests # noqa from .setitem import BaseSetitemTests # noqa diff --git a/pandas/tests/extension/base/ops.py b/pandas/tests/extension/base/ops.py index 9f51f709a18c2..12ca85fad5624 100644 --- a/pandas/tests/extension/base/ops.py +++ b/pandas/tests/extension/base/ops.py @@ -1,66 +1,89 @@ import pytest -import numpy as np + +import operator + import pandas as pd from .base import BaseExtensionTests -class BaseOpsTests(BaseExtensionTests): - """Various Series and DataFrame ops methos.""" +class BaseOpsUtil(BaseExtensionTests): + def check_opname(self, s, op_name, other, exc=NotImplementedError): + + short_opname = op_name.strip('_') + try: + op = getattr(operator, short_opname) + except AttributeError: + # Assume it is the reverse operator + rop = getattr(operator, short_opname[1:]) + op = lambda x, y: rop(y, x) - def check_op(self, s, op, other, exc=NotImplementedError): + self._check_op(s, op, other, exc) + + def _check_op(self, s, op, other, exc=NotImplementedError): + if exc is None: + result = op(s, other) + expected = s.combine(other, op) + self.assert_series_equal(result, expected) + else: + with pytest.raises(exc): + op(s, other) - with pytest.raises(exc): - getattr(s, op)(other) + +class BaseArithmeticOpsTests(BaseOpsUtil): + """Various Series and DataFrame arithmetic ops methods.""" def test_arith_scalar(self, data, all_arithmetic_operators): # scalar - op = all_arithmetic_operators + op_name = all_arithmetic_operators s = pd.Series(data) - self.check_op(s, op, 1, exc=TypeError) + self.check_opname(s, op_name, s.iloc[0], exc=TypeError) def test_arith_array(self, data, all_arithmetic_operators): # ndarray & other series - op = all_arithmetic_operators + op_name = all_arithmetic_operators s = pd.Series(data) - self.check_op(s, op, np.ones(len(s), dtype=s.dtype.type), - exc=TypeError) + self.check_opname(s, op_name, [s.iloc[0]] * len(s), exc=TypeError) def test_divmod(self, data): s = pd.Series(data) - self.check_op(s, divmod, 1, exc=TypeError) + self._check_op(s, divmod, 1, exc=TypeError) + self._check_op(1, divmod, s, exc=TypeError) - def _compare_other(self, data, op, other): + def test_error(self, data, all_arithmetic_operators): + # invalid ops + op_name = all_arithmetic_operators + with pytest.raises(AttributeError): + getattr(data, op_name) + + +class BaseComparisonOpsTests(BaseOpsUtil): + """Various Series and DataFrame comparison ops methods.""" + + def _compare_other(self, data, op_name, other): s = pd.Series(data) - if op in '__eq__': - assert getattr(data, op)(other) is NotImplemented - assert not getattr(s, op)(other).all() - elif op in '__ne__': - assert getattr(data, op)(other) is NotImplemented - assert getattr(s, op)(other).all() + if op_name == '__eq__': + assert getattr(data, op_name)(other) is NotImplemented + assert not getattr(s, op_name)(other).all() + elif op_name == '__ne__': + assert getattr(data, op_name)(other) is NotImplemented + assert getattr(s, op_name)(other).all() else: # array - getattr(data, op)(other) is NotImplementedError + getattr(data, op_name)(other) is NotImplementedError # series s = pd.Series(data) with pytest.raises(TypeError): - getattr(s, op)(other) + getattr(s, op_name)(other) def test_compare_scalar(self, data, all_compare_operators): - op = all_compare_operators - self._compare_other(data, op, 0) + op_name = all_compare_operators + self._compare_other(data, op_name, 0) def test_compare_array(self, data, all_compare_operators): - op = all_compare_operators + op_name = all_compare_operators other = [0] * len(data) - self._compare_other(data, op, other) - - def test_error(self, data, all_arithmetic_operators): - - # invalid ops - op = all_arithmetic_operators - with pytest.raises(AttributeError): - getattr(data, op) + self._compare_other(data, op_name, other) diff --git a/pandas/tests/extension/category/test_categorical.py b/pandas/tests/extension/category/test_categorical.py index 155810b54ba68..c27a509c9e6ac 100644 --- a/pandas/tests/extension/category/test_categorical.py +++ b/pandas/tests/extension/category/test_categorical.py @@ -159,7 +159,18 @@ class TestCasting(base.BaseCastingTests): pass -class TestOps(base.BaseOpsTests): +class TestArithmeticOps(base.BaseArithmeticOpsTests): + + def test_arith_scalar(self, data, all_arithmetic_operators): + + op_name = all_arithmetic_operators + if op_name != '__rmod__': + super(TestArithmeticOps, self).test_arith_scalar(data, op_name) + else: + pytest.skip('rmod never called when string is first argument') + + +class TestComparisonOps(base.BaseComparisonOpsTests): def _compare_other(self, data, op, other): diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index a42fe3320ae68..ae6587ab13425 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -1,5 +1,4 @@ import decimal -import operator import numpy as np import pandas as pd @@ -186,22 +185,11 @@ def test_dataframe_constructor_with_different_dtype_raises(): pd.DataFrame({"A": arr}, dtype='int64') -class TestOps(BaseDecimal, base.BaseOpsTests): - def check_op(self, s, op_name, other): +class TestArithmeticOps(BaseDecimal, base.BaseArithmeticOpsTests): - short_opname = op_name.strip('_') - if short_opname[0] == 'r': - short_opname = short_opname[1:] - op = getattr(operator, short_opname) - result = op(s, other) - expected = s.combine(other, op) - self.assert_series_equal(result, expected) - - def test_arith_scalar(self, data, all_arithmetic_operators): - # scalar - op_name = all_arithmetic_operators - s = pd.Series(data) - self.check_op(s, op_name, decimal.Decimal(1.5)) + def check_opname(self, s, op_name, other, exc=None): + super(TestArithmeticOps, self).check_opname(s, op_name, + other, exc=None) def test_arith_array(self, data, all_arithmetic_operators): op_name = all_arithmetic_operators @@ -213,12 +201,15 @@ def test_arith_array(self, data, all_arithmetic_operators): context.traps[decimal.DivisionByZero] = 0 context.traps[decimal.InvalidOperation] = 0 + # Decimal supports ops with int, but not float + other = pd.Series([int(d * 100) for d in data]) + self.check_opname(s, op_name, other) + if "mod" not in op_name: - self.check_op(s, op_name, s * 2) - else: - self.check_op(s, op_name, pd.Series([int(d * 10) for d in data])) + self.check_opname(s, op_name, s * 2) - self.check_op(s, op_name, 0) + self.check_opname(s, op_name, 0) + self.check_opname(s, op_name, 5) context.traps[decimal.DivisionByZero] = divbyzerotrap context.traps[decimal.InvalidOperation] = invalidoptrap @@ -226,9 +217,19 @@ def test_arith_array(self, data, all_arithmetic_operators): def test_divmod(self, data): pass + def test_error(self): + pass + + +class TestComparisonOps(BaseDecimal, base.BaseComparisonOpsTests): + + def check_opname(self, s, op_name, other, exc=None): + super(TestComparisonOps, self).check_opname(s, op_name, + other, exc=None) + def _compare_other(self, data, op_name, other): s = pd.Series(data) - self.check_op(s, op_name, other) + self.check_opname(s, op_name, other) def test_compare_scalar(self, data, all_compare_operators): op_name = all_compare_operators @@ -242,6 +243,3 @@ def test_compare_array(self, data, all_compare_operators): other = pd.Series(data) * [decimal.Decimal(pow(2.0, i)) for i in alter] self._compare_other(data, op_name, other) - - def test_error(self): - pass diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index 07b92e57e3dc3..63c6fc0bc741b 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -232,5 +232,9 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping): ) -class TestOps(BaseJSON, base.BaseOpsTests): +class TestArithmeticOps(BaseJSON, base.BaseArithmeticOpsTests): + pass + + +class TestComparisonOps(BaseJSON, base.BaseComparisonOpsTests): pass diff --git a/pandas/tests/series/test_operators.py b/pandas/tests/series/test_operators.py index ed0ae64be53ad..ecb74622edf10 100644 --- a/pandas/tests/series/test_operators.py +++ b/pandas/tests/series/test_operators.py @@ -1216,11 +1216,11 @@ def test_neg(self): def test_invert(self): assert_series_equal(-(self.series < 0), ~(self.series < 0)) - def test_operators(self, absfunc=np.abs): + def test_operators(self): def _check_op(series, other, op, pos_only=False, check_dtype=True): - left = absfunc(series) if pos_only else series - right = absfunc(other) if pos_only else other + left = np.abs(series) if pos_only else series + right = np.abs(other) if pos_only else other cython_or_numpy = op(left, right) python = left.combine(right, op) From dfcda3bcf03159b98aafbb572e5b47c3afac3f38 Mon Sep 17 00:00:00 2001 From: Dr-Irv Date: Tue, 5 Jun 2018 15:29:36 -0400 Subject: [PATCH 05/12] Allow users to define their own operators. Additional docs --- doc/source/extending.rst | 66 +++++++-- doc/source/whatsnew/v0.24.0.txt | 32 ++++- pandas/api/extensions/__init__.py | 4 +- pandas/core/arrays/__init__.py | 4 +- pandas/core/arrays/base.py | 172 ++++++++++++++++++----- pandas/tests/extension/json/array.py | 7 +- pandas/tests/extension/json/test_json.py | 3 +- 7 files changed, 230 insertions(+), 58 deletions(-) diff --git a/doc/source/extending.rst b/doc/source/extending.rst index 197dbb75ebe64..08f9e2f69f2e6 100644 --- a/doc/source/extending.rst +++ b/doc/source/extending.rst @@ -113,21 +113,67 @@ by some other storage type, like Python lists. See the `extension array source`_ for the interface definition. The docstrings and comments contain guidance for properly implementing the interface. +.. _extending.extension.operator: + :class:`~pandas.api.extensions.ExtensionArray` Operator Support ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ By default, there are no operators defined for the class :class:`~pandas.api.extensions.ExtensionArray`. -There are two ways that you can provide operator support for your ExtensionArray. One -is to define each of the operators on your ExtensionArray subclass. The second method -assumes that the underlying elements of the ExtensionArray have the individual operators -already defined. An ``ExtensionArray`` subclass consisting of objects that have arithmetic and -comparison operators defined on the underlying objects can easily support -those operators on the ``ExtensionArray``, and therefore the operators -on ``Series`` built on those ``ExtensionArray`` classes will work as expected. +There are two approaches for providing operator support for your ExtensionArray: + +1. Define each of the operators on your ExtensionArray subclass. +2. Use operators from pandas defined on the ExtensionArray subclass based on already defined + operators on the underlying elements. + +For the first approach, you will need to create a mixin class with a single class method, +with the following signature: + +.. code-block:: python + + @classmethod + def _create_method(cls, op, coerce_to_dtype=True): + +The method ``create_method`` should return a method with the signature +``binop(self, other)`` that returns the result of applying the operator ``op`` +to your ExtensionArray subclass. Your mixin class will then become a base class +for the provided :class:`ExtensionArithmeticOpsMixin` and +:class:`ExtensionComparisonOpsMixin` classes. + +For example, if your ExtensionArray subclass +is called ``MyExtensionArray``, you could create a mixin class ``MyOpsMixin`` +that has the following skeleton: + +.. code-block:: python + + class MyOpsMixin(object): + @classmethod + def _create_method(cls, op, coerce_to_dtype=True): + def _binop(self, other): + # Your implementation of the operator op + return _binop + +Then to use this class to define the operators for ``MyExtensionArray``, you can write: + +.. code-block:: python + + class MyExtensionArray(ExtensionArray, + ExtensionArithmeticOpsMixin(MyOpsMixin), + ExtensionComparisonOpsMixin(MyOpsMixin)) + +The mixin classes :class:`ExtensionArithmeticOpsMixin` and +:class:`ExtensionComparisonOpsMixin` will then define the appropriate operators +using your implementation of those operators in ``MyOpsMixin``. + +The second approach assumes that the underlying elements of the ExtensionArray +have the individual operators already defined. In other words, if your ExtensionArray +named ``MyExtensionArray`` is implemented so that each element is an instance +of the class ``MyExtensionElement``, then if the operators are defined +for ``MyExtensionElement``, the second approach will automatically +define the operators for ``MyExtensionArray``. Two mixin classes, :class:`~pandas.api.extensions.ExtensionScalarArithmeticMixin` and -:class:`~pandas.api.extensions.ExtensionScalarComparisonMixin`, support this capability. -If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``, +:class:`~pandas.api.extensions.ExtensionScalarComparisonMixin`, support this second +approach. If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``, simply include ``ExtensionScalarArithmeticMixin`` and/or ``ExtensionScalarComparisonMixin`` as parent classes of ``MyExtensionArray`` as follows: @@ -141,6 +187,8 @@ Note that since ``pandas`` automatically calls the underlying operator on each element one-by-one, this might not be as performant as implementing your own version of the associated operators directly on the ExtensionArray. +.. _extending.extension.testing: + Testing Extension Arrays ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index e8f0dfa9bf84d..1f87e29c6cb33 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -10,16 +10,31 @@ New features .. _whatsnew_0240.enhancements.other: +.. _whatsnew_0240.enhancements.extension_array_operators + ``ExtensionArray`` operator support ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -An ``ExtensionArray`` subclass consisting of objects that have arithmetic and -comparison operators defined on the underlying objects can easily support -those operators on the ``ExtensionArray``, and therefore the operators -on ``Series`` built on those ``ExtensionArray`` classes will work as expected. - -Two new mixin classes, :class:`ExtensionScalarArithmeticMixin` and -:class:`ExtensionScalarComparisonMixin`, support this capability. +A ``Series`` based on ``ExtensionArray`` now supports arithmetic and comparison +operators. There are two approaches for providing operator support for an ExtensionArray: + +1. Define each of the operators on your ExtensionArray subclass. +2. Use operators from pandas defined on the ExtensionArray subclass based on already defined + operators on the underlying elements. + +To use the first approach where you define your own implementation of the operators, +use one or both of the mixin classes, :class:`ExtensionArithmeticOpsMixin` and +:class:`ExtensionComparisonOpsMixin` that, by default, will create +operators that are ``NotImplemented``. To use those classes, you will need to create +a class that has the implementation of the operator methods. Details can be found in the +:ref:`ExtensionArray Operator Support ` documentation section. + +For the second approach, which is appropriate if your ExtensionArray contains +elements that already have the operators +defined on a per-element basis, pandas provides two mixins, +:class:`ExtensionScalarArithmeticMixin` and :class:`ExtensionScalarComparisonMixin`, +that you can use that will automatically define the operators on your ExtensionArray +subclass. If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``, simply include ``ExtensionScalarArithmeticMixin`` and/or ``ExtensionScalarComparisonMixin`` as parent classes of ``MyExtensionArray`` @@ -30,6 +45,9 @@ as follows: class MyExtensionArray(ExtensionArray, ExtensionScalarArithmeticMixin, ExtensionScalarComparisonMixin): +See the :ref:`ExtensionArray Operator Support +` documentation section for details on both +ways of adding operator support. Other Enhancements ^^^^^^^^^^^^^^^^^^ diff --git a/pandas/api/extensions/__init__.py b/pandas/api/extensions/__init__.py index 5057c707e91d2..152a45a185aa2 100644 --- a/pandas/api/extensions/__init__.py +++ b/pandas/api/extensions/__init__.py @@ -6,5 +6,7 @@ from pandas.core.arrays.base import (ExtensionArray, # noqa ExtensionScalarArithmeticMixin, ExtensionScalarComparisonMixin, - ExtensionScalarOpsMixin) + ExtensionScalarOpsMixin, + ExtensionArithmeticOpsMixin, + ExtensionComparisonOpsMixin) from pandas.core.dtypes.dtypes import ExtensionDtype # noqa diff --git a/pandas/core/arrays/__init__.py b/pandas/core/arrays/__init__.py index b37f02bea9ea7..5ee1a6d47f36f 100644 --- a/pandas/core/arrays/__init__.py +++ b/pandas/core/arrays/__init__.py @@ -1,5 +1,7 @@ from .base import (ExtensionArray, # noqa ExtensionScalarArithmeticMixin, ExtensionScalarComparisonMixin, - ExtensionScalarOpsMixin) + ExtensionScalarOpsMixin, + ExtensionArithmeticOpsMixin, + ExtensionComparisonOpsMixin) from .categorical import Categorical # noqa diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 4a19a7ca66cdd..a39b75250de98 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -14,7 +14,6 @@ from pandas.compat import set_function_name, PY3 from pandas.core.dtypes.common import is_list_like from pandas.core import ops -from pandas.core.ops import _get_op_name _not_implemented_message = "{} does not implement {}." @@ -618,6 +617,37 @@ def _ndarray_values(self): return np.array(self) +class ExtensionDefaultOpsMixin(object): + """ + A base class for default ops, that returns NotImplemented + """ + @classmethod + def _create_method(cls, op, coerce_to_dtype=True): + """ + A class method that returns a method will correspond to an + operator for an ExtensionArray subclass, by always indicating + the operator is NotImplemented + + Parameters + ---------- + op: function + An operator that takes arguments op(a, b) + coerce_to_dtype: bool + boolean indicating whether to attempt to convert + the result to the underlying ExtensionArray dtype + (default True) + + Returns + ------- + A method that can be bound to a method of a class + """ + + def returnNotImplemented(self, other): + return NotImplemented + + return returnNotImplemented + + class ExtensionScalarOpsMixin(object): """ A base class for the mixins for different operators. @@ -653,7 +683,8 @@ def _create_method(cls, op, coerce_to_dtype=True): >>> __add__ = ExtensionScalarOpsMixin.create_method(operator.add) in the class definition of MyClass to create the operator - for addition. + for addition, that will be based on the operator implementation + of the underlying elements of the ExtensionArray """ @@ -679,13 +710,105 @@ def convert_values(param): return res - op_name = _get_op_name(op, True) + op_name = ops._get_op_name(op, True) return set_function_name(_binop, op_name, cls) -class ExtensionScalarArithmeticMixin(ExtensionScalarOpsMixin): +def ExtensionArithmeticOpsMixin(base=ExtensionDefaultOpsMixin): + """A mixin that will define the arithmetic operators. + + Parameters + ---------- + base: class + A class with the class method _create_method() defined as follows: + + @classmethod + def _create_method(cls, op, coerce_to_dtype=True): + Parameters + ---------- + op: function + An operator that takes arguments op(a, b) + coerce_to_dtype: bool + boolean indicating whether to attempt to convert + the result to the underlying ExtensionArray dtype + (default True) + + Returns + ------- + A method that can be bound to a method of a class. That method + should return the result of op(self, other) where self is + an ExtensionArray subclass + """ + + class _ExtensionArithmeticOpsMixin(base): + __add__ = base._create_method(operator.add) + __radd__ = base._create_method(ops.radd) + __sub__ = base._create_method(operator.sub) + __rsub__ = base._create_method(ops.rsub) + __mul__ = base._create_method(operator.mul) + __rmul__ = base._create_method(ops.rmul) + __pow__ = base._create_method(operator.pow) + __rpow__ = base._create_method(ops.rpow) + __mod__ = base._create_method(operator.mod) + __rmod__ = base._create_method(ops.rmod) + __floordiv__ = base._create_method(operator.floordiv) + __rfloordiv__ = base._create_method(ops.rfloordiv) + __truediv__ = base._create_method(operator.truediv) + __rtruediv__ = base._create_method(ops.rtruediv) + if not PY3: + __div__ = base._create_method(operator.div) + __rdiv__ = base._create_method(ops.rdiv) + + __divmod__ = base._create_method(divmod) + __rdivmod__ = base._create_method(ops.rdivmod) + + _ExtensionArithmeticOpsMixin.__name__ = ( + "ExtensionArithmeticOpsMixin_" + base.__name__) + return _ExtensionArithmeticOpsMixin + + +def ExtensionComparisonOpsMixin(base=ExtensionDefaultOpsMixin): + """A mixin that will define the comparison operators. + + Parameters + ---------- + base: class + A class with the class method _create_method() defined as follows: + + @classmethod + def _create_method(cls, op, coerce_to_dtype=True): + Parameters + ---------- + op: function + An operator that takes arguments op(a, b) + coerce_to_dtype: bool + boolean indicating whether to attempt to convert + the result to the underlying ExtensionArray dtype + (default True) + + Returns + ------- + A method that can be bound to a method of a class. That method + should return the result of op(self, other) where self is + an ExtensionArray subclass + """ + class _ExtensionComparisonOpsMixin(base): + __eq__ = base._create_method(operator.eq, False) + __ne__ = base._create_method(operator.ne, False) + __lt__ = base._create_method(operator.lt, False) + __gt__ = base._create_method(operator.gt, False) + __le__ = base._create_method(operator.le, False) + __ge__ = base._create_method(operator.ge, False) + + _ExtensionComparisonOpsMixin.__name__ = ( + "ExtensionComparisonOpsMixin_" + base.__name__) + return _ExtensionComparisonOpsMixin + + +class ExtensionScalarArithmeticMixin( + ExtensionArithmeticOpsMixin(ExtensionScalarOpsMixin)): """A mixin for defining the arithmetic operations on an ExtensionArray - class, where it assumed that the underlying objects have the operators + class, where it is assumed that the underlying objects have the operators already defined. Usage @@ -694,44 +817,19 @@ class ExtensionScalarArithmeticMixin(ExtensionScalarOpsMixin): use MyClass(ExtensionArray, ExtensionScalarArithmeticMixin) to get the arithmetic operators """ + pass + - __add__ = ExtensionScalarOpsMixin._create_method(operator.add) - __radd__ = ExtensionScalarOpsMixin._create_method(ops.radd) - __sub__ = ExtensionScalarOpsMixin._create_method(operator.sub) - __rsub__ = ExtensionScalarOpsMixin._create_method(ops.rsub) - __mul__ = ExtensionScalarOpsMixin._create_method(operator.mul) - __rmul__ = ExtensionScalarOpsMixin._create_method(ops.rmul) - __pow__ = ExtensionScalarOpsMixin._create_method(operator.pow) - __rpow__ = ExtensionScalarOpsMixin._create_method(ops.rpow) - __mod__ = ExtensionScalarOpsMixin._create_method(operator.mod) - __rmod__ = ExtensionScalarOpsMixin._create_method(ops.rmod) - __floordiv__ = ExtensionScalarOpsMixin._create_method(operator.floordiv) - __rfloordiv__ = ExtensionScalarOpsMixin._create_method(ops.rfloordiv) - __truediv__ = ExtensionScalarOpsMixin._create_method(operator.truediv) - __rtruediv__ = ExtensionScalarOpsMixin._create_method(ops.rtruediv) - if not PY3: - __div__ = ExtensionScalarOpsMixin._create_method(operator.div) - __rdiv__ = ExtensionScalarOpsMixin._create_method(ops.rdiv) - - __divmod__ = ExtensionScalarOpsMixin._create_method(divmod) - __rdivmod__ = ExtensionScalarOpsMixin._create_method(ops.rdivmod) - - -class ExtensionScalarComparisonMixin(ExtensionScalarOpsMixin): +class ExtensionScalarComparisonMixin( + ExtensionComparisonOpsMixin(ExtensionScalarOpsMixin)): """A mixin for defining the comparison operations on an ExtensionArray - class, where it assumed that the underlying objects have the operators + class, where it is assumed that the underlying objects have the operators already defined. Usage ------ If you have defined a subclass MyClass(ExtensionArray), then - use MyClass(ExtensionArray, ExtensionComparisonMixin) to + use MyClass(ExtensionArray, ExtensionScalarComparisonMixin) to get the comparison operators """ - - __eq__ = ExtensionScalarOpsMixin._create_method(operator.eq, False) - __ne__ = ExtensionScalarOpsMixin._create_method(operator.ne, False) - __lt__ = ExtensionScalarOpsMixin._create_method(operator.lt, False) - __gt__ = ExtensionScalarOpsMixin._create_method(operator.gt, False) - __le__ = ExtensionScalarOpsMixin._create_method(operator.le, False) - __ge__ = ExtensionScalarOpsMixin._create_method(operator.ge, False) + pass diff --git a/pandas/tests/extension/json/array.py b/pandas/tests/extension/json/array.py index d3043bf0852d2..12bc1e072f5b1 100644 --- a/pandas/tests/extension/json/array.py +++ b/pandas/tests/extension/json/array.py @@ -20,7 +20,9 @@ import numpy as np from pandas.core.dtypes.base import ExtensionDtype -from pandas.core.arrays import ExtensionArray +from pandas.core.arrays import (ExtensionArray, + ExtensionArithmeticOpsMixin, + ExtensionComparisonOpsMixin) class JSONDtype(ExtensionDtype): @@ -41,7 +43,8 @@ def construct_from_string(cls, string): "'{}'".format(cls, string)) -class JSONArray(ExtensionArray): +class JSONArray(ExtensionArray, ExtensionArithmeticOpsMixin(), + ExtensionComparisonOpsMixin()): dtype = JSONDtype() def __init__(self, values): diff --git a/pandas/tests/extension/json/test_json.py b/pandas/tests/extension/json/test_json.py index 63c6fc0bc741b..f0c1703a20172 100644 --- a/pandas/tests/extension/json/test_json.py +++ b/pandas/tests/extension/json/test_json.py @@ -233,7 +233,8 @@ def test_groupby_extension_agg(self, as_index, data_for_grouping): class TestArithmeticOps(BaseJSON, base.BaseArithmeticOpsTests): - pass + def test_error(self, data, all_arithmetic_operators): + pass class TestComparisonOps(BaseJSON, base.BaseComparisonOpsTests): From aaaa8fd2445473177bbfd665c08997ea032c5f76 Mon Sep 17 00:00:00 2001 From: Dr-Irv Date: Tue, 5 Jun 2018 18:25:35 -0400 Subject: [PATCH 06/12] Use approach of initializing operators via func calls --- doc/source/extending.rst | 62 ++----- doc/source/whatsnew/v0.24.0.txt | 25 ++- pandas/api/extensions/__init__.py | 6 +- pandas/core/arrays/__init__.py | 6 +- pandas/core/arrays/base.py | 205 ++++++------------------ pandas/tests/extension/decimal/array.py | 9 +- pandas/tests/extension/json/array.py | 7 +- 7 files changed, 86 insertions(+), 234 deletions(-) diff --git a/doc/source/extending.rst b/doc/source/extending.rst index 08f9e2f69f2e6..8c2b187a85573 100644 --- a/doc/source/extending.rst +++ b/doc/source/extending.rst @@ -122,66 +122,34 @@ By default, there are no operators defined for the class :class:`~pandas.api.ext There are two approaches for providing operator support for your ExtensionArray: 1. Define each of the operators on your ExtensionArray subclass. -2. Use operators from pandas defined on the ExtensionArray subclass based on already defined - operators on the underlying elements. +2. Use an operator implementation from pandas that depends on operators that are already defined + on the underlying elements (scalars) of the ExtensionArray. -For the first approach, you will need to create a mixin class with a single class method, -with the following signature: +For the first approach, you define selected operators, e.g., ``_add__``, ``__le__``, etc. that +you want your ExtensionArray subclass to support. -.. code-block:: python - - @classmethod - def _create_method(cls, op, coerce_to_dtype=True): - -The method ``create_method`` should return a method with the signature -``binop(self, other)`` that returns the result of applying the operator ``op`` -to your ExtensionArray subclass. Your mixin class will then become a base class -for the provided :class:`ExtensionArithmeticOpsMixin` and -:class:`ExtensionComparisonOpsMixin` classes. - -For example, if your ExtensionArray subclass -is called ``MyExtensionArray``, you could create a mixin class ``MyOpsMixin`` -that has the following skeleton: - -.. code-block:: python - - class MyOpsMixin(object): - @classmethod - def _create_method(cls, op, coerce_to_dtype=True): - def _binop(self, other): - # Your implementation of the operator op - return _binop - -Then to use this class to define the operators for ``MyExtensionArray``, you can write: - -.. code-block:: python - - class MyExtensionArray(ExtensionArray, - ExtensionArithmeticOpsMixin(MyOpsMixin), - ExtensionComparisonOpsMixin(MyOpsMixin)) - -The mixin classes :class:`ExtensionArithmeticOpsMixin` and -:class:`ExtensionComparisonOpsMixin` will then define the appropriate operators -using your implementation of those operators in ``MyOpsMixin``. - -The second approach assumes that the underlying elements of the ExtensionArray +The second approach assumes that the underlying elements (i.e., scalar type) of the ExtensionArray have the individual operators already defined. In other words, if your ExtensionArray named ``MyExtensionArray`` is implemented so that each element is an instance of the class ``MyExtensionElement``, then if the operators are defined for ``MyExtensionElement``, the second approach will automatically define the operators for ``MyExtensionArray``. -Two mixin classes, :class:`~pandas.api.extensions.ExtensionScalarArithmeticMixin` and -:class:`~pandas.api.extensions.ExtensionScalarComparisonMixin`, support this second +A mixin class, :class:`~pandas.api.extensions.ExtensionScalarOpscMixin` supports this second approach. If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``, -simply include ``ExtensionScalarArithmeticMixin`` and/or -``ExtensionScalarComparisonMixin`` as parent classes of ``MyExtensionArray`` -as follows: +simply include ``ExtensionScalarOpsMixin`` as a parent class of ``MyExtensionArray`` +and then call the methods :meth:`~MyExtensionArray.addArithmeticOps` and/or +:meth:`~MyExtensionArray.addComparisonOps` to hook the operators into +your ``MyExtensionArray`` class, as follows: .. code-block:: python - class MyExtensionArray(ExtensionArray, ExtensionScalarArithmeticMixin, + class MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin, ExtensionScalarComparisonMixin): + pass + + MyExtensionArray.addArithmeticOps() + MyExtensionArray.addComparisonOps() Note that since ``pandas`` automatically calls the underlying operator on each element one-by-one, this might not be as performant as implementing your own diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index 1f87e29c6cb33..23c66e87f50cf 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -19,26 +19,23 @@ A ``Series`` based on ``ExtensionArray`` now supports arithmetic and comparison operators. There are two approaches for providing operator support for an ExtensionArray: 1. Define each of the operators on your ExtensionArray subclass. -2. Use operators from pandas defined on the ExtensionArray subclass based on already defined - operators on the underlying elements. +2. Use an operator implementation from pandas that depends on operators that are already defined + on the underlying elements (scalars) of the ExtensionArray. To use the first approach where you define your own implementation of the operators, -use one or both of the mixin classes, :class:`ExtensionArithmeticOpsMixin` and -:class:`ExtensionComparisonOpsMixin` that, by default, will create -operators that are ``NotImplemented``. To use those classes, you will need to create -a class that has the implementation of the operator methods. Details can be found in the -:ref:`ExtensionArray Operator Support ` documentation section. +you define each operator such as `__add__`, __le__`, etc. on your ExtensionArray +subclass. For the second approach, which is appropriate if your ExtensionArray contains elements that already have the operators -defined on a per-element basis, pandas provides two mixins, -:class:`ExtensionScalarArithmeticMixin` and :class:`ExtensionScalarComparisonMixin`, -that you can use that will automatically define the operators on your ExtensionArray -subclass. +defined on a per-element basis, pandas provides a mixin, +:class:`ExtensionScalarOpsMixin` that you can use that can +define the operators on your ExtensionArray subclass. If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``, -simply include ``ExtensionScalarArithmeticMixin`` and/or -``ExtensionScalarComparisonMixin`` as parent classes of ``MyExtensionArray`` -as follows: +simply include ``ExtensionScalarOpsMixin`` as a parent class of ``MyExtensionArray`` +and then call the methods :meth:`~MyExtensionArray.addArithmeticOps` and/or +:meth:`~MyExtensionArray.addComparisonOps` to hook the operators into +your ``MyExtensionArray`` class, as follows: .. code-block:: python diff --git a/pandas/api/extensions/__init__.py b/pandas/api/extensions/__init__.py index 152a45a185aa2..851a63725952a 100644 --- a/pandas/api/extensions/__init__.py +++ b/pandas/api/extensions/__init__.py @@ -4,9 +4,5 @@ register_series_accessor) from pandas.core.algorithms import take # noqa from pandas.core.arrays.base import (ExtensionArray, # noqa - ExtensionScalarArithmeticMixin, - ExtensionScalarComparisonMixin, - ExtensionScalarOpsMixin, - ExtensionArithmeticOpsMixin, - ExtensionComparisonOpsMixin) + ExtensionScalarOpsMixin) from pandas.core.dtypes.dtypes import ExtensionDtype # noqa diff --git a/pandas/core/arrays/__init__.py b/pandas/core/arrays/__init__.py index 5ee1a6d47f36f..f57348116c195 100644 --- a/pandas/core/arrays/__init__.py +++ b/pandas/core/arrays/__init__.py @@ -1,7 +1,3 @@ from .base import (ExtensionArray, # noqa - ExtensionScalarArithmeticMixin, - ExtensionScalarComparisonMixin, - ExtensionScalarOpsMixin, - ExtensionArithmeticOpsMixin, - ExtensionComparisonOpsMixin) + ExtensionScalarOpsMixin) from .categorical import Categorical # noqa diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index a39b75250de98..7f981e4739faa 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -617,42 +617,60 @@ def _ndarray_values(self): return np.array(self) -class ExtensionDefaultOpsMixin(object): + +class ExtensionOpsMixin(object): """ - A base class for default ops, that returns NotImplemented + A base class for linking the operators to their dunder names """ @classmethod - def _create_method(cls, op, coerce_to_dtype=True): - """ - A class method that returns a method will correspond to an - operator for an ExtensionArray subclass, by always indicating - the operator is NotImplemented - - Parameters - ---------- - op: function - An operator that takes arguments op(a, b) - coerce_to_dtype: bool - boolean indicating whether to attempt to convert - the result to the underlying ExtensionArray dtype - (default True) - - Returns - ------- - A method that can be bound to a method of a class - """ - - def returnNotImplemented(self, other): - return NotImplemented + def addArithmeticOps(cls): + cls.__add__ = cls._create_method(operator.add) + cls.__radd__ = cls._create_method(ops.radd) + cls.__sub__ = cls._create_method(operator.sub) + cls.__rsub__ = cls._create_method(ops.rsub) + cls.__mul__ = cls._create_method(operator.mul) + cls.__rmul__ = cls._create_method(ops.rmul) + cls.__pow__ = cls._create_method(operator.pow) + cls.__rpow__ = cls._create_method(ops.rpow) + cls.__mod__ = cls._create_method(operator.mod) + cls.__rmod__ = cls._create_method(ops.rmod) + cls.__floordiv__ = cls._create_method(operator.floordiv) + cls.__rfloordiv__ = cls._create_method(ops.rfloordiv) + cls.__truediv__ = cls._create_method(operator.truediv) + cls.__rtruediv__ = cls._create_method(ops.rtruediv) + if not PY3: + cls.__div__ = cls._create_method(operator.div) + cls.__rdiv__ = cls._create_method(ops.rdiv) + + cls.__divmod__ = cls._create_method(divmod) + cls.__rdivmod__ = cls._create_method(ops.rdivmod) + + @classmethod + def addComparisonOps(cls): + cls.__eq__ = cls._create_method(operator.eq, False) + cls.__ne__ = cls._create_method(operator.ne, False) + cls.__lt__ = cls._create_method(operator.lt, False) + cls.__gt__ = cls._create_method(operator.gt, False) + cls.__le__ = cls._create_method(operator.le, False) + cls.__ge__ = cls._create_method(operator.ge, False) - return returnNotImplemented +class ExtensionScalarOpsMixin(ExtensionOpsMixin): + """A mixin for defining the arithmetic and logical operations on + an ExtensionArray class, where it is assumed that the underlying objects + have the operators already defined. -class ExtensionScalarOpsMixin(object): - """ - A base class for the mixins for different operators. - Can also be used to define an individual method for a specific - operator using the class method create_method() + Usage + ------ + If you have defined a subclass MyExtensionArray(ExtensionArray), then + use MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin) to + get the arithmetic operators. After the definition of MyExtensionArray, + insert the lines + + MyExtensionArray.addArithmeticOperators() + MyExtensionArray.addComparisonOperators() + + to link the operators to your class. """ @classmethod @@ -678,11 +696,11 @@ def _create_method(cls, op, coerce_to_dtype=True): Example ------- - Given an ExtensionArray subclass called MyClass, use + Given an ExtensionArray subclass called MyExtensionArray, use - >>> __add__ = ExtensionScalarOpsMixin.create_method(operator.add) + >>> __add__ = cls._create_method(operator.add) - in the class definition of MyClass to create the operator + in the class definition of MyExtensionArray to create the operator for addition, that will be based on the operator implementation of the underlying elements of the ExtensionArray @@ -712,124 +730,3 @@ def convert_values(param): op_name = ops._get_op_name(op, True) return set_function_name(_binop, op_name, cls) - - -def ExtensionArithmeticOpsMixin(base=ExtensionDefaultOpsMixin): - """A mixin that will define the arithmetic operators. - - Parameters - ---------- - base: class - A class with the class method _create_method() defined as follows: - - @classmethod - def _create_method(cls, op, coerce_to_dtype=True): - Parameters - ---------- - op: function - An operator that takes arguments op(a, b) - coerce_to_dtype: bool - boolean indicating whether to attempt to convert - the result to the underlying ExtensionArray dtype - (default True) - - Returns - ------- - A method that can be bound to a method of a class. That method - should return the result of op(self, other) where self is - an ExtensionArray subclass - """ - - class _ExtensionArithmeticOpsMixin(base): - __add__ = base._create_method(operator.add) - __radd__ = base._create_method(ops.radd) - __sub__ = base._create_method(operator.sub) - __rsub__ = base._create_method(ops.rsub) - __mul__ = base._create_method(operator.mul) - __rmul__ = base._create_method(ops.rmul) - __pow__ = base._create_method(operator.pow) - __rpow__ = base._create_method(ops.rpow) - __mod__ = base._create_method(operator.mod) - __rmod__ = base._create_method(ops.rmod) - __floordiv__ = base._create_method(operator.floordiv) - __rfloordiv__ = base._create_method(ops.rfloordiv) - __truediv__ = base._create_method(operator.truediv) - __rtruediv__ = base._create_method(ops.rtruediv) - if not PY3: - __div__ = base._create_method(operator.div) - __rdiv__ = base._create_method(ops.rdiv) - - __divmod__ = base._create_method(divmod) - __rdivmod__ = base._create_method(ops.rdivmod) - - _ExtensionArithmeticOpsMixin.__name__ = ( - "ExtensionArithmeticOpsMixin_" + base.__name__) - return _ExtensionArithmeticOpsMixin - - -def ExtensionComparisonOpsMixin(base=ExtensionDefaultOpsMixin): - """A mixin that will define the comparison operators. - - Parameters - ---------- - base: class - A class with the class method _create_method() defined as follows: - - @classmethod - def _create_method(cls, op, coerce_to_dtype=True): - Parameters - ---------- - op: function - An operator that takes arguments op(a, b) - coerce_to_dtype: bool - boolean indicating whether to attempt to convert - the result to the underlying ExtensionArray dtype - (default True) - - Returns - ------- - A method that can be bound to a method of a class. That method - should return the result of op(self, other) where self is - an ExtensionArray subclass - """ - class _ExtensionComparisonOpsMixin(base): - __eq__ = base._create_method(operator.eq, False) - __ne__ = base._create_method(operator.ne, False) - __lt__ = base._create_method(operator.lt, False) - __gt__ = base._create_method(operator.gt, False) - __le__ = base._create_method(operator.le, False) - __ge__ = base._create_method(operator.ge, False) - - _ExtensionComparisonOpsMixin.__name__ = ( - "ExtensionComparisonOpsMixin_" + base.__name__) - return _ExtensionComparisonOpsMixin - - -class ExtensionScalarArithmeticMixin( - ExtensionArithmeticOpsMixin(ExtensionScalarOpsMixin)): - """A mixin for defining the arithmetic operations on an ExtensionArray - class, where it is assumed that the underlying objects have the operators - already defined. - - Usage - ------ - If you have defined a subclass MyClass(ExtensionArray), then - use MyClass(ExtensionArray, ExtensionScalarArithmeticMixin) to - get the arithmetic operators - """ - pass - - -class ExtensionScalarComparisonMixin( - ExtensionComparisonOpsMixin(ExtensionScalarOpsMixin)): - """A mixin for defining the comparison operations on an ExtensionArray - class, where it is assumed that the underlying objects have the operators - already defined. - - Usage - ------ - If you have defined a subclass MyClass(ExtensionArray), then - use MyClass(ExtensionArray, ExtensionScalarComparisonMixin) to - get the comparison operators - """ - pass diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index e4010023805ca..39bb840cbb85d 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -7,8 +7,7 @@ import pandas as pd from pandas.core.arrays import (ExtensionArray, - ExtensionScalarArithmeticMixin, - ExtensionScalarComparisonMixin) + ExtensionScalarOpsMixin) from pandas.core.dtypes.base import ExtensionDtype @@ -26,8 +25,7 @@ def construct_from_string(cls, string): "'{}'".format(cls, string)) -class DecimalArray(ExtensionArray, ExtensionScalarArithmeticMixin, - ExtensionScalarComparisonMixin): +class DecimalArray(ExtensionArray, ExtensionScalarOpsMixin): dtype = DecimalDtype() def __init__(self, values): @@ -106,6 +104,9 @@ def _na_value(self): def _concat_same_type(cls, to_concat): return cls(np.concatenate([x._data for x in to_concat])) +DecimalArray.addArithmeticOps() +DecimalArray.addComparisonOps() + def make_data(): return [decimal.Decimal(random.random()) for _ in range(100)] diff --git a/pandas/tests/extension/json/array.py b/pandas/tests/extension/json/array.py index 12bc1e072f5b1..d3043bf0852d2 100644 --- a/pandas/tests/extension/json/array.py +++ b/pandas/tests/extension/json/array.py @@ -20,9 +20,7 @@ import numpy as np from pandas.core.dtypes.base import ExtensionDtype -from pandas.core.arrays import (ExtensionArray, - ExtensionArithmeticOpsMixin, - ExtensionComparisonOpsMixin) +from pandas.core.arrays import ExtensionArray class JSONDtype(ExtensionDtype): @@ -43,8 +41,7 @@ def construct_from_string(cls, string): "'{}'".format(cls, string)) -class JSONArray(ExtensionArray, ExtensionArithmeticOpsMixin(), - ExtensionComparisonOpsMixin()): +class JSONArray(ExtensionArray): dtype = JSONDtype() def __init__(self, values): From 4bcf9780ae944cdda31f62e2da2551b56413b5ea Mon Sep 17 00:00:00 2001 From: Dr-Irv Date: Tue, 5 Jun 2018 20:07:03 -0400 Subject: [PATCH 07/12] Fix lint errors and make add methods private --- doc/source/extending.rst | 11 ++-- doc/source/whatsnew/v0.24.0.txt | 11 ++-- pandas/core/arrays/base.py | 71 ++++++++++++++----------- pandas/tests/extension/decimal/array.py | 5 +- 4 files changed, 54 insertions(+), 44 deletions(-) diff --git a/doc/source/extending.rst b/doc/source/extending.rst index 8c2b187a85573..e74bbbd7b8918 100644 --- a/doc/source/extending.rst +++ b/doc/source/extending.rst @@ -138,18 +138,17 @@ define the operators for ``MyExtensionArray``. A mixin class, :class:`~pandas.api.extensions.ExtensionScalarOpscMixin` supports this second approach. If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``, simply include ``ExtensionScalarOpsMixin`` as a parent class of ``MyExtensionArray`` -and then call the methods :meth:`~MyExtensionArray.addArithmeticOps` and/or -:meth:`~MyExtensionArray.addComparisonOps` to hook the operators into +and then call the methods :meth:`~MyExtensionArray._add_arithmetic_ops` and/or +:meth:`~MyExtensionArray._add_comparison_ops` to hook the operators into your ``MyExtensionArray`` class, as follows: .. code-block:: python - class MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin, - ExtensionScalarComparisonMixin): + class MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin): pass - MyExtensionArray.addArithmeticOps() - MyExtensionArray.addComparisonOps() + MyExtensionArray._add_arithmetic_ops() + MyExtensionArray._add_comparison_ops() Note that since ``pandas`` automatically calls the underlying operator on each element one-by-one, this might not be as performant as implementing your own diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index 23c66e87f50cf..4f3abda6d3bd2 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -33,14 +33,17 @@ defined on a per-element basis, pandas provides a mixin, define the operators on your ExtensionArray subclass. If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``, simply include ``ExtensionScalarOpsMixin`` as a parent class of ``MyExtensionArray`` -and then call the methods :meth:`~MyExtensionArray.addArithmeticOps` and/or -:meth:`~MyExtensionArray.addComparisonOps` to hook the operators into +and then call the methods :meth:`~MyExtensionArray._add_arithmetic_ops` and/or +:meth:`~MyExtensionArray._add_comparison_ops` to hook the operators into your ``MyExtensionArray`` class, as follows: .. code-block:: python - class MyExtensionArray(ExtensionArray, ExtensionScalarArithmeticMixin, - ExtensionScalarComparisonMixin): + class MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin): + pass + + MyExtensionArray._add_arithmetic_ops() + MyExtensionArray._add_comparison_ops() See the :ref:`ExtensionArray Operator Support ` documentation section for details on both diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 7f981e4739faa..d49f46dd6a099 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -617,47 +617,46 @@ def _ndarray_values(self): return np.array(self) - class ExtensionOpsMixin(object): """ A base class for linking the operators to their dunder names """ @classmethod - def addArithmeticOps(cls): - cls.__add__ = cls._create_method(operator.add) - cls.__radd__ = cls._create_method(ops.radd) - cls.__sub__ = cls._create_method(operator.sub) - cls.__rsub__ = cls._create_method(ops.rsub) - cls.__mul__ = cls._create_method(operator.mul) - cls.__rmul__ = cls._create_method(ops.rmul) - cls.__pow__ = cls._create_method(operator.pow) - cls.__rpow__ = cls._create_method(ops.rpow) - cls.__mod__ = cls._create_method(operator.mod) - cls.__rmod__ = cls._create_method(ops.rmod) - cls.__floordiv__ = cls._create_method(operator.floordiv) - cls.__rfloordiv__ = cls._create_method(ops.rfloordiv) - cls.__truediv__ = cls._create_method(operator.truediv) - cls.__rtruediv__ = cls._create_method(ops.rtruediv) + def _add_arithmetic_ops(cls): + cls.__add__ = cls._create_arithmetic_method(operator.add) + cls.__radd__ = cls._create_arithmetic_method(ops.radd) + cls.__sub__ = cls._create_arithmetic_method(operator.sub) + cls.__rsub__ = cls._create_arithmetic_method(ops.rsub) + cls.__mul__ = cls._create_arithmetic_method(operator.mul) + cls.__rmul__ = cls._create_arithmetic_method(ops.rmul) + cls.__pow__ = cls._create_arithmetic_method(operator.pow) + cls.__rpow__ = cls._create_arithmetic_method(ops.rpow) + cls.__mod__ = cls._create_arithmetic_method(operator.mod) + cls.__rmod__ = cls._create_arithmetic_method(ops.rmod) + cls.__floordiv__ = cls._create_arithmetic_method(operator.floordiv) + cls.__rfloordiv__ = cls._create_arithmetic_method(ops.rfloordiv) + cls.__truediv__ = cls._create_arithmetic_method(operator.truediv) + cls.__rtruediv__ = cls._create_arithmetic_method(ops.rtruediv) if not PY3: - cls.__div__ = cls._create_method(operator.div) - cls.__rdiv__ = cls._create_method(ops.rdiv) - - cls.__divmod__ = cls._create_method(divmod) - cls.__rdivmod__ = cls._create_method(ops.rdivmod) - + cls.__div__ = cls._create_arithmetic_method(operator.div) + cls.__rdiv__ = cls._create_arithmetic_method(ops.rdiv) + + cls.__divmod__ = cls._create_arithmetic_method(divmod) + cls.__rdivmod__ = cls._create_arithmetic_method(ops.rdivmod) + @classmethod - def addComparisonOps(cls): - cls.__eq__ = cls._create_method(operator.eq, False) - cls.__ne__ = cls._create_method(operator.ne, False) - cls.__lt__ = cls._create_method(operator.lt, False) - cls.__gt__ = cls._create_method(operator.gt, False) - cls.__le__ = cls._create_method(operator.le, False) - cls.__ge__ = cls._create_method(operator.ge, False) + def _add_comparison_ops(cls): + cls.__eq__ = cls._create_comparison_method(operator.eq) + cls.__ne__ = cls._create_comparison_method(operator.ne) + cls.__lt__ = cls._create_comparison_method(operator.lt) + cls.__gt__ = cls._create_comparison_method(operator.gt) + cls.__le__ = cls._create_comparison_method(operator.le) + cls.__ge__ = cls._create_comparison_method(operator.ge) class ExtensionScalarOpsMixin(ExtensionOpsMixin): """A mixin for defining the arithmetic and logical operations on - an ExtensionArray class, where it is assumed that the underlying objects + an ExtensionArray class, where it is assumed that the underlying objects have the operators already defined. Usage @@ -666,10 +665,10 @@ class ExtensionScalarOpsMixin(ExtensionOpsMixin): use MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin) to get the arithmetic operators. After the definition of MyExtensionArray, insert the lines - + MyExtensionArray.addArithmeticOperators() MyExtensionArray.addComparisonOperators() - + to link the operators to your class. """ @@ -730,3 +729,11 @@ def convert_values(param): op_name = ops._get_op_name(op, True) return set_function_name(_binop, op_name, cls) + + @classmethod + def _create_arithmetic_method(cls, op): + return cls._create_method(op) + + @classmethod + def _create_comparison_method(cls, op): + return cls._create_method(op, False) diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 39bb840cbb85d..3f2f24cd26af0 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -104,8 +104,9 @@ def _na_value(self): def _concat_same_type(cls, to_concat): return cls(np.concatenate([x._data for x in to_concat])) -DecimalArray.addArithmeticOps() -DecimalArray.addComparisonOps() + +DecimalArray._add_arithmetic_ops() +DecimalArray._add_comparison_ops() def make_data(): From 41dc5cae616a3dcfd765c471f7759a85c5a663ae Mon Sep 17 00:00:00 2001 From: Dr-Irv Date: Fri, 8 Jun 2018 10:10:57 -0400 Subject: [PATCH 08/12] Fix up merge issue --- pandas/core/series.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pandas/core/series.py b/pandas/core/series.py index b1db2d2dd5806..2ba1f15044952 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -2209,6 +2209,7 @@ def combine(self, other, func, fill_value=None): Perform elementwise binary operation on two Series using given function with optional fill value when an index is missing from one Series or the other + Parameters ---------- other : Series or scalar value @@ -2221,6 +2222,7 @@ def combine(self, other, func, fill_value=None): Returns ------- result : Series + Examples -------- >>> s1 = Series([1, 2]) @@ -2229,6 +2231,7 @@ def combine(self, other, func, fill_value=None): 0 0 1 2 dtype: int64 + See Also -------- Series.combine_first : Combine Series values, choosing the calling From a0f503cc880b768a6888f45523e0e6e4e2a7a204 Mon Sep 17 00:00:00 2001 From: Dr-Irv Date: Thu, 21 Jun 2018 11:32:54 -0400 Subject: [PATCH 09/12] Remove changes related to get() bug fix --- doc/source/extending.rst | 2 +- doc/source/whatsnew/v0.24.0.txt | 2 +- pandas/core/arrays/base.py | 4 ++-- pandas/core/indexes/base.py | 10 +++------- pandas/tests/extension/base/getitem.py | 2 +- 5 files changed, 8 insertions(+), 12 deletions(-) diff --git a/doc/source/extending.rst b/doc/source/extending.rst index a45b03bdca1e7..cb56bdb15f2bc 100644 --- a/doc/source/extending.rst +++ b/doc/source/extending.rst @@ -135,7 +135,7 @@ of the class ``MyExtensionElement``, then if the operators are defined for ``MyExtensionElement``, the second approach will automatically define the operators for ``MyExtensionArray``. -A mixin class, :class:`~pandas.api.extensions.ExtensionScalarOpscMixin` supports this second +A mixin class, :class:`~pandas.api.extensions.ExtensionScalarOpsMixin` supports this second approach. If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``, simply include ``ExtensionScalarOpsMixin`` as a parent class of ``MyExtensionArray`` and then call the methods :meth:`~MyExtensionArray._add_arithmetic_ops` and/or diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index 7ec061963e8a2..77f2fa47a5c86 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -31,7 +31,7 @@ subclass. For the second approach, which is appropriate if your ExtensionArray contains elements that already have the operators defined on a per-element basis, pandas provides a mixin, -:class:`ExtensionScalarOpsMixin` that you can use that can +:class:`ExtensionScalarOpsMixin` that you can use to define the operators on your ExtensionArray subclass. If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``, simply include ``ExtensionScalarOpsMixin`` as a parent class of ``MyExtensionArray`` diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index d8d2f7360266a..60f653825b63c 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -666,8 +666,8 @@ class ExtensionScalarOpsMixin(ExtensionOpsMixin): get the arithmetic operators. After the definition of MyExtensionArray, insert the lines - MyExtensionArray.addArithmeticOperators() - MyExtensionArray.addComparisonOperators() + MyExtensionArray._add_arithmetic_ops() + MyExtensionArray._add_comparison_ops() to link the operators to your class. """ diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index d163a6384263d..4f140a6e77b2f 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -2987,20 +2987,16 @@ def get_value(self, series, key): # use this, e.g. DatetimeIndex s = getattr(series, '_values', None) if isinstance(s, (ExtensionArray, Index)) and is_scalar(key): - # GH 20882, 21257 + # GH 20825 # Unify Index and ExtensionArray treatment # First try to convert the key to a location - # If that fails, raise a KeyError if an integer - # index, otherwise, see if key is an integer, and + # If that fails, see if key is an integer, and # try that try: iloc = self.get_loc(key) return s[iloc] except KeyError: - if (len(self) > 0 and - self.inferred_type in ['integer', 'boolean']): - raise - elif is_integer(key): + if is_integer(key): return s[key] s = com._values_from_object(series) diff --git a/pandas/tests/extension/base/getitem.py b/pandas/tests/extension/base/getitem.py index 390971c134642..883b3f5588aef 100644 --- a/pandas/tests/extension/base/getitem.py +++ b/pandas/tests/extension/base/getitem.py @@ -130,7 +130,7 @@ def test_get(self, data): expected = s.iloc[[0, 1]] self.assert_series_equal(result, expected) - assert s.get(-1) is None + assert s.get(-1) == s.iloc[-1] assert s.get(s.index.max() + 1) is None s = pd.Series(data[:6], index=list('abcdef')) From 700d75b8ade7a7449c7d3a3d0e63fcdbce0aaecf Mon Sep 17 00:00:00 2001 From: Dr-Irv Date: Fri, 22 Jun 2018 18:19:01 -0400 Subject: [PATCH 10/12] changes per joris comments --- doc/source/whatsnew/v0.24.0.txt | 4 +-- pandas/core/arrays/base.py | 6 ++--- pandas/tests/extension/base/ops.py | 27 +++++++++++-------- .../extension/category/test_categorical.py | 14 +++++----- .../tests/extension/decimal/test_decimal.py | 9 ++++--- 5 files changed, 33 insertions(+), 27 deletions(-) diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index 77f2fa47a5c86..18912eb55ff44 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -10,8 +10,6 @@ New features - ``ExcelWriter`` now accepts ``mode`` as a keyword argument, enabling append to existing workbooks when using the ``openpyxl`` engine (:issue:`3441`) -.. _whatsnew_0240.enhancements.other: - .. _whatsnew_0240.enhancements.extension_array_operators ``ExtensionArray`` operator support @@ -51,6 +49,8 @@ See the :ref:`ExtensionArray Operator Support ` documentation section for details on both ways of adding operator support. +.. _whatsnew_0240.enhancements.other: + Other Enhancements ^^^^^^^^^^^^^^^^^^ - :func:`to_datetime` now supports the ``%Z`` and ``%z`` directive when passed into ``format`` (:issue:`13486`) diff --git a/pandas/core/arrays/base.py b/pandas/core/arrays/base.py index 60f653825b63c..59ae684403382 100644 --- a/pandas/core/arrays/base.py +++ b/pandas/core/arrays/base.py @@ -682,9 +682,9 @@ def _create_method(cls, op, coerce_to_dtype=True): Parameters ---------- - op: function + op : function An operator that takes arguments op(a, b) - coerce_to_dtype: bool + coerce_to_dtype : bool boolean indicating whether to attempt to convert the result to the underlying ExtensionArray dtype (default True) @@ -736,4 +736,4 @@ def _create_arithmetic_method(cls, op): @classmethod def _create_comparison_method(cls, op): - return cls._create_method(op, False) + return cls._create_method(op, coerce_to_dtype=False) diff --git a/pandas/tests/extension/base/ops.py b/pandas/tests/extension/base/ops.py index 12ca85fad5624..659b9757ac1e3 100644 --- a/pandas/tests/extension/base/ops.py +++ b/pandas/tests/extension/base/ops.py @@ -7,8 +7,7 @@ class BaseOpsUtil(BaseExtensionTests): - def check_opname(self, s, op_name, other, exc=NotImplementedError): - + def get_op_from_name(self, op_name): short_opname = op_name.strip('_') try: op = getattr(operator, short_opname) @@ -17,6 +16,11 @@ def check_opname(self, s, op_name, other, exc=NotImplementedError): rop = getattr(operator, short_opname[1:]) op = lambda x, y: rop(y, x) + return op + + def check_opname(self, s, op_name, other, exc=NotImplementedError): + op = self.get_op_from_name(op_name) + self._check_op(s, op, other, exc) def _check_op(self, s, op, other, exc=NotImplementedError): @@ -59,31 +63,32 @@ def test_error(self, data, all_arithmetic_operators): class BaseComparisonOpsTests(BaseOpsUtil): """Various Series and DataFrame comparison ops methods.""" - def _compare_other(self, data, op_name, other): - s = pd.Series(data) - + def _compare_other(self, s, data, op_name, other): + op = self.get_op_from_name(op_name) if op_name == '__eq__': assert getattr(data, op_name)(other) is NotImplemented - assert not getattr(s, op_name)(other).all() + assert not op(s, other).all() elif op_name == '__ne__': assert getattr(data, op_name)(other) is NotImplemented - assert getattr(s, op_name)(other).all() + assert op(s, other).all() else: # array - getattr(data, op_name)(other) is NotImplementedError + assert getattr(data, op_name)(other) is NotImplemented # series s = pd.Series(data) with pytest.raises(TypeError): - getattr(s, op_name)(other) + op(s, other) def test_compare_scalar(self, data, all_compare_operators): op_name = all_compare_operators - self._compare_other(data, op_name, 0) + s = pd.Series(data) + self._compare_other(s, data, op_name, 0) def test_compare_array(self, data, all_compare_operators): op_name = all_compare_operators + s = pd.Series(data) other = [0] * len(data) - self._compare_other(data, op_name, other) + self._compare_other(s, data, op_name, other) diff --git a/pandas/tests/extension/category/test_categorical.py b/pandas/tests/extension/category/test_categorical.py index 71bba1c7abdba..ae0d72c204d13 100644 --- a/pandas/tests/extension/category/test_categorical.py +++ b/pandas/tests/extension/category/test_categorical.py @@ -198,14 +198,14 @@ def test_arith_scalar(self, data, all_arithmetic_operators): class TestComparisonOps(base.BaseComparisonOpsTests): - def _compare_other(self, data, op, other): + def _compare_other(self, s, data, op_name, other): + op = self.get_op_from_name(op_name) + if op_name == '__eq__': + assert not op(data, other).all() - if op == '__eq__': - assert not getattr(data, op)(other).all() - - elif op == '__ne__': - assert getattr(data, op)(other).all() + elif op_name == '__ne__': + assert op(data, other).all() else: with pytest.raises(TypeError): - getattr(data, op)(other) + op(data, other) diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 8e0d15b037d11..45ee7f227c4f0 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -235,19 +235,20 @@ def check_opname(self, s, op_name, other, exc=None): super(TestComparisonOps, self).check_opname(s, op_name, other, exc=None) - def _compare_other(self, data, op_name, other): - s = pd.Series(data) + def _compare_other(self, s, data, op_name, other): self.check_opname(s, op_name, other) def test_compare_scalar(self, data, all_compare_operators): op_name = all_compare_operators - self._compare_other(data, op_name, 0.5) + s = pd.Series(data) + self._compare_other(s, data, op_name, 0.5) def test_compare_array(self, data, all_compare_operators): op_name = all_compare_operators + s = pd.Series(data) alter = np.random.choice([-1, 0, 1], len(data)) # Randomly double, halve or keep same value other = pd.Series(data) * [decimal.Decimal(pow(2.0, i)) for i in alter] - self._compare_other(data, op_name, other) + self._compare_other(s, data, op_name, other) From 97bd2910b6daf5763687c2ad4c0a64e5ff50d9c1 Mon Sep 17 00:00:00 2001 From: Jeff Reback Date: Thu, 28 Jun 2018 20:13:09 -0400 Subject: [PATCH 11/12] doc & clean --- doc/source/extending.rst | 36 +++++++++++++++++---------------- doc/source/whatsnew/v0.24.0.txt | 22 ++++++++++---------- pandas/conftest.py | 24 +++++++--------------- 3 files changed, 37 insertions(+), 45 deletions(-) diff --git a/doc/source/extending.rst b/doc/source/extending.rst index cb56bdb15f2bc..38b3b19031a0e 100644 --- a/doc/source/extending.rst +++ b/doc/source/extending.rst @@ -118,41 +118,43 @@ and comments contain guidance for properly implementing the interface. :class:`~pandas.api.extensions.ExtensionArray` Operator Support ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +.. versionadded:: 0.24.0 + By default, there are no operators defined for the class :class:`~pandas.api.extensions.ExtensionArray`. There are two approaches for providing operator support for your ExtensionArray: -1. Define each of the operators on your ExtensionArray subclass. +1. Define each of the operators on your ``ExtensionArray`` subclass. 2. Use an operator implementation from pandas that depends on operators that are already defined on the underlying elements (scalars) of the ExtensionArray. - -For the first approach, you define selected operators, e.g., ``_add__``, ``__le__``, etc. that -you want your ExtensionArray subclass to support. - -The second approach assumes that the underlying elements (i.e., scalar type) of the ExtensionArray -have the individual operators already defined. In other words, if your ExtensionArray -named ``MyExtensionArray`` is implemented so that each element is an instance -of the class ``MyExtensionElement``, then if the operators are defined + +For the first approach, you define selected operators, e.g., ``__add__``, ``__le__``, etc. that +you want your ``ExtensionArray`` subclass to support. + +The second approach assumes that the underlying elements (i.e., scalar type) of the ``ExtensionArray`` +have the individual operators already defined. In other words, if your ``ExtensionArray`` +named ``MyExtensionArray`` is implemented so that each element is an instance +of the class ``MyExtensionElement``, then if the operators are defined for ``MyExtensionElement``, the second approach will automatically define the operators for ``MyExtensionArray``. A mixin class, :class:`~pandas.api.extensions.ExtensionScalarOpsMixin` supports this second approach. If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``, -simply include ``ExtensionScalarOpsMixin`` as a parent class of ``MyExtensionArray`` +can simply include ``ExtensionScalarOpsMixin`` as a parent class of ``MyExtensionArray``, and then call the methods :meth:`~MyExtensionArray._add_arithmetic_ops` and/or -:meth:`~MyExtensionArray._add_comparison_ops` to hook the operators into +:meth:`~MyExtensionArray._add_comparison_ops` to hook the operators into your ``MyExtensionArray`` class, as follows: .. code-block:: python class MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin): pass - + MyExtensionArray._add_arithmetic_ops() MyExtensionArray._add_comparison_ops() Note that since ``pandas`` automatically calls the underlying operator on each element one-by-one, this might not be as performant as implementing your own -version of the associated operators directly on the ExtensionArray. +version of the associated operators directly on the ``ExtensionArray``. .. _extending.extension.testing: @@ -220,11 +222,11 @@ There are 3 constructor properties to be defined: Following table shows how ``pandas`` data structures define constructor properties by default. =========================== ======================= ============= -Property Attributes ``Series`` ``DataFrame`` +Property Attributes ``Series`` ``DataFrame`` =========================== ======================= ============= -``_constructor`` ``Series`` ``DataFrame`` -``_constructor_sliced`` ``NotImplementedError`` ``Series`` -``_constructor_expanddim`` ``DataFrame`` ``Panel`` +``_constructor`` ``Series`` ``DataFrame`` +``_constructor_sliced`` ``NotImplementedError`` ``Series`` +``_constructor_expanddim`` ``DataFrame`` ``Panel`` =========================== ======================= ============= Below example shows how to define ``SubclassedSeries`` and ``SubclassedDataFrame`` overriding constructor properties. diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index e4bc478d8a739..b6008a6e8f135 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -15,38 +15,38 @@ New features ``ExtensionArray`` operator support ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -A ``Series`` based on ``ExtensionArray`` now supports arithmetic and comparison -operators. There are two approaches for providing operator support for an ExtensionArray: +A ``Series`` based on an ``ExtensionArray`` now supports arithmetic and comparison +operators. There are two approaches for providing operator support for an ExtensionArray: -1. Define each of the operators on your ExtensionArray subclass. +1. Define each of the operators on your ``ExtensionArray`` subclass. 2. Use an operator implementation from pandas that depends on operators that are already defined - on the underlying elements (scalars) of the ExtensionArray. + on the underlying elements (scalars) of the ``ExtensionArray``. -To use the first approach where you define your own implementation of the operators, -you define each operator such as `__add__`, __le__`, etc. on your ExtensionArray +To use the first approach, where you define your own implementation of the operators, +you define each operator such as `__add__`, __le__`, etc. on your ``ExtensionArray`` subclass. -For the second approach, which is appropriate if your ExtensionArray contains +For the second approach, which is appropriate if your ``ExtensionArray`` contains elements that already have the operators defined on a per-element basis, pandas provides a mixin, :class:`ExtensionScalarOpsMixin` that you can use to define the operators on your ExtensionArray subclass. If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``, -simply include ``ExtensionScalarOpsMixin`` as a parent class of ``MyExtensionArray`` + an simply include ``ExtensionScalarOpsMixin``, as a parent class of ``MyExtensionArray`` and then call the methods :meth:`~MyExtensionArray._add_arithmetic_ops` and/or -:meth:`~MyExtensionArray._add_comparison_ops` to hook the operators into +:meth:`~MyExtensionArray._add_comparison_ops` to hook the operators into your ``MyExtensionArray`` class, as follows: .. code-block:: python class MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin): pass - + MyExtensionArray._add_arithmetic_ops() MyExtensionArray._add_comparison_ops() See the :ref:`ExtensionArray Operator Support -` documentation section for details on both +` documentation section for details on both ways of adding operator support. .. _whatsnew_0240.enhancements.other: diff --git a/pandas/conftest.py b/pandas/conftest.py index 6bcc5a66b6014..8ca90722d17f7 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -108,6 +108,13 @@ def all_arithmetic_operators(request): def all_compare_operators(request): """ Fixture for dunder names for common compare operations + + * >= + * > + * == + * != + * < + * <= """ return request.param @@ -330,20 +337,3 @@ def mock(): return importlib.import_module("unittest.mock") else: return pytest.importorskip("mock") - - -@pytest.fixture(params=['__eq__', '__ne__', '__le__', - '__lt__', '__ge__', '__gt__']) -def all_compare_operators(request): - """ - Fixture for dunder names for common compare operations - - * >= - * > - * == - * != - * < - * <= - """ - - return request.param From 8fc93e49a28eccc4c66294e7238ba63f0e7bda29 Mon Sep 17 00:00:00 2001 From: Jeff Reback Date: Thu, 28 Jun 2018 20:18:28 -0400 Subject: [PATCH 12/12] moar doc --- doc/source/whatsnew/v0.24.0.txt | 25 +------------------------ 1 file changed, 1 insertion(+), 24 deletions(-) diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index b6008a6e8f135..2b38e7b1d5cc3 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -16,35 +16,12 @@ New features ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ A ``Series`` based on an ``ExtensionArray`` now supports arithmetic and comparison -operators. There are two approaches for providing operator support for an ExtensionArray: +operators. (:issue:`19577`). There are two approaches for providing operator support for an ``ExtensionArray``: 1. Define each of the operators on your ``ExtensionArray`` subclass. 2. Use an operator implementation from pandas that depends on operators that are already defined on the underlying elements (scalars) of the ``ExtensionArray``. -To use the first approach, where you define your own implementation of the operators, -you define each operator such as `__add__`, __le__`, etc. on your ``ExtensionArray`` -subclass. - -For the second approach, which is appropriate if your ``ExtensionArray`` contains -elements that already have the operators -defined on a per-element basis, pandas provides a mixin, -:class:`ExtensionScalarOpsMixin` that you can use to -define the operators on your ExtensionArray subclass. -If developing an ``ExtensionArray`` subclass, for example ``MyExtensionArray``, - an simply include ``ExtensionScalarOpsMixin``, as a parent class of ``MyExtensionArray`` -and then call the methods :meth:`~MyExtensionArray._add_arithmetic_ops` and/or -:meth:`~MyExtensionArray._add_comparison_ops` to hook the operators into -your ``MyExtensionArray`` class, as follows: - -.. code-block:: python - - class MyExtensionArray(ExtensionArray, ExtensionScalarOpsMixin): - pass - - MyExtensionArray._add_arithmetic_ops() - MyExtensionArray._add_comparison_ops() - See the :ref:`ExtensionArray Operator Support ` documentation section for details on both ways of adding operator support.