diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index 6b00a5284ec5b..c6386db82c23e 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -65,6 +65,7 @@ class ExtensionDtype: * _is_numeric * _is_boolean * _get_common_dtype + * _find_compatible_dtype The `na_value` class attribute can be used to set the default NA value for this type. :attr:`numpy.nan` is used by default. @@ -445,6 +446,25 @@ def _can_fast_transpose(self) -> bool: """ return False + def _find_compatible_dtype(self, item: Any) -> tuple[DtypeObj, Any]: + """ + Find the minimal dtype that we need to upcast to in order to hold + the given item. + + This is used in when doing Series setitem-with-expansion on a Series + with our dtype. + + Parameters + ---------- + item : object + + Returns + ------- + np.dtype or ExtensionDtype + object + """ + return np.dtype(object), item + class StorageExtensionDtype(ExtensionDtype): """ExtensionDtype that may be backed by more than one implementation.""" diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index 320f028f4484c..3a104f5719148 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -50,7 +50,6 @@ ensure_int16, ensure_int32, ensure_int64, - ensure_object, ensure_str, is_bool, is_complex, @@ -548,13 +547,13 @@ def ensure_dtype_can_hold_na(dtype: DtypeObj) -> DtypeObj: } -def maybe_promote(dtype: np.dtype, fill_value=np.nan): +def maybe_promote(dtype: DtypeObj, fill_value=np.nan): """ Find the minimal dtype that can hold both the given dtype and fill_value. Parameters ---------- - dtype : np.dtype + dtype : np.dtype or ExtensionDtype fill_value : scalar, default np.nan Returns @@ -611,9 +610,13 @@ def _maybe_promote_cached(dtype, fill_value, fill_value_type): return _maybe_promote(dtype, fill_value) -def _maybe_promote(dtype: np.dtype, fill_value=np.nan): +def _maybe_promote(dtype: DtypeObj, fill_value=np.nan): # The actual implementation of the function, use `maybe_promote` above for # a cached version. + + if not isinstance(dtype, np.dtype): + return dtype._find_compatible_dtype(fill_value) + if not is_scalar(fill_value): # with object dtype there is nothing to promote, and the user can # pass pretty much any weird fill_value they like @@ -629,12 +632,6 @@ def _maybe_promote(dtype: np.dtype, fill_value=np.nan): fv = na_value_for_dtype(dtype) return dtype, fv - elif isinstance(dtype, CategoricalDtype): - if fill_value in dtype.categories or isna(fill_value): - return dtype, fill_value - else: - return object, ensure_object(fill_value) - elif isna(fill_value): dtype = _dtype_obj if fill_value is None: diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index 07fb25008cfb2..3a9d953d39004 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -686,6 +686,15 @@ def index_class(self) -> type_t[CategoricalIndex]: return CategoricalIndex + def _find_compatible_dtype(self, item) -> tuple[DtypeObj, Any]: + from pandas.core.dtypes.missing import is_valid_na_for_dtype + + if item in self.categories or is_valid_na_for_dtype( + item, self.categories.dtype + ): + return self, item + return np.dtype(object), item + @register_extension_dtype class DatetimeTZDtype(PandasExtensionDtype): @@ -1606,6 +1615,18 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None: except (KeyError, NotImplementedError): return None + def _find_compatible_dtype(self, item) -> tuple[DtypeObj, Any]: + from pandas.core.dtypes.cast import maybe_promote + from pandas.core.dtypes.missing import is_valid_na_for_dtype + + if is_valid_na_for_dtype(item, self): + return self, item + + dtype, item = maybe_promote(self.numpy_dtype, item) + if dtype.kind in "iufb": + return type(self).from_numpy_dtype(dtype), item + return dtype, item + @register_extension_dtype class SparseDtype(ExtensionDtype): @@ -2344,3 +2365,29 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray): array_class = self.construct_array_type() arr = array.cast(self.pyarrow_dtype, safe=True) return array_class(arr) + + def _find_compatible_dtype(self, item: Any) -> tuple[DtypeObj, Any]: + if isinstance(item, pa.Scalar): + if not item.is_valid: + # TODO: ask joris for help making these checks more robust + if item.type == self.pyarrow_dtype: + return self, item.as_py() + if item.type.to_pandas_dtype() == np.int64 and self.kind == "i": + # FIXME: kludge + return self, item.as_py() + + item = item.as_py() + + elif item is None or item is libmissing.NA: + # TODO: np.nan? use is_valid_na_for_dtype + return self, item + + from pandas.core.dtypes.cast import maybe_promote + + dtype, item = maybe_promote(self.numpy_dtype, item) + + if dtype == self.numpy_dtype: + return self, item + + # TODO: implement from_numpy_dtype analogous to MaskedDtype.from_numpy_dtype + return np.dtype(object), item diff --git a/pandas/core/indexing.py b/pandas/core/indexing.py index e3928621a4e48..62dc8cc793f46 100644 --- a/pandas/core/indexing.py +++ b/pandas/core/indexing.py @@ -56,12 +56,7 @@ ABCDataFrame, ABCSeries, ) -from pandas.core.dtypes.missing import ( - infer_fill_value, - is_valid_na_for_dtype, - isna, - na_value_for_dtype, -) +from pandas.core.dtypes.missing import infer_fill_value from pandas.core import algorithms as algos import pandas.core.common as com @@ -2203,26 +2198,14 @@ def _setitem_with_indexer_missing(self, indexer, value): return self._setitem_with_indexer(new_indexer, value, "loc") # this preserves dtype of the value and of the object - if not is_scalar(value): - new_dtype = None - - elif is_valid_na_for_dtype(value, self.obj.dtype): - if not is_object_dtype(self.obj.dtype): - # Every NA value is suitable for object, no conversion needed - value = na_value_for_dtype(self.obj.dtype, compat=False) - - new_dtype = maybe_promote(self.obj.dtype, value)[0] - - elif isna(value): - new_dtype = None + new_dtype = None + if is_list_like(value): + pass elif not self.obj.empty and not is_object_dtype(self.obj.dtype): # We should not cast, if we have object dtype because we can # set timedeltas into object series curr_dtype = self.obj.dtype - curr_dtype = getattr(curr_dtype, "numpy_dtype", curr_dtype) - new_dtype = maybe_promote(curr_dtype, value)[0] - else: - new_dtype = None + new_dtype, value = maybe_promote(curr_dtype, value) new_values = Series([value], dtype=new_dtype)._values diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 9ce7ac309b6d3..19a085c4fe2e8 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -65,6 +65,14 @@ def construct_array_type(cls) -> type_t[DecimalArray]: def _is_numeric(self) -> bool: return True + def _find_compatible_dtype(self, item): + if isinstance(item, decimal.Decimal): + # TODO: need to watch out for precision? + # TODO: allow non-decimal numeric? + # TODO: allow np.nan, pd.NA? + return (self, item) + return (np.dtype(object), item) + class DecimalArray(OpsMixin, ExtensionScalarOpsMixin, ExtensionArray): __array_priority__ = 1000 diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index 5ccffd1d25b3d..562d97f739ae1 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -535,3 +535,12 @@ def test_array_copy_on_write(using_copy_on_write): {"a": [decimal.Decimal(2), decimal.Decimal(3)]}, dtype=DecimalDtype() ) tm.assert_equal(df2.values, expected.values) + + +def test_setitem_with_expansion(): + # GH#32346 dont upcast to object + arr = DecimalArray(make_data()) + ser = pd.Series(arr[:3]) + + ser[3] = ser[0] + assert ser.dtype == arr.dtype diff --git a/pandas/tests/extension/test_arrow.py b/pandas/tests/extension/test_arrow.py index 1941e359299b6..060e6a3c93c6a 100644 --- a/pandas/tests/extension/test_arrow.py +++ b/pandas/tests/extension/test_arrow.py @@ -2751,6 +2751,29 @@ def test_describe_timedelta_data(pa_type): tm.assert_series_equal(result, expected) +@pytest.mark.parametrize( + "value, target_value, dtype", + [ + (pa.scalar(4, type="int32"), 4, "int32[pyarrow]"), + (pa.scalar(4, type="int64"), 4, "int32[pyarrow]"), + # (pa.scalar(4.5, type="float64"), 4, "int32[pyarrow]"), + (4, 4, "int32[pyarrow]"), + (pd.NA, None, "int32[pyarrow]"), + (None, None, "int32[pyarrow]"), + (pa.scalar(None, type="int32"), None, "int32[pyarrow]"), + (pa.scalar(None, type="int64"), None, "int32[pyarrow]"), + ], +) +def test_series_setitem_with_enlargement(value, target_value, dtype): + # GH#52235 + # similar to series/inedexing/test_setitem.py::test_setitem_keep_precision + # and test_setitem_enlarge_with_na, but for arrow dtypes + ser = pd.Series([1, 2, 3], dtype=dtype) + ser[3] = value + expected = pd.Series([1, 2, 3, target_value], dtype=dtype) + tm.assert_series_equal(ser, expected) + + @pytest.mark.parametrize("pa_type", tm.DATETIME_PYARROW_DTYPES) def test_describe_datetime_data(pa_type): # GH53001 diff --git a/pandas/tests/indexing/test_loc.py b/pandas/tests/indexing/test_loc.py index ce7dde3c4cb42..b026b59502f43 100644 --- a/pandas/tests/indexing/test_loc.py +++ b/pandas/tests/indexing/test_loc.py @@ -1994,6 +1994,28 @@ def test_loc_drops_level(self): class TestLocSetitemWithExpansion: + def test_series_loc_setitem_with_expansion_categorical(self): + # NA value that can't be held in integer categories + ser = Series([1, 2, 3], dtype="category") + ser.loc[3] = pd.NaT + assert ser.dtype == object + + def test_series_loc_setitem_with_expansion_interval(self): + idx = pd.interval_range(1, 3) + ser2 = Series(idx) + ser2.loc[2] = np.nan + assert ser2.dtype == "interval[float64, right]" + + ser2.loc[3] = pd.NaT + assert ser2.dtype == object + + def test_series_loc_setitem_with_expansion_list_object(self): + ser3 = Series(range(3)) + ser3.loc[3] = [] + assert ser3.dtype == object + item = ser3.loc[3] + assert isinstance(item, list) and len(item) == 0 + def test_loc_setitem_with_expansion_large_dataframe(self, monkeypatch): # GH#10692 size_cutoff = 50 diff --git a/pandas/tests/series/indexing/test_setitem.py b/pandas/tests/series/indexing/test_setitem.py index e583d55101a8b..b77473e7fc9ef 100644 --- a/pandas/tests/series/indexing/test_setitem.py +++ b/pandas/tests/series/indexing/test_setitem.py @@ -621,7 +621,10 @@ def test_setitem_enlargement_object_none(self, nulls_fixture, using_infer_string ) expected = Series(["a", "b", nulls_fixture], index=[0, 1, 3], dtype=dtype) tm.assert_series_equal(ser, expected) - if using_infer_string: + if isinstance(nulls_fixture, float): + # We retain the same type, but maybe not the same _object_ + assert np.isnan(ser[3]) + elif using_infer_string: ser[3] is np.nan else: assert ser[3] is nulls_fixture