Skip to content

Commit

Permalink
Propagate attrs with unary, binary functions (#4195)
Browse files Browse the repository at this point in the history
* Propagate attrs with unary, binary functions

Closes #3490
Closes #4065
Closes #3433
Closes #3595

* Un xfail test

* bugfix

* Some progress. Still need keep_attrs in DataArray._unary_op

* Fix dataset attrs

* whats-new

* small fix

* Fix imag, real

* fix variable tests

* fix multiple return variables.

* review comments

* Update doc/whats-new.rst

* Propagate attrs with DataArray unary ops

* More tests

* Small cleanup

* Review comments.

* Fix duplication

Co-authored-by: Maximilian Roos <[email protected]>
  • Loading branch information
dcherian and max-sixty authored Oct 14, 2020
1 parent 92e49f9 commit db4f03e
Show file tree
Hide file tree
Showing 10 changed files with 110 additions and 20 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ New Features
- :py:func:`open_dataset` and :py:func:`open_mfdataset`
now works with ``engine="zarr"`` (:issue:`3668`, :pull:`4003`, :pull:`4187`).
By `Miguel Jimenez <https://github.com/Mikejmnez>`_ and `Wei Ji Leong <https://github.com/weiji14>`_.
- Unary & binary operations follow the ``keep_attrs`` flag (:issue:`3490`, :issue:`4065`, :issue:`3433`, :issue:`3595`, :pull:`4195`).
By `Deepak Cherian <https://github.com/dcherian>`_.

Bug fixes
~~~~~~~~~
Expand Down
3 changes: 2 additions & 1 deletion xarray/core/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from .options import OPTIONS
from .options import OPTIONS, _get_keep_attrs
from .pycompat import dask_array_type
from .utils import not_implemented

Expand Down Expand Up @@ -77,6 +77,7 @@ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
dataset_fill_value=np.nan,
kwargs=kwargs,
dask="allowed",
keep_attrs=_get_keep_attrs(default=True),
)

# this has no runtime function - these are listed so IDEs know these
Expand Down
38 changes: 30 additions & 8 deletions xarray/core/computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@
_JOINS_WITHOUT_FILL_VALUES = frozenset({"inner", "exact"})


def _first_of_type(args, kind):
""" Return either first object of type 'kind' or raise if not found. """
for arg in args:
if isinstance(arg, kind):
return arg
raise ValueError("This should be unreachable.")


class _UFuncSignature:
"""Core dimensions signature for a given function.
Expand Down Expand Up @@ -252,8 +260,9 @@ def apply_dataarray_vfunc(
args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
)

if keep_attrs and hasattr(args[0], "name"):
name = args[0].name
if keep_attrs:
first_obj = _first_of_type(args, DataArray)
name = first_obj.name
else:
name = result_name(args)
result_coords = build_output_coords(args, signature, exclude_dims)
Expand All @@ -270,6 +279,14 @@ def apply_dataarray_vfunc(
(coords,) = result_coords
out = DataArray(result_var, coords, name=name, fastpath=True)

if keep_attrs:
if isinstance(out, tuple):
for da in out:
# This is adding attrs in place
da._copy_attrs_from(first_obj)
else:
out._copy_attrs_from(first_obj)

return out


Expand Down Expand Up @@ -390,15 +407,16 @@ def apply_dataset_vfunc(
"""
from .dataset import Dataset

first_obj = args[0] # we'll copy attrs from this in case keep_attrs=True

if dataset_join not in _JOINS_WITHOUT_FILL_VALUES and fill_value is _NO_FILL_VALUE:
raise TypeError(
"to apply an operation to datasets with different "
"data variables with apply_ufunc, you must supply the "
"dataset_fill_value argument."
)

if keep_attrs:
first_obj = _first_of_type(args, Dataset)

