Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Plot ppc rootogram #142

Merged
merged 1 commit into from
Feb 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions References.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,12 @@ See https://github.com/arviz-devs/arviz-stats/issues/56

References
----------
.. [1] Betancourt. Diagnosing Suboptimal Cotangent Disintegrations in
Hamiltonian Monte Carlo. (2016) https://arxiv.org/abs/1604.00695
.. [1] Betancourt. *Diagnosing Suboptimal Cotangent Disintegrations in
Hamiltonian Monte Carlo*. (2016) https://arxiv.org/abs/1604.00695

## Rootograms

References
----------
.. [1] Kleiber C, Zeileis A. *Visualizing Count Data Regressions Using Rootograms*.
The American Statistician, 70(3). (2016) https://doi.org/10.1080/00031305.2016.1173590
3 changes: 2 additions & 1 deletion docs/source/api/plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ A complementary introduction and guide to ``plot_...`` functions is available at
plot_ess
plot_ess_evolution
plot_forest
plot_ppc_pava
plot_ppc_dist
plot_ppc_pava
plot_ppc_rootogram
plot_psense_dist
plot_psense_quantities
plot_ridge
Expand Down
23 changes: 23 additions & 0 deletions docs/source/gallery/model_criticism/plot_ppc_rootogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
# Rootogram plot

Rootogram for the posterior predictive and observed data.

---

:::{seealso}
API Documentation: {func}`~arviz_plots.plot_ppc_rootogram`
:::
"""
from arviz_base import load_arviz_data

import arviz_plots as azp

azp.style.use("arviz-variat")

dt = load_arviz_data("rugby")
pc = azp.plot_ppc_rootogram(
dt,
backend="none",
)
pc.show()
25 changes: 25 additions & 0 deletions src/arviz_plots/backend/bokeh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,26 @@ def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset,
return span_element


def ciliney(
x,
y_bottom,
y_top,
target,
*,
color=unset,
alpha=unset,
width=unset,
linestyle=unset,
**artist_kws,
):
"""Interface to bokeh for a line from y_bottom to y_top at given value of x."""
kwargs = {"color": color, "alpha": alpha, "line_width": width, "line_dash": linestyle}
x = np.atleast_1d(x)
y_bottom = np.atleast_1d(y_bottom)
y_top = np.atleast_1d(y_top)
return target.segment(x0=x, x1=x, y0=y_bottom, y1=y_top, **_filter_kwargs(kwargs, artist_kws))


# general plot appeareance
def title(string, target, *, size=unset, color=unset, **artist_kws):
"""Interface to bokeh for adding a title to a plot."""
Expand Down Expand Up @@ -473,3 +493,8 @@ def remove_axis(target, axis="y"):
target.axis.visible = False
else:
raise ValueError(f"axis must be one of 'x', 'y' or 'both', got '{axis}'")


def set_y_scale(target, scale):
"""Interface to matplotlib for setting the y scale of a plot."""
target.set_yscale(scale)
84 changes: 84 additions & 0 deletions src/arviz_plots/backend/matplotlib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=no-self-use
"""Matplotlib interface layer.

Notes
Expand All @@ -9,7 +10,10 @@
import warnings
from typing import Any, Dict

import matplotlib.scale as mscale
import matplotlib.transforms as mtransforms
import numpy as np
from matplotlib import ticker
from matplotlib.cbook import normalize_kwargs
from matplotlib.collections import PathCollection
from matplotlib.lines import Line2D
Expand All @@ -29,6 +33,63 @@ class UnsetDefault:
unset = UnsetDefault()


class SquareRootScale(mscale.ScaleBase):
"""ScaleBase class for generating square root scale."""

name = "sqrt"

def __init__(self, axis, **kwargs): # pylint: disable=unused-argument
mscale.ScaleBase.__init__(self, axis)

def set_default_locators_and_formatters(self, axis):
"""Set the locators and formatters to default."""
axis.set_major_locator(ticker.AutoLocator())
axis.set_major_formatter(ticker.ScalarFormatter())
axis.set_minor_locator(ticker.NullLocator())
axis.set_minor_formatter(ticker.NullFormatter())

def limit_range_for_scale(self, vmin, vmax, minpos): # pylint: disable=unused-argument
"""Limit the range of the scale."""
return max(0.0, vmin), vmax

class SquareRootTransform(mtransforms.Transform):
"""Square root transformation."""

input_dims = 1
output_dims = 1
is_separable = True

def transform_non_affine(self, values):
"""Transform the data."""
return np.array(values) ** 0.5

def inverted(self):
"""Invert the transformation."""
return SquareRootScale.InvertedSquareRootTransform()

