Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Make BP closer to jax optimizer #135

Merged
merged 9 commits into from
Apr 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions examples/gmrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
)

# %%
run_bp, _, get_beliefs = graph.BP(fg.bp_state, 15, 1.0)
bp = graph.BP(fg.bp_state, temperature=1.0)

# %%
n_plots = 5
Expand All @@ -103,10 +103,13 @@
evidence = jnp.log(jnp.where(noisy_image[..., None] == 0, p_contour, 1 - p_contour))
target = prototype_targets[target_image]
marginals = graph.get_marginals(
get_beliefs(
run_bp(
evidence_updates={None: evidence},
log_potentials_updates=log_potentials,
bp.get_beliefs(
bp.run_bp(
bp.init(
evidence_updates={None: evidence},
log_potentials_updates=log_potentials,
),
num_iters=15,
damping=0.0,
)
)
Expand Down Expand Up @@ -147,10 +150,13 @@ def loss(noisy_image, target_image, log_potentials):
evidence = jnp.log(jnp.where(noisy_image[..., None] == 0, p_contour, 1 - p_contour))
target = prototype_targets[target_image]
marginals = graph.get_marginals(
get_beliefs(
run_bp(
evidence_updates={None: evidence},
log_potentials_updates=log_potentials,
bp.get_beliefs(
bp.run_bp(
bp.init(
evidence_updates={None: evidence},
log_potentials_updates=log_potentials,
),
num_iters=15,
damping=0.0,
)
)
Expand Down
15 changes: 9 additions & 6 deletions examples/ising_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@
# ### Run inference and visualize results

# %%
bp_state = fg.bp_state
run_bp, _, get_beliefs = graph.BP(bp_state, 3000)
bp = graph.BP(fg.bp_state, temperature=0)

# %%
bp_arrays = run_bp(
bp_arrays = bp.init(
evidence_updates={None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
)
bp_arrays = bp.run_bp(bp_arrays, num_iters=3000)

# %%
img = graph.decode_map_states(get_beliefs(bp_arrays))
img = graph.decode_map_states(bp.get_beliefs(bp_arrays))
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(img)

Expand All @@ -68,10 +68,11 @@

# %%
def loss(log_potentials_updates, evidence_updates):
bp_arrays = run_bp(
bp_arrays = bp.init(
log_potentials_updates=log_potentials_updates, evidence_updates=evidence_updates
)
beliefs = get_beliefs(bp_arrays)
bp_arrays = bp.run_bp(bp_arrays, num_iters=3000)
beliefs = bp.get_beliefs(bp_arrays)
loss = -jnp.sum(beliefs)
return loss

Expand All @@ -91,6 +92,8 @@ def loss(log_potentials_updates, evidence_updates):
# ### Message and evidence manipulation

# %%
bp_state = bp.to_bp_state(bp_arrays)

# Query evidence for variable (0, 0)
bp_state.evidence[0, 0]

Expand Down
11 changes: 8 additions & 3 deletions examples/pmp_binary_deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def plot_images(images, display=True, nr=None):
# in the same manner does not change X, so this naturally results in multiple equivalent modes.

# %%
run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, 3000)
bp = graph.BP(fg.bp_state, temperature=0.0)

# %% [markdown]
# We first compute the evidence without perturbation, similar to the PMP paper.
Expand Down Expand Up @@ -243,15 +243,20 @@ def plot_images(images, display=True, nr=None):
np.random.seed(seed=40)
n_samples = 4

bp_arrays = jax.vmap(functools.partial(run_bp, damping=0.5), in_axes=0, out_axes=0)(
bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)(
evidence_updates={
"S": uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape),
"W": uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape),
"SW": np.zeros(shape=(n_samples,) + SW.shape),
"X": uX[None] + np.zeros(shape=(n_samples,) + uX.shape),
},
)
beliefs = jax.vmap(get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
bp_arrays = jax.vmap(
functools.partial(bp.run_bp, num_iters=100, damping=0.5),
in_axes=0,
out_axes=0,
)(bp_arrays)
beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
map_states = graph.decode_map_states(beliefs)

# %% [markdown]
Expand Down
36 changes: 22 additions & 14 deletions examples/rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@
#
# An alternative way of creating the above factors is to add them iteratively by calling [`fg.add_factor`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph.add_factor) as below. This approach is not recommended as it is not computationally efficient.
# ~~~python
# import itertools
# from tqdm import tqdm
#
# # Add unary factors
# for ii in range(bh.shape[0]):
# fg.add_factor(
Expand Down Expand Up @@ -126,36 +129,34 @@
#
# Once we have added the factors, we can run max-product LBP and get MAP decoding by
# ~~~python
# run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, num_iters=100, temperature=0.0)
# bp_arrays = run_bp(damping=0.5)
# beliefs = get_beliefs(bp_arrays)
# bp = graph.BP(fg.bp_state, temperature=0.0)
# bp_arrays = bp.run_bp(bp.init(), num_iters=100, damping=0.5)
# beliefs = bp.get_beliefs(bp_arrays)
# map_states = graph.decode_map_states(beliefs)
# ~~~
# and run sum-product LBP and get estimated marginals by
# ~~~python
# run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, num_iters=100, temperature=1.0)
# bp_arrays = run_bp(damping=0.5)
# beliefs = get_beliefs(bp_arrays)
# bp = graph.BP(fg.bp_state, temperature=1.0)
# bp_arrays = bp.run_bp(bp.init(), num_iters=100, damping=0.5)
# beliefs = bp.get_beliefs(bp_arrays)
# marginals = graph.get_marginals(beliefs)
# ~~~
# More generally, PGMax implements LBP with temperature, with `temperature=0.0` and `temperature=1.0` corresponding to the commonly used max/sum-product LBP respectively.
#
# Now we are ready to demonstrate PMP sampling from RBM. PMP perturbs the model with [Gumbel](https://numpy.org/doc/stable/reference/random/generated/numpy.random.gumbel.html) unary potentials, and draws a sample from the RBM as the MAP decoding from running max-product LBP on the perturbed model

# %%
run_bp, get_bp_state, get_beliefs = graph.BP(
fg.bp_state, num_iters=100, temperature=0.0
)
bp = graph.BP(fg.bp_state, temperature=0.0)

# %%
bp_arrays = run_bp(
bp_arrays = bp.init(
evidence_updates={
"hidden": np.random.gumbel(size=(bh.shape[0], 2)),
"visible": np.random.gumbel(size=(bv.shape[0], 2)),
},
damping=0.5,
)
beliefs = get_beliefs(bp_arrays)
bp_arrays = bp.run_bp(bp_arrays, num_iters=100, damping=0.5)
beliefs = bp.get_beliefs(bp_arrays)
map_states = graph.decode_map_states(beliefs)

# %% [markdown]
Expand Down Expand Up @@ -191,13 +192,18 @@

# %%
n_samples = 10
bp_arrays = jax.vmap(functools.partial(run_bp, damping=0.5), in_axes=0, out_axes=0)(
bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)(
evidence_updates={
"hidden": np.random.gumbel(size=(n_samples, bh.shape[0], 2)),
"visible": np.random.gumbel(size=(n_samples, bv.shape[0], 2)),
},
)
beliefs = jax.vmap(get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
bp_arrays = jax.vmap(
functools.partial(bp.run_bp, num_iters=100, damping=0.5),
in_axes=0,
out_axes=0,
)(bp_arrays)
beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
map_states = graph.decode_map_states(beliefs)

# %% [markdown]
Expand All @@ -212,3 +218,5 @@
ax[np.unravel_index(ii, (2, 5))].axis("off")

fig.tight_layout()

# %%
7 changes: 3 additions & 4 deletions examples/rcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray:

# %%
frcs_dict = {model_idx: frcs[model_idx] for model_idx in range(frcs.shape[0])}
run_bp, _, get_beliefs = graph.BP(fg.bp_state, 30)
bp = graph.BP(fg.bp_state, temperature=0.0)
scores = np.zeros((len(test_set), frcs.shape[0]))
map_states_dict = {}

Expand All @@ -398,9 +398,8 @@ def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray:
print(f"Initializing evidences took {end-start:.3f} seconds for image {test_idx}.")

start = end
map_states = graph.decode_map_states(
get_beliefs(run_bp(evidence_updates=evidence_updates))
)
bp_arrays = bp.run_bp(bp.init(evidence_updates=evidence_updates), num_iters=30)
map_states = graph.decode_map_states(bp.get_beliefs(bp_arrays))
end = time.time()
print(f"Max product inference took {end-start:.3f} seconds for image {test_idx}.")

Expand Down
Loading