From dc2dd89b999b16e08ba51e9cf623896b01be7297 Mon Sep 17 00:00:00 2001 From: Russell Manser Date: Wed, 2 Sep 2020 13:28:11 -0500 Subject: [PATCH] Change isinstance checks to duck Dask Array checks #4208 (#4221) * Change isinstance checks to duck Dask Array checks #4208 * Use is_dask_collection in is_duck_dask_array * Use is_dask_collection in is_duck_dask_array * Revert to isinstance checks according to review discussion * Move is_duck_dask_array to pycompat.py and use tokenize for comparisons * isort * Implement `is_duck_array` to replace `is_array_like` * Rename `is_array_like` to `is_duck_array` * `is_duck_array` checks for `__array_function__` and `__array_ufunc__` in addition to previous checks * Replace checks for `is_duck_dask_array` and `__array_function__` with `is_duck_array` * Skip numpy duck array tests when NEP18 is not active * Use utils.is_duck_array in xarray/core/formatting.py * Replace locally defined `is_duck_array` in _diff_mapping_repr * Replace `"__array_function__"` and `is_duck_dask_array` check in `short_data_repr` * Revert back to isinstance check for iris cube * Add is_duck_array_or_ndarray function to utils * Use is_duck_array_or_ndarray for duck array checks without NEP18 * Remove is_duck_dask_array_or_ndarray, replace checks with is_duck_array * Add explicit check for NumPy array to is_duck_array * Replace is_duck_array_or_ndarray checks with is_duck_array * Remove is_duck_array check for deep copy Co-authored-by: keewis * Use is_duck_array check in load * Move duck dask array tokenize tests from test_units.py to test_dask.py * Use _importorskip to require pint >=0.15 instead of pytest.mark.skipif Co-authored-by: Deepak Cherian Co-authored-by: keewis --- xarray/backends/common.py | 4 +-- xarray/coding/strings.py | 6 ++--- xarray/coding/variables.py | 6 ++--- xarray/conventions.py | 6 ++--- xarray/convert.py | 3 +-- xarray/core/accessor_dt.py | 8 +++--- xarray/core/common.py | 6 ++--- xarray/core/computation.py | 8 +++--- xarray/core/dask_array_compat.py | 4 +-- xarray/core/dataset.py | 12 +++------ xarray/core/duck_array_ops.py | 43 ++++++++++++++++---------------- xarray/core/formatting.py | 14 +++-------- xarray/core/indexing.py | 9 +++++-- xarray/core/missing.py | 5 ++-- xarray/core/pycompat.py | 12 ++++++++- xarray/core/rolling.py | 8 +++--- xarray/core/rolling_exp.py | 6 ++--- xarray/core/utils.py | 10 ++++++-- xarray/core/variable.py | 34 ++++++++++++------------- xarray/testing.py | 8 +++--- xarray/tests/__init__.py | 2 ++ xarray/tests/test_dask.py | 35 ++++++++++++++++++++++++++ xarray/tests/test_testing.py | 10 +++++++- 23 files changed, 156 insertions(+), 103 deletions(-) diff --git a/xarray/backends/common.py b/xarray/backends/common.py index da619905ce6..a8c5f61e7ef 100644 --- a/xarray/backends/common.py +++ b/xarray/backends/common.py @@ -6,7 +6,7 @@ from ..conventions import cf_encoder from ..core import indexing -from ..core.pycompat import dask_array_type +from ..core.pycompat import is_duck_dask_array from ..core.utils import FrozenDict, NdimSizeLenMixin # Create a logger object, but don't add any handlers. Leave that to user code. @@ -134,7 +134,7 @@ def __init__(self, lock=None): self.lock = lock def add(self, source, target, region=None): - if isinstance(source, dask_array_type): + if is_duck_dask_array(source): self.sources.append(source) self.targets.append(target) self.regions.append(region) diff --git a/xarray/coding/strings.py b/xarray/coding/strings.py index 8d7f777d1d5..dfe0175947c 100644 --- a/xarray/coding/strings.py +++ b/xarray/coding/strings.py @@ -4,7 +4,7 @@ import numpy as np from ..core import indexing -from ..core.pycompat import dask_array_type +from ..core.pycompat import is_duck_dask_array from ..core.variable import Variable from .variables import ( VariableCoder, @@ -130,7 +130,7 @@ def bytes_to_char(arr): if arr.dtype.kind != "S": raise ValueError("argument must have a fixed-width bytes dtype") - if isinstance(arr, dask_array_type): + if is_duck_dask_array(arr): import dask.array as da return da.map_blocks( @@ -166,7 +166,7 @@ def char_to_bytes(arr): # can't make an S0 dtype return np.zeros(arr.shape[:-1], dtype=np.string_) - if isinstance(arr, dask_array_type): + if is_duck_dask_array(arr): import dask.array as da if len(arr.chunks[-1]) > 1: diff --git a/xarray/coding/variables.py b/xarray/coding/variables.py index afb50fa517a..dd27bda107f 100644 --- a/xarray/coding/variables.py +++ b/xarray/coding/variables.py @@ -7,7 +7,7 @@ import pandas as pd from ..core import dtypes, duck_array_ops, indexing -from ..core.pycompat import dask_array_type +from ..core.pycompat import is_duck_dask_array from ..core.variable import Variable @@ -54,7 +54,7 @@ class _ElementwiseFunctionArray(indexing.ExplicitlyIndexedNDArrayMixin): """ def __init__(self, array, func, dtype): - assert not isinstance(array, dask_array_type) + assert not is_duck_dask_array(array) self.array = indexing.as_indexable(array) self.func = func self._dtype = dtype @@ -91,7 +91,7 @@ def lazy_elemwise_func(array, func, dtype): ------- Either a dask.array.Array or _ElementwiseFunctionArray. """ - if isinstance(array, dask_array_type): + if is_duck_dask_array(array): return array.map_blocks(func, dtype=dtype) else: return _ElementwiseFunctionArray(array, func, dtype) diff --git a/xarray/conventions.py b/xarray/conventions.py index da5ad7eea85..da69ce52527 100644 --- a/xarray/conventions.py +++ b/xarray/conventions.py @@ -8,7 +8,7 @@ from .coding.variables import SerializationWarning, pop_to from .core import duck_array_ops, indexing from .core.common import contains_cftime_datetimes -from .core.pycompat import dask_array_type +from .core.pycompat import is_duck_dask_array from .core.variable import IndexVariable, Variable, as_variable @@ -178,7 +178,7 @@ def ensure_dtype_not_object(var, name=None): if var.dtype.kind == "O": dims, data, attrs, encoding = _var_as_tuple(var) - if isinstance(data, dask_array_type): + if is_duck_dask_array(data): warnings.warn( "variable {} has data in the form of a dask array with " "dtype=object, which means it is being loaded into memory " @@ -351,7 +351,7 @@ def decode_cf_variable( del attributes["dtype"] data = BoolTypeArray(data) - if not isinstance(data, dask_array_type): + if not is_duck_dask_array(data): data = indexing.LazilyOuterIndexedArray(data) return Variable(dimensions, data, attributes, encoding=encoding) diff --git a/xarray/convert.py b/xarray/convert.py index 43e9ce94fb7..0fbd1e13163 100644 --- a/xarray/convert.py +++ b/xarray/convert.py @@ -10,6 +10,7 @@ from .core import duck_array_ops from .core.dataarray import DataArray from .core.dtypes import get_fill_value +from .core.pycompat import dask_array_type cdms2_ignored_attrs = {"name", "tileIndex"} iris_forbidden_keys = { @@ -246,8 +247,6 @@ def from_iris(cube): """Convert a Iris cube into an DataArray""" import iris.exceptions - from xarray.core.pycompat import dask_array_type - name = _name(cube) if name == "unknown": name = None diff --git a/xarray/core/accessor_dt.py b/xarray/core/accessor_dt.py index 214b4352c8a..a4ec7a2c30e 100644 --- a/xarray/core/accessor_dt.py +++ b/xarray/core/accessor_dt.py @@ -6,7 +6,7 @@ is_np_datetime_like, is_np_timedelta_like, ) -from .pycompat import dask_array_type +from .pycompat import is_duck_dask_array def _season_from_months(months): @@ -69,7 +69,7 @@ def _get_date_field(values, name, dtype): else: access_method = _access_through_cftimeindex - if isinstance(values, dask_array_type): + if is_duck_dask_array(values): from dask.array import map_blocks return map_blocks(access_method, values, name, dtype=dtype) @@ -114,7 +114,7 @@ def _round_field(values, name, freq): Array-like of datetime fields accessed for each element in values """ - if isinstance(values, dask_array_type): + if is_duck_dask_array(values): from dask.array import map_blocks dtype = np.datetime64 if is_np_datetime_like(values.dtype) else np.dtype("O") @@ -151,7 +151,7 @@ def _strftime(values, date_format): access_method = _strftime_through_series else: access_method = _strftime_through_cftimeindex - if isinstance(values, dask_array_type): + if is_duck_dask_array(values): from dask.array import map_blocks return map_blocks(access_method, values, date_format) diff --git a/xarray/core/common.py b/xarray/core/common.py index b693ed7832f..38803f821d4 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -23,7 +23,7 @@ from .arithmetic import SupportsArithmetic from .npcompat import DTypeLike from .options import OPTIONS, _get_keep_attrs -from .pycompat import dask_array_type +from .pycompat import is_duck_dask_array from .rolling_exp import RollingExp from .utils import Frozen, either_dict_or_kwargs, is_scalar @@ -1507,7 +1507,7 @@ def _full_like_variable(other, fill_value, dtype: DTypeLike = None): if fill_value is dtypes.NA: fill_value = dtypes.get_fill_value(dtype if dtype is not None else other.dtype) - if isinstance(other.data, dask_array_type): + if is_duck_dask_array(other.data): import dask.array if dtype is None: @@ -1652,7 +1652,7 @@ def _contains_cftime_datetimes(array) -> bool: else: if array.dtype == np.dtype("O") and array.size > 0: sample = array.ravel()[0] - if isinstance(sample, dask_array_type): + if is_duck_dask_array(sample): sample = sample.compute() if isinstance(sample, np.ndarray): sample = sample.item() diff --git a/xarray/core/computation.py b/xarray/core/computation.py index a2fec799a70..c6fea0e5cd1 100644 --- a/xarray/core/computation.py +++ b/xarray/core/computation.py @@ -29,7 +29,7 @@ from .alignment import align, deep_align from .merge import merge_coordinates_without_align from .options import OPTIONS -from .pycompat import dask_array_type +from .pycompat import is_duck_dask_array from .utils import is_dict_like from .variable import Variable @@ -610,7 +610,7 @@ def apply_variable_ufunc( for arg, core_dims in zip(args, signature.input_core_dims) ] - if any(isinstance(array, dask_array_type) for array in input_data): + if any(is_duck_dask_array(array) for array in input_data): if dask == "forbidden": raise ValueError( "apply_ufunc encountered a dask array on an " @@ -726,7 +726,7 @@ def func(*arrays): def apply_array_ufunc(func, *args, dask="forbidden"): """Apply a ndarray level function over ndarray objects.""" - if any(isinstance(arg, dask_array_type) for arg in args): + if any(is_duck_dask_array(arg) for arg in args): if dask == "forbidden": raise ValueError( "apply_ufunc encountered a dask array on an " @@ -1604,7 +1604,7 @@ def _calc_idxminmax( indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna) # Handle dask arrays. - if isinstance(array.data, dask_array_type): + if is_duck_dask_array(array.data): import dask.array chunks = dict(zip(array.dims, array.chunks)) diff --git a/xarray/core/dask_array_compat.py b/xarray/core/dask_array_compat.py index 4efbaad0855..50c87eacde1 100644 --- a/xarray/core/dask_array_compat.py +++ b/xarray/core/dask_array_compat.py @@ -4,7 +4,7 @@ import numpy as np -from .pycompat import dask_array_type +from .pycompat import is_duck_dask_array try: import dask.array as da @@ -39,7 +39,7 @@ def meta_from_array(x, ndim=None, dtype=None): """ # If using x._meta, x must be a Dask Array, some libraries (e.g. zarr) # implement a _meta attribute that are incompatible with Dask Array._meta - if hasattr(x, "_meta") and isinstance(x, dask_array_type): + if hasattr(x, "_meta") and is_duck_dask_array(x): x = x._meta if dtype is None and x is None: diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index dbbae01dd22..92de628f5ad 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -80,7 +80,7 @@ ) from .missing import get_clean_interp_index from .options import OPTIONS, _get_keep_attrs -from .pycompat import dask_array_type +from .pycompat import is_duck_dask_array from .utils import ( Default, Frozen, @@ -645,9 +645,7 @@ def load(self, **kwargs) -> "Dataset": """ # access .data to coerce everything to numpy or dask arrays lazy_data = { - k: v._data - for k, v in self.variables.items() - if isinstance(v._data, dask_array_type) + k: v._data for k, v in self.variables.items() if is_duck_dask_array(v._data) } if lazy_data: import dask.array as da @@ -815,9 +813,7 @@ def _persist_inplace(self, **kwargs) -> "Dataset": """Persist all Dask arrays in memory""" # access .data to coerce everything to numpy or dask arrays lazy_data = { - k: v._data - for k, v in self.variables.items() - if isinstance(v._data, dask_array_type) + k: v._data for k, v in self.variables.items() if is_duck_dask_array(v._data) } if lazy_data: import dask @@ -6043,7 +6039,7 @@ def polyfit( if dim not in da.dims: continue - if isinstance(da.data, dask_array_type) and ( + if is_duck_dask_array(da.data) and ( rank != order or full or skipna is None ): # Current algorithm with dask and skipna=False neither supports diff --git a/xarray/core/duck_array_ops.py b/xarray/core/duck_array_ops.py index 16bdd0e0fa6..53849a3eac8 100644 --- a/xarray/core/duck_array_ops.py +++ b/xarray/core/duck_array_ops.py @@ -15,10 +15,17 @@ from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils from .nputils import nanfirst, nanlast -from .pycompat import cupy_array_type, dask_array_type, sparse_array_type +from .pycompat import ( + cupy_array_type, + dask_array_type, + is_duck_dask_array, + sparse_array_type, +) +from .utils import is_duck_array try: import dask.array as dask_array + from dask.base import tokenize except ImportError: dask_array = None # type: ignore @@ -39,7 +46,7 @@ def f(*args, **kwargs): dispatch_args = args[0] else: dispatch_args = args[array_args] - if any(isinstance(a, dask_array_type) for a in dispatch_args): + if any(is_duck_dask_array(a) for a in dispatch_args): try: wrapped = getattr(dask_module, name) except AttributeError as e: @@ -57,7 +64,7 @@ def f(*args, **kwargs): def fail_on_dask_array_input(values, msg=None, func_name=None): - if isinstance(values, dask_array_type): + if is_duck_dask_array(values): if msg is None: msg = "%r is not yet a valid method on dask arrays" if func_name is None: @@ -129,7 +136,7 @@ def notnull(data): def gradient(x, coord, axis, edge_order): - if isinstance(x, dask_array_type): + if is_duck_dask_array(x): return dask_array.gradient(x, coord, axis=axis, edge_order=edge_order) return np.gradient(x, coord, axis=axis, edge_order=edge_order) @@ -174,11 +181,7 @@ def astype(data, **kwargs): def asarray(data, xp=np): - return ( - data - if (isinstance(data, dask_array_type) or hasattr(data, "__array_function__")) - else xp.asarray(data) - ) + return data if is_duck_array(data) else xp.asarray(data) def as_shared_dtype(scalars_or_arrays): @@ -200,10 +203,10 @@ def as_shared_dtype(scalars_or_arrays): def lazy_array_equiv(arr1, arr2): """Like array_equal, but doesn't actually compare values. - Returns True when arr1, arr2 identical or their dask names are equal. + Returns True when arr1, arr2 identical or their dask tokens are equal. Returns False when shapes are not equal. Returns None when equality cannot determined: one or both of arr1, arr2 are numpy arrays; - or their dask names are not equal + or their dask tokens are not equal """ if arr1 is arr2: return True @@ -211,13 +214,9 @@ def lazy_array_equiv(arr1, arr2): arr2 = asarray(arr2) if arr1.shape != arr2.shape: return False - if ( - dask_array - and isinstance(arr1, dask_array_type) - and isinstance(arr2, dask_array_type) - ): - # GH3068 - if arr1.name == arr2.name: + if dask_array and is_duck_dask_array(arr1) and is_duck_dask_array(arr2): + # GH3068, GH4221 + if tokenize(arr1) == tokenize(arr2): return True else: return None @@ -331,7 +330,7 @@ def f(values, axis=None, skipna=None, **kwargs): try: return func(values, axis=axis, **kwargs) except AttributeError: - if not isinstance(values, dask_array_type): + if not is_duck_dask_array(values): raise try: # dask/dask#3133 dask sometimes needs dtype argument # if func does not accept dtype, then raises TypeError @@ -545,7 +544,7 @@ def mean(array, axis=None, skipna=None, **kwargs): + offset ) elif _contains_cftime_datetimes(array): - if isinstance(array, dask_array_type): + if is_duck_dask_array(array): raise NotImplementedError( "Computing the mean of an array containing " "cftime.datetime objects is not yet implemented on " @@ -614,7 +613,7 @@ def rolling_window(array, axis, window, center, fill_value): Make an ndarray with a rolling window of axis-th dimension. The rolling dimension will be placed at the last dimension. """ - if isinstance(array, dask_array_type): + if is_duck_dask_array(array): return dask_array_ops.rolling_window(array, axis, window, center, fill_value) else: # np.ndarray return nputils.rolling_window(array, axis, window, center, fill_value) @@ -622,7 +621,7 @@ def rolling_window(array, axis, window, center, fill_value): def least_squares(lhs, rhs, rcond=None, skipna=False): """Return the coefficients and residuals of a least-squares fit.""" - if isinstance(rhs, dask_array_type): + if is_duck_dask_array(rhs): return dask_array_ops.least_squares(lhs, rhs, rcond=rcond, skipna=skipna) else: return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna) diff --git a/xarray/core/formatting.py b/xarray/core/formatting.py index e06ca4bd0f8..3ed8c6dc241 100644 --- a/xarray/core/formatting.py +++ b/xarray/core/formatting.py @@ -13,6 +13,7 @@ from .duck_array_ops import array_equiv from .options import OPTIONS from .pycompat import dask_array_type, sparse_array_type +from .utils import is_duck_array def pretty_print(x, numchars: int): @@ -457,9 +458,7 @@ def short_data_repr(array): internal_data = getattr(array, "variable", array)._data if isinstance(array, np.ndarray): return short_numpy_repr(array) - elif hasattr(internal_data, "__array_function__") or isinstance( - internal_data, dask_array_type - ): + elif is_duck_array(internal_data): return limit_lines(repr(array.data), limit=40) elif array._in_memory or array.size < 1e5: return short_numpy_repr(array) @@ -527,13 +526,6 @@ def diff_dim_summary(a, b): def _diff_mapping_repr(a_mapping, b_mapping, compat, title, summarizer, col_width=None): - def is_array_like(value): - return ( - hasattr(value, "ndim") - and hasattr(value, "shape") - and hasattr(value, "dtype") - ) - def extra_items_repr(extra_keys, mapping, ab_side): extra_repr = [summarizer(k, mapping[k], col_width) for k in extra_keys] if extra_repr: @@ -559,7 +551,7 @@ def extra_items_repr(extra_keys, mapping, ab_side): is_variable = True except AttributeError: # compare attribute value - if is_array_like(a_mapping[k]) or is_array_like(b_mapping[k]): + if is_duck_array(a_mapping[k]) or is_duck_array(b_mapping[k]): compatible = array_equiv(a_mapping[k], b_mapping[k]) else: compatible = a_mapping[k] == b_mapping[k] diff --git a/xarray/core/indexing.py b/xarray/core/indexing.py index 68c61ac13dd..919a9cf5293 100644 --- a/xarray/core/indexing.py +++ b/xarray/core/indexing.py @@ -11,7 +11,12 @@ from . import duck_array_ops, nputils, utils from .npcompat import DTypeLike -from .pycompat import dask_array_type, integer_types, sparse_array_type +from .pycompat import ( + dask_array_type, + integer_types, + is_duck_dask_array, + sparse_array_type, +) from .utils import is_dict_like, maybe_cast_to_coords_dtype @@ -1108,7 +1113,7 @@ def _masked_result_drop_slice(key, data=None): new_keys = [] for k in key: if isinstance(k, np.ndarray): - if isinstance(data, dask_array_type): + if is_duck_dask_array(data): new_keys.append(_dask_array_with_chunks_hint(k, chunks_hint)) elif isinstance(data, sparse_array_type): import sparse diff --git a/xarray/core/missing.py b/xarray/core/missing.py index 7a5ffa48f77..f608468ed9f 100644 --- a/xarray/core/missing.py +++ b/xarray/core/missing.py @@ -11,8 +11,9 @@ from . import utils from .common import _contains_datetime_like_objects, ones_like from .computation import apply_ufunc -from .duck_array_ops import dask_array_type, datetime_to_numeric, timedelta_to_numeric +from .duck_array_ops import datetime_to_numeric, timedelta_to_numeric from .options import _get_keep_attrs +from .pycompat import is_duck_dask_array from .utils import OrderedSet, is_scalar from .variable import Variable, broadcast_variables @@ -695,7 +696,7 @@ def interp_func(var, x, new_x, method, kwargs): else: func, kwargs = _get_interpolator_nd(method, **kwargs) - if isinstance(var, dask_array_type): + if is_duck_dask_array(var): import dask.array as da nconst = var.ndim - len(x) diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index dcb78d17cf8..8d613038957 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -1,14 +1,24 @@ import numpy as np +from .utils import is_duck_array + integer_types = (int, np.integer) try: - # solely for isinstance checks import dask.array + from dask.base import is_dask_collection + # solely for isinstance checks dask_array_type = (dask.array.Array,) + + def is_duck_dask_array(x): + return is_duck_array(x) and is_dask_collection(x) + + except ImportError: # pragma: no cover dask_array_type = () + is_duck_dask_array = lambda _: False + is_dask_collection = lambda _: False try: # solely for isinstance checks diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 4af9a4bb0f7..0c4614e0b57 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -8,7 +8,7 @@ from .dask_array_ops import dask_rolling_wrapper from .ops import inject_reduce_methods from .options import _get_keep_attrs -from .pycompat import dask_array_type +from .pycompat import is_duck_dask_array try: import bottleneck @@ -376,7 +376,7 @@ def _bottleneck_reduce(self, func, **kwargs): padded = self.obj.variable if self.center[0]: - if isinstance(padded.data, dask_array_type): + if is_duck_dask_array(padded.data): # Workaround to make the padded chunk size is larger than # self.window-1 shift = -(self.window[0] + 1) // 2 @@ -389,7 +389,7 @@ def _bottleneck_reduce(self, func, **kwargs): valid = (slice(None),) * axis + (slice(-shift, None),) padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant") - if isinstance(padded.data, dask_array_type): + if is_duck_dask_array(padded.data): raise AssertionError("should not be reachable") values = dask_rolling_wrapper( func, padded.data, window=self.window[0], min_count=min_count, axis=axis @@ -418,7 +418,7 @@ def _numpy_or_bottleneck_reduce( if ( bottleneck_move_func is not None - and not isinstance(self.obj.data, dask_array_type) + and not is_duck_dask_array(self.obj.data) and len(self.dim) == 1 ): # TODO: renable bottleneck with dask after the issues diff --git a/xarray/core/rolling_exp.py b/xarray/core/rolling_exp.py index 525867cc025..96444f0f864 100644 --- a/xarray/core/rolling_exp.py +++ b/xarray/core/rolling_exp.py @@ -1,7 +1,7 @@ import numpy as np from .pdcompat import count_not_none -from .pycompat import dask_array_type +from .pycompat import is_duck_dask_array def _get_alpha(com=None, span=None, halflife=None, alpha=None): @@ -13,8 +13,8 @@ def _get_alpha(com=None, span=None, halflife=None, alpha=None): def move_exp_nanmean(array, *, axis, alpha): - if isinstance(array, dask_array_type): - raise TypeError("rolling_exp is not currently support for dask arrays") + if is_duck_dask_array(array): + raise TypeError("rolling_exp is not currently support for dask-like arrays") import numbagg if axis == (): diff --git a/xarray/core/utils.py b/xarray/core/utils.py index 0952d185f85..cfb627f7af5 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -247,9 +247,15 @@ def is_list_like(value: Any) -> bool: return isinstance(value, list) or isinstance(value, tuple) -def is_array_like(value: Any) -> bool: +def is_duck_array(value: Any) -> bool: + if isinstance(value, np.ndarray): + return True return ( - hasattr(value, "ndim") and hasattr(value, "shape") and hasattr(value, "dtype") + hasattr(value, "ndim") + and hasattr(value, "shape") + and hasattr(value, "dtype") + and hasattr(value, "__array_function__") + and hasattr(value, "__array_ufunc__") ) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 98d74ebbe1f..203f7437914 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -33,7 +33,12 @@ ) from .npcompat import IS_NEP18_ACTIVE from .options import _get_keep_attrs -from .pycompat import cupy_array_type, dask_array_type, integer_types +from .pycompat import ( + cupy_array_type, + dask_array_type, + integer_types, + is_duck_dask_array, +) from .utils import ( OrderedSet, _default, @@ -42,6 +47,7 @@ either_dict_or_kwargs, ensure_us_time_resolution, infix_dims, + is_duck_array, ) NON_NUMPY_SUPPORTED_ARRAY_TYPES = ( @@ -347,9 +353,7 @@ def _in_memory(self): @property def data(self): - if hasattr(self._data, "__array_function__") or isinstance( - self._data, dask_array_type - ): + if is_duck_array(self._data): return self._data else: return self.values @@ -427,9 +431,9 @@ def load(self, **kwargs): -------- dask.array.compute """ - if isinstance(self._data, dask_array_type): + if is_duck_dask_array(self._data): self._data = as_compatible_data(self._data.compute(**kwargs)) - elif not hasattr(self._data, "__array_function__"): + elif not is_duck_array(self._data): self._data = np.asarray(self._data) return self @@ -462,7 +466,7 @@ def __dask_tokenize__(self): return normalize_token((type(self), self._dims, self.data, self._attrs)) def __dask_graph__(self): - if isinstance(self._data, dask_array_type): + if is_duck_dask_array(self._data): return self._data.__dask_graph__() else: return None @@ -788,7 +792,7 @@ def _getitem_with_mask(self, key, fill_value=dtypes.NA): dims, indexer, new_order = self._broadcast_indexes(key) if self.size: - if isinstance(self._data, dask_array_type): + if is_duck_dask_array(self._data): # dask's indexing is faster this way; also vindex does not # support negative indices yet: # https://github.com/dask/dask/pull/2967 @@ -932,11 +936,7 @@ def copy(self, deep=True, data=None): # don't share caching between copies data = indexing.MemoryCachedArray(data.array) - if deep and ( - hasattr(data, "__array_function__") - or isinstance(data, dask_array_type) - or (not IS_NEP18_ACTIVE and isinstance(data, np.ndarray)) - ): + if deep: data = copy.deepcopy(data) else: @@ -1024,7 +1024,7 @@ def chunk(self, chunks=None, name=None, lock=False): chunks = self.chunks or self.shape data = self._data - if isinstance(data, da.Array): + if is_duck_dask_array(data): data = data.rechunk(chunks) else: if isinstance(data, indexing.ExplicitlyIndexed): @@ -1171,7 +1171,7 @@ def _shift_one_dim(self, dim, count, fill_value=dtypes.NA): constant_values=fill_value, ) - if isinstance(data, dask_array_type): + if is_duck_dask_array(data): # chunked data should come out with the same chunks; this makes # it feasible to combine shifted and unshifted data # TODO: remove this once dask.array automatically aligns chunks @@ -1330,7 +1330,7 @@ def _roll_one_dim(self, dim, count): data = duck_array_ops.concatenate(arrays, axis) - if isinstance(data, dask_array_type): + if is_duck_dask_array(data): # chunked data should come out with the same chunks; this makes # it feasible to combine shifted and unshifted data # TODO: remove this once dask.array automatically aligns chunks @@ -1902,7 +1902,7 @@ def rank(self, dim, pct=False): data = self.data - if isinstance(data, dask_array_type): + if is_duck_dask_array(data): raise TypeError( "rank does not work for arrays stored as dask " "arrays. Load the data via .compute() or .load() " diff --git a/xarray/testing.py b/xarray/testing.py index 13efd57579c..ca72a4bee8e 100644 --- a/xarray/testing.py +++ b/xarray/testing.py @@ -196,14 +196,14 @@ def assert_duckarray_equal(x, y, err_msg="", verbose=True): """ Like `np.testing.assert_array_equal`, but for duckarrays """ __tracebackhide__ = True - if not utils.is_array_like(x) and not utils.is_scalar(x): + if not utils.is_duck_array(x) and not utils.is_scalar(x): x = np.asarray(x) - if not utils.is_array_like(y) and not utils.is_scalar(y): + if not utils.is_duck_array(y) and not utils.is_scalar(y): y = np.asarray(y) - if (utils.is_array_like(x) and utils.is_scalar(y)) or ( - utils.is_scalar(x) and utils.is_array_like(y) + if (utils.is_duck_array(x) and utils.is_scalar(y)) or ( + utils.is_scalar(x) and utils.is_duck_array(y) ): equiv = (x == y).all() else: diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 6ad30007f9f..9e1fdc0df33 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -78,6 +78,8 @@ def LooseVersion(vstring): has_seaborn, requires_seaborn = _importorskip("seaborn") has_sparse, requires_sparse = _importorskip("sparse") has_cartopy, requires_cartopy = _importorskip("cartopy") +# Need Pint 0.15 for __dask_tokenize__ tests for Quantity wrapped Dask Arrays +has_pint_0_15, requires_pint_0_15 = _importorskip("pint", minversion="0.15") # some special cases has_scipy_or_netCDF4 = has_scipy or has_netCDF4 diff --git a/xarray/tests/test_dask.py b/xarray/tests/test_dask.py index 358ea731b90..46685a29a47 100644 --- a/xarray/tests/test_dask.py +++ b/xarray/tests/test_dask.py @@ -24,6 +24,7 @@ assert_frame_equal, assert_identical, raises_regex, + requires_pint_0_15, requires_scipy_or_netCDF4, ) from .test_backends import create_tmp_file @@ -292,6 +293,22 @@ def test_persist(self): self.assertLazyAndAllClose(u + 1, v) self.assertLazyAndAllClose(u + 1, v2) + @requires_pint_0_15(reason="Need __dask_tokenize__") + def test_tokenize_duck_dask_array(self): + import pint + + unit_registry = pint.UnitRegistry() + + q = unit_registry.Quantity(self.data, "meter") + variable = xr.Variable(("x", "y"), q) + + token = dask.base.tokenize(variable) + post_op = variable + 5 * unit_registry.meter + + assert dask.base.tokenize(variable) != dask.base.tokenize(post_op) + # Immutability check + assert dask.base.tokenize(variable) == token + class TestDataArrayAndDataset(DaskTestCase): def assertLazyAndIdentical(self, expected, actual): @@ -715,6 +732,24 @@ def test_from_dask_variable(self): a = DataArray(self.lazy_array.variable, coords={"x": range(4)}, name="foo") self.assertLazyAndIdentical(self.lazy_array, a) + @requires_pint_0_15(reason="Need __dask_tokenize__") + def test_tokenize_duck_dask_array(self): + import pint + + unit_registry = pint.UnitRegistry() + + q = unit_registry.Quantity(self.data, unit_registry.meter) + data_array = xr.DataArray( + data=q, coords={"x": range(4)}, dims=("x", "y"), name="foo" + ) + + token = dask.base.tokenize(data_array) + post_op = data_array + 5 * unit_registry.meter + + assert dask.base.tokenize(data_array) != dask.base.tokenize(post_op) + # Immutability check + assert dask.base.tokenize(data_array) == token + class TestToDaskDataFrame: def test_to_dask_dataframe(self): diff --git a/xarray/tests/test_testing.py b/xarray/tests/test_testing.py index adc29a3cc92..0f2ae8b31d4 100644 --- a/xarray/tests/test_testing.py +++ b/xarray/tests/test_testing.py @@ -2,6 +2,7 @@ import pytest import xarray as xr +from xarray.core.npcompat import IS_NEP18_ACTIVE from . import has_dask @@ -98,7 +99,14 @@ def test_assert_duckarray_equal_failing(duckarray, obj1, obj2): @pytest.mark.parametrize( "duckarray", ( - pytest.param(np.array, id="numpy"), + pytest.param( + np.array, + id="numpy", + marks=pytest.mark.skipif( + not IS_NEP18_ACTIVE, + reason="NUMPY_EXPERIMENTAL_ARRAY_FUNCTION is not enabled", + ), + ), pytest.param( dask_from_array, id="dask",