Skip to content

Commit

Permalink
Add rootogram
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia committed Feb 21, 2025
1 parent 602d554 commit 08268eb
Show file tree
Hide file tree
Showing 10 changed files with 546 additions and 3 deletions.
11 changes: 9 additions & 2 deletions References.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,12 @@ See https://github.com/arviz-devs/arviz-stats/issues/56

References
----------
.. [1] Betancourt. Diagnosing Suboptimal Cotangent Disintegrations in
Hamiltonian Monte Carlo. (2016) https://arxiv.org/abs/1604.00695
.. [1] Betancourt. *Diagnosing Suboptimal Cotangent Disintegrations in
Hamiltonian Monte Carlo*. (2016) https://arxiv.org/abs/1604.00695

## Rootograms

References
----------
.. [1] Kleiber C, Zeileis A. *Visualizing Count Data Regressions Using Rootograms*.
The American Statistician, 70(3). (2016) https://doi.org/10.1080/00031305.2016.1173590
3 changes: 2 additions & 1 deletion docs/source/api/plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ A complementary introduction and guide to ``plot_...`` functions is available at
plot_ess
plot_ess_evolution
plot_forest
plot_ppc_pava
plot_ppc_dist
plot_ppc_pava
plot_ppc_rootogram
plot_psense_dist
plot_psense_quantities
plot_ridge
Expand Down
23 changes: 23 additions & 0 deletions docs/source/gallery/model_criticism/plot_ppc_rootogram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""
# Rootogram plot
Rootogram for the posterior predictive and observed data.
---
:::{seealso}
API Documentation: {func}`~arviz_plots.plot_ppc_rootogram`
:::
"""
from arviz_base import load_arviz_data

import arviz_plots as azp

azp.style.use("arviz-variat")

dt = load_arviz_data("rugby")
pc = azp.plot_ppc_rootogram(
dt,
backend="none",
)
pc.show()
25 changes: 25 additions & 0 deletions src/arviz_plots/backend/bokeh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,26 @@ def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset,
return span_element


def ciliney(
x,
y_bottom,
y_top,
target,
*,
color=unset,
alpha=unset,
width=unset,
linestyle=unset,
**artist_kws,
):
"""Interface to bokeh for a line from y_bottom to y_top at given value of x."""
kwargs = {"color": color, "alpha": alpha, "line_width": width, "line_dash": linestyle}
x = np.atleast_1d(x)
y_bottom = np.atleast_1d(y_bottom)
y_top = np.atleast_1d(y_top)
return target.segment(x0=x, x1=x, y0=y_bottom, y1=y_top, **_filter_kwargs(kwargs, artist_kws))


# general plot appeareance
def title(string, target, *, size=unset, color=unset, **artist_kws):
"""Interface to bokeh for adding a title to a plot."""
Expand Down Expand Up @@ -473,3 +493,8 @@ def remove_axis(target, axis="y"):
target.axis.visible = False
else:
raise ValueError(f"axis must be one of 'x', 'y' or 'both', got '{axis}'")


def set_y_scale(target, scale):
"""Interface to matplotlib for setting the y scale of a plot."""
target.set_yscale(scale)
84 changes: 84 additions & 0 deletions src/arviz_plots/backend/matplotlib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=no-self-use
"""Matplotlib interface layer.
Notes
Expand All @@ -9,7 +10,10 @@
import warnings
from typing import Any, Dict

import matplotlib.scale as mscale
import matplotlib.transforms as mtransforms
import numpy as np
from matplotlib import ticker
from matplotlib.cbook import normalize_kwargs
from matplotlib.collections import PathCollection
from matplotlib.lines import Line2D
Expand All @@ -29,6 +33,63 @@ class UnsetDefault:
unset = UnsetDefault()


class SquareRootScale(mscale.ScaleBase):
"""ScaleBase class for generating square root scale."""

name = "sqrt"

def __init__(self, axis, **kwargs): # pylint: disable=unused-argument
mscale.ScaleBase.__init__(self, axis)

def set_default_locators_and_formatters(self, axis):
"""Set the locators and formatters to default."""
axis.set_major_locator(ticker.AutoLocator())
axis.set_major_formatter(ticker.ScalarFormatter())
axis.set_minor_locator(ticker.NullLocator())
axis.set_minor_formatter(ticker.NullFormatter())

def limit_range_for_scale(self, vmin, vmax, minpos): # pylint: disable=unused-argument
"""Limit the range of the scale."""
return max(0.0, vmin), vmax

class SquareRootTransform(mtransforms.Transform):
"""Square root transformation."""

input_dims = 1
output_dims = 1
is_separable = True

def transform_non_affine(self, values):
"""Transform the data."""
return np.array(values) ** 0.5

def inverted(self):
"""Invert the transformation."""
return SquareRootScale.InvertedSquareRootTransform()

