Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Variables refactor #136

Merged
merged 35 commits into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
66bbcbc
Rewrite NDVariableArray
antoine-dedieu Apr 12, 2022
d816f63
Falke8
antoine-dedieu Apr 12, 2022
58b0115
Test + HashableDict
antoine-dedieu Apr 12, 2022
c8f2c5e
Minbor
antoine-dedieu Apr 13, 2022
39b546a
Variables as tuple + Remove BPOuputs/HashableDict
antoine-dedieu Apr 20, 2022
ef92d1b
Start tests + mypy
antoine-dedieu Apr 21, 2022
7e55e5c
Tests passing
antoine-dedieu Apr 21, 2022
9a2dcd2
Variables
antoine-dedieu Apr 21, 2022
c6ae8d8
Tests + mypy
antoine-dedieu Apr 21, 2022
5c8b381
Some docstrings
antoine-dedieu Apr 22, 2022
40ec519
Stannis first comments
antoine-dedieu Apr 22, 2022
96c7fe3
Remove add_factor
antoine-dedieu Apr 22, 2022
88f8e23
Test
antoine-dedieu Apr 23, 2022
033d176
Docstring
antoine-dedieu Apr 23, 2022
89767c8
Coverage
antoine-dedieu Apr 25, 2022
1ccfcf5
Coverage
antoine-dedieu Apr 25, 2022
ecaab6c
Coverage 100%
antoine-dedieu Apr 25, 2022
cbd136b
Remove factor group names
antoine-dedieu Apr 25, 2022
22b604e
Remove factor group names
antoine-dedieu Apr 25, 2022
87fcfd9
Modify hash + add_factors
antoine-dedieu Apr 26, 2022
0f639fe
Stannis' comments
antoine-dedieu Apr 26, 2022
546b790
Flattent / unflattent
antoine-dedieu Apr 26, 2022
35ce6c0
Unflatten with nan
antoine-dedieu Apr 26, 2022
704ee3a
Speeding up
antoine-dedieu Apr 27, 2022
aaa67ef
max size
antoine-dedieu Apr 27, 2022
04c4d89
Understand timings
antoine-dedieu Apr 27, 2022
a276ce5
Some comments
antoine-dedieu Apr 27, 2022
a839be6
Comments
antoine-dedieu Apr 27, 2022
05de350
Minor
antoine-dedieu Apr 27, 2022
1e0c93d
Docstring
antoine-dedieu Apr 27, 2022
63c6738
Minor changes
antoine-dedieu Apr 28, 2022
0397f9b
Doc
antoine-dedieu Apr 28, 2022
8b3d60e
Rename this_hash
StannisZhou Apr 30, 2022
8751e90
Final comments
antoine-dedieu May 2, 2022
b265f14
Minor
antoine-dedieu May 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 46 additions & 33 deletions examples/gmrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@

