Skip to content

Commit

Permalink
Use DataTree class from xarray (#111)
Browse files Browse the repository at this point in the history
* start work for xarray datatree compatibility

* use datatree from xarray

* fix docs

* remove unused import
  • Loading branch information
OriolAbril authored Dec 5, 2024
1 parent d88c5ea commit d259b47
Show file tree
Hide file tree
Showing 9 changed files with 25 additions and 38 deletions.
3 changes: 1 addition & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,14 @@
numpydoc_xref_aliases = {
"DataArray": ":class:`xarray.DataArray`",
"Dataset": ":class:`xarray.Dataset`",
"DataTree": ":class:`datatree.DataTree`",
"DataTree": ":class:`xarray.DataTree`",
"mapping": ":term:`python:mapping`",
"hashable": ":term:`python:hashable`",
**{f"{singular}s": f":any:`{singular}s <{singular}>`" for singular in singulars},
}

intersphinx_mapping = {
"arviz_org": ("https://www.arviz.org/en/latest/", None),
"datatree": ("https://xarray-datatree.readthedocs.io/en/latest/", None),
"numpy": ("https://numpy.org/doc/stable/", None),
"python": ("https://docs.python.org/3/", None),
"xarray": ("https://docs.xarray.dev/en/stable/", None),
Expand Down
2 changes: 1 addition & 1 deletion docs/source/gallery/mixed/plot_forest_ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
color = plot_bknd.get_default_aes("color", 1, {})[0]

centered = load_arviz_data("centered_eight")
c_aux = centered["posterior"].expand_dims(
c_aux = centered["posterior"].dataset.expand_dims(
column=3
).assign_coords(column=["labels", "forest", "ess"])
pc = azp.plot_forest(c_aux, combined=True, backend=backend)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ classifiers = [
]
dynamic = ["version", "description"]
dependencies = [
"arviz-base @ git+https://github.com/arviz-devs/arviz-base",
"arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats"
"arviz-base @ git+https://github.com/arviz-devs/arviz-base@xarray_datatree",
"arviz-stats[xarray] @ git+https://github.com/arviz-devs/arviz-stats@xarray_datatree",
]

[tool.flit.module]
Expand Down
37 changes: 14 additions & 23 deletions src/arviz_plots/plot_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import xarray as xr
from arviz_base import rcParams
from arviz_base.sel_utils import xarray_sel_iter
from datatree import DataTree


def concat_model_dict(data):
Expand Down Expand Up @@ -243,7 +242,7 @@ def aes(self):
"""
if self.coords is None:
return self._aes_dt
return DataTree.from_dict(
return xr.DataTree.from_dict(
{
group: ds.to_dataset().sel(sel_subset(self.coords, ds.dims))
for group, ds in self._aes_dt.children.items()
Expand Down Expand Up @@ -286,7 +285,7 @@ def viz(self):
}
home_ds = self._viz_dt.to_dataset()
sliced_viz_dict["/"] = home_ds.sel(sel_subset(self.coords, home_ds.dims))
return DataTree.from_dict(sliced_viz_dict)
return xr.DataTree.from_dict(sliced_viz_dict)

@viz.setter
def viz(self, value):
Expand Down Expand Up @@ -376,22 +375,11 @@ def generate_aes_dt(self, aes=None, **kwargs):
.. jupyter-execute::
from datatree import DataTree
from arviz_base import load_arviz_data
from arviz_plots import PlotCollection
from arviz_base.datasets import REMOTE_DATASETS, RemoteFileMetadata
# TODO: remove this monkeypatching once the arviz_example_data repo has been updated
REMOTE_DATASETS.update({
"rugby_field": RemoteFileMetadata(
name="rugby_field",
filename="rugby_field.nc",
url="http://figshare.com/ndownloader/files/44667112",
checksum="53a99da7ac40d82cd01bb0b089263b9633ee016f975700e941b4c6ea289a1fb0",
description="Variant of the rugby model."
)
})
import xarray as xr
idata = load_arviz_data("rugby_field")
pc = PlotCollection(idata.posterior, DataTree(), backend="matplotlib")
pc = PlotCollection(idata.posterior, xr.DataTree(), backend="matplotlib")
pc.generate_aes_dt(
aes={
"color": ["team"],
Expand Down Expand Up @@ -560,7 +548,7 @@ def generate_aes_dt(self, aes=None, **kwargs):
ds[aes_key] = aes_da
else:
ds[aes_key] = neutral_element
self._aes_dt = DataTree.from_dict(ds_dict)
self._aes_dt = xr.DataTree.from_dict(ds_dict)

def get_aes_as_dataset(self, aes_key):
"""Get the values of the provided aes_key for all variables as a Dataset.
Expand Down Expand Up @@ -625,7 +613,7 @@ def rename_artists(self, name_dict=None, **names):
"""
viz_dt = self.viz
for group, child in viz_dt.children.items():
viz_dt[group] = child.rename_vars(name_dict, **names)
viz_dt[group] = child.dataset.rename_vars(name_dict, **names)
self.viz = viz_dt

@classmethod
Expand Down Expand Up @@ -707,7 +695,7 @@ def wrap(
coords={dim: data[dim] for dim in dims},
)
else:
viz_dict["/"] = xr.Dataset({"chart": xr.DataArray(fig)})
viz_dict["/"] = xr.Dataset({"chart": np.array(fig, dtype=object)})
all_dims = cols
facet_cumulative = 0
for var_name, da in data.items():
Expand Down Expand Up @@ -736,7 +724,10 @@ def wrap(
},
coords={dim: da[dim] for dim in dims},
)
viz_dt = DataTree.from_dict(viz_dict)
viz_dt = xr.DataTree(
viz_dict["/"],
children={key: xr.DataTree(value) for key, value in viz_dict.items() if key != "/"},
)
return cls(data, viz_dt, backend=backend, **kwargs)

@classmethod
Expand Down Expand Up @@ -815,7 +806,7 @@ def grid(
coords={dim: data[dim] for dim in dims},
)
else:
viz_dict["/"] = xr.Dataset({"chart": xr.DataArray(fig)})
viz_dict["/"] = xr.Dataset({"chart": np.array(fig, dtype=object)})
all_dims = tuple((*rows, *cols)) # use provided dim orders, not existing ones
facet_cumulative = 0
for var_name, da in data.items():
Expand Down Expand Up @@ -852,7 +843,7 @@ def grid(
},
coords={dim: da[dim] for dim in dims},
)
viz_dt = DataTree.from_dict(viz_dict)
viz_dt = xr.DataTree.from_dict(viz_dict)
return cls(data, viz_dt, backend=backend, **kwargs)

def update_aes(self, ignore_aes=frozenset(), coords=None):
Expand All @@ -870,7 +861,7 @@ def allocate_artist(self, fun_label, data, all_loop_dims, artist_dims=None):
artist_dims = {}
for var_name, da in data.items():
if var_name not in self.viz.children:
DataTree(name=var_name, parent=self.viz)
self.viz[var_name] = xr.DataTree()
inherited_dims = [dim for dim in da.dims if dim in all_loop_dims]
artist_shape = [da.sizes[dim] for dim in inherited_dims] + list(artist_dims.values())
all_artist_dims = inherited_dims + list(artist_dims.keys())
Expand Down
3 changes: 1 addition & 2 deletions src/arviz_plots/plots/compareplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

import numpy as np
from arviz_base import rcParams
from datatree import DataTree
from xarray import Dataset
from xarray import Dataset, DataTree

from arviz_plots.plot_collection import PlotCollection

Expand Down
2 changes: 1 addition & 1 deletion src/arviz_plots/plots/ridgeplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def plot_ridge(
>>> from arviz_plots import visuals
>>> import arviz_stats # make accessor available
>>>
>>> c_aux = centered["posterior"].expand_dims(
>>> c_aux = centered["posterior"].dataset.expand_dims(
>>> column=3
>>> ).assign_coords(column=["labels", "ridge", "ess"])
>>> pc = plot_ridge(c_aux, combined=True)
Expand Down
5 changes: 2 additions & 3 deletions tests/test_hypothesis_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import numpy as np
import pytest
from arviz_base import from_dict
from datatree import DataTree
from hypothesis import given
from scipy.stats import halfnorm, norm

Expand Down Expand Up @@ -54,8 +53,8 @@ def datatree(seed=31):
dt["point_estimate"] = dt.posterior.mean(("chain", "draw"))
# TODO: should become dt.azstats.eti() after fix in arviz-stats
post = dt.posterior.ds
DataTree(name="trunk", parent=dt, data=post.azstats.eti(prob=0.5))
DataTree(name="twig", parent=dt, data=post.azstats.eti(prob=0.9))
dt["trunk"] = post.azstats.eti(prob=0.5)
dt["twig"] = post.azstats.eti(prob=0.9)
return dt


Expand Down
3 changes: 1 addition & 2 deletions tests/test_plot_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import pytest
import xarray.testing as xrt
from arviz_base import dict_to_dataset, load_arviz_data
from datatree import DataTree
from xarray import DataArray, Dataset, concat, full_like
from xarray import DataArray, Dataset, DataTree, concat, full_like

from arviz_plots import PlotCollection
from arviz_plots.plot_collection import _get_aes_dict_from_dt
Expand Down
4 changes: 2 additions & 2 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def test_plot_forest_models(self, datatree, datatree2, backend):
def test_plot_forest_extendable(self, datatree, backend):
dt_aux = (
datatree["posterior"]
.expand_dims(column=3)
.dataset.expand_dims(column=3)
.assign_coords(column=["labels", "forest", "ess"])
)
pc = plot_forest(dt_aux, combined=True, backend=backend)
Expand Down Expand Up @@ -274,7 +274,7 @@ def test_plot_ridge_models(self, datatree, datatree2, backend):
def test_plot_ridge_extendable(self, datatree, backend):
dt_aux = (
datatree["posterior"]
.expand_dims(column=3)
.dataset.expand_dims(column=3)
.assign_coords(column=["labels", "ridge", "ess"])
)
pc = plot_ridge(dt_aux, combined=True, backend=backend)
Expand Down

0 comments on commit d259b47

Please sign in to comment.