Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Speeding up adding factors + computing wiring #129

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
a7fcdf6
Factor groups adding
antoine-dedieu Mar 28, 2022
c10ec7c
Speed up wiring
antoine-dedieu Mar 28, 2022
1d1a0e3
Leave warning for user, dont call __get_item__ internally
antoine-dedieu Mar 28, 2022
b4a7d04
Do not instantiate factors
antoine-dedieu Mar 29, 2022
9c1aeff
Unit test
antoine-dedieu Mar 29, 2022
e7bf493
Unit test
antoine-dedieu Mar 29, 2022
9d1b99f
Add logical factor group
antoine-dedieu Mar 29, 2022
589d488
Enumeration and Logical have same wiring for factors and ggroups
antoine-dedieu Mar 29, 2022
282f1a7
Remove graph
antoine-dedieu Mar 30, 2022
29d66c1
Unaries
antoine-dedieu Mar 30, 2022
25d5268
Simplify
antoine-dedieu Mar 30, 2022
1d4a25d
Groups folder
antoine-dedieu Mar 30, 2022
f8ab1b7
Mypy
antoine-dedieu Mar 30, 2022
971c765
Mypy
antoine-dedieu Mar 30, 2022
8b6ea24
Pytest
antoine-dedieu Mar 30, 2022
3f75e31
Test pgmax
antoine-dedieu Mar 30, 2022
a32b736
Docstring
antoine-dedieu Mar 30, 2022
eb3eaef
Docs
antoine-dedieu Mar 30, 2022
527c15d
Coverage
antoine-dedieu Mar 31, 2022
f3b0062
More tests + coverage at 99%
antoine-dedieu Mar 31, 2022
4b389e0
coverage 100%
antoine-dedieu Mar 31, 2022
46ad92c
Comments - Part 1
antoine-dedieu Apr 1, 2022
5eba118
Comments - Part 2
antoine-dedieu Apr 1, 2022
c970d67
Wiring function only done once
antoine-dedieu Apr 1, 2022
3d80958
Commets - part3
antoine-dedieu Apr 1, 2022
67d441c
Mypy
StannisZhou Apr 4, 2022
e604c49
Fixes
StannisZhou Apr 4, 2022
06f466d
Fix black import error
StannisZhou Apr 4, 2022
85a6d9f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2022
ddc3f08
Type hints
StannisZhou Apr 4, 2022
5bda961
Docstring
StannisZhou Apr 4, 2022
46f2ab0
Graph updates
StannisZhou Apr 4, 2022
94e27e8
Final commit
antoine-dedieu Apr 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ repos:
- id: isort

- repo: https://github.com/psf/black
rev: 21.12b0
rev: 22.3.0
hooks:
- id: black
language_version: python3.7
Expand Down
16 changes: 9 additions & 7 deletions examples/gmrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.13.2
# jupytext_version: 1.13.7
# kernelspec:
# display_name: Python 3 (ipykernel)
# language: python
Expand All @@ -22,7 +22,9 @@
from jax.example_libraries import optimizers
from tqdm.notebook import tqdm

from pgmax.fg import graph, groups
from pgmax.fg import graph
from pgmax.groups import enumeration
from pgmax.groups import variables as vgroup

# %% [markdown]
# # Visualize a trained GMRF
Expand Down Expand Up @@ -52,36 +54,36 @@
# %%
M, N = target_images.shape[-2:]
num_states = np.sum(n_clones)
variables = groups.NDVariableArray(num_states=num_states, shape=(M, N))
variables = vgroup.NDVariableArray(num_states=num_states, shape=(M, N))
fg = graph.FactorGraph(variables)

