diff --git a/doc/whats-new.rst b/doc/whats-new.rst index c76f3800f2b..462872bb80e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -106,6 +106,8 @@ Bug fixes By `Kai Mühlbauer `_. - Fix weighted ``polyfit`` for arrays with more than two dimensions (:issue:`9972`, :pull:`9974`). By `Mattia Almansi `_. +- Preserve order of variables in :py:func:`xarray.combine_by_coords` (:issue:`8828`, :pull:`9070`). + By `Kai Mühlbauer `_. - Cast ``numpy`` scalars to arrays in :py:meth:`NamedArray.from_arrays` (:issue:`10005`, :pull:`10008`) By `Justus Magin `_. diff --git a/xarray/core/combine.py b/xarray/core/combine.py index f2852443d60..f02d046fff6 100644 --- a/xarray/core/combine.py +++ b/xarray/core/combine.py @@ -1,8 +1,7 @@ from __future__ import annotations -import itertools -from collections import Counter -from collections.abc import Iterable, Iterator, Sequence +from collections import Counter, defaultdict +from collections.abc import Callable, Hashable, Iterable, Iterator, Sequence from typing import TYPE_CHECKING, Literal, TypeVar, Union, cast import pandas as pd @@ -269,10 +268,7 @@ def _combine_all_along_first_dim( combine_attrs: CombineAttrsOptions = "drop", ): # Group into lines of datasets which must be combined along dim - # need to sort by _new_tile_id first for groupby to work - # TODO: is the sorted need? - combined_ids = dict(sorted(combined_ids.items(), key=_new_tile_id)) - grouped = itertools.groupby(combined_ids.items(), key=_new_tile_id) + grouped = groupby_defaultdict(list(combined_ids.items()), key=_new_tile_id) # Combine all of these datasets along dim new_combined_ids = {} @@ -606,6 +602,21 @@ def vars_as_keys(ds): return tuple(sorted(ds)) +K = TypeVar("K", bound=Hashable) + + +def groupby_defaultdict( + iter: list[T], + key: Callable[[T], K], +) -> Iterator[tuple[K, Iterator[T]]]: + """replacement for itertools.groupby""" + idx = defaultdict(list) + for i, obj in enumerate(iter): + idx[key(obj)].append(i) + for k, ix in idx.items(): + yield k, (iter[i] for i in ix) + + def _combine_single_variable_hypercube( datasets, fill_value=dtypes.NA, @@ -965,8 +976,7 @@ def combine_by_coords( ] # Group by data vars - sorted_datasets = sorted(data_objects, key=vars_as_keys) - grouped_by_vars = itertools.groupby(sorted_datasets, key=vars_as_keys) + grouped_by_vars = groupby_defaultdict(data_objects, key=vars_as_keys) # Perform the multidimensional combine on each group of data variables # before merging back together diff --git a/xarray/tests/test_combine.py b/xarray/tests/test_combine.py index dfd047e692c..cc20ab414ee 100644 --- a/xarray/tests/test_combine.py +++ b/xarray/tests/test_combine.py @@ -1043,6 +1043,20 @@ def test_combine_by_coords_incomplete_hypercube(self): with pytest.raises(ValueError): combine_by_coords([x1, x2, x3], fill_value=None) + def test_combine_by_coords_override_order(self) -> None: + # regression test for https://github.com/pydata/xarray/issues/8828 + x1 = Dataset({"a": (("y", "x"), [[1]])}, coords={"y": [0], "x": [0]}) + x2 = Dataset( + {"a": (("y", "x"), [[2]]), "b": (("y", "x"), [[1]])}, + coords={"y": [0], "x": [0]}, + ) + actual = combine_by_coords([x1, x2], compat="override") + assert_equal(actual["a"], actual["b"]) + assert_equal(actual["a"], x1["a"]) + + actual = combine_by_coords([x2, x1], compat="override") + assert_equal(actual["a"], x2["a"]) + class TestCombineMixedObjectsbyCoords: def test_combine_by_coords_mixed_unnamed_dataarrays(self):