if len(args) > 1:
args = deep_align(
args, join=join, copy=False, exclude=exclude_dims, raise_on_invalid=False
Expand All @@ -417,9 +435,11 @@ def apply_dataset_vfunc(
(coord_vars,) = list_of_coords
out = _fast_dataset(result_vars, coord_vars)

if keep_attrs and isinstance(first_obj, Dataset):
if keep_attrs:
if isinstance(out, tuple):
out = tuple(ds._copy_attrs_from(first_obj) for ds in out)
for ds in out:
# This is adding attrs in place
ds._copy_attrs_from(first_obj)
else:
out._copy_attrs_from(first_obj)
return out
Expand Down Expand Up @@ -595,6 +615,8 @@ def apply_variable_ufunc(
"""Apply a ndarray level function over Variable and/or ndarray objects."""
from .variable import Variable, as_compatible_data

first_obj = _first_of_type(args, Variable)

dim_sizes = unified_dim_sizes(
(a for a in args if hasattr(a, "dims")), exclude_dims=exclude_dims
)
Expand Down Expand Up @@ -734,8 +756,8 @@ def func(*arrays):
)
)

if keep_attrs and isinstance(args[0], Variable):
var.attrs.update(args[0].attrs)
if keep_attrs:
var.attrs.update(first_obj.attrs)
output.append(var)

if signature.num_outputs == 1:
Expand Down
10 changes: 8 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from .indexes import Indexes, default_indexes, propagate_indexes
from .indexing import is_fancy_indexer
from .merge import PANDAS_TYPES, MergeError, _extract_indexes_from_coords
from .options import OPTIONS
from .options import OPTIONS, _get_keep_attrs
from .utils import Default, ReprObject, _check_inplace, _default, either_dict_or_kwargs
from .variable import (
IndexVariable,
Expand Down Expand Up @@ -2734,13 +2734,19 @@ def __rmatmul__(self, other):
def _unary_op(f: Callable[..., Any]) -> Callable[..., "DataArray"]:
@functools.wraps(f)
def func(self, *args, **kwargs):
keep_attrs = kwargs.pop("keep_attrs", None)
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", r"All-NaN (slice|axis) encountered")
warnings.filterwarnings(
"ignore", r"Mean of empty slice", category=RuntimeWarning
)
with np.errstate(all="ignore"):
return self.__array_wrap__(f(self.variable.data, *args, **kwargs))
da = self.__array_wrap__(f(self.variable.data, *args, **kwargs))
if keep_attrs:
da.attrs = self.attrs
return da

return func

Expand Down
18 changes: 13 additions & 5 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4403,12 +4403,15 @@ def map(
foo (dim_0, dim_1) float64 1.764 0.4002 0.9787 2.241 1.868 0.9773
bar (x) float64 1.0 2.0
"""
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
variables = {
k: maybe_wrap_array(v, func(v, *args, **kwargs))
for k, v in self.data_vars.items()
}
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)
if keep_attrs:
for k, v in variables.items():
v._copy_attrs_from(self.data_vars[k])
attrs = self.attrs if keep_attrs else None
return type(self)(variables, attrs=attrs)

Expand Down Expand Up @@ -4939,15 +4942,20 @@ def from_dict(cls, d):
return obj

@staticmethod
def _unary_op(f, keep_attrs=False):
def _unary_op(f):
@functools.wraps(f)
def func(self, *args, **kwargs):
variables = {}
keep_attrs = kwargs.pop("keep_attrs", None)
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)
for k, v in self._variables.items():
if k in self._coord_names:
variables[k] = v
else:
variables[k] = f(v, *args, **kwargs)
if keep_attrs:
variables[k].attrs = v._attrs
attrs = self._attrs if keep_attrs else None
return self._replace_with_new_dims(variables, attrs=attrs)

Expand Down Expand Up @@ -5684,11 +5692,11 @@ def _integrate_one(self, coord, datetime_unit=None):

@property
def real(self):
return self._unary_op(lambda x: x.real, keep_attrs=True)(self)
return self.map(lambda x: x.real, keep_attrs=True)

@property
def imag(self):
return self._unary_op(lambda x: x.imag, keep_attrs=True)(self)
return self.map(lambda x: x.imag, keep_attrs=True)

plot = utils.UncachedAccessor(_Dataset_PlotMethods)

Expand Down
8 changes: 7 additions & 1 deletion xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -2102,8 +2102,14 @@ def __array_wrap__(self, obj, context=None):
def _unary_op(f):
@functools.wraps(f)
def func(self, *args, **kwargs):
keep_attrs = kwargs.pop("keep_attrs", None)
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=True)
with np.errstate(all="ignore"):
return self.__array_wrap__(f(self.data, *args, **kwargs))
result = self.__array_wrap__(f(self.data, *args, **kwargs))
if keep_attrs:
result.attrs = self.attrs
return result

return func

Expand Down
25 changes: 24 additions & 1 deletion xarray/tests/test_dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@
import pytest

import xarray as xr
from xarray import DataArray, Dataset, IndexVariable, Variable, align, broadcast
from xarray import (
DataArray,
Dataset,
IndexVariable,
Variable,
align,
broadcast,
set_options,
)
from xarray.coding.times import CFDatetimeCoder
from xarray.convert import from_cdms2
from xarray.core import dtypes
Expand Down Expand Up @@ -2486,6 +2494,21 @@ def test_assign_attrs(self):
assert_identical(new_actual, expected)
assert actual.attrs == {"a": 1, "b": 2}

@pytest.mark.parametrize(
"func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs]
)
def test_propagate_attrs(self, func):
da = DataArray(self.va)

# test defaults
assert func(da).attrs == da.attrs

with set_options(keep_attrs=False):
assert func(da).attrs == {}

with set_options(keep_attrs=True):
assert func(da).attrs == da.attrs

def test_fillna(self):
a = DataArray([np.nan, 1, np.nan, 3], coords={"x": range(4)}, dims="x")
actual = a.fillna(-1)
Expand Down
22 changes: 22 additions & 0 deletions xarray/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4473,6 +4473,28 @@ def test_fillna(self):
assert actual.a.name == "a"
assert actual.a.attrs == ds.a.attrs

@pytest.mark.parametrize(
"func", [lambda x: x.clip(0, 1), lambda x: np.float64(1.0) * x, np.abs, abs]
)
def test_propagate_attrs(self, func):

da = DataArray(range(5), name="a", attrs={"attr": "da"})
ds = Dataset({"a": da}, attrs={"attr": "ds"})

# test defaults
assert func(ds).attrs == ds.attrs
with set_options(keep_attrs=False):
assert func(ds).attrs != ds.attrs
assert func(ds).a.attrs != ds.a.attrs

with set_options(keep_attrs=False):
assert func(ds).attrs != ds.attrs
assert func(ds).a.attrs != ds.a.attrs

with set_options(keep_attrs=True):
assert func(ds).attrs == ds.attrs
assert func(ds).a.attrs == ds.a.attrs

def test_where(self):
ds = Dataset({"a": ("x", range(5))})
expected = Dataset({"a": ("x", [np.nan, np.nan, 2, 3, 4])})
Expand Down
3 changes: 2 additions & 1 deletion xarray/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,8 @@ def test_1d_math(self):
assert_array_equal(y - v, 1 - v)
# verify attributes are dropped
v2 = self.cls(["x"], x, {"units": "meters"})
assert_identical(base_v, +v2)
with set_options(keep_attrs=False):
assert_identical(base_v, +v2)
# binary ops with all variables
assert_array_equal(v + v, 2 * v)
w = self.cls(["x"], y, {"foo": "bar"})
Expand Down
1 change: 0 additions & 1 deletion xarray/tests/test_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ def test_weighted_operations_keep_attr(operation, as_dataset, keep_attrs):
assert not result.attrs


@pytest.mark.xfail(reason="xr.Dataset.map does not copy attrs of DataArrays GH: 3595")
@pytest.mark.parametrize("operation", ("sum", "mean"))
def test_weighted_operations_keep_attr_da_in_ds(operation):
# GH #3595
Expand Down

0 comments on commit db4f03e

Please sign in to comment.