class InvertedSquareRootTransform(mtransforms.Transform):
"""Inverted square root transformation."""

input_dims = 1
output_dims = 1
is_separable = True

def transform(self, values):
"""Transform the data."""
return np.array(values) ** 2

def inverted(self):
"""Invert the transformation."""
return SquareRootScale.SquareRootTransform()

def get_transform(self):
"""Get the transformation."""
return self.SquareRootTransform()


mscale.register_scale(SquareRootScale)


# generation of default values for aesthetics
def get_default_aes(aes_key, n, kwargs=None):
"""Generate `n` *matplotlib valid* default values for a given aesthetics keyword."""
Expand Down Expand Up @@ -319,6 +380,24 @@ def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset,
return target.axhline(y, **_filter_kwargs(kwargs, Line2D, artist_kws))


def ciliney(
x,
y_bottom,
y_top,
target,
*,
color=unset,
alpha=unset,
width=unset,
linestyle=unset,
**artist_kws,
):
"""Interface to matplotlib for a line from y_bottom to y_top at given value of x."""
artist_kws.setdefault("zorder", 2)
kwargs = {"color": color, "alpha": alpha, "linewidth": width, "linestyle": linestyle}
return target.plot([x, x], [y_bottom, y_top], **_filter_kwargs(kwargs, Line2D, artist_kws))[0]


# general plot appeareance
def title(string, target, *, size=unset, color=unset, **artist_kws):
"""Interface to matplotlib for adding a title to a plot."""
Expand Down Expand Up @@ -396,3 +475,8 @@ def remove_axis(target, axis="y"):
target.spines["bottom"].set_visible(False)
else:
raise ValueError(f"axis must be one of 'x', 'y' or 'both', got '{axis}'")


def set_y_scale(target, scale):
"""Interface to matplotlib for setting the y scale of a plot."""
target.set_yscale(scale)
27 changes: 27 additions & 0 deletions src/arviz_plots/backend/none/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,33 @@ def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset,
return artist_element


def ciliney(
x,
y_bottom,
y_top,
target,
*,
color=unset,
alpha=unset,
width=unset,
linestyle=unset,
**artist_kws,
):
"""Interface to a line from y_bottom to y_top at given value of x."""
kwargs = {"color": color, "alpha": alpha, "width": width, "linestyle": linestyle}
if not ALLOW_KWARGS and artist_kws:
raise ValueError("artist_kws not empty")
artist_element = {
"function": "line",
"x": np.atleast_1d(x),
"y_bottom": np.atleast_1d(y_bottom),
"y_top": np.atleast_1d(y_top),
**_filter_kwargs(kwargs, artist_kws),
}
target.append(artist_element)
return artist_element


# general plot appeareance
def title(string, target, *, size=unset, color=unset, **artist_kws):
"""Interface to adding a title to a plot."""
Expand Down
32 changes: 32 additions & 0 deletions src/arviz_plots/backend/plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,38 @@ def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset,
)


def ciliney(
x,
y_bottom,
y_top,
target,
*,
color=unset,
alpha=unset,
width=unset,
linestyle=unset,
**artist_kws,
):
"""Interface to plotly for a line from y_bottom to y_top at given value of x."""
artist_kws.setdefault("showlegend", False)
line_kwargs = {"color": color, "width": width, "dash": linestyle}
line_artist_kws = artist_kws.pop("line", {}).copy()
kwargs = {"opacity": alpha}

# I was not able to figure it out how to do this withouth a loop
# all solutions I tried did not plot the lines or the lines were connected
for x_i, y_t, y_b in zip(x, y_top, y_bottom):
line_object = go.Scatter(
x=[x_i, x_i],
y=[y_b, y_t],
mode="lines",
line=_filter_kwargs(line_kwargs, line_artist_kws),
**_filter_kwargs(kwargs, artist_kws),
)
target.add_trace(line_object)
return line_object


# general plot appeareance
def title(string, target, *, size=unset, color=unset, **artist_kws):
"""Interface to plotly for adding a title to a plot."""
Expand Down
2 changes: 2 additions & 0 deletions src/arviz_plots/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .forestplot import plot_forest
from .pavacalibrationplot import plot_ppc_pava
from .ppcdistplot import plot_ppc_dist
from .ppcrootogramplot import plot_ppc_rootogram
from .psensedistplot import plot_psense_dist
from .psensequantitiesplot import plot_psense_quantities
from .ridgeplot import plot_ridge
Expand All @@ -26,6 +27,7 @@
"plot_ess",
"plot_ess_evolution",
"plot_ppc_dist",
"plot_ppc_rootogram",
"plot_ridge",
"plot_ppc_pava",
"plot_psense_dist",
Expand Down
Loading

0 comments on commit 08268eb

Please sign in to comment.