# %%
# Add top-down factors
fg.add_factor_group(
factory=groups.PairwiseFactorGroup,
factory=enumeration.PairwiseFactorGroup,
variable_names_for_factors=[
[(ii, jj), (ii + 1, jj)] for ii in range(M - 1) for jj in range(N)
],
name="top_down",
)
# Add left-right factors
fg.add_factor_group(
factory=groups.PairwiseFactorGroup,
factory=enumeration.PairwiseFactorGroup,
variable_names_for_factors=[
[(ii, jj), (ii, jj + 1)] for ii in range(M) for jj in range(N - 1)
],
name="left_right",
)
# Add diagonal factors
fg.add_factor_group(
factory=groups.PairwiseFactorGroup,
factory=enumeration.PairwiseFactorGroup,
variable_names_for_factors=[
[(ii, jj), (ii + 1, jj + 1)] for ii in range(M - 1) for jj in range(N - 1)
],
name="diagonal0",
)
fg.add_factor_group(
factory=groups.PairwiseFactorGroup,
factory=enumeration.PairwiseFactorGroup,
variable_names_for_factors=[
[(ii, jj), (ii - 1, jj + 1)] for ii in range(1, M) for jj in range(N - 1)
],
Expand Down
26 changes: 7 additions & 19 deletions examples/ising_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
# extension: .py
# format_name: percent
# format_version: '1.3'
# jupytext_version: 1.11.4
# jupytext_version: 1.13.7
# kernelspec:
# display_name: Python 3
# display_name: Python 3 (ipykernel)
# language: python
# name: python3
# ---
Expand All @@ -20,13 +20,15 @@
import matplotlib.pyplot as plt
import numpy as np

from pgmax.fg import graph, groups
from pgmax.fg import graph
from pgmax.groups import enumeration
from pgmax.groups import variables as vgroup

# %% [markdown]
# ### Construct variable grid, initialize factor graph, and add factors

# %%
variables = groups.NDVariableArray(num_states=2, shape=(50, 50))
variables = vgroup.NDVariableArray(num_states=2, shape=(50, 50))
fg = graph.FactorGraph(variables=variables)
variable_names_for_factors = []
for ii in range(50):
Expand All @@ -37,7 +39,7 @@
variable_names_for_factors.append([(ii, jj), (ii, ll)])

fg.add_factor_group(
factory=groups.PairwiseFactorGroup,
factory=enumeration.PairwiseFactorGroup,
variable_names_for_factors=variable_names_for_factors,
log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
name="factors",
Expand Down Expand Up @@ -102,17 +104,3 @@ def loss(log_potentials_updates, evidence_updates):
evidence = np.random.randn(50, 50, 2)
bp_state.evidence[None] = evidence
bp_state.evidence[10, 10] == evidence[10, 10]

# %%
# Query messages from the factor involving (0, 0), (0, 1) in factor group "factors" to variable (0, 0)
bp_state.ftov_msgs[[(0, 0), (0, 1)], (0, 0)]

# %%
# Set messages from the factor involving (0, 0), (0, 1) in factor group "factors" to variable (0, 0)
bp_state.ftov_msgs[[(0, 0), (0, 1)], (0, 0)] = np.array([1.0, 1.0])
bp_state.ftov_msgs[[(0, 0), (0, 1)], (0, 0)]

# %%
# Uniformly spread expected belief at a variable to all connected factors
bp_state.ftov_msgs[0, 0] = np.array([1.0, 1.0])
bp_state.ftov_msgs[[(0, 0), (0, 1)], (0, 0)]
70 changes: 40 additions & 30 deletions examples/pmp_binary_deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@

# %%
import functools
from collections import defaultdict

import jax
import matplotlib.pyplot as plt
import numpy as np
from scipy.special import logit
from tqdm.notebook import tqdm

from pgmax.factors import logical
from pgmax.fg import graph, groups
from pgmax.fg import graph
from pgmax.groups import logical
from pgmax.groups import variables as vgroup


# %%
Expand Down Expand Up @@ -115,32 +117,32 @@ def plot_images(images, display=True, nr=None):
s_width = im_width - feat_width + 1

# Binary features
W = groups.NDVariableArray(
W = vgroup.NDVariableArray(
num_states=2, shape=(n_chan, n_feat, feat_height, feat_width)
)

# Binary indicators of features locations
S = groups.NDVariableArray(num_states=2, shape=(n_images, n_feat, s_height, s_width))
S = vgroup.NDVariableArray(num_states=2, shape=(n_images, n_feat, s_height, s_width))

# Auxiliary binary variables combining W and S
SW = groups.NDVariableArray(
SW = vgroup.NDVariableArray(
num_states=2,
shape=(n_images, n_chan, im_height, im_width, n_feat, feat_height, feat_width),
)

# Binary images obtained by convolution
X = groups.NDVariableArray(num_states=2, shape=X_gt.shape)
X = vgroup.NDVariableArray(num_states=2, shape=X_gt.shape)

# Factor graph
fg = graph.FactorGraph(variables=dict(S=S, W=W, SW=SW, X=X))
# %% [markdown]
# For computation efficiency, we add large FactorGroups via `fg.add_factor_group` instead of adding individual Factors

# %%
# Note: although adding Factors is currently handled via for loops,
# we have plans to make this more efficient in the near future through the use of FactorGroups

variable_names_for_ORFactors = {}
# Factor graph
fg = graph.FactorGraph(variables=dict(S=S, W=W, SW=SW, X=X))

# Add ANDFactors
# Define the ANDFactors
variable_names_for_ANDFactors = []
variable_names_for_ORFactors_dict = defaultdict(list)
for idx_img in tqdm(range(n_images)):
for idx_chan in range(n_chan):
for idx_s_height in range(s_height):
Expand All @@ -161,7 +163,7 @@ def plot_images(images, display=True, nr=None):
idx_feat_width,
)

variable_names = [
variable_names_for_ANDFactor = [
("S", idx_img, idx_feat, idx_s_height, idx_s_width),
(
"W",
Expand All @@ -172,28 +174,36 @@ def plot_images(images, display=True, nr=None):
),
SW_var,
]

fg.add_factor_by_type(
variable_names=variable_names,
factor_type=logical.ANDFactor,
variable_names_for_ANDFactors.append(
variable_names_for_ANDFactor
)

X_var = (idx_img, idx_chan, idx_img_height, idx_img_width)
if X_var not in variable_names_for_ORFactors:
variable_names_for_ORFactors[X_var] = [SW_var]
else:
variable_names_for_ORFactors[X_var].append(SW_var)
variable_names_for_ORFactors_dict[X_var].append(SW_var)

# Add ANDFactorGroup, which is computationally efficient
fg.add_factor_group(
factory=logical.ANDFactorGroup,
variable_names_for_factors=variable_names_for_ANDFactors,
)

# Define the ORFactors
variable_names_for_ORFactors = [
list(tuple(variable_names_for_ORFactors_dict[X_var]) + (("X",) + X_var,))
for X_var in variable_names_for_ORFactors_dict
]

# Add ORFactorGroup, which is computationally efficient
fg.add_factor_group(
factory=logical.ORFactorGroup,
variable_names_for_factors=variable_names_for_ORFactors,
)

# Add ORFactors
for X_var, variable_names_for_ORFactor in variable_names_for_ORFactors.items():
fg.add_factor_by_type(
variable_names=variable_names_for_ORFactor + [("X",) + X_var], # type: ignore
factor_type=logical.ORFactor,
)
for factor_type, factor_groups in fg.factor_groups.items():
if len(factor_groups) > 0:
assert len(factor_groups) == 1
print(f"The factor graph contains {factor_groups[0].num_factors} {factor_type}")

for factor_type, factors in fg.factors.items():
print(f"The factor graph contains {len(factors)} {factor_type}")

# %% [markdown]
# ### Run inference and visualize results
Expand Down
Loading