From a0f8c31f15baf19c5633d8fb65db0b5c7b46aa59 Mon Sep 17 00:00:00 2001 From: Brock Date: Fri, 5 May 2023 08:46:13 -0700 Subject: [PATCH 1/6] ENH: EADtype._maybe_promote --- pandas/core/arrays/arrow/dtype.py | 31 +++++++++++++++++++++++++++- pandas/core/dtypes/base.py | 3 +++ pandas/core/dtypes/cast.py | 17 +++++++-------- pandas/core/dtypes/dtypes.py | 21 +++++++++++++++++++ pandas/core/indexing.py | 5 ++--- pandas/tests/extension/test_arrow.py | 23 +++++++++++++++++++++ 6 files changed, 86 insertions(+), 14 deletions(-) diff --git a/pandas/core/arrays/arrow/dtype.py b/pandas/core/arrays/arrow/dtype.py index c416fbd03417a..266c6aaadfd15 100644 --- a/pandas/core/arrays/arrow/dtype.py +++ b/pandas/core/arrays/arrow/dtype.py @@ -8,10 +8,14 @@ ) from decimal import Decimal import re -from typing import TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Any, +) import numpy as np +from pandas._libs import missing as libmissing from pandas._libs.tslibs import ( Timedelta, Timestamp, @@ -23,6 +27,7 @@ StorageExtensionDtype, register_extension_dtype, ) +from pandas.core.dtypes.cast import maybe_promote from pandas.core.dtypes.dtypes import CategoricalDtypeType if not pa_version_under7p0: @@ -321,3 +326,27 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray): array_class = self.construct_array_type() arr = array.cast(self.pyarrow_dtype, safe=True) return array_class(arr) + + def _maybe_promote(self, item: Any) -> tuple[DtypeObj, Any]: + if isinstance(item, pa.Scalar): + if not item.is_valid: + # TODO: ask joris for help making these checks more robust + if item.type == self.pyarrow_dtype: + return self, item.as_py() + if item.type.to_pandas_dtype() == np.int64 and self.kind == "i": + # FIXME: kludge + return self, item.as_py() + + item = item.as_py() + + elif item is None or item is libmissing.NA: + # TODO: np.nan? use is_valid_na_for_dtype + return self, item + + dtype, item = maybe_promote(self.numpy_dtype, item) + + if dtype == self.numpy_dtype: + return self, item + + # TODO: implement from_numpy_dtype analogous to MaskedDtype.from_numpy_dtype + return np.dtype(object), item diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index f0e55aa178ec0..d702cf67bc330 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -391,6 +391,9 @@ def _can_hold_na(self) -> bool: """ return True + def _maybe_promote(self, item: Any) -> tuple[DtypeObj, Any]: + return np.dtype(object), item + class StorageExtensionDtype(ExtensionDtype): """ExtensionDtype that may be backed by more than one implementation.""" diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 6dabb866b8f5c..54d35e82a5fe2 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -46,7 +46,6 @@ ensure_int16, ensure_int32, ensure_int64, - ensure_object, ensure_str, is_bool, is_complex, @@ -539,13 +538,13 @@ def ensure_dtype_can_hold_na(dtype: DtypeObj) -> DtypeObj: } -def maybe_promote(dtype: np.dtype, fill_value=np.nan): +def maybe_promote(dtype: DtypeObj, fill_value=np.nan): """ Find the minimal dtype that can hold both the given dtype and fill_value. Parameters ---------- - dtype : np.dtype + dtype : np.dtype or ExtensionDtype fill_value : scalar, default np.nan Returns @@ -593,9 +592,13 @@ def _maybe_promote_cached(dtype, fill_value, fill_value_type): return _maybe_promote(dtype, fill_value) -def _maybe_promote(dtype: np.dtype, fill_value=np.nan): +def _maybe_promote(dtype: DtypeObj, fill_value=np.nan): # The actual implementation of the function, use `maybe_promote` above for # a cached version. + + if not isinstance(dtype, np.dtype): + return dtype._maybe_promote(fill_value) + if not is_scalar(fill_value): # with object dtype there is nothing to promote, and the user can # pass pretty much any weird fill_value they like @@ -611,12 +614,6 @@ def _maybe_promote(dtype: np.dtype, fill_value=np.nan): fv = na_value_for_dtype(dtype) return dtype, fv - elif isinstance(dtype, CategoricalDtype): - if fill_value in dtype.categories or isna(fill_value): - return dtype, fill_value - else: - return object, ensure_object(fill_value) - elif isna(fill_value): dtype = _dtype_obj if fill_value is None: diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 093101e2ae5a4..c69f940e8bd42 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -635,6 +635,15 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: return find_common_type(non_cat_dtypes) + def _maybe_promote(self, item) -> tuple[DtypeObj, Any]: + from pandas.core.dtypes.missing import is_valid_na_for_dtype + + if item in self.categories or is_valid_na_for_dtype( + item, self.categories.dtype + ): + return self, item + return np.dtype(object), item + @register_extension_dtype class DatetimeTZDtype(PandasExtensionDtype): @@ -1500,3 +1509,15 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: return type(self).from_numpy_dtype(new_dtype) except (KeyError, NotImplementedError): return None + + def _maybe_promote(self, item) -> tuple[DtypeObj, Any]: + from pandas.core.dtypes.cast import maybe_promote + from pandas.core.dtypes.missing import is_valid_na_for_dtype + + if is_valid_na_for_dtype(item, self): + return self, item + + dtype, item = maybe_promote(self.numpy_dtype, item) + if dtype.kind in "iufb": + return type(self).from_numpy_dtype(dtype), item + return dtype, item diff --git a/pandas/core/indexing.py b/pandas/core/indexing.py index 2a3119515bb99..d9e4de0282c36 100644 --- a/pandas/core/indexing.py +++ b/pandas/core/indexing.py @@ -2091,7 +2091,7 @@ def _setitem_with_indexer_missing(self, indexer, value): return self._setitem_with_indexer(new_indexer, value, "loc") # this preserves dtype of the value and of the object - if not is_scalar(value): + if is_list_like(value): new_dtype = None elif is_valid_na_for_dtype(value, self.obj.dtype): @@ -2107,8 +2107,7 @@ def _setitem_with_indexer_missing(self, indexer, value): # We should not cast, if we have object dtype because we can # set timedeltas into object series curr_dtype = self.obj.dtype - curr_dtype = getattr(curr_dtype, "numpy_dtype", curr_dtype) - new_dtype = maybe_promote(curr_dtype, value)[0] + new_dtype, value = maybe_promote(curr_dtype, value) else: new_dtype = None diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 5078a4e8078f8..9f2fb0f60acdb 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2855,6 +2855,29 @@ def test_describe_timedelta_data(pa_type): tm.assert_series_equal(result, expected) +@pytest.mark.parametrize( + "value, target_value, dtype", + [ + (pa.scalar(4, type="int32"), 4, "int32[pyarrow]"), + (pa.scalar(4, type="int64"), 4, "int32[pyarrow]"), + # (pa.scalar(4.5, type="float64"), 4, "int32[pyarrow]"), + (4, 4, "int32[pyarrow]"), + (pd.NA, None, "int32[pyarrow]"), + (None, None, "int32[pyarrow]"), + (pa.scalar(None, type="int32"), None, "int32[pyarrow]"), + (pa.scalar(None, type="int64"), None, "int32[pyarrow]"), + ], +) +def test_series_setitem_with_enlargement(value, target_value, dtype): + # GH#52235 + # similar to series/inedexing/test_setitem.py::test_setitem_keep_precision + # and test_setitem_enlarge_with_na, but for arrow dtypes + ser = pd.Series([1, 2, 3], dtype=dtype) + ser[3] = value + expected = pd.Series([1, 2, 3, target_value], dtype=dtype) + tm.assert_series_equal(ser, expected) + + @pytest.mark.parametrize("pa_type", tm.DATETIME_PYARROW_DTYPES) def test_describe_datetime_data(pa_type): # GH53001 From 0770a07c565ace4f78f04bd200039efed456aac2 Mon Sep 17 00:00:00 2001 From: Brock Date: Sat, 6 May 2023 13:23:07 -0700 Subject: [PATCH 2/6] simplify check --- pandas/core/indexing.py | 22 +++----------------- pandas/tests/series/indexing/test_setitem.py | 6 +++++- 2 files changed, 8 insertions(+), 20 deletions(-) diff --git a/pandas/core/indexing.py b/pandas/core/indexing.py index d9e4de0282c36..ec9f6a42bb97b 100644 --- a/pandas/core/indexing.py +++ b/pandas/core/indexing.py @@ -50,12 +50,7 @@ ABCDataFrame, ABCSeries, ) -from pandas.core.dtypes.missing import ( - infer_fill_value, - is_valid_na_for_dtype, - isna, - na_value_for_dtype, -) +from pandas.core.dtypes.missing import infer_fill_value from pandas.core import algorithms as algos import pandas.core.common as com @@ -2091,25 +2086,14 @@ def _setitem_with_indexer_missing(self, indexer, value): return self._setitem_with_indexer(new_indexer, value, "loc") # this preserves dtype of the value and of the object + new_dtype = None if is_list_like(value): - new_dtype = None - - elif is_valid_na_for_dtype(value, self.obj.dtype): - if not is_object_dtype(self.obj.dtype): - # Every NA value is suitable for object, no conversion needed - value = na_value_for_dtype(self.obj.dtype, compat=False) - - new_dtype = maybe_promote(self.obj.dtype, value)[0] - - elif isna(value): - new_dtype = None + pass elif not self.obj.empty and not is_object_dtype(self.obj.dtype): # We should not cast, if we have object dtype because we can # set timedeltas into object series curr_dtype = self.obj.dtype new_dtype, value = maybe_promote(curr_dtype, value) - else: - new_dtype = None new_values = Series([value], dtype=new_dtype)._values diff --git a/pandas/tests/series/indexing/test_setitem.py b/pandas/tests/series/indexing/test_setitem.py index 39cbf2b7bac10..d7d85098a71d1 100644 --- a/pandas/tests/series/indexing/test_setitem.py +++ b/pandas/tests/series/indexing/test_setitem.py @@ -570,7 +570,11 @@ def test_setitem_enlargement_object_none(self, nulls_fixture): ser[3] = nulls_fixture expected = Series(["a", "b", nulls_fixture], index=[0, 1, 3]) tm.assert_series_equal(ser, expected) - assert ser[3] is nulls_fixture + if isinstance(nulls_fixture, float): + # We retain the same type, but maybe not the same _object_ + assert np.isnan(ser[3]) + else: + assert ser[3] is nulls_fixture def test_setitem_scalar_into_readonly_backing_data(): From 9be179bb87ebaf549144ec207d04269fde994676 Mon Sep 17 00:00:00 2001 From: Brock Date: Sat, 6 May 2023 13:26:40 -0700 Subject: [PATCH 3/6] docstring --- pandas/core/dtypes/base.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index d702cf67bc330..cc762bfb218fd 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -63,6 +63,7 @@ class ExtensionDtype: * _is_numeric * _is_boolean * _get_common_dtype + * _maybe_promote The `na_value` class attribute can be used to set the default NA value for this type. :attr:`numpy.nan` is used by default. @@ -392,6 +393,22 @@ def _can_hold_na(self) -> bool: return True def _maybe_promote(self, item: Any) -> tuple[DtypeObj, Any]: + """ + Find the minimal dtype that we need to upcast to in order to hold + the given item. + + This is used in when doing Series setitem-with-expansion on a Series + with our dtype. + + Parameters + ---------- + item : object + + Returns + ------- + np.dtype or ExtensionDtype + object + """ return np.dtype(object), item From c62dcaafde69e21e300e2c2ec582f29ec5bf9b8d Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 27 Jul 2023 14:17:38 -0700 Subject: [PATCH 4/6] Better name, tests --- pandas/core/dtypes/base.py | 2 +- pandas/core/dtypes/cast.py | 2 +- pandas/core/dtypes/dtypes.py | 6 +++--- pandas/tests/indexing/test_loc.py | 22 ++++++++++++++++++++++ 4 files changed, 27 insertions(+), 5 deletions(-) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index b4d0ca35e1164..23a29fe6f9275 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -392,7 +392,7 @@ def _can_hold_na(self) -> bool: """ return True - def _maybe_promote(self, item: Any) -> tuple[DtypeObj, Any]: + def _find_compatible_dtype(self, item: Any) -> tuple[DtypeObj, Any]: """ Find the minimal dtype that we need to upcast to in order to hold the given item. diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index f1be02d15e6d2..38917ba1b9519 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -612,7 +612,7 @@ def _maybe_promote(dtype: DtypeObj, fill_value=np.nan): # a cached version. if not isinstance(dtype, np.dtype): - return dtype._maybe_promote(fill_value) + return dtype._find_compatible_dtype(fill_value) if not is_scalar(fill_value): # with object dtype there is nothing to promote, and the user can diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index c2326ea1febd4..4c5e8daa9785f 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -671,7 +671,7 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: return find_common_type(non_cat_dtypes) - def _maybe_promote(self, item) -> tuple[DtypeObj, Any]: + def _find_compatible_dtype(self, item) -> tuple[DtypeObj, Any]: from pandas.core.dtypes.missing import is_valid_na_for_dtype if item in self.categories or is_valid_na_for_dtype( @@ -1569,7 +1569,7 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: except (KeyError, NotImplementedError): return None - def _maybe_promote(self, item) -> tuple[DtypeObj, Any]: + def _find_compatible_dtype(self, item) -> tuple[DtypeObj, Any]: from pandas.core.dtypes.cast import maybe_promote from pandas.core.dtypes.missing import is_valid_na_for_dtype @@ -2316,7 +2316,7 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray): arr = array.cast(self.pyarrow_dtype, safe=True) return array_class(arr) - def _maybe_promote(self, item: Any) -> tuple[DtypeObj, Any]: + def _find_compatible_dtype(self, item: Any) -> tuple[DtypeObj, Any]: if isinstance(item, pa.Scalar): if not item.is_valid: # TODO: ask joris for help making these checks more robust diff --git a/pandas/tests/indexing/test_loc.py b/pandas/tests/indexing/test_loc.py index 2bbeebcff8ebd..7471b144e77eb 100644 --- a/pandas/tests/indexing/test_loc.py +++ b/pandas/tests/indexing/test_loc.py @@ -1956,6 +1956,28 @@ def test_loc_drops_level(self): class TestLocSetitemWithExpansion: + def test_series_loc_setitem_with_expansion_categorical(self): + # NA value that can't be held in integer categories + ser = Series([1, 2, 3], dtype="category") + ser.loc[3] = pd.NaT + assert ser.dtype == object + + def test_series_loc_setitem_with_expansion_interval(self): + idx = pd.interval_range(1, 3) + ser2 = Series(idx) + ser2.loc[2] = np.nan + assert ser2.dtype == "interval[float64, right]" + + ser2.loc[3] = pd.NaT + assert ser2.dtype == object + + def test_series_loc_setitem_with_expansion_list_object(self): + ser3 = Series(range(3)) + ser3.loc[3] = [] + assert ser3.dtype == object + item = ser3.loc[3] + assert isinstance(item, list) and len(item) == 0 + @pytest.mark.slow def test_loc_setitem_with_expansion_large_dataframe(self): # GH#10692 From c85a3eb7db38bb5c870bd04bb3cce31ed1caeba7 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 27 Jul 2023 14:18:45 -0700 Subject: [PATCH 5/6] Updat edoc --- pandas/core/dtypes/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index 23a29fe6f9275..59d0e63645582 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -63,7 +63,7 @@ class ExtensionDtype: * _is_numeric * _is_boolean * _get_common_dtype - * _maybe_promote + * _find_compatible_dtype The `na_value` class attribute can be used to set the default NA value for this type. :attr:`numpy.nan` is used by default. From fb0b03a8cc7f8e05fdfdbd657cf4b32d93627d58 Mon Sep 17 00:00:00 2001 From: Brock Date: Thu, 27 Jul 2023 15:52:11 -0700 Subject: [PATCH 6/6] Fix setitem-with-expansion for Decimal --- pandas/tests/extension/decimal/array.py | 8 ++++++++ pandas/tests/extension/decimal/test_decimal.py | 9 +++++++++ 2 files changed, 17 insertions(+) diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 3101bbd171f75..81e1efcc4f3d4 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -65,6 +65,14 @@ def construct_array_type(cls) -> type_t[DecimalArray]: def _is_numeric(self) -> bool: return True + def _find_compatible_dtype(self, item): + if isinstance(item, decimal.Decimal): + # TODO: need to watch out for precision? + # TODO: allow non-decimal numeric? + # TODO: allow np.nan, pd.NA? + return (self, item) + return (np.dtype(object), item) + class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray): __array_priority__ = 1000 diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 4f0ff427dd900..88516c65340c5 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -577,3 +577,12 @@ def test_array_copy_on_write(using_copy_on_write): {"a": [decimal.Decimal(2), decimal.Decimal(3)]}, dtype=DecimalDtype() ) tm.assert_equal(df2.values, expected.values) + + +def test_setitem_with_expansion(): + # GH#32346 dont upcast to object + arr = DecimalArray(make_data()) + ser = pd.Series(arr[:3]) + + ser[3] = ser[0] + assert ser.dtype == arr.dtype