Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Modules reorganization #140

Merged
merged 3 commits into from
May 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 10 additions & 12 deletions examples/gmrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,7 @@
from jax.example_libraries import optimizers
from tqdm.notebook import tqdm

from pgmax.fg import graph
from pgmax.groups import enumeration
from pgmax.groups import variables as vgroup
from pgmax import fgraph, fgroup, infer, vgroup

# %% [markdown]
# # Visualize a trained GMRF
Expand Down Expand Up @@ -54,12 +52,12 @@
# %%
M, N = target_images.shape[-2:]
num_states = np.sum(n_clones)
variables = vgroup.NDVariableArray(num_states=num_states, shape=(M, N))
fg = graph.FactorGraph(variables)
variables = vgroup.NDVarArray(num_states=num_states, shape=(M, N))
fg = fgraph.FactorGraph(variables)

# %%
# Create top-down factors
top_down = enumeration.PairwiseFactorGroup(
top_down = fgroup.PairwiseFactorGroup(
variables_for_factors=[
[variables[ii, jj], variables[ii + 1, jj]]
for ii in range(M - 1)
Expand All @@ -68,7 +66,7 @@
)

# Create left-right factors
left_right = enumeration.PairwiseFactorGroup(
left_right = fgroup.PairwiseFactorGroup(
variables_for_factors=[
[variables[ii, jj], variables[ii, jj + 1]]
for ii in range(M)
Expand All @@ -77,14 +75,14 @@
)

# Create diagonal factors
diagonal0 = enumeration.PairwiseFactorGroup(
diagonal0 = fgroup.PairwiseFactorGroup(
variables_for_factors=[
[variables[ii, jj], variables[ii + 1, jj + 1]]
for ii in range(M - 1)
for jj in range(N - 1)
],
)
diagonal1 = enumeration.PairwiseFactorGroup(
diagonal1 = fgroup.PairwiseFactorGroup(
variables_for_factors=[
[variables[ii, jj], variables[ii - 1, jj + 1]]
for ii in range(1, M)
Expand All @@ -96,7 +94,7 @@
fg.add_factors([top_down, left_right, diagonal0, diagonal1])

# %%
bp = graph.BP(fg.bp_state, temperature=1.0)
bp = infer.BP(fg.bp_state, temperature=1.0)

# %%
log_potentials = {
Expand All @@ -114,7 +112,7 @@
target_image = target_images[idx]
evidence = jnp.log(jnp.where(noisy_image[..., None] == 0, p_contour, 1 - p_contour))
target = prototype_targets[target_image]
marginals = graph.get_marginals(
marginals = infer.get_marginals(
bp.get_beliefs(
bp.run_bp(
bp.init(
Expand Down Expand Up @@ -162,7 +160,7 @@
def loss(noisy_image, target_image, log_potentials):
evidence = jnp.log(jnp.where(noisy_image[..., None] == 0, p_contour, 1 - p_contour))
target = prototype_targets[target_image]
marginals = graph.get_marginals(
marginals = infer.get_marginals(
bp.get_beliefs(
bp.run_bp(
bp.init(
Expand Down
14 changes: 6 additions & 8 deletions examples/ising_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,14 @@
import matplotlib.pyplot as plt
import numpy as np

from pgmax.fg import graph
from pgmax.groups import enumeration
from pgmax.groups import variables as vgroup
from pgmax import fgraph, fgroup, infer, vgroup

# %% [markdown]
# ### Construct variable grid, initialize factor graph, and add factors

# %%
variables = vgroup.NDVariableArray(num_states=2, shape=(50, 50))
fg = graph.FactorGraph(variable_groups=variables)
variables = vgroup.NDVarArray(num_states=2, shape=(50, 50))
fg = fgraph.FactorGraph(variable_groups=variables)

variables_for_factors = []
for ii in range(50):
Expand All @@ -39,7 +37,7 @@
variables_for_factors.append([variables[ii, jj], variables[kk, jj]])
variables_for_factors.append([variables[ii, jj], variables[ii, ll]])

factor_group = enumeration.PairwiseFactorGroup(
factor_group = fgroup.PairwiseFactorGroup(
variables_for_factors=variables_for_factors,
log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
)
Expand All @@ -49,7 +47,7 @@
# ### Run inference and visualize results

# %%
bp = graph.BP(fg.bp_state, temperature=0)
bp = infer.BP(fg.bp_state, temperature=0)

# %%
bp_arrays = bp.init(
Expand All @@ -59,7 +57,7 @@
beliefs = bp.get_beliefs(bp_arrays)

# %%
img = graph.decode_map_states(beliefs)[variables]
img = infer.decode_map_states(beliefs)[variables]
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(img)

Expand Down
24 changes: 10 additions & 14 deletions examples/pmp_binary_deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
from scipy.special import logit
from tqdm.notebook import tqdm

from pgmax.fg import graph
from pgmax.groups import logical
from pgmax.groups import variables as vgroup
from pgmax import fgraph, fgroup, infer, vgroup


# %%
Expand Down Expand Up @@ -117,28 +115,26 @@ def plot_images(images, display=True, nr=None):
s_width = im_width - feat_width + 1

# Binary features
W = vgroup.NDVariableArray(
num_states=2, shape=(n_chan, n_feat, feat_height, feat_width)
)
W = vgroup.NDVarArray(num_states=2, shape=(n_chan, n_feat, feat_height, feat_width))

# Binary indicators of features locations
S = vgroup.NDVariableArray(num_states=2, shape=(n_images, n_feat, s_height, s_width))
S = vgroup.NDVarArray(num_states=2, shape=(n_images, n_feat, s_height, s_width))

# Auxiliary binary variables combining W and S
SW = vgroup.NDVariableArray(
SW = vgroup.NDVarArray(
num_states=2,
shape=(n_images, n_chan, im_height, im_width, n_feat, feat_height, feat_width),
)

# Binary images obtained by convolution
X = vgroup.NDVariableArray(num_states=2, shape=X_gt.shape)
X = vgroup.NDVarArray(num_states=2, shape=X_gt.shape)

# %% [markdown]
# For computation efficiency, we construct large FactorGroups instead of individual factors

# %%
# Factor graph
fg = graph.FactorGraph(variable_groups=[S, W, SW, X])
fg = fgraph.FactorGraph(variable_groups=[S, W, SW, X])

# Define the ANDFactors
variables_for_ANDFactors = []
Expand Down Expand Up @@ -173,7 +169,7 @@ def plot_images(images, display=True, nr=None):
variables_for_ORFactors_dict[X_var].append(SW_var)

# Add ANDFactorGroup, which is computationally efficient
AND_factor_group = logical.ANDFactorGroup(variables_for_ANDFactors)
AND_factor_group = fgroup.ANDFactorGroup(variables_for_ANDFactors)
fg.add_factors(AND_factor_group)

# Define the ORFactors
Expand All @@ -183,7 +179,7 @@ def plot_images(images, display=True, nr=None):
]

# Add ORFactorGroup, which is computationally efficient
OR_factor_group = logical.ORFactorGroup(variables_for_ORFactors)
OR_factor_group = fgroup.ORFactorGroup(variables_for_ORFactors)
fg.add_factors(OR_factor_group)

for factor_type, factor_groups in fg.factor_groups.items():
Expand All @@ -202,7 +198,7 @@ def plot_images(images, display=True, nr=None):
# in the same manner does not change X, so this naturally results in multiple equivalent modes.

# %%
bp = graph.BP(fg.bp_state, temperature=0.0)
bp = infer.BP(fg.bp_state, temperature=0.0)

# %% [markdown]
# We first compute the evidence without perturbation, similar to the PMP paper.
Expand Down Expand Up @@ -246,7 +242,7 @@ def plot_images(images, display=True, nr=None):
)(bp_arrays)

beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
map_states = graph.decode_map_states(beliefs)
map_states = infer.decode_map_states(beliefs)

# %% [markdown]
# Visualizing the MAP decoding, we see that we have 4 good random samples (one per row) from the posterior!
Expand Down
56 changes: 27 additions & 29 deletions examples/rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@
import matplotlib.pyplot as plt
import numpy as np

from pgmax.fg import graph
from pgmax.groups import enumeration
from pgmax.groups import variables as vgroup
from pgmax import fgraph, fgroup, infer, vgroup

# %% [markdown]
# The [`pgmax.fg.graph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.html#module-pgmax.fg.graph) module contains core classes for specifying factor graphs and implementing LBP, while the [`pgmax.fg.groups`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.html#module-pgmax.fg.graph) module contains classes for specifying groups of variables/factors.
# The [`pgmax.fgraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgraph.html#module-pgmax.fgraph) module contains classes for specifying factor graphs, the [`pgmax.fgroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.html#module-pgmax.vgroup) module contains classes for specifying groups of variables, the [`pgmax.vgroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.html#module-pgmax.fgroup) module contains classes for specifying groups of factors and the [`pgmax.infer`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.infer.html#module-pgmax.infer) module containing core functions to perform LBP.
#
# We next load the RBM trained in Sec. 5.5 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) on MNIST digits.

Expand All @@ -47,23 +45,23 @@

# %%
# Initialize factor graph
hidden_variables = vgroup.NDVariableArray(num_states=2, shape=bh.shape)
visible_variables = vgroup.NDVariableArray(num_states=2, shape=bv.shape)
fg = graph.FactorGraph(variable_groups=[hidden_variables, visible_variables])
hidden_variables = vgroup.NDVarArray(num_states=2, shape=bh.shape)
visible_variables = vgroup.NDVarArray(num_states=2, shape=bv.shape)
fg = fgraph.FactorGraph(variable_groups=[hidden_variables, visible_variables])

# %% [markdown]
# [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup) (e.g. an [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)), or a list/dictionary of [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
# [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VarGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VarGroup.html#pgmax.fg.groups.VarGroup) (e.g. an [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray)), or a list of [`VarGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VarGroup.html#pgmax.fg.groups.VarGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
#
# After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors by grouping them using [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)

# %%
# Create unary factors
hidden_unaries = enumeration.EnumerationFactorGroup(
hidden_unaries = fgroup.EnumFactorGroup(
variables_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])],
factor_configs=np.arange(2)[:, None],
log_potentials=np.stack([np.zeros_like(bh), bh], axis=1),
)
visible_unaries = enumeration.EnumerationFactorGroup(
visible_unaries = fgroup.EnumFactorGroup(
variables_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])],
factor_configs=np.arange(2)[:, None],
log_potentials=np.stack([np.zeros_like(bv), bv], axis=1),
Expand All @@ -78,7 +76,7 @@
for ii in range(bh.shape[0])
for jj in range(bv.shape[0])
]
pairwise_factors = enumeration.PairwiseFactorGroup(
pairwise_factors = fgroup.PairwiseFactorGroup(
variables_for_factors=variables_for_factors,
log_potential_matrix=log_potential_matrix,
)
Expand All @@ -88,67 +86,67 @@


# %% [markdown]
# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing groups of similar factors. The code above makes use of [`EnumerationFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumerationFactorGroup.html#pgmax.fg.groups.EnumerationFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup).
# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing groups of similar factors. The code above makes use of [`EnumFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumFactorGroup.html#pgmax.fg.groups.EnumFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup).
#
# A [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) takes as argument `variables_for_factors` which is a list of lists of the variables involved in the different factors, and additional arguments specific to each [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) (e.g. `factor_configs` or `log_potential_matrix` here).
#
# In this example, since we construct `fg` with variables `hidden_variables` and `visible_variables`, which are both [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)s, we can refer to the `ii`th hidden variable as `hidden_variables[ii]` and the `jj`th visible variable as `visible_variables[jj]`.
# In this example, since we construct `fg` with variables `hidden_variables` and `visible_variables`, which are both [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray)s, we can refer to the `ii`th hidden variable as `hidden_variables[ii]` and the `jj`th visible variable as `visible_variables[jj]`.
#
# An alternative way of creating the above factors is to add them iteratively without building the [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)s as below. This approach is not recommended as it is not computationally efficient.
# ~~~python
# from pgmax.factors import enumeration as enumeration_factor
# from pgmax import factor
# import itertools
# from tqdm import tqdm
#
# # Add unary factors
# for ii in range(bh.shape[0]):
# factor = enumeration_factor.EnumerationFactor(
# unary_factor = factor.EnumFactor(
# variables=[hidden_variables[ii]],
# factor_configs=np.arange(2)[:, None],
# log_potentials=np.array([0, bh[ii]]),
# )
# fg.add_factors(factor)
# fg.add_factors(unary_factor)
#
# for jj in range(bv.shape[0]):
# factor = enumeration_factor.EnumerationFactor(
# unary_factor = factor.EnumFactor(
# variables=[visible_variables[jj]],
# factor_configs=np.arange(2)[:, None],
# log_potentials=np.array([0, bv[jj]]),
# )
# fg.add_factors(factor)
# fg.add_factors(unary_factor)
#
# # Add pairwise factors
# factor_configs = np.array(list(itertools.product(np.arange(2), repeat=2)))
# for ii in tqdm(range(bh.shape[0])):
# for jj in range(bv.shape[0]):
# factor = enumeration_factor.EnumerationFactor(
# pairwise_factor = factor.EnumFactor(
# variables=[hidden_variables[ii], visible_variables[jj]],
# factor_configs=factor_configs,
# log_potentials=np.array([0, 0, 0, W[ii, jj]]),
# )
# fg.add_factors(factor)
# fg.add_factors(pairwise_factor)
# ~~~
#
# Once we have added the factors, we can run max-product LBP and get MAP decoding by
# ~~~python
# bp = graph.BP(fg.bp_state, temperature=0.0)
# bp = infer.BP(fg.bp_state, temperature=0.0)
# bp_arrays = bp.run_bp(bp.init(), num_iters=100, damping=0.5)
# beliefs = bp.get_beliefs(bp_arrays)
# map_states = graph.decode_map_states(beliefs)
# map_states = infer.decode_map_states(beliefs)
# ~~~
# and run sum-product LBP and get estimated marginals by
# ~~~python
# bp = graph.BP(fg.bp_state, temperature=1.0)
# bp = infer.BP(fg.bp_state, temperature=1.0)
# bp_arrays = bp.run_bp(bp.init(), num_iters=100, damping=0.5)
# beliefs = bp.get_beliefs(bp_arrays)
# marginals = graph.get_marginals(beliefs)
# marginals = infer.get_marginals(beliefs)
# ~~~
# More generally, PGMax implements LBP with temperature, with `temperature=0.0` and `temperature=1.0` corresponding to the commonly used max/sum-product LBP respectively.
#
# Now we are ready to demonstrate PMP sampling from RBM. PMP perturbs the model with [Gumbel](https://numpy.org/doc/stable/reference/random/generated/numpy.random.gumbel.html) unary potentials, and draws a sample from the RBM as the MAP decoding from running max-product LBP on the perturbed model

# %%
bp = graph.BP(fg.bp_state, temperature=0.0)
bp = infer.BP(fg.bp_state, temperature=0.0)

# %%
bp_arrays = bp.init(
Expand All @@ -168,17 +166,17 @@
# %%
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(
graph.decode_map_states(beliefs)[visible_variables].copy().reshape((28, 28)),
infer.decode_map_states(beliefs)[visible_variables].copy().reshape((28, 28)),
cmap="gray",
)
ax.axis("off")

# %% [markdown]
# PGMax adopts a functional interface for implementing LBP: running LBP in PGMax starts with
# ~~~python
# bp = graph.BP(fg.bp_state, temperature=T)
# bp = infer.BP(fg.bp_state, temperature=T)
# ~~~
# where the arguments of the `bp` are several useful functions to run LBP. In particular, `bp.init`, `bp.run_bp`, `bp.get_beliefs` are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`/`vmap`/`grad`, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see [here](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) and [here](https://github.com/n2cholas/awesome-jax)).
# where the arguments of the `this_bp` are several useful functions to run LBP. In particular, `bp.init`, `bp.run_bp`, `bp.get_beliefs` are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`/`vmap`/`grad`, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see [here](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) and [here](https://github.com/n2cholas/awesome-jax)).
#
# As an example of applying `jax.vmap` to `bp.init`/`bp.run_bp`/`bp.get_beliefs` to process a batch of samples/models in parallel, instead of drawing one sample at a time as above, we can draw a batch of samples in parallel as follows:

Expand All @@ -197,7 +195,7 @@
)(bp_arrays)

beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
map_states = graph.decode_map_states(beliefs)
map_states = infer.decode_map_states(beliefs)

# %% [markdown]
# Visualizing the MAP decodings, we see that we have sampled 10 MNIST digits in parallel!
Expand Down
Loading