diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index 4ba0f0a73a2..9def2d5494b 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -6880,7 +6880,7 @@ def groupby( [[nan, nan, nan], [ 3., 4., 5.]]]) Coordinates: - * x_bins (x_bins) object 16B (5, 15] (15, 25] + * x_bins (x_bins) interval[int64, right] 16B (5, 15] (15, 25] * letters (letters) object 16B 'a' 'b' Dimensions without coordinates: y diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 449f502c43a..01996b0571b 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -10391,7 +10391,7 @@ def groupby( Size: 128B Dimensions: (y: 3, x_bins: 2, letters: 2) Coordinates: - * x_bins (x_bins) object 16B (5, 15] (15, 25] + * x_bins (x_bins) interval[int64, right] 16B (5, 15] (15, 25] * letters (letters) object 16B 'a' 'b' Dimensions without coordinates: y Data variables: diff --git a/xarray/core/indexes.py b/xarray/core/indexes.py index 37711275bce..caf68b28737 100644 --- a/xarray/core/indexes.py +++ b/xarray/core/indexes.py @@ -600,7 +600,10 @@ def __init__( self.index = index self.dim = dim - if coord_dtype is None: + if pd.api.types.is_extension_array_dtype(index.dtype): + cast(pd.api.extensions.ExtensionDtype, index.dtype) + coord_dtype = index.dtype + elif coord_dtype is None: coord_dtype = get_valid_numpy_dtype(index) self.coord_dtype = coord_dtype @@ -697,6 +700,8 @@ def concat( if not indexes: coord_dtype = None + elif len(set(idx.coord_dtype for idx in indexes)) == 1: + coord_dtype = indexes[0].coord_dtype else: coord_dtype = np.result_type(*[idx.coord_dtype for idx in indexes]) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 521abcdfddd..4c74d5d0db4 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -10,10 +10,11 @@ from dataclasses import dataclass, field from datetime import timedelta from html import escape -from typing import TYPE_CHECKING, Any, overload +from typing import TYPE_CHECKING, Any, cast, overload import numpy as np import pandas as pd +from numpy.typing import DTypeLike from packaging.version import Version from xarray.core import duck_array_ops @@ -34,8 +35,6 @@ from xarray.namedarray.pycompat import array_type, integer_types, is_chunked_array if TYPE_CHECKING: - from numpy.typing import DTypeLike - from xarray.core.indexes import Index from xarray.core.types import Self from xarray.core.variable import Variable @@ -1744,27 +1743,44 @@ class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin): __slots__ = ("_dtype", "array") array: pd.Index - _dtype: np.dtype + _dtype: np.dtype | pd.api.extensions.ExtensionDtype - def __init__(self, array: pd.Index, dtype: DTypeLike = None): + def __init__( + self, + array: pd.Index, + dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None, + ): from xarray.core.indexes import safe_cast_to_index self.array = safe_cast_to_index(array) if dtype is None: - self._dtype = get_valid_numpy_dtype(array) + if pd.api.types.is_extension_array_dtype(array.dtype): + cast(pd.api.extensions.ExtensionDtype, array.dtype) + self._dtype = array.dtype + else: + self._dtype = get_valid_numpy_dtype(array) + elif pd.api.types.is_extension_array_dtype(dtype): + self._dtype = cast(pd.api.extensions.ExtensionDtype, dtype) else: - self._dtype = np.dtype(dtype) + self._dtype = np.dtype(cast(DTypeLike, dtype)) @property - def dtype(self) -> np.dtype: + def dtype(self) -> np.dtype | pd.api.extensions.ExtensionDtype: # type: ignore[override] return self._dtype def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, + dtype: np.typing.DTypeLike | pd.api.extensions.ExtensionDtype | None = None, + /, + *, + copy: bool | None = None, ) -> np.ndarray: if dtype is None: dtype = self.dtype + if pd.api.types.is_extension_array_dtype(dtype): + dtype = get_valid_numpy_dtype(self.array) + dtype = cast(np.dtype, dtype) array = self.array if isinstance(array, pd.PeriodIndex): with suppress(AttributeError): @@ -1783,7 +1799,7 @@ def get_duck_array(self) -> np.ndarray: def shape(self) -> _Shape: return (len(self.array),) - def _convert_scalar(self, item): + def _convert_scalar(self, item) -> np.ndarray: if item is pd.NaT: # work around the impossibility of casting NaT with asarray # note: it probably would be better in general to return @@ -1799,7 +1815,10 @@ def _convert_scalar(self, item): # numpy fails to convert pd.Timestamp to np.datetime64[ns] item = np.asarray(item.to_datetime64()) elif self.dtype != object: - item = np.asarray(item, dtype=self.dtype) + dtype = self.dtype + if pd.api.types.is_extension_array_dtype(dtype): + dtype = get_valid_numpy_dtype(self.array) + item = np.asarray(item, dtype=cast(np.dtype, dtype)) # as for numpy.ndarray indexing, we always want the result to be # a NumPy array. @@ -1914,23 +1933,28 @@ class PandasMultiIndexingAdapter(PandasIndexingAdapter): __slots__ = ("_dtype", "adapter", "array", "level") array: pd.MultiIndex - _dtype: np.dtype + _dtype: np.dtype | pd.api.extensions.ExtensionDtype level: str | None def __init__( self, array: pd.MultiIndex, - dtype: DTypeLike = None, + dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None, level: str | None = None, ): super().__init__(array, dtype) self.level = level def __array__( - self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None + self, + dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None, + /, + *, + copy: bool | None = None, ) -> np.ndarray: if dtype is None: dtype = self.dtype + dtype = cast(np.dtype, dtype) if self.level is not None: return np.asarray( self.array.get_level_values(self.level).values, dtype=dtype diff --git a/xarray/core/variable.py b/xarray/core/variable.py index ed860dc0e6b..6e154b2fa87 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -13,7 +13,6 @@ import numpy as np import pandas as pd from numpy.typing import ArrayLike -from pandas.api.types import is_extension_array_dtype import xarray as xr # only for Dataset and DataArray from xarray.core import common, dtypes, duck_array_ops, indexing, nputils, ops, utils @@ -412,6 +411,10 @@ def data(self): if is_duck_array(self._data): return self._data elif isinstance(self._data, indexing.ExplicitlyIndexed): + if pd.api.types.is_extension_array_dtype(self._data) and isinstance( + self._data, PandasIndexingAdapter + ): + return self._data.array return self._data.get_duck_array() else: return self.values @@ -2592,11 +2595,6 @@ def chunk( # type: ignore[override] dask.array.from_array """ - if is_extension_array_dtype(self): - raise ValueError( - f"{self} was found to be a Pandas ExtensionArray. Please convert to numpy first." - ) - if from_array_kwargs is None: from_array_kwargs = {} diff --git a/xarray/namedarray/core.py b/xarray/namedarray/core.py index cdf9eab5c8d..21b3b993b53 100644 --- a/xarray/namedarray/core.py +++ b/xarray/namedarray/core.py @@ -5,6 +5,7 @@ import sys import warnings from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence +from importlib.util import find_spec from types import EllipsisType from typing import ( TYPE_CHECKING, @@ -834,7 +835,18 @@ def chunk( if chunkmanager.is_chunked_array(data_old): data_chunked = chunkmanager.rechunk(data_old, chunks) # type: ignore[arg-type] else: - if not isinstance(data_old, ExplicitlyIndexed): + if find_spec("pandas"): + from pandas.api.types import is_extension_array_dtype + else: + + def is_extension_array_dtype(dtype: Any) -> Literal[False]: # type: ignore[misc] + return False + + ndata: duckarray[Any, Any] + if is_extension_array_dtype(data_old.dtype): + # One of PandasExtensionArray or PandasIndexingAdapter? + ndata = np.asarray(data_old) + elif not isinstance(data_old, ExplicitlyIndexed): ndata = data_old else: # Unambiguously handle array storage backends (like NetCDF4 and h5py) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index cca9fe4f561..e8535f67124 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -524,7 +524,7 @@ def line( assert hueplt is not None ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label) - if np.issubdtype(xplt.dtype, np.datetime64): + if isinstance(xplt.dtype, np.dtype) and np.issubdtype(xplt.dtype, np.datetime64): # type: ignore[redundant-expr] _set_concise_date(ax, axis="x") _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 7be2d13f9dd..af23ac81396 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -4351,8 +4351,13 @@ def test_setitem_pandas(self) -> None: ds = self.make_example_math_dataset() ds["x"] = np.arange(3) ds_copy = ds.copy() - ds_copy["bar"] = ds["bar"].to_pandas() - + series = ds["bar"].to_pandas() + # to_pandas will actually give the result where the internal array of the series is a NumpyExtensionArray + # but ds["bar"] is a numpy array. + # TODO: should assert_equal be updated to handle? + assert (ds["bar"] == series).all() + del ds["bar"] + del ds_copy["bar"] assert_equal(ds, ds_copy) def test_setitem_auto_align(self) -> None: @@ -4943,6 +4948,16 @@ def test_to_and_from_dataframe(self) -> None: expected = pd.DataFrame([[]], index=idx) assert expected.equals(actual), (expected, actual) + def test_from_dataframe_categorical_dtype_index(self) -> None: + cat = pd.CategoricalIndex(list("abcd")) + df = pd.DataFrame({"f": [0, 1, 2, 3]}, index=cat) + ds = df.to_xarray() + restored = ds.to_dataframe() + df.index.name = ( + "index" # restored gets the name because it has the coord with the name + ) + pd.testing.assert_frame_equal(df, restored) + def test_from_dataframe_categorical_index(self) -> None: cat = pd.CategoricalDtype( categories=["foo", "bar", "baz", "qux", "quux", "corge"] @@ -4967,7 +4982,7 @@ def test_from_dataframe_categorical_index_string_categories(self) -> None: ) ser = pd.Series(1, index=cat) ds = ser.to_xarray() - assert ds.coords.dtypes["index"] == np.dtype("O") + assert ds.coords.dtypes["index"] == ser.index.dtype @requires_sparse def test_from_dataframe_sparse(self) -> None: diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index d42f86f5ea6..188545bae21 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -1118,7 +1118,8 @@ def test_groupby_math_nD_group() -> None: expected = da.isel(x=slice(30)) - expanded_mean expected["labels"] = expected.labels.broadcast_like(expected.labels2d) expected["num"] = expected.num.broadcast_like(expected.num2d) - expected["num2d_bins"] = (("x", "y"), mean.num2d_bins.data[idxr]) + # mean.num2d_bins.data is a pandas IntervalArray so needs to be put in `numpy` to allow indexing + expected["num2d_bins"] = (("x", "y"), mean.num2d_bins.data.to_numpy()[idxr]) actual = g - mean assert_identical(expected, actual) diff --git a/xarray/tests/test_variable.py b/xarray/tests/test_variable.py index c283797bd08..293f20a1204 100644 --- a/xarray/tests/test_variable.py +++ b/xarray/tests/test_variable.py @@ -656,7 +656,7 @@ def test_pandas_categorical_dtype(self): data = pd.Categorical(np.arange(10, dtype="int64")) v = self.cls("x", data) print(v) # should not error - assert v.dtype == "int64" + assert v.dtype == data.dtype def test_pandas_datetime64_with_tz(self): data = pd.date_range( @@ -667,9 +667,12 @@ def test_pandas_datetime64_with_tz(self): ) v = self.cls("x", data) print(v) # should not error - if "America/New_York" in str(data.dtype): - # pandas is new enough that it has datetime64 with timezone dtype - assert v.dtype == "object" + if v.dtype == np.dtype("O"): + import dask.array as da + + assert isinstance(v.data, da.Array) + else: + assert v.dtype == data.dtype def test_multiindex(self): idx = pd.MultiIndex.from_product([list("abc"), [0, 1]]) @@ -1592,14 +1595,6 @@ def test_pandas_categorical_dtype(self): print(v) # should not error assert pd.api.types.is_extension_array_dtype(v.dtype) - def test_pandas_categorical_no_chunk(self): - data = pd.Categorical(np.arange(10, dtype="int64")) - v = self.cls("x", data) - with pytest.raises( - ValueError, match=r".*was found to be a Pandas ExtensionArray.*" - ): - v.chunk((5,)) - def test_squeeze(self): v = Variable(["x", "y"], [[1]]) assert_identical(Variable([], 1), v.squeeze()) @@ -2414,8 +2409,8 @@ def test_pad(self, mode, xr_arg, np_arg): def test_pandas_categorical_dtype(self): data = pd.Categorical(np.arange(10, dtype="int64")) - with pytest.raises(ValueError, match="was found to be a Pandas ExtensionArray"): - self.cls("x", data) + v = self.cls("x", data) + assert (v.data.compute() == data.to_numpy()).all() @requires_sparse @@ -3019,7 +3014,7 @@ def test_datetime_conversion(values, unit) -> None: # todo: check for redundancy (suggested per review) dims = ["time"] if isinstance(values, np.ndarray | pd.Index | pd.Series) else [] var = Variable(dims, values) - if var.dtype.kind == "M": + if var.dtype.kind == "M" and isinstance(var.dtype, np.dtype): assert var.dtype == np.dtype(f"datetime64[{unit}]") else: # The only case where a non-datetime64 dtype can occur currently is in @@ -3061,8 +3056,12 @@ def test_pandas_two_only_datetime_conversion_warnings( # todo: check for redundancy (suggested per review) var = Variable(["time"], data.astype(dtype)) # type: ignore[arg-type] - if var.dtype.kind == "M": + # we internally convert series to numpy representations to avoid too much nastiness with extension arrays + # when calling data.array e.g., with NumpyExtensionArrays + if isinstance(data, pd.Series): assert var.dtype == np.dtype("datetime64[s]") + elif var.dtype.kind == "M": + assert var.dtype == dtype else: # The only case where a non-datetime64 dtype can occur currently is in # the case that the variable is backed by a timezone-aware