# %%
# Load saved log potentials
log_potentials = dict(**np.load("example_data/gmrf_log_potentials.npz"))
n_clones = log_potentials.pop("n_clones")
grmf_log_potentials = dict(**np.load("example_data/gmrf_log_potentials.npz"))
n_clones = grmf_log_potentials.pop("n_clones")
p_contour = jax.device_put(np.repeat(data["p_contour"], n_clones))
prototype_targets = jax.device_put(
np.array(
Expand All @@ -58,42 +58,54 @@
fg = graph.FactorGraph(variables)

# %%
# Add top-down factors
fg.add_factor_group(
factory=enumeration.PairwiseFactorGroup,
variable_names_for_factors=[
[(ii, jj), (ii + 1, jj)] for ii in range(M - 1) for jj in range(N)
# Create top-down factors
top_down = enumeration.PairwiseFactorGroup(
variables_for_factors=[
[variables[ii, jj], variables[ii + 1, jj]]
for ii in range(M - 1)
for jj in range(N)
],
name="top_down",
)
# Add left-right factors
fg.add_factor_group(
factory=enumeration.PairwiseFactorGroup,
variable_names_for_factors=[
[(ii, jj), (ii, jj + 1)] for ii in range(M) for jj in range(N - 1)

# Create left-right factors
left_right = enumeration.PairwiseFactorGroup(
variables_for_factors=[
[variables[ii, jj], variables[ii, jj + 1]]
for ii in range(M)
for jj in range(N - 1)
],
name="left_right",
)
# Add diagonal factors
fg.add_factor_group(
factory=enumeration.PairwiseFactorGroup,
variable_names_for_factors=[
[(ii, jj), (ii + 1, jj + 1)] for ii in range(M - 1) for jj in range(N - 1)

# Create diagonal factors
diagonal0 = enumeration.PairwiseFactorGroup(
variables_for_factors=[
[variables[ii, jj], variables[ii + 1, jj + 1]]
for ii in range(M - 1)
for jj in range(N - 1)
],
name="diagonal0",
)
fg.add_factor_group(
factory=enumeration.PairwiseFactorGroup,
variable_names_for_factors=[
[(ii, jj), (ii - 1, jj + 1)] for ii in range(1, M) for jj in range(N - 1)
diagonal1 = enumeration.PairwiseFactorGroup(
variables_for_factors=[
[variables[ii, jj], variables[ii - 1, jj + 1]]
for ii in range(1, M)
for jj in range(N - 1)
],
name="diagonal1",
)

# Add factors
fg.add_factors([top_down, left_right, diagonal0, diagonal1])

# %%
bp = graph.BP(fg.bp_state, temperature=1.0)

# %%
log_potentials = {
top_down: grmf_log_potentials["top_down"],
left_right: grmf_log_potentials["left_right"],
diagonal0: grmf_log_potentials["diagonal0"],
diagonal1: grmf_log_potentials["diagonal1"],
}

n_plots = 5
indices = np.random.permutation(noisy_images.shape[0])[:n_plots]
fig, ax = plt.subplots(n_plots, 3, figsize=(30, 10 * n_plots))
Expand All @@ -106,14 +118,15 @@
bp.get_beliefs(
bp.run_bp(
bp.init(
evidence_updates={None: evidence},
evidence_updates={variables: evidence},
log_potentials_updates=log_potentials,
),
num_iters=15,
damping=0.0,
)
)
)
)[variables]

pred_image = np.argmax(
np.stack(
[
Expand Down Expand Up @@ -153,15 +166,15 @@ def loss(noisy_image, target_image, log_potentials):
bp.get_beliefs(
bp.run_bp(
bp.init(
evidence_updates={None: evidence},
evidence_updates={variables: evidence},
log_potentials_updates=log_potentials,
),
num_iters=15,
damping=0.0,
)
)
)
logp = jnp.mean(jnp.log(jnp.sum(target * marginals, axis=-1)))
logp = jnp.mean(jnp.log(jnp.sum(target * marginals[variables], axis=-1)))
return -logp


Expand Down Expand Up @@ -191,10 +204,10 @@ def update(step, batch_noisy_images, batch_target_images, opt_state):
# %%
opt_state = init_fun(
{
"top_down": np.random.randn(num_states, num_states),
"left_right": np.random.randn(num_states, num_states),
"diagonal0": np.random.randn(num_states, num_states),
"diagonal1": np.random.randn(num_states, num_states),
top_down: np.random.randn(num_states, num_states),
left_right: np.random.randn(num_states, num_states),
diagonal0: np.random.randn(num_states, num_states),
diagonal1: np.random.randn(num_states, num_states),
}
)

Expand Down
40 changes: 21 additions & 19 deletions examples/ising_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,21 @@

# %%
variables = vgroup.NDVariableArray(num_states=2, shape=(50, 50))
fg = graph.FactorGraph(variables=variables)
variable_names_for_factors = []
fg = graph.FactorGraph(variable_groups=variables)

variables_for_factors = []
for ii in range(50):
for jj in range(50):
kk = (ii + 1) % 50
ll = (jj + 1) % 50
variable_names_for_factors.append([(ii, jj), (kk, jj)])
variable_names_for_factors.append([(ii, jj), (ii, ll)])
variables_for_factors.append([variables[ii, jj], variables[kk, jj]])
variables_for_factors.append([variables[ii, jj], variables[ii, ll]])

fg.add_factor_group(
factory=enumeration.PairwiseFactorGroup,
variable_names_for_factors=variable_names_for_factors,
factor_group = enumeration.PairwiseFactorGroup(
variables_for_factors=variables_for_factors,
log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
name="factors",
)
fg.add_factors(factor_group)

# %% [markdown]
# ### Run inference and visualize results
Expand All @@ -53,12 +53,13 @@

# %%
bp_arrays = bp.init(
evidence_updates={None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
evidence_updates={variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
)
bp_arrays = bp.run_bp(bp_arrays, num_iters=3000)
beliefs = bp.get_beliefs(bp_arrays)

# %%
img = graph.decode_map_states(bp.get_beliefs(bp_arrays))
img = graph.decode_map_states(beliefs)[variables]
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
ax.imshow(img)

Expand All @@ -73,19 +74,20 @@ def loss(log_potentials_updates, evidence_updates):
)
bp_arrays = bp.run_bp(bp_arrays, num_iters=3000)
beliefs = bp.get_beliefs(bp_arrays)
loss = -jnp.sum(beliefs)
loss = -jnp.sum(beliefs[variables])
return loss


batch_loss = jax.jit(jax.vmap(loss, in_axes=(None, {None: 0}), out_axes=0))
batch_loss = jax.jit(jax.vmap(loss, in_axes=(None, {variables: 0}), out_axes=0))
log_potentials_grads = jax.jit(jax.grad(loss, argnums=0))

# %%
batch_loss(None, {None: jax.device_put(np.random.gumbel(size=(10, 50, 50, 2)))})
batch_loss(None, {variables: jax.device_put(np.random.gumbel(size=(10, 50, 50, 2)))})

# %%
grads = log_potentials_grads(
{"factors": jnp.eye(2)}, {None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
{factor_group: jnp.eye(2)},
{variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))},
)

# %% [markdown]
Expand All @@ -95,15 +97,15 @@ def loss(log_potentials_updates, evidence_updates):
bp_state = bp.to_bp_state(bp_arrays)

# Query evidence for variable (0, 0)
bp_state.evidence[0, 0]
bp_state.evidence[variables[0, 0]]

# %%
# Set evidence for variable (0, 0)
bp_state.evidence[0, 0] = np.array([1.0, 1.0])
bp_state.evidence[0, 0]
bp_state.evidence[variables[0, 0]] = np.array([1.0, 1.0])
bp_state.evidence[variables[0, 0]]

# %%
# Set evidence for all variables using an array
evidence = np.random.randn(50, 50, 2)
bp_state.evidence[None] = evidence
bp_state.evidence[10, 10] == evidence[10, 10]
bp_state.evidence[variables] = evidence
np.allclose(bp_state.evidence[variables[10, 10]], evidence[10, 10])
71 changes: 30 additions & 41 deletions examples/pmp_binary_deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,15 @@ def plot_images(images, display=True, nr=None):
X = vgroup.NDVariableArray(num_states=2, shape=X_gt.shape)

# %% [markdown]
# For computation efficiency, we add large FactorGroups via `fg.add_factor_group` instead of adding individual Factors
# For computation efficiency, we construct large FactorGroups instead of individual factors

# %%
# Factor graph
fg = graph.FactorGraph(variables=dict(S=S, W=W, SW=SW, X=X))
fg = graph.FactorGraph(variable_groups=[S, W, SW, X])

# Define the ANDFactors
variable_names_for_ANDFactors = []
variable_names_for_ORFactors_dict = defaultdict(list)
variables_for_ANDFactors = []
variables_for_ORFactors_dict = defaultdict(list)
for idx_img in tqdm(range(n_images)):
for idx_chan in range(n_chan):
for idx_s_height in range(s_height):
Expand All @@ -152,52 +152,39 @@ def plot_images(images, display=True, nr=None):
for idx_feat_width in range(feat_width):
idx_img_height = idx_feat_height + idx_s_height
idx_img_width = idx_feat_width + idx_s_width
SW_var = (
"SW",
SW_var = SW[
idx_img,
idx_chan,
idx_img_height,
idx_img_width,
idx_feat,
idx_feat_height,
idx_feat_width,
)

variable_names_for_ANDFactor = [
("S", idx_img, idx_feat, idx_s_height, idx_s_width),
(
"W",
idx_chan,
idx_feat,
idx_feat_height,
idx_feat_width,
),
]

variables_for_ANDFactor = [
S[idx_img, idx_feat, idx_s_height, idx_s_width],
W[idx_chan, idx_feat, idx_feat_height, idx_feat_width],
SW_var,
]
variable_names_for_ANDFactors.append(
variable_names_for_ANDFactor
)
variables_for_ANDFactors.append(variables_for_ANDFactor)

X_var = (idx_img, idx_chan, idx_img_height, idx_img_width)
variable_names_for_ORFactors_dict[X_var].append(SW_var)
X_var = X[idx_img, idx_chan, idx_img_height, idx_img_width]
variables_for_ORFactors_dict[X_var].append(SW_var)

# Add ANDFactorGroup, which is computationally efficient
fg.add_factor_group(
factory=logical.ANDFactorGroup,
variable_names_for_factors=variable_names_for_ANDFactors,
)
AND_factor_group = logical.ANDFactorGroup(variables_for_ANDFactors)
fg.add_factors(AND_factor_group)

# Define the ORFactors
variable_names_for_ORFactors = [
list(tuple(variable_names_for_ORFactors_dict[X_var]) + (("X",) + X_var,))
for X_var in variable_names_for_ORFactors_dict
variables_for_ORFactors = [
list(tuple(variables_for_ORFactors_dict[X_var]) + (X_var,))
for X_var in variables_for_ORFactors_dict
]

# Add ORFactorGroup, which is computationally efficient
fg.add_factor_group(
factory=logical.ORFactorGroup,
variable_names_for_factors=variable_names_for_ORFactors,
)
OR_factor_group = logical.ORFactorGroup(variables_for_ORFactors)
fg.add_factors(OR_factor_group)

for factor_type, factor_groups in fg.factor_groups.items():
if len(factor_groups) > 0:
Expand All @@ -222,7 +209,7 @@ def plot_images(images, display=True, nr=None):

# %%
pW = 0.25
pS = 1e-100
pS = 1e-75
pX = 1e-100

# Sparsity inducing priors for W and S
Expand All @@ -237,25 +224,27 @@ def plot_images(images, display=True, nr=None):
uX[..., 0] = (2 * X_gt - 1) * logit(pX)

# %% [markdown]
# We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap`
# We draw a batch of samples from the posterior in parallel by transforming `bp.init`/`bp.run_bp`/`bp.get_beliefs` with `jax.vmap`

# %%
np.random.seed(seed=40)
np.random.seed(seed=0)
n_samples = 4

bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)(
evidence_updates={
"S": uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape),
"W": uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape),
"SW": np.zeros(shape=(n_samples,) + SW.shape),
"X": uX[None] + np.zeros(shape=(n_samples,) + uX.shape),
S: uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape),
W: uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape),
SW: np.zeros(shape=(n_samples,) + SW.shape),
X: uX[None] + np.zeros(shape=(n_samples,) + uX.shape),
},
)

bp_arrays = jax.vmap(
functools.partial(bp.run_bp, num_iters=100, damping=0.5),
in_axes=0,
out_axes=0,
)(bp_arrays)

beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
map_states = graph.decode_map_states(beliefs)

Expand All @@ -265,4 +254,4 @@ def plot_images(images, display=True, nr=None):
# Because we have used one extra feature for inference, each posterior sample recovers the 4 basic features used to generate the images, and includes an extra symbol.

# %%
_ = plot_images(map_states["W"].reshape(-1, feat_height, feat_width), nr=n_samples)
_ = plot_images(map_states[W].reshape(-1, feat_height, feat_width), nr=n_samples)
Loading