From 1dd6b3583bec0f049916fdc83ef3c045feab9556 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Sun, 12 Jul 2020 20:59:56 +0900 Subject: [PATCH 01/23] nd-rolling --- xarray/core/dask_array_ops.py | 8 +++ xarray/core/nputils.py | 22 ++++-- xarray/core/rolling.py | 121 +++++++++++++++++++++------------ xarray/core/variable.py | 21 ++++-- xarray/tests/test_dataarray.py | 15 +++- xarray/tests/test_dataset.py | 1 + xarray/tests/test_nputils.py | 20 ++++++ 7 files changed, 150 insertions(+), 58 deletions(-) diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 87f646352eb..c94b8f1c380 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -32,6 +32,14 @@ def rolling_window(a, axis, window, center, fill_value): """ import dask.array as da + # for nd-rolling. + # TODO It can be more efficient. Currently, the chunks at the boundaries + # will be copied, but it might be OK for many-chunked-arrays. + if hasattr(axis, '__len__'): + for ax, win, cen in zip(axis, window, center): + a = rolling_window(a, ax, win, cen, fill_value) + return a + orig_shape = a.shape if axis < 0: axis = a.ndim + axis diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index fa6df63e0ea..f77e005e5dc 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -135,14 +135,22 @@ def __setitem__(self, key, value): def rolling_window(a, axis, window, center, fill_value): """ rolling window with padding. """ pads = [(0, 0) for s in a.shape] - if center: - start = int(window / 2) # 10 -> 5, 9 -> 4 - end = window - 1 - start - pads[axis] = (start, end) - else: - pads[axis] = (window - 1, 0) + if not hasattr(axis, '__len__'): + axis = [axis] + window = [window] + center = [center] + + for ax, win, cent in zip(axis, window, center): + if cent: + start = int(win / 2) # 10 -> 5, 9 -> 4 + end = win - 1 - start + pads[ax] = (start, end) + else: + pads[ax] = (win - 1, 0) a = np.pad(a, pads, mode="constant", constant_values=fill_value) - return _rolling_window(a, window, axis) + for ax, win in zip(axis, window): + a = _rolling_window(a, win, ax) + return a def _rolling_window(a, window, axis=-1): diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index ecba5307680..86c0314aae4 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -75,21 +75,27 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None ------- rolling : type of input argument """ - if len(windows) != 1: - raise ValueError("exactly one dim/window should be provided") + dim = list(windows.keys()) + window = list(windows.values()) - dim, window = next(iter(windows.items())) - - if window <= 0: + if any([w <= 0 for w in window]): raise ValueError("window must be > 0") + if center is None or isinstance(center, bool): + center = [center] * len(dim) + + # TODO support nd-min_periods + if hasattr(min_periods, '__len__'): + raise NotImplementedError('multiple min_periods is not yet supported.') + min_periods = [min_periods] * len(dim) + self.obj = obj # attributes self.window = window - if min_periods is not None and min_periods <= 0: + if any(mp is not None and mp <= 0 for mp in min_periods): raise ValueError("min_periods must be greater than zero or None") - self.min_periods = min_periods + self.min_periods = [w if mp is None else mp for mp, w in zip(min_periods, window)] self.center = center self.dim = dim @@ -98,17 +104,13 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None keep_attrs = _get_keep_attrs(default=False) self.keep_attrs = keep_attrs - @property - def _min_periods(self): - return self.min_periods if self.min_periods is not None else self.window - def __repr__(self): """provide a nice str repr of our rolling object""" attrs = [ "{k}->{v}".format(k=k, v=getattr(self, k)) - for k in self._attributes - if getattr(self, k, None) is not None + for k in + list(self.dim) + list(self.window) + list(self.center) + list(self.min_periods) ] return "{klass} [{attrs}]".format( klass=self.__class__.__name__, attrs=",".join(attrs) @@ -142,8 +144,10 @@ def method(self, **kwargs): median = _reduce_method("median") def count(self): + if len(self.dim) > 1: + raise NotImplementedError('count is not implemented for nd-rolling.') rolling_count = self._counts() - enough_periods = rolling_count >= self._min_periods + enough_periods = rolling_count >= self.min_periods[0] return rolling_count.where(enough_periods) count.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="count") @@ -195,18 +199,21 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None super().__init__( obj, windows, min_periods=min_periods, center=center, keep_attrs=keep_attrs ) - - self.window_labels = self.obj[self.dim] + + # TODO legacy attribute + self.window_labels = self.obj[self.dim[0]] def __iter__(self): + if len(self.dim) > 1: + raise ValueError('__iter__ is only supported for 1d-rolling') stops = np.arange(1, len(self.window_labels) + 1) - starts = stops - int(self.window) - starts[: int(self.window)] = 0 + starts = stops - int(self.window[0]) + starts[: int(self.window[0])] = 0 for (label, start, stop) in zip(self.window_labels, starts, stops): - window = self.obj.isel(**{self.dim: slice(start, stop)}) + window = self.obj.isel(**{self.dim[0]: slice(start, stop)}) - counts = window.count(dim=self.dim) - window = window.where(counts >= self._min_periods) + counts = window.count(dim=self.dim[0]) + window = window.where(counts >= self.min_periods[0]) yield (label, window) @@ -250,14 +257,18 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA): """ from .dataarray import DataArray - + + if len(self.dim) == 1 and not isinstance(window_dim, list): + window_dim = [window_dim] + if isinstance(stride, int): + stride = [stride] * len(self.dim) window = self.obj.variable.rolling_window( self.dim, self.window, window_dim, self.center, fill_value=fill_value ) result = DataArray( - window, dims=self.obj.dims + (window_dim,), coords=self.obj.coords + window, dims=self.obj.dims + tuple(window_dim), coords=self.obj.coords ) - return result.isel(**{self.dim: slice(None, None, stride)}) + return result.isel(**{d: slice(None, None, s) for d, s in zip(self.dim, stride)}) def reduce(self, func, **kwargs): """Reduce the items in this group by applying `func` along some @@ -300,25 +311,29 @@ def reduce(self, func, **kwargs): [ 4., 9., 15., 18.]]) """ - rolling_dim = utils.get_temp_dimname(self.obj.dims, "_rolling_dim") + rolling_dim = [ + utils.get_temp_dimname(self.obj.dims, "_rolling_dim_{}".format(d)) + for d in self.dim] windows = self.construct(rolling_dim) result = windows.reduce(func, dim=rolling_dim, **kwargs) # Find valid windows based on count. counts = self._counts() - return result.where(counts >= self._min_periods) + return result.where(counts >= self.min_periods[0]) def _counts(self): """ Number of non-nan entries in each rolling window. """ - rolling_dim = utils.get_temp_dimname(self.obj.dims, "_rolling_dim") + rolling_dim = [ + utils.get_temp_dimname(self.obj.dims, "_rolling_dim_{}".format(d)) + for d in self.dim] # We use False as the fill_value instead of np.nan, since boolean # array is faster to be reduced than object array. # The use of skipna==False is also faster since it does not need to # copy the strided array. counts = ( self.obj.notnull() - .rolling(center=self.center, **{self.dim: self.window}) + .rolling(center=self.center, **{d: w for d, w in zip(self.dim, self.window)}) .construct(rolling_dim, fill_value=False) .sum(dim=rolling_dim, skipna=False) ) @@ -329,39 +344,40 @@ def _bottleneck_reduce(self, func, **kwargs): # bottleneck doesn't allow min_count to be 0, although it should # work the same as if min_count = 1 - if self.min_periods is not None and self.min_periods == 0: + # Note bottleneck only works with 1d-rolling. + if self.min_periods[0] is not None and self.min_periods[0] == 0: min_count = 1 else: - min_count = self.min_periods + min_count = self.min_periods[0] - axis = self.obj.get_axis_num(self.dim) + axis = self.obj.get_axis_num(self.dim[0]) padded = self.obj.variable - if self.center: + if self.center[0]: if isinstance(padded.data, dask_array_type): # Workaround to make the padded chunk size is larger than # self.window-1 - shift = -(self.window + 1) // 2 - offset = (self.window - 1) // 2 + shift = -(self.window[0] + 1) // 2 + offset = (self.window[0] - 1) // 2 valid = (slice(None),) * axis + ( slice(offset, offset + self.obj.shape[axis]), ) else: - shift = (-self.window // 2) + 1 + shift = (-self.window[0] // 2) + 1 valid = (slice(None),) * axis + (slice(-shift, None),) - padded = padded.pad({self.dim: (0, -shift)}, mode="constant") + padded = padded.pad({self.dim[0]: (0, -shift)}, mode="constant") if isinstance(padded.data, dask_array_type): raise AssertionError("should not be reachable") values = dask_rolling_wrapper( - func, padded.data, window=self.window, min_count=min_count, axis=axis + func, padded.data, window=self.window[0], min_count=min_count, axis=axis ) else: values = func( - padded.data, window=self.window, min_count=min_count, axis=axis + padded.data, window=self.window[0], min_count=min_count, axis=axis ) - if self.center: + if self.center[0]: values = values[valid] result = DataArray(values, self.obj.coords) @@ -380,7 +396,7 @@ def _numpy_or_bottleneck_reduce( if bottleneck_move_func is not None and not isinstance( self.obj.data, dask_array_type - ): + ) and len(self.dim) == 1: # TODO: renable bottleneck with dask after the issues # underlying https://github.com/pydata/xarray/issues/2940 are # fixed. @@ -431,13 +447,19 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None DataArray.groupby """ super().__init__(obj, windows, min_periods, center, keep_attrs) - if self.dim not in self.obj.dims: + if any(d not in self.obj.dims for d in self.dim): raise KeyError(self.dim) # Keep each Rolling object as a dictionary self.rollings = {} for key, da in self.obj.data_vars.items(): # keeps rollings only for the dataset depending on slf.dim - if self.dim in da.dims: + dims, center = [], [] + for i, d in enumerate(self.dim): + if d in da.dims: + dims.append(d) + center.append(self.center[i]) + + if len(dims) > 0: self.rollings[key] = DataArrayRolling( da, windows, min_periods, center, keep_attrs ) @@ -447,7 +469,7 @@ def _dataset_implementation(self, func, **kwargs): reduced = {} for key, da in self.obj.data_vars.items(): - if self.dim in da.dims: + if any(d in da.dims for d in self.dim): reduced[key] = func(self.rollings[key], **kwargs) else: reduced[key] = self.obj[key] @@ -511,20 +533,29 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None) """ from .dataset import Dataset + if isinstance(stride, int): + stride = [stride] * len(self.dim) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) dataset = {} for key, da in self.obj.data_vars.items(): - if self.dim in da.dims: + # keeps rollings only for the dataset depending on slf.dim + dims, center = [], [] + for i, d in enumerate(self.dim): + if d in da.dims: + dims.append(d) + center.append(self.center[i]) + + if len(dims) > 0: dataset[key] = self.rollings[key].construct( window_dim, fill_value=fill_value ) else: dataset[key] = da return Dataset(dataset, coords=self.obj.coords).isel( - **{self.dim: slice(None, None, stride)} + **{d: slice(None, None, s) for d, s in zip(self.dim, stride)} ) diff --git a/xarray/core/variable.py b/xarray/core/variable.py index c505c749557..37f4441672b 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1881,11 +1881,14 @@ def rolling_window( Parameters ---------- dim: str - Dimension over which to compute rolling_window + Dimension over which to compute rolling_window. + For nd-rolling, should be list of dimensions. window: int Window size of the rolling + For nd-rolling, should be list of integers. window_dim: str New name of the window dimension. + For nd-rolling, should be list of integers. center: boolean. default False. If True, pad fill_value for both ends. Otherwise, pad in the head of the axis. @@ -1918,13 +1921,23 @@ def rolling_window( else: dtype = self.dtype array = self.data - - new_dims = self.dims + (window_dim,) + + if isinstance(dim, list): + assert len(dim) == len(window) + assert len(dim) == len(window_dim) + assert len(dim) == len(center) + else: + dim = [dim] + window = [window] + window_dim = [window_dim] + center = [center] + axis = [self.get_axis_num(d) for d in dim] + new_dims = self.dims + tuple(window_dim) return Variable( new_dims, duck_array_ops.rolling_window( array, - axis=self.get_axis_num(dim), + axis=axis, window=window, center=center, fill_value=fill_value, diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 793090cc122..56cbacafd2a 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -6193,8 +6193,6 @@ def test_rolling_properties(da): assert rolling_obj.obj.get_axis_num("time") == 1 # catching invalid args - with pytest.raises(ValueError, match="exactly one dim/window should"): - da.rolling(time=7, x=2) with pytest.raises(ValueError, match="window must be > 0"): da.rolling(time=-2) with pytest.raises(ValueError, match="min_periods must be greater than zero"): @@ -6399,6 +6397,19 @@ def test_rolling_count_correct(): assert_equal(result, expected) +@pytest.mark.parametrize("da", (1, ), indirect=True) +@pytest.mark.parametrize("center", (True, False)) +@pytest.mark.parametrize("min_periods", (None, 1, )) +def test_ndrolling_reduce(da, center, min_periods): + rolling_obj = da.rolling(time=3, a=2, center=center, min_periods=min_periods) + + # add nan prefix to numpy methods to get similar # behavior as bottleneck + actual = rolling_obj.reduce(np.nansum) + expected = rolling_obj.sum() + assert_allclose(actual, expected) + assert actual.dims == expected.dims + + def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: xr.DataArray([1, 2, np.NaN]) > 0 diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 9037013cc79..3547bdf71c6 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5962,6 +5962,7 @@ def test_rolling_construct(center, window): df_rolling = df.rolling(window, center=center, min_periods=1).mean() ds_rolling = ds.rolling(index=window, center=center) + print(ds_rolling.construct("window")) ds_rolling_mean = ds_rolling.construct("window").mean("window") np.testing.assert_allclose(df_rolling["x"].values, ds_rolling_mean["x"].values) np.testing.assert_allclose(df_rolling.index, ds_rolling_mean["index"]) diff --git a/xarray/tests/test_nputils.py b/xarray/tests/test_nputils.py index 1002a9dd9e3..9d0bf7c6f8c 100644 --- a/xarray/tests/test_nputils.py +++ b/xarray/tests/test_nputils.py @@ -1,5 +1,6 @@ import numpy as np from numpy.testing import assert_array_equal +import pytest from xarray.core.nputils import NumpyVIndexAdapter, _is_contiguous, rolling_window @@ -47,3 +48,22 @@ def test_rolling(): actual = rolling_window(x, axis=-1, window=3, center=False, fill_value=0.0) expected = np.stack([expected, expected * 1.1], axis=0) assert_array_equal(actual, expected) + + +@pytest.mark.parametrize('center', [[True, True], [False, False]]) +@pytest.mark.parametrize('axis', [(0, 1), (1, 2), (2, 0)]) +def test_nd_rolling(center, axis): + x = np.arange(7 * 6 * 8).reshape(7, 6, 8).astype(float) + window = [3, 3] + actual = rolling_window( + x, axis=axis, window=window, + center=center, fill_value=np.nan) + expected = x + for ax, win, cent in zip(axis, window, center): + expected = rolling_window( + expected, axis=ax, window=win, center=cent, fill_value=np.nan) + assert_array_equal(actual, expected) + + + + From 62854760b17e99dcd452ff853b8ec6c808e7ce92 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Sun, 12 Jul 2020 21:01:38 +0900 Subject: [PATCH 02/23] remove unnecessary print --- xarray/tests/test_dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 3547bdf71c6..9037013cc79 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5962,7 +5962,6 @@ def test_rolling_construct(center, window): df_rolling = df.rolling(window, center=center, min_periods=1).mean() ds_rolling = ds.rolling(index=window, center=center) - print(ds_rolling.construct("window")) ds_rolling_mean = ds_rolling.construct("window").mean("window") np.testing.assert_allclose(df_rolling["x"].values, ds_rolling_mean["x"].values) np.testing.assert_allclose(df_rolling.index, ds_rolling_mean["index"]) From dc113dc688d3921d426bfb1ec4f0cd184c39690a Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Sun, 12 Jul 2020 21:04:33 +0900 Subject: [PATCH 03/23] black --- xarray/core/dask_array_ops.py | 4 +- xarray/core/nputils.py | 2 +- xarray/core/rolling.py | 51 +++++++++------ xarray/core/variable.py | 8 +-- xarray/tests/test_dataarray.py | 4 +- xarray/tests/test_nputils.py | 15 ++--- xarray/tests/test_testing.py | 2 +- xarray/tests/test_units.py | 114 ++++++++++----------------------- 8 files changed, 79 insertions(+), 121 deletions(-) diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index c94b8f1c380..a5f8441f59a 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -32,10 +32,10 @@ def rolling_window(a, axis, window, center, fill_value): """ import dask.array as da - # for nd-rolling. + # for nd-rolling. # TODO It can be more efficient. Currently, the chunks at the boundaries # will be copied, but it might be OK for many-chunked-arrays. - if hasattr(axis, '__len__'): + if hasattr(axis, "__len__"): for ax, win, cen in zip(axis, window, center): a = rolling_window(a, ax, win, cen, fill_value) return a diff --git a/xarray/core/nputils.py b/xarray/core/nputils.py index f77e005e5dc..4f592eb3c5c 100644 --- a/xarray/core/nputils.py +++ b/xarray/core/nputils.py @@ -135,7 +135,7 @@ def __setitem__(self, key, value): def rolling_window(a, axis, window, center, fill_value): """ rolling window with padding. """ pads = [(0, 0) for s in a.shape] - if not hasattr(axis, '__len__'): + if not hasattr(axis, "__len__"): axis = [axis] window = [window] center = [center] diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 86c0314aae4..fca0460a348 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -83,10 +83,10 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None if center is None or isinstance(center, bool): center = [center] * len(dim) - + # TODO support nd-min_periods - if hasattr(min_periods, '__len__'): - raise NotImplementedError('multiple min_periods is not yet supported.') + if hasattr(min_periods, "__len__"): + raise NotImplementedError("multiple min_periods is not yet supported.") min_periods = [min_periods] * len(dim) self.obj = obj @@ -95,7 +95,9 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None self.window = window if any(mp is not None and mp <= 0 for mp in min_periods): raise ValueError("min_periods must be greater than zero or None") - self.min_periods = [w if mp is None else mp for mp, w in zip(min_periods, window)] + self.min_periods = [ + w if mp is None else mp for mp, w in zip(min_periods, window) + ] self.center = center self.dim = dim @@ -109,8 +111,10 @@ def __repr__(self): attrs = [ "{k}->{v}".format(k=k, v=getattr(self, k)) - for k in - list(self.dim) + list(self.window) + list(self.center) + list(self.min_periods) + for k in list(self.dim) + + list(self.window) + + list(self.center) + + list(self.min_periods) ] return "{klass} [{attrs}]".format( klass=self.__class__.__name__, attrs=",".join(attrs) @@ -145,7 +149,7 @@ def method(self, **kwargs): def count(self): if len(self.dim) > 1: - raise NotImplementedError('count is not implemented for nd-rolling.') + raise NotImplementedError("count is not implemented for nd-rolling.") rolling_count = self._counts() enough_periods = rolling_count >= self.min_periods[0] return rolling_count.where(enough_periods) @@ -199,13 +203,13 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None super().__init__( obj, windows, min_periods=min_periods, center=center, keep_attrs=keep_attrs ) - + # TODO legacy attribute self.window_labels = self.obj[self.dim[0]] def __iter__(self): if len(self.dim) > 1: - raise ValueError('__iter__ is only supported for 1d-rolling') + raise ValueError("__iter__ is only supported for 1d-rolling") stops = np.arange(1, len(self.window_labels) + 1) starts = stops - int(self.window[0]) starts[: int(self.window[0])] = 0 @@ -257,7 +261,7 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA): """ from .dataarray import DataArray - + if len(self.dim) == 1 and not isinstance(window_dim, list): window_dim = [window_dim] if isinstance(stride, int): @@ -268,7 +272,9 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA): result = DataArray( window, dims=self.obj.dims + tuple(window_dim), coords=self.obj.coords ) - return result.isel(**{d: slice(None, None, s) for d, s in zip(self.dim, stride)}) + return result.isel( + **{d: slice(None, None, s) for d, s in zip(self.dim, stride)} + ) def reduce(self, func, **kwargs): """Reduce the items in this group by applying `func` along some @@ -313,7 +319,8 @@ def reduce(self, func, **kwargs): """ rolling_dim = [ utils.get_temp_dimname(self.obj.dims, "_rolling_dim_{}".format(d)) - for d in self.dim] + for d in self.dim + ] windows = self.construct(rolling_dim) result = windows.reduce(func, dim=rolling_dim, **kwargs) @@ -326,14 +333,17 @@ def _counts(self): rolling_dim = [ utils.get_temp_dimname(self.obj.dims, "_rolling_dim_{}".format(d)) - for d in self.dim] + for d in self.dim + ] # We use False as the fill_value instead of np.nan, since boolean # array is faster to be reduced than object array. # The use of skipna==False is also faster since it does not need to # copy the strided array. counts = ( self.obj.notnull() - .rolling(center=self.center, **{d: w for d, w in zip(self.dim, self.window)}) + .rolling( + center=self.center, **{d: w for d, w in zip(self.dim, self.window)} + ) .construct(rolling_dim, fill_value=False) .sum(dim=rolling_dim, skipna=False) ) @@ -394,9 +404,11 @@ def _numpy_or_bottleneck_reduce( ) del kwargs["dim"] - if bottleneck_move_func is not None and not isinstance( - self.obj.data, dask_array_type - ) and len(self.dim) == 1: + if ( + bottleneck_move_func is not None + and not isinstance(self.obj.data, dask_array_type) + and len(self.dim) == 1 + ): # TODO: renable bottleneck with dask after the issues # underlying https://github.com/pydata/xarray/issues/2940 are # fixed. @@ -458,7 +470,7 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None if d in da.dims: dims.append(d) center.append(self.center[i]) - + if len(dims) > 0: self.rollings[key] = DataArrayRolling( da, windows, min_periods, center, keep_attrs @@ -533,6 +545,7 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None) """ from .dataset import Dataset + if isinstance(stride, int): stride = [stride] * len(self.dim) @@ -547,7 +560,7 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None) if d in da.dims: dims.append(d) center.append(self.center[i]) - + if len(dims) > 0: dataset[key] = self.rollings[key].construct( window_dim, fill_value=fill_value diff --git a/xarray/core/variable.py b/xarray/core/variable.py index 37f4441672b..330b522cc1a 100644 --- a/xarray/core/variable.py +++ b/xarray/core/variable.py @@ -1921,7 +1921,7 @@ def rolling_window( else: dtype = self.dtype array = self.data - + if isinstance(dim, list): assert len(dim) == len(window) assert len(dim) == len(window_dim) @@ -1936,11 +1936,7 @@ def rolling_window( return Variable( new_dims, duck_array_ops.rolling_window( - array, - axis=axis, - window=window, - center=center, - fill_value=fill_value, + array, axis=axis, window=window, center=center, fill_value=fill_value ), ) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 56cbacafd2a..ddaf144ddc1 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -6397,9 +6397,9 @@ def test_rolling_count_correct(): assert_equal(result, expected) -@pytest.mark.parametrize("da", (1, ), indirect=True) +@pytest.mark.parametrize("da", (1,), indirect=True) @pytest.mark.parametrize("center", (True, False)) -@pytest.mark.parametrize("min_periods", (None, 1, )) +@pytest.mark.parametrize("min_periods", (None, 1)) def test_ndrolling_reduce(da, center, min_periods): rolling_obj = da.rolling(time=3, a=2, center=center, min_periods=min_periods) diff --git a/xarray/tests/test_nputils.py b/xarray/tests/test_nputils.py index 9d0bf7c6f8c..de8595c0f81 100644 --- a/xarray/tests/test_nputils.py +++ b/xarray/tests/test_nputils.py @@ -50,20 +50,17 @@ def test_rolling(): assert_array_equal(actual, expected) -@pytest.mark.parametrize('center', [[True, True], [False, False]]) -@pytest.mark.parametrize('axis', [(0, 1), (1, 2), (2, 0)]) +@pytest.mark.parametrize("center", [[True, True], [False, False]]) +@pytest.mark.parametrize("axis", [(0, 1), (1, 2), (2, 0)]) def test_nd_rolling(center, axis): x = np.arange(7 * 6 * 8).reshape(7, 6, 8).astype(float) window = [3, 3] actual = rolling_window( - x, axis=axis, window=window, - center=center, fill_value=np.nan) + x, axis=axis, window=window, center=center, fill_value=np.nan + ) expected = x for ax, win, cent in zip(axis, window, center): expected = rolling_window( - expected, axis=ax, window=win, center=cent, fill_value=np.nan) + expected, axis=ax, window=win, center=cent, fill_value=np.nan + ) assert_array_equal(actual, expected) - - - - diff --git a/xarray/tests/test_testing.py b/xarray/tests/test_testing.py index 39ad250246b..adc29a3cc92 100644 --- a/xarray/tests/test_testing.py +++ b/xarray/tests/test_testing.py @@ -37,7 +37,7 @@ def test_allclose_regression(): "obj1,obj2", ( pytest.param( - xr.Variable("x", [1e-17, 2]), xr.Variable("x", [0, 3]), id="Variable", + xr.Variable("x", [1e-17, 2]), xr.Variable("x", [0, 3]), id="Variable" ), pytest.param( xr.DataArray([1e-17, 2], dims="x"), diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 619fa10116d..e8dd172ba02 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -624,7 +624,7 @@ def test_align_dataset(value, unit, variant, error, dtype): units_a = extract_units(ds1) units_b = extract_units(ds2) expected_a, expected_b = func( - strip_units(ds1), strip_units(convert_units(ds2, units_a)), **stripped_kwargs, + strip_units(ds1), strip_units(convert_units(ds2, units_a)), **stripped_kwargs ) expected_a = attach_units(expected_a, units_a) if isinstance(array2, Quantity): @@ -1223,11 +1223,7 @@ def test_merge_dataset(variant, unit, error, dtype): def test_replication_dataarray(func, variant, dtype): unit = unit_registry.m - variants = { - "data": (unit, 1, 1), - "dims": (1, unit, 1), - "coords": (1, 1, unit), - } + variants = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)} data_unit, dim_unit, coord_unit = variants.get(variant) array = np.linspace(0, 10, 20).astype(dtype) * data_unit @@ -1308,11 +1304,7 @@ def test_replication_full_like_dataarray(variant, dtype): # fill value, we don't need to try multiple units unit = unit_registry.m - variants = { - "data": (unit, 1, 1), - "dims": (1, unit, 1), - "coords": (1, 1, unit), - } + variants = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)} data_unit, dim_unit, coord_unit = variants.get(variant) array = np.linspace(0, 5, 10) * data_unit @@ -1370,10 +1362,7 @@ def test_replication_full_like_dataset(variant, dtype): fill_value = -1 * unit_registry.degK - units = { - **extract_units(ds), - **{name: unit_registry.degK for name in ds.data_vars}, - } + units = {**extract_units(ds), **{name: unit_registry.degK for name in ds.data_vars}} expected = attach_units( xr.full_like(strip_units(ds), fill_value=strip_units(fill_value)), units ) @@ -1735,7 +1724,7 @@ def test_missing_value_fillna(self, unit, error): pytest.param(1, id="no_unit"), pytest.param(unit_registry.dimensionless, id="dimensionless"), pytest.param(unit_registry.s, id="incompatible_unit"), - pytest.param(unit_registry.cm, id="compatible_unit",), + pytest.param(unit_registry.cm, id="compatible_unit"), pytest.param(unit_registry.m, id="identical_unit"), ), ) @@ -2186,7 +2175,7 @@ def test_pad(self, mode, xr_arg, np_arg): v = xr.Variable(["x", "y", "z"], data) expected = attach_units( - strip_units(v).pad(mode=mode, **xr_arg), extract_units(v), + strip_units(v).pad(mode=mode, **xr_arg), extract_units(v) ) actual = v.pad(mode=mode, **xr_arg) @@ -2424,7 +2413,7 @@ def test_binary_operations(self, func, dtype): id="equal", marks=pytest.mark.xfail( # LooseVersion(pint.__version__) < "0.14", - reason="inconsistencies in the return values of pint's eq", + reason="inconsistencies in the return values of pint's eq" ), ), ), @@ -2918,8 +2907,8 @@ def test_interpolate_na(self): unit_registry.dimensionless, DimensionalityError, id="dimensionless" ), pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param(unit_registry.cm, None, id="compatible_unit",), - pytest.param(unit_registry.m, None, id="identical_unit",), + pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.m, None, id="identical_unit"), ), ) def test_combine_first(self, unit, error, dtype): @@ -3167,11 +3156,7 @@ def test_pad(self, dtype): def test_content_manipulation(self, func, variant, dtype): unit = unit_registry.m - variants = { - "data": (unit, 1, 1), - "dims": (1, unit, 1), - "coords": (1, 1, unit), - } + variants = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)} data_unit, dim_unit, coord_unit = variants.get(variant) quantity = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit @@ -3437,10 +3422,7 @@ def test_head_tail_thin(self, func, dtype): ids=repr, ) def test_interp_reindex(self, variant, func, dtype): - variants = { - "data": (unit_registry.m, 1), - "coords": (1, unit_registry.m), - } + variants = {"data": (unit_registry.m, 1), "coords": (1, unit_registry.m)} data_unit, coord_unit = variants.get(variant) array = np.linspace(1, 2, 10).astype(dtype) * data_unit @@ -3470,9 +3452,7 @@ def test_interp_reindex(self, variant, func, dtype): pytest.param(unit_registry.m, None, id="identical_unit"), ), ) - @pytest.mark.parametrize( - "func", (method("interp"), method("reindex")), ids=repr, - ) + @pytest.mark.parametrize("func", (method("interp"), method("reindex")), ids=repr) def test_interp_reindex_indexing(self, func, unit, error, dtype): array = np.linspace(1, 2, 10).astype(dtype) x = np.arange(10) * unit_registry.m @@ -3510,10 +3490,7 @@ def test_interp_reindex_indexing(self, func, unit, error, dtype): ids=repr, ) def test_interp_reindex_like(self, variant, func, dtype): - variants = { - "data": (unit_registry.m, 1), - "coords": (1, unit_registry.m), - } + variants = {"data": (unit_registry.m, 1), "coords": (1, unit_registry.m)} data_unit, coord_unit = variants.get(variant) array = np.linspace(1, 2, 10).astype(dtype) * data_unit @@ -3545,7 +3522,7 @@ def test_interp_reindex_like(self, variant, func, dtype): ), ) @pytest.mark.parametrize( - "func", (method("interp_like"), method("reindex_like")), ids=repr, + "func", (method("interp_like"), method("reindex_like")), ids=repr ) def test_interp_reindex_like_indexing(self, func, unit, error, dtype): array = np.linspace(1, 2, 10).astype(dtype) @@ -3681,11 +3658,7 @@ def test_stacking_reordering(self, func, dtype): def test_computation(self, func, variant, dtype): unit = unit_registry.m - variants = { - "data": (unit, 1, 1), - "dims": (1, unit, 1), - "coords": (1, 1, unit), - } + variants = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)} data_unit, dim_unit, coord_unit = variants.get(variant) array = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit @@ -3745,11 +3718,7 @@ def test_computation(self, func, variant, dtype): def test_computation_objects(self, func, variant, dtype): unit = unit_registry.m - variants = { - "data": (unit, 1, 1), - "dims": (1, unit, 1), - "coords": (1, 1, unit), - } + variants = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)} data_unit, dim_unit, coord_unit = variants.get(variant) array = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit @@ -3808,11 +3777,7 @@ def test_resample(self, dtype): def test_grouped_operations(self, func, variant, dtype): unit = unit_registry.m - variants = { - "data": (unit, 1, 1), - "dims": (1, unit, 1), - "coords": (1, 1, unit), - } + variants = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)} data_unit, dim_unit, coord_unit = variants.get(variant) array = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit @@ -3927,7 +3892,7 @@ def test_init(self, shared, unit, error, dtype): ( "data", pytest.param( - "dims", marks=pytest.mark.xfail(reason="indexes don't support units"), + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") ), "coords", ), @@ -3943,11 +3908,7 @@ def test_repr(self, func, variant, dtype): x = np.arange(len(array1)) * unit_registry.s y = x.to(unit_registry.ms) - variants = { - "dims": {"x": x}, - "coords": {"y": ("x", y)}, - "data": {}, - } + variants = {"dims": {"x": x}, "coords": {"y": ("x", y)}, "data": {}} ds = xr.Dataset( data_vars={"a": ("x", array1), "b": ("x", array2)}, @@ -4195,7 +4156,7 @@ def test_missing_value_filling(self, func, dtype): unit_registry.dimensionless, DimensionalityError, id="dimensionless" ), pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param(unit_registry.cm, None, id="compatible_unit",), + pytest.param(unit_registry.cm, None, id="compatible_unit"), pytest.param(unit_registry.m, None, id="identical_unit"), ), ) @@ -4340,7 +4301,7 @@ def test_where(self, variant, unit, error, dtype): for key, value in kwargs.items() } - expected = attach_units(strip_units(ds).where(**kwargs_without_units), units,) + expected = attach_units(strip_units(ds).where(**kwargs_without_units), units) actual = ds.where(**kwargs) assert_units_equal(expected, actual) @@ -4359,7 +4320,7 @@ def test_interpolate_na(self, dtype): ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) units = extract_units(ds) - expected = attach_units(strip_units(ds).interpolate_na(dim="x"), units,) + expected = attach_units(strip_units(ds).interpolate_na(dim="x"), units) actual = ds.interpolate_na(dim="x") assert_units_equal(expected, actual) @@ -4382,7 +4343,7 @@ def test_interpolate_na(self, dtype): ( "data", pytest.param( - "dims", marks=pytest.mark.xfail(reason="indexes don't support units"), + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") ), ), ) @@ -4401,7 +4362,7 @@ def test_combine_first(self, variant, unit, error, dtype): ) x = np.arange(len(array1)) * dims_unit ds = xr.Dataset( - data_vars={"a": ("x", array1), "b": ("x", array2)}, coords={"x": x}, + data_vars={"a": ("x", array1), "b": ("x", array2)}, coords={"x": x} ) units = extract_units(ds) @@ -4478,7 +4439,7 @@ def test_comparisons(self, func, variant, unit, dtype): y = coord * coord_unit ds = xr.Dataset( - data_vars={"a": ("x", a), "b": ("x", b)}, coords={"x": x, "y": ("x", y)}, + data_vars={"a": ("x", a), "b": ("x", b)}, coords={"x": x, "y": ("x", y)} ) units = extract_units(ds) @@ -4535,7 +4496,7 @@ def test_comparisons(self, func, variant, unit, dtype): ( "data", pytest.param( - "dims", marks=pytest.mark.xfail(reason="indexes don't support units"), + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") ), ), ) @@ -4588,7 +4549,7 @@ def test_broadcast_equals(self, unit, dtype): right_array2 = np.zeros(shape=(3,)) * unit left = xr.Dataset( - {"a": (("x", "y"), left_array1), "b": (("y", "z"), left_array2)}, + {"a": (("x", "y"), left_array1), "b": (("y", "z"), left_array2)} ) right = xr.Dataset({"a": ("x", right_array1), "b": ("y", right_array2)}) @@ -4626,15 +4587,12 @@ def test_pad(self, dtype): ( "data", pytest.param( - "dims", marks=pytest.mark.xfail(reason="indexes don't support units"), + "dims", marks=pytest.mark.xfail(reason="indexes don't support units") ), ), ) def test_stacking_stacked(self, variant, func, dtype): - variants = { - "data": (unit_registry.m, 1), - "dims": (1, unit_registry.m), - } + variants = {"data": (unit_registry.m, 1), "dims": (1, unit_registry.m)} data_unit, dim_unit = variants.get(variant) array1 = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit @@ -4677,7 +4635,7 @@ def test_to_stacked_array(self, dtype): func = method("to_stacked_array", "z", variable_dim="y", sample_dims=["x"]) actual = func(ds).rename(None) - expected = attach_units(func(strip_units(ds)).rename(None), units,) + expected = attach_units(func(strip_units(ds)).rename(None), units) assert_units_equal(expected, actual) assert_equal(expected, actual) @@ -4983,7 +4941,7 @@ def test_squeeze(self, shape, dim, dtype): data_vars={ "a": (tuple(names[: len(shape)]), array1), "b": (tuple(names[: len(shape)]), array2), - }, + } ) units = extract_units(ds) @@ -5008,10 +4966,7 @@ def test_squeeze(self, shape, dim, dtype): ids=repr, ) def test_interp_reindex(self, func, variant, dtype): - variants = { - "data": (unit_registry.m, 1), - "coords": (1, unit_registry.m), - } + variants = {"data": (unit_registry.m, 1), "coords": (1, unit_registry.m)} data_unit, coord_unit = variants.get(variant) array1 = np.linspace(-1, 0, 10).astype(dtype) * data_unit @@ -5081,10 +5036,7 @@ def test_interp_reindex_indexing(self, func, unit, error, dtype): ids=repr, ) def test_interp_reindex_like(self, func, variant, dtype): - variants = { - "data": (unit_registry.m, 1), - "coords": (1, unit_registry.m), - } + variants = {"data": (unit_registry.m, 1), "coords": (1, unit_registry.m)} data_unit, coord_unit = variants.get(variant) array1 = np.linspace(-1, 0, 10).astype(dtype) * data_unit From 9a45fc78d7917a862d6a42db0779c3139fe68bf9 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Sun, 12 Jul 2020 21:29:40 +0900 Subject: [PATCH 04/23] finding a bug... --- xarray/tests/test_dataarray.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index ddaf144ddc1..509bdf8d8e6 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -6401,11 +6401,14 @@ def test_rolling_count_correct(): @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("min_periods", (None, 1)) def test_ndrolling_reduce(da, center, min_periods): - rolling_obj = da.rolling(time=3, a=2, center=center, min_periods=min_periods) + rolling_obj = da.rolling(time=3, x=2, center=center, min_periods=min_periods) # add nan prefix to numpy methods to get similar # behavior as bottleneck - actual = rolling_obj.reduce(np.nansum) - expected = rolling_obj.sum() + actual = rolling_obj.sum(skipna=False) + expected = ( + da.rolling(time=3, center=center, min_periods=min_periods).sum(skipna=False) + .rolling(x=2, center=center, min_periods=min_periods).sum(skipna=False)) + assert_allclose(actual, expected) assert actual.dims == expected.dims From f74b2e8ea9588e2932376fd09f9bfc02e270d34d Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Mon, 13 Jul 2020 06:01:31 +0900 Subject: [PATCH 05/23] make tests for ndrolling pass --- xarray/core/rolling.py | 26 +++++++++----------------- xarray/tests/test_dataarray.py | 7 +++---- 2 files changed, 12 insertions(+), 21 deletions(-) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index fca0460a348..2d0d0c416a9 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -84,20 +84,14 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None if center is None or isinstance(center, bool): center = [center] * len(dim) - # TODO support nd-min_periods - if hasattr(min_periods, "__len__"): - raise NotImplementedError("multiple min_periods is not yet supported.") - min_periods = [min_periods] * len(dim) - self.obj = obj # attributes self.window = window - if any(mp is not None and mp <= 0 for mp in min_periods): + if min_periods is not None and min_periods <= 0: raise ValueError("min_periods must be greater than zero or None") - self.min_periods = [ - w if mp is None else mp for mp, w in zip(min_periods, window) - ] + + self.min_periods = np.prod(window) if min_periods is None else min_periods self.center = center self.dim = dim @@ -114,7 +108,7 @@ def __repr__(self): for k in list(self.dim) + list(self.window) + list(self.center) - + list(self.min_periods) + + [self.min_periods] ] return "{klass} [{attrs}]".format( klass=self.__class__.__name__, attrs=",".join(attrs) @@ -148,10 +142,8 @@ def method(self, **kwargs): median = _reduce_method("median") def count(self): - if len(self.dim) > 1: - raise NotImplementedError("count is not implemented for nd-rolling.") rolling_count = self._counts() - enough_periods = rolling_count >= self.min_periods[0] + enough_periods = rolling_count >= self.min_periods return rolling_count.where(enough_periods) count.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="count") @@ -217,7 +209,7 @@ def __iter__(self): window = self.obj.isel(**{self.dim[0]: slice(start, stop)}) counts = window.count(dim=self.dim[0]) - window = window.where(counts >= self.min_periods[0]) + window = window.where(counts >= self.min_periods) yield (label, window) @@ -326,7 +318,7 @@ def reduce(self, func, **kwargs): # Find valid windows based on count. counts = self._counts() - return result.where(counts >= self.min_periods[0]) + return result.where(counts >= self.min_periods) def _counts(self): """ Number of non-nan entries in each rolling window. """ @@ -355,10 +347,10 @@ def _bottleneck_reduce(self, func, **kwargs): # bottleneck doesn't allow min_count to be 0, although it should # work the same as if min_count = 1 # Note bottleneck only works with 1d-rolling. - if self.min_periods[0] is not None and self.min_periods[0] == 0: + if self.min_periods is not None and self.min_periods == 0: min_count = 1 else: - min_count = self.min_periods[0] + min_count = self.min_periods axis = self.obj.get_axis_num(self.dim[0]) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 509bdf8d8e6..7125bfeada1 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -6403,11 +6403,10 @@ def test_rolling_count_correct(): def test_ndrolling_reduce(da, center, min_periods): rolling_obj = da.rolling(time=3, x=2, center=center, min_periods=min_periods) - # add nan prefix to numpy methods to get similar # behavior as bottleneck - actual = rolling_obj.sum(skipna=False) + actual = rolling_obj.sum() expected = ( - da.rolling(time=3, center=center, min_periods=min_periods).sum(skipna=False) - .rolling(x=2, center=center, min_periods=min_periods).sum(skipna=False)) + da.rolling(time=3, center=center, min_periods=min_periods).sum() + .rolling(x=2, center=center, min_periods=min_periods).sum()) assert_allclose(actual, expected) assert actual.dims == expected.dims From d4990d7c31f71685259ab16105ad4a9593a85e3b Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Mon, 13 Jul 2020 06:55:26 +0900 Subject: [PATCH 06/23] make center and window_dim a dict --- xarray/core/rolling.py | 81 ++++++++++++++++++++++++++---------------- 1 file changed, 50 insertions(+), 31 deletions(-) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 2d0d0c416a9..41f0489d827 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -75,26 +75,28 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None ------- rolling : type of input argument """ - dim = list(windows.keys()) - window = list(windows.values()) - - if any([w <= 0 for w in window]): - raise ValueError("window must be > 0") - - if center is None or isinstance(center, bool): - center = [center] * len(dim) + self.dim, self.window = [], [] + for d, w in windows.items(): + self.dim.append(d) + if w <= 0: + raise ValueError("window must be > 0") + self.window.append(w) + + if utils.is_dict_like(center): + self.center = [center.get(d, False) for d in self.dim] + elif isinstance(center, bool) or center is None: + self.center = [center] * len(self.dim) + else: + raise ValueError('center should be boolean or a mapping. ' + 'Given {}'.format(center)) self.obj = obj # attributes - self.window = window if min_periods is not None and min_periods <= 0: raise ValueError("min_periods must be greater than zero or None") - self.min_periods = np.prod(window) if min_periods is None else min_periods - - self.center = center - self.dim = dim + self.min_periods = np.prod(self.window) if min_periods is None else min_periods if keep_attrs is None: keep_attrs = _get_keep_attrs(default=False) @@ -148,6 +150,11 @@ def count(self): count.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="count") + def _dict_to_list(self, arg, default=None): + if utils.is_dict_like(arg): + return [arg.get(d, default) for d in self.dim] + else: # for single argument + return [arg] * len(self.dim) class DataArrayRolling(Rolling): __slots__ = ("window_labels",) @@ -213,7 +220,10 @@ def __iter__(self): yield (label, window) - def construct(self, window_dim, stride=1, fill_value=dtypes.NA): + def construct( + self, window_dim=None, stride=1, fill_value=dtypes.NA, + **window_dim_kwargs + ): """ Convert this rolling object to xr.DataArray, where the window dimension is stacked as a new dimension @@ -254,10 +264,22 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA): from .dataarray import DataArray - if len(self.dim) == 1 and not isinstance(window_dim, list): + if window_dim is None: + if len(window_dim_kwargs) == 0: + raise ValueError('Either window_dim or window_dim_kwargs need to be specified.') + window_dim = {d: window_dim_kwargs[d] for d in self.dim} + + if len(self.dim) == 1 and not utils.is_dict_like(window_dim): window_dim = [window_dim] + else: + # make window_dim a list + window_dim = [window_dim[d] for d in self.dim] + if isinstance(stride, int): stride = [stride] * len(self.dim) + else: + stride = [stride.get(d) for d in self.dim] + window = self.obj.variable.rolling_window( self.dim, self.window, window_dim, self.center, fill_value=fill_value ) @@ -309,12 +331,12 @@ def reduce(self, func, **kwargs): [ 4., 9., 15., 18.]]) """ - rolling_dim = [ - utils.get_temp_dimname(self.obj.dims, "_rolling_dim_{}".format(d)) + rolling_dim = { + d: utils.get_temp_dimname(self.obj.dims, "_rolling_dim_{}".format(d)) for d in self.dim - ] + } windows = self.construct(rolling_dim) - result = windows.reduce(func, dim=rolling_dim, **kwargs) + result = windows.reduce(func, dim=list(rolling_dim.values()), **kwargs) # Find valid windows based on count. counts = self._counts() @@ -323,10 +345,10 @@ def reduce(self, func, **kwargs): def _counts(self): """ Number of non-nan entries in each rolling window. """ - rolling_dim = [ - utils.get_temp_dimname(self.obj.dims, "_rolling_dim_{}".format(d)) + rolling_dim = { + d: utils.get_temp_dimname(self.obj.dims, "_rolling_dim_{}".format(d)) for d in self.dim - ] + } # We use False as the fill_value instead of np.nan, since boolean # array is faster to be reduced than object array. # The use of skipna==False is also faster since it does not need to @@ -334,10 +356,11 @@ def _counts(self): counts = ( self.obj.notnull() .rolling( - center=self.center, **{d: w for d, w in zip(self.dim, self.window)} + center={d: self.center[i] for i, d in enumerate(self.dim)}, + **{d: w for d, w in zip(self.dim, self.window)} ) .construct(rolling_dim, fill_value=False) - .sum(dim=rolling_dim, skipna=False) + .sum(dim=list(rolling_dim.values()), skipna=False) ) return counts @@ -457,11 +480,11 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None self.rollings = {} for key, da in self.obj.data_vars.items(): # keeps rollings only for the dataset depending on slf.dim - dims, center = [], [] + dims, center = [], {} for i, d in enumerate(self.dim): if d in da.dims: dims.append(d) - center.append(self.center[i]) + center[d] = self.center[i] if len(dims) > 0: self.rollings[key] = DataArrayRolling( @@ -547,11 +570,7 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None) dataset = {} for key, da in self.obj.data_vars.items(): # keeps rollings only for the dataset depending on slf.dim - dims, center = [], [] - for i, d in enumerate(self.dim): - if d in da.dims: - dims.append(d) - center.append(self.center[i]) + dims = [d for d in self.dim if d in da.dims] if len(dims) > 0: dataset[key] = self.rollings[key].construct( From 531b0ecc2bf9e9679fed9d6d1e6e17c84b882106 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Mon, 13 Jul 2020 07:03:56 +0900 Subject: [PATCH 07/23] A cleanup. --- xarray/core/rolling.py | 55 ++++++++++++++++------------------ xarray/tests/test_dataarray.py | 7 +++-- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 41f0489d827..0fd6a8adcea 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -82,20 +82,13 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None raise ValueError("window must be > 0") self.window.append(w) - if utils.is_dict_like(center): - self.center = [center.get(d, False) for d in self.dim] - elif isinstance(center, bool) or center is None: - self.center = [center] * len(self.dim) - else: - raise ValueError('center should be boolean or a mapping. ' - 'Given {}'.format(center)) - + self.center = self._mapping_to_list(center, default=False) self.obj = obj # attributes if min_periods is not None and min_periods <= 0: raise ValueError("min_periods must be greater than zero or None") - + self.min_periods = np.prod(self.window) if min_periods is None else min_periods if keep_attrs is None: @@ -150,11 +143,21 @@ def count(self): count.__doc__ = _ROLLING_REDUCE_DOCSTRING_TEMPLATE.format(name="count") - def _dict_to_list(self, arg, default=None): + def _mapping_to_list( + self, arg, default=None, allow_default=True, allow_allsame=True + ): if utils.is_dict_like(arg): - return [arg.get(d, default) for d in self.dim] - else: # for single argument + if allow_default: + return [arg.get(d, default) for d in self.dim] + else: + return [arg.get(d, default) for d in self.dim] + elif allow_allsame: # for single argument return [arg] * len(self.dim) + elif len(self.dim) == 1: + return [arg] + else: + raise ValueError("Mapping argument is necessary.") + class DataArrayRolling(Rolling): __slots__ = ("window_labels",) @@ -221,8 +224,7 @@ def __iter__(self): yield (label, window) def construct( - self, window_dim=None, stride=1, fill_value=dtypes.NA, - **window_dim_kwargs + self, window_dim=None, stride=1, fill_value=dtypes.NA, **window_dim_kwargs ): """ Convert this rolling object to xr.DataArray, @@ -266,19 +268,15 @@ def construct( if window_dim is None: if len(window_dim_kwargs) == 0: - raise ValueError('Either window_dim or window_dim_kwargs need to be specified.') - window_dim = {d: window_dim_kwargs[d] for d in self.dim} - - if len(self.dim) == 1 and not utils.is_dict_like(window_dim): - window_dim = [window_dim] - else: - # make window_dim a list - window_dim = [window_dim[d] for d in self.dim] + raise ValueError( + "Either window_dim or window_dim_kwargs need to be specified." + ) + window_dim = {d: window_dim_kwargs[d] for d in self.dim} - if isinstance(stride, int): - stride = [stride] * len(self.dim) - else: - stride = [stride.get(d) for d in self.dim] + window_dim = self._mapping_to_list( + window_dim, allow_default=False, allow_allsame=False + ) + stride = self._mapping_to_list(stride, default=1) window = self.obj.variable.rolling_window( self.dim, self.window, window_dim, self.center, fill_value=fill_value @@ -357,7 +355,7 @@ def _counts(self): self.obj.notnull() .rolling( center={d: self.center[i] for i, d in enumerate(self.dim)}, - **{d: w for d, w in zip(self.dim, self.window)} + **{d: w for d, w in zip(self.dim, self.window)}, ) .construct(rolling_dim, fill_value=False) .sum(dim=list(rolling_dim.values()), skipna=False) @@ -561,8 +559,7 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None) from .dataset import Dataset - if isinstance(stride, int): - stride = [stride] * len(self.dim) + stride = self._mapping_to_list(stride, default=1) if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 7125bfeada1..579b8d15ec7 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -6405,8 +6405,11 @@ def test_ndrolling_reduce(da, center, min_periods): actual = rolling_obj.sum() expected = ( - da.rolling(time=3, center=center, min_periods=min_periods).sum() - .rolling(x=2, center=center, min_periods=min_periods).sum()) + da.rolling(time=3, center=center, min_periods=min_periods) + .sum() + .rolling(x=2, center=center, min_periods=min_periods) + .sum() + ) assert_allclose(actual, expected) assert actual.dims == expected.dims From b6cf25010995ebf178846d392df282a7bc935144 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Mon, 13 Jul 2020 07:16:48 +0900 Subject: [PATCH 08/23] Revert test_units --- xarray/tests/test_units.py | 114 ++++++++++++++++++++++++++----------- 1 file changed, 81 insertions(+), 33 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index e8dd172ba02..619fa10116d 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -624,7 +624,7 @@ def test_align_dataset(value, unit, variant, error, dtype): units_a = extract_units(ds1) units_b = extract_units(ds2) expected_a, expected_b = func( - strip_units(ds1), strip_units(convert_units(ds2, units_a)), **stripped_kwargs + strip_units(ds1), strip_units(convert_units(ds2, units_a)), **stripped_kwargs, ) expected_a = attach_units(expected_a, units_a) if isinstance(array2, Quantity): @@ -1223,7 +1223,11 @@ def test_merge_dataset(variant, unit, error, dtype): def test_replication_dataarray(func, variant, dtype): unit = unit_registry.m - variants = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)} + variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } data_unit, dim_unit, coord_unit = variants.get(variant) array = np.linspace(0, 10, 20).astype(dtype) * data_unit @@ -1304,7 +1308,11 @@ def test_replication_full_like_dataarray(variant, dtype): # fill value, we don't need to try multiple units unit = unit_registry.m - variants = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)} + variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } data_unit, dim_unit, coord_unit = variants.get(variant) array = np.linspace(0, 5, 10) * data_unit @@ -1362,7 +1370,10 @@ def test_replication_full_like_dataset(variant, dtype): fill_value = -1 * unit_registry.degK - units = {**extract_units(ds), **{name: unit_registry.degK for name in ds.data_vars}} + units = { + **extract_units(ds), + **{name: unit_registry.degK for name in ds.data_vars}, + } expected = attach_units( xr.full_like(strip_units(ds), fill_value=strip_units(fill_value)), units ) @@ -1724,7 +1735,7 @@ def test_missing_value_fillna(self, unit, error): pytest.param(1, id="no_unit"), pytest.param(unit_registry.dimensionless, id="dimensionless"), pytest.param(unit_registry.s, id="incompatible_unit"), - pytest.param(unit_registry.cm, id="compatible_unit"), + pytest.param(unit_registry.cm, id="compatible_unit",), pytest.param(unit_registry.m, id="identical_unit"), ), ) @@ -2175,7 +2186,7 @@ def test_pad(self, mode, xr_arg, np_arg): v = xr.Variable(["x", "y", "z"], data) expected = attach_units( - strip_units(v).pad(mode=mode, **xr_arg), extract_units(v) + strip_units(v).pad(mode=mode, **xr_arg), extract_units(v), ) actual = v.pad(mode=mode, **xr_arg) @@ -2413,7 +2424,7 @@ def test_binary_operations(self, func, dtype): id="equal", marks=pytest.mark.xfail( # LooseVersion(pint.__version__) < "0.14", - reason="inconsistencies in the return values of pint's eq" + reason="inconsistencies in the return values of pint's eq", ), ), ), @@ -2907,8 +2918,8 @@ def test_interpolate_na(self): unit_registry.dimensionless, DimensionalityError, id="dimensionless" ), pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param(unit_registry.cm, None, id="compatible_unit"), - pytest.param(unit_registry.m, None, id="identical_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit",), + pytest.param(unit_registry.m, None, id="identical_unit",), ), ) def test_combine_first(self, unit, error, dtype): @@ -3156,7 +3167,11 @@ def test_pad(self, dtype): def test_content_manipulation(self, func, variant, dtype): unit = unit_registry.m - variants = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)} + variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } data_unit, dim_unit, coord_unit = variants.get(variant) quantity = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit @@ -3422,7 +3437,10 @@ def test_head_tail_thin(self, func, dtype): ids=repr, ) def test_interp_reindex(self, variant, func, dtype): - variants = {"data": (unit_registry.m, 1), "coords": (1, unit_registry.m)} + variants = { + "data": (unit_registry.m, 1), + "coords": (1, unit_registry.m), + } data_unit, coord_unit = variants.get(variant) array = np.linspace(1, 2, 10).astype(dtype) * data_unit @@ -3452,7 +3470,9 @@ def test_interp_reindex(self, variant, func, dtype): pytest.param(unit_registry.m, None, id="identical_unit"), ), ) - @pytest.mark.parametrize("func", (method("interp"), method("reindex")), ids=repr) + @pytest.mark.parametrize( + "func", (method("interp"), method("reindex")), ids=repr, + ) def test_interp_reindex_indexing(self, func, unit, error, dtype): array = np.linspace(1, 2, 10).astype(dtype) x = np.arange(10) * unit_registry.m @@ -3490,7 +3510,10 @@ def test_interp_reindex_indexing(self, func, unit, error, dtype): ids=repr, ) def test_interp_reindex_like(self, variant, func, dtype): - variants = {"data": (unit_registry.m, 1), "coords": (1, unit_registry.m)} + variants = { + "data": (unit_registry.m, 1), + "coords": (1, unit_registry.m), + } data_unit, coord_unit = variants.get(variant) array = np.linspace(1, 2, 10).astype(dtype) * data_unit @@ -3522,7 +3545,7 @@ def test_interp_reindex_like(self, variant, func, dtype): ), ) @pytest.mark.parametrize( - "func", (method("interp_like"), method("reindex_like")), ids=repr + "func", (method("interp_like"), method("reindex_like")), ids=repr, ) def test_interp_reindex_like_indexing(self, func, unit, error, dtype): array = np.linspace(1, 2, 10).astype(dtype) @@ -3658,7 +3681,11 @@ def test_stacking_reordering(self, func, dtype): def test_computation(self, func, variant, dtype): unit = unit_registry.m - variants = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)} + variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } data_unit, dim_unit, coord_unit = variants.get(variant) array = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit @@ -3718,7 +3745,11 @@ def test_computation(self, func, variant, dtype): def test_computation_objects(self, func, variant, dtype): unit = unit_registry.m - variants = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)} + variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } data_unit, dim_unit, coord_unit = variants.get(variant) array = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit @@ -3777,7 +3808,11 @@ def test_resample(self, dtype): def test_grouped_operations(self, func, variant, dtype): unit = unit_registry.m - variants = {"data": (unit, 1, 1), "dims": (1, unit, 1), "coords": (1, 1, unit)} + variants = { + "data": (unit, 1, 1), + "dims": (1, unit, 1), + "coords": (1, 1, unit), + } data_unit, dim_unit, coord_unit = variants.get(variant) array = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit @@ -3892,7 +3927,7 @@ def test_init(self, shared, unit, error, dtype): ( "data", pytest.param( - "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + "dims", marks=pytest.mark.xfail(reason="indexes don't support units"), ), "coords", ), @@ -3908,7 +3943,11 @@ def test_repr(self, func, variant, dtype): x = np.arange(len(array1)) * unit_registry.s y = x.to(unit_registry.ms) - variants = {"dims": {"x": x}, "coords": {"y": ("x", y)}, "data": {}} + variants = { + "dims": {"x": x}, + "coords": {"y": ("x", y)}, + "data": {}, + } ds = xr.Dataset( data_vars={"a": ("x", array1), "b": ("x", array2)}, @@ -4156,7 +4195,7 @@ def test_missing_value_filling(self, func, dtype): unit_registry.dimensionless, DimensionalityError, id="dimensionless" ), pytest.param(unit_registry.s, DimensionalityError, id="incompatible_unit"), - pytest.param(unit_registry.cm, None, id="compatible_unit"), + pytest.param(unit_registry.cm, None, id="compatible_unit",), pytest.param(unit_registry.m, None, id="identical_unit"), ), ) @@ -4301,7 +4340,7 @@ def test_where(self, variant, unit, error, dtype): for key, value in kwargs.items() } - expected = attach_units(strip_units(ds).where(**kwargs_without_units), units) + expected = attach_units(strip_units(ds).where(**kwargs_without_units), units,) actual = ds.where(**kwargs) assert_units_equal(expected, actual) @@ -4320,7 +4359,7 @@ def test_interpolate_na(self, dtype): ds = xr.Dataset({"a": ("x", array1), "b": ("x", array2)}) units = extract_units(ds) - expected = attach_units(strip_units(ds).interpolate_na(dim="x"), units) + expected = attach_units(strip_units(ds).interpolate_na(dim="x"), units,) actual = ds.interpolate_na(dim="x") assert_units_equal(expected, actual) @@ -4343,7 +4382,7 @@ def test_interpolate_na(self, dtype): ( "data", pytest.param( - "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + "dims", marks=pytest.mark.xfail(reason="indexes don't support units"), ), ), ) @@ -4362,7 +4401,7 @@ def test_combine_first(self, variant, unit, error, dtype): ) x = np.arange(len(array1)) * dims_unit ds = xr.Dataset( - data_vars={"a": ("x", array1), "b": ("x", array2)}, coords={"x": x} + data_vars={"a": ("x", array1), "b": ("x", array2)}, coords={"x": x}, ) units = extract_units(ds) @@ -4439,7 +4478,7 @@ def test_comparisons(self, func, variant, unit, dtype): y = coord * coord_unit ds = xr.Dataset( - data_vars={"a": ("x", a), "b": ("x", b)}, coords={"x": x, "y": ("x", y)} + data_vars={"a": ("x", a), "b": ("x", b)}, coords={"x": x, "y": ("x", y)}, ) units = extract_units(ds) @@ -4496,7 +4535,7 @@ def test_comparisons(self, func, variant, unit, dtype): ( "data", pytest.param( - "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + "dims", marks=pytest.mark.xfail(reason="indexes don't support units"), ), ), ) @@ -4549,7 +4588,7 @@ def test_broadcast_equals(self, unit, dtype): right_array2 = np.zeros(shape=(3,)) * unit left = xr.Dataset( - {"a": (("x", "y"), left_array1), "b": (("y", "z"), left_array2)} + {"a": (("x", "y"), left_array1), "b": (("y", "z"), left_array2)}, ) right = xr.Dataset({"a": ("x", right_array1), "b": ("y", right_array2)}) @@ -4587,12 +4626,15 @@ def test_pad(self, dtype): ( "data", pytest.param( - "dims", marks=pytest.mark.xfail(reason="indexes don't support units") + "dims", marks=pytest.mark.xfail(reason="indexes don't support units"), ), ), ) def test_stacking_stacked(self, variant, func, dtype): - variants = {"data": (unit_registry.m, 1), "dims": (1, unit_registry.m)} + variants = { + "data": (unit_registry.m, 1), + "dims": (1, unit_registry.m), + } data_unit, dim_unit = variants.get(variant) array1 = np.linspace(0, 10, 5 * 10).reshape(5, 10).astype(dtype) * data_unit @@ -4635,7 +4677,7 @@ def test_to_stacked_array(self, dtype): func = method("to_stacked_array", "z", variable_dim="y", sample_dims=["x"]) actual = func(ds).rename(None) - expected = attach_units(func(strip_units(ds)).rename(None), units) + expected = attach_units(func(strip_units(ds)).rename(None), units,) assert_units_equal(expected, actual) assert_equal(expected, actual) @@ -4941,7 +4983,7 @@ def test_squeeze(self, shape, dim, dtype): data_vars={ "a": (tuple(names[: len(shape)]), array1), "b": (tuple(names[: len(shape)]), array2), - } + }, ) units = extract_units(ds) @@ -4966,7 +5008,10 @@ def test_squeeze(self, shape, dim, dtype): ids=repr, ) def test_interp_reindex(self, func, variant, dtype): - variants = {"data": (unit_registry.m, 1), "coords": (1, unit_registry.m)} + variants = { + "data": (unit_registry.m, 1), + "coords": (1, unit_registry.m), + } data_unit, coord_unit = variants.get(variant) array1 = np.linspace(-1, 0, 10).astype(dtype) * data_unit @@ -5036,7 +5081,10 @@ def test_interp_reindex_indexing(self, func, unit, error, dtype): ids=repr, ) def test_interp_reindex_like(self, func, variant, dtype): - variants = {"data": (unit_registry.m, 1), "coords": (1, unit_registry.m)} + variants = { + "data": (unit_registry.m, 1), + "coords": (1, unit_registry.m), + } data_unit, coord_unit = variants.get(variant) array1 = np.linspace(-1, 0, 10).astype(dtype) * data_unit From 425ed4575c1ecc93b02babdef687f56239774c6d Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Mon, 13 Jul 2020 08:38:28 +0900 Subject: [PATCH 09/23] make test pass --- xarray/tests/test_dataset.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 9037013cc79..eb39c3584bd 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -5881,8 +5881,6 @@ def test_rolling_keep_attrs(): def test_rolling_properties(ds): # catching invalid args - with pytest.raises(ValueError, match="exactly one dim/window should"): - ds.rolling(time=7, x=2) with pytest.raises(ValueError, match="window must be > 0"): ds.rolling(time=-2) with pytest.raises(ValueError, match="min_periods must be greater than zero"): From 54d84e558fcc1ab86bb09f29c130f634b67724f6 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Tue, 14 Jul 2020 09:12:55 +0900 Subject: [PATCH 10/23] More tests. --- xarray/core/rolling.py | 32 ++++++++++++++++++++++---- xarray/tests/test_dataarray.py | 39 ++++++++++++++++++++++++------- xarray/tests/test_dataset.py | 42 ++++++++++++++++++++++++++++++++++ xarray/tests/test_nputils.py | 2 +- 4 files changed, 101 insertions(+), 14 deletions(-) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 0fd6a8adcea..846f38a4994 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -150,7 +150,10 @@ def _mapping_to_list( if allow_default: return [arg.get(d, default) for d in self.dim] else: - return [arg.get(d, default) for d in self.dim] + for d in self.dim: + if d not in arg: + raise KeyError("argument has no key {}.".format(d)) + return [arg[d] for d in self.dim] elif allow_allsame: # for single argument return [arg] * len(self.dim) elif len(self.dim) == 1: @@ -485,8 +488,9 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None center[d] = self.center[i] if len(dims) > 0: + w = {d: windows[d] for d in dims} self.rollings[key] = DataArrayRolling( - da, windows, min_periods, center, keep_attrs + da, w, min_periods, center, keep_attrs ) def _dataset_implementation(self, func, **kwargs): @@ -538,7 +542,14 @@ def _numpy_or_bottleneck_reduce( **kwargs, ) - def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None): + def construct( + self, + window_dim=None, + stride=1, + fill_value=dtypes.NA, + keep_attrs=None, + **window_dim_kwargs, + ): """ Convert this rolling object to xr.Dataset, where the window dimension is stacked as a new dimension @@ -559,6 +570,16 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None) from .dataset import Dataset + if window_dim is None: + if len(window_dim_kwargs) == 0: + raise ValueError( + "Either window_dim or window_dim_kwargs need to be specified." + ) + window_dim = {d: window_dim_kwargs[d] for d in self.dim} + + window_dim = self._mapping_to_list( + window_dim, allow_default=False, allow_allsame=False + ) stride = self._mapping_to_list(stride, default=1) if keep_attrs is None: @@ -568,10 +589,11 @@ def construct(self, window_dim, stride=1, fill_value=dtypes.NA, keep_attrs=None) for key, da in self.obj.data_vars.items(): # keeps rollings only for the dataset depending on slf.dim dims = [d for d in self.dim if d in da.dims] - if len(dims) > 0: + wi = {d: window_dim[i] for i, d in enumerate(self.dim) if d in da.dims} + st = {d: stride[i] for i, d in enumerate(self.dim) if d in da.dims} dataset[key] = self.rollings[key].construct( - window_dim, fill_value=fill_value + window_dim=wi, fill_value=fill_value, stride=st ) else: dataset[key] = da diff --git a/xarray/tests/test_dataarray.py b/xarray/tests/test_dataarray.py index 579b8d15ec7..bf0edfac371 100644 --- a/xarray/tests/test_dataarray.py +++ b/xarray/tests/test_dataarray.py @@ -6400,21 +6400,44 @@ def test_rolling_count_correct(): @pytest.mark.parametrize("da", (1,), indirect=True) @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("min_periods", (None, 1)) -def test_ndrolling_reduce(da, center, min_periods): +@pytest.mark.parametrize("name", ("sum", "mean", "max")) +def test_ndrolling_reduce(da, center, min_periods, name): rolling_obj = da.rolling(time=3, x=2, center=center, min_periods=min_periods) - actual = rolling_obj.sum() - expected = ( - da.rolling(time=3, center=center, min_periods=min_periods) - .sum() - .rolling(x=2, center=center, min_periods=min_periods) - .sum() - ) + actual = getattr(rolling_obj, name)() + expected = getattr( + getattr( + da.rolling(time=3, center=center, min_periods=min_periods), name + )().rolling(x=2, center=center, min_periods=min_periods), + name, + )() assert_allclose(actual, expected) assert actual.dims == expected.dims +@pytest.mark.parametrize("center", (True, False, (True, False))) +@pytest.mark.parametrize("fill_value", (np.nan, 0.0)) +def test_ndrolling_construct(center, fill_value): + da = DataArray( + np.arange(5 * 6 * 7).reshape(5, 6, 7).astype(float), + dims=["x", "y", "z"], + coords={"x": ["a", "b", "c", "d", "e"], "y": np.arange(6)}, + ) + actual = da.rolling(x=3, z=2, center=center).construct( + x="x1", z="z1", fill_value=fill_value + ) + if not isinstance(center, tuple): + center = (center, center) + expected = ( + da.rolling(x=3, center=center[0]) + .construct(x="x1", fill_value=fill_value) + .rolling(z=2, center=center[1]) + .construct(z="z1", fill_value=fill_value) + ) + assert_allclose(actual, expected) + + def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: xr.DataArray([1, 2, np.NaN]) > 0 diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index eb39c3584bd..0bf3ac80124 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6005,6 +6005,48 @@ def test_rolling_reduce(ds, center, min_periods, window, name): assert src_var.dims == actual[key].dims +@pytest.mark.parametrize("ds", (1,), indirect=True) +@pytest.mark.parametrize("center", (True, False)) +@pytest.mark.parametrize("min_periods", (None, 1)) +@pytest.mark.parametrize("name", ("sum", "mean", "max")) +def test_ndrolling_reduce(ds, center, min_periods, name): + rolling_obj = ds.rolling(time=3, x=2, center=center, min_periods=min_periods) + + actual = getattr(rolling_obj, name)() + expected = getattr( + getattr( + ds.rolling(time=3, center=center, min_periods=min_periods), name + )().rolling(x=2, center=center, min_periods=min_periods), + name, + )() + + assert_allclose(actual, expected) + assert actual.dims == expected.dims + + +@pytest.mark.parametrize("center", (True, False, (True, False))) +@pytest.mark.parametrize("fill_value", (np.nan, 0.0)) +def test_ndrolling_construct(center, fill_value): + da = DataArray( + np.arange(5 * 6 * 7).reshape(5, 6, 7).astype(float), + dims=["x", "y", "z"], + coords={"x": ["a", "b", "c", "d", "e"], "y": np.arange(6)}, + ) + ds = xr.Dataset({"da": da}) + actual = ds.rolling(x=3, z=2, center=center).construct( + x="x1", z="z1", fill_value=fill_value + ) + if not isinstance(center, tuple): + center = (center, center) + expected = ( + ds.rolling(x=3, center=center[0]) + .construct(x="x1", fill_value=fill_value) + .rolling(z=2, center=center[1]) + .construct(z="z1", fill_value=fill_value) + ) + assert_allclose(actual, expected) + + def test_raise_no_warning_for_nan_in_binary_ops(): with pytest.warns(None) as record: Dataset(data_vars={"x": ("y", [1, 2, np.NaN])}) > 0 diff --git a/xarray/tests/test_nputils.py b/xarray/tests/test_nputils.py index de8595c0f81..ccb825dc7e9 100644 --- a/xarray/tests/test_nputils.py +++ b/xarray/tests/test_nputils.py @@ -1,6 +1,6 @@ import numpy as np -from numpy.testing import assert_array_equal import pytest +from numpy.testing import assert_array_equal from xarray.core.nputils import NumpyVIndexAdapter, _is_contiguous, rolling_window From f7c4911cccd2fe16ff63e755efc3af2c9031caa9 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Tue, 14 Jul 2020 09:21:51 +0900 Subject: [PATCH 11/23] more docs --- doc/whats-new.rst | 3 +++ xarray/core/common.py | 2 +- xarray/core/rolling.py | 18 ++++++++++++------ 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index d086d4f411d..5b7f59cc6d1 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -89,6 +89,9 @@ Breaking changes New Features ~~~~~~~~~~~~ +- :py:meth:`~xarray.DataArray.rolling` and :py:meth:`~xarray.Dataset.rolling` + now accept more than 1 dimension.(:pull:`4219`) + By `Keisuke Fujii `_. - :py:meth:`DataArray.argmin` and :py:meth:`DataArray.argmax` now support sequences of 'dim' arguments, and if a sequence is passed return a dict (which can be passed to :py:meth:`isel` to get the value of the minimum) of diff --git a/xarray/core/common.py b/xarray/core/common.py index 67dc0fda461..b851225c6af 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -802,7 +802,7 @@ def rolling( Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. - center : boolean, default False + center : boolean, or a mapping, default False Set the labels at the center of the window. keep_attrs : bool, optional If True, the object's attributes (`attrs`) will be copied from diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 846f38a4994..f360082961f 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -235,12 +235,15 @@ def construct( Parameters ---------- - window_dim: str - New name of the window dimension. - stride: integer, optional + window_dim: str or a mapping, optional + A mapping from dimension name to the new window dimension names. + Just a string can be used for 1d-rolling. + stride: integer or a mapping, optional Size of stride for the rolling window. fill_value: optional. Default dtypes.NA Filling value to match the dimension size. + **window_dim_kwargs : {dim: new_name, ...}, optional + The keyword arguments form of ``window_dim``. Returns ------- @@ -456,7 +459,7 @@ def __init__(self, obj, windows, min_periods=None, center=False, keep_attrs=None Minimum number of observations in window required to have a value (otherwise result is NA). The default, None, is equivalent to setting min_periods equal to the size of the window. - center : boolean, default False + center : boolean, or a mapping from dimension name to boolean, default False Set the labels at the center of the window. keep_attrs : bool, optional If True, the object's attributes (`attrs`) will be copied from @@ -556,12 +559,15 @@ def construct( Parameters ---------- - window_dim: str - New name of the window dimension. + window_dim: str or a mapping, optional + A mapping from dimension name to the new window dimension names. + Just a string can be used for 1d-rolling. stride: integer, optional size of stride for the rolling window. fill_value: optional. Default dtypes.NA Filling value to match the dimension size. + **window_dim_kwargs : {dim: new_name, ...}, optional + The keyword arguments form of ``window_dim``. Returns ------- From 2e8a76f3117f5d17aca24bfbb3717f13d0f86103 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Wed, 15 Jul 2020 05:32:17 +0900 Subject: [PATCH 12/23] mypy --- xarray/core/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 5bfddaa710b..6deafbc865f 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -5969,7 +5969,7 @@ def polyfit( skipna_da = np.any(da.isnull()) dims_to_stack = [dimname for dimname in da.dims if dimname != dim] - stacked_coords = {} + stacked_coords: Dict[Hashable, DataArray] = {} if dims_to_stack: stacked_dim = utils.get_temp_dimname(dims_to_stack, "stacked") rhs = da.transpose(dim, *dims_to_stack).stack( From 31243b1b15ebbce6ec79d8b26f1968ca8d27f834 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Wed, 15 Jul 2020 05:34:04 +0900 Subject: [PATCH 13/23] improve whatsnew --- doc/whats-new.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 5b7f59cc6d1..1a48eae7c64 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -25,6 +25,9 @@ Breaking changes New Features ~~~~~~~~~~~~ +- :py:meth:`~xarray.DataArray.rolling` and :py:meth:`~xarray.Dataset.rolling` + now accept more than 1 dimension.(:pull:`4219`) + By `Keisuke Fujii `_. Bug fixes @@ -89,9 +92,6 @@ Breaking changes New Features ~~~~~~~~~~~~ -- :py:meth:`~xarray.DataArray.rolling` and :py:meth:`~xarray.Dataset.rolling` - now accept more than 1 dimension.(:pull:`4219`) - By `Keisuke Fujii `_. - :py:meth:`DataArray.argmin` and :py:meth:`DataArray.argmax` now support sequences of 'dim' arguments, and if a sequence is passed return a dict (which can be passed to :py:meth:`isel` to get the value of the minimum) of From 22ba8b91f61cfb5538c50c445971ba926010a7b1 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Wed, 15 Jul 2020 05:44:52 +0900 Subject: [PATCH 14/23] Improve doc --- doc/computation.rst | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/doc/computation.rst b/doc/computation.rst index 3660aed93ed..474c3905981 100644 --- a/doc/computation.rst +++ b/doc/computation.rst @@ -188,9 +188,16 @@ a value when aggregating: r = arr.rolling(y=3, center=True, min_periods=2) r.mean() +From version 0.17, xarray supports multidimensional rolling, + +.. ipython:: python + + r = arr.rolling(x=2, y=3, min_periods=2) + r.mean() + .. tip:: - Note that rolling window aggregations are faster and use less memory when bottleneck_ is installed. This only applies to numpy-backed xarray objects. + Note that rolling window aggregations are faster and use less memory when bottleneck_ is installed. This only applies to numpy-backed xarray objects with 1d-rolling. .. _bottleneck: https://github.com/pydata/bottleneck/ From 4b4e64aa070f3a1a863d2db551dac3b0681b6083 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Fri, 7 Aug 2020 14:57:14 +0900 Subject: [PATCH 15/23] Support nd-rolling in dask correctly --- xarray/core/dask_array_ops.py | 119 +++++++++++++++++----------------- 1 file changed, 61 insertions(+), 58 deletions(-) diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index a5f8441f59a..1396f161105 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -32,77 +32,80 @@ def rolling_window(a, axis, window, center, fill_value): """ import dask.array as da - # for nd-rolling. - # TODO It can be more efficient. Currently, the chunks at the boundaries - # will be copied, but it might be OK for many-chunked-arrays. - if hasattr(axis, "__len__"): - for ax, win, cen in zip(axis, window, center): - a = rolling_window(a, ax, win, cen, fill_value) - return a + if not hasattr(axis, "__len__"): + axis = [axis] + window = [window] + center = [center] orig_shape = a.shape - if axis < 0: - axis = a.ndim + axis depth = {d: 0 for d in range(a.ndim)} - depth[axis] = int(window / 2) - # For evenly sized window, we need to crop the first point of each block. - offset = 1 if window % 2 == 0 else 0 - - if depth[axis] > min(a.chunks[axis]): - raise ValueError( - "For window size %d, every chunk should be larger than %d, " - "but the smallest chunk size is %d. Rechunk your array\n" - "with a larger chunk size or a chunk size that\n" - "more evenly divides the shape of your array." - % (window, depth[axis], min(a.chunks[axis])) - ) - - # Although da.overlap pads values to boundaries of the array, - # the size of the generated array is smaller than what we want - # if center == False. - if center: - start = int(window / 2) # 10 -> 5, 9 -> 4 - end = window - 1 - start - else: - start, end = window - 1, 0 - pad_size = max(start, end) + offset - depth[axis] - drop_size = 0 - # pad_size becomes more than 0 when the overlapped array is smaller than - # needed. In this case, we need to enlarge the original array by padding - # before overlapping. - if pad_size > 0: - if pad_size < depth[axis]: - # overlapping requires each chunk larger than depth. If pad_size is - # smaller than the depth, we enlarge this and truncate it later. - drop_size = depth[axis] - pad_size - pad_size = depth[axis] - shape = list(a.shape) - shape[axis] = pad_size - chunks = list(a.chunks) - chunks[axis] = (pad_size,) - fill_array = da.full(shape, fill_value, dtype=a.dtype, chunks=chunks) - a = da.concatenate([fill_array, a], axis=axis) - + offset = [0] * a.ndim + drop_size = [0] * a.ndim + pad_size = [0] * a.ndim + for ax, win, cent in zip(axis, window, center): + if ax < 0: + ax = a.ndim + ax + depth[ax] = int(win / 2) + # For evenly sized window, we need to crop the first point of each block. + offset[ax] = 1 if win % 2 == 0 else 0 + + if depth[ax] > min(a.chunks[ax]): + raise ValueError( + "For window size %d, every chunk should be larger than %d, " + "but the smallest chunk size is %d. Rechunk your array\n" + "with a larger chunk size or a chunk size that\n" + "more evenly divides the shape of your array." + % (win, depth[ax], min(a.chunks[ax])) + ) + + # Although da.overlap pads values to boundaries of the array, + # the size of the generated array is smaller than what we want + # if center == False. + if cent: + start = int(win / 2) # 10 -> 5, 9 -> 4 + end = win - 1 - start + else: + start, end = win - 1, 0 + pad_size[ax] = max(start, end) + offset[ax] - depth[ax] + drop_size[ax] = 0 + # pad_size becomes more than 0 when the overlapped array is smaller than + # needed. In this case, we need to enlarge the original array by padding + # before overlapping. + if pad_size[ax] > 0: + if pad_size[ax] < depth[ax]: + # overlapping requires each chunk larger than depth. If pad_size is + # smaller than the depth, we enlarge this and truncate it later. + drop_size[ax] = depth[ax] - pad_size[ax] + pad_size[ax] = depth[ax] + + # TODO maybe following two lines can be summarized. + a = da.pad( + a, [(p, 0) for p in pad_size], mode="constant", constant_values=fill_value + ) boundary = {d: fill_value for d in range(a.ndim)} # create overlap arrays ag = da.overlap.overlap(a, depth=depth, boundary=boundary) - # apply rolling func - def func(x, window, axis=-1): + def func(x, window, axis): x = np.asarray(x) - rolling = nputils._rolling_window(x, window, axis) - return rolling[(slice(None),) * axis + (slice(offset, None),)] - - chunks = list(a.chunks) - chunks.append(window) + index = [slice(None)] * x.ndim + for ax, win in zip(axis, window): + x = nputils._rolling_window(x, win, ax) + index[ax] = slice(offset[ax], None) + return x[tuple(index)] + + chunks = list(a.chunks) + window + new_axis = [a.ndim + i for i in range(len(axis))] out = ag.map_blocks( - func, dtype=a.dtype, new_axis=a.ndim, chunks=chunks, window=window, axis=axis + func, dtype=a.dtype, new_axis=new_axis, chunks=chunks, window=window, axis=axis ) # crop boundary. - index = (slice(None),) * axis + (slice(drop_size, drop_size + orig_shape[axis]),) - return out[index] + index = [slice(None)] * a.ndim + for ax in axis: + index[ax] = slice(drop_size[ax], drop_size[ax] + orig_shape[ax]) + return out[tuple(index)] def least_squares(lhs, rhs, rcond=None, skipna=False): From 398638cfcc0cd1e9f04ed78f58935f1044be7f4f Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Fri, 7 Aug 2020 15:09:40 +0900 Subject: [PATCH 16/23] Cleanup according to max's comment --- xarray/core/rolling.py | 14 ++++---------- xarray/tests/test_dataset.py | 18 ++++++++++++++---- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index f360082961f..9cd5871d53d 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -100,10 +100,7 @@ def __repr__(self): attrs = [ "{k}->{v}".format(k=k, v=getattr(self, k)) - for k in list(self.dim) - + list(self.window) - + list(self.center) - + [self.min_periods] + for k in list(self.dim) + self.window + self.center + [self.min_periods] ] return "{klass} [{attrs}]".format( klass=self.__class__.__name__, attrs=",".join(attrs) @@ -272,12 +269,9 @@ def construct( from .dataarray import DataArray - if window_dim is None: - if len(window_dim_kwargs) == 0: - raise ValueError( - "Either window_dim or window_dim_kwargs need to be specified." - ) - window_dim = {d: window_dim_kwargs[d] for d in self.dim} + window_dim = utils.either_dict_or_kwargs( + window_dim, window_dim_kwargs, "Dataset.rolling" + ) window_dim = self._mapping_to_list( window_dim, allow_default=False, allow_allsame=False diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 0bf3ac80124..f39ec4ca966 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6008,15 +6008,25 @@ def test_rolling_reduce(ds, center, min_periods, window, name): @pytest.mark.parametrize("ds", (1,), indirect=True) @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("min_periods", (None, 1)) -@pytest.mark.parametrize("name", ("sum", "mean", "max")) +@pytest.mark.parametrize("name", ("sum", "mean", "max", "std")) def test_ndrolling_reduce(ds, center, min_periods, name): - rolling_obj = ds.rolling(time=3, x=2, center=center, min_periods=min_periods) + rolling_obj = ds.rolling(time=4, x=3, center=center, min_periods=min_periods) actual = getattr(rolling_obj, name)() expected = getattr( getattr( - ds.rolling(time=3, center=center, min_periods=min_periods), name - )().rolling(x=2, center=center, min_periods=min_periods), + ds.rolling(time=4, center=center, min_periods=min_periods), name + )().rolling(x=3, center=center, min_periods=min_periods), + name, + )() + assert_allclose(actual, expected) + assert actual.dims == expected.dims + + # Do it in the oposite order + expected = getattr( + getattr( + ds.rolling(x=3, center=center, min_periods=min_periods), name + )().rolling(time=4, center=center, min_periods=min_periods), name, )() From c4fd353e4474e80652acba4752f8252f25f7febf Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Fri, 7 Aug 2020 15:16:24 +0900 Subject: [PATCH 17/23] flake8 --- xarray/core/dask_array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 1396f161105..549f01aba47 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -86,7 +86,7 @@ def rolling_window(a, axis, window, center, fill_value): # create overlap arrays ag = da.overlap.overlap(a, depth=depth, boundary=boundary) - # apply rolling func + def func(x, window, axis): x = np.asarray(x) index = [slice(None)] * x.ndim From 28816f64d504da9fb704df1d91e08c61e86deca3 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Fri, 7 Aug 2020 15:27:44 +0900 Subject: [PATCH 18/23] black --- xarray/core/dask_array_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/dask_array_ops.py b/xarray/core/dask_array_ops.py index 549f01aba47..74474f4321e 100644 --- a/xarray/core/dask_array_ops.py +++ b/xarray/core/dask_array_ops.py @@ -86,7 +86,7 @@ def rolling_window(a, axis, window, center, fill_value): # create overlap arrays ag = da.overlap.overlap(a, depth=depth, boundary=boundary) - + def func(x, window, axis): x = np.asarray(x) index = [slice(None)] * x.ndim From 404e78f8f8a5d8103646c0516841b77b09cb130b Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Fri, 7 Aug 2020 15:32:05 +0900 Subject: [PATCH 19/23] stop using either_dict_or_kwargs --- xarray/core/rolling.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/xarray/core/rolling.py b/xarray/core/rolling.py index 9cd5871d53d..5f996565243 100644 --- a/xarray/core/rolling.py +++ b/xarray/core/rolling.py @@ -269,9 +269,12 @@ def construct( from .dataarray import DataArray - window_dim = utils.either_dict_or_kwargs( - window_dim, window_dim_kwargs, "Dataset.rolling" - ) + if window_dim is None: + if len(window_dim_kwargs) == 0: + raise ValueError( + "Either window_dim or window_dim_kwargs need to be specified." + ) + window_dim = {d: window_dim_kwargs[d] for d in self.dim} window_dim = self._mapping_to_list( window_dim, allow_default=False, allow_allsame=False From 4cac857f55573a3fcf019a4083aca2bd421f6079 Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Sat, 8 Aug 2020 05:30:48 +0900 Subject: [PATCH 20/23] Better tests. --- xarray/tests/test_dataset.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index f39ec4ca966..9a955e75283 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6005,11 +6005,15 @@ def test_rolling_reduce(ds, center, min_periods, window, name): assert src_var.dims == actual[key].dims -@pytest.mark.parametrize("ds", (1,), indirect=True) +@pytest.mark.parametrize("ds", (2,), indirect=True) @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("min_periods", (None, 1)) -@pytest.mark.parametrize("name", ("sum", "mean", "max", "std")) -def test_ndrolling_reduce(ds, center, min_periods, name): +@pytest.mark.parametrize("name", ("sum", "max")) +@pytest.mark.parameteris("dask", (True, False)) +def test_ndrolling_reduce(ds, center, min_periods, name, das): + if dask and has_dask: + ds = ds.chunk({"x": 4}) + rolling_obj = ds.rolling(time=4, x=3, center=center, min_periods=min_periods) actual = getattr(rolling_obj, name)() @@ -6036,13 +6040,17 @@ def test_ndrolling_reduce(ds, center, min_periods, name): @pytest.mark.parametrize("center", (True, False, (True, False))) @pytest.mark.parametrize("fill_value", (np.nan, 0.0)) -def test_ndrolling_construct(center, fill_value): +@pytest.mark.parametrize("dask", (True, False)) +def test_ndrolling_construct(center, fill_value, dask): da = DataArray( np.arange(5 * 6 * 7).reshape(5, 6, 7).astype(float), dims=["x", "y", "z"], coords={"x": ["a", "b", "c", "d", "e"], "y": np.arange(6)}, ) ds = xr.Dataset({"da": da}) + if dask and has_dask: + ds = ds.chunk({"x": 4}) + actual = ds.rolling(x=3, z=2, center=center).construct( x="x1", z="z1", fill_value=fill_value ) From 4bb780460a8ddbbf29439d848c428ee573945f5a Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Sat, 8 Aug 2020 05:42:24 +0900 Subject: [PATCH 21/23] typo --- xarray/tests/test_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 9a955e75283..234d7bfe4d0 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6010,7 +6010,7 @@ def test_rolling_reduce(ds, center, min_periods, window, name): @pytest.mark.parametrize("min_periods", (None, 1)) @pytest.mark.parametrize("name", ("sum", "max")) @pytest.mark.parameteris("dask", (True, False)) -def test_ndrolling_reduce(ds, center, min_periods, name, das): +def test_ndrolling_reduce(ds, center, min_periods, name, dask): if dask and has_dask: ds = ds.chunk({"x": 4}) @@ -6026,7 +6026,7 @@ def test_ndrolling_reduce(ds, center, min_periods, name, das): assert_allclose(actual, expected) assert actual.dims == expected.dims - # Do it in the oposite order + # Do it in the opposite order expected = getattr( getattr( ds.rolling(x=3, center=center, min_periods=min_periods), name From 9703e115838979f0fc93748c86157b1df3eee85a Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Sat, 8 Aug 2020 05:53:44 +0900 Subject: [PATCH 22/23] mypy --- xarray/core/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/core/common.py b/xarray/core/common.py index 0183c898aca..bc5035b682e 100644 --- a/xarray/core/common.py +++ b/xarray/core/common.py @@ -786,7 +786,7 @@ def rolling( self, dim: Mapping[Hashable, int] = None, min_periods: int = None, - center: bool = False, + center: Union[bool, Mapping[Hashable, bool]] = False, keep_attrs: bool = None, **window_kwargs: int, ): From f44dd5db5ee54cb01f1c6cb6a3d662f93932cd1d Mon Sep 17 00:00:00 2001 From: keisukefujii Date: Sat, 8 Aug 2020 07:58:33 +0900 Subject: [PATCH 23/23] typo2 --- xarray/tests/test_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/tests/test_dataset.py b/xarray/tests/test_dataset.py index 234d7bfe4d0..da7621dceb8 100644 --- a/xarray/tests/test_dataset.py +++ b/xarray/tests/test_dataset.py @@ -6009,7 +6009,7 @@ def test_rolling_reduce(ds, center, min_periods, window, name): @pytest.mark.parametrize("center", (True, False)) @pytest.mark.parametrize("min_periods", (None, 1)) @pytest.mark.parametrize("name", ("sum", "max")) -@pytest.mark.parameteris("dask", (True, False)) +@pytest.mark.parametrize("dask", (True, False)) def test_ndrolling_reduce(ds, center, min_periods, name, dask): if dask and has_dask: ds = ds.chunk({"x": 4})