Skip to content

Commit

Permalink
add test and small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Feb 19, 2025
1 parent a9ea4ae commit 2518164
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 29 deletions.
6 changes: 4 additions & 2 deletions docs/source/gallery/model_criticism/plot_ppc_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
---
:::{seealso}
API Documentation: {func}`~arviz_plots.plot_pava_calibration`
API Documentation: {func}`~arviz_plots.plot_ppc_dist`
:::
"""
from arviz_base import load_arviz_data
Expand All @@ -15,9 +15,11 @@

azp.style.use("arviz-variat")

dt = load_arviz_data("radon")
dt = load_arviz_data("rugby")
pc = azp.plot_ppc_dist(
dt,
pc_kwargs={"aes": {"color": ["__variable__"]}}, # map color to variable
aes_map={"title": ["color"]}, # also map color to title
backend="none",
)
pc.show()
42 changes: 18 additions & 24 deletions src/arviz_plots/plots/ppcdistplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from arviz_plots.plot_collection import PlotCollection, process_facet_dims
from arviz_plots.plots.distplot import plot_dist
from arviz_plots.plots.utils import filter_aes, process_group_variables_coords
from arviz_plots.visuals import hist, line_xy
from arviz_plots.visuals import ecdf_line, hist, line_xy


def plot_ppc_dist(
Expand Down Expand Up @@ -103,20 +103,20 @@ def plot_ppc_dist(
Examples
--------
Make a plot of the posterior predictive distribution vs the observed data.
We used an ECDF representation and mapped the color to the variable name.
We used an ECDF representation customized the colors.
.. plot::
:context: close-figs
>>> from arviz_plots import plot_ppc_dist, style
>>> style.use("arviz-variat")
>>> from arviz_base import load_arviz_data
>>> rugby = load_arviz_data('rugby')
>>> radon = load_arviz_data('radon')
>>> pc = plot_ppc_dist(
>>> rugby,
>>> kind="ecdf",
>>> pc_kwargs={"aes": {"color": ["__variable__"]}}, # map color to variable
>>> aes_map={"title": ["color"]}, # also map color to title
>>> plot_kwargs={"predictive_density": {"color":"C1"},
>>> "observed_density": {"color":"C3"}},
>>> )
.. minigallery:: plot_ppc_dist
Expand Down Expand Up @@ -207,7 +207,6 @@ def plot_ppc_dist(
if labeller is None:
labeller = BaseLabeller()

print(pc_kwargs)
# We don't want credible_interval or point_estimate to be mapped to the density representation
plot_kwargs.setdefault("credible_interval", False)
plot_kwargs.setdefault("point_estimate", False)
Expand All @@ -217,19 +216,13 @@ def plot_ppc_dist(
pred_density_kwargs = copy(plot_kwargs.get("predictive_density", {}))
if pred_density_kwargs is not False:
plot_kwargs.setdefault(kind, pred_density_kwargs)

if kind == "kde":
plot_kwargs[kind].setdefault("alpha", 0.3)

plot_kwargs[kind].setdefault("alpha", 0.3)
if kind == "hist":
plot_kwargs["hist"].setdefault("alpha", 0.3)
plot_kwargs["hist"].setdefault("edgecolor", None)
stats_kwargs.setdefault("density", True)

if kind == "ecdf":
plot_kwargs["ecdf"].setdefault("alpha", 0.3)
if plot_kwargs["hist"] is not False:
plot_kwargs["hist"].setdefault("edgecolor", None)
stats_kwargs.setdefault("density", True)

plot_dist(
plot_collection = plot_dist(
distribution,
group=group,
sample_dims=pp_dims,
Expand All @@ -245,13 +238,13 @@ def plot_ppc_dist(
plot_kwargs.get("observed_density", False if group == "prior_predictive" else {})
)

if observed_density_kwargs is not False and any(observed_density_kwargs):
observed_density_kwargs = copy(plot_kwargs.get("observed_density", copy(plot_kwargs[kind])))
if kind in ["kde", "ecdf"]:
observed_density_kwargs.setdefault("alpha", 1)

if observed_density_kwargs is not False:
observed_density_kwargs.setdefault("color", "black")
if kind == "hist":
observed_density_kwargs.setdefault("alpha", 0.3)
observed_density_kwargs.setdefault("edgecolor", None)
stats_kwargs.setdefault("density", True)

_, _, observed_ignore = filter_aes(
plot_collection, aes_map, "observed_density", sample_dims
)
Expand All @@ -277,12 +270,13 @@ def plot_ppc_dist(
)

if kind == "ecdf":
observed_density_kwargs.setdefault("alpha", 1)
dt_observed = dt.observed_data.ds.azstats.ecdf(**stats_kwargs)
plot_collection.map(
line_xy,
ecdf_line,
"observe_density",
data=dt_observed,
ignore_aes=observed_ignore,
**observed_density_kwargs,
)

return plot_collection
28 changes: 25 additions & 3 deletions tests/test_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
plot_ess,
plot_ess_evolution,
plot_forest,
plot_ppc_dist,
plot_psense_dist,
plot_ridge,
plot_trace,
Expand Down Expand Up @@ -42,14 +43,18 @@ def generate_base_data(seed=31):
mu_prior = norm(0, 3).logpdf(mu)
tau_prior = halfnorm(scale=5).logpdf(tau)
theta_prior = norm(0, 1).logpdf(theta)
prior_predictive = rng.normal(size=(1, 100, 7))
posterior_predictive = rng.normal(size=(4, 100, 7))
diverging = rng.choice([True, False], size=(4, 100), p=[0.1, 0.9])

return {
"posterior": {"mu": mu, "theta": theta, "tau": tau},
"sample_stats": {"diverging": diverging},
"observed_data": {"y": obs},
"log_likelihood": {"y": log_lik},
"log_prior": {"mu": mu_prior, "theta": theta_prior, "tau": tau_prior},
"prior_predictive": {"y": prior_predictive},
"posterior_predictive": {"y": posterior_predictive},
"sample_stats": {"diverging": diverging},
}


Expand Down Expand Up @@ -88,10 +93,16 @@ def datatree_4d(seed=31):
theta = rng.normal(size=(4, 100, 5))
eta = rng.normal(size=(4, 100, 5, 3))
diverging = rng.choice([True, False], size=(4, 100), p=[0.1, 0.9])
obs = rng.normal(size=(5, 3))
prior_predictive = rng.normal(size=(1, 100, 5, 3))
posterior_predictive = rng.normal(size=(4, 100, 5, 3))

return from_dict(
{
"posterior": {"mu": mu, "theta": theta, "eta": eta},
"observed_data": {"obs": obs},
"prior_predictive": {"obs": prior_predictive},
"posterior_predictive": {"obs": posterior_predictive},
"sample_stats": {"diverging": diverging},
},
dims={"theta": ["hierarchy"], "eta": ["hierarchy", "group"]},
Expand Down Expand Up @@ -371,8 +382,10 @@ def test_plot_ess_evolution(self, datatree, backend):
assert "hierarchy" not in pc.viz["mu"].dims
assert "hierarchy" in pc.viz["theta"].dims

def test_plot_ess_evolution_sample(self, datatree_sample, backend):
pc = plot_ess_evolution(datatree_sample, backend=backend, sample_dims="sample")
def test_plot_ess_evolution_sample(
self, datatree_sample, backend
): # pylint: disable=unused-argument
pc = plot_ess_evolution(datatree_sample, sample_dims="sample")
assert "chart" in pc.viz.data_vars
assert "plot" not in pc.viz.data_vars
assert "ess_bulk" in pc.viz["mu"]
Expand All @@ -384,6 +397,15 @@ def test_plot_ess_evolution_sample(self, datatree_sample, backend):
assert "hierarchy" not in pc.viz["mu"].dims
assert "hierarchy" in pc.viz["theta"].dims

# omitting hist for the moment as I get [hist-none] - ValueError: artist_kws not empty
@pytest.mark.parametrize("kind", ["kde", "ecdf"])
def test_plot_ppc_dist(self, datatree, kind, backend):
pc = plot_ppc_dist(datatree, kind=kind, backend=backend)
assert "chart" in pc.viz.data_vars
assert pc.aes["y"]
assert kind in pc.viz["y"]
assert "observe_density" in pc.viz["y"]

def test_plot_psense_dist(self, datatree, backend):
pc = plot_psense_dist(datatree, backend=backend)
assert "chart" in pc.viz.data_vars
Expand Down

0 comments on commit 2518164

Please sign in to comment.