Skip to content

Commit

Permalink
rename plot_pava and minor fixes (#140)
Browse files Browse the repository at this point in the history
  • Loading branch information
aloctavodia authored Feb 19, 2025
1 parent ae1b36d commit 3bd0288
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 25 deletions.
2 changes: 1 addition & 1 deletion docs/source/api/plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ A complementary introduction and guide to ``plot_...`` functions is available at
plot_ess
plot_ess_evolution
plot_forest
plot_pava_calibration
plot_ppc_pava
plot_ppc_dist
plot_psense_dist
plot_psense_quantities
Expand Down
4 changes: 2 additions & 2 deletions docs/source/gallery/model_criticism/plot_pava_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
---
:::{seealso}
API Documentation: {func}`~arviz_plots.plot_pava_calibration`
API Documentation: {func}`~arviz_plots.plot_ppc_pava`
:::
"""
from arviz_base import load_arviz_data
Expand All @@ -16,7 +16,7 @@
azp.style.use("arviz-variat")

dt = load_arviz_data("classification10d")
pc = azp.plot_pava_calibration(
pc = azp.plot_ppc_pava(
dt,
backend="none",
)
Expand Down
2 changes: 1 addition & 1 deletion src/arviz_plots/backend/bokeh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_default_aes(aes_key, n, kwargs=None):
elif aes_key in {"linestyle", "line_dash"}:
vals = ["solid", "dashed", "dotted", "dashdot"]
elif aes_key == "marker":
vals = ["circle", "cross", "triangle", "x", "diamond", "square"]
vals = ["circle", "cross", "triangle", "x", "diamond", "square", "dot"]
else:
return get_agnostic_default_aes(aes_key, n)
return get_agnostic_default_aes(aes_key, n, {aes_key: vals})
Expand Down
2 changes: 1 addition & 1 deletion src/arviz_plots/backend/matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_default_aes(aes_key, n, kwargs=None):
vals = ["-", "--", ":", "-."]
vals = default_prop_cycle.get("linestyle", vals)
elif aes_key in {"marker", "m"}:
vals = ["o", "+", "^", "x", "d", "s"]
vals = ["o", "+", "^", "x", "d", "s", "."]
vals = default_prop_cycle.get("marker", vals)
elif aes_key in default_prop_cycle:
vals = default_prop_cycle[aes_key]
Expand Down
3 changes: 2 additions & 1 deletion src/arviz_plots/backend/plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def get_default_aes(aes_key, n, kwargs=None):
elif aes_key in {"linestyle", "dash"}:
vals = ["solid", "dash", "dot", "dashdot"]
elif aes_key in {"marker", "style"}:
vals = ["circle", "cross", "triangle-up", "x", "diamond", "square"]
# plotly does not have "dot" using "circle-open" instead
vals = ["circle", "cross", "triangle-up", "x", "diamond", "square", "circle-open"]
else:
return get_agnostic_default_aes(aes_key, n, {})
return get_agnostic_default_aes(aes_key, n, {aes_key: vals})
Expand Down
4 changes: 2 additions & 2 deletions src/arviz_plots/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .essplot import plot_ess
from .evolutionplot import plot_ess_evolution
from .forestplot import plot_forest
from .pavacalibrationplot import plot_pava_calibration
from .pavacalibrationplot import plot_ppc_pava
from .ppcdistplot import plot_ppc_dist
from .psensedistplot import plot_psense_dist
from .psensequantitiesplot import plot_psense_quantities
Expand All @@ -27,7 +27,7 @@
"plot_ess_evolution",
"plot_ppc_dist",
"plot_ridge",
"plot_pava_calibration",
"plot_ppc_pava",
"plot_psense_dist",
"plot_psense_quantities",
]
32 changes: 15 additions & 17 deletions src/arviz_plots/plots/pavacalibrationplot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""psense quantities plot code."""
"""Plot ppc using PAV-adjusted calibration plot."""
from copy import copy
from importlib import import_module

Expand All @@ -19,7 +19,7 @@
)


def plot_pava_calibration(
def plot_ppc_pava(
dt,
n_bootstaps=1000,
ci_prob=None,
Expand Down Expand Up @@ -76,8 +76,8 @@ def plot_pava_calibration(
plot_kwargs : mapping of {str : mapping or False}, optional
Valid keys are:
* calibration_line -> passed to :func:`~arviz_plots.visuals.line_xy`
* calibration_markers -> passed to :func:`~arviz_plots.visuals.scatter_xy`
* lines -> passed to :func:`~arviz_plots.visuals.line_xy`
* markers -> passed to :func:`~arviz_plots.visuals.scatter_xy`
* reference_line -> passed to :func:`~arviz_plots.visuals.line_xy`
* ci -> passed to :func:`~arviz_plots.visuals.fill_between_y`
* xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x`
Expand All @@ -99,14 +99,14 @@ def plot_pava_calibration(
.. plot::
:context: close-figs
>>> from arviz_plots import plot_pava_calibration, style
>>> from arviz_plots import plot_ppc_pava, style
>>> style.use("arviz-variat")
>>> from arviz_base import load_arviz_data
>>> dt = load_arviz_data('rugby')
>>> plot_pava_calibration(dt, ci_prob=0.90)
>>> plot_ppc_pava(dt, ci_prob=0.90)
.. minigallery:: plot_pava_calibration
.. minigallery:: plot_ppc_pava
References
----------
Expand Down Expand Up @@ -145,6 +145,7 @@ def plot_pava_calibration(

plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
colors = plot_bknd.get_default_aes("color", 1, {})
markers = plot_bknd.get_default_aes("marker", 7, {})
lines = plot_bknd.get_default_aes("linestyle", 2, {})

if plot_collection is None:
Expand Down Expand Up @@ -201,34 +202,31 @@ def plot_pava_calibration(
)

## markers
calibration_ms_kwargs = copy(plot_kwargs.get("calibration_line", {}))
calibration_ms_kwargs = copy(plot_kwargs.get("markers", {}))

if calibration_ms_kwargs is not False:
_, _, calibration_ms_ignore = filter_aes(
plot_collection, aes_map, "calibration_line", sample_dims
)
_, _, calibration_ms_ignore = filter_aes(plot_collection, aes_map, "lines", sample_dims)
calibration_ms_kwargs.setdefault("color", colors[0])
calibration_ms_kwargs.setdefault("marker", markers[6])

plot_collection.map(
scatter_xy,
"calibration_markers",
"markers",
data=ds_calibration,
ignore_aes=calibration_ms_ignore,
**calibration_ms_kwargs,
)

## lines
calibration_ls_kwargs = copy(plot_kwargs.get("calibration_line", {}))
calibration_ls_kwargs = copy(plot_kwargs.get("lines", {}))

if calibration_ls_kwargs is not False:
_, _, calibration_ls_ignore = filter_aes(
plot_collection, aes_map, "calibration_line", sample_dims
)
_, _, calibration_ls_ignore = filter_aes(plot_collection, aes_map, "lines", sample_dims)
calibration_ls_kwargs.setdefault("color", colors[0])

plot_collection.map(
line_xy,
"calibration_line",
"lines",
data=ds_calibration,
ignore_aes=calibration_ls_ignore,
**calibration_ls_kwargs,
Expand Down

0 comments on commit 3bd0288

Please sign in to comment.