Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add copy option in DataTree.from_dict #9193

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
13 changes: 9 additions & 4 deletions xarray/core/datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,15 +882,17 @@ def __getitem__(self: DataTree, key: str) -> DataTree | DataArray:
else:
raise ValueError(f"Invalid format for key: {key}")

def _set(self, key: str, val: DataTree | CoercibleValue) -> None:
def _set(
self, key: str, val: DataTree | CoercibleValue, *, copy: bool = True
) -> None:
"""
Set the child node or variable with the specified key to value.

Counterpart to the public .get method, and also only works on the immediate node, not other nodes in the tree.
"""
if isinstance(val, DataTree):
# create and assign a shallow copy here so as not to alter original name of node in grafted tree
new_node = val.copy(deep=False)
new_node = val.copy(deep=False) if copy else val
new_node.name = key
new_node.parent = self
else:
Expand Down Expand Up @@ -1052,6 +1054,8 @@ def from_dict(
cls,
d: MutableMapping[str, Dataset | DataArray | DataTree | None],
name: str | None = None,
*,
copy: bool = True,
) -> DataTree:
"""
Create a datatree from a dictionary of data objects, organised by paths into the tree.
Expand Down Expand Up @@ -1080,7 +1084,7 @@ def from_dict(
# First create the root node
root_data = d.pop("/", None)
if isinstance(root_data, DataTree):
obj = root_data.copy()
obj = root_data.copy() if copy else root_data
obj.orphan()
else:
obj = cls(name=name, data=root_data, parent=None, children=None)
Expand All @@ -1091,7 +1095,7 @@ def from_dict(
# Create and set new node
node_name = NodePath(path).name
if isinstance(data, DataTree):
new_node = data.copy()
new_node = data.copy() if copy else data
new_node.orphan()
else:
new_node = cls(name=node_name, data=data)
Expand All @@ -1100,6 +1104,7 @@ def from_dict(
new_node,
allow_overwrite=False,
new_nodes_along_path=True,
copy=copy,
)

return obj
Expand Down
10 changes: 6 additions & 4 deletions xarray/core/treenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,7 @@ def _get_item(self: Tree, path: str | NodePath) -> Tree | T_DataArray:
current_node = current_node.get(part)
return current_node

def _set(self: Tree, key: str, val: Tree) -> None:
def _set(self: Tree, key: str, val: Tree, *, copy: bool = True) -> None:
"""
Set the child node with the specified key to value.

Expand All @@ -483,6 +483,8 @@ def _set_item(
item: Tree | T_DataArray,
new_nodes_along_path: bool = False,
allow_overwrite: bool = True,
*,
copy: bool = True,
) -> None:
"""
Set a new item in the tree, overwriting anything already present at that path.
Expand Down Expand Up @@ -539,19 +541,19 @@ def _set_item(
elif new_nodes_along_path:
# Want child classes (i.e. DataTree) to populate tree with their own types
new_node = type(self)()
current_node._set(part, new_node)
current_node._set(part, new_node, copy=copy)
current_node = current_node.children[part]
else:
raise KeyError(f"Could not reach node at path {path}")

if name in current_node.children:
# Deal with anything already existing at this location
if allow_overwrite:
current_node._set(name, item)
current_node._set(name, item, copy=copy)
else:
raise KeyError(f"Already a node object at path {path}")
else:
current_node._set(name, item)
current_node._set(name, item, copy=copy)

def __delitem__(self: Tree, key: str):
"""Remove a child node from this tree object."""
Expand Down
7 changes: 7 additions & 0 deletions xarray/tests/test_datatree.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,13 @@ def test_roundtrip_unnamed_root(self, simple_datatree):
roundtrip = DataTree.from_dict(dt.to_dict())
assert roundtrip.equals(dt)

@pytest.mark.parametrize("copy", [True, False])
def test_copy(self, copy: bool) -> None:
run1 = DataTree.from_dict({"run1": xr.Dataset({"a": 1})})
dt = DataTree.from_dict({"run1": run1}, copy=copy)
is_exact = dt["run1"] is run1
assert is_exact is not copy


class TestDatasetView:
def test_view_contents(self):
Expand Down
Loading