From dc768621a51a927f36f669c42d3efc806028cc4a Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 28 Jun 2022 10:38:47 +0100 Subject: [PATCH 01/12] Support NumPy array API (experimental) --- xarray/core/duck_array_ops.py | 6 ++++- xarray/core/indexing.py | 43 ++++++++++++++++++++++++++++++ xarray/core/utils.py | 7 +++-- xarray/core/variable.py | 6 ++++- xarray/tests/test_array_api.py | 48 ++++++++++++++++++++++++++++++++++ 5 files changed, 106 insertions(+), 4 deletions(-) create mode 100644 xarray/tests/test_array_api.py diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 6e73ee41b40..2cd2fb3af04 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -329,7 +329,11 @@ def f(values, axis=None, skipna=None, **kwargs): if name in ["sum", "prod"]: kwargs.pop("min_count", None) - func = getattr(np, name) + if hasattr(values, "__array_namespace__"): + xp = values.__array_namespace__() + func = getattr(xp, name) + else: + func = getattr(np, name) try: with warnings.catch_warnings(): diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 9a29b63f4d0..9d1f397b3b4 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -679,6 +679,8 @@ def as_indexable(array): return DaskIndexingAdapter(array) if hasattr(array, "__array_function__"): return NdArrayLikeIndexingAdapter(array) + if hasattr(array, "__array_namespace__"): + return ArrayApiIndexingAdapter(array) raise TypeError(f"Invalid array type: {type(array)}") @@ -1288,6 +1290,47 @@ def __init__(self, array): self.array = array +class ArrayApiIndexingAdapter(ExplicitlyIndexedNDArrayMixin): + """Wrap an array API array to use explicit indexing.""" + + __slots__ = ("array",) + + def __init__(self, array): + if not hasattr(array, "__array_namespace__"): + raise TypeError( + "ArrayApiIndexingAdapter must wrap an object that " + "implements the __array_namespace__ protocol" + ) + self.array = array + + def __getitem__(self, key): + if isinstance(key, BasicIndexer): + return self.array[key.tuple] + elif isinstance(key, OuterIndexer): + # manual orthogonal indexing (implemented like DaskIndexingAdapter) + key = key.tuple + value = self.array + for axis, subkey in reversed(list(enumerate(key))): + value = value[(slice(None),) * axis + (subkey, Ellipsis)] + return value + else: + assert isinstance(key, VectorizedIndexer) + raise TypeError("Vectorized indexing is not supported") + + def __setitem__(self, key, value): + if isinstance(key, BasicIndexer): + self.array[key.tuple] = value + elif isinstance(key, OuterIndexer): + self.array[key.tuple] = value + else: + assert isinstance(key, VectorizedIndexer) + raise TypeError("Vectorized indexing is not supported") + + def transpose(self, order): + xp = self.array.__array_namespace__() + return xp.permute_dims(self.array, order) + + class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin): """Wrap a dask array to support explicit indexing.""" diff --git a/xarray/core/utils.py b/xarray/core/utils.py index ab3f8d3a282..51bf1346506 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -263,8 +263,10 @@ def is_duck_array(value: Any) -> bool: hasattr(value, "ndim") and hasattr(value, "shape") and hasattr(value, "dtype") - and hasattr(value, "__array_function__") - and hasattr(value, "__array_ufunc__") + and ( + (hasattr(value, "__array_function__") and hasattr(value, "__array_ufunc__")) + or hasattr(value, "__array_namespace__") + ) ) @@ -298,6 +300,7 @@ def _is_scalar(value, include_0d): or not ( isinstance(value, (Iterable,) + NON_NUMPY_SUPPORTED_ARRAY_TYPES) or hasattr(value, "__array_function__") + or hasattr(value, "__array_namespace__") ) ) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 90edf652284..f5c913b1764 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -211,7 +211,11 @@ def as_compatible_data(data, fastpath=False): if isinstance(data, (Variable, DataArray)): return data.data - if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES): + if ( + isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES) + or hasattr(data, "__array_function__") + or hasattr(data, "__array_namespace__") + ): return _maybe_wrap_data(data) if isinstance(data, tuple): diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py new file mode 100644 index 00000000000..01860bee4e4 --- /dev/null +++ b/xarray/tests/test_array_api.py @@ -0,0 +1,48 @@ +import numpy.array_api as xp +import pytest +from numpy.array_api._array_object import Array + +import xarray as xr +from xarray.testing import assert_equal + +np = pytest.importorskip("numpy", minversion="1.22") + + +@pytest.fixture +def arrays(): + np_arr = xr.DataArray(np.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]}) + xp_arr = xr.DataArray(xp.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]}) + assert isinstance(xp_arr.data, Array) + return np_arr, xp_arr + + +def test_arithmetic(arrays): + np_arr, xp_arr = arrays + expected = np_arr + 7 + actual = xp_arr + 7 + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_aggregation(arrays): + np_arr, xp_arr = arrays + expected = np_arr.sum(skipna=False) + actual = xp_arr.sum(skipna=False) + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_indexing(arrays): + np_arr, xp_arr = arrays + expected = np_arr[:, 0] + actual = xp_arr[:, 0] + assert isinstance(actual.data, Array) + assert_equal(actual, expected) + + +def test_reorganizing_operation(arrays): + np_arr, xp_arr = arrays + expected = np_arr.transpose() + actual = xp_arr.transpose() + assert isinstance(actual.data, Array) + assert_equal(actual, expected) From 14ef4435496b1540a66f34f837f20f1c65f45f1f Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 19 Jul 2022 10:26:38 +0100 Subject: [PATCH 02/12] Address feedback --- xarray/core/indexing.py | 10 +++++----- xarray/core/variable.py | 10 ++++------ 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 9d1f397b3b4..7005b45b910 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1314,16 +1314,16 @@ def __getitem__(self, key): value = value[(slice(None),) * axis + (subkey, Ellipsis)] return value else: - assert isinstance(key, VectorizedIndexer) + if not isinstance(key, VectorizedIndexer): + raise TypeError(f"Unrecognized indexer: {key}") raise TypeError("Vectorized indexing is not supported") def __setitem__(self, key, value): - if isinstance(key, BasicIndexer): - self.array[key.tuple] = value - elif isinstance(key, OuterIndexer): + if isinstance(key, (BasicIndexer, OuterIndexer)): self.array[key.tuple] = value else: - assert isinstance(key, VectorizedIndexer) + if not isinstance(key, VectorizedIndexer): + raise TypeError(f"Unrecognized indexer: {key}") raise TypeError("Vectorized indexing is not supported") def transpose(self, order): diff --git a/xarray/core/variable.py b/xarray/core/variable.py index f5c913b1764..502bf8482f2 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -211,11 +211,7 @@ def as_compatible_data(data, fastpath=False): if isinstance(data, (Variable, DataArray)): return data.data - if ( - isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES) - or hasattr(data, "__array_function__") - or hasattr(data, "__array_namespace__") - ): + if isinstance(data, NON_NUMPY_SUPPORTED_ARRAY_TYPES): return _maybe_wrap_data(data) if isinstance(data, tuple): @@ -241,7 +237,9 @@ def as_compatible_data(data, fastpath=False): else: data = np.asarray(data) - if not isinstance(data, np.ndarray) and hasattr(data, "__array_function__"): + if not isinstance(data, np.ndarray) and ( + hasattr(data, "__array_function__") or hasattr(data, "__array_namespace__") + ): return data # validate whether the data is valid data types. From ffbb6023defd5bdfccbd2be57cd6a58142f652e8 Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 19 Jul 2022 12:03:19 +0100 Subject: [PATCH 03/12] Update xarray/core/indexing.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/core/indexing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 7005b45b910..f0b69d5ce2e 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1322,9 +1322,10 @@ def __setitem__(self, key, value): if isinstance(key, (BasicIndexer, OuterIndexer)): self.array[key.tuple] = value else: - if not isinstance(key, VectorizedIndexer): + if isinstance(key, VectorizedIndexer): + raise TypeError("Vectorized indexing is not supported") + else: raise TypeError(f"Unrecognized indexer: {key}") - raise TypeError("Vectorized indexing is not supported") def transpose(self, order): xp = self.array.__array_namespace__() From fe3320258114ee08df56072a2b2f4517d63fb2cc Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 19 Jul 2022 12:03:33 +0100 Subject: [PATCH 04/12] Update xarray/core/indexing.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/core/indexing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index f0b69d5ce2e..72ca60d4d5e 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -1314,9 +1314,10 @@ def __getitem__(self, key): value = value[(slice(None),) * axis + (subkey, Ellipsis)] return value else: - if not isinstance(key, VectorizedIndexer): + if isinstance(key, VectorizedIndexer): + raise TypeError("Vectorized indexing is not supported") + else: raise TypeError(f"Unrecognized indexer: {key}") - raise TypeError("Vectorized indexing is not supported") def __setitem__(self, key, value): if isinstance(key, (BasicIndexer, OuterIndexer)): From 0da1f5e176df5a328be3d1d63371ad7263df3721 Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 19 Jul 2022 13:44:35 +0100 Subject: [PATCH 05/12] Update xarray/tests/test_array_api.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/tests/test_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index 01860bee4e4..bd46d1029ce 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -9,7 +9,7 @@ @pytest.fixture -def arrays(): +def arrays() -> tuple[xr.DataArray, xr.DataArray]: np_arr = xr.DataArray(np.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]}) xp_arr = xr.DataArray(xp.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]}) assert isinstance(xp_arr.data, Array) From 7e91bd5d2d1bb95137eb51074a3db023439435b3 Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 19 Jul 2022 13:44:43 +0100 Subject: [PATCH 06/12] Update xarray/tests/test_array_api.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/tests/test_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index bd46d1029ce..d63de8db4ed 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -16,7 +16,7 @@ def arrays() -> tuple[xr.DataArray, xr.DataArray]: return np_arr, xp_arr -def test_arithmetic(arrays): +def test_arithmetic(arrays) -> None: np_arr, xp_arr = arrays expected = np_arr + 7 actual = xp_arr + 7 From fa9ea14decf972d5daccd1e47cff96118f48e420 Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 19 Jul 2022 13:44:50 +0100 Subject: [PATCH 07/12] Update xarray/tests/test_array_api.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/tests/test_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index d63de8db4ed..5ff3f8c9943 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -24,7 +24,7 @@ def test_arithmetic(arrays) -> None: assert_equal(actual, expected) -def test_aggregation(arrays): +def test_aggregation(arrays) -> None: np_arr, xp_arr = arrays expected = np_arr.sum(skipna=False) actual = xp_arr.sum(skipna=False) From 3cc3cb40db9b0810ef48ae1beac705b6f7fd3f54 Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 19 Jul 2022 13:44:59 +0100 Subject: [PATCH 08/12] Update xarray/tests/test_array_api.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/tests/test_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index 5ff3f8c9943..1b7eb61125b 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -32,7 +32,7 @@ def test_aggregation(arrays) -> None: assert_equal(actual, expected) -def test_indexing(arrays): +def test_indexing(arrays) -> None: np_arr, xp_arr = arrays expected = np_arr[:, 0] actual = xp_arr[:, 0] From f6df2555cee2e04161ac3a0d9108f82c9ab46ccb Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 19 Jul 2022 13:45:07 +0100 Subject: [PATCH 09/12] Update xarray/tests/test_array_api.py Co-authored-by: Illviljan <14371165+Illviljan@users.noreply.github.com> --- xarray/tests/test_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index 1b7eb61125b..4419b79b9d5 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -40,7 +40,7 @@ def test_indexing(arrays) -> None: assert_equal(actual, expected) -def test_reorganizing_operation(arrays): +def test_reorganizing_operation(arrays) -> None: np_arr, xp_arr = arrays expected = np_arr.transpose() actual = xp_arr.transpose() From afe3d9f8f42bd285500a73eb7b4c9e3c9f089e62 Mon Sep 17 00:00:00 2001 From: Tom White Date: Tue, 19 Jul 2022 13:47:37 +0100 Subject: [PATCH 10/12] Fix import order --- xarray/tests/test_array_api.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index 4419b79b9d5..b5a0f4c41f6 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -1,4 +1,5 @@ -import numpy.array_api as xp +from typing import Tuple + import pytest from numpy.array_api._array_object import Array @@ -7,9 +8,11 @@ np = pytest.importorskip("numpy", minversion="1.22") +import numpy.array_api as xp # isort:skip + @pytest.fixture -def arrays() -> tuple[xr.DataArray, xr.DataArray]: +def arrays() -> Tuple[xr.DataArray, xr.DataArray]: np_arr = xr.DataArray(np.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]}) xp_arr = xr.DataArray(xp.ones((2, 3)), dims=("x", "y"), coords={"x": [10, 20]}) assert isinstance(xp_arr.data, Array) From 0f8120990e0052a1615650dffee089505920c3b6 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 19 Jul 2022 08:44:33 -0600 Subject: [PATCH 11/12] Fix import order --- xarray/tests/test_array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_array_api.py b/xarray/tests/test_array_api.py index b5a0f4c41f6..8e378054c29 100644 --- a/xarray/tests/test_array_api.py +++ b/xarray/tests/test_array_api.py @@ -1,7 +1,6 @@ from typing import Tuple import pytest -from numpy.array_api._array_object import Array import xarray as xr from xarray.testing import assert_equal @@ -9,6 +8,7 @@ np = pytest.importorskip("numpy", minversion="1.22") import numpy.array_api as xp # isort:skip +from numpy.array_api._array_object import Array # isort:skip @pytest.fixture From 6d7f13e3a777e0d4bdb4feb78f0b14dd916651ca Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Wed, 20 Jul 2022 00:52:18 -0500 Subject: [PATCH 12/12] update whatsnew --- doc/whats-new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 9f6f3622f71..f859f6c420a 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -30,6 +30,9 @@ New Features :py:meth:`coarsen`, :py:meth:`weighted`, :py:meth:`resample`, (:pull:`6702`) By `Michael Niklas `_. +- Experimental support for wrapping any array type that conforms to the python array api standard. + (:pull:`6804`) + By `Tom White `_. Deprecations ~~~~~~~~~~~~