Skip to content

Commit

Permalink
Merge pull request #116 from stefanradev93/Development
Browse files Browse the repository at this point in the history
Update README.md with forum
  • Loading branch information
stefanradev93 authored Dec 17, 2023
2 parents ec61454 + b098a42 commit f5f7a6a
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ For starters, check out some of our walk-through notebooks:
7. [Model comparison for cognitive models](examples/Model_Comparison_MPT.ipynb)
8. [Hierarchical model comparison for cognitive models](examples/Hierarchical_Model_Comparison_MPT.ipynb)

## Project Documentation
## Documentation \& Help

The project documentation is available at <https://bayesflow.org>
The project documentation is available at <https://bayesflow.org>. Please use the [BayesFlow Forums](https://discuss.bayesflow.org/) for any BayesFlow-related questions and discussions, and [GitHub Issues](https://github.com/stefanradev93/BayesFlow/issues) for bug reports and feature requests.

## Installation

Expand Down
23 changes: 14 additions & 9 deletions bayesflow/diagnostics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def plot_recovery(
n_row=None,
xlabel="Ground truth",
ylabel="Estimated",
**kwargs
**kwargs,
):
"""Creates and plots publication-ready recovery plot with true vs. point estimate + uncertainty.
The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate
Expand Down Expand Up @@ -110,7 +110,7 @@ def plot_recovery(
**kwargs : optional
Additional keyword arguments passed to ax.errorbar or ax.scatter.
Example: `rasterized=True` to reduce PDF file size with many dots
Returns
-------
f : plt.Figure - the figure instance for optional saving
Expand Down Expand Up @@ -240,7 +240,7 @@ def plot_z_score_contraction(
tick_fontsize=12,
color="#8f2727",
n_col=None,
n_row=None
n_row=None,
):
"""Implements a graphical check for global model sensitivity by plotting the posterior
z-score over the posterior contraction for each set of posterior samples in ``post_samples``
Expand Down Expand Up @@ -567,7 +567,7 @@ def plot_sbc_histograms(
tick_fontsize=12,
hist_color="#a34f4f",
n_row=None,
n_col=None
n_col=None,
):
"""Creates and plots publication-ready histograms of rank statistics for simulation-based calibration
(SBC) checks according to [1].
Expand Down Expand Up @@ -910,7 +910,7 @@ def plot_losses(
for i, ax in enumerate(looper):
# Plot train curve
ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training")
if moving_average:
if moving_average and train_losses.columns[i] == "Loss":
moving_average_window = int(train_losses.shape[0] * ma_window_fraction)
smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean()
ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)")
Expand All @@ -929,7 +929,7 @@ def plot_losses(
)
# Schmuck
ax.set_xlabel("Training step #", fontsize=label_fontsize)
ax.set_ylabel("Loss value", fontsize=label_fontsize)
ax.set_ylabel("Value", fontsize=label_fontsize)
sns.despine(ax=ax)
ax.grid(alpha=grid_alpha)
ax.set_title(train_losses.columns[i], fontsize=title_fontsize)
Expand Down Expand Up @@ -1061,7 +1061,7 @@ def plot_calibration_curves(
fig_size=None,
color="#8f2727",
n_row=None,
n_col=None
n_col=None,
):
"""Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities
for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin.
Expand Down Expand Up @@ -1114,7 +1114,6 @@ def plot_calibration_curves(
elif n_row is not None and n_col is None:
n_col = int(np.ceil(num_models / n_row))


# Compute calibration
cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins)

Expand Down Expand Up @@ -1273,7 +1272,13 @@ def plot_confusion_matrix(
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(
j, i, format(cm[i, j], fmt), fontsize=value_fontsize, ha="center", va="center", color="white" if cm[i, j] > thresh else "black"
j,
i,
format(cm[i, j], fmt),
fontsize=value_fontsize,
ha="center",
va="center",
color="white" if cm[i, j] > thresh else "black",
)
if title:
ax.set_title("Confusion Matrix", fontsize=title_fontsize)
Expand Down
5 changes: 3 additions & 2 deletions bayesflow/summary_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def __init__(
# Construct final attention layer, which will perform cross-attention
# between the outputs ot the self-attention layers and the dynamic template
if bidirectional:
final_input_dim = template_dim*2
final_input_dim = template_dim * 2
else:
final_input_dim = template_dim
self.output_attention = MultiHeadAttentionBlock(
Expand Down Expand Up @@ -184,7 +184,8 @@ def call(self, x, **kwargs):

class SetTransformer(tf.keras.Model):
"""Implements the set transformer architecture from [1] which ultimately represents
a learnable permutation-invariant function.
a learnable permutation-invariant function. Designed to naturally model interactions in
the input set, which may be hard to capture with the simpler ``DeepSet`` architecture.
[1] Lee, J., Lee, Y., Kim, J., Kosiorek, A., Choi, S., & Teh, Y. W. (2019).
Set transformer: A framework for attention-based permutation-invariant neural networks.
Expand Down

0 comments on commit f5f7a6a

Please sign in to comment.