Skip to content

Commit

Permalink
support filter keys in diagnostic plots and metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
paul-buerkner committed Feb 12, 2025
1 parent f83e92a commit d7b89ee
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 23 deletions.
17 changes: 13 additions & 4 deletions bayesflow/diagnostics/metrics/calibration_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@
def calibration_error(
targets: Mapping[str, np.ndarray] | np.ndarray,
references: Mapping[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
resolution: int = 20,
aggregation: Callable = np.median,
min_quantile: float = 0.005,
max_quantile: float = 0.995,
variable_names: Sequence[str] = None,
) -> Mapping[str, Any]:
"""Computes an aggregate score for the marginal calibration error over an ensemble of approximate
posteriors. The calibration error is given as the aggregate (e.g., median) of the absolute deviation
Expand All @@ -25,6 +26,11 @@ def calibration_error(
The random draws from the approximate posteriors over ``num_datasets``
references : np.ndarray of shape (num_datasets, num_variables)
The corresponding ground-truth values sampled from the prior
filter_keys : Sequence[str], optional (default = None)
Select keys from the dictionaries provided in targets and references.
By default, select all keys.
variable_names : Sequence[str], optional (default = None)
Optional variable names to show in the output.
resolution : int, optional, default: 20
The number of credibility intervals (CIs) to consider
aggregation : callable or None, optional, default: np.median
Expand All @@ -34,8 +40,6 @@ def calibration_error(
The minimum posterior quantile to consider.
max_quantile : float in (0, 1), optional, default: 0.995
The maximum posterior quantile to consider.
variable_names : Sequence[str], optional (default = None)
Optional variable names to select from the available variables.
Returns
-------
Expand All @@ -49,7 +53,12 @@ def calibration_error(
The (inferred) variable names.
"""

samples = dicts_to_arrays(targets=targets, references=references, variable_names=variable_names)
samples = dicts_to_arrays(
targets=targets,
references=references,
filter_keys=filter_keys,
variable_names=variable_names,
)

# Define alpha values and the corresponding quantile bounds
alphas = np.linspace(start=min_quantile, stop=max_quantile, num=resolution)
Expand Down
17 changes: 13 additions & 4 deletions bayesflow/diagnostics/metrics/posterior_contraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
def posterior_contraction(
targets: Mapping[str, np.ndarray] | np.ndarray,
references: Mapping[str, np.ndarray] | np.ndarray,
aggregation: Callable = np.median,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
aggregation: Callable = np.median,
) -> Mapping[str, Any]:
"""Computes the posterior contraction (PC) from prior to posterior for the given samples.
Expand All @@ -20,10 +21,13 @@ def posterior_contraction(
for each data set from `num_datasets`.
references : np.ndarray of shape (num_datasets, num_variables)
Prior samples, comprising `num_datasets` ground truths.
filter_keys : Sequence[str], optional (default = None)
Select keys from the dictionaries provided in targets and references.
By default, select all keys.
variable_names : Sequence[str], optional (default = None)
Optional variable names to show in the output.
aggregation : callable, optional (default = np.median)
Function to aggregate the PC across draws. Typically `np.mean` or `np.median`.
variable_names : Sequence[str], optional (default = None)
Optional variable names to select from the available variables.
Returns
-------
Expand All @@ -43,7 +47,12 @@ def posterior_contraction(
indicate low contraction.
"""

samples = dicts_to_arrays(targets=targets, references=references, variable_names=variable_names)
samples = dicts_to_arrays(
targets=targets,
references=references,
filter_keys=filter_keys,
variable_names=variable_names,
)

post_vars = samples["targets"].var(axis=1, ddof=1)
prior_vars = samples["references"].var(axis=0, keepdims=True, ddof=1)
Expand Down
17 changes: 13 additions & 4 deletions bayesflow/diagnostics/metrics/root_mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
def root_mean_squared_error(
targets: Mapping[str, np.ndarray] | np.ndarray,
references: Mapping[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
normalize: bool = True,
aggregation: Callable = np.median,
variable_names: Sequence[str] = None,
) -> Mapping[str, Any]:
"""Computes the (Normalized) Root Mean Squared Error (RMSE/NRMSE) for the given posterior and prior samples.
Expand All @@ -21,12 +22,15 @@ def root_mean_squared_error(
for each data set from `num_datasets`.
references : np.ndarray of shape (num_datasets, num_variables)
Prior samples, comprising `num_datasets` ground truths.
filter_keys : Sequence[str], optional (default = None)
Select keys from the dictionaries provided in targets and references.
By default, select all keys.
variable_names : Sequence[str], optional (default = None)
Optional variable names to show in the output.
normalize : bool, optional (default = True)
Whether to normalize the RMSE using the range of the prior samples.
aggregation : callable, optional (default = np.median)
Function to aggregate the RMSE across draws. Typically `np.mean` or `np.median`.
variable_names : Sequence[str], optional (default = None)
Optional variable names to select from the available variables.
Notes
-----
Expand All @@ -45,7 +49,12 @@ def root_mean_squared_error(
The (inferred) variable names.
"""

samples = dicts_to_arrays(targets=targets, references=references, variable_names=variable_names)
samples = dicts_to_arrays(
targets=targets,
references=references,
filter_keys=filter_keys,
variable_names=variable_names,
)

rmse = np.sqrt(np.mean((samples["targets"] - samples["references"][:, None, :]) ** 2, axis=0))

Expand Down
5 changes: 5 additions & 0 deletions bayesflow/diagnostics/plots/calibration_ecdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
def calibration_ecdf(
targets: dict[str, np.ndarray] | np.ndarray,
references: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
difference: bool = False,
stacked: bool = False,
Expand Down Expand Up @@ -69,6 +70,9 @@ def calibration_ecdf(
You can pass a reference array in the same shape as the
`prior_samples` array by setting `references` in the ``ranks_kwargs``.
This is motivated by [2].
filter_keys : list or None, optional, default: None
Select keys from the dictionaries provided in targets and references.
By default, select all keys.
variable_names : list or None, optional, default: None
The parameter names for nice plot titles.
Inferred if None. Only relevant if `stacked=False`.
Expand Down Expand Up @@ -117,6 +121,7 @@ def calibration_ecdf(
plot_data = prepare_plot_data(
targets=targets,
references=references,
filter_keys=filter_keys,
variable_names=variable_names,
num_col=num_col,
num_row=num_row,
Expand Down
5 changes: 5 additions & 0 deletions bayesflow/diagnostics/plots/calibration_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
def calibration_histogram(
targets: dict[str, np.ndarray] | np.ndarray,
references: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
figsize: Sequence[float] = None,
num_bins: int = 10,
Expand Down Expand Up @@ -39,6 +40,9 @@ def calibration_histogram(
The posterior draws obtained from n_data_sets
references : np.ndarray of shape (n_data_sets, n_params)
The prior draws obtained for generating n_data_sets
filter_keys : list or None, optional, default: None
Select keys from the dictionaries provided in targets and references.
By default, select all keys.
variable_names : list or None, optional, default: None
The parameter names for nice plot titles. Inferred if None
figsize : tuple or None, optional, default : None
Expand Down Expand Up @@ -73,6 +77,7 @@ def calibration_histogram(
plot_data = prepare_plot_data(
targets=targets,
references=references,
filter_keys=filter_keys,
variable_names=variable_names,
num_col=num_col,
num_row=num_row,
Expand Down
19 changes: 15 additions & 4 deletions bayesflow/diagnostics/plots/pairs_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@

def pairs_samples(
samples: dict[str, np.ndarray] | np.ndarray = None,
context: str = None,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
context: str = None,
height: float = 2.5,
color: str | tuple = "#132a70",
alpha: float = 0.9,
Expand All @@ -25,9 +26,14 @@ def pairs_samples(
----------
samples : dict[str, Tensor], default: None
Sample draws from any dataset
filter_keys : list or None, optional, default: None
Select keys from the dictionary provided in samples.
By default, select all keys.
variable_names : list or None, optional, default: None
The parameter names for nice plot titles. Inferred if None
context : str, default: None
The context that the sample represents. If specified,
should usually either be `Prior` or `Posterior`.
The context that the sample represents.
If specified, should usually either be `Prior` or `Posterior`.
height : float, optional, default: 2.5
The height of the pair plot
color : str, optional, default : '#8f2727'
Expand All @@ -44,7 +50,12 @@ def pairs_samples(
Additional keyword arguments passed to the sns.PairGrid constructor
"""

plot_data = dicts_to_arrays(targets=samples, variable_names=variable_names, default_name=context)
plot_data = dicts_to_arrays(
targets=samples,
filter_keys=filter_keys,
variable_names=variable_names,
default_name=context,
)

dim = plot_data["targets"].shape[-1]
if context is None:
Expand Down
36 changes: 34 additions & 2 deletions bayesflow/diagnostics/plots/recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@
def recovery(
targets: dict[str, np.ndarray] | np.ndarray,
references: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
point_agg=np.median,
uncertainty_agg=median_abs_deviation,
add_corr: bool = True,
figsize: Sequence[int] = None,
label_fontsize: int = 16,
title_fontsize: int = 18,
metric_fontsize: int = 16,
tick_fontsize: int = 12,
add_corr: bool = True,
color: str = "#132a70",
num_col: int = None,
num_row: int = None,
Expand All @@ -46,7 +47,37 @@ def recovery(
Parameters
----------
#TODO
targets : np.ndarray of shape (num_datasets, num_post_draws, num_params)
The posterior draws obtained from num_datasets
references : np.ndarray of shape (num_datasets, num_params)
The prior draws (true parameters) used for generating the num_datasets
filter_keys : list or None, optional, default: None
Select keys from the dictionaries provided in targets and references.
By default, select all keys.
variable_names : list or None, optional, default: None
The individual parameter names for nice plot titles. Inferred if None
point_agg : function to compute point estimates. Default: median
uncertainty_agg : function to compute uncertainty estimates. Default: MAD
add_corr : boolean, default: True
Should correlations between estimates and ground truth values be shown?
figsize : tuple or None, optional, default : None
The figure size passed to the matplotlib constructor. Inferred if None.
label_fontsize : int, optional, default: 16
The font size of the y-label text.
title_fontsize : int, optional, default: 18
The font size of the title text.
metric_fontsize : int, optional, default: 16
The font size of the metrics shown as text.
tick_fontsize : int, optional, default: 12
The font size of the axis ticklabels.
color : str, optional, default: '#8f2727'
The color for the true vs. estimated scatter points and error bars.
num_row : int, optional, default: None
The number of rows for the subplots. Dynamically determined if None.
num_col : int, optional, default: None
The number of columns for the subplots. Dynamically determined if None.
xlabel:
ylabel:
Returns
-------
Expand All @@ -62,6 +93,7 @@ def recovery(
plot_data = prepare_plot_data(
targets=targets,
references=references,
filter_keys=filter_keys,
variable_names=variable_names,
num_col=num_col,
num_row=num_row,
Expand Down
5 changes: 5 additions & 0 deletions bayesflow/diagnostics/plots/z_score_contraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
def z_score_contraction(
targets: dict[str, np.ndarray] | np.ndarray,
references: dict[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
figsize: Sequence[int] = None,
label_fontsize: int = 16,
Expand Down Expand Up @@ -56,6 +57,9 @@ def z_score_contraction(
The posterior draws obtained from num_datasets
references : np.ndarray of shape (num_datasets, num_params)
The prior draws (true parameters) used for generating the num_datasets
filter_keys : list or None, optional, default: None
Select keys from the dictionaries provided in targets and references.
By default, select all keys.
variable_names : list or None, optional, default: None
The parameter names for nice plot titles. Inferred if None
figsize : tuple or None, optional, default : None
Expand Down Expand Up @@ -87,6 +91,7 @@ def z_score_contraction(
plot_data = prepare_plot_data(
targets=targets,
references=references,
filter_keys=filter_keys,
variable_names=variable_names,
num_col=num_col,
num_row=num_row,
Expand Down
26 changes: 22 additions & 4 deletions bayesflow/utils/dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ def split_arrays(data: Mapping[str, np.ndarray], axis: int = -1) -> Mapping[str,
result = {}

for key, value in data.items():
if not hasattr(value, "shape"):
result[key] = np.array([value])
continue

if len(value.shape) == 1:
result[key] = value
continue
Expand All @@ -130,8 +134,9 @@ def split_arrays(data: Mapping[str, np.ndarray], axis: int = -1) -> Mapping[str,
def dicts_to_arrays(
targets: Mapping[str, np.ndarray] | np.ndarray,
references: Mapping[str, np.ndarray] | np.ndarray = None,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
default_name: str = "var",
default_name: str = "v",
) -> Mapping[str, Any]:
"""Helper function that prepares estimates and optional ground truths for diagnostics
(plotting or computation of metrics).
Expand Down Expand Up @@ -175,13 +180,23 @@ def dicts_to_arrays(

# Case dictionaries provided
if isinstance(targets, dict):
if filter_keys is not None:
targets = {k: targets[k] for k in filter_keys}

# to ensure safe subsetting of references if specified
filter_keys = targets.keys()

targets = split_arrays(targets)
variable_names = list(targets.keys()) if variable_names is None else variable_names
targets = np.stack([v for k, v in targets.items() if k in variable_names], axis=-1)

if variable_names is None:
variable_names = list(targets.keys())

targets = np.stack(list(targets.values()), axis=-1)

if references is not None:
references = {k: references[k] for k in filter_keys}
references = split_arrays(references)
references = np.stack([v for k, v in references.items() if k in variable_names], axis=-1)
references = np.stack(list(references.values()), axis=-1)

# Case arrays provided
elif isinstance(targets, np.ndarray):
Expand All @@ -194,6 +209,9 @@ def dicts_to_arrays(
f"Only dicts and tensors are supported as arguments, but your targets are of type {type(targets)}"
)

if len(variable_names) is not targets.shape[-1]:
raise ValueError("Length of 'variable_names' should be the same as the number of variables.")

return dict(
targets=targets,
references=references,
Expand Down
7 changes: 6 additions & 1 deletion bayesflow/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
def prepare_plot_data(
targets: Mapping[str, np.ndarray] | np.ndarray,
references: Mapping[str, np.ndarray] | np.ndarray,
filter_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
num_col: int = None,
num_row: int = None,
Expand Down Expand Up @@ -50,7 +51,11 @@ def prepare_plot_data(
"""

plot_data = dicts_to_arrays(
targets=targets, references=references, variable_names=variable_names, default_name=default_name
targets=targets,
references=references,
filter_keys=filter_keys,
variable_names=variable_names,
default_name=default_name,
)
check_estimates_prior_shapes(plot_data["targets"], plot_data["references"])

Expand Down

0 comments on commit d7b89ee

Please sign in to comment.