From 5886b727a3f10faa0e605bb83b3e9c144d9597d7 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 24 Jan 2024 09:18:36 -0700 Subject: [PATCH 1/2] groupby: Set method="None" by default on flox>=0.9 --- doc/whats-new.rst | 3 +++ xarray/core/groupby.py | 11 +++++++---- xarray/tests/test_groupby.py | 18 ++++++++++++++++++ 3 files changed, 28 insertions(+), 4 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 317f3b1a824..a91d345e301 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,6 +23,9 @@ v2024.02.0 (unreleased) New Features ~~~~~~~~~~~~ +- Xarray now defers to flox's `heuristics `_ + to set default `method` for groupby problems. This only applies to ``flox>=0.9``. + By `Deepak Cherian `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index ebb488d42c9..3a9c3ee6470 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -16,6 +16,7 @@ import numpy as np import pandas as pd +from packaging.version import Version from xarray.core import dtypes, duck_array_ops, nputils, ops from xarray.core._aggregations import ( @@ -1003,6 +1004,7 @@ def _flox_reduce( **kwargs: Any, ): """Adaptor function that translates our groupby API to that of flox.""" + import flox from flox.xarray import xarray_reduce from xarray.core.dataset import Dataset @@ -1014,10 +1016,11 @@ def _flox_reduce( if keep_attrs is None: keep_attrs = _get_keep_attrs(default=True) - # preserve current strategy (approximately) for dask groupby. - # We want to control the default anyway to prevent surprises - # if flox decides to change its default - kwargs.setdefault("method", "cohorts") + if Version(flox.__version__) < Version("0.9"): + # preserve current strategy (approximately) for dask groupby + # on older flox versions to prevent surprises. + # flox >=0.9 will choose this on its own. + kwargs.setdefault("method", "cohorts") numeric_only = kwargs.pop("numeric_only", None) if numeric_only: diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index e45d8ed0bef..70a798fbd47 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -3,10 +3,12 @@ import datetime import operator import warnings +from unittest import mock import numpy as np import pandas as pd import pytest +from packaging.version import Version import xarray as xr from xarray import DataArray, Dataset, Variable @@ -2499,3 +2501,19 @@ def test_groupby_dim_no_dim_equal(use_flox): actual1 = da.drop_vars("lat").groupby("lat", squeeze=False).sum() actual2 = da.groupby("lat", squeeze=False).sum() assert_identical(actual1, actual2.drop_vars("lat")) + + +@requires_flox +def test_default_flox_method(): + import flox.xarray + + da = xr.DataArray([1, 2, 3], dims="x", coords={"label": ("x", [2, 2, 1])}) + + result = xr.DataArray([3, 3], dims="label", coords={"label": [1, 2]}) + with mock.patch("flox.xarray.xarray_reduce", return_value=result) as mocked_reduce: + da.groupby("label").sum() + + if Version(flox.__version__) < Version("0.9.0"): + mocked_reduce.assert_called_once_with(method="cohorts") + else: + assert "method" not in mocked_reduce.call_args From b10cbe6f3bef1e8b744085f9741f25d40566c0f8 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Wed, 24 Jan 2024 14:44:22 -0700 Subject: [PATCH 2/2] Fix? --- xarray/tests/test_groupby.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index 70a798fbd47..25fabd5e2b9 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -2513,7 +2513,8 @@ def test_default_flox_method(): with mock.patch("flox.xarray.xarray_reduce", return_value=result) as mocked_reduce: da.groupby("label").sum() + kwargs = mocked_reduce.call_args.kwargs if Version(flox.__version__) < Version("0.9.0"): - mocked_reduce.assert_called_once_with(method="cohorts") + assert kwargs["method"] == "cohorts" else: - assert "method" not in mocked_reduce.call_args + assert "method" not in kwargs