From e035853161002635f8a8dfb191f8b99627eb334c Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Fri, 13 Sep 2024 16:03:09 -0700 Subject: [PATCH] De-duplicate inherited coordinates & dimensions --- xarray/core/datatree.py | 105 ++++++++++++++++++++++++++++------ xarray/tests/test_datatree.py | 40 ++++++++++++- 2 files changed, 124 insertions(+), 21 deletions(-) diff --git a/xarray/core/datatree.py b/xarray/core/datatree.py index 2b3179cb79d..96af7c43b26 100644 --- a/xarray/core/datatree.py +++ b/xarray/core/datatree.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import itertools import textwrap from collections import ChainMap @@ -373,6 +374,10 @@ def map( # type: ignore[override] return Dataset(variables) +CONFLICTING_COORDS_MODES = {"error", "ignore"} +ConflictingCoordsMode = Literal["ignore", "raise"] + + class DataTree( NamedNode, MappedDatasetMethodsMixin, @@ -414,6 +419,7 @@ class DataTree( _attrs: dict[Hashable, Any] | None _encoding: dict[Hashable, Any] | None _close: Callable[[], None] | None + _conflicting_coords_mode: ConflictingCoordsMode __slots__ = ( "_name", @@ -427,6 +433,7 @@ class DataTree( "_attrs", "_encoding", "_close", + "_conflicting_coords_mode", ) def __init__( @@ -459,9 +466,10 @@ def __init__( -------- DataTree.from_dict """ + self._conflicting_coords_mode: str = "ignore" self._set_node_data(_to_new_dataset(dataset)) - - # comes after setting node data as this will check for clashes between child names and existing variable names + # comes after setting node data as this will check for clashes between + # child names and existing variable names super().__init__(name=name, children=children) def _set_node_data(self, dataset: Dataset): @@ -486,6 +494,55 @@ def _pre_attach(self: DataTree, parent: DataTree, name: str) -> None: parent_ds = parent._to_dataset_view(rebuild_dims=False, inherited=True) check_alignment(path, node_ds, parent_ds, self.children) + @contextlib.contextmanager + def _with_conflicting_coords_mode(self, mode: ConflictingCoordsMode): + """Set handling of duplicated inherited coordinates on this object. + + This private context manager is an indirect way to control the handling + of duplicated inherited coordinates, which is done inside _post_attach() + and hence is difficult to directly parameterize. + """ + if mode not in CONFLICTING_COORDS_MODES: + raise ValueError( + f"conflicting coordinates mode must be in {CONFLICTING_COORDS_MODES}, got {mode!r}" + ) + assert self.parent is None + original_mode = self._conflicting_coords_mode + self._conflicting_coords_mode = mode + try: + yield + finally: + self._conflicting_coords_mode = original_mode + + def _dedup_inherited_coordinates(self): + removed_something = False + for name in self._coord_variables.parents: + if name in self._node_coord_variables: + indexed = name in self._node_indexes + if indexed: + del self._node_indexes[name] + elif self.root._conflicting_coords_mode == "error": + # error only for non-indexed coordinates, which are not + # checked for equality + raise ValueError( + f"coordinate {name!r} on node {self.path!r} is also found on a parent node" + ) + + del self._node_coord_variables[name] + removed_something = True + + if removed_something: + self._node_dims = calculate_dimensions( + self._data_variables | self._node_coord_variables + ) + + for child in self._children.values(): + child._dedup_inherited_coordinates() + + def _post_attach(self: DataTree, parent: DataTree, name: str) -> None: + super()._post_attach(parent, name) + self._dedup_inherited_coordinates() + @property def _coord_variables(self) -> ChainMap[Hashable, Variable]: return ChainMap( @@ -1055,6 +1112,7 @@ def from_dict( d: Mapping[str, Dataset | DataTree | None], /, name: str | None = None, + conflicting_coords: ConflictingCoordsMode = "ignore", ) -> DataTree: """ Create a datatree from a dictionary of data objects, organised by paths into the tree. @@ -1070,6 +1128,11 @@ def from_dict( To assign data to the root node of the tree use "/" as the path. name : Hashable | None, optional Name for the root node of the tree. Default is None. + conflicting_coords : "ignore" or "error" + How to handle repeated coordinates without an associated index + between parent and child nodes. By default, such coordinates are + dropped from child nodes. Alternatively, an error can be raise when + they are encountered. Returns ------- @@ -1097,23 +1160,27 @@ def depth(item) -> int: return len(NodePath(pathstr).parts) if d_cast: - # Populate tree with children determined from data_objects mapping - # Sort keys by depth so as to insert nodes from root first (see GH issue #9276) - for path, data in sorted(d_cast.items(), key=depth): - # Create and set new node - node_name = NodePath(path).name - if isinstance(data, DataTree): - new_node = data.copy() - elif isinstance(data, Dataset) or data is None: - new_node = cls(name=node_name, dataset=data) - else: - raise TypeError(f"invalid values: {data}") - obj._set_item( - path, - new_node, - allow_overwrite=False, - new_nodes_along_path=True, - ) + # Populate tree with children determined from data_objects mapping. + # Ensure the result object has the right handling of conflicting + # coordinates. + with obj._with_conflicting_coords_mode(conflicting_coords): + # Sort keys by depth so as to insert nodes from root first (see + # GH issue #9276) + for path, data in sorted(d_cast.items(), key=depth): + # Create and set new node + node_name = NodePath(path).name + if isinstance(data, DataTree): + new_node = data.copy() + elif isinstance(data, Dataset) or data is None: + new_node = cls(name=node_name, dataset=data) + else: + raise TypeError(f"invalid values: {data}") + obj._set_item( + path, + new_node, + allow_overwrite=False, + new_nodes_along_path=True, + ) return obj diff --git a/xarray/tests/test_datatree.py b/xarray/tests/test_datatree.py index 0435b199b9b..26e4c809d20 100644 --- a/xarray/tests/test_datatree.py +++ b/xarray/tests/test_datatree.py @@ -852,6 +852,42 @@ def test_array_values(self): with pytest.raises(TypeError): DataTree.from_dict(data) # type: ignore[arg-type] + def test_conflicting_coords(self): + tree = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": 1}), + "/sub": xr.Dataset(coords={"x": 2}), + }, + conflicting_coords="ignore", + ) + expected = DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": 1}), + "/sub": xr.Dataset(), + } + ) + assert_identical(tree, expected) + + with pytest.raises( + ValueError, + match="coordinate 'x' on node '/sub' is also found on a parent node", + ): + DataTree.from_dict( + { + "/": xr.Dataset(coords={"x": 1}), + "/sub": xr.Dataset(coords={"x": 2}), + }, + conflicting_coords="error", + ) + + with pytest.raises( + ValueError, + match=re.escape( + r"conflicting coordinates mode must be in {'ignore', 'error'}, got 'invalid'" + ), + ): + DataTree.from_dict({"/sub": xr.Dataset()}, conflicting_coords="invalid") # type: ignore + class TestDatasetView: def test_view_contents(self): @@ -1166,11 +1202,11 @@ def test_inherited_coords_override(self): ) assert dt.coords.keys() == {"x", "y"} root_coords = {"x": 1, "y": 2} - sub_coords = {"x": 4, "y": 2, "z": 3} + sub_coords = {"x": 1, "y": 2, "z": 3} xr.testing.assert_equal(dt["/x"], xr.DataArray(1, coords=root_coords)) xr.testing.assert_equal(dt["/y"], xr.DataArray(2, coords=root_coords)) assert dt["/b"].coords.keys() == {"x", "y", "z"} - xr.testing.assert_equal(dt["/b/x"], xr.DataArray(4, coords=sub_coords)) + xr.testing.assert_equal(dt["/b/x"], xr.DataArray(1, coords=sub_coords)) xr.testing.assert_equal(dt["/b/y"], xr.DataArray(2, coords=sub_coords)) xr.testing.assert_equal(dt["/b/z"], xr.DataArray(3, coords=sub_coords))