Skip to content

Commit

Permalink
ENH: add ops to extension array
Browse files Browse the repository at this point in the history
  • Loading branch information
jreback committed May 24, 2018
1 parent a854f06 commit 338ea4d
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 11 deletions.
9 changes: 9 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,16 @@ def observed(request):
def all_arithmetic_operators(request):
"""
Fixture for dunder names for common arithmetic operations
"""
return request.param


@pytest.fixture(params=['__eq__', '__ne__', '__le__',
'__lt__', '__ge__', '__gt__'])
def all_compare_operators(request):
"""
Fixture for dunder names for common compare operations
"""
return request.param


Expand Down
54 changes: 54 additions & 0 deletions pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

from pandas.errors import AbstractMethodError
from pandas.compat.numpy import function as nv
from pandas.compat import set_function_name, PY3
from pandas.core import ops
import operator

_not_implemented_message = "{} does not implement {}."

Expand Down Expand Up @@ -652,3 +655,54 @@ def _ndarray_values(self):
used for interacting with our indexers.
"""
return np.array(self)

# ------------------------------------------------------------------------
# ops-related methods
# ------------------------------------------------------------------------

@classmethod
def _add_comparison_methods_binary(cls):
cls.__eq__ = cls._make_comparison_op(operator.eq)
cls.__ne__ = cls._make_comparison_op(operator.ne)
cls.__lt__ = cls._make_comparison_op(operator.lt)
cls.__gt__ = cls._make_comparison_op(operator.gt)
cls.__le__ = cls._make_comparison_op(operator.le)
cls.__ge__ = cls._make_comparison_op(operator.ge)

@classmethod
def _add_numeric_methods_binary(cls):
""" add in numeric methods """
cls.__add__ = cls._make_arithmetic_op(operator.add)
cls.__radd__ = cls._make_arithmetic_op(ops.radd)
cls.__sub__ = cls._make_arithmetic_op(operator.sub)
cls.__rsub__ = cls._make_arithmetic_op(ops.rsub)
cls.__mul__ = cls._make_arithmetic_op(operator.mul)
cls.__rmul__ = cls._make_arithmetic_op(ops.rmul)
cls.__rpow__ = cls._make_arithmetic_op(ops.rpow)
cls.__pow__ = cls._make_arithmetic_op(operator.pow)
cls.__mod__ = cls._make_arithmetic_op(operator.mod)
cls.__floordiv__ = cls._make_arithmetic_op(operator.floordiv)
cls.__rfloordiv__ = cls._make_arithmetic_op(ops.rfloordiv)
cls.__truediv__ = cls._make_arithmetic_op(operator.truediv)
cls.__rtruediv__ = cls._make_arithmetic_op(ops.rtruediv)
if not PY3:
cls.__div__ = cls._make_arithmetic_op(operator.div)
cls.__rdiv__ = cls._make_arithmetic_op(ops.rdiv)

cls.__divmod__ = cls._make_arithmetic_op(divmod)

@classmethod
def make_comparison_op(cls, op):
def cmp_method(self, other):
raise NotImplementedError

name = '__{name}__'.format(name=op.__name__)
return set_function_name(cmp_method, name, cls)

@classmethod
def make_arithmetic_op(cls, op):
def integer_arithmetic_method(self, other):
raise NotImplementedError

name = '__{name}__'.format(name=op.__name__)
return set_function_name(integer_arithmetic_method, name, cls)
3 changes: 2 additions & 1 deletion pandas/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,8 @@ def fill_zeros(result, x, y, name, fill):
# if we have a fill of inf, then sign it correctly
# (GH 6178 and PR 9308)
if np.isinf(fill):
signs = np.sign(y if name.startswith(('r', '__r')) else x)
signs = y if name.startswith(('r', '__r')) else x
signs = np.sign(signs.astype('float', copy=False))
negative_inf_mask = (signs.ravel() < 0) & mask
np.putmask(result, negative_inf_mask, -fill)

Expand Down
38 changes: 30 additions & 8 deletions pandas/core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
is_integer_dtype, is_categorical_dtype,
is_object_dtype, is_timedelta64_dtype,
is_datetime64_dtype, is_datetime64tz_dtype,
is_bool_dtype,
is_bool_dtype, is_extension_array_dtype,
is_list_like,
is_scalar,
_ensure_object)
Expand Down Expand Up @@ -1003,8 +1003,18 @@ def _arith_method_SERIES(cls, op, special):
if op is divmod else _construct_result)

def na_op(x, y):
import pandas.core.computation.expressions as expressions
# handle extension array ops
# TODO(extension)
# the ops *between* non-same-type extension arrays are not
# very well defined
if (is_extension_array_dtype(x) or is_extension_array_dtype(y)):
if (op_name.startswith('__r') and not
is_extension_array_dtype(y) and not
is_scalar(y)):
y = x.__class__._from_sequence(y)
return op(x, y)

import pandas.core.computation.expressions as expressions
try:
result = expressions.evaluate(op, str_rep, x, y, **eval_kwargs)
except TypeError:
Expand All @@ -1025,6 +1035,7 @@ def na_op(x, y):
return result

def safe_na_op(lvalues, rvalues):
# all others
try:
with np.errstate(all='ignore'):
return na_op(lvalues, rvalues)
Expand All @@ -1035,14 +1046,21 @@ def safe_na_op(lvalues, rvalues):
raise

def wrapper(left, right):

if isinstance(right, ABCDataFrame):
return NotImplemented

left, right = _align_method_SERIES(left, right)
res_name = get_op_result_name(left, right)

if is_datetime64_dtype(left) or is_datetime64tz_dtype(left):
if is_categorical_dtype(left):
raise TypeError("{typ} cannot perform the operation "
"{op}".format(typ=type(left).__name__, op=str_rep))

elif (is_extension_array_dtype(left) or
is_extension_array_dtype(right)):
pass

elif is_datetime64_dtype(left) or is_datetime64tz_dtype(left):
result = dispatch_to_index_op(op, left, right, pd.DatetimeIndex)
return construct_result(left, result,
index=left.index, name=res_name,
Expand All @@ -1054,10 +1072,6 @@ def wrapper(left, right):
index=left.index, name=res_name,
dtype=result.dtype)

elif is_categorical_dtype(left):
raise TypeError("{typ} cannot perform the operation "
"{op}".format(typ=type(left).__name__, op=str_rep))

lvalues = left.values
rvalues = right
if isinstance(rvalues, ABCSeries):
Expand Down Expand Up @@ -1136,6 +1150,14 @@ def na_op(x, y):
# The `not is_scalar(y)` check excludes the string "category"
return op(y, x)

# handle extension array ops
# TODO(extension)
# the ops *between* non-same-type extension arrays are not
# very well defined
elif (is_extension_array_dtype(x) or
is_extension_array_dtype(y)):
return op(x, y)

elif is_object_dtype(x.dtype):
result = _comp_method_OBJECT_ARRAY(op, x, y)

Expand Down
54 changes: 54 additions & 0 deletions pandas/tests/extension/base/ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest
import numpy as np
import pandas as pd
from .base import BaseExtensionTests


class BaseOpsTests(BaseExtensionTests):
"""Various Series and DataFrame ops methos."""

def compare(self, s, op, other, exc=NotImplementedError):

with pytest.raises(exc):
getattr(s, op)(other)

def test_arith_scalar(self, data, all_arithmetic_operators):
# scalar
op = all_arithmetic_operators
s = pd.Series(data)
self.compare(s, op, 1, exc=TypeError)

def test_arith_array(self, data, all_arithmetic_operators):
# ndarray & other series
op = all_arithmetic_operators
s = pd.Series(data)
self.compare(s, op, np.ones(len(s), dtype=s.dtype.type), exc=TypeError)

def test_compare_scalar(self, data, all_compare_operators):
op = all_compare_operators

s = pd.Series(data)

if op in '__eq__':
assert getattr(data, op)(0) is NotImplemented
assert not getattr(s, op)(0).all()
elif op in '__ne__':
assert getattr(data, op)(0) is NotImplemented
assert getattr(s, op)(0).all()

else:

# array
getattr(data, op)(0) is NotImplementedError

# series
s = pd.Series(data)
with pytest.raises(TypeError):
getattr(s, op)(0)

def test_error(self, data, all_arithmetic_operators):

# invalid ops
op = all_arithmetic_operators
with pytest.raises(AttributeError):
getattr(data, op)
14 changes: 13 additions & 1 deletion pandas/tests/extension/category/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,19 @@ class TestDtype(base.BaseDtypeTests):


class TestOps(base.BaseOpsTests):
pass

def test_compare_scalar(self, data, all_compare_operators):
op = all_compare_operators

if op == '__eq__':
assert not getattr(data, op)(0).all()

elif op == '__ne__':
assert getattr(data, op)(0).all()

else:
with pytest.raises(TypeError):
getattr(data, op)(0)


class TestInterface(base.BaseInterfaceTests):
Expand Down
33 changes: 32 additions & 1 deletion pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,38 @@ class TestInterface(BaseDecimal, base.BaseInterfaceTests):


class TestOps(BaseDecimal, base.BaseOpsTests):
pass

def compare(self, s, op, other):
# TODO(extension)

pytest.xfail("not implemented")

result = getattr(s, op)(other)
expected = result

self.assert_series_equal(result, expected)

def test_arith_scalar(self, data, all_arithmetic_operators):
# scalar
op = all_arithmetic_operators
s = pd.Series(data)
self.compare(s, op, 1)

def test_arith_array(self, data, all_arithmetic_operators):
# ndarray & other series
op = all_arithmetic_operators
s = pd.Series(data)
self.compare(s, op, np.ones(len(s), dtype=s.dtype.type))

@pytest.mark.xfail(reason="Not implemented")
def test_compare_scalar(self, data, all_compare_operators):
op = all_compare_operators

# array
result = getattr(data, op)(0)
expected = getattr(data.data, op)(0)

tm.assert_series_equal(result, expected)


class TestConstructors(BaseDecimal, base.BaseConstructorsTests):
Expand Down

0 comments on commit 338ea4d

Please sign in to comment.