Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Using PGMax to solve the PMP 2D blind deconvolution exp #127

Merged
merged 6 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ Here are a few self-contained Colab notebooks to help you get started on using P
- [Tutorial on basic PGMax usage](https://colab.research.google.com/drive/1PQ9eVaOg336XzPqko-v_us3izEbjvWMW?usp=sharing)
- [Implementing max-product LBP](https://colab.research.google.com/drive/1mSffrA1WgQwgIiJQd2pLULPa5YKAOJOX?usp=sharing) for [Recursive Cortical Networks](https://www.science.org/doi/10.1126/science.aag2612)
- [End-to-end differentiable LBP for gradient-based PGM training](https://colab.research.google.com/drive/1yxDCLwhX0PVgFS7NHUcXG3ptMAY1CxMC?usp=sharing)


- [2D binary deconvolution](https://colab.research.google.com/drive/1w_ufQz0u18V_paM8pI97CO11965MduO4?usp=sharing)

## Citing PGMax

Expand Down
Binary file added examples/example_data/conv_problem.npz
Binary file not shown.
253 changes: 253 additions & 0 deletions examples/pmp_binary_deconvolution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# ---
# jupyter:
# jupytext:
# formats: ipynb,py:percent
# text_representation:
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.13.7
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---

# %% [markdown]
# We use PGMax to reimplement the binary deconvolution experiment presented in the Section 5.6 of the [Perturb-and-max-product (PMP)](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) Neurips 2021 paper.
#
# The original implementation is available on the [GitHub repository of the paper.](https://github.com/vicariousinc/perturb_and_max_product/blob/master/experiments/exp6_convor.py)

# %%
import functools

import jax
import matplotlib.pyplot as plt
import numpy as np
from scipy.special import logit
from tqdm.notebook import tqdm

from pgmax.factors import logical
from pgmax.fg import graph, groups


# %%
def plot_images(images, display=True, nr=None):
"Useful function for visualizing several images"
n_images, H, W = images.shape
images = images - images.min()
images /= images.max() + 1e-10

if nr is None:
nr = nc = np.ceil(np.sqrt(n_images)).astype(int)
else:
nc = n_images // nr
assert n_images == nr * nc
big_image = np.ones(((H + 1) * nr + 1, (W + 1) * nc + 1, 3))
big_image[..., :3] = 0
big_image[:: H + 1] = [0.5, 0, 0.5]

im = 0
for r in range(nr):
for c in range(nc):
if im < n_images:
big_image[
(H + 1) * r + 1 : (H + 1) * r + 1 + H,
(W + 1) * c + 1 : (W + 1) * c + 1 + W,
:,
] = images[im, :, :, None]
im += 1

if display:
plt.figure(figsize=(10, 10))
plt.imshow(big_image, interpolation="none")
return big_image


# %% [markdown]
# ### Load the data

# %% [markdown]
# Our binary 2D convolution generative model uses two set of binary variables to form a set of binary images X:
# - a set W of 2D binary features shared across images
# - a set S of binary indicator variables representing whether each feature is present at each possible image location.
#
# Each binary entry of W and S is modeled with an independent Bernoulli prior. W and S are then combined by convolution, placing the features defined by W at the locations specified by S in order to form the image.
#
# We load the dataset of 100 images used in the PMP paper.
# We only keep the first 20 images here for the sake of speed.

# %%
data = np.load("example_data/conv_problem.npz")
W_gt = data["W"]
X_gt = data["X"]
X_gt = X_gt[:20]

_ = plot_images(X_gt[:, 0], nr=4)

# %% [markdown]
# We also visualize the four 2D binary features used to generate the images above.
#
# We aim at recovering these binary features using PGMax.

# %%
_ = plot_images(W_gt[0], nr=1)

# %% [markdown]
# ### Construct variable grid, initialize factor graph, and add factors

# %% [markdown]
# Our factor graph naturally includes the binary features W, the binary indicators of features locations S and the binary images obtained by convolution X.
#
# To generate X from W and S, we observe that a binary convolution can be represented by two set of logical factors:
# - a first set of ANDFactors, which combine the joint activations in W and S. We store the children of these ANDFactors in an auxiliary variable SW
# - a second set of ORFactors, which maps SW to X and model (binary) features overlapping.
#
# See Section 5.6 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) for more details.

# %%
# The dimensions of W used for the generation of X were (4, 5, 5) but we set them to (5, 6, 6)
# to simulate a more realistic scenario in which we do not know their ground truth values
n_feat, feat_height, feat_width = 5, 6, 6

n_images, n_chan, im_height, im_width = X_gt.shape
s_height = im_height - feat_height + 1
s_width = im_width - feat_width + 1

# Binary features
W = groups.NDVariableArray(
num_states=2, shape=(n_chan, n_feat, feat_height, feat_width)
)

# Binary indicators of features locations
S = groups.NDVariableArray(num_states=2, shape=(n_images, n_feat, s_height, s_width))

# Auxiliary binary variables combining W and S
SW = groups.NDVariableArray(
num_states=2,
shape=(n_images, n_chan, im_height, im_width, n_feat, feat_height, feat_width),
)

# Binary images obtained by convolution
X = groups.NDVariableArray(num_states=2, shape=X_gt.shape)

# Factor graph
fg = graph.FactorGraph(variables=dict(S=S, W=W, SW=SW, X=X))

# %%
# Note: although adding Factors is currently handled via for loops,
# we have plans to make this more efficient in the near future through the use of FactorGroups

variable_names_for_ORFactors = {}

# Add ANDFactors
for idx_img in tqdm(range(n_images)):
for idx_chan in range(n_chan):
for idx_s_height in range(s_height):
for idx_s_width in range(s_width):
for idx_feat in range(n_feat):
for idx_feat_height in range(feat_height):
for idx_feat_width in range(feat_width):
idx_img_height = idx_feat_height + idx_s_height
idx_img_width = idx_feat_width + idx_s_width
SW_var = (
"SW",
idx_img,
idx_chan,
idx_img_height,
idx_img_width,
idx_feat,
idx_feat_height,
idx_feat_width,
)

variable_names = [
("S", idx_img, idx_feat, idx_s_height, idx_s_width),
(
"W",
idx_chan,
idx_feat,
idx_feat_height,
idx_feat_width,
),
SW_var,
]

fg.add_factor_by_type(
variable_names=variable_names,
factor_type=logical.ANDFactor,
)

X_var = (idx_img, idx_chan, idx_img_height, idx_img_width)
if X_var not in variable_names_for_ORFactors:
variable_names_for_ORFactors[X_var] = [SW_var]
else:
variable_names_for_ORFactors[X_var].append(SW_var)


# Add ORFactors
for X_var, variable_names_for_ORFactor in variable_names_for_ORFactors.items():
fg.add_factor_by_type(
variable_names=variable_names_for_ORFactor + [("X",) + X_var], # type: ignore
factor_type=logical.ORFactor,
)

for factor_type, factors in fg.factors.items():
print(f"The factor graph contains {len(factors)} {factor_type}")

# %% [markdown]
# ### Run inference and visualize results

# %% [markdown]
# PMP perturbs the model by adding Gumbel noise to unary potentials, then samples from the joint posterior *p(W, S | X)*.
#
# Note that this posterior is highly multimodal: permuting the first dimension of W and the second dimension of S
# in the same manner does not change X, so this naturally results in multiple equivalent modes.

# %%
run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, 3000)

# %% [markdown]
# We first compute the evidence without perturbation, similar to the PMP paper.

# %%
pW = 0.25
pS = 1e-72
pX = 1e-100

# Sparsity inducing priors for W and S
uW = np.zeros((W.shape) + (2,))
uW[..., 1] = logit(pW)

uS = np.zeros((S.shape) + (2,))
uS[..., 1] = logit(pS)

# Likelihood the binary images given X
uX = np.zeros((X_gt.shape) + (2,))
uX[..., 0] = (2 * X_gt - 1) * logit(pX)

# %% [markdown]
# We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap`

# %%
np.random.seed(seed=42)
n_samples = 4

bp_arrays = jax.vmap(functools.partial(run_bp, damping=0.5), in_axes=0, out_axes=0)(
evidence_updates={
"S": uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape),
"W": uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape),
"SW": np.zeros(shape=(n_samples,) + SW.shape),
"X": uX[None] + np.zeros(shape=(n_samples,) + uX.shape),
},
)
beliefs = jax.vmap(get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
map_states = graph.decode_map_states(beliefs)

# %% [markdown]
# Visualizing the MAP decoding, we see that we have 4 good random samples (one per row) from the posterior!
#
# Because we have used one extra feature for inference, each posterior sample recovers the 4 basic features used to generate the images, and includes an extra symbol.

# %%
_ = plot_images(map_states["W"].reshape(-1, feat_height, feat_width), nr=n_samples)
6 changes: 3 additions & 3 deletions pgmax/factors/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class LogicalWiring(nodes.Wiring):
parents_edge_states[ii, 0] contains the global ORFactor index,
parents_edge_states[ii, 1] contains the message index of the parent variable's state 0.
Both indices only take into account the LogicalFactors of the same subtype (OR/AND) of the FactorGraph.
The parent variable's state 1 is parents_edge_states[ii, 2] + 1.
The parent variable's state 1 is parents_edge_states[ii, 1] + 1.
children_edge_states: Array of shape (num_factors,)
children_edge_states[ii] contains the message index of the child variable's state 0,
which takes into account all the LogicalFactors of the same subtype (OR/AND) of the FactorGraph.
Expand Down Expand Up @@ -168,7 +168,7 @@ def compile_wiring(
]
)
num_parents = len(self.variables) - 1
relevant_state = (-self.edge_states_offset + 1) / 2
relevant_state = (-self.edge_states_offset + 1) // 2
parents_edge_states = np.vstack(
[
np.zeros(num_parents, dtype=int),
Expand Down Expand Up @@ -238,7 +238,7 @@ def pass_logical_fac_to_var_messages(
parents_edge_states[ii, 1] contains the message index of the parent variable's relevant state.
For ORFactors the relevant state is 0, for ANDFactors the relevant state is 1.
Both indices only take into account the LogicalFactors of the FactorGraph
The parent variable's other state is parents_edge_states[ii, 2] + edge_states_offset
The parent variable's other state is parents_edge_states[ii, 1] + edge_states_offset
children_edge_states: Array of shape (num_factors,)
children_edge_states[ii] contains the message index of the child variable's relevant state.
For ORFactors the relevant state is 0, for ANDFactors the relevant state is 1.
Expand Down