Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: extension ops #21191

Closed
wants to merge 12 commits into from
Closed
12 changes: 11 additions & 1 deletion doc/source/whatsnew/v0.24.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,22 @@ Datetimelike API Changes

- For :class:`DatetimeIndex` and :class:`TimedeltaIndex` with non-``None`` ``freq`` attribute, addition or subtraction of integer-dtyped array or ``Index`` will return an object of the same class (:issue:`19959`)

.. _whatsnew_0240.api.extension:

ExtensionType Changes
^^^^^^^^^^^^^^^^^^^^^

- ``ExtensionArray`` has gained the abstract methods ``.dropna()`` and ``.append()`` (:issue:`21185`)
- ``ExtensionDtype`` has gained the ability to instantiate from string dtypes, e.g. ``decimal`` would instantiate a registered ``DecimalDtype``; furthermore
the dtype has gained the ``construct_array_type`` (:issue:`21185`)
- The ``ExtensionArray`` constructor, ``_from_sequence`` now take the keyword arg ``copy=False`` (:issue:`21185`)

.. _whatsnew_0240.api.other:

Other API Changes
^^^^^^^^^^^^^^^^^

-
- Invalid consruction of ``IntervalDtype`` will now always raise a ``TypeError`` rather than a ``ValueError`` if the subdtype is invalid (:issue:`21185`)
-
-

Expand Down
9 changes: 9 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ def observed(request):
def all_arithmetic_operators(request):
"""
Fixture for dunder names for common arithmetic operations
"""
return request.param


@pytest.fixture(params=['__eq__', '__ne__', '__le__',
'__lt__', '__ge__', '__gt__'])
def all_compare_operators(request):
"""
Fixture for dunder names for common compare operations
"""
return request.param

Expand Down
4 changes: 2 additions & 2 deletions pandas/core/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _reconstruct_data(values, dtype, original):
"""
from pandas import Index
if is_extension_array_dtype(dtype):
pass
values = dtype.construct_array_type(values)._from_sequence(values)
elif is_datetime64tz_dtype(dtype) or is_period_dtype(dtype):
values = Index(original)._shallow_copy(values, name=None)
elif is_bool_dtype(dtype):
Expand Down Expand Up @@ -705,7 +705,7 @@ def value_counts(values, sort=True, ascending=False, normalize=False,

else:

if is_categorical_dtype(values) or is_sparse(values):
if is_extension_array_dtype(values) or is_sparse(values):

# handle Categorical and sparse,
result = Series(values)._values.value_counts(dropna=dropna)
Expand Down
100 changes: 99 additions & 1 deletion pandas/core/arrays/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

from pandas.errors import AbstractMethodError
from pandas.compat.numpy import function as nv
from pandas.compat import set_function_name, PY3
from pandas.core import ops
import operator

_not_implemented_message = "{} does not implement {}."

Expand Down Expand Up @@ -36,6 +39,7 @@ class ExtensionArray(object):
* isna
* take
* copy
* append
* _concat_same_type

An additional method is available to satisfy pandas' internal,
Expand All @@ -49,6 +53,7 @@ class ExtensionArray(object):
methods:

* fillna
* dropna
* unique
* factorize / _values_for_factorize
* argsort / _values_for_argsort
Expand Down Expand Up @@ -82,14 +87,16 @@ class ExtensionArray(object):
# Constructors
# ------------------------------------------------------------------------
@classmethod
def _from_sequence(cls, scalars):
def _from_sequence(cls, scalars, copy=False):
"""Construct a new ExtensionArray from a sequence of scalars.

Parameters
----------
scalars : Sequence
Each element will be an instance of the scalar type for this
array, ``cls.dtype.type``.
copy : boolean, default True
if True, copy the underlying data
Returns
-------
ExtensionArray
Expand Down Expand Up @@ -379,6 +386,16 @@ def fillna(self, value=None, method=None, limit=None):
new_values = self.copy()
return new_values

def dropna(self):
""" Return ExtensionArray without NA values

Returns
-------
valid : ExtensionArray
"""

return self[~self.isna()]

def unique(self):
"""Compute the ExtensionArray of unique values.

Expand Down Expand Up @@ -567,6 +584,34 @@ def copy(self, deep=False):
"""
raise AbstractMethodError(self)

def append(self, other):
"""
Append a collection of Arrays together

Parameters
----------
other : ExtensionArray or list/tuple of ExtensionArrays

Returns
-------
appended : ExtensionArray
"""

to_concat = [self]
cls = self.__class__

if isinstance(other, (list, tuple)):
to_concat = to_concat + list(other)
else:
to_concat.append(other)

for obj in to_concat:
if not isinstance(obj, cls):
raise TypeError('all inputs must be of type {}'.format(
cls.__name__))

return cls._concat_same_type(to_concat)

