Skip to content

Commit

Permalink
Fix DataArrayRolling.__iter__ with center=True (#6744)
Browse files Browse the repository at this point in the history
* new test_rolling module

* fix rolling iter with center=True

* add fix to whats-new

* fix DatasetRolling test names

* small code simplification
  • Loading branch information
headtr1ck authored Jul 14, 2022
1 parent e5fcd79 commit f28d7f8
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ Bug fixes
- :py:meth:`open_dataset` with dask and ``~`` in the path now resolves the home directory
instead of raising an error. (:issue:`6707`, :pull:`6710`)
By `Michael Niklas <https://github.com/headtr1ck>`_.
- :py:meth:`DataArrayRolling.__iter__` with ``center=True`` now works correctly.
(:issue:`6739`, :pull:`6744`)
By `Michael Niklas <https://github.com/headtr1ck>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
17 changes: 11 additions & 6 deletions xarray/core/rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,16 +267,21 @@ def __init__(
# TODO legacy attribute
self.window_labels = self.obj[self.dim[0]]

def __iter__(self) -> Iterator[tuple[RollingKey, DataArray]]:
def __iter__(self) -> Iterator[tuple[DataArray, DataArray]]:
if self.ndim > 1:
raise ValueError("__iter__ is only supported for 1d-rolling")
stops = np.arange(1, len(self.window_labels) + 1)
starts = stops - int(self.window[0])
starts[: int(self.window[0])] = 0

dim0 = self.dim[0]
window0 = int(self.window[0])
offset = (window0 + 1) // 2 if self.center[0] else 1
stops = np.arange(offset, self.obj.sizes[dim0] + offset)
starts = stops - window0
starts[: window0 - offset] = 0

for (label, start, stop) in zip(self.window_labels, starts, stops):
window = self.obj.isel({self.dim[0]: slice(start, stop)})
window = self.obj.isel({dim0: slice(start, stop)})

counts = window.count(dim=self.dim[0])
counts = window.count(dim=dim0)
window = window.where(counts >= self.min_periods)

yield (label, window)
Expand Down
15 changes: 5 additions & 10 deletions xarray/tests/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@

class TestDataArrayRolling:
@pytest.mark.parametrize("da", (1, 2), indirect=True)
def test_rolling_iter(self, da) -> None:
rolling_obj = da.rolling(time=7)
@pytest.mark.parametrize("center", [True, False])
@pytest.mark.parametrize("size", [1, 2, 3, 7])
def test_rolling_iter(self, da: DataArray, center: bool, size: int) -> None:
rolling_obj = da.rolling(time=size, center=center)
rolling_obj_mean = rolling_obj.mean()

assert len(rolling_obj.window_labels) == len(da["time"])
Expand All @@ -40,14 +42,7 @@ def test_rolling_iter(self, da) -> None:
actual = rolling_obj_mean.isel(time=i)
expected = window_da.mean("time")

# TODO add assert_allclose_with_nan, which compares nan position
# as well as the closeness of the values.
assert_array_equal(actual.isnull(), expected.isnull())
if (~actual.isnull()).sum() > 0:
np.allclose(
actual.values[actual.values.nonzero()],
expected.values[expected.values.nonzero()],
)
np.testing.assert_allclose(actual.values, expected.values)

@pytest.mark.parametrize("da", (1,), indirect=True)
def test_rolling_repr(self, da) -> None:
Expand Down

0 comments on commit f28d7f8

Please sign in to comment.