Skip to content

Commit

Permalink
Change isinstance checks to duck Dask Array checks #4208 (#4221)
Browse files Browse the repository at this point in the history
* Change isinstance checks to duck Dask Array checks #4208

* Use is_dask_collection in is_duck_dask_array

* Use is_dask_collection in is_duck_dask_array

* Revert to isinstance checks according to review discussion

* Move is_duck_dask_array to pycompat.py and use tokenize for comparisons

* isort

* Implement `is_duck_array` to replace `is_array_like`

* Rename `is_array_like` to `is_duck_array`

* `is_duck_array` checks for `__array_function__` and `__array_ufunc__`
  in addition to previous checks

* Replace checks for `is_duck_dask_array` and `__array_function__` with
  `is_duck_array`

* Skip numpy duck array tests when NEP18 is not active

* Use utils.is_duck_array in xarray/core/formatting.py

* Replace locally defined `is_duck_array` in _diff_mapping_repr

* Replace `"__array_function__"` and `is_duck_dask_array` check in
  `short_data_repr`

* Revert back to isinstance check for iris cube

* Add is_duck_array_or_ndarray function to utils

* Use is_duck_array_or_ndarray for duck array checks without NEP18

* Remove is_duck_dask_array_or_ndarray, replace checks with is_duck_array

* Add explicit check for NumPy array to is_duck_array

* Replace is_duck_array_or_ndarray checks with is_duck_array

* Remove is_duck_array check for deep copy

Co-authored-by: keewis <[email protected]>

* Use is_duck_array check in load

* Move duck dask array tokenize tests from test_units.py to test_dask.py

* Use _importorskip to require pint >=0.15 instead of pytest.mark.skipif

Co-authored-by: Deepak Cherian <[email protected]>
Co-authored-by: keewis <[email protected]>
  • Loading branch information
3 people authored Sep 2, 2020
1 parent 9ee0f01 commit dc2dd89
Show file tree
Hide file tree
Showing 23 changed files with 156 additions and 103 deletions.
4 changes: 2 additions & 2 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from ..conventions import cf_encoder
from ..core import indexing
from ..core.pycompat import dask_array_type
from ..core.pycompat import is_duck_dask_array
from ..core.utils import FrozenDict, NdimSizeLenMixin

# Create a logger object, but don't add any handlers. Leave that to user code.
Expand Down Expand Up @@ -134,7 +134,7 @@ def __init__(self, lock=None):
self.lock = lock