# ------------------------------------------------------------------------
# Block-related methods
# ------------------------------------------------------------------------
Expand Down Expand Up @@ -610,3 +655,56 @@ def _ndarray_values(self):
used for interacting with our indexers.
"""
return np.array(self)

# ------------------------------------------------------------------------
# ops-related methods
# ------------------------------------------------------------------------

@classmethod
def _add_comparison_methods_binary(cls):
cls.__eq__ = cls._make_comparison_op(operator.eq)
cls.__ne__ = cls._make_comparison_op(operator.ne)
cls.__lt__ = cls._make_comparison_op(operator.lt)
cls.__gt__ = cls._make_comparison_op(operator.gt)
cls.__le__ = cls._make_comparison_op(operator.le)
cls.__ge__ = cls._make_comparison_op(operator.ge)

@classmethod
def _add_numeric_methods_binary(cls):
""" add in numeric methods """
cls.__add__ = cls._make_arithmetic_op(operator.add)
cls.__radd__ = cls._make_arithmetic_op(ops.radd)
cls.__sub__ = cls._make_arithmetic_op(operator.sub)
cls.__rsub__ = cls._make_arithmetic_op(ops.rsub)
cls.__mul__ = cls._make_arithmetic_op(operator.mul)
cls.__rmul__ = cls._make_arithmetic_op(ops.rmul)
cls.__rpow__ = cls._make_arithmetic_op(ops.rpow)
cls.__pow__ = cls._make_arithmetic_op(operator.pow)
cls.__mod__ = cls._make_arithmetic_op(operator.mod)
cls.__rmod__ = cls._make_arithmetic_op(ops.rmod)
cls.__floordiv__ = cls._make_arithmetic_op(operator.floordiv)
cls.__rfloordiv__ = cls._make_arithmetic_op(ops.rfloordiv)
cls.__truediv__ = cls._make_arithmetic_op(operator.truediv)
cls.__rtruediv__ = cls._make_arithmetic_op(ops.rtruediv)
if not PY3:
cls.__div__ = cls._make_arithmetic_op(operator.div)
cls.__rdiv__ = cls._make_arithmetic_op(ops.rdiv)

cls.__divmod__ = cls._make_arithmetic_op(divmod)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rdivmod?

cls.__rdivmod__ = cls._make_arithmetic_op(ops.rdivmod)

@classmethod
def make_comparison_op(cls, op):
def cmp_method(self, other):
raise NotImplementedError

name = '__{name}__'.format(name=op.__name__)
return set_function_name(cmp_method, name, cls)

@classmethod
def make_arithmetic_op(cls, op):
def integer_arithmetic_method(self, other):
raise NotImplementedError

name = '__{name}__'.format(name=op.__name__)
return set_function_name(integer_arithmetic_method, name, cls)
21 changes: 21 additions & 0 deletions pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ class ExtensionDtype(_DtypeOpsMixin):
* name
* construct_from_string

Optionally one can override construct_array_type for construction
with the name of this dtype via the Registry

* construct_array_type

The `na_value` class attribute can be used to set the default NA value
for this type. :attr:`numpy.nan` is used by default.

Expand Down Expand Up @@ -156,6 +161,22 @@ def name(self):
"""
raise AbstractMethodError(self)

@classmethod
def construct_array_type(cls, array=None):
"""Return the array type associated with this dtype

Parameters
----------
array : array-like, optional

Returns
-------
type
"""
if array is None:
return cls
raise NotImplementedError

@classmethod
def construct_from_string(cls, string):
"""Attempt to construct this type from a string.
Expand Down
5 changes: 5 additions & 0 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,11 @@ def conv(r, dtype):
def astype_nansafe(arr, dtype, copy=True):
""" return a view if copy is False, but
need to be very careful as the result shape could change! """

# dispatch on extension dtype if needed
if is_extension_array_dtype(dtype):
return dtype.array_type._from_sequence(arr, copy=copy)

if not isinstance(dtype, np.dtype):
dtype = pandas_dtype(dtype)

Expand Down
39 changes: 7 additions & 32 deletions pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
DatetimeTZDtype, DatetimeTZDtypeType,
PeriodDtype, PeriodDtypeType,
IntervalDtype, IntervalDtypeType,
ExtensionDtype, PandasExtensionDtype)
ExtensionDtype, registry)
from .generic import (ABCCategorical, ABCPeriodIndex,
ABCDatetimeIndex, ABCSeries,
ABCSparseArray, ABCSparseSeries, ABCCategoricalIndex,
Expand Down Expand Up @@ -1975,38 +1975,13 @@ def pandas_dtype(dtype):
np.dtype or a pandas dtype
"""

if isinstance(dtype, DatetimeTZDtype):
return dtype
elif isinstance(dtype, PeriodDtype):
return dtype
elif isinstance(dtype, CategoricalDtype):
return dtype
elif isinstance(dtype, IntervalDtype):
return dtype
elif isinstance(dtype, string_types):
try:
return DatetimeTZDtype.construct_from_string(dtype)
except TypeError:
pass

if dtype.startswith('period[') or dtype.startswith('Period['):
# do not parse string like U as period[U]
try:
return PeriodDtype.construct_from_string(dtype)
except TypeError:
pass

elif dtype.startswith('interval') or dtype.startswith('Interval'):
try:
return IntervalDtype.construct_from_string(dtype)
except TypeError:
pass
# registered extension types
result = registry.find(dtype)
if result is not None:
return result

try:
return CategoricalDtype.construct_from_string(dtype)
except TypeError:
pass
elif isinstance(dtype, (PandasExtensionDtype, ExtensionDtype)):
# un-registered extension types
if isinstance(dtype, ExtensionDtype):
return dtype

try:
Expand Down
Loading