From d43e86acd3b8412fbf0f7a7ec0b9c47cb63a3df0 Mon Sep 17 00:00:00 2001 From: Jeff Reback Date: Sun, 13 May 2018 15:21:34 -0400 Subject: [PATCH] add in extension dtype registry --- pandas/core/dtypes/common.py | 39 ++-------- pandas/core/dtypes/dtypes.py | 80 +++++++++++++++++++++ pandas/core/series.py | 1 + pandas/tests/dtypes/test_dtypes.py | 23 +++++- pandas/tests/extension/base/constructors.py | 12 ++++ 5 files changed, 120 insertions(+), 35 deletions(-) diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index c45838e6040a98..4d9846b3518145 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -9,7 +9,7 @@ DatetimeTZDtype, DatetimeTZDtypeType, PeriodDtype, PeriodDtypeType, IntervalDtype, IntervalDtypeType, - ExtensionDtype, PandasExtensionDtype) + ExtensionDtype, registry) from .generic import (ABCCategorical, ABCPeriodIndex, ABCDatetimeIndex, ABCSeries, ABCSparseArray, ABCSparseSeries, ABCCategoricalIndex, @@ -1975,39 +1975,10 @@ def pandas_dtype(dtype): np.dtype or a pandas dtype """ - if isinstance(dtype, DatetimeTZDtype): - return dtype - elif isinstance(dtype, PeriodDtype): - return dtype - elif isinstance(dtype, CategoricalDtype): - return dtype - elif isinstance(dtype, IntervalDtype): - return dtype - elif isinstance(dtype, string_types): - try: - return DatetimeTZDtype.construct_from_string(dtype) - except TypeError: - pass - - if dtype.startswith('period[') or dtype.startswith('Period['): - # do not parse string like U as period[U] - try: - return PeriodDtype.construct_from_string(dtype) - except TypeError: - pass - - elif dtype.startswith('interval') or dtype.startswith('Interval'): - try: - return IntervalDtype.construct_from_string(dtype) - except TypeError: - pass - - try: - return CategoricalDtype.construct_from_string(dtype) - except TypeError: - pass - elif isinstance(dtype, (PandasExtensionDtype, ExtensionDtype)): - return dtype + # registered extension types + result = registry.find(dtype) + if result is not None: + return result try: npdtype = np.dtype(dtype) diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 708f54f5ca75ba..795f8ec54f3d57 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -2,12 +2,64 @@ import re import numpy as np +from collections import OrderedDict from pandas import compat from pandas.core.dtypes.generic import ABCIndexClass, ABCCategoricalIndex from .base import ExtensionDtype, _DtypeOpsMixin +class Registry: + """ class to register our dtypes for inference + + We can directly construct dtypes in pandas_dtypes if they are + a type; the registry allows us to register an extension dtype + to try inference from a string or a dtype class + + These are tried in order for inference. + """ + dtypes = OrderedDict() + + @classmethod + def register(self, dtype, constructor=None): + """ + Parameters + ---------- + dtype : PandasExtension Dtype + """ + if not issubclass(dtype, PandasExtensionDtype): + raise ValueError("can only register pandas extension dtypes") + + if constructor is None: + constructor = dtype.construct_from_string + + self.dtypes[dtype] = constructor + + def find(self, dtype): + """ + Parameters + ---------- + dtype : PandasExtensionDtype or string + + Returns + ------- + return the first matching dtype, otherwise return None + """ + for dtype_type, constructor in self.dtypes.items(): + if isinstance(dtype, dtype_type): + return dtype + if isinstance(dtype, compat.string_types): + try: + return constructor(dtype) + except TypeError: + pass + + return None + + +registry = Registry() + + class PandasExtensionDtype(_DtypeOpsMixin): """ A np.dtype duck-typed class, suitable for holding a custom dtype. @@ -564,6 +616,17 @@ def construct_from_string(cls, string): pass raise TypeError("could not construct PeriodDtype") + @classmethod + def construct_from_string_strict(cls, string): + """ + Strict construction from a string, raise a TypeError if not + possible + """ + if string.startswith('period[') or string.startswith('Period['): + # do not parse string like U as period[U] + return PeriodDtype.construct_from_string(string) + raise TypeError("could not construct PeriodDtype") + def __unicode__(self): return "period[{freq}]".format(freq=self.freq.freqstr) @@ -683,6 +746,16 @@ def construct_from_string(cls, string): msg = "a string needs to be passed, got type {typ}" raise TypeError(msg.format(typ=type(string))) + @classmethod + def construct_from_string_strict(cls, string): + """ + Strict construction from a string, raise a TypeError if not + possible + """ + if string.startswith('interval') or string.startswith('Interval'): + return IntervalDtype.construct_from_string(string) + raise TypeError("cannot construct IntervalDtype") + def __unicode__(self): if self.subtype is None: return "interval" @@ -723,3 +796,10 @@ def is_dtype(cls, dtype): else: return False return super(IntervalDtype, cls).is_dtype(dtype) + + +# register the dtypes in search order +registry.register(DatetimeTZDtype) +registry.register(PeriodDtype, PeriodDtype.construct_from_string_strict) +registry.register(IntervalDtype, IntervalDtype.construct_from_string_strict) +registry.register(CategoricalDtype) diff --git a/pandas/core/series.py b/pandas/core/series.py index 0e2ae22f35af7b..0e5005ac562758 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -4060,6 +4060,7 @@ def _try_cast(arr, take_fast_path): "Pass the extension array directly.".format(dtype)) raise ValueError(msg) + elif dtype is not None and raise_cast_failure: raise else: diff --git a/pandas/tests/dtypes/test_dtypes.py b/pandas/tests/dtypes/test_dtypes.py index cc833af03ae66d..6c353283ba2db8 100644 --- a/pandas/tests/dtypes/test_dtypes.py +++ b/pandas/tests/dtypes/test_dtypes.py @@ -12,7 +12,7 @@ from pandas.compat import string_types from pandas.core.dtypes.dtypes import ( DatetimeTZDtype, PeriodDtype, - IntervalDtype, CategoricalDtype) + IntervalDtype, CategoricalDtype, registry) from pandas.core.dtypes.common import ( is_categorical_dtype, is_categorical, is_datetime64tz_dtype, is_datetimetz, @@ -767,3 +767,24 @@ def test_update_dtype_errors(self, bad_dtype): msg = 'a CategoricalDtype must be passed to perform an update, ' with tm.assert_raises_regex(ValueError, msg): dtype.update_dtype(bad_dtype) + + +@pytest.mark.parametrize( + 'dtype', + [DatetimeTZDtype, CategoricalDtype, + PeriodDtype, IntervalDtype]) +def test_registry(dtype): + assert dtype in registry.dtypes + + +@pytest.mark.parametrize( + 'dtype, expected', + [('int64', None), + ('interval', IntervalDtype()), + ('interval[int64]', IntervalDtype()), + ('category', CategoricalDtype()), + ('period[D]', PeriodDtype('D')), + ('datetime64[ns, US/Eastern]', DatetimeTZDtype('ns', 'US/Eastern'))]) +def test_registry_find(dtype, expected): + + assert registry.find(dtype) == expected diff --git a/pandas/tests/extension/base/constructors.py b/pandas/tests/extension/base/constructors.py index 489a430bb40200..972ef7f37accae 100644 --- a/pandas/tests/extension/base/constructors.py +++ b/pandas/tests/extension/base/constructors.py @@ -1,5 +1,6 @@ import pytest +import numpy as np import pandas as pd import pandas.util.testing as tm from pandas.core.internals import ExtensionBlock @@ -45,3 +46,14 @@ def test_series_given_mismatched_index_raises(self, data): msg = 'Length of passed values is 3, index implies 5' with tm.assert_raises_regex(ValueError, msg): pd.Series(data[:3], index=[0, 1, 2, 3, 4]) + + def test_from_dtype(self, data): + # construct from our dtype & string dtype + dtype = data.dtype + + expected = pd.Series(data) + result = pd.Series(np.array(data), dtype=dtype) + self.assert_series_equal(result, expected) + + result = pd.Series(np.array(data), dtype=str(dtype)) + self.assert_series_equal(result, expected)