Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rename plot_pava and minor fixes #140

Merged
merged 1 commit into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/api/plots.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ A complementary introduction and guide to ``plot_...`` functions is available at
plot_ess
plot_ess_evolution
plot_forest
plot_pava_calibration
plot_ppc_pava
plot_ppc_dist
plot_psense_dist
plot_psense_quantities
Expand Down
4 changes: 2 additions & 2 deletions docs/source/gallery/model_criticism/plot_pava_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
---

:::{seealso}
API Documentation: {func}`~arviz_plots.plot_pava_calibration`
API Documentation: {func}`~arviz_plots.plot_ppc_pava`
:::
"""
from arviz_base import load_arviz_data
Expand All @@ -16,7 +16,7 @@
azp.style.use("arviz-variat")

dt = load_arviz_data("classification10d")
pc = azp.plot_pava_calibration(
pc = azp.plot_ppc_pava(
dt,
backend="none",
)
Expand Down
2 changes: 1 addition & 1 deletion src/arviz_plots/backend/bokeh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_default_aes(aes_key, n, kwargs=None):
elif aes_key in {"linestyle", "line_dash"}:
vals = ["solid", "dashed", "dotted", "dashdot"]
elif aes_key == "marker":
vals = ["circle", "cross", "triangle", "x", "diamond", "square"]
vals = ["circle", "cross", "triangle", "x", "diamond", "square", "dot"]
else:
return get_agnostic_default_aes(aes_key, n)
return get_agnostic_default_aes(aes_key, n, {aes_key: vals})
Expand Down
2 changes: 1 addition & 1 deletion src/arviz_plots/backend/matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_default_aes(aes_key, n, kwargs=None):
vals = ["-", "--", ":", "-."]
vals = default_prop_cycle.get("linestyle", vals)
elif aes_key in {"marker", "m"}:
vals = ["o", "+", "^", "x", "d", "s"]
vals = ["o", "+", "^", "x", "d", "s", "."]
vals = default_prop_cycle.get("marker", vals)
elif aes_key in default_prop_cycle:
vals = default_prop_cycle[aes_key]
Expand Down
3 changes: 2 additions & 1 deletion src/arviz_plots/backend/plotly/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def get_default_aes(aes_key, n, kwargs=None):
elif aes_key in {"linestyle", "dash"}:
vals = ["solid", "dash", "dot", "dashdot"]
elif aes_key in {"marker", "style"}:
vals = ["circle", "cross", "triangle-up", "x", "diamond", "square"]
# plotly does not have "dot" using "circle-open" instead
vals = ["circle", "cross", "triangle-up", "x", "diamond", "square", "circle-open"]
else:
return get_agnostic_default_aes(aes_key, n, {})
return get_agnostic_default_aes(aes_key, n, {aes_key: vals})
Expand Down
4 changes: 2 additions & 2 deletions src/arviz_plots/plots/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .essplot import plot_ess
from .evolutionplot import plot_ess_evolution
from .forestplot import plot_forest
from .pavacalibrationplot import plot_pava_calibration
from .pavacalibrationplot import plot_ppc_pava
from .ppcdistplot import plot_ppc_dist
from .psensedistplot import plot_psense_dist
from .psensequantitiesplot import plot_psense_quantities
Expand All @@ -27,7 +27,7 @@
"plot_ess_evolution",
"plot_ppc_dist",
"plot_ridge",
"plot_pava_calibration",
"plot_ppc_pava",
"plot_psense_dist",
"plot_psense_quantities",
]
32 changes: 15 additions & 17 deletions src/arviz_plots/plots/pavacalibrationplot.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""psense quantities plot code."""
"""Plot ppc using PAV-adjusted calibration plot."""
from copy import copy
from importlib import import_module

Expand All @@ -19,7 +19,7 @@
)


def plot_pava_calibration(
def plot_ppc_pava(
dt,
n_bootstaps=1000,
ci_prob=None,
Expand Down Expand Up @@ -76,8 +76,8 @@ def plot_pava_calibration(
plot_kwargs : mapping of {str : mapping or False}, optional
Valid keys are:

* calibration_line -> passed to :func:`~arviz_plots.visuals.line_xy`
* calibration_markers -> passed to :func:`~arviz_plots.visuals.scatter_xy`
* lines -> passed to :func:`~arviz_plots.visuals.line_xy`
* markers -> passed to :func:`~arviz_plots.visuals.scatter_xy`
* reference_line -> passed to :func:`~arviz_plots.visuals.line_xy`
* ci -> passed to :func:`~arviz_plots.visuals.fill_between_y`
* xlabel -> passed to :func:`~arviz_plots.visuals.labelled_x`
Expand All @@ -99,14 +99,14 @@ def plot_pava_calibration(
.. plot::
:context: close-figs

>>> from arviz_plots import plot_pava_calibration, style
>>> from arviz_plots import plot_ppc_pava, style
>>> style.use("arviz-variat")
>>> from arviz_base import load_arviz_data
>>> dt = load_arviz_data('rugby')
>>> plot_pava_calibration(dt, ci_prob=0.90)
>>> plot_ppc_pava(dt, ci_prob=0.90)


.. minigallery:: plot_pava_calibration
.. minigallery:: plot_ppc_pava

References
----------
Expand Down Expand Up @@ -145,6 +145,7 @@ def plot_pava_calibration(

plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
colors = plot_bknd.get_default_aes("color", 1, {})
markers = plot_bknd.get_default_aes("marker", 7, {})
lines = plot_bknd.get_default_aes("linestyle", 2, {})

if plot_collection is None:
Expand Down Expand Up @@ -201,34 +202,31 @@ def plot_pava_calibration(
)

## markers
calibration_ms_kwargs = copy(plot_kwargs.get("calibration_line", {}))
calibration_ms_kwargs = copy(plot_kwargs.get("markers", {}))

if calibration_ms_kwargs is not False:
_, _, calibration_ms_ignore = filter_aes(
plot_collection, aes_map, "calibration_line", sample_dims
)
_, _, calibration_ms_ignore = filter_aes(plot_collection, aes_map, "lines", sample_dims)
calibration_ms_kwargs.setdefault("color", colors[0])
calibration_ms_kwargs.setdefault("marker", markers[6])

plot_collection.map(
scatter_xy,
"calibration_markers",
"markers",
data=ds_calibration,
ignore_aes=calibration_ms_ignore,
**calibration_ms_kwargs,
)

## lines
calibration_ls_kwargs = copy(plot_kwargs.get("calibration_line", {}))
calibration_ls_kwargs = copy(plot_kwargs.get("lines", {}))

if calibration_ls_kwargs is not False:
_, _, calibration_ls_ignore = filter_aes(
plot_collection, aes_map, "calibration_line", sample_dims
)
_, _, calibration_ls_ignore = filter_aes(plot_collection, aes_map, "lines", sample_dims)
calibration_ls_kwargs.setdefault("color", colors[0])

plot_collection.map(
line_xy,
"calibration_line",
"lines",
data=ds_calibration,
ignore_aes=calibration_ls_ignore,
**calibration_ls_kwargs,
Expand Down