class InvertedSquareRootTransform(mtransforms.Transform):
"""Inverted square root transformation."""

input_dims = 1
output_dims = 1
is_separable = True

def transform(self, values):
"""Transform the data."""
return np.array(values) ** 2

def inverted(self):
"""Invert the transformation."""
return SquareRootScale.SquareRootTransform()

def get_transform(self):
"""Get the transformation."""
return self.SquareRootTransform()


mscale.register_scale(SquareRootScale)


# generation of default values for aesthetics
def get_default_aes(aes_key, n, kwargs=None):
"""Generate `n` *matplotlib valid* default values for a given aesthetics keyword."""
Expand Down Expand Up @@ -319,6 +380,24 @@ def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset,
return target.axhline(y, **_filter_kwargs(kwargs, Line2D, artist_kws))


def ciliney(
x,
y_bottom,
y_top,
target,
*,
color=unset,
alpha=unset,
width=unset,
linestyle=unset,
**artist_kws,
):
"""Interface to matplotlib for a line from y_bottom to y_top at given value of x."""
artist_kws.setdefault("zorder", 2)
kwargs = {"color": color, "alpha": alpha, "linewidth": width, "linestyle": linestyle}
return target.plot([x, x], [y_bottom, y_top], **_filter_kwargs(kwargs, Line2D, artist_kws))[0]


# general plot appeareance
def title(string, target, *, size=unset, color=unset, **artist_kws):
"""Interface to matplotlib for adding a title to a plot."""
Expand Down Expand Up @@ -396,3 +475,8 @@ def remove_axis(target, axis="y"):
target.spines["bottom"].set_visible(False)
else:
raise ValueError(f"axis must be one of 'x', 'y' or 'both', got '{axis}'")


def set_y_scale(target, scale):
"""Interface to matplotlib for setting the y scale of a plot."""
target.set_yscale(scale)
27 changes: 27 additions & 0 deletions src/arviz_plots/backend/none/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,33 @@ def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset,
return artist_element


def ciliney(
x,
y_bottom,
y_top,
target,
*,
color=unset,
alpha=unset,
width=unset,
linestyle=unset,
**artist_kws,
):
"""Interface to a line from y_bottom to y_top at given value of x."""
kwargs = {"color": color, "alpha": alpha, "width": width, "linestyle": linestyle}
if not ALLOW_KWARGS and artist_kws:
raise ValueError("artist_kws not empty")
artist_element = {
"function": "line",
"x": np.atleast_1d(x),
"y_bottom": np.atleast_1d(y_bottom),
"y_top": np.atleast_1d(y_top),
**_filter_kwargs(kwargs, artist_kws),
}
target.append(artist_element)
return artist_element


# general plot appeareance
def title(string, target, *, size=unset, color=unset, **artist_kws):
"""Interface to adding a title to a plot."""
Expand Down
32 changes: 32 additions & 0 deletions src/arviz_plots/backend/plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,38 @@ def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset,
)


def ciliney(
x,
y_bottom,
y_top,
target,
*,
color=unset,
alpha=unset,
width=unset,
linestyle=unset,
**artist_kws,
):
"""Interface to plotly for a line from y_bottom to y_top at given value of x."""
artist_kws.setdefault("showlegend", False)
line_kwargs = {"color": color, "width": width, "dash": linestyle}
line_artist_kws = artist_kws.pop("line", {}).copy()
kwargs = {"opacity": alpha}

# I was not able to figure it out how to do this withouth a loop
# all solutions I tried did not plot the lines or the lines were connected
for x_i, y_t, y_b in zip(x, y_top, y_bottom):
line_object = go.Scatter(
x=[x_i, x_i],
y=[y_b, y_t],
mode="lines",
line=_filter_kwargs(line_kwargs, line_artist_kws),
**_filter_kwargs(kwargs, artist_kws),
)
target.add_trace(line_object)
return line_object


# general plot appeareance
def title(string, target, *, size=unset, color=unset, **artist_kws):
"""Interface to plotly for adding a title to a plot."""
Expand Down
2 changes: 2 additions & 0 deletions src/arviz_plots/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .forestplot import plot_forest
from .pavacalibrationplot import plot_ppc_pava
from .ppcdistplot import plot_ppc_dist
from .ppcrootogramplot import plot_ppc_rootogram
from .psensedistplot import plot_psense_dist
from .psensequantitiesplot import plot_psense_quantities
from .ridgeplot import plot_ridge
Expand All @@ -26,6 +27,7 @@
"plot_ess",
"plot_ess_evolution",
"plot_ppc_dist",
"plot_ppc_rootogram",
"plot_ridge",
"plot_ppc_pava",
"plot_psense_dist",
Expand Down
Loading