def add(self, source, target, region=None):
if isinstance(source, dask_array_type):
if is_duck_dask_array(source):
self.sources.append(source)
self.targets.append(target)
self.regions.append(region)
Expand Down
6 changes: 3 additions & 3 deletions xarray/coding/strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from ..core import indexing
from ..core.pycompat import dask_array_type
from ..core.pycompat import is_duck_dask_array
from ..core.variable import Variable
from .variables import (
VariableCoder,
Expand Down Expand Up @@ -130,7 +130,7 @@ def bytes_to_char(arr):
if arr.dtype.kind != "S":
raise ValueError("argument must have a fixed-width bytes dtype")

if isinstance(arr, dask_array_type):
if is_duck_dask_array(arr):
import dask.array as da

return da.map_blocks(
Expand Down Expand Up @@ -166,7 +166,7 @@ def char_to_bytes(arr):
# can't make an S0 dtype
return np.zeros(arr.shape[:-1], dtype=np.string_)

if isinstance(arr, dask_array_type):
if is_duck_dask_array(arr):
import dask.array as da

if len(arr.chunks[-1]) > 1:
Expand Down
6 changes: 3 additions & 3 deletions xarray/coding/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd

from ..core import dtypes, duck_array_ops, indexing
from ..core.pycompat import dask_array_type
from ..core.pycompat import is_duck_dask_array
from ..core.variable import Variable


Expand Down Expand Up @@ -54,7 +54,7 @@ class _ElementwiseFunctionArray(indexing.ExplicitlyIndexedNDArrayMixin):
"""

def __init__(self, array, func, dtype):
assert not isinstance(array, dask_array_type)
assert not is_duck_dask_array(array)
self.array = indexing.as_indexable(array)
self.func = func
self._dtype = dtype
Expand Down Expand Up @@ -91,7 +91,7 @@ def lazy_elemwise_func(array, func, dtype):
-------
Either a dask.array.Array or _ElementwiseFunctionArray.
"""
if isinstance(array, dask_array_type):
if is_duck_dask_array(array):
return array.map_blocks(func, dtype=dtype)
else:
return _ElementwiseFunctionArray(array, func, dtype)
Expand Down
6 changes: 3 additions & 3 deletions xarray/conventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .coding.variables import SerializationWarning, pop_to
from .core import duck_array_ops, indexing
from .core.common import contains_cftime_datetimes
from .core.pycompat import dask_array_type
from .core.pycompat import is_duck_dask_array
from .core.variable import IndexVariable, Variable, as_variable


Expand Down Expand Up @@ -178,7 +178,7 @@ def ensure_dtype_not_object(var, name=None):
if var.dtype.kind == "O":
dims, data, attrs, encoding = _var_as_tuple(var)

if isinstance(data, dask_array_type):
if is_duck_dask_array(data):
warnings.warn(
"variable {} has data in the form of a dask array with "
"dtype=object, which means it is being loaded into memory "
Expand Down Expand Up @@ -351,7 +351,7 @@ def decode_cf_variable(
del attributes["dtype"]
data = BoolTypeArray(data)

if not isinstance(data, dask_array_type):
if not is_duck_dask_array(data):
data = indexing.LazilyOuterIndexedArray(data)

return Variable(dimensions, data, attributes, encoding=encoding)
Expand Down
3 changes: 1 addition & 2 deletions xarray/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .core import duck_array_ops
from .core.dataarray import DataArray
from .core.dtypes import get_fill_value
from .core.pycompat import dask_array_type

cdms2_ignored_attrs = {"name", "tileIndex"}
iris_forbidden_keys = {
Expand Down Expand Up @@ -246,8 +247,6 @@ def from_iris(cube):
"""Convert a Iris cube into an DataArray"""
import iris.exceptions

from xarray.core.pycompat import dask_array_type

name = _name(cube)
if name == "unknown":
name = None
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/accessor_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
is_np_datetime_like,
is_np_timedelta_like,
)
from .pycompat import dask_array_type
from .pycompat import is_duck_dask_array


def _season_from_months(months):
Expand Down Expand Up @@ -69,7 +69,7 @@ def _get_date_field(values, name, dtype):
else:
access_method = _access_through_cftimeindex

if isinstance(values, dask_array_type):
if is_duck_dask_array(values):
from dask.array import map_blocks

return map_blocks(access_method, values, name, dtype=dtype)
Expand Down Expand Up @@ -114,7 +114,7 @@ def _round_field(values, name, freq):
Array-like of datetime fields accessed for each element in values
"""
if isinstance(values, dask_array_type):
if is_duck_dask_array(values):
from dask.array import map_blocks

dtype = np.datetime64 if is_np_datetime_like(values.dtype) else np.dtype("O")
Expand Down Expand Up @@ -151,7 +151,7 @@ def _strftime(values, date_format):
access_method = _strftime_through_series
else:
access_method = _strftime_through_cftimeindex
if isinstance(values, dask_array_type):
if is_duck_dask_array(values):
from dask.array import map_blocks

return map_blocks(access_method, values, date_format)
Expand Down
6 changes: 3 additions & 3 deletions xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .arithmetic import SupportsArithmetic
from .npcompat import DTypeLike
from .options import OPTIONS, _get_keep_attrs
from .pycompat import dask_array_type
from .pycompat import is_duck_dask_array
from .rolling_exp import RollingExp
from .utils import Frozen, either_dict_or_kwargs, is_scalar

Expand Down Expand Up @@ -1507,7 +1507,7 @@ def _full_like_variable(other, fill_value, dtype: DTypeLike = None):
if fill_value is dtypes.NA:
fill_value = dtypes.get_fill_value(dtype if dtype is not None else other.dtype)

if isinstance(other.data, dask_array_type):
if is_duck_dask_array(other.data):
import dask.array

if dtype is None:
Expand Down Expand Up @@ -1652,7 +1652,7 @@ def _contains_cftime_datetimes(array) -> bool:
else:
if array.dtype == np.dtype("O") and array.size > 0:
sample = array.ravel()[0]
if isinstance(sample, dask_array_type):
if is_duck_dask_array(sample):
sample = sample.compute()
if isinstance(sample, np.ndarray):
sample = sample.item()
Expand Down
8 changes: 4 additions & 4 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from .alignment import align, deep_align
from .merge import merge_coordinates_without_align
from .options import OPTIONS
from .pycompat import dask_array_type
from .pycompat import is_duck_dask_array
from .utils import is_dict_like
from .variable import Variable

Expand Down Expand Up @@ -610,7 +610,7 @@ def apply_variable_ufunc(
for arg, core_dims in zip(args, signature.input_core_dims)
]

if any(isinstance(array, dask_array_type) for array in input_data):
if any(is_duck_dask_array(array) for array in input_data):
if dask == "forbidden":
raise ValueError(
"apply_ufunc encountered a dask array on an "
Expand Down Expand Up @@ -726,7 +726,7 @@ def func(*arrays):

def apply_array_ufunc(func, *args, dask="forbidden"):
"""Apply a ndarray level function over ndarray objects."""
if any(isinstance(arg, dask_array_type) for arg in args):
if any(is_duck_dask_array(arg) for arg in args):
if dask == "forbidden":
raise ValueError(
"apply_ufunc encountered a dask array on an "
Expand Down Expand Up @@ -1604,7 +1604,7 @@ def _calc_idxminmax(
indx = func(array, dim=dim, axis=None, keep_attrs=keep_attrs, skipna=skipna)

# Handle dask arrays.
if isinstance(array.data, dask_array_type):
if is_duck_dask_array(array.data):
import dask.array

chunks = dict(zip(array.dims, array.chunks))
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dask_array_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from .pycompat import dask_array_type
from .pycompat import is_duck_dask_array

try:
import dask.array as da
Expand Down Expand Up @@ -39,7 +39,7 @@ def meta_from_array(x, ndim=None, dtype=None):
"""
# If using x._meta, x must be a Dask Array, some libraries (e.g. zarr)
# implement a _meta attribute that are incompatible with Dask Array._meta
if hasattr(x, "_meta") and isinstance(x, dask_array_type):
if hasattr(x, "_meta") and is_duck_dask_array(x):
x = x._meta

if dtype is None and x is None:
Expand Down
12 changes: 4 additions & 8 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
)
from .missing import get_clean_interp_index
from .options import OPTIONS, _get_keep_attrs
from .pycompat import dask_array_type
from .pycompat import is_duck_dask_array
from .utils import (
Default,
Frozen,
Expand Down Expand Up @@ -645,9 +645,7 @@ def load(self, **kwargs) -> "Dataset":
"""
# access .data to coerce everything to numpy or dask arrays
lazy_data = {
k: v._data
for k, v in self.variables.items()
if isinstance(v._data, dask_array_type)
k: v._data for k, v in self.variables.items() if is_duck_dask_array(v._data)
}
if lazy_data:
import dask.array as da
Expand Down Expand Up @@ -815,9 +813,7 @@ def _persist_inplace(self, **kwargs) -> "Dataset":
"""Persist all Dask arrays in memory"""
# access .data to coerce everything to numpy or dask arrays
lazy_data = {
k: v._data
for k, v in self.variables.items()
if isinstance(v._data, dask_array_type)
k: v._data for k, v in self.variables.items() if is_duck_dask_array(v._data)
}
if lazy_data:
import dask
Expand Down Expand Up @@ -6043,7 +6039,7 @@ def polyfit(
if dim not in da.dims:
continue

if isinstance(da.data, dask_array_type) and (
if is_duck_dask_array(da.data) and (
rank != order or full or skipna is None
):
# Current algorithm with dask and skipna=False neither supports
Expand Down
43 changes: 21 additions & 22 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@

from . import dask_array_compat, dask_array_ops, dtypes, npcompat, nputils
from .nputils import nanfirst, nanlast
from .pycompat import cupy_array_type, dask_array_type, sparse_array_type
from .pycompat import (
cupy_array_type,
dask_array_type,
is_duck_dask_array,
sparse_array_type,
)
from .utils import is_duck_array

try:
import dask.array as dask_array
from dask.base import tokenize
except ImportError:
dask_array = None # type: ignore

Expand All @@ -39,7 +46,7 @@ def f(*args, **kwargs):
dispatch_args = args[0]
else:
dispatch_args = args[array_args]
if any(isinstance(a, dask_array_type) for a in dispatch_args):
if any(is_duck_dask_array(a) for a in dispatch_args):
try:
wrapped = getattr(dask_module, name)
except AttributeError as e:
Expand All @@ -57,7 +64,7 @@ def f(*args, **kwargs):


def fail_on_dask_array_input(values, msg=None, func_name=None):
if isinstance(values, dask_array_type):
if is_duck_dask_array(values):
if msg is None:
msg = "%r is not yet a valid method on dask arrays"
if func_name is None:
Expand Down Expand Up @@ -129,7 +136,7 @@ def notnull(data):


def gradient(x, coord, axis, edge_order):
if isinstance(x, dask_array_type):
if is_duck_dask_array(x):
return dask_array.gradient(x, coord, axis=axis, edge_order=edge_order)
return np.gradient(x, coord, axis=axis, edge_order=edge_order)

Expand Down Expand Up @@ -174,11 +181,7 @@ def astype(data, **kwargs):


def asarray(data, xp=np):
return (
data
if (isinstance(data, dask_array_type) or hasattr(data, "__array_function__"))
else xp.asarray(data)
)
return data if is_duck_array(data) else xp.asarray(data)


def as_shared_dtype(scalars_or_arrays):
Expand All @@ -200,24 +203,20 @@ def as_shared_dtype(scalars_or_arrays):

def lazy_array_equiv(arr1, arr2):
"""Like array_equal, but doesn't actually compare values.
Returns True when arr1, arr2 identical or their dask names are equal.
Returns True when arr1, arr2 identical or their dask tokens are equal.
Returns False when shapes are not equal.
Returns None when equality cannot determined: one or both of arr1, arr2 are numpy arrays;
or their dask names are not equal
or their dask tokens are not equal
"""
if arr1 is arr2:
return True
arr1 = asarray(arr1)
arr2 = asarray(arr2)
if arr1.shape != arr2.shape:
return False
if (
dask_array
and isinstance(arr1, dask_array_type)
and isinstance(arr2, dask_array_type)
):
# GH3068
if arr1.name == arr2.name:
if dask_array and is_duck_dask_array(arr1) and is_duck_dask_array(arr2):
# GH3068, GH4221
if tokenize(arr1) == tokenize(arr2):
return True
else:
return None
Expand Down Expand Up @@ -331,7 +330,7 @@ def f(values, axis=None, skipna=None, **kwargs):
try:
return func(values, axis=axis, **kwargs)
except AttributeError:
if not isinstance(values, dask_array_type):
if not is_duck_dask_array(values):
raise
try: # dask/dask#3133 dask sometimes needs dtype argument
# if func does not accept dtype, then raises TypeError
Expand Down Expand Up @@ -545,7 +544,7 @@ def mean(array, axis=None, skipna=None, **kwargs):
+ offset
)
elif _contains_cftime_datetimes(array):
if isinstance(array, dask_array_type):
if is_duck_dask_array(array):
raise NotImplementedError(
"Computing the mean of an array containing "
"cftime.datetime objects is not yet implemented on "
Expand Down Expand Up @@ -614,15 +613,15 @@ def rolling_window(array, axis, window, center, fill_value):
Make an ndarray with a rolling window of axis-th dimension.
The rolling dimension will be placed at the last dimension.
"""
if isinstance(array, dask_array_type):
if is_duck_dask_array(array):
return dask_array_ops.rolling_window(array, axis, window, center, fill_value)
else: # np.ndarray
return nputils.rolling_window(array, axis, window, center, fill_value)


def least_squares(lhs, rhs, rcond=None, skipna=False):
"""Return the coefficients and residuals of a least-squares fit."""
if isinstance(rhs, dask_array_type):
if is_duck_dask_array(rhs):
return dask_array_ops.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)
else:
return nputils.least_squares(lhs, rhs, rcond=rcond, skipna=skipna)
Loading

0 comments on commit dc2dd89

Please sign in to comment.