diff --git a/References.md b/References.md index 7f356b7..c8e6aa3 100644 --- a/References.md +++ b/References.md @@ -48,5 +48,12 @@ See https://github.com/arviz-devs/arviz-stats/issues/56 References ---------- - .. [1] Betancourt. Diagnosing Suboptimal Cotangent Disintegrations in - Hamiltonian Monte Carlo. (2016) https://arxiv.org/abs/1604.00695 \ No newline at end of file + .. [1] Betancourt. *Diagnosing Suboptimal Cotangent Disintegrations in + Hamiltonian Monte Carlo*. (2016) https://arxiv.org/abs/1604.00695 + +## Rootograms + + References + ---------- + .. [1] Kleiber C, Zeileis A. *Visualizing Count Data Regressions Using Rootograms*. + The American Statistician, 70(3). (2016) https://doi.org/10.1080/00031305.2016.1173590 \ No newline at end of file diff --git a/docs/source/api/plots.rst b/docs/source/api/plots.rst index f8e3b25..a17609e 100644 --- a/docs/source/api/plots.rst +++ b/docs/source/api/plots.rst @@ -23,8 +23,9 @@ A complementary introduction and guide to ``plot_...`` functions is available at plot_ess plot_ess_evolution plot_forest - plot_ppc_pava plot_ppc_dist + plot_ppc_pava + plot_ppc_rootogram plot_psense_dist plot_psense_quantities plot_ridge diff --git a/docs/source/gallery/model_criticism/plot_ppc_rootogram.py b/docs/source/gallery/model_criticism/plot_ppc_rootogram.py new file mode 100644 index 0000000..b9e7a30 --- /dev/null +++ b/docs/source/gallery/model_criticism/plot_ppc_rootogram.py @@ -0,0 +1,23 @@ +""" +# Rootogram plot + +Rootogram for the posterior predictive and observed data. + +--- + +:::{seealso} +API Documentation: {func}`~arviz_plots.plot_ppc_rootogram` +::: +""" +from arviz_base import load_arviz_data + +import arviz_plots as azp + +azp.style.use("arviz-variat") + +dt = load_arviz_data("rugby") +pc = azp.plot_ppc_rootogram( + dt, + backend="none", +) +pc.show() diff --git a/src/arviz_plots/backend/bokeh/__init__.py b/src/arviz_plots/backend/bokeh/__init__.py index d8190b9..18ae0cf 100644 --- a/src/arviz_plots/backend/bokeh/__init__.py +++ b/src/arviz_plots/backend/bokeh/__init__.py @@ -390,6 +390,26 @@ def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, return span_element +def ciliney( + x, + y_bottom, + y_top, + target, + *, + color=unset, + alpha=unset, + width=unset, + linestyle=unset, + **artist_kws, +): + """Interface to bokeh for a line from y_bottom to y_top at given value of x.""" + kwargs = {"color": color, "alpha": alpha, "line_width": width, "line_dash": linestyle} + x = np.atleast_1d(x) + y_bottom = np.atleast_1d(y_bottom) + y_top = np.atleast_1d(y_top) + return target.segment(x0=x, x1=x, y0=y_bottom, y1=y_top, **_filter_kwargs(kwargs, artist_kws)) + + # general plot appeareance def title(string, target, *, size=unset, color=unset, **artist_kws): """Interface to bokeh for adding a title to a plot.""" @@ -473,3 +493,8 @@ def remove_axis(target, axis="y"): target.axis.visible = False else: raise ValueError(f"axis must be one of 'x', 'y' or 'both', got '{axis}'") + + +def set_y_scale(target, scale): + """Interface to matplotlib for setting the y scale of a plot.""" + target.set_yscale(scale) diff --git a/src/arviz_plots/backend/matplotlib/__init__.py b/src/arviz_plots/backend/matplotlib/__init__.py index 7d28d07..2816b05 100644 --- a/src/arviz_plots/backend/matplotlib/__init__.py +++ b/src/arviz_plots/backend/matplotlib/__init__.py @@ -1,3 +1,4 @@ +# pylint: disable=no-self-use """Matplotlib interface layer. Notes @@ -9,7 +10,10 @@ import warnings from typing import Any, Dict +import matplotlib.scale as mscale +import matplotlib.transforms as mtransforms import numpy as np +from matplotlib import ticker from matplotlib.cbook import normalize_kwargs from matplotlib.collections import PathCollection from matplotlib.lines import Line2D @@ -29,6 +33,63 @@ class UnsetDefault: unset = UnsetDefault() +class SquareRootScale(mscale.ScaleBase): + """ScaleBase class for generating square root scale.""" + + name = "sqrt" + + def __init__(self, axis, **kwargs): # pylint: disable=unused-argument + mscale.ScaleBase.__init__(self, axis) + + def set_default_locators_and_formatters(self, axis): + """Set the locators and formatters to default.""" + axis.set_major_locator(ticker.AutoLocator()) + axis.set_major_formatter(ticker.ScalarFormatter()) + axis.set_minor_locator(ticker.NullLocator()) + axis.set_minor_formatter(ticker.NullFormatter()) + + def limit_range_for_scale(self, vmin, vmax, minpos): # pylint: disable=unused-argument + """Limit the range of the scale.""" + return max(0.0, vmin), vmax + + class SquareRootTransform(mtransforms.Transform): + """Square root transformation.""" + + input_dims = 1 + output_dims = 1 + is_separable = True + + def transform_non_affine(self, values): + """Transform the data.""" + return np.array(values) ** 0.5 + + def inverted(self): + """Invert the transformation.""" + return SquareRootScale.InvertedSquareRootTransform() + + class InvertedSquareRootTransform(mtransforms.Transform): + """Inverted square root transformation.""" + + input_dims = 1 + output_dims = 1 + is_separable = True + + def transform(self, values): + """Transform the data.""" + return np.array(values) ** 2 + + def inverted(self): + """Invert the transformation.""" + return SquareRootScale.SquareRootTransform() + + def get_transform(self): + """Get the transformation.""" + return self.SquareRootTransform() + + +mscale.register_scale(SquareRootScale) + + # generation of default values for aesthetics def get_default_aes(aes_key, n, kwargs=None): """Generate `n` *matplotlib valid* default values for a given aesthetics keyword.""" @@ -319,6 +380,24 @@ def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, return target.axhline(y, **_filter_kwargs(kwargs, Line2D, artist_kws)) +def ciliney( + x, + y_bottom, + y_top, + target, + *, + color=unset, + alpha=unset, + width=unset, + linestyle=unset, + **artist_kws, +): + """Interface to matplotlib for a line from y_bottom to y_top at given value of x.""" + artist_kws.setdefault("zorder", 2) + kwargs = {"color": color, "alpha": alpha, "linewidth": width, "linestyle": linestyle} + return target.plot([x, x], [y_bottom, y_top], **_filter_kwargs(kwargs, Line2D, artist_kws))[0] + + # general plot appeareance def title(string, target, *, size=unset, color=unset, **artist_kws): """Interface to matplotlib for adding a title to a plot.""" @@ -396,3 +475,8 @@ def remove_axis(target, axis="y"): target.spines["bottom"].set_visible(False) else: raise ValueError(f"axis must be one of 'x', 'y' or 'both', got '{axis}'") + + +def set_y_scale(target, scale): + """Interface to matplotlib for setting the y scale of a plot.""" + target.set_yscale(scale) diff --git a/src/arviz_plots/backend/none/__init__.py b/src/arviz_plots/backend/none/__init__.py index 0ae91d7..454bb17 100644 --- a/src/arviz_plots/backend/none/__init__.py +++ b/src/arviz_plots/backend/none/__init__.py @@ -357,6 +357,33 @@ def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, return artist_element +def ciliney( + x, + y_bottom, + y_top, + target, + *, + color=unset, + alpha=unset, + width=unset, + linestyle=unset, + **artist_kws, +): + """Interface to a line from y_bottom to y_top at given value of x.""" + kwargs = {"color": color, "alpha": alpha, "width": width, "linestyle": linestyle} + if not ALLOW_KWARGS and artist_kws: + raise ValueError("artist_kws not empty") + artist_element = { + "function": "line", + "x": np.atleast_1d(x), + "y_bottom": np.atleast_1d(y_bottom), + "y_top": np.atleast_1d(y_top), + **_filter_kwargs(kwargs, artist_kws), + } + target.append(artist_element) + return artist_element + + # general plot appeareance def title(string, target, *, size=unset, color=unset, **artist_kws): """Interface to adding a title to a plot.""" diff --git a/src/arviz_plots/backend/plotly/__init__.py b/src/arviz_plots/backend/plotly/__init__.py index 425f67c..b33ecc0 100644 --- a/src/arviz_plots/backend/plotly/__init__.py +++ b/src/arviz_plots/backend/plotly/__init__.py @@ -462,6 +462,38 @@ def hline(y, target, *, color=unset, alpha=unset, width=unset, linestyle=unset, ) +def ciliney( + x, + y_bottom, + y_top, + target, + *, + color=unset, + alpha=unset, + width=unset, + linestyle=unset, + **artist_kws, +): + """Interface to plotly for a line from y_bottom to y_top at given value of x.""" + artist_kws.setdefault("showlegend", False) + line_kwargs = {"color": color, "width": width, "dash": linestyle} + line_artist_kws = artist_kws.pop("line", {}).copy() + kwargs = {"opacity": alpha} + + # I was not able to figure it out how to do this withouth a loop + # all solutions I tried did not plot the lines or the lines were connected + for x_i, y_t, y_b in zip(x, y_top, y_bottom): + line_object = go.Scatter( + x=[x_i, x_i], + y=[y_b, y_t], + mode="lines", + line=_filter_kwargs(line_kwargs, line_artist_kws), + **_filter_kwargs(kwargs, artist_kws), + ) + target.add_trace(line_object) + return line_object + + # general plot appeareance def title(string, target, *, size=unset, color=unset, **artist_kws): """Interface to plotly for adding a title to a plot.""" diff --git a/src/arviz_plots/plots/__init__.py b/src/arviz_plots/plots/__init__.py index efc8c69..dfda262 100644 --- a/src/arviz_plots/plots/__init__.py +++ b/src/arviz_plots/plots/__init__.py @@ -9,6 +9,7 @@ from .forestplot import plot_forest from .pavacalibrationplot import plot_ppc_pava from .ppcdistplot import plot_ppc_dist +from .ppcrootogramplot import plot_ppc_rootogram from .psensedistplot import plot_psense_dist from .psensequantitiesplot import plot_psense_quantities from .ridgeplot import plot_ridge @@ -26,6 +27,7 @@ "plot_ess", "plot_ess_evolution", "plot_ppc_dist", + "plot_ppc_rootogram", "plot_ridge", "plot_ppc_pava", "plot_psense_dist", diff --git a/src/arviz_plots/plots/ppcrootogramplot.py b/src/arviz_plots/plots/ppcrootogramplot.py new file mode 100644 index 0000000..dd59455 --- /dev/null +++ b/src/arviz_plots/plots/ppcrootogramplot.py @@ -0,0 +1,324 @@ +"""Plot ppc rootogram for discrete (count) data.""" +from copy import copy +from importlib import import_module + +from arviz_base import rcParams +from arviz_base.labels import BaseLabeller +from arviz_stats.helper_stats import point_interval_unique, point_unique + +from arviz_plots.plot_collection import PlotCollection, process_facet_dims +from arviz_plots.plots.utils import filter_aes, process_group_variables_coords +from arviz_plots.visuals import ( + ci_line_y, + labelled_title, + labelled_x, + labelled_y, + scatter_xy, + set_y_scale, +) + + +def plot_ppc_rootogram( + dt, + ci_prob=None, + yscale="sqrt", + data_pairs=None, + var_names=None, + filter_vars=None, + group="posterior_predictive", + coords=None, + sample_dims=None, + plot_collection=None, + backend=None, + labeller=None, + aes_map=None, + plot_kwargs=None, + pc_kwargs=None, +): + """Rootogram with confidence intervals per predicted count. + + Rootograms are useful to check the calibration of count models. + A rootogram shows the difference between observed and predicted counts. The y-axis, + showing frequencies, is on the square root scale. This makes easier to compare + observed and expected frequencies even for low frequencies [1]_. + + + Parameters + ---------- + dt : DataTree + Input data + ci_prob : float, optional + Probability for the credible interval. Defaults to ``rcParams["stats.ci_prob"]``. + yscale : str, optional + Scale for the y-axis. Defaults to "sqrt", pass "linear" for linear scale. + Currently only "matplotlib" backend is supported. For "bokeh" and "plotly" + the y-axis is linear. + data_pairs : tuple, optional + Tuple of prior/posterior predictive data and observed data variable names. + If None, it will assume that the observed data and the predictive data have + the same variable name. + var_names : str or list of str, optional + One or more variables to be plotted. Currently only one variable is supported. + Prefix the variables by ~ when you want to exclude them from the plot. + filter_vars : {None, “like”, “regex”}, optional, default=None + If None (default), interpret var_names as the real variables names. + If “like”, interpret var_names as substrings of the real variables names. + If “regex”, interpret var_names as regular expressions on the real variables names. + coords : dict, optional + Coordinates to plot. + sample_dims : str or sequence of hashable, optional + Dimensions to reduce unless mapped to an aesthetic. + Defaults to ``rcParams["data.sample_dims"]`` + plot_collection : PlotCollection, optional + backend : {"matplotlib", "bokeh", "plotly"}, optional + labeller : labeller, optional + aes_map : mapping of {str : sequence of str}, optional + Mapping of artists to aesthetics that should use their mapping in `plot_collection` + when plotted. Valid keys are the same as for `plot_kwargs`. + + plot_kwargs : mapping of {str : mapping or False}, optional + Valid keys are: + + * predictive_markers -> passed to :func:`~arviz_plots.visuals.scatter_xy` + * observed_markers -> passed to :func:`~arviz_plots.visuals.scatter_xy` + * ci -> passed to :func:`~arviz_plots.visuals.ci_line_y` + * xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x` + * ylabel -> passed to :func:`~arviz_plots.visuals.labelled_y` + * title -> passed to :func:`~arviz_plots.visuals.labelled_title` + + pc_kwargs : mapping + Passed to :class:`arviz_plots.PlotCollection.grid` + + Returns + ------- + PlotCollection + + Examples + -------- + Plot the rootogram for the crabs dataset. + + .. plot:: + :context: close-figs + + >>> from arviz_plots import plot_ppc_rootogram, style + >>> style.use("arviz-variat") + >>> from arviz_base import load_arviz_data + >>> dt = load_arviz_data('crabs') + >>> plot_ppc_rootogram(dt) + + + .. minigallery:: plot_ppc_rootogram + + .. [1] Kleiber C, Zeileis A. *Visualizing Count Data Regressions Using Rootograms*. + The American Statistician, 70(3). (2016) https://doi.org/10.1080/00031305.2016.1173590 + """ + if ci_prob is None: + ci_prob = rcParams["stats.ci_prob"] + if sample_dims is None: + sample_dims = rcParams["data.sample_dims"] + if isinstance(sample_dims, str): + sample_dims = [sample_dims] + sample_dims = list(sample_dims) + if plot_kwargs is None: + plot_kwargs = {} + else: + plot_kwargs = plot_kwargs.copy() + if pc_kwargs is None: + pc_kwargs = {} + else: + pc_kwargs = pc_kwargs.copy() + + if backend is None: + if plot_collection is None: + backend = rcParams["plot.backend"] + else: + backend = plot_collection.backend + + labeller = BaseLabeller() + + if data_pairs is None: + data_pairs = (var_names, var_names) + else: + data_pairs = (list(data_pairs.keys()), list(data_pairs.values())) + + predictive_dist = process_group_variables_coords( + dt, group=group, var_names=data_pairs[0], filter_vars=filter_vars, coords=coords + ) + + observed_dist = process_group_variables_coords( + dt, group="observed_data", var_names=data_pairs[1], filter_vars=filter_vars, coords=coords + ) + + predictive_types = [ + predictive_dist[var].values.dtype.kind == "f" for var in predictive_dist.data_vars + ] + observed_types = [ + observed_dist[var].values.dtype.kind == "f" for var in observed_dist.data_vars + ] + + if any(predictive_types + observed_types): + raise ValueError( + "Detected at least one continuous variable.\n" + "Use plot_ppc variants specific for continuous data, " + "such as plot_ppc_dist.", + ) + + predictive_ds = point_interval_unique(dt, predictive_dist.data_vars, group, ci_prob) + observed_ds = point_unique(dt, observed_dist.data_vars) + + plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") + colors = plot_bknd.get_default_aes("color", 1, {}) + markers = plot_bknd.get_default_aes("marker", 7, {}) + + if plot_collection is None: + pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy() + + pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() + pc_kwargs.setdefault("col_wrap", 5) + pc_kwargs.setdefault("cols", "__variable__") + pc_kwargs.setdefault("rows", None) + + figsize = pc_kwargs["plot_grid_kws"].get("figsize", None) + figsize_units = pc_kwargs["plot_grid_kws"].get("figsize_units", "inches") + col_dims = pc_kwargs["cols"] + row_dims = pc_kwargs["rows"] + if figsize is None: + figsize = plot_bknd.scale_fig_size( + figsize, + rows=process_facet_dims(predictive_ds, row_dims)[0], + cols=process_facet_dims(predictive_ds, col_dims)[0], + figsize_units=figsize_units, + ) + figsize_units = "dots" + pc_kwargs["plot_grid_kws"]["figsize"] = figsize + pc_kwargs["plot_grid_kws"]["figsize_units"] = figsize_units + + plot_collection = PlotCollection.grid( + predictive_ds, + backend=backend, + **pc_kwargs, + ) + + if aes_map is None: + aes_map = {} + else: + aes_map = aes_map.copy() + + aes_map.setdefault("predictive_markers", plot_collection.aes_set) + aes_map.setdefault("ci", plot_collection.aes_set) + # aes_map.setdefault("predictive_markers", ["color"]) + # aes_map.setdefault("ci", ["color"]) + ## predictive_markers + predictive_ms_kwargs = copy(plot_kwargs.get("predictive_markers", {})) + + if predictive_ms_kwargs is not False: + _, predictive_ms_aes, predictive_ms_ignore = filter_aes( + plot_collection, aes_map, "predictive_markers", sample_dims + ) + if "color" not in predictive_ms_aes: + predictive_ms_kwargs.setdefault("color", colors[0]) + + predictive_ms_kwargs.setdefault("marker", markers[4]) + + plot_collection.map( + scatter_xy, + "predictive_markers", + data=predictive_ds, + ignore_aes=predictive_ms_ignore, + **predictive_ms_kwargs, + ) + + ## confidence intervals + ci_kwargs = copy(plot_kwargs.get("ci", {})) + _, ci_aes, ci_ignore = filter_aes(plot_collection, aes_map, "ci", sample_dims) + + if ci_kwargs is not False: + if "color" not in ci_aes: + ci_kwargs.setdefault("color", colors[0]) + + ci_kwargs.setdefault("alpha", 0.3) + ci_kwargs.setdefault("width", 3) + + plot_collection.map( + ci_line_y, + "ci", + data=predictive_ds, + ignore_aes=ci_ignore, + **ci_kwargs, + ) + + ## observed_markers + observed_ms_kwargs = copy(plot_kwargs.get("observed_markers", {})) + + if observed_ms_kwargs is not False: + _, _, observed_ms_ignore = filter_aes( + plot_collection, aes_map, "observed_markers", sample_dims + ) + observed_ms_kwargs.setdefault("color", "black") + observed_ms_kwargs.setdefault("marker", markers[6]) + + plot_collection.map( + scatter_xy, + "observed_markers", + data=observed_ds, + ignore_aes=observed_ms_ignore, + **observed_ms_kwargs, + ) + + # set xlabel + _, xlabels_aes, xlabels_ignore = filter_aes(plot_collection, aes_map, "xlabel", sample_dims) + xlabel_kwargs = copy(plot_kwargs.get("xlabel", {})) + if xlabel_kwargs is not False: + if "color" not in xlabels_aes: + xlabel_kwargs.setdefault("color", "black") + + xlabel_kwargs.setdefault("text", "counts") + + plot_collection.map( + labelled_x, + "xlabel", + ignore_aes=xlabels_ignore, + subset_info=True, + **xlabel_kwargs, + ) + + # set ylabel + _, ylabels_aes, ylabels_ignore = filter_aes(plot_collection, aes_map, "ylabel", sample_dims) + ylabel_kwargs = copy(plot_kwargs.get("ylabel", {})) + if ylabel_kwargs is not False: + if "color" not in ylabels_aes: + ylabel_kwargs.setdefault("color", "black") + + ylabel_kwargs.setdefault("text", "frequency") + + plot_collection.map( + labelled_y, + "ylabel", + ignore_aes=ylabels_ignore, + subset_info=True, + **ylabel_kwargs, + ) + + # title + title_kwargs = copy(plot_kwargs.get("title", {})) + _, _, title_ignore = filter_aes(plot_collection, aes_map, "title", sample_dims) + + if title_kwargs is not False: + plot_collection.map( + labelled_title, + "title", + ignore_aes=title_ignore, + subset_info=True, + labeller=labeller, + **title_kwargs, + ) + + if backend == "matplotlib": + plot_collection.map( + set_y_scale, + store_artist=False, + ignore_aes=plot_collection.aes_set, + scale=yscale, + ) + + return plot_collection diff --git a/src/arviz_plots/visuals/__init__.py b/src/arviz_plots/visuals/__init__.py index 1ff5bff..7435df7 100644 --- a/src/arviz_plots/visuals/__init__.py +++ b/src/arviz_plots/visuals/__init__.py @@ -40,6 +40,18 @@ def line_xy(da, target, backend, x=None, y=None, **kwargs): return plot_backend.line(x, y, target, **kwargs) +def ci_line_y(values, target, backend, **kwargs): + """Plot a line from y_bottom to y_top at given value of x.""" + plot_backend = import_module(f"arviz_plots.backend.{backend}") + return plot_backend.ciliney( + values.sel(plot_axis="x"), + values.sel(plot_axis="y_bottom"), + values.sel(plot_axis="y_top"), + target, + **kwargs, + ) + + def line_x(da, target, backend, y=None, **kwargs): """Plot a line along the x axis (y constant).""" if y is None: @@ -289,3 +301,9 @@ def set_xticks(da, target, backend, values, labels, **kwargs): """Dispatch to ``set_xticks`` function in backend.""" plot_backend = import_module(f"arviz_plots.backend.{backend}") plot_backend.xticks(values, labels, target, **kwargs) + + +def set_y_scale(da, target, backend, scale, **kwargs): + """Dispatch to ``remove_axis`` function in backend.""" + plot_backend = import_module(f"arviz_plots.backend.{backend}") + plot_backend.set_y_scale(target, scale, **kwargs)