From 66bbcbcd47d83d34278866ce2deae870838c7f39 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Tue, 12 Apr 2022 18:01:04 +0000
Subject: [PATCH 01/35] Rewrite NDVariableArray

---
 .pre-commit-config.yaml              |  10 +-
 examples/ising_model.py              |  25 ++-
 examples/pmp_binary_deconvolution.py |  61 ++++--
 examples/rbm.py                      |  47 +++--
 pgmax/factors/enumeration.py         |  20 +-
 pgmax/factors/logical.py             |  10 +-
 pgmax/fg/graph.py                    | 290 ++++++++++++++++---------
 pgmax/fg/groups.py                   | 305 ++-------------------------
 pgmax/fg/nodes.py                    |  25 +--
 pgmax/groups/enumeration.py          |  43 ++--
 pgmax/groups/logical.py              |   9 +-
 pgmax/groups/variables.py            | 281 +++++++++++++-----------
 tests/fg/test_wiring.py              |  44 ++--
 13 files changed, 539 insertions(+), 631 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index cfec6112..7e553c33 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -27,10 +27,10 @@ repos:
         args: ['--config', '.flake8.config','--exit-zero']
         verbose: true
 
--   repo: https://github.com/pre-commit/mirrors-mypy
-    rev: 'v0.942'  # Use the sha / tag you want to point at
-    hooks:
-    -   id: mypy
-        additional_dependencies: [tokenize-rt==3.2.0]
+# -   repo: https://github.com/pre-commit/mirrors-mypy
+#     rev: 'v0.942'  # Use the sha / tag you want to point at
+#     hooks:
+#     -   id: mypy
+#         additional_dependencies: [tokenize-rt==3.2.0]
 ci:
   autoupdate_schedule: 'quarterly'
diff --git a/examples/ising_model.py b/examples/ising_model.py
index 5b805c1a..ef5a267b 100644
--- a/examples/ising_model.py
+++ b/examples/ising_model.py
@@ -30,35 +30,50 @@
 # %%
 variables = vgroup.NDVariableArray(num_states=2, shape=(50, 50))
 fg = graph.FactorGraph(variables=variables)
+
+# TODO: rename variable_for_factors?
 variable_names_for_factors = []
 for ii in range(50):
     for jj in range(50):
         kk = (ii + 1) % 50
         ll = (jj + 1) % 50
-        variable_names_for_factors.append([(ii, jj), (kk, jj)])
-        variable_names_for_factors.append([(ii, jj), (ii, ll)])
+        variable_names_for_factors.append([variables[ii, jj], variables[kk, jj]])
+        variable_names_for_factors.append([variables[ii, jj], variables[kk, ll]])
 
 fg.add_factor_group(
     factory=enumeration.PairwiseFactorGroup,
     variable_names_for_factors=variable_names_for_factors,
     log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
-    name="factors",
 )
 
+# %%
+from pgmax.factors import enumeration as enumeration_factor
+
+factors = fg.factor_groups[enumeration_factor.EnumerationFactor][0].factors
+
 # %% [markdown]
 # ### Run inference and visualize results
 
+import imp
+
 # %%
+from pgmax.fg import graph
+
+imp.reload(graph)
 bp = graph.BP(fg.bp_state, temperature=0)
 
 # %%
+# TODO: check\ time for before BP vs time for BP
+# TODO: why bug when done twice?
+# TODO: time PGMAX vs PMP
 bp_arrays = bp.init(
-    evidence_updates={None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
+    evidence_updates={variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
 )
 bp_arrays = bp.run_bp(bp_arrays, num_iters=3000)
+output = bp.get_bp_output(bp_arrays)
 
 # %%
-img = graph.decode_map_states(bp.get_beliefs(bp_arrays))
+img = output.map_states[variables]
 fig, ax = plt.subplots(1, 1, figsize=(10, 10))
 ax.imshow(img)
 
diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py
index c00aac38..b10f5688 100644
--- a/examples/pmp_binary_deconvolution.py
+++ b/examples/pmp_binary_deconvolution.py
@@ -107,7 +107,13 @@ def plot_images(images, display=True, nr=None):
 #
 # See Section 5.6 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) for more details.
 
+import imp
+
 # %%
+from pgmax.fg import graph
+
+imp.reload(graph)
+
 # The dimensions of W used for the generation of X were (4, 5, 5) but we set them to (5, 6, 6)
 # to simulate a more realistic scenario in which we do not know their ground truth values
 n_feat, feat_height, feat_width = 5, 6, 6
@@ -116,6 +122,9 @@ def plot_images(images, display=True, nr=None):
 s_height = im_height - feat_height + 1
 s_width = im_width - feat_width + 1
 
+import time
+
+start = time.time()
 # Binary features
 W = vgroup.NDVariableArray(
     num_states=2, shape=(n_chan, n_feat, feat_height, feat_width)
@@ -132,13 +141,16 @@ def plot_images(images, display=True, nr=None):
 
 # Binary images obtained by convolution
 X = vgroup.NDVariableArray(num_states=2, shape=X_gt.shape)
+print("Time", time.time() - start)
 
 # %% [markdown]
 # For computation efficiency, we add large FactorGroups via `fg.add_factor_group` instead of adding individual Factors
 
 # %%
+start = time.time()
 # Factor graph
-fg = graph.FactorGraph(variables=dict(S=S, W=W, SW=SW, X=X))
+fg = graph.FactorGraph(variables=[S, W, SW, X])
+print("x", time.time() - start)
 
 # Define the ANDFactors
 variable_names_for_ANDFactors = []
@@ -152,8 +164,7 @@ def plot_images(images, display=True, nr=None):
                         for idx_feat_width in range(feat_width):
                             idx_img_height = idx_feat_height + idx_s_height
                             idx_img_width = idx_feat_width + idx_s_width
-                            SW_var = (
-                                "SW",
+                            SW_var = SW[
                                 idx_img,
                                 idx_chan,
                                 idx_img_height,
@@ -161,35 +172,31 @@ def plot_images(images, display=True, nr=None):
                                 idx_feat,
                                 idx_feat_height,
                                 idx_feat_width,
-                            )
+                            ]
 
                             variable_names_for_ANDFactor = [
-                                ("S", idx_img, idx_feat, idx_s_height, idx_s_width),
-                                (
-                                    "W",
-                                    idx_chan,
-                                    idx_feat,
-                                    idx_feat_height,
-                                    idx_feat_width,
-                                ),
+                                S[idx_img, idx_feat, idx_s_height, idx_s_width],
+                                W[idx_chan, idx_feat, idx_feat_height, idx_feat_width],
                                 SW_var,
                             ]
                             variable_names_for_ANDFactors.append(
                                 variable_names_for_ANDFactor
                             )
 
-                            X_var = (idx_img, idx_chan, idx_img_height, idx_img_width)
+                            X_var = X[idx_img, idx_chan, idx_img_height, idx_img_width]
                             variable_names_for_ORFactors_dict[X_var].append(SW_var)
+print(time.time() - start)
 
 # Add ANDFactorGroup, which is computationally efficient
 fg.add_factor_group(
     factory=logical.ANDFactorGroup,
     variable_names_for_factors=variable_names_for_ANDFactors,
 )
+print(time.time() - start)
 
 # Define the ORFactors
 variable_names_for_ORFactors = [
-    list(tuple(variable_names_for_ORFactors_dict[X_var]) + (("X",) + X_var,))
+    list(tuple(variable_names_for_ORFactors_dict[X_var]) + (X_var,))
     for X_var in variable_names_for_ORFactors_dict
 ]
 
@@ -198,6 +205,7 @@ def plot_images(images, display=True, nr=None):
     factory=logical.ORFactorGroup,
     variable_names_for_factors=variable_names_for_ORFactors,
 )
+print("Time", time.time() - start)
 
 for factor_type, factor_groups in fg.factor_groups.items():
     if len(factor_groups) > 0:
@@ -215,7 +223,9 @@ def plot_images(images, display=True, nr=None):
 # in the same manner does not change X, so this naturally results in multiple equivalent modes.
 
 # %%
+start = time.time()
 bp = graph.BP(fg.bp_state, temperature=0.0)
+print("Time", time.time() - start)
 
 # %% [markdown]
 # We first compute the evidence without perturbation, similar to the PMP paper.
@@ -236,6 +246,29 @@ def plot_images(images, display=True, nr=None):
 uX = np.zeros((X_gt.shape) + (2,))
 uX[..., 0] = (2 * X_gt - 1) * logit(pX)
 
+# %%
+np.random.seed(seed=40)
+n_samples = 1
+
+start = time.time()
+bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)(
+    evidence_updates={
+        "S": uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape),
+        "W": uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape),
+        "SW": np.zeros(shape=(n_samples,) + SW.shape),
+        "X": uX[None] + np.zeros(shape=(n_samples,) + uX.shape),
+    },
+)
+print("Time", time.time() - start)
+bp_arrays = jax.vmap(
+    functools.partial(bp.run_bp, num_iters=100, damping=0.5),
+    in_axes=0,
+    out_axes=0,
+)(bp_arrays)
+print("Time", time.time() - start)
+# beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
+# map_states = graph.decode_map_states(beliefs)
+
 # %% [markdown]
 # We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap`
 
diff --git a/examples/rbm.py b/examples/rbm.py
index d0faa91a..e4d26a83 100644
--- a/examples/rbm.py
+++ b/examples/rbm.py
@@ -45,13 +45,21 @@
 # %% [markdown]
 # We can then initialize the factor graph for the RBM with
 
+import imp
+
 # %%
+from pgmax.fg import graph
+
+imp.reload(graph)
+
+import time
+
+start = time.time()
 # Initialize factor graph
 hidden_variables = vgroup.NDVariableArray(num_states=2, shape=bh.shape)
 visible_variables = vgroup.NDVariableArray(num_states=2, shape=bv.shape)
-fg = graph.FactorGraph(
-    variables=dict(hidden=hidden_variables, visible=visible_variables),
-)
+fg = graph.FactorGraph(variables=[hidden_variables, visible_variables])
+print("Time", time.time() - start)
 
 # %% [markdown]
 # [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup) (e.g. an [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)), or a list/dictionary of [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
@@ -59,17 +67,18 @@
 # After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors by grouping them using
 
 # %%
+start = time.time()
 # Add unary factors
 fg.add_factor_group(
     factory=enumeration.EnumerationFactorGroup,
-    variable_names_for_factors=[[("hidden", ii)] for ii in range(bh.shape[0])],
+    variable_names_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])],
     factor_configs=np.arange(2)[:, None],
     log_potentials=np.stack([np.zeros_like(bh), bh], axis=1),
 )
 
 fg.add_factor_group(
     factory=enumeration.EnumerationFactorGroup,
-    variable_names_for_factors=[[("visible", jj)] for jj in range(bv.shape[0])],
+    variable_names_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])],
     factor_configs=np.arange(2)[:, None],
     log_potentials=np.stack([np.zeros_like(bv), bv], axis=1),
 )
@@ -81,13 +90,16 @@
 fg.add_factor_group(
     factory=enumeration.PairwiseFactorGroup,
     variable_names_for_factors=[
-        [("hidden", ii), ("visible", jj)]
+        [hidden_variables[ii], visible_variables[jj]]
         for ii in range(bh.shape[0])
         for jj in range(bv.shape[0])
     ],
     log_potential_matrix=log_potential_matrix,
 )
 
+# fg.add_factor_group(factory=enumeration.PairwiseFactorGroup, variable_names_for_factors=[[hidden_variables[ii], visible_variables[jj]]for ii in range(bh.shape[0])for jj in range(bv.shape[0])], log_potential_matrix=log_potential_matrix,)
+print("Time", time.time() - start)
+
 
 # %% [markdown]
 # PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing Groups of similar factors. The code above makes use of [`EnumerationFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumerationFactorGroup.html#pgmax.fg.groups.EnumerationFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup), two [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)s implemented in the [`pgmax.fg.groups`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.html#module-pgmax.fg.graph) module.
@@ -146,18 +158,23 @@
 # Now we are ready to demonstrate PMP sampling from RBM. PMP perturbs the model with [Gumbel](https://numpy.org/doc/stable/reference/random/generated/numpy.random.gumbel.html) unary potentials, and draws a sample from the RBM as the MAP decoding from running max-product LBP on the perturbed model
 
 # %%
+start = time.time()
 bp = graph.BP(fg.bp_state, temperature=0.0)
+print("Time", time.time() - start)
 
 # %%
+start = time.time()
 bp_arrays = bp.init(
     evidence_updates={
-        "hidden": np.random.gumbel(size=(bh.shape[0], 2)),
-        "visible": np.random.gumbel(size=(bv.shape[0], 2)),
+        hidden_variables: np.random.gumbel(size=(bh.shape[0], 2)),
+        visible_variables: np.random.gumbel(size=(bv.shape[0], 2)),
     },
 )
+print("Time", time.time() - start)
 bp_arrays = bp.run_bp(bp_arrays, num_iters=100, damping=0.5)
-beliefs = bp.get_beliefs(bp_arrays)
-map_states = graph.decode_map_states(beliefs)
+print("Time", time.time() - start)
+output = bp.get_bp_output(bp_arrays)
+print("Time", time.time() - start)
 
 # %% [markdown]
 # Here we use the `evidence_updates` argument of `run_bp` to perturb the model with Gumbel unary potentials. In general, `evidence_updates` can be used to incorporate evidence in the form of externally applied unary potentials in PGM inference.
@@ -166,7 +183,7 @@
 
 # %%
 fig, ax = plt.subplots(1, 1, figsize=(10, 10))
-ax.imshow(map_states["visible"].copy().reshape((28, 28)), cmap="gray")
+ax.imshow(output.map_states[visible_variables].copy().reshape((28, 28)), cmap="gray")
 ax.axis("off")
 
 # %% [markdown]
@@ -203,8 +220,12 @@
     in_axes=0,
     out_axes=0,
 )(bp_arrays)
-beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
-map_states = graph.decode_map_states(beliefs)
+
+outputs = jax.vmap(bp.get_bp_output, in_axes=0, out_axes=0)(bp_arrays)
+# map_states = graph.decode_map_states(beliefs)
+
+# %%
+O
 
 # %% [markdown]
 # Visualizing the MAP decodings (Figure [fig:rbm_multiple_digits]), we see that we have sampled 10 MNIST digits in parallel!
diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py
index 61d349d5..e2521597 100644
--- a/pgmax/factors/enumeration.py
+++ b/pgmax/factors/enumeration.py
@@ -81,9 +81,9 @@ def __post_init__(self):
                 f"EnumerationFactor. Got a factor_configs array of shape {self.factor_configs.shape}."
             )
 
-        if len(self.variables) != self.factor_configs.shape[1]:
+        if len(self.vars_to_num_states.keys()) != self.factor_configs.shape[1]:
             raise ValueError(
-                f"Number of variables {len(self.variables)} doesn't match given configurations {self.factor_configs.shape}"
+                f"Number of variables {len(self.vars_to_num_states.keys())} doesn't match given configurations {self.factor_configs.shape}"
             )
 
         if self.log_potentials.shape != (self.factor_configs.shape[0],):
@@ -93,7 +93,7 @@ def __post_init__(self):
                 f"shape {self.log_potentials.shape}."
             )
 
-        vars_num_states = np.array([variable.num_states for variable in self.variables])
+        vars_num_states = np.array([list(self.vars_to_num_states.values())])
         if not np.logical_and(
             self.factor_configs >= 0, self.factor_configs < vars_num_states[None]
         ).all():
@@ -155,9 +155,9 @@ def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiri
     @staticmethod
     def compile_wiring(
         factor_edges_num_states: np.ndarray,
-        variables_for_factors: Tuple[nodes.Variable, ...],
+        variables_for_factors: Tuple[int, ...],  # TODO: rename
         factor_configs: np.ndarray,
-        vars_to_starts: Mapping[nodes.Variable, int],
+        vars_to_starts: Mapping[int, int],
         num_factors: int,
     ) -> EnumerationWiring:
         """Compile an EnumerationWiring for an EnumerationFactor or a FactorGroup with EnumerationFactors.
@@ -166,7 +166,7 @@ def compile_wiring(
         Args:
             factor_edges_num_states: An array concatenating the number of states for the variables connected to each
                 Factor of the FactorGroup. Each variable will appear once for each Factor it connects to.
-            variables_for_factors: A tuple of tuples containing variables connected to each Factor of the FactorGroup.
+            variables_for_factors: A tuple containing variables connected to each Factor of the FactorGroup.
                 Each variable will appear once for each Factor it connects to.
             factor_configs: Array of shape (num_val_configs, num_variables) containing an explicit enumeration
                 of all valid configurations.
@@ -184,10 +184,10 @@ def compile_wiring(
         var_states = np.array(
             [vars_to_starts[variable] for variable in variables_for_factors]
         )
-        num_states = np.array(
-            [variable.num_states for variable in variables_for_factors]
-        )
-        num_states_cumsum = np.insert(np.cumsum(num_states), 0, 0)
+        # num_states = np.array(
+        #     [variable.num_states for variable in variables_for_factors]
+        # )
+        num_states_cumsum = np.insert(np.cumsum(factor_edges_num_states), 0, 0)
         var_states_for_edges = np.empty(shape=(num_states_cumsum[-1],), dtype=int)
         _compile_var_states_numba(var_states_for_edges, num_states_cumsum, var_states)
 
diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py
index 8a27d0be..c53a5fac 100644
--- a/pgmax/factors/logical.py
+++ b/pgmax/factors/logical.py
@@ -86,12 +86,14 @@ class LogicalFactor(nodes.Factor):
     edge_states_offset: int = field(init=False)
 
     def __post_init__(self):
-        if len(self.variables) < 2:
+        if len(self.vars_to_num_states.keys()) < 2:
             raise ValueError(
                 "At least one parent variable and one child variable is required"
             )
 
-        if not np.all([variable.num_states == 2 for variable in self.variables]):
+        if not np.all(
+            [num_states == 2 for num_states in self.vars_to_num_states.values()]
+        ):
             raise ValueError("All variables should all be binary")
 
     @staticmethod
@@ -141,9 +143,9 @@ def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring:
     @staticmethod
     def compile_wiring(
         factor_edges_num_states: np.ndarray,
-        variables_for_factors: Tuple[nodes.Variable, ...],
+        variables_for_factors: Tuple[int, ...],  # notsure
         factor_sizes: np.ndarray,
-        vars_to_starts: Mapping[nodes.Variable, int],
+        vars_to_starts: Mapping[int, int],
         edge_states_offset: int,
     ) -> LogicalWiring:
         """Compile a LogicalWiring for a LogicalFactor or a FactorGroup with LogicalFactors.
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index c0b4f096..f0a0aefe 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -1,5 +1,7 @@
 from __future__ import annotations
 
+from this import d
+
 """A module containing the core class to specify a Factor Graph."""
 
 import collections
@@ -7,7 +9,7 @@
 import functools
 import inspect
 import typing
-from dataclasses import asdict, dataclass
+from dataclasses import asdict, dataclass, field
 from types import MappingProxyType
 from typing import (
     Any,
@@ -35,6 +37,7 @@
 from pgmax.bp import infer
 from pgmax.factors import FAC_TO_VAR_UPDATES
 from pgmax.fg import groups, nodes
+from pgmax.groups import variables as vgroup
 from pgmax.groups.enumeration import EnumerationFactorGroup
 from pgmax.utils import cached_property
 
@@ -53,23 +56,34 @@ class FactorGraph:
             this input, and the individual VariableGroups will need to be accessed by indexing.
     """
 
-    variables: Union[
-        Mapping[Any, groups.VariableGroup],
-        Sequence[groups.VariableGroup],
-        groups.VariableGroup,
-    ]
+    variables: Union[groups.VariableGroup, Sequence[groups.VariableGroup]]
 
     def __post_init__(self):
-        if isinstance(self.variables, groups.VariableGroup):
-            self._variable_group = self.variables
-        else:
-            self._variable_group = groups.CompositeVariableGroup(self.variables)
-
+        import time
+
+        start = time.time()
+        # if isinstance(self.variables, groups.VariableGroup):
+        if isinstance(self.variables, vgroup.NDVariableArray):
+            self.variables = [self.variables]
+
+        self._variable_group: Mapping[
+            int, groups.VariableGroup
+        ] = collections.OrderedDict()
+        for variable_group in self.variables:
+            self._variable_group[variable_group.__hash__()] = variable_group
+
+        vars_names = []
+        vars_num_states = []
+        for variable_group in self.variables:
+            if isinstance(variable_group, vgroup.NDVariableArray):
+                vars_names.append(variable_group.variable_names.flatten())
+                vars_num_states.append(variable_group.num_states.flatten())
+        print("1", time.time() - start)
+
+        vars_names = np.concatenate(vars_names)
+        vars_num_states = np.concatenate(vars_num_states)
         vars_num_states_cumsum = np.insert(
-            np.array(
-                [variable.num_states for variable in self._variable_group.variables],
-                dtype=int,
-            ).cumsum(),
+            np.array(vars_num_states).cumsum(),
             0,
             0,
         )
@@ -85,16 +99,22 @@ def __post_init__(self):
         ] = collections.OrderedDict(
             [(factor_type, set()) for factor_type in FAC_TO_VAR_UPDATES]
         )
-
+        print("2", time.time() - start)
+
+        # Used to add FactorGroups
+        # TODO: dict is faster
+        aa = zip(vars_names, vars_num_states)
+        print("3", time.time() - start)
+        self._vars_to_num_states: OrderedDict[int, int] = collections.OrderedDict(
+            zip(vars_names, vars_num_states)
+        )
         # See FactorGraphState docstrings for documentation on the following fields
         self._num_var_states = vars_num_states_cumsum[-1]
-        self._vars_to_starts = MappingProxyType(
-            {
-                variable: vars_num_states_cumsum[vv]
-                for vv, variable in enumerate(self._variable_group.variables)
-            }
+        self._vars_to_starts: OrderedDict[int, int] = collections.OrderedDict(
+            zip(vars_names, vars_num_states_cumsum[:-1])
         )
         self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {}
+        print("3", time.time() - start)
 
     def __hash__(self) -> int:
         all_factor_groups = tuple(
@@ -130,7 +150,7 @@ def add_factor(
                 initialized.
         """
         factor_group = EnumerationFactorGroup(
-            self._variable_group,
+            self._vars_to_num_states,
             variable_names_for_factors=[variable_names],
             factor_configs=factor_configs,
             log_potentials=log_potentials,
@@ -138,7 +158,7 @@ def add_factor(
         self._register_factor_group(factor_group, name)
 
     def add_factor_by_type(
-        self, variable_names: List, factor_type: type, *args, **kwargs
+        self, variable_names: List[int], factor_type: type, *args, **kwargs
     ) -> None:
         """Function to add a single factor to the FactorGraph.
 
@@ -163,15 +183,16 @@ def add_factor_by_type(
                 f"Type {factor_type} is not one of the supported factor types {FAC_TO_VAR_UPDATES.keys()}"
             )
 
-        name = kwargs.pop("name", None)
-        variables = tuple(self._variable_group[variable_names])
-        factor = factor_type(variables, *args, **kwargs)
+        vars_to_num_states = collections.OrderedDict(
+            (var, self._vars_to_num_states[var]) for var in variable_names
+        )
+        factor = factor_type(vars_to_num_states, *args, **kwargs)
         factor_group = groups.SingleFactorGroup(
-            variable_group=self._variable_group,
+            vars_to_num_states=self._vars_to_num_states,
             variable_names_for_factors=[variable_names],
             factor=factor,
         )
-        self._register_factor_group(factor_group, name)
+        self._register_factor_group(factor_group)
 
     def add_factor_group(self, factory: Callable, *args, **kwargs) -> None:
         """Add a factor group to the factor graph
@@ -182,13 +203,10 @@ def add_factor_group(self, factory: Callable, *args, **kwargs) -> None:
             kwargs: kwargs to be passed to the factory function, and an optional "name" argument
                 for specifying the name of a named factor group.
         """
-        name = kwargs.pop("name", None)
-        factor_group = factory(self._variable_group, *args, **kwargs)
-        self._register_factor_group(factor_group, name)
+        factor_group = factory(self._vars_to_num_states, *args, **kwargs)
+        self._register_factor_group(factor_group)
 
-    def _register_factor_group(
-        self, factor_group: groups.FactorGroup, name: Optional[str] = None
-    ) -> None:
+    def _register_factor_group(self, factor_group: groups.FactorGroup) -> None:
         """Register a factor group to the factor graph, by updating the factor graph state.
 
         Args:
@@ -199,10 +217,6 @@ def _register_factor_group(
             ValueError: If the factor group with the same name or a factor involving the same variables
                 already exists in the factor graph.
         """
-        if name in self._named_factor_groups:
-            raise ValueError(
-                f"A factor group with the name {name} already exists. Please choose a different name!"
-            )
 
         factor_type = factor_group.factor_type
         for var_names_for_factor in factor_group.variable_names_for_factors:
@@ -211,15 +225,13 @@ def _register_factor_group(
                 var_names
                 in self._factor_types_to_variable_names_for_factors[factor_type]
             ):
+                print(len(var_names_for_factor))
                 raise ValueError(
                     f"A Factor of type {factor_type} involving variables {var_names} already exists. Please merge the corresponding factors."
                 )
-
             self._factor_types_to_variable_names_for_factors[factor_type].add(var_names)
 
         self._factor_types_to_groups[factor_type].append(factor_group)
-        if name is not None:
-            self._named_factor_groups[name] = factor_group
 
     @functools.lru_cache(None)
     def compute_offsets(self) -> None:
@@ -411,7 +423,7 @@ class FactorGraphState:
     """
 
     variable_group: groups.VariableGroup
-    vars_to_starts: Mapping[nodes.Variable, int]
+    vars_to_starts: Mapping[int, int]
     num_var_states: int
     total_factor_num_states: int
     named_factor_groups: Mapping[Hashable, groups.FactorGroup]
@@ -719,7 +731,7 @@ def __setitem__(self, names, data) -> None:
         )
 
 
-@functools.partial(jax.jit, static_argnames="fg_state")
+# @functools.partial(jax.jit, static_argnames="fg_state")
 def update_evidence(
     evidence: jnp.ndarray, updates: Dict[Any, jnp.ndarray], fg_state: FactorGraphState
 ) -> jnp.ndarray:
@@ -733,26 +745,21 @@ def update_evidence(
     Returns:
         A flat jnp array containing updated evidence.
     """
-    for name in updates:
-        data = updates[name]
-        if name in fg_state.variable_group.container_names:
-            if name is None:
-                variable_group = fg_state.variable_group
-            else:
-                assert isinstance(
-                    fg_state.variable_group, groups.CompositeVariableGroup
-                )
-                variable_group = fg_state.variable_group.variable_group_container[name]
-
-            start_index = fg_state.vars_to_starts[variable_group.variables[0]]
+    for var_group_name in updates:
+        data = updates[var_group_name]
+        print(data.shape)
+        if var_group_name.__hash__() in fg_state.variable_group:
+            variable_group = fg_state.variable_group[var_group_name.__hash__()]
+            first_variable = variable_group.variable_names.flatten()[0]
+            start_index = fg_state.vars_to_starts[first_variable]
             flat_data = variable_group.flatten(data)
             evidence = evidence.at[start_index : start_index + flat_data.shape[0]].set(
                 flat_data
             )
-        else:
-            var = fg_state.variable_group[name]
-            start_index = fg_state.vars_to_starts[var]
-            evidence = evidence.at[start_index : start_index + var.num_states].set(data)
+        # else:
+        #     var = fg_state.variable_group[name]
+        #     start_index = fg_state.vars_to_starts[var]
+        #     evidence = evidence.at[start_index : start_index + var.num_states].set(data)
     return evidence
 
 
@@ -889,18 +896,18 @@ class BeliefPropagation:
             Returns:
                 The reconstructed BPState
 
-        get_beliefs: Function to calculate beliefs from a BPArrays.
+        get_bp_output: Function to calculate beliefs from a BPArrays.
             Args:
                 bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence.
             Returns:
-                beliefs: An array or a PyTree container containing the beliefs for the variables.
+                bp_output: Belief propagation output.
     """
 
     init: Callable
     update: Callable
     run_bp: Callable
     to_bp_state: Callable
-    get_beliefs: Callable
+    get_bp_output: Callable
 
 
 def BP(bp_state: BPState, temperature: float = 0.0) -> BeliefPropagation:
@@ -978,6 +985,7 @@ def update(
             )
 
         if evidence_updates is not None:
+            print(type(evidence), type(evidence_updates))
             evidence = update_evidence(evidence, evidence_updates, bp_state.fg_state)
 
         return BPArrays(
@@ -1065,62 +1073,150 @@ def to_bp_state(bp_arrays: BPArrays) -> BPState:
             evidence=Evidence(fg_state=bp_state.fg_state, value=bp_arrays.evidence),
         )
 
-    @jax.jit
-    def get_beliefs(bp_arrays: BPArrays) -> Any:
+    def get_bp_output(bp_arrays: BPArrays) -> Any:
         """Function to calculate beliefs from a BPArrays
 
         Args:
             bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence.
 
         Returns:
-            beliefs: An array or a PyTree container containing the beliefs for the variables.
+            bp_output: Belief propagation output.
         """
-        beliefs = bp_state.fg_state.variable_group.unflatten(
-            jax.device_put(bp_arrays.evidence)
-            .at[jax.device_put(var_states_for_edges)]
-            .add(bp_arrays.ftov_msgs)
+
+        @jax.jit
+        def compute_flat_beliefs(bp_arrays):
+            flat_beliefs = (
+                jax.device_put(bp_arrays.evidence)
+                .at[jax.device_put(var_states_for_edges)]
+                .add(bp_arrays.ftov_msgs)
+            )
+            return flat_beliefs
+
+        return BeliefPropagationOutputs(
+            compute_flat_beliefs(bp_arrays), bp_state.fg_state.variable_group
         )
-        return beliefs
 
     bp = BeliefPropagation(
         init=functools.partial(update, None),
         update=update,
         run_bp=run_bp,
         to_bp_state=to_bp_state,
-        get_beliefs=get_beliefs,
+        get_bp_output=get_bp_output,
     )
     return bp
 
 
-@jax.jit
-def decode_map_states(beliefs: Any) -> Any:
-    """Function to decode MAP states given the calculated beliefs.
+@dataclass(frozen=True, eq=False)
+class HashableDict:
+    d: Dict = field(default_factory=dict)
 
-    Args:
-        beliefs: An array or a PyTree container containing beliefs for different variables.
+    def __setitem__(self, key, value):
+        self.d[key] = value
 
-    Returns:
-        An array or a PyTree container containing the MAP states for different variables.
-    """
-    map_states = jax.tree_util.tree_map(
-        lambda x: jnp.argmax(x, axis=-1),
-        beliefs,
-    )
-    return map_states
+    def __getitem__(self, value):
+        return self.d[value.__hash__()]
 
 
-@jax.jit
-def get_marginals(beliefs: Any) -> Any:
-    """Function to get marginal probabilities given the calculated beliefs.
+@dataclass(frozen=True, eq=False)
+class BeliefPropagationOutputs:
+    # beliefs: An array or a PyTree container containing the beliefs for the variables.
+    flat_beliefs: jnp.ndarray
+    variable_groups: Mapping[int, groups.VariableGroup]
+    beliefs: HashableDict = field(init=False, default_factory=dict)
 
-    Args:
-        beliefs: An array or a PyTree container containing beliefs for different variables.
+    def __post_init__(self):
+        if self.flat_beliefs.ndim != 1:
+            raise ValueError(
+                f"Can only unflatten 1D array. Got a {self.flat_beliefs.ndim}D array."
+            )
+        beliefs = self.unflatten()
+        object.__setattr__(self, "beliefs", beliefs)
 
-    Returns:
-        An array or a PyTree container containing the marginal probabilities different variables.
-    """
-    marginals = jax.tree_util.tree_map(
-        lambda x: jnp.exp(x - logsumexp(x, axis=-1, keepdims=True)),
-        beliefs,
-    )
-    return marginals
+    @functools.lru_cache
+    def unflatten(self) -> None:
+        """Function that recovers meaningful structured data from internal flat data array
+
+        Args:
+            variable_groups: TODO
+
+        Returns:
+            Meaningful structured data, with structure matching that of self.variable_group_container.
+
+        Raises:
+            ValueError: if flat_data is not of the right shape
+        """
+        # Note: this is a reimplementation of CompositeVariableGroup.unflatten
+        num_variables = 0
+        num_variable_states = 0
+        for variable_group in self.variable_groups.values():
+            if isinstance(variable_group, vgroup.NDVariableArray):
+                num_variables += variable_group.num_states.size
+                num_variable_states += variable_group.num_states.sum()
+
+        if self.flat_beliefs.shape[0] == num_variables:
+            use_num_states = False
+        elif self.flat_beliefs.shape[0] == num_variable_states:
+            use_num_states = True
+        else:
+            raise ValueError(
+                f"flat_data should be either of shape (num_variables(={len(self.variables)}),), "
+                f"or (num_variable_states(={num_variable_states}),). "
+                f"Got {self.flat_beliefs.shape}"
+            )
+
+        beliefs = {}
+        start = 0
+        for name, variable_group in self.variable_groups.items():
+            if use_num_states:
+                length = variable_group.num_states.sum()
+            else:
+                length = variable_group.num_states.size
+
+            beliefs[name] = variable_group.unflatten(
+                self.flat_beliefs[start : start + length]
+            )
+            start += length
+        return beliefs
+
+    @cached_property
+    def map_states(self) -> Any:
+        """Function to decode MAP states given the calculated beliefs.
+
+        Args:
+            beliefs: An array or a PyTree container containing beliefs for different variables.
+
+        Returns:
+            An array or a PyTree container containing the MAP states for different variables.
+        """
+
+        @jax.jit
+        def _decode_map_states(beliefs) -> Any:
+            return jax.tree_util.tree_map(lambda x: jnp.argmax(x, axis=-1), beliefs)
+
+        map_states = HashableDict()
+        for name, beliefs in self.beliefs.items():
+            map_states[name] = _decode_map_states(beliefs)
+        return map_states
+
+    @cached_property
+    def get_marginals(self) -> Any:
+        """Function to get marginal probabilities given the calculated beliefs.
+
+        Args:
+            beliefs: An array or a PyTree container containing beliefs for different variables.
+
+        Returns:
+            An array or a PyTree container containing the marginal probabilities different variables.
+        """
+
+        @jax.jit
+        def _get_marginals(beliefs) -> Any:
+            return jax.tree_util.tree_map(
+                lambda x: jnp.exp(x - logsumexp(x, axis=-1, keepdims=True)),
+                beliefs,
+            )
+
+        marginals = HashableDict()
+        for name, beliefs in self.beliefs.items():
+            marginals[name] = _get_marginals(beliefs)
+        return marginals
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index 7b0f767d..c90fe106 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -33,34 +33,13 @@ class VariableGroup:
     All variables in the group are assumed to have the same size. Additionally, the
     variables are indexed by a variable name, and can be retrieved by direct indexing (even indexing
     a sequence of variable names) of the VariableGroup.
-
-    Attributes:
-        _names_to_variables: A private, immutable mapping from variable names to variables
     """
 
-    _names_to_variables: Mapping[Hashable, nodes.Variable] = field(init=False)
-
-    def __post_init__(self) -> None:
-        """Initialize a private, immutable mapping from variable names to variables."""
-        object.__setattr__(
-            self,
-            "_names_to_variables",
-            MappingProxyType(self._get_names_to_variables()),
-        )
-
-    @typing.overload
-    def __getitem__(self, name: Hashable) -> nodes.Variable:
-        """This function is a typing overload and is overwritten by the implemented __getitem__"""
-
-    @typing.overload
-    def __getitem__(self, name: List) -> List[nodes.Variable]:
-        """This function is a typing overload and is overwritten by the implemented __getitem__"""
-
-    def __getitem__(self, name):
+    def __getitem__(self, val):
         """Given a name, retrieve the associated Variable.
 
         Args:
-            name: a single name corresponding to a single variable, or a list of such names
+            val: a single name corresponding to a single variable, or a list of such names
 
         Returns:
             A single variable if the name is not a list. A list of variables if name is a list
@@ -68,61 +47,20 @@ def __getitem__(self, name):
         Raises:
             ValueError: if the name is not found in the group
         """
-        if isinstance(name, List):
-            names_list = name
-        else:
-            names_list = [name]
-
-        vars_list = []
-        for curr_name in names_list:
-            var = self._names_to_variables.get(curr_name)
-            if var is None:
-                raise ValueError(
-                    f"The name {curr_name} is not present in the VariableGroup {type(self)}; please ensure "
-                    "it's been added to the VariableGroup before trying to query it."
-                )
-
-            vars_list.append(var)
-
-        if isinstance(name, List):
-            return vars_list
-        else:
-            return vars_list[0]
-
-    def _get_names_to_variables(self) -> OrderedDict[Any, nodes.Variable]:
-        """Function that generates a dictionary mapping names to variables.
-
-        Returns:
-            a dictionary mapping all possible names to different variables.
-        """
         raise NotImplementedError(
             "Please subclass the VariableGroup class and override this method"
         )
 
     @cached_property
-    def names(self) -> Tuple[Any, ...]:
-        """Function to return a tuple of all names in the group.
-
-        Returns:
-            tuple of all names that are part of this VariableGroup
-        """
-        return tuple(self._names_to_variables.keys())
-
-    @cached_property
-    def variables(self) -> Tuple[nodes.Variable, ...]:
-        """Function to return a tuple of all variables in the group.
+    def variables_names(self) -> Any:
+        """Function that generates a dictionary mapping names to variables.
 
         Returns:
-            tuple of all variable that are part of this VariableGroup
-        """
-        return tuple(self._names_to_variables.values())
-
-    @cached_property
-    def container_names(self) -> Tuple:
-        """Placeholder function. Returns a tuple containing None for all variable groups
-        other than a composite variable group
+            a dictionary mapping all possible names to different variables.
         """
-        return (None,)
+        raise NotImplementedError(
+            "Please subclass the VariableGroup class and override this method"
+        )
 
     def flatten(self, data: Any) -> jnp.ndarray:
         """Function that turns meaningful structured data into a flat data array for internal use.
@@ -151,220 +89,21 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any:
         )
 
 
-@dataclass(frozen=True, eq=False)
-class CompositeVariableGroup(VariableGroup):
-    """A class to encapsulate a collection of instantiated VariableGroups.
-
-    This class enables users to wrap various different VariableGroups and then index
-    them in a straightforward manner. To index into a CompositeVariableGroup, simply
-    provide the name of the VariableGroup within this CompositeVariableGroup followed
-    by the name to be indexed within the VariableGroup.
-
-    Args:
-        variable_group_container: A container containing multiple variable groups.
-            Supported containers include mapping and sequence.
-            For a mapping, the names of the mapping are used to index the variable groups.
-            For a sequence, the indices of the sequence are used to index the variable groups.
-
-    Attributes:
-        _names_to_variables: A private, immutable mapping from names to variables
-    """
-
-    variable_group_container: Union[
-        Mapping[Hashable, VariableGroup], Sequence[VariableGroup]
-    ]
-
-    def __post_init__(self):
-        object.__setattr__(
-            self,
-            "_names_to_variables",
-            MappingProxyType(self._get_names_to_variables()),
-        )
-
-    @typing.overload
-    def __getitem__(self, name: Hashable) -> nodes.Variable:
-        """This function is a typing overload and is overwritten by the implemented __getitem__"""
-
-    @typing.overload
-    def __getitem__(self, name: List) -> List[nodes.Variable]:
-        """This function is a typing overload and is overwritten by the implemented __getitem__"""
-
-    def __getitem__(self, name):
-        """Given a name, retrieve the associated Variable from the associated VariableGroup.
-
-        Args:
-            name: a single name corresponding to a single Variable within a VariableGroup, or a list
-                of such names
-
-        Returns:
-            A single variable if the name is not a list. A list of variables if name is a list
-
-        Raises:
-            ValueError: if the name does not have the right format (tuples with at least two elements).
-        """
-        if isinstance(name, List):
-            names_list = name
-        else:
-            names_list = [name]
-
-        vars_list = []
-        for curr_name in names_list:
-            if len(curr_name) < 2:
-                raise ValueError(
-                    "The name needs to have at least 2 elements to index from a composite variable group."
-                )
-
-            variable_group = self.variable_group_container[curr_name[0]]
-            if len(curr_name) == 2:
-                vars_list.append(variable_group[curr_name[1]])
-            else:
-                vars_list.append(variable_group[curr_name[1:]])
-
-        if isinstance(name, List):
-            return vars_list
-        else:
-            return vars_list[0]
-
-    def _get_names_to_variables(self) -> OrderedDict[Hashable, nodes.Variable]:
-        """Function that generates a dictionary mapping names to variables.
-
-        Returns:
-            a dictionary mapping all possible names to different variables.
-        """
-        names_to_variables: OrderedDict[
-            Hashable, nodes.Variable
-        ] = collections.OrderedDict()
-        for container_name in self.container_names:
-            for variable_group_name in self.variable_group_container[
-                container_name
-            ].names:
-                if isinstance(variable_group_name, tuple):
-                    names_to_variables[
-                        (container_name,) + variable_group_name
-                    ] = self.variable_group_container[container_name][
-                        variable_group_name
-                    ]
-                else:
-                    names_to_variables[
-                        (container_name, variable_group_name)
-                    ] = self.variable_group_container[container_name][
-                        variable_group_name
-                    ]
-
-        return names_to_variables
-
-    def flatten(self, data: Union[Mapping, Sequence]) -> jnp.ndarray:
-        """Function that turns meaningful structured data into a flat data array for internal use.
-
-        Args:
-            data: Meaningful structured data.
-                The structure of data should match self.variable_group_container.
-
-        Returns:
-            A flat jnp.array for internal use
-        """
-        flat_data = jnp.concatenate(
-            [
-                self.variable_group_container[name].flatten(data[name])
-                for name in self.container_names
-            ]
-        )
-        return flat_data
-
-    def unflatten(
-        self, flat_data: Union[np.ndarray, jnp.ndarray]
-    ) -> Union[Mapping, Sequence]:
-        """Function that recovers meaningful structured data from internal flat data array
-
-        Args:
-            flat_data: Internal flat data array.
-
-        Returns:
-            Meaningful structured data, with structure matching that of self.variable_group_container.
-
-        Raises:
-            ValueError if:
-                (1) flat_data is not a 1D array
-                (2) flat_data is not of the right shape
-        """
-        if flat_data.ndim != 1:
-            raise ValueError(
-                f"Can only unflatten 1D array. Got a {flat_data.ndim}D array."
-            )
-
-        num_variables = 0
-        num_variable_states = 0
-        for name in self.container_names:
-            variable_group = self.variable_group_container[name]
-            num_variables += len(variable_group.variables)
-            num_variable_states += (
-                len(variable_group.variables) * variable_group.variables[0].num_states
-            )
-
-        if flat_data.shape[0] == num_variables:
-            use_num_states = False
-        elif flat_data.shape[0] == num_variable_states:
-            use_num_states = True
-        else:
-            raise ValueError(
-                f"flat_data should be either of shape (num_variables(={len(self.variables)}),), "
-                f"or (num_variable_states(={num_variable_states}),). "
-                f"Got {flat_data.shape}"
-            )
-
-        data: List[np.ndarray] = []
-        start = 0
-        for name in self.container_names:
-            variable_group = self.variable_group_container[name]
-            length = len(variable_group.variables)
-            if use_num_states:
-                length *= variable_group.variables[0].num_states
-
-            data.append(variable_group.unflatten(flat_data[start : start + length]))
-            start += length
-        if isinstance(self.variable_group_container, Mapping):
-            return dict(
-                [(name, data[kk]) for kk, name in enumerate(self.container_names)]
-            )
-        else:
-            return data
-
-    @cached_property
-    def container_names(self) -> Tuple:
-        """Function to get names referring to the variable groups within this
-        CompositeVariableGroup.
-
-        Returns:
-            a tuple of the names referring to the variable groups within this
-            CompositeVariableGroup.
-        """
-        if isinstance(self.variable_group_container, Mapping):
-            container_names = tuple(self.variable_group_container.keys())
-        else:
-            container_names = tuple(range(len(self.variable_group_container)))
-
-        return container_names
-
-
 @dataclass(frozen=True, eq=False)
 class FactorGroup:
     """Class to represent a group of Factors.
 
     Args:
-        variable_group: either a VariableGroup or - if the elements of more than one VariableGroup
-            are connected to this FactorGroup - then a CompositeVariableGroup. This holds
-            all the variables that are connected to this FactorGroup
+        vars_to_num_states: TODO
         variable_names_for_factors: A list of list of variable names, where each innermost element is the
-            name of a variable in variable_group. Each list within the outer list is taken to contain
-            the names of the variables connected to a Factor.
+            name of a variable. Each list within the outer list is taken to contain the names of the
+            variables connected to a Factor.
         factor_configs: Optional array containing an explicit enumeration of all valid configurations
         log_potentials: Array of log potentials.
 
     Attributes:
         factor_type: Factor type shared by all the Factors in the FactorGroup.
         factor_sizes: Array of the different factor sizes.
-        variables_for_factors: Tuple concatenating the variables connected to each factor in the FactorGroup.
-            Each variable will appear once for each Factor it connects to.
         factor_edges_num_states: Array concatenating the number of states for the variables connected to each Factor of
             the FactorGroup. Each variable will appear once for each Factor it connects to.
 
@@ -372,36 +111,39 @@ class FactorGroup:
         ValueError: if the FactorGroup does not contain a Factor
     """
 
-    variable_group: Union[CompositeVariableGroup, VariableGroup]
+    vars_to_num_states: Mapping[int, int]
     variable_names_for_factors: Sequence[List]
     factor_configs: np.ndarray = field(init=False)
     log_potentials: np.ndarray = field(init=False, default=np.empty((0,)))
     factor_type: Type = field(init=False)
     factor_sizes: np.ndarray = field(init=False)
-    variables_for_factors: Tuple[nodes.Variable, ...] = field(init=False)
+    variables_for_factors: Tuple[Tuple[int], ...] = field(init=False)
     factor_edges_num_states: np.ndarray = field(init=False)
 
     def __post_init__(self):
         if len(self.variable_names_for_factors) == 0:
             raise ValueError("Do not add a factor group with no factors.")
 
+        # Note: variable_names_for_factors contains the HASHes
+        # Note: this can probably be sped up by numba
         factor_sizes = []
-        variables_for_factors = []
+        flat_var_names_for_factors = []
         factor_edges_num_states = []
         for variable_names_for_factor in self.variable_names_for_factors:
             for variable_name in variable_names_for_factor:
-                variable = self.variable_group._names_to_variables[variable_name]
-                variables_for_factors.append(variable)
-                factor_edges_num_states.append(variable.num_states)
+                factor_edges_num_states.append(self.vars_to_num_states[variable_name])
+                flat_var_names_for_factors.append(variable_name)
             factor_sizes.append(len(variable_names_for_factor))
 
         object.__setattr__(self, "factor_sizes", np.array(factor_sizes))
-        object.__setattr__(self, "variables_for_factors", tuple(variables_for_factors))
+        object.__setattr__(
+            self, "variables_for_factors", np.array(flat_var_names_for_factors)
+        )
         object.__setattr__(
             self, "factor_edges_num_states", np.array(factor_edges_num_states)
         )
 
-    def __getitem__(self, variables: Union[Sequence, Collection]) -> Any:
+    def __getitem__(self, variables: Sequence[int]) -> Any:
         """Function to query individual factors in the factor group
 
         Args:
@@ -419,7 +161,6 @@ def __getitem__(self, variables: Union[Sequence, Collection]) -> Any:
             raise ValueError(
                 f"The queried factor connected to the set of variables {variables} is not present in the factor group."
             )
-
         return self._variables_to_factors[variables]
 
     @cached_property
@@ -485,7 +226,7 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any:
             "Please subclass the FactorGroup class and override this method"
         )
 
-    def compile_wiring(self, vars_to_starts: Mapping[nodes.Variable, int]) -> Any:
+    def compile_wiring(self, vars_to_starts: Mapping[int, int]) -> Any:
         """Compile an efficient wiring for the FactorGroup.
 
         Args:
diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py
index 32fa29ad..97fe7c5e 100644
--- a/pgmax/fg/nodes.py
+++ b/pgmax/fg/nodes.py
@@ -1,7 +1,7 @@
 """A module containing classes that specify the basic components of a Factor Graph."""
 
 from dataclasses import asdict, dataclass
-from typing import Sequence, Tuple, Union
+from typing import OrderedDict, Sequence, Union
 
 import jax
 import jax.numpy as jnp
@@ -10,19 +10,6 @@
 from pgmax import utils
 
 
-@dataclass(frozen=True, eq=False)
-class Variable:
-    """Base class for variables.
-    If desired, this can be sub-classed to add additional concrete
-    meta-information
-
-    Args:
-        num_states: an int representing the number of states this variable has.
-    """
-
-    num_states: int
-
-
 @jax.tree_util.register_pytree_node_class
 @dataclass(frozen=True, eq=False)
 class Wiring:
@@ -56,13 +43,14 @@ class Factor:
     """A factor
 
     Args:
-        variables: List of connected variables
+        vars_to_num_states: Dictionnary mapping the variables names, represented
+            in the form of a hash, to the variables number of states.
 
     Raises:
         NotImplementedError: If compile_wiring is not implemented
     """
 
-    variables: Tuple[Variable, ...]
+    vars_to_num_states: OrderedDict[int, int]
     log_potentials: np.ndarray
 
     def __post_init__(self):
@@ -79,10 +67,7 @@ def edges_num_states(self) -> np.ndarray:
             Array of shape (num_edges,)
             Number of states for the variables connected to each edge
         """
-        edges_num_states = np.array(
-            [variable.num_states for variable in self.variables], dtype=int
-        )
-        return edges_num_states
+        return self.vars_to_num_states.values()
 
     @staticmethod
     def concatenate_wirings(wirings: Sequence) -> Wiring:
diff --git a/pgmax/groups/enumeration.py b/pgmax/groups/enumeration.py
index 84a9c033..f091c53d 100644
--- a/pgmax/groups/enumeration.py
+++ b/pgmax/groups/enumeration.py
@@ -80,16 +80,19 @@ def _get_variables_to_factors(
         variables_to_factors = collections.OrderedDict(
             [
                 (
-                    frozenset(self.variable_names_for_factors[ii]),
+                    frozenset(variable_names_for_factor),
                     enumeration.EnumerationFactor(
-                        variables=tuple(
-                            self.variable_group[self.variable_names_for_factors[ii]]
+                        vars_to_num_states=collections.OrderedDict(
+                            (var, self.vars_to_num_states[var])
+                            for var in variable_names_for_factor
                         ),
                         factor_configs=self.factor_configs,
                         log_potentials=np.array(self.log_potentials)[ii],
                     ),
                 )
-                for ii in range(len(self.variable_names_for_factors))
+                for ii, variable_names_for_factor in enumerate(
+                    self.variable_names_for_factors
+                )
             ]
         )
         return variables_to_factors
@@ -207,12 +210,8 @@ def __post_init__(self):
         if self.log_potential_matrix is None:
             log_potential_matrix = np.zeros(
                 (
-                    self.variable_group[
-                        self.variable_names_for_factors[0][0]
-                    ].num_states,
-                    self.variable_group[
-                        self.variable_names_for_factors[0][1]
-                    ].num_states,
+                    self.vars_to_num_states[self.variable_names_for_factors[0][0]],
+                    self.vars_to_num_states[self.variable_names_for_factors[0][1]],
                 )
             )
         else:
@@ -245,18 +244,11 @@ def __post_init__(self):
                     f" {len(fac_list)} variables ({fac_list})."
                 )
 
-            # Note: num_states0 = self.variable_group[fac_list[0]] is 2x slower
-            num_states0 = self.variable_group._names_to_variables[
-                fac_list[0]
-            ].num_states
-            num_states1 = self.variable_group._names_to_variables[
-                fac_list[1]
-            ].num_states
-
+            num_states0 = self.vars_to_num_states[fac_list[0]]
+            num_states1 = self.vars_to_num_states[fac_list[1]]
             if not log_potential_matrix.shape[-2:] == (num_states0, num_states1):
                 raise ValueError(
-                    f"The specified pairwise factor {fac_list} (with "
-                    f"{(self.variable_group[fac_list[0]].num_states, self.variable_group[fac_list[1]].num_states)} "
+                    f"The specified pairwise factor {fac_list} (with {(num_states0, num_states1)}"
                     f"configurations) does not match the specified log_potential_matrix "
                     f"(with {log_potential_matrix.shape[-2:]} configurations)."
                 )
@@ -294,16 +286,19 @@ def _get_variables_to_factors(
         variables_to_factors = collections.OrderedDict(
             [
                 (
-                    frozenset(self.variable_names_for_factors[ii]),
+                    frozenset(variable_names_for_factor),
                     enumeration.EnumerationFactor(
-                        variables=tuple(
-                            self.variable_group[self.variable_names_for_factors[ii]]
+                        vars_to_num_states=collections.OrderedDict(
+                            (var, self.vars_to_num_states[var])
+                            for var in variable_names_for_factor
                         ),
                         factor_configs=self.factor_configs,
                         log_potentials=self.log_potentials[ii],
                     ),
                 )
-                for ii in range(len(self.variable_names_for_factors))
+                for ii, variable_names_for_factor in enumerate(
+                    self.variable_names_for_factors
+                )
             ]
         )
         return variables_to_factors
diff --git a/pgmax/groups/logical.py b/pgmax/groups/logical.py
index 5003df02..325a4313 100644
--- a/pgmax/groups/logical.py
+++ b/pgmax/groups/logical.py
@@ -34,14 +34,15 @@ def _get_variables_to_factors(
         variables_to_factors = collections.OrderedDict(
             [
                 (
-                    frozenset(self.variable_names_for_factors[ii]),
+                    frozenset(variable_names_for_factor),
                     self.factor_type(
-                        variables=tuple(
-                            self.variable_group[self.variable_names_for_factors[ii]]
+                        vars_to_num_states=collections.OrderedDict(
+                            (var, self.vars_to_num_states[var])
+                            for var in variable_names_for_factor
                         ),
                     ),
                 )
-                for ii in range(len(self.variable_names_for_factors))
+                for variable_names_for_factor in self.variable_names_for_factors
             ]
         )
         return variables_to_factors
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index 959aabbc..dd381233 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -3,17 +3,27 @@
 import collections
 import itertools
 from dataclasses import dataclass
-from typing import Any, Dict, Hashable, Mapping, OrderedDict, Tuple, Union
+from typing import (
+    Any,
+    Dict,
+    Hashable,
+    Mapping,
+    Optional,
+    OrderedDict,
+    Set,
+    Tuple,
+    Union,
+)
 
 import jax
 import jax.numpy as jnp
 import numpy as np
 
-from pgmax.fg import groups, nodes
+from pgmax.utils import cached_property
 
 
 @dataclass(frozen=True, eq=False)
-class NDVariableArray(groups.VariableGroup):
+class NDVariableArray:
     """Subclass of VariableGroup for n-dimensional grids of variables.
 
     Args:
@@ -22,27 +32,37 @@ class NDVariableArray(groups.VariableGroup):
             the notion of a NumPy ndarray shape)
     """
 
-    num_states: int
     shape: Tuple[int, ...]
+    num_states: Union[int, np.ndarray]
 
-    def _get_names_to_variables(
-        self,
-    ) -> OrderedDict[Union[int, Tuple[int, ...]], nodes.Variable]:
+    def __post_init__(self):
+        # super().__post_init__()
+
+        if isinstance(self.num_states, int):
+            num_states = np.full(self.shape, fill_value=self.num_states)
+            object.__setattr__(self, "num_states", num_states)
+        elif isinstance(self.num_states, np.ndarray):
+            if self.num_states.shape != self.shape:
+                raise ValueError("Should be same shape")
+
+    @cached_property
+    def variable_names(self) -> np.ndarray:
         """Function that generates a dictionary mapping names to variables.
 
         Returns:
             a dictionary mapping all possible names to different variables.
         """
-        names_to_variables: OrderedDict[
-            Union[int, Tuple[int, ...]], nodes.Variable
-        ] = collections.OrderedDict()
-        for name in itertools.product(*[list(range(k)) for k in self.shape]):
-            if len(name) == 1:
-                names_to_variables[name[0]] = nodes.Variable(self.num_states)
-            else:
-                names_to_variables[name] = nodes.Variable(self.num_states)
-
-        return names_to_variables
+        variable_names = np.empty(self.shape, dtype=int)
+        self_hash = self.__hash__()
+        for index in itertools.product(*[list(range(k)) for k in self.shape]):
+            name = hash((self_hash, index))
+            variable_names[index] = name
+        return variable_names
+
+    def __getitem__(self, val):
+        # Numpy indexation will throw IndexError is out-of-bounds
+        # This will be used to add FactorGroups
+        return self.variable_names[val]
 
     def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
         """Function that turns meaningful structured data into a flat data array for internal use.
@@ -57,12 +77,14 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
         Raises:
             ValueError: If the data is not of the correct shape.
         """
-        if data.shape != self.shape and data.shape != self.shape + (self.num_states,):
+        # TODO: what should we do for different number of states
+        if data.shape != self.shape and data.shape != self.shape + (
+            self.num_states.max(),
+        ):
             raise ValueError(
-                f"data should be of shape {self.shape} or {self.shape + (self.num_states,)}. "
+                f"data should be of shape {self.shape} or {self.shape + (self.num_states.max(),)}. "
                 f"Got {data.shape}."
             )
-
         return jax.device_put(data).flatten()
 
     def unflatten(
@@ -89,8 +111,9 @@ def unflatten(
 
         if flat_data.size == np.product(self.shape):
             data = flat_data.reshape(self.shape)
-        elif flat_data.size == np.product(self.shape) * self.num_states:
-            data = flat_data.reshape(self.shape + (self.num_states,))
+        elif flat_data.size == self.num_states.sum():
+            # TODO: what should we dot for different number of states
+            data = flat_data.reshape(self.shape + (self.num_states.max(),))
         else:
             raise ValueError(
                 f"flat_data should be compatible with shape {self.shape} or {self.shape + (self.num_states,)}. "
@@ -100,110 +123,110 @@ def unflatten(
         return data
 
 
-@dataclass(frozen=True, eq=False)
-class VariableDict(groups.VariableGroup):
-    """A variable dictionary that contains a set of variables of the same size
-
-    Args:
-        num_states: The size of the variables in this variable group
-        variable_names: A tuple of all names of the variables in this variable group
-
-    """
-
-    num_states: int
-    variable_names: Tuple[Any, ...]
-
-    def _get_names_to_variables(self) -> OrderedDict[Tuple[int, ...], nodes.Variable]:
-        """Function that generates a dictionary mapping names to variables.
-
-        Returns:
-            a dictionary mapping all possible names to different variables.
-        """
-        names_to_variables: OrderedDict[
-            Tuple[Any, ...], nodes.Variable
-        ] = collections.OrderedDict()
-        for name in self.variable_names:
-            names_to_variables[name] = nodes.Variable(self.num_states)
-
-        return names_to_variables
-
-    def flatten(
-        self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]]
-    ) -> jnp.ndarray:
-        """Function that turns meaningful structured data into a flat data array for internal use.
-
-        Args:
-            data: Meaningful structured data. Should be a mapping with names from self.variable_names.
-                Each value should be an array of shape (1,) (for e.g. MAP decodings) or
-                (self.num_states,) (for e.g. evidence, beliefs).
-
-        Returns:
-            A flat jnp.array for internal use
-
-        Raises:
-            ValueError if:
-                (1) data is referring to a non-existing variable
-                (2) data is not of the correct shape
-        """
-        for name in data:
-            if name not in self._names_to_variables:
-                raise ValueError(
-                    f"data is referring to a non-existent variable {name}."
-                )
-
-            if data[name].shape != (self.num_states,) and data[name].shape != (1,):
-                raise ValueError(
-                    f"Variable {name} expects a data array of shape "
-                    f"{(self.num_states,)} or (1,). Got {data[name].shape}."
-                )
-
-        flat_data = jnp.concatenate([data[name].flatten() for name in self.names])
-        return flat_data
-
-    def unflatten(
-        self, flat_data: Union[np.ndarray, jnp.ndarray]
-    ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]:
-        """Function that recovers meaningful structured data from internal flat data array
-
-        Args:
-            flat_data: Internal flat data array.
-
-        Returns:
-            Meaningful structured data. Should be a mapping with names from self.variable_names.
-                Each value should be an array of shape (1,) (for e.g. MAP decodings) or
-                (self.num_states,) (for e.g. evidence, beliefs).
-
-        Raises:
-            ValueError if:
-                (1) flat_data is not a 1D array
-                (2) flat_data is not of the right shape
-        """
-        if flat_data.ndim != 1:
-            raise ValueError(
-                f"Can only unflatten 1D array. Got a {flat_data.ndim}D array."
-            )
-
-        num_variables = len(self.variable_names)
-        num_variable_states = len(self.variable_names) * self.num_states
-        if flat_data.shape[0] == num_variables:
-            use_num_states = False
-        elif flat_data.shape[0] == num_variable_states:
-            use_num_states = True
-        else:
-            raise ValueError(
-                f"flat_data should be either of shape (num_variables(={len(self.variables)}),), "
-                f"or (num_variable_states(={num_variable_states}),). "
-                f"Got {flat_data.shape}"
-            )
-
-        start = 0
-        data = {}
-        for name in self.variable_names:
-            if use_num_states:
-                data[name] = flat_data[start : start + self.num_states]
-                start += self.num_states
-            else:
-                data[name] = flat_data[np.array([start])]
-                start += 1
-
-        return data
+# @dataclass(frozen=True, eq=False)
+# class VariableDict(groups.VariableGroup):
+#     """A variable dictionary that contains a set of variables of the same size
+
+#     Args:
+#         num_states: The size of the variables in this variable group
+#         variable_names: A tuple of all names of the variables in this variable group
+
+#     """
+
+#     num_states: int
+#     variable_names: Tuple[Any, ...]
+
+#     def _get_names_to_variables(self) -> OrderedDict[Tuple[int, ...], nodes.Variable]:
+#         """Function that generates a dictionary mapping names to variables.
+
+#         Returns:
+#             a dictionary mapping all possible names to different variables.
+#         """
+#         names_to_variables: OrderedDict[
+#             Tuple[Any, ...], nodes.Variable
+#         ] = collections.OrderedDict()
+#         for name in self.variable_names:
+#             names_to_variables[name] = nodes.Variable(self.num_states)
+
+#         return names_to_variables
+
+#     def flatten(
+#         self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]]
+#     ) -> jnp.ndarray:
+#         """Function that turns meaningful structured data into a flat data array for internal use.
+
+#         Args:
+#             data: Meaningful structured data. Should be a mapping with names from self.variable_names.
+#                 Each value should be an array of shape (1,) (for e.g. MAP decodings) or
+#                 (self.num_states,) (for e.g. evidence, beliefs).
+
+#         Returns:
+#             A flat jnp.array for internal use
+
+#         Raises:
+#             ValueError if:
+#                 (1) data is referring to a non-existing variable
+#                 (2) data is not of the correct shape
+#         """
+#         for name in data:
+#             if name not in self._names_to_variables:
+#                 raise ValueError(
+#                     f"data is referring to a non-existent variable {name}."
+#                 )
+
+#             if data[name].shape != (self.num_states,) and data[name].shape != (1,):
+#                 raise ValueError(
+#                     f"Variable {name} expects a data array of shape "
+#                     f"{(self.num_states,)} or (1,). Got {data[name].shape}."
+#                 )
+
+#         flat_data = jnp.concatenate([data[name].flatten() for name in self.names])
+#         return flat_data
+
+#     def unflatten(
+#         self, flat_data: Union[np.ndarray, jnp.ndarray]
+#     ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]:
+#         """Function that recovers meaningful structured data from internal flat data array
+
+#         Args:
+#             flat_data: Internal flat data array.
+
+#         Returns:
+#             Meaningful structured data. Should be a mapping with names from self.variable_names.
+#                 Each value should be an array of shape (1,) (for e.g. MAP decodings) or
+#                 (self.num_states,) (for e.g. evidence, beliefs).
+
+#         Raises:
+#             ValueError if:
+#                 (1) flat_data is not a 1D array
+#                 (2) flat_data is not of the right shape
+#         """
+#         if flat_data.ndim != 1:
+#             raise ValueError(
+#                 f"Can only unflatten 1D array. Got a {flat_data.ndim}D array."
+#             )
+
+#         num_variables = len(self.variable_names)
+#         num_variable_states = len(self.variable_names) * self.num_states
+#         if flat_data.shape[0] == num_variables:
+#             use_num_states = False
+#         elif flat_data.shape[0] == num_variable_states:
+#             use_num_states = True
+#         else:
+#             raise ValueError(
+#                 f"flat_data should be either of shape (num_variables(={len(self.variables)}),), "
+#                 f"or (num_variable_states(={num_variable_states}),). "
+#                 f"Got {flat_data.shape}"
+#             )
+
+#         start = 0
+#         data = {}
+#         for name in self.variable_names:
+#             if use_num_states:
+#                 data[name] = flat_data[start : start + self.num_states]
+#                 start += self.num_states
+#             else:
+#                 data[name] = flat_data[np.array([start])]
+#                 start += 1
+
+#         return data
diff --git a/tests/fg/test_wiring.py b/tests/fg/test_wiring.py
index 232309ee..a3b6ae2b 100644
--- a/tests/fg/test_wiring.py
+++ b/tests/fg/test_wiring.py
@@ -19,10 +19,10 @@ def test_wiring_with_PairwiseFactorGroup():
     B = vgroup.NDVariableArray(num_states=2, shape=(10,))
 
     # First test that compile_wiring enforces the correct factor_edges_num_states shape
-    fg = graph.FactorGraph(variables=dict(A=A, B=B))
+    fg = graph.FactorGraph(variables=[A, B])
     fg.add_factor_group(
         factory=enumeration.PairwiseFactorGroup,
-        variable_names_for_factors=[[("A", idx), ("B", idx)] for idx in range(10)],
+        variable_names_for_factors=[[A[idx], B[idx]] for idx in range(10)],
     )
     factor_group = fg.factor_groups[enumeration_factor.EnumerationFactor][0]
     object.__setattr__(
@@ -37,27 +37,27 @@ def test_wiring_with_PairwiseFactorGroup():
         factor_group.compile_wiring(fg._vars_to_starts)
 
     # FactorGraph with a single PairwiseFactorGroup
-    fg1 = graph.FactorGraph(variables=dict(A=A, B=B))
+    fg1 = graph.FactorGraph(variables=[A, B])
     fg1.add_factor_group(
         factory=enumeration.PairwiseFactorGroup,
-        variable_names_for_factors=[[("A", idx), ("B", idx)] for idx in range(10)],
+        variable_names_for_factors=[[A[idx], B[idx]] for idx in range(10)],
     )
     assert len(fg1.factor_groups[enumeration_factor.EnumerationFactor]) == 1
 
     # FactorGraph with multiple PairwiseFactorGroup
-    fg2 = graph.FactorGraph(variables=dict(A=A, B=B))
+    fg2 = graph.FactorGraph(variables=[A, B])
     for idx in range(10):
         fg2.add_factor_group(
             factory=enumeration.PairwiseFactorGroup,
-            variable_names_for_factors=[[("A", idx), ("B", idx)]],
+            variable_names_for_factors=[[A[idx], B[idx]]],
         )
     assert len(fg2.factor_groups[enumeration_factor.EnumerationFactor]) == 10
 
     # FactorGraph with multiple SingleFactorGroup
-    fg3 = graph.FactorGraph(variables=dict(A=A, B=B))
+    fg3 = graph.FactorGraph(variables=[A, B])
     for idx in range(10):
         fg3.add_factor_by_type(
-            variable_names=[("A", idx), ("B", idx)],
+            variable_names=[A[idx], B[idx]],
             factor_type=enumeration_factor.EnumerationFactor,
             **{
                 "factor_configs": np.array([[0, 0], [0, 1], [1, 0], [1, 1]]),
@@ -98,29 +98,27 @@ def test_wiring_with_ORFactorGroup():
     C = vgroup.NDVariableArray(num_states=2, shape=(10,))
 
     # FactorGraph with a single ORFactorGroup
-    fg1 = graph.FactorGraph(variables=dict(A=A, B=B, C=C))
+    fg1 = graph.FactorGraph(variables=[A, B, C])
     fg1.add_factor_group(
         factory=logical.ORFactorGroup,
-        variable_names_for_factors=[
-            [("A", idx), ("B", idx), ("C", idx)] for idx in range(10)
-        ],
+        variable_names_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)],
     )
     assert len(fg1.factor_groups[logical_factor.ORFactor]) == 1
 
     # FactorGraph with multiple ORFactorGroup
-    fg2 = graph.FactorGraph(variables=dict(A=A, B=B, C=C))
+    fg2 = graph.FactorGraph(variables=[A, B, C])
     for idx in range(10):
         fg2.add_factor_group(
             factory=logical.ORFactorGroup,
-            variable_names_for_factors=[[("A", idx), ("B", idx), ("C", idx)]],
+            variable_names_for_factors=[[A[idx], B[idx], C[idx]]],
         )
     assert len(fg2.factor_groups[logical_factor.ORFactor]) == 10
 
     # FactorGraph with multiple SingleFactorGroup
-    fg3 = graph.FactorGraph(variables=dict(A=A, B=B, C=C))
+    fg3 = graph.FactorGraph(variables=[A, B, C])
     for idx in range(10):
         fg3.add_factor_by_type(
-            variable_names=[("A", idx), ("B", idx), ("C", idx)],
+            variable_names=[A[idx], B[idx], C[idx]],
             factor_type=logical_factor.ORFactor,
         )
     assert len(fg3.factor_groups[logical_factor.ORFactor]) == 10
@@ -155,29 +153,27 @@ def test_wiring_with_ANDFactorGroup():
     C = vgroup.NDVariableArray(num_states=2, shape=(10,))
 
     # FactorGraph with a single ANDFactorGroup
-    fg1 = graph.FactorGraph(variables=dict(A=A, B=B, C=C))
+    fg1 = graph.FactorGraph(variables=[A, B, C])
     fg1.add_factor_group(
         factory=logical.ANDFactorGroup,
-        variable_names_for_factors=[
-            [("A", idx), ("B", idx), ("C", idx)] for idx in range(10)
-        ],
+        variable_names_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)],
     )
     assert len(fg1.factor_groups[logical_factor.ANDFactor]) == 1
 
     # FactorGraph with multiple ANDFactorGroup
-    fg2 = graph.FactorGraph(variables=dict(A=A, B=B, C=C))
+    fg2 = graph.FactorGraph(variables=[A, B, C])
     for idx in range(10):
         fg2.add_factor_group(
             factory=logical.ANDFactorGroup,
-            variable_names_for_factors=[[("A", idx), ("B", idx), ("C", idx)]],
+            variable_names_for_factors=[[A[idx], B[idx], C[idx]]],
         )
     assert len(fg2.factor_groups[logical_factor.ANDFactor]) == 10
 
     # FactorGraph with multiple SingleFactorGroup
-    fg3 = graph.FactorGraph(variables=dict(A=A, B=B, C=C))
+    fg3 = graph.FactorGraph(variables=[A, B, C])
     for idx in range(10):
         fg3.add_factor_by_type(
-            variable_names=[("A", idx), ("B", idx), ("C", idx)],
+            variable_names=[A[idx], B[idx], C[idx]],
             factor_type=logical_factor.ANDFactor,
         )
     assert len(fg3.factor_groups[logical_factor.ANDFactor]) == 10

From d816f63a1f74ddf018d1d9fa30f83d094a7a2c57 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Tue, 12 Apr 2022 18:10:02 +0000
Subject: [PATCH 02/35] Falke8

---
 pgmax/fg/graph.py         | 24 +++++++++++-------------
 pgmax/fg/groups.py        |  5 -----
 pgmax/groups/variables.py | 13 +------------
 3 files changed, 12 insertions(+), 30 deletions(-)

diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index f0a0aefe..ec03714c 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -1,7 +1,5 @@
 from __future__ import annotations
 
-from this import d
-
 """A module containing the core class to specify a Factor Graph."""
 
 import collections
@@ -102,9 +100,7 @@ def __post_init__(self):
         print("2", time.time() - start)
 
         # Used to add FactorGroups
-        # TODO: dict is faster
-        aa = zip(vars_names, vars_num_states)
-        print("3", time.time() - start)
+        # TODO: move to dict, which is faster
         self._vars_to_num_states: OrderedDict[int, int] = collections.OrderedDict(
             zip(vars_names, vars_num_states)
         )
@@ -434,12 +430,14 @@ class FactorGraphState:
     wiring: OrderedDict[type, nodes.Wiring]
 
     def __post_init__(self):
-        for field in self.__dataclass_fields__:
-            if isinstance(getattr(self, field), np.ndarray):
-                getattr(self, field).flags.writeable = False
+        for this_field in self.__dataclass_fields__:
+            if isinstance(getattr(self, this_field), np.ndarray):
+                getattr(self, this_field).flags.writeable = False
 
-            if isinstance(getattr(self, field), Mapping):
-                object.__setattr__(self, field, MappingProxyType(getattr(self, field)))
+            if isinstance(getattr(self, this_field), Mapping):
+                object.__setattr__(
+                    self, this_field, MappingProxyType(getattr(self, this_field))
+                )
 
 
 @dataclass(frozen=True, eq=False)
@@ -848,9 +846,9 @@ class BPArrays:
     evidence: Union[np.ndarray, jnp.ndarray]
 
     def __post_init__(self):
-        for field in self.__dataclass_fields__:
-            if isinstance(getattr(self, field), np.ndarray):
-                getattr(self, field).flags.writeable = False
+        for this_field in self.__dataclass_fields__:
+            if isinstance(getattr(self, this_field), np.ndarray):
+                getattr(self, this_field).flags.writeable = False
 
     def tree_flatten(self):
         return jax.tree_util.tree_flatten(asdict(self))
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index c90fe106..a08541b0 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -1,15 +1,10 @@
 """A module containing the base classes for variable and factor groups in a Factor Graph."""
 
-import collections
 import inspect
-import typing
 from dataclasses import dataclass, field
-from types import MappingProxyType
 from typing import (
     Any,
-    Collection,
     FrozenSet,
-    Hashable,
     List,
     Mapping,
     OrderedDict,
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index dd381233..1ca5d4bf 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -1,19 +1,8 @@
 """A module containing the variables group classes inheriting from the base VariableGroup."""
 
-import collections
 import itertools
 from dataclasses import dataclass
-from typing import (
-    Any,
-    Dict,
-    Hashable,
-    Mapping,
-    Optional,
-    OrderedDict,
-    Set,
-    Tuple,
-    Union,
-)
+from typing import Tuple, Union
 
 import jax
 import jax.numpy as jnp

From 58b011535233f905e91c84a2251d57e65a49335b Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Tue, 12 Apr 2022 23:54:05 +0000
Subject: [PATCH 03/35] Test + HashableDict

---
 examples/ising_model.py              | 16 ++---
 examples/pmp_binary_deconvolution.py | 41 +++++++------
 examples/rbm.py                      | 51 ++++++++++++++--
 pgmax/factors/enumeration.py         |  3 -
 pgmax/fg/graph.py                    | 71 +++++++++++++++-------
 pgmax/fg/groups.py                   | 15 +++++
 pgmax/groups/variables.py            | 37 +++++-------
 tests/factors/test_or.py             | 89 +++++++++++++++++-----------
 8 files changed, 210 insertions(+), 113 deletions(-)

diff --git a/examples/ising_model.py b/examples/ising_model.py
index ef5a267b..837c764e 100644
--- a/examples/ising_model.py
+++ b/examples/ising_model.py
@@ -46,25 +46,19 @@
     log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
 )
 
-# %%
-from pgmax.factors import enumeration as enumeration_factor
-
-factors = fg.factor_groups[enumeration_factor.EnumerationFactor][0].factors
-
 # %% [markdown]
 # ### Run inference and visualize results
 
-import imp
-
 # %%
-from pgmax.fg import graph
-
-imp.reload(graph)
 bp = graph.BP(fg.bp_state, temperature=0)
 
+# %%
+d = {variables: 1, variables.__hash__(): 2}
+hd = graph.HashableDict(d)
+hd[variables]
+
 # %%
 # TODO: check\ time for before BP vs time for BP
-# TODO: why bug when done twice?
 # TODO: time PGMAX vs PMP
 bp_arrays = bp.init(
     evidence_updates={variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py
index b10f5688..a5c91c22 100644
--- a/examples/pmp_binary_deconvolution.py
+++ b/examples/pmp_binary_deconvolution.py
@@ -106,6 +106,8 @@ def plot_images(images, display=True, nr=None):
 #  - a second set of ORFactors, which maps SW to X and model (binary) features overlapping.
 #
 # See Section 5.6 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) for more details.
+#
+# import imp
 
 import imp
 
@@ -150,7 +152,7 @@ def plot_images(images, display=True, nr=None):
 start = time.time()
 # Factor graph
 fg = graph.FactorGraph(variables=[S, W, SW, X])
-print("x", time.time() - start)
+print(time.time() - start)
 
 # Define the ANDFactors
 variable_names_for_ANDFactors = []
@@ -232,7 +234,7 @@ def plot_images(images, display=True, nr=None):
 
 # %%
 pW = 0.25
-pS = 1e-100
+pS = 1e-70
 pX = 1e-100
 
 # Sparsity inducing priors for W and S
@@ -248,26 +250,28 @@ def plot_images(images, display=True, nr=None):
 
 # %%
 np.random.seed(seed=40)
-n_samples = 1
 
 start = time.time()
-bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)(
+bp_arrays = bp.init(
     evidence_updates={
-        "S": uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape),
-        "W": uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape),
-        "SW": np.zeros(shape=(n_samples,) + SW.shape),
-        "X": uX[None] + np.zeros(shape=(n_samples,) + uX.shape),
+        S: uS + np.random.gumbel(size=uS.shape),
+        W: uW + np.random.gumbel(size=uW.shape),
+        SW: np.zeros(SW.shape),
+        X: uX + np.zeros(shape=uX.shape),
     },
 )
 print("Time", time.time() - start)
-bp_arrays = jax.vmap(
-    functools.partial(bp.run_bp, num_iters=100, damping=0.5),
-    in_axes=0,
-    out_axes=0,
-)(bp_arrays)
+bp_arrays = bp.run_bp(bp_arrays, num_iters=100, damping=0.5)
 print("Time", time.time() - start)
-# beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
-# map_states = graph.decode_map_states(beliefs)
+output = bp.get_bp_output(bp_arrays)
+print("Time", time.time() - start)
+
+# %%
+_ = plot_images(output.map_states[W].reshape(-1, feat_height, feat_width), nr=1)
+
+# %%
+
+# %%
 
 # %% [markdown]
 # We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap`
@@ -276,6 +280,7 @@ def plot_images(images, display=True, nr=None):
 np.random.seed(seed=40)
 n_samples = 4
 
+start = time.time()
 bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)(
     evidence_updates={
         "S": uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape),
@@ -284,13 +289,15 @@ def plot_images(images, display=True, nr=None):
         "X": uX[None] + np.zeros(shape=(n_samples,) + uX.shape),
     },
 )
+print("Time", time.time() - start)
 bp_arrays = jax.vmap(
     functools.partial(bp.run_bp, num_iters=100, damping=0.5),
     in_axes=0,
     out_axes=0,
 )(bp_arrays)
-beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
-map_states = graph.decode_map_states(beliefs)
+print("Time", time.time() - start)
+# beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
+# map_states = graph.decode_map_states(beliefs)
 
 # %% [markdown]
 # Visualizing the MAP decoding, we see that we have 4 good random samples (one per row) from the posterior!
diff --git a/examples/rbm.py b/examples/rbm.py
index e4d26a83..0f820851 100644
--- a/examples/rbm.py
+++ b/examples/rbm.py
@@ -61,6 +61,49 @@
 fg = graph.FactorGraph(variables=[hidden_variables, visible_variables])
 print("Time", time.time() - start)
 
+# %%
+import itertools
+
+start = time.time()
+variable_names_for_factors = factors = list(
+    map(
+        lambda ij: (
+            hidden_variables.variable_names[ij[0]],
+            visible_variables.variable_names[ij[1]],
+        ),
+        list(itertools.product(range(bh.shape[0]), range(bv.shape[0]))),
+    )
+)
+print("Time", time.time() - start, len(variable_names_for_factors))
+
+import numba as nb
+
+
+@nb.jit(parallel=False, cache=True, fastmath=True, nopython=True)
+def run_numba(h, v, f):
+    for h_idx in nb.prange(h.shape[0]):
+        for v_idx in nb.prange(v.shape[0]):
+            f[h_idx * v.shape[0] + v_idx, 0] = h[h_idx]
+            f[h_idx * v.shape[0] + v_idx, 1] = v[v_idx]
+
+
+start = time.time()
+variable_names_for_factors = np.empty(shape=(bv.shape[0] * bh.shape[0], 2), dtype=int)
+run_numba(
+    hidden_variables.variable_names,
+    visible_variables.variable_names,
+    variable_names_for_factors,
+)
+print("Time", time.time() - start, len(variable_names_for_factors))
+
+start = time.time()
+variable_names_for_factors = [
+    [hidden_variables[ii], visible_variables[jj]]
+    for ii in range(bh.shape[0])
+    for jj in range(bv.shape[0])
+]
+print("Time", time.time() - start, len(variable_names_for_factors))
+
 # %% [markdown]
 # [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup) (e.g. an [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)), or a list/dictionary of [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
 #
@@ -82,7 +125,6 @@
     factor_configs=np.arange(2)[:, None],
     log_potentials=np.stack([np.zeros_like(bv), bv], axis=1),
 )
-
 # Add pairwise factors
 log_potential_matrix = np.zeros(W.shape + (2, 2)).reshape((-1, 2, 2))
 log_potential_matrix[:, 1, 1] = W.ravel()
@@ -94,10 +136,11 @@
         for ii in range(bh.shape[0])
         for jj in range(bv.shape[0])
     ],
+    #    variable_names_for_factors=variable_names_for_factors,
     log_potential_matrix=log_potential_matrix,
 )
 
-# fg.add_factor_group(factory=enumeration.PairwiseFactorGroup, variable_names_for_factors=[[hidden_variables[ii], visible_variables[jj]]for ii in range(bh.shape[0])for jj in range(bv.shape[0])], log_potential_matrix=log_potential_matrix,)
+# # %snakeviz fg.add_factor_group(factory=enumeration.PairwiseFactorGroup, variable_names_for_factors=v, log_potential_matrix=log_potential_matrix,)
 print("Time", time.time() - start)
 
 
@@ -168,7 +211,7 @@
     evidence_updates={
         hidden_variables: np.random.gumbel(size=(bh.shape[0], 2)),
         visible_variables: np.random.gumbel(size=(bv.shape[0], 2)),
-    },
+    }
 )
 print("Time", time.time() - start)
 bp_arrays = bp.run_bp(bp_arrays, num_iters=100, damping=0.5)
@@ -221,11 +264,11 @@
     out_axes=0,
 )(bp_arrays)
 
+# TODO: problem
 outputs = jax.vmap(bp.get_bp_output, in_axes=0, out_axes=0)(bp_arrays)
 # map_states = graph.decode_map_states(beliefs)
 
 # %%
-O
 
 # %% [markdown]
 # Visualizing the MAP decodings (Figure [fig:rbm_multiple_digits]), we see that we have sampled 10 MNIST digits in parallel!
diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py
index e2521597..450cc5ab 100644
--- a/pgmax/factors/enumeration.py
+++ b/pgmax/factors/enumeration.py
@@ -184,9 +184,6 @@ def compile_wiring(
         var_states = np.array(
             [vars_to_starts[variable] for variable in variables_for_factors]
         )
-        # num_states = np.array(
-        #     [variable.num_states for variable in variables_for_factors]
-        # )
         num_states_cumsum = np.insert(np.cumsum(factor_edges_num_states), 0, 0)
         var_states_for_edges = np.empty(shape=(num_states_cumsum[-1],), dtype=int)
         _compile_var_states_numba(var_states_for_edges, num_states_cumsum, var_states)
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index ec03714c..c144660f 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -127,7 +127,6 @@ def add_factor(
         variable_names: List,
         factor_configs: np.ndarray,
         log_potentials: Optional[np.ndarray] = None,
-        name: Optional[str] = None,
     ) -> None:
         """Function to add a single factor to the FactorGraph.
 
@@ -151,7 +150,7 @@ def add_factor(
             factor_configs=factor_configs,
             log_potentials=log_potentials,
         )
-        self._register_factor_group(factor_group, name)
+        self._register_factor_group(factor_group)
 
     def add_factor_by_type(
         self, variable_names: List[int], factor_type: type, *args, **kwargs
@@ -221,7 +220,6 @@ def _register_factor_group(self, factor_group: groups.FactorGroup) -> None:
                 var_names
                 in self._factor_types_to_variable_names_for_factors[factor_type]
             ):
-                print(len(var_names_for_factor))
                 raise ValueError(
                     f"A Factor of type {factor_type} involving variables {var_names} already exists. Please merge the corresponding factors."
                 )
@@ -492,8 +490,7 @@ def update_log_potentials(
         (1) Provided log_potentials shape does not match the expected log_potentials shape.
         (2) Provided name is not valid for log_potentials updates.
     """
-    for name in updates:
-        data = updates[name]
+    for name, data in updates.items():
         if name in fg_state.named_factor_groups:
             factor_group = fg_state.named_factor_groups[name]
 
@@ -613,8 +610,7 @@ def update_ftov_msgs(
         (1) provided ftov_msgs shape does not match the expected ftov_msgs shape.
         (2) provided name is not valid for ftov_msgs updates.
     """
-    for names in updates:
-        data = updates[names]
+    for names, data in updates.items():
         if names in fg_state.variable_group.names:
             variable = fg_state.variable_group[names]
             if data.shape != (variable.num_states,):
@@ -729,7 +725,7 @@ def __setitem__(self, names, data) -> None:
         )
 
 
-# @functools.partial(jax.jit, static_argnames="fg_state")
+@functools.partial(jax.jit, static_argnames="fg_state")
 def update_evidence(
     evidence: jnp.ndarray, updates: Dict[Any, jnp.ndarray], fg_state: FactorGraphState
 ) -> jnp.ndarray:
@@ -743,11 +739,9 @@ def update_evidence(
     Returns:
         A flat jnp array containing updated evidence.
     """
-    for var_group_name in updates:
-        data = updates[var_group_name]
-        print(data.shape)
-        if var_group_name.__hash__() in fg_state.variable_group:
-            variable_group = fg_state.variable_group[var_group_name.__hash__()]
+    for var_group_name, data in updates.items():
+        if var_group_name in fg_state.variable_group:
+            variable_group = fg_state.variable_group[var_group_name]
             first_variable = variable_group.variable_names.flatten()[0]
             start_index = fg_state.vars_to_starts[first_variable]
             flat_data = variable_group.flatten(data)
@@ -974,17 +968,19 @@ def update(
 
         if log_potentials_updates is not None:
             log_potentials = update_log_potentials(
-                log_potentials, log_potentials_updates, bp_state.fg_state
+                log_potentials, HashableDict(log_potentials_updates), bp_state.fg_state
             )
 
         if ftov_msgs_updates is not None:
             ftov_msgs = update_ftov_msgs(
-                ftov_msgs, ftov_msgs_updates, bp_state.fg_state
+                ftov_msgs, HashableDict(ftov_msgs_updates), bp_state.fg_state
             )
 
         if evidence_updates is not None:
-            print(type(evidence), type(evidence_updates))
-            evidence = update_evidence(evidence, evidence_updates, bp_state.fg_state)
+            # Note: if we overwrite variables.__hash__ then hash may change when we call jit
+            evidence = update_evidence(
+                evidence, HashableDict(evidence_updates), bp_state.fg_state
+            )
 
         return BPArrays(
             log_potentials=log_potentials, ftov_msgs=ftov_msgs, evidence=evidence
@@ -1082,7 +1078,7 @@ def get_bp_output(bp_arrays: BPArrays) -> Any:
         """
 
         @jax.jit
-        def compute_flat_beliefs(bp_arrays):
+        def compute_flat_beliefs(bp_arrays, var_states_for_edges):
             flat_beliefs = (
                 jax.device_put(bp_arrays.evidence)
                 .at[jax.device_put(var_states_for_edges)]
@@ -1091,7 +1087,8 @@ def compute_flat_beliefs(bp_arrays):
             return flat_beliefs
 
         return BeliefPropagationOutputs(
-            compute_flat_beliefs(bp_arrays), bp_state.fg_state.variable_group
+            compute_flat_beliefs(bp_arrays, var_states_for_edges),
+            bp_state.fg_state.variable_group,
         )
 
     bp = BeliefPropagation(
@@ -1104,23 +1101,47 @@ def compute_flat_beliefs(bp_arrays):
     return bp
 
 
+@jax.tree_util.register_pytree_node_class
 @dataclass(frozen=True, eq=False)
 class HashableDict:
+    "A convenient class"
     d: Dict = field(default_factory=dict)
 
+    def __post_init__(self):
+        # Allows to initialize from dict without knowing hash
+        new_d = {k.__hash__(): v for k, v in self.d.items()}
+        object.__setattr__(self, "d", new_d)
+
+    @functools.lru_cache()
+    def keys(self):
+        return self.d.keys()
+
+    def items(self):
+        return self.d.items()
+
     def __setitem__(self, key, value):
+        # Allows to copy keys from another HashableDict
         self.d[key] = value
 
     def __getitem__(self, value):
+        # Allows to retrieve from dict without knowing hash
         return self.d[value.__hash__()]
 
+    def tree_flatten(self):
+        return jax.tree_util.tree_flatten(asdict(self))
 
+    @classmethod
+    def tree_unflatten(cls, aux_data, children):
+        return cls(**aux_data.unflatten(children))
+
+
+@jax.tree_util.register_pytree_node_class
 @dataclass(frozen=True, eq=False)
 class BeliefPropagationOutputs:
     # beliefs: An array or a PyTree container containing the beliefs for the variables.
     flat_beliefs: jnp.ndarray
     variable_groups: Mapping[int, groups.VariableGroup]
-    beliefs: HashableDict = field(init=False, default_factory=dict)
+    beliefs: HashableDict = field(init=False)
 
     def __post_init__(self):
         if self.flat_beliefs.ndim != 1:
@@ -1162,7 +1183,7 @@ def unflatten(self) -> None:
                 f"Got {self.flat_beliefs.shape}"
             )
 
-        beliefs = {}
+        beliefs = HashableDict()
         start = 0
         for name, variable_group in self.variable_groups.items():
             if use_num_states:
@@ -1218,3 +1239,11 @@ def _get_marginals(beliefs) -> Any:
         for name, beliefs in self.beliefs.items():
             marginals[name] = _get_marginals(beliefs)
         return marginals
+
+    def tree_flatten(self):
+        return jax.tree_util.tree_flatten(asdict(self))
+
+    @classmethod
+    def tree_unflatten(cls, aux_data, children):
+        # TODO: fix this
+        return cls(**aux_data.unflatten(children))
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index a08541b0..f9f3dd41 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -311,3 +311,18 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any:
         raise NotImplementedError(
             "SingleFactorGroup does not support vectorized factor operations."
         )
+
+
+# def get_ndvariable_names_for_factor_groups(
+#     arrays: List[Any]
+
+# ):
+#     "Util function"
+
+# import numba as nb
+# @nb.jit(parallel=False, cache=True, fastmath=True, nopython=True)
+# def run_numba(h, v, f):
+#     for h_idx in nb.prange(h.shape[0]):
+#         for v_idx in nb.prange(v.shape[0]):
+#             f[h_idx * v.shape[0] + v_idx, 0] = h[h_idx]
+#             f[h_idx * v.shape[0] + v_idx, 1] = v[v_idx]
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index 1ca5d4bf..4c6587f3 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -1,6 +1,7 @@
 """A module containing the variables group classes inheriting from the base VariableGroup."""
 
 import itertools
+import random
 from dataclasses import dataclass
 from typing import Tuple, Union
 
@@ -34,6 +35,10 @@ def __post_init__(self):
             if self.num_states.shape != self.shape:
                 raise ValueError("Should be same shape")
 
+    def __getitem__(self, val):
+        # Numpy indexation will throw IndexError for us if out-of-bounds
+        return self.variable_names[val]
+
     @cached_property
     def variable_names(self) -> np.ndarray:
         """Function that generates a dictionary mapping names to variables.
@@ -41,17 +46,10 @@ def variable_names(self) -> np.ndarray:
         Returns:
             a dictionary mapping all possible names to different variables.
         """
-        variable_names = np.empty(self.shape, dtype=int)
-        self_hash = self.__hash__()
-        for index in itertools.product(*[list(range(k)) for k in self.shape]):
-            name = hash((self_hash, index))
-            variable_names[index] = name
-        return variable_names
-
-    def __getitem__(self, val):
-        # Numpy indexation will throw IndexError is out-of-bounds
-        # This will be used to add FactorGroups
-        return self.variable_names[val]
+        # Overwite default hash as it does not give enough spacing across consecutive objects
+        this_hash = random.randint(0, 2**63)
+        indices = np.reshape(np.arange(np.product(self.shape)), self.shape)
+        return this_hash + indices
 
     def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
         """Function that turns meaningful structured data into a flat data array for internal use.
@@ -112,32 +110,29 @@ def unflatten(
         return data
 
 
+# TODO: delete?
 # @dataclass(frozen=True, eq=False)
-# class VariableDict(groups.VariableGroup):
+# class VariableDict():
 #     """A variable dictionary that contains a set of variables of the same size
 
 #     Args:
 #         num_states: The size of the variables in this variable group
-#         variable_names: A tuple of all names of the variables in this variable group
+#         num_variables: The number of variables
 
 #     """
 
 #     num_states: int
-#     variable_names: Tuple[Any, ...]
+#     num_variables: int
 
-#     def _get_names_to_variables(self) -> OrderedDict[Tuple[int, ...], nodes.Variable]:
+#     @cached_property
+#     def variable_names(self) -> np.ndarray:
 #         """Function that generates a dictionary mapping names to variables.
 
 #         Returns:
 #             a dictionary mapping all possible names to different variables.
 #         """
-#         names_to_variables: OrderedDict[
-#             Tuple[Any, ...], nodes.Variable
-#         ] = collections.OrderedDict()
-#         for name in self.variable_names:
-#             names_to_variables[name] = nodes.Variable(self.num_states)
+#         return self.__hash__() + np.arange(self.num_states)
 
-#         return names_to_variables
 
 #     def flatten(
 #         self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]]
diff --git a/tests/factors/test_or.py b/tests/factors/test_or.py
index 1e538a8d..e634fbab 100644
--- a/tests/factors/test_or.py
+++ b/tests/factors/test_or.py
@@ -25,6 +25,7 @@ def test_run_bp_with_OR_factors():
     Note: for the first seed, add all the EnumerationFactors to FG1 and all the ORFactors to FG2
     """
     for idx in range(10):
+        print("it", idx)
         np.random.seed(idx)
 
         # Parameters
@@ -43,30 +44,41 @@ def test_run_bp_with_OR_factors():
         parents_variables1 = vgroup.NDVariableArray(
             num_states=2, shape=(num_parents.sum(),)
         )
-        children_variable1 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
-        fg1 = graph.FactorGraph(
-            variables=dict(parents=parents_variables1, children=children_variable1)
-        )
+        children_variables1 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
+        fg1 = graph.FactorGraph(variables=[parents_variables1, children_variables1])
 
         # Graph 2
         parents_variables2 = vgroup.NDVariableArray(
             num_states=2, shape=(num_parents.sum(),)
         )
-        children_variable2 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
-        fg2 = graph.FactorGraph(
-            variables=dict(parents=parents_variables2, children=children_variable2)
-        )
+        children_variables2 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
+        fg2 = graph.FactorGraph(variables=[parents_variables2, children_variables2])
 
-        # Option 1: Define EnumerationFactors equivalent to the ORFactors
+        # Variable names for factors
+        variable_names_for_factors1 = []
+        variable_names_for_factors2 = []
         for factor_idx in range(num_factors):
-            this_num_parents = num_parents[factor_idx]
-            variable_names = [
-                ("parents", idx)
+            variable_names1 = [
+                parents_variables1[idx]
+                for idx in range(
+                    num_parents_cumsum[factor_idx],
+                    num_parents_cumsum[factor_idx + 1],
+                )
+            ] + [children_variables1[factor_idx]]
+            variable_names_for_factors1.append(variable_names1)
+
+            variable_names2 = [
+                parents_variables2[idx]
                 for idx in range(
                     num_parents_cumsum[factor_idx],
                     num_parents_cumsum[factor_idx + 1],
                 )
-            ] + [("children", factor_idx)]
+            ] + [children_variables2[factor_idx]]
+            variable_names_for_factors2.append(variable_names2)
+
+        # Option 1: Define EnumerationFactors equivalent to the ORFactors
+        for factor_idx in range(num_factors):
+            this_num_parents = num_parents[factor_idx]
 
             configs = np.array(list(product([0, 1], repeat=this_num_parents + 1)))
             # Children state is last
@@ -82,7 +94,7 @@ def test_run_bp_with_OR_factors():
             if factor_idx < num_factors // 2:
                 # Add the first half of factors to FactorGraph1
                 fg1.add_factor(
-                    variable_names=variable_names,
+                    variable_names=variable_names_for_factors1[factor_idx],
                     factor_configs=valid_configs,
                     log_potentials=np.zeros(valid_configs.shape[0]),
                 )
@@ -90,14 +102,14 @@ def test_run_bp_with_OR_factors():
                 if idx != 0:
                     # Add the second half of factors to FactorGraph2
                     fg2.add_factor(
-                        variable_names=variable_names,
+                        variable_names=variable_names_for_factors2[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
                 else:
                     # Add all the EnumerationFactors to FactorGraph1 for the first iter
                     fg1.add_factor(
-                        variable_names=variable_names,
+                        variable_names=variable_names_for_factors1[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
@@ -108,28 +120,22 @@ def test_run_bp_with_OR_factors():
         variable_names_for_ORFactors_fg2 = []
 
         for factor_idx in range(num_factors):
-            variables_names_for_ORFactor = [
-                ("parents", idx)
-                for idx in range(
-                    num_parents_cumsum[factor_idx],
-                    num_parents_cumsum[factor_idx + 1],
-                )
-            ] + [("children", factor_idx)]
             if factor_idx < num_factors // 2:
                 # Add the first half of factors to FactorGraph2
-                variable_names_for_ORFactors_fg2.append(variables_names_for_ORFactor)
+                variable_names_for_ORFactors_fg2.append(
+                    variable_names_for_factors2[factor_idx]
+                )
             else:
                 if idx != 0:
                     # Add the second half of factors to FactorGraph1
                     variable_names_for_ORFactors_fg1.append(
-                        variables_names_for_ORFactor
+                        variable_names_for_factors1[factor_idx]
                     )
                 else:
                     # Add all the ORFactors to FactorGraph2 for the first iter
                     variable_names_for_ORFactors_fg2.append(
-                        variables_names_for_ORFactor
+                        variable_names_for_factors2[factor_idx]
                     )
-
         if idx != 0:
             fg1.add_factor_group(
                 factory=logical.ORFactorGroup,
@@ -144,19 +150,30 @@ def test_run_bp_with_OR_factors():
         bp1 = graph.BP(fg1.bp_state, temperature=temperature)
         bp2 = graph.BP(fg2.bp_state, temperature=temperature)
 
-        evidence_updates = {
-            "parents": jax.device_put(np.random.gumbel(size=(sum(num_parents), 2))),
-            "children": jax.device_put(np.random.gumbel(size=(num_factors, 2))),
+        evidence_parents = jax.device_put(np.random.gumbel(size=(sum(num_parents), 2)))
+        evidence_children = jax.device_put(np.random.gumbel(size=(num_factors, 2)))
+
+        evidence_updates1 = {
+            parents_variables1: evidence_parents,
+            children_variables1: evidence_children,
+        }
+        evidence_updates2 = {
+            parents_variables2: evidence_parents,
+            children_variables2: evidence_children,
         }
 
-        bp_arrays1 = bp1.init(evidence_updates=evidence_updates)
+        bp_arrays1 = bp1.init(evidence_updates=evidence_updates1)
         bp_arrays1 = bp1.run_bp(bp_arrays1, num_iters=5)
-        bp_arrays2 = bp2.init(evidence_updates=evidence_updates)
+        bp_arrays2 = bp2.init(evidence_updates=evidence_updates2)
         bp_arrays2 = bp2.run_bp(bp_arrays2, num_iters=5)
 
         # Get beliefs
-        beliefs1 = bp1.get_beliefs(bp_arrays1)
-        beliefs2 = bp2.get_beliefs(bp_arrays2)
+        beliefs1 = bp1.get_bp_output(bp_arrays1).beliefs
+        beliefs2 = bp2.get_bp_output(bp_arrays2).beliefs
 
-        assert np.allclose(beliefs1["children"], beliefs2["children"], atol=1e-4)
-        assert np.allclose(beliefs1["parents"], beliefs2["parents"], atol=1e-4)
+        assert np.allclose(
+            beliefs1[children_variables1], beliefs2[children_variables2], atol=1e-4
+        )
+        assert np.allclose(
+            beliefs1[parents_variables1], beliefs2[parents_variables2], atol=1e-4
+        )

From c8f2c5e22f2327a81c386440cd2901d2497445fd Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Wed, 13 Apr 2022 01:00:51 +0000
Subject: [PATCH 04/35] Minbor

---
 examples/pmp_binary_deconvolution.py | 2 ++
 examples/rbm.py                      | 2 ++
 pgmax/fg/graph.py                    | 2 +-
 3 files changed, 5 insertions(+), 1 deletion(-)

diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py
index a5c91c22..c61fefcf 100644
--- a/examples/pmp_binary_deconvolution.py
+++ b/examples/pmp_binary_deconvolution.py
@@ -108,6 +108,8 @@ def plot_images(images, display=True, nr=None):
 # See Section 5.6 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) for more details.
 #
 # import imp
+#
+# import imp
 
 import imp
 
diff --git a/examples/rbm.py b/examples/rbm.py
index 0f820851..3dc6a61f 100644
--- a/examples/rbm.py
+++ b/examples/rbm.py
@@ -44,6 +44,8 @@
 
 # %% [markdown]
 # We can then initialize the factor graph for the RBM with
+#
+# import imp
 
 import imp
 
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index c144660f..be4c026a 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -1104,7 +1104,7 @@ def compute_flat_beliefs(bp_arrays, var_states_for_edges):
 @jax.tree_util.register_pytree_node_class
 @dataclass(frozen=True, eq=False)
 class HashableDict:
-    "A convenient class"
+    "Represents a dictionnary where stored keys are the hash of the elements"
     d: Dict = field(default_factory=dict)
 
     def __post_init__(self):

From 39b546a9fa88d8779e8b01ceea2858ec1c71761b Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Wed, 20 Apr 2022 22:31:42 +0000
Subject: [PATCH 05/35] Variables as tuple + Remove BPOuputs/HashableDict

---
 examples/ising_model.py              |  38 ++-
 examples/pmp_binary_deconvolution.py |  67 ++---
 examples/rbm.py                      |  29 +--
 pgmax/factors/enumeration.py         |  32 +--
 pgmax/factors/logical.py             |  36 +--
 pgmax/fg/graph.py                    | 352 +++++++++++----------------
 pgmax/fg/groups.py                   |  56 ++---
 pgmax/fg/nodes.py                    |   8 +-
 pgmax/groups/enumeration.py          |  46 ++--
 pgmax/groups/logical.py              |  11 +-
 pgmax/groups/variables.py            |  36 ++-
 tests/factors/test_and.py            |  99 ++++----
 tests/factors/test_or.py             |  40 ++-
 tests/fg/test_wiring.py              |  26 +-
 14 files changed, 394 insertions(+), 482 deletions(-)

diff --git a/examples/ising_model.py b/examples/ising_model.py
index 837c764e..88535c23 100644
--- a/examples/ising_model.py
+++ b/examples/ising_model.py
@@ -31,19 +31,19 @@
 variables = vgroup.NDVariableArray(num_states=2, shape=(50, 50))
 fg = graph.FactorGraph(variables=variables)
 
-# TODO: rename variable_for_factors?
-variable_names_for_factors = []
+variables_for_factors = []
 for ii in range(50):
     for jj in range(50):
         kk = (ii + 1) % 50
         ll = (jj + 1) % 50
-        variable_names_for_factors.append([variables[ii, jj], variables[kk, jj]])
-        variable_names_for_factors.append([variables[ii, jj], variables[kk, ll]])
+        variables_for_factors.append([variables[ii, jj], variables[kk, jj]])
+        variables_for_factors.append([variables[ii, jj], variables[kk, ll]])
 
 fg.add_factor_group(
     factory=enumeration.PairwiseFactorGroup,
-    variable_names_for_factors=variable_names_for_factors,
+    variables_for_factors=variables_for_factors,
     log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
+    name="factors",
 )
 
 # %% [markdown]
@@ -52,11 +52,6 @@
 # %%
 bp = graph.BP(fg.bp_state, temperature=0)
 
-# %%
-d = {variables: 1, variables.__hash__(): 2}
-hd = graph.HashableDict(d)
-hd[variables]
-
 # %%
 # TODO: check\ time for before BP vs time for BP
 # TODO: time PGMAX vs PMP
@@ -64,10 +59,10 @@
     evidence_updates={variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
 )
 bp_arrays = bp.run_bp(bp_arrays, num_iters=3000)
-output = bp.get_bp_output(bp_arrays)
+beliefs = bp.get_beliefs(bp_arrays)
 
 # %%
-img = output.map_states[variables]
+img = graph.decode_map_states(beliefs)[variables]
 fig, ax = plt.subplots(1, 1, figsize=(10, 10))
 ax.imshow(img)
 
@@ -82,19 +77,20 @@ def loss(log_potentials_updates, evidence_updates):
     )
     bp_arrays = bp.run_bp(bp_arrays, num_iters=3000)
     beliefs = bp.get_beliefs(bp_arrays)
-    loss = -jnp.sum(beliefs)
+    loss = -jnp.sum(beliefs[variables])
     return loss
 
 
-batch_loss = jax.jit(jax.vmap(loss, in_axes=(None, {None: 0}), out_axes=0))
+batch_loss = jax.jit(jax.vmap(loss, in_axes=(None, {variables: 0}), out_axes=0))
 log_potentials_grads = jax.jit(jax.grad(loss, argnums=0))
 
 # %%
-batch_loss(None, {None: jax.device_put(np.random.gumbel(size=(10, 50, 50, 2)))})
+batch_loss(None, {variables: jax.device_put(np.random.gumbel(size=(10, 50, 50, 2)))})
 
 # %%
 grads = log_potentials_grads(
-    {"factors": jnp.eye(2)}, {None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
+    {"factors": jnp.eye(2)},
+    {variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))},
 )
 
 # %% [markdown]
@@ -104,15 +100,15 @@ def loss(log_potentials_updates, evidence_updates):
 bp_state = bp.to_bp_state(bp_arrays)
 
 # Query evidence for variable (0, 0)
-bp_state.evidence[0, 0]
+bp_state.evidence[variables[0, 0]]
 
 # %%
 # Set evidence for variable (0, 0)
-bp_state.evidence[0, 0] = np.array([1.0, 1.0])
-bp_state.evidence[0, 0]
+bp_state.evidence[variables[0, 0]] = np.array([1.0, 1.0])
+bp_state.evidence[variables[0, 0]]
 
 # %%
 # Set evidence for all variables using an array
 evidence = np.random.randn(50, 50, 2)
-bp_state.evidence[None] = evidence
-bp_state.evidence[10, 10] == evidence[10, 10]
+bp_state.evidence[variables] = evidence
+np.allclose(bp_state.evidence[variables[10, 10]], evidence[10, 10])
diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py
index c61fefcf..bd418a69 100644
--- a/examples/pmp_binary_deconvolution.py
+++ b/examples/pmp_binary_deconvolution.py
@@ -106,10 +106,6 @@ def plot_images(images, display=True, nr=None):
 #  - a second set of ORFactors, which maps SW to X and model (binary) features overlapping.
 #
 # See Section 5.6 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) for more details.
-#
-# import imp
-#
-# import imp
 
 import imp
 
@@ -157,8 +153,8 @@ def plot_images(images, display=True, nr=None):
 print(time.time() - start)
 
 # Define the ANDFactors
-variable_names_for_ANDFactors = []
-variable_names_for_ORFactors_dict = defaultdict(list)
+variables_for_ANDFactors = []
+variables_for_ORFactors_dict = defaultdict(list)
 for idx_img in tqdm(range(n_images)):
     for idx_chan in range(n_chan):
         for idx_s_height in range(s_height):
@@ -178,36 +174,34 @@ def plot_images(images, display=True, nr=None):
                                 idx_feat_width,
                             ]
 
-                            variable_names_for_ANDFactor = [
+                            variables_for_ANDFactor = [
                                 S[idx_img, idx_feat, idx_s_height, idx_s_width],
                                 W[idx_chan, idx_feat, idx_feat_height, idx_feat_width],
                                 SW_var,
                             ]
-                            variable_names_for_ANDFactors.append(
-                                variable_names_for_ANDFactor
-                            )
+                            variables_for_ANDFactors.append(variables_for_ANDFactor)
 
                             X_var = X[idx_img, idx_chan, idx_img_height, idx_img_width]
-                            variable_names_for_ORFactors_dict[X_var].append(SW_var)
+                            variables_for_ORFactors_dict[X_var].append(SW_var)
 print(time.time() - start)
 
 # Add ANDFactorGroup, which is computationally efficient
 fg.add_factor_group(
     factory=logical.ANDFactorGroup,
-    variable_names_for_factors=variable_names_for_ANDFactors,
+    variables_for_factors=variables_for_ANDFactors,
 )
 print(time.time() - start)
 
 # Define the ORFactors
-variable_names_for_ORFactors = [
-    list(tuple(variable_names_for_ORFactors_dict[X_var]) + (X_var,))
-    for X_var in variable_names_for_ORFactors_dict
+variables_for_ORFactors = [
+    list(tuple(variables_for_ORFactors_dict[X_var]) + (X_var,))
+    for X_var in variables_for_ORFactors_dict
 ]
 
 # Add ORFactorGroup, which is computationally efficient
 fg.add_factor_group(
     factory=logical.ORFactorGroup,
-    variable_names_for_factors=variable_names_for_ORFactors,
+    variables_for_factors=variables_for_ORFactors,
 )
 print("Time", time.time() - start)
 
@@ -236,7 +230,7 @@ def plot_images(images, display=True, nr=None):
 
 # %%
 pW = 0.25
-pS = 1e-70
+pS = 1e-72
 pX = 1e-100
 
 # Sparsity inducing priors for W and S
@@ -250,31 +244,6 @@ def plot_images(images, display=True, nr=None):
 uX = np.zeros((X_gt.shape) + (2,))
 uX[..., 0] = (2 * X_gt - 1) * logit(pX)
 
-# %%
-np.random.seed(seed=40)
-
-start = time.time()
-bp_arrays = bp.init(
-    evidence_updates={
-        S: uS + np.random.gumbel(size=uS.shape),
-        W: uW + np.random.gumbel(size=uW.shape),
-        SW: np.zeros(SW.shape),
-        X: uX + np.zeros(shape=uX.shape),
-    },
-)
-print("Time", time.time() - start)
-bp_arrays = bp.run_bp(bp_arrays, num_iters=100, damping=0.5)
-print("Time", time.time() - start)
-output = bp.get_bp_output(bp_arrays)
-print("Time", time.time() - start)
-
-# %%
-_ = plot_images(output.map_states[W].reshape(-1, feat_height, feat_width), nr=1)
-
-# %%
-
-# %%
-
 # %% [markdown]
 # We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap`
 
@@ -285,10 +254,10 @@ def plot_images(images, display=True, nr=None):
 start = time.time()
 bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)(
     evidence_updates={
-        "S": uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape),
-        "W": uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape),
-        "SW": np.zeros(shape=(n_samples,) + SW.shape),
-        "X": uX[None] + np.zeros(shape=(n_samples,) + uX.shape),
+        S: uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape),
+        W: uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape),
+        SW: np.zeros(shape=(n_samples,) + SW.shape),
+        X: uX[None] + np.zeros(shape=(n_samples,) + uX.shape),
     },
 )
 print("Time", time.time() - start)
@@ -298,8 +267,8 @@ def plot_images(images, display=True, nr=None):
     out_axes=0,
 )(bp_arrays)
 print("Time", time.time() - start)
-# beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
-# map_states = graph.decode_map_states(beliefs)
+beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
+map_states = graph.decode_map_states(beliefs)
 
 # %% [markdown]
 # Visualizing the MAP decoding, we see that we have 4 good random samples (one per row) from the posterior!
@@ -307,4 +276,4 @@ def plot_images(images, display=True, nr=None):
 # Because we have used one extra feature for inference, each posterior sample recovers the 4 basic features used to generate the images, and includes an extra symbol.
 
 # %%
-_ = plot_images(map_states["W"].reshape(-1, feat_height, feat_width), nr=n_samples)
+_ = plot_images(map_states[W].reshape(-1, feat_height, feat_width), nr=n_samples)
diff --git a/examples/rbm.py b/examples/rbm.py
index 3dc6a61f..9a46c0dc 100644
--- a/examples/rbm.py
+++ b/examples/rbm.py
@@ -44,8 +44,6 @@
 
 # %% [markdown]
 # We can then initialize the factor graph for the RBM with
-#
-# import imp
 
 import imp
 
@@ -116,14 +114,14 @@ def run_numba(h, v, f):
 # Add unary factors
 fg.add_factor_group(
     factory=enumeration.EnumerationFactorGroup,
-    variable_names_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])],
+    variables_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])],
     factor_configs=np.arange(2)[:, None],
     log_potentials=np.stack([np.zeros_like(bh), bh], axis=1),
 )
 
 fg.add_factor_group(
     factory=enumeration.EnumerationFactorGroup,
-    variable_names_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])],
+    variables_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])],
     factor_configs=np.arange(2)[:, None],
     log_potentials=np.stack([np.zeros_like(bv), bv], axis=1),
 )
@@ -133,7 +131,7 @@ def run_numba(h, v, f):
 
 fg.add_factor_group(
     factory=enumeration.PairwiseFactorGroup,
-    variable_names_for_factors=[
+    variables_for_factors=[
         [hidden_variables[ii], visible_variables[jj]]
         for ii in range(bh.shape[0])
         for jj in range(bv.shape[0])
@@ -218,7 +216,7 @@ def run_numba(h, v, f):
 print("Time", time.time() - start)
 bp_arrays = bp.run_bp(bp_arrays, num_iters=100, damping=0.5)
 print("Time", time.time() - start)
-output = bp.get_bp_output(bp_arrays)
+beliefs = bp.get_beliefs(bp_arrays)
 print("Time", time.time() - start)
 
 # %% [markdown]
@@ -228,7 +226,9 @@ def run_numba(h, v, f):
 
 # %%
 fig, ax = plt.subplots(1, 1, figsize=(10, 10))
-ax.imshow(output.map_states[visible_variables].copy().reshape((28, 28)), cmap="gray")
+ax.imshow(
+    graph.map_states(beliefs)[visible_variables].copy().reshape((28, 28)), cmap="gray"
+)
 ax.axis("off")
 
 # %% [markdown]
@@ -256,8 +256,8 @@ def run_numba(h, v, f):
 n_samples = 10
 bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)(
     evidence_updates={
-        "hidden": np.random.gumbel(size=(n_samples, bh.shape[0], 2)),
-        "visible": np.random.gumbel(size=(n_samples, bv.shape[0], 2)),
+        hidden_variables: np.random.gumbel(size=(n_samples, bh.shape[0], 2)),
+        visible_variables: np.random.gumbel(size=(n_samples, bv.shape[0], 2)),
     },
 )
 bp_arrays = jax.vmap(
@@ -266,11 +266,8 @@ def run_numba(h, v, f):
     out_axes=0,
 )(bp_arrays)
 
-# TODO: problem
-outputs = jax.vmap(bp.get_bp_output, in_axes=0, out_axes=0)(bp_arrays)
-# map_states = graph.decode_map_states(beliefs)
-
-# %%
+beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
+map_states = graph.decode_map_states(beliefs)
 
 # %% [markdown]
 # Visualizing the MAP decodings (Figure [fig:rbm_multiple_digits]), we see that we have sampled 10 MNIST digits in parallel!
@@ -279,10 +276,8 @@ def run_numba(h, v, f):
 fig, ax = plt.subplots(2, 5, figsize=(20, 8))
 for ii in range(10):
     ax[np.unravel_index(ii, (2, 5))].imshow(
-        map_states["visible"][ii].copy().reshape((28, 28)), cmap="gray"
+        map_states[visible_variables][ii].copy().reshape((28, 28)), cmap="gray"
     )
     ax[np.unravel_index(ii, (2, 5))].axis("off")
 
 fig.tight_layout()
-
-# %%
diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py
index 450cc5ab..a8078e78 100644
--- a/pgmax/factors/enumeration.py
+++ b/pgmax/factors/enumeration.py
@@ -21,7 +21,7 @@ class EnumerationWiring(nodes.Wiring):
     Args:
         factor_configs_edge_states: Array of shape (num_factor_configs, 2)
             factor_configs_edge_states[ii] contains a pair of global enumeration factor_config and global edge_state indices
-            factor_configs_edge_states[ii, 0] contains the global EnumerationFactor config index,
+            factor_configs_edge_states[ii, 0] contains the global EnumerExpected factor_edges_num_statesationFactor config index,
             factor_configs_edge_states[ii, 1] contains the corresponding global edge_state index.
             Both indices only take into account the EnumerationFactors of the FactorGraph
 
@@ -81,9 +81,9 @@ def __post_init__(self):
                 f"EnumerationFactor. Got a factor_configs array of shape {self.factor_configs.shape}."
             )
 
-        if len(self.vars_to_num_states.keys()) != self.factor_configs.shape[1]:
+        if len(self.variables) != self.factor_configs.shape[1]:
             raise ValueError(
-                f"Number of variables {len(self.vars_to_num_states.keys())} doesn't match given configurations {self.factor_configs.shape}"
+                f"Number of variables {len(self.variables)} doesn't match given configurations {self.factor_configs.shape}"
             )
 
         if self.log_potentials.shape != (self.factor_configs.shape[0],):
@@ -93,7 +93,7 @@ def __post_init__(self):
                 f"shape {self.log_potentials.shape}."
             )
 
-        vars_num_states = np.array([list(self.vars_to_num_states.values())])
+        vars_num_states = np.array([variable[1] for variable in self.variables])
         if not np.logical_and(
             self.factor_configs >= 0, self.factor_configs < vars_num_states[None]
         ).all():
@@ -154,20 +154,18 @@ def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiri
 
     @staticmethod
     def compile_wiring(
-        factor_edges_num_states: np.ndarray,
-        variables_for_factors: Tuple[int, ...],  # TODO: rename
+        variables_for_factors: Tuple[Tuple[int, int], ...],
         factor_configs: np.ndarray,
-        vars_to_starts: Mapping[int, int],
+        vars_to_starts: Mapping[Tuple[int, int], int],
         num_factors: int,
     ) -> EnumerationWiring:
         """Compile an EnumerationWiring for an EnumerationFactor or a FactorGroup with EnumerationFactors.
         Internally calls _compile_var_states_numba and _compile_enumeration_wiring_numba for speed.
 
         Args:
-            factor_edges_num_states: An array concatenating the number of states for the variables connected to each
-                Factor of the FactorGroup. Each variable will appear once for each Factor it connects to.
-            variables_for_factors: A tuple containing variables connected to each Factor of the FactorGroup.
-                Each variable will appear once for each Factor it connects to.
+            variables_for_factors: A list of list of variables, where each innermost element is a
+                variable. Each list within the outer list is taken to contain the names of the
+                variables connected to a Factor.
             factor_configs: Array of shape (num_val_configs, num_variables) containing an explicit enumeration
                 of all valid configurations.
             vars_to_starts: A dictionary that maps variables to their global starting indices
@@ -181,9 +179,15 @@ def compile_wiring(
         Returns:
             The EnumerationWiring
         """
-        var_states = np.array(
-            [vars_to_starts[variable] for variable in variables_for_factors]
-        )
+        var_states = []
+        factor_edges_num_states = []
+        for variables_for_factor in variables_for_factors:
+            for variable in variables_for_factor:
+                var_states.append(vars_to_starts[variable])
+                factor_edges_num_states.append(variable[1])
+        var_states = np.array(var_states)
+        factor_edges_num_states = np.array(factor_edges_num_states)
+
         num_states_cumsum = np.insert(np.cumsum(factor_edges_num_states), 0, 0)
         var_states_for_edges = np.empty(shape=(num_states_cumsum[-1],), dtype=int)
         _compile_var_states_numba(var_states_for_edges, num_states_cumsum, var_states)
diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py
index c53a5fac..81eab93d 100644
--- a/pgmax/factors/logical.py
+++ b/pgmax/factors/logical.py
@@ -86,14 +86,12 @@ class LogicalFactor(nodes.Factor):
     edge_states_offset: int = field(init=False)
 
     def __post_init__(self):
-        if len(self.vars_to_num_states.keys()) < 2:
+        if len(self.variables) < 2:
             raise ValueError(
-                "At least one parent variable and one child variable is required"
+                "A LogicalFactor requires at least one parent variable and one child variable "
             )
 
-        if not np.all(
-            [num_states == 2 for num_states in self.vars_to_num_states.values()]
-        ):
+        if not np.all([variable[1] == 2 for variable in self.variables]):
             raise ValueError("All variables should all be binary")
 
     @staticmethod
@@ -142,9 +140,7 @@ def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring:
 
     @staticmethod
     def compile_wiring(
-        factor_edges_num_states: np.ndarray,
-        variables_for_factors: Tuple[int, ...],  # notsure
-        factor_sizes: np.ndarray,
+        variables_for_factors: Tuple[Tuple[int, int], ...],
         vars_to_starts: Mapping[int, int],
         edge_states_offset: int,
     ) -> LogicalWiring:
@@ -152,11 +148,9 @@ def compile_wiring(
         Internally calls _compile_var_states_numba and _compile_logical_wiring_numba for speed.
 
         Args:
-            factor_edges_num_states: An array concatenating the number of states for the variables connected to each
-                Factor of the FactorGroup. Each variable will appear once for each Factor it connects to.
-            variables_for_factors: A tuple of tuples containing variables connected to each Factor of the FactorGroup.
-                Each variable will appear once for each Factor it connects to.
-            factor_sizes: An array containing the different factor sizes.
+            variables_for_factors: A list of list of variables, where each innermost element is a
+                variable. Each list within the outer list is taken to contain the names of the
+                variables connected to a Factor.
             vars_to_starts: A dictionary that maps variables to their global starting indices
                 For an n-state variable, a global start index of m means the global indices
                 of its n variable states are m, m + 1, ..., m + n - 1
@@ -166,11 +160,21 @@ def compile_wiring(
         Returns:
             The LogicalWiring
         """
+        factor_sizes = []
+        var_states = []
+        factor_edges_num_states = []
+        for variables_for_factor in variables_for_factors:
+            factor_sizes.append(len(variables_for_factor))
+            for variable in variables_for_factor:
+                var_states.append(vars_to_starts[variable])
+                factor_edges_num_states.append(variable[1])
+        factor_sizes = np.array(factor_sizes)
+        var_states = np.array(var_states)
+        factor_edges_num_states = np.array(factor_edges_num_states)
+
+        # Relevant state differs for ANDFactors and ORFactors
         relevant_state = (-edge_states_offset + 1) // 2
 
-        var_states = np.array(
-            [vars_to_starts[variable] for variable in variables_for_factors]
-        )
         # Note: all the variables in a LogicalFactorGroup are binary
         num_states_cumsum = np.arange(0, 2 * var_states.shape[0] + 2, 2)
         var_states_for_edges = np.empty(shape=(2 * var_states.shape[0],), dtype=int)
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index be4c026a..1a8a9833 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -7,7 +7,7 @@
 import functools
 import inspect
 import typing
-from dataclasses import asdict, dataclass, field
+from dataclasses import asdict, dataclass
 from types import MappingProxyType
 from typing import (
     Any,
@@ -62,37 +62,32 @@ def __post_init__(self):
         start = time.time()
         # if isinstance(self.variables, groups.VariableGroup):
         if isinstance(self.variables, vgroup.NDVariableArray):
-            self.variables = [self.variables]
-
-        self._variable_group: Mapping[
-            int, groups.VariableGroup
-        ] = collections.OrderedDict()
-        for variable_group in self.variables:
-            self._variable_group[variable_group.__hash__()] = variable_group
-
-        vars_names = []
-        vars_num_states = []
-        for variable_group in self.variables:
-            if isinstance(variable_group, vgroup.NDVariableArray):
-                vars_names.append(variable_group.variable_names.flatten())
-                vars_num_states.append(variable_group.num_states.flatten())
+            self.variable_groups = [self.variables]
+        else:
+            self.variable_groups = self.variables
+
+        # TODO: remove?
+        self._variable_group = self.variable_groups
+        # self._variable_group: Mapping[
+        #     int, groups.VariableGroup
+        # ] = collections.OrderedDict()
+        # for variable_group in self.variables:
+        #     self._variable_group[variable_group.__hash__()] = variable_group
+
+        self._variables = [
+            variable
+            for variable_group in self.variable_groups
+            for variable in variable_group.variables
+        ]
         print("1", time.time() - start)
 
-        vars_names = np.concatenate(vars_names)
-        vars_num_states = np.concatenate(vars_num_states)
-        vars_num_states_cumsum = np.insert(
-            np.array(vars_num_states).cumsum(),
-            0,
-            0,
-        )
-
         # Useful objects to build the FactorGraph
         self._factor_types_to_groups: OrderedDict[
             Type, List[groups.FactorGroup]
         ] = collections.OrderedDict(
             [(factor_type, []) for factor_type in FAC_TO_VAR_UPDATES]
         )
-        self._factor_types_to_variable_names_for_factors: OrderedDict[
+        self._factor_types_to_variables_for_factors: OrderedDict[
             Type, Set[FrozenSet]
         ] = collections.OrderedDict(
             [(factor_type, set()) for factor_type in FAC_TO_VAR_UPDATES]
@@ -100,14 +95,16 @@ def __post_init__(self):
         print("2", time.time() - start)
 
         # Used to add FactorGroups
-        # TODO: move to dict, which is faster
-        self._vars_to_num_states: OrderedDict[int, int] = collections.OrderedDict(
-            zip(vars_names, vars_num_states)
+        vars_num_states = [variable[1] for variable in self._variables]
+        vars_num_states_cumsum = np.insert(
+            np.array(vars_num_states).cumsum(),
+            0,
+            0,
         )
         # See FactorGraphState docstrings for documentation on the following fields
         self._num_var_states = vars_num_states_cumsum[-1]
         self._vars_to_starts: OrderedDict[int, int] = collections.OrderedDict(
-            zip(vars_names, vars_num_states_cumsum[:-1])
+            zip(self._variables, vars_num_states_cumsum[:-1])
         )
         self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {}
         print("3", time.time() - start)
@@ -127,6 +124,7 @@ def add_factor(
         variable_names: List,
         factor_configs: np.ndarray,
         log_potentials: Optional[np.ndarray] = None,
+        name: Optional[str] = None,
     ) -> None:
         """Function to add a single factor to the FactorGraph.
 
@@ -145,15 +143,14 @@ def add_factor(
                 initialized.
         """
         factor_group = EnumerationFactorGroup(
-            self._vars_to_num_states,
-            variable_names_for_factors=[variable_names],
+            variables_for_factors=[variable_names],
             factor_configs=factor_configs,
             log_potentials=log_potentials,
         )
-        self._register_factor_group(factor_group)
+        self._register_factor_group(factor_group, name)
 
     def add_factor_by_type(
-        self, variable_names: List[int], factor_type: type, *args, **kwargs
+        self, variables: List[int], factor_type: type, *args, **kwargs
     ) -> None:
         """Function to add a single factor to the FactorGraph.
 
@@ -169,7 +166,7 @@ def add_factor_by_type(
             To add an ORFactor to a FactorGraph fg, run::
 
                 fg.add_factor_by_type(
-                    variable_names=variables_names_for_OR_factor,
+                    variables=variables_for_OR_factor,
                     factor_type=logical.ORFactor
                 )
         """
@@ -178,16 +175,13 @@ def add_factor_by_type(
                 f"Type {factor_type} is not one of the supported factor types {FAC_TO_VAR_UPDATES.keys()}"
             )
 
-        vars_to_num_states = collections.OrderedDict(
-            (var, self._vars_to_num_states[var]) for var in variable_names
-        )
-        factor = factor_type(vars_to_num_states, *args, **kwargs)
+        name = kwargs.pop("name", None)
+        factor = factor_type(variables, *args, **kwargs)
         factor_group = groups.SingleFactorGroup(
-            vars_to_num_states=self._vars_to_num_states,
-            variable_names_for_factors=[variable_names],
+            variables_for_factors=[variables],
             factor=factor,
         )
-        self._register_factor_group(factor_group)
+        self._register_factor_group(factor_group, name)
 
     def add_factor_group(self, factory: Callable, *args, **kwargs) -> None:
         """Add a factor group to the factor graph
@@ -198,10 +192,13 @@ def add_factor_group(self, factory: Callable, *args, **kwargs) -> None:
             kwargs: kwargs to be passed to the factory function, and an optional "name" argument
                 for specifying the name of a named factor group.
         """
-        factor_group = factory(self._vars_to_num_states, *args, **kwargs)
-        self._register_factor_group(factor_group)
+        name = kwargs.pop("name", None)
+        factor_group = factory(*args, **kwargs)
+        self._register_factor_group(factor_group, name)
 
-    def _register_factor_group(self, factor_group: groups.FactorGroup) -> None:
+    def _register_factor_group(
+        self, factor_group: groups.FactorGroup, name: Optional[str] = None
+    ) -> None:
         """Register a factor group to the factor graph, by updating the factor graph state.
 
         Args:
@@ -212,21 +209,25 @@ def _register_factor_group(self, factor_group: groups.FactorGroup) -> None:
             ValueError: If the factor group with the same name or a factor involving the same variables
                 already exists in the factor graph.
         """
+        if name in self._named_factor_groups:
+            raise ValueError(
+                f"A factor group with the name {name} already exists. Please choose a different name!"
+            )
 
         factor_type = factor_group.factor_type
-        for var_names_for_factor in factor_group.variable_names_for_factors:
+        for var_names_for_factor in factor_group.variables_for_factors:
             var_names = frozenset(var_names_for_factor)
-            if (
-                var_names
-                in self._factor_types_to_variable_names_for_factors[factor_type]
-            ):
+            if var_names in self._factor_types_to_variables_for_factors[factor_type]:
                 raise ValueError(
                     f"A Factor of type {factor_type} involving variables {var_names} already exists. Please merge the corresponding factors."
                 )
-            self._factor_types_to_variable_names_for_factors[factor_type].add(var_names)
+            self._factor_types_to_variables_for_factors[factor_type].add(var_names)
 
         self._factor_types_to_groups[factor_type].append(factor_group)
 
+        if name is not None:
+            self._named_factor_groups[name] = factor_group
+
     @functools.lru_cache(None)
     def compute_offsets(self) -> None:
         """Compute factor messages offsets for the factor types and factor groups
@@ -258,7 +259,7 @@ def compute_offsets(self) -> None:
                     factor_group
                 ] = factor_num_configs_cumsum
 
-                factor_num_states_cumsum += sum(factor_group.factor_edges_num_states)
+                factor_num_states_cumsum += factor_group.total_num_states
                 factor_num_configs_cumsum += (
                     factor_group.factor_group_log_potentials.shape[0]
                 )
@@ -739,19 +740,22 @@ def update_evidence(
     Returns:
         A flat jnp array containing updated evidence.
     """
-    for var_group_name, data in updates.items():
-        if var_group_name in fg_state.variable_group:
-            variable_group = fg_state.variable_group[var_group_name]
-            first_variable = variable_group.variable_names.flatten()[0]
+    for name, data in updates.items():
+        # Name is a variable_group or a variable
+        if name in fg_state.variable_group:
+            first_variable = name.variables[0]
             start_index = fg_state.vars_to_starts[first_variable]
-            flat_data = variable_group.flatten(data)
+            flat_data = name.flatten(data)
             evidence = evidence.at[start_index : start_index + flat_data.shape[0]].set(
                 flat_data
             )
-        # else:
-        #     var = fg_state.variable_group[name]
-        #     start_index = fg_state.vars_to_starts[var]
-        #     evidence = evidence.at[start_index : start_index + var.num_states].set(data)
+        elif name in fg_state.vars_to_starts:
+            start_index = fg_state.vars_to_starts[name]
+            evidence = evidence.at[start_index : start_index + name[1]].set(data)
+        else:
+            raise ValueError(
+                "Got evidence for a variable or a variable group not in the FactorGraph!"
+            )
     return evidence
 
 
@@ -781,19 +785,18 @@ def __post_init__(self):
 
             object.__setattr__(self, "value", self.value)
 
-    def __getitem__(self, name: Any) -> np.ndarray:
+    def __getitem__(self, variable: Tuple[int, int]) -> np.ndarray:
         """Function to query evidence for a variable
 
         Args:
-            name: name for the variable
+            variable: Variable queried
 
         Returns:
             evidence for the queried variable
         """
         value = cast(np.ndarray, self.value)
-        variable = self.fg_state.variable_group[name]
         start = self.fg_state.vars_to_starts[variable]
-        evidence = value[start : start + variable.num_states]
+        evidence = value[start : start + variable[1]]
         return evidence
 
     def __setitem__(
@@ -888,18 +891,18 @@ class BeliefPropagation:
             Returns:
                 The reconstructed BPState
 
-        get_bp_output: Function to calculate beliefs from a BPArrays.
+        get_beliefs: Function to calculate beliefs from a BPArrays.
             Args:
                 bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence.
             Returns:
-                bp_output: Belief propagation output.
+                beliefs: Beliefs returned by belief propagation.
     """
 
     init: Callable
     update: Callable
     run_bp: Callable
     to_bp_state: Callable
-    get_bp_output: Callable
+    get_beliefs: Callable
 
 
 def BP(bp_state: BPState, temperature: float = 0.0) -> BeliefPropagation:
@@ -968,19 +971,16 @@ def update(
 
         if log_potentials_updates is not None:
             log_potentials = update_log_potentials(
-                log_potentials, HashableDict(log_potentials_updates), bp_state.fg_state
+                log_potentials, log_potentials_updates, bp_state.fg_state
             )
 
         if ftov_msgs_updates is not None:
             ftov_msgs = update_ftov_msgs(
-                ftov_msgs, HashableDict(ftov_msgs_updates), bp_state.fg_state
+                ftov_msgs, ftov_msgs_updates, bp_state.fg_state
             )
 
         if evidence_updates is not None:
-            # Note: if we overwrite variables.__hash__ then hash may change when we call jit
-            evidence = update_evidence(
-                evidence, HashableDict(evidence_updates), bp_state.fg_state
-            )
+            evidence = update_evidence(evidence, evidence_updates, bp_state.fg_state)
 
         return BPArrays(
             log_potentials=log_potentials, ftov_msgs=ftov_msgs, evidence=evidence
@@ -1067,92 +1067,7 @@ def to_bp_state(bp_arrays: BPArrays) -> BPState:
             evidence=Evidence(fg_state=bp_state.fg_state, value=bp_arrays.evidence),
         )
 
-    def get_bp_output(bp_arrays: BPArrays) -> Any:
-        """Function to calculate beliefs from a BPArrays
-
-        Args:
-            bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence.
-
-        Returns:
-            bp_output: Belief propagation output.
-        """
-
-        @jax.jit
-        def compute_flat_beliefs(bp_arrays, var_states_for_edges):
-            flat_beliefs = (
-                jax.device_put(bp_arrays.evidence)
-                .at[jax.device_put(var_states_for_edges)]
-                .add(bp_arrays.ftov_msgs)
-            )
-            return flat_beliefs
-
-        return BeliefPropagationOutputs(
-            compute_flat_beliefs(bp_arrays, var_states_for_edges),
-            bp_state.fg_state.variable_group,
-        )
-
-    bp = BeliefPropagation(
-        init=functools.partial(update, None),
-        update=update,
-        run_bp=run_bp,
-        to_bp_state=to_bp_state,
-        get_bp_output=get_bp_output,
-    )
-    return bp
-
-
-@jax.tree_util.register_pytree_node_class
-@dataclass(frozen=True, eq=False)
-class HashableDict:
-    "Represents a dictionnary where stored keys are the hash of the elements"
-    d: Dict = field(default_factory=dict)
-
-    def __post_init__(self):
-        # Allows to initialize from dict without knowing hash
-        new_d = {k.__hash__(): v for k, v in self.d.items()}
-        object.__setattr__(self, "d", new_d)
-
-    @functools.lru_cache()
-    def keys(self):
-        return self.d.keys()
-
-    def items(self):
-        return self.d.items()
-
-    def __setitem__(self, key, value):
-        # Allows to copy keys from another HashableDict
-        self.d[key] = value
-
-    def __getitem__(self, value):
-        # Allows to retrieve from dict without knowing hash
-        return self.d[value.__hash__()]
-
-    def tree_flatten(self):
-        return jax.tree_util.tree_flatten(asdict(self))
-
-    @classmethod
-    def tree_unflatten(cls, aux_data, children):
-        return cls(**aux_data.unflatten(children))
-
-
-@jax.tree_util.register_pytree_node_class
-@dataclass(frozen=True, eq=False)
-class BeliefPropagationOutputs:
-    # beliefs: An array or a PyTree container containing the beliefs for the variables.
-    flat_beliefs: jnp.ndarray
-    variable_groups: Mapping[int, groups.VariableGroup]
-    beliefs: HashableDict = field(init=False)
-
-    def __post_init__(self):
-        if self.flat_beliefs.ndim != 1:
-            raise ValueError(
-                f"Can only unflatten 1D array. Got a {self.flat_beliefs.ndim}D array."
-            )
-        beliefs = self.unflatten()
-        object.__setattr__(self, "beliefs", beliefs)
-
-    @functools.lru_cache
-    def unflatten(self) -> None:
+    def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
         """Function that recovers meaningful structured data from internal flat data array
 
         Args:
@@ -1164,86 +1079,115 @@ def unflatten(self) -> None:
         Raises:
             ValueError: if flat_data is not of the right shape
         """
-        # Note: this is a reimplementation of CompositeVariableGroup.unflatten
+        if flat_beliefs.ndim != 1:
+            raise ValueError(
+                f"Can only unflatten 1D array. Got a {flat_beliefs.ndim}D array."
+            )
+
         num_variables = 0
         num_variable_states = 0
-        for variable_group in self.variable_groups.values():
+        for variable_group in variable_groups:
             if isinstance(variable_group, vgroup.NDVariableArray):
                 num_variables += variable_group.num_states.size
                 num_variable_states += variable_group.num_states.sum()
 
-        if self.flat_beliefs.shape[0] == num_variables:
+        if flat_beliefs.shape[0] == num_variables:
             use_num_states = False
-        elif self.flat_beliefs.shape[0] == num_variable_states:
+        elif flat_beliefs.shape[0] == num_variable_states:
             use_num_states = True
         else:
             raise ValueError(
-                f"flat_data should be either of shape (num_variables(={len(self.variables)}),), "
+                f"flat_data should be either of shape (num_variables(={len(num_variables)}),), "
                 f"or (num_variable_states(={num_variable_states}),). "
-                f"Got {self.flat_beliefs.shape}"
+                f"Got {flat_beliefs.shape}"
             )
 
-        beliefs = HashableDict()
+        beliefs = {}
         start = 0
-        for name, variable_group in self.variable_groups.items():
+        for variable_group in variable_groups:
             if use_num_states:
                 length = variable_group.num_states.sum()
             else:
                 length = variable_group.num_states.size
 
-            beliefs[name] = variable_group.unflatten(
-                self.flat_beliefs[start : start + length]
+            beliefs[variable_group] = variable_group.unflatten(
+                flat_beliefs[start : start + length]
             )
             start += length
         return beliefs
 
-    @cached_property
-    def map_states(self) -> Any:
-        """Function to decode MAP states given the calculated beliefs.
+    @jax.jit
+    def get_beliefs(bp_arrays: BPArrays) -> Dict[Hashable, Any]:
+        """Function to calculate beliefs from a BPArrays
 
         Args:
-            beliefs: An array or a PyTree container containing beliefs for different variables.
+            bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence.
 
         Returns:
-            An array or a PyTree container containing the MAP states for different variables.
+            beliefs: Beliefs returned by belief propagation.
         """
 
-        @jax.jit
-        def _decode_map_states(beliefs) -> Any:
-            return jax.tree_util.tree_map(lambda x: jnp.argmax(x, axis=-1), beliefs)
+        def compute_flat_beliefs(bp_arrays, var_states_for_edges):
+            flat_beliefs = (
+                jax.device_put(bp_arrays.evidence)
+                .at[jax.device_put(var_states_for_edges)]
+                .add(bp_arrays.ftov_msgs)
+            )
+            return flat_beliefs
+
+        return unflatten_beliefs(
+            compute_flat_beliefs(bp_arrays, var_states_for_edges),
+            bp_state.fg_state.variable_group,
+        )
+
+    bp = BeliefPropagation(
+        init=functools.partial(update, None),
+        update=update,
+        run_bp=run_bp,
+        to_bp_state=to_bp_state,
+        get_beliefs=get_beliefs,
+    )
+    return bp
 
-        map_states = HashableDict()
-        for name, beliefs in self.beliefs.items():
-            map_states[name] = _decode_map_states(beliefs)
-        return map_states
 
-    @cached_property
-    def get_marginals(self) -> Any:
-        """Function to get marginal probabilities given the calculated beliefs.
+def decode_map_states(beliefs: Dict[Hashable, Any]) -> Dict[Hashable, Any]:
+    """Function to decode MAP states given the calculated beliefs.
 
-        Args:
-            beliefs: An array or a PyTree container containing beliefs for different variables.
+    Args:
+        beliefs: An array or a PyTree container containing beliefs for different variables.
 
-        Returns:
-            An array or a PyTree container containing the marginal probabilities different variables.
-        """
+    Returns:
+        An array or a PyTree container containing the MAP states for different variables.
+    """
 
-        @jax.jit
-        def _get_marginals(beliefs) -> Any:
-            return jax.tree_util.tree_map(
-                lambda x: jnp.exp(x - logsumexp(x, axis=-1, keepdims=True)),
-                beliefs,
-            )
+    @jax.jit
+    def _decode_map_states(beliefs) -> Any:
+        return jax.tree_util.tree_map(lambda x: jnp.argmax(x, axis=-1), beliefs)
 
-        marginals = HashableDict()
-        for name, beliefs in self.beliefs.items():
-            marginals[name] = _get_marginals(beliefs)
-        return marginals
+    map_states = {}
+    for variable_group, vgroup_beliefs in beliefs.items():
+        map_states[variable_group] = _decode_map_states(vgroup_beliefs)
+    return map_states
 
-    def tree_flatten(self):
-        return jax.tree_util.tree_flatten(asdict(self))
 
-    @classmethod
-    def tree_unflatten(cls, aux_data, children):
-        # TODO: fix this
-        return cls(**aux_data.unflatten(children))
+def get_marginals(beliefs: Dict[Hashable, Any]) -> Dict[Hashable, Any]:
+    """Function to get marginal probabilities given the calculated beliefs.
+
+    Args:
+        beliefs: An array or a PyTree container containing beliefs for different variables.
+
+    Returns:
+        An array or a PyTree container containing the marginal probabilities different variables.
+    """
+
+    @jax.jit
+    def _get_marginals(beliefs) -> Any:
+        return jax.tree_util.tree_map(
+            lambda x: jnp.exp(x - logsumexp(x, axis=-1, keepdims=True)),
+            beliefs,
+        )
+
+    marginals = {}
+    for variable_group, vgroup_beliefs in beliefs.items():
+        marginals[variable_group] = _get_marginals(vgroup_beliefs)
+    return marginals
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index f9f3dd41..1752a002 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -89,55 +89,28 @@ class FactorGroup:
     """Class to represent a group of Factors.
 
     Args:
-        vars_to_num_states: TODO
-        variable_names_for_factors: A list of list of variable names, where each innermost element is the
-            name of a variable. Each list within the outer list is taken to contain the names of the
+        variables_for_factors: A list of list of variables, where each innermost element is a
+            variable. Each list within the outer list is taken to contain the names of the
             variables connected to a Factor.
         factor_configs: Optional array containing an explicit enumeration of all valid configurations
         log_potentials: Array of log potentials.
 
     Attributes:
         factor_type: Factor type shared by all the Factors in the FactorGroup.
-        factor_sizes: Array of the different factor sizes.
-        factor_edges_num_states: Array concatenating the number of states for the variables connected to each Factor of
-            the FactorGroup. Each variable will appear once for each Factor it connects to.
 
     Raises:
         ValueError: if the FactorGroup does not contain a Factor
     """
 
-    vars_to_num_states: Mapping[int, int]
-    variable_names_for_factors: Sequence[List]
+    variables_for_factors: Sequence[List]
     factor_configs: np.ndarray = field(init=False)
     log_potentials: np.ndarray = field(init=False, default=np.empty((0,)))
     factor_type: Type = field(init=False)
-    factor_sizes: np.ndarray = field(init=False)
-    variables_for_factors: Tuple[Tuple[int], ...] = field(init=False)
-    factor_edges_num_states: np.ndarray = field(init=False)
 
     def __post_init__(self):
-        if len(self.variable_names_for_factors) == 0:
+        if len(self.variables_for_factors) == 0:
             raise ValueError("Do not add a factor group with no factors.")
 
-        # Note: variable_names_for_factors contains the HASHes
-        # Note: this can probably be sped up by numba
-        factor_sizes = []
-        flat_var_names_for_factors = []
-        factor_edges_num_states = []
-        for variable_names_for_factor in self.variable_names_for_factors:
-            for variable_name in variable_names_for_factor:
-                factor_edges_num_states.append(self.vars_to_num_states[variable_name])
-                flat_var_names_for_factors.append(variable_name)
-            factor_sizes.append(len(variable_names_for_factor))
-
-        object.__setattr__(self, "factor_sizes", np.array(factor_sizes))
-        object.__setattr__(
-            self, "variables_for_factors", np.array(flat_var_names_for_factors)
-        )
-        object.__setattr__(
-            self, "factor_edges_num_states", np.array(factor_edges_num_states)
-        )
-
     def __getitem__(self, variables: Sequence[int]) -> Any:
         """Function to query individual factors in the factor group
 
@@ -168,6 +141,17 @@ def _variables_to_factors(self) -> Mapping[FrozenSet, nodes.Factor]:
         """
         return self._get_variables_to_factors()
 
+    @cached_property
+    def total_num_states(self) -> int:
+        """TODO"""
+        return sum(
+            [
+                variable[1]
+                for variables_for_factor in self.variables_for_factors
+                for variable in variables_for_factor
+            ]
+        )
+
     @cached_property
     def factor_group_log_potentials(self) -> np.ndarray:
         """Flattened array of log potentials"""
@@ -182,7 +166,7 @@ def factors(self) -> Tuple[nodes.Factor, ...]:
     @cached_property
     def num_factors(self) -> int:
         """Returns the number of factors in the FactorGroup."""
-        return len(self.variable_names_for_factors)
+        return len(self.variables_for_factors)
 
     def _get_variables_to_factors(self) -> OrderedDict[FrozenSet, Any]:
         """Function that generates a dictionary mapping names to factors.
@@ -260,9 +244,9 @@ class SingleFactorGroup(FactorGroup):
     def __post_init__(self):
         super().__post_init__()
 
-        if not len(self.variable_names_for_factors) == 1:
+        if not len(self.variables_for_factors) == 1:
             raise ValueError(
-                f"SingleFactorGroup should only contain one factor. Got {len(self.variable_names_for_factors)}"
+                f"SingleFactorGroup should only contain one factor. Got {len(self.variables_for_factors)}"
             )
 
         object.__setattr__(self, "factor_type", type(self.factor))
@@ -282,9 +266,7 @@ def _get_variables_to_factors(
         Returns:
             A dictionary mapping all possible names to different factors.
         """
-        return OrderedDict(
-            [(frozenset(self.variable_names_for_factors[0]), self.factor)]
-        )
+        return OrderedDict([(frozenset(self.variables_for_factors[0]), self.factor)])
 
     def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
         """Function that turns meaningful structured data into a flat data array for internal use.
diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py
index 97fe7c5e..97e6bd3b 100644
--- a/pgmax/fg/nodes.py
+++ b/pgmax/fg/nodes.py
@@ -1,7 +1,7 @@
 """A module containing classes that specify the basic components of a Factor Graph."""
 
 from dataclasses import asdict, dataclass
-from typing import OrderedDict, Sequence, Union
+from typing import List, Sequence, Tuple, Union
 
 import jax
 import jax.numpy as jnp
@@ -43,14 +43,14 @@ class Factor:
     """A factor
 
     Args:
-        vars_to_num_states: Dictionnary mapping the variables names, represented
-            in the form of a hash, to the variables number of states.
+        variables: List of variables in the factors. Each variable is represented
+            by a tuple containing the variable hash and number of states.
 
     Raises:
         NotImplementedError: If compile_wiring is not implemented
     """
 
-    vars_to_num_states: OrderedDict[int, int]
+    variables: List[Tuple[int, int]]
     log_potentials: np.ndarray
 
     def __post_init__(self):
diff --git a/pgmax/groups/enumeration.py b/pgmax/groups/enumeration.py
index f091c53d..8c94ad46 100644
--- a/pgmax/groups/enumeration.py
+++ b/pgmax/groups/enumeration.py
@@ -80,19 +80,14 @@ def _get_variables_to_factors(
         variables_to_factors = collections.OrderedDict(
             [
                 (
-                    frozenset(variable_names_for_factor),
+                    frozenset(variables_for_factor),
                     enumeration.EnumerationFactor(
-                        vars_to_num_states=collections.OrderedDict(
-                            (var, self.vars_to_num_states[var])
-                            for var in variable_names_for_factor
-                        ),
+                        variables=variables_for_factor,
                         factor_configs=self.factor_configs,
                         log_potentials=np.array(self.log_potentials)[ii],
                     ),
                 )
-                for ii, variable_names_for_factor in enumerate(
-                    self.variable_names_for_factors
-                )
+                for ii, variables_for_factor in enumerate(self.variables_for_factors)
             ]
         )
         return variables_to_factors
@@ -188,7 +183,7 @@ class PairwiseFactorGroup(groups.FactorGroup):
     Args:
         log_potential_matrix: array of shape (var1.num_states, var2.num_states),
             where var1 and var2 are the 2 VariableGroups (that may refer to the same
-            VariableGroup) whose names are present in each sub-list from self.variable_names_for_factors.
+            VariableGroup) whose names are present in each sub-list from self.variables_for_factors.
         factor_type: Factor type shared by all the Factors in the FactorGroup.
 
     Raises:
@@ -210,8 +205,8 @@ def __post_init__(self):
         if self.log_potential_matrix is None:
             log_potential_matrix = np.zeros(
                 (
-                    self.vars_to_num_states[self.variable_names_for_factors[0][0]],
-                    self.vars_to_num_states[self.variable_names_for_factors[0][1]],
+                    self.variables_for_factors[0][0][1],
+                    self.variables_for_factors[0][1][1],
                 )
             )
         else:
@@ -230,25 +225,25 @@ def __post_init__(self):
             )
 
         if log_potential_matrix.ndim == 3 and log_potential_matrix.shape[0] != len(
-            self.variable_names_for_factors
+            self.variables_for_factors
         ):
             raise ValueError(
-                f"Expected log_potential_matrix for {len(self.variable_names_for_factors)} factors. "
+                f"Expected log_potential_matrix for {len(self.variables_for_factors)} factors. "
                 f"Got log_potential_matrix for {log_potential_matrix.shape[0]} factors."
             )
 
-        for fac_list in self.variable_names_for_factors:
-            if len(fac_list) != 2:
+        for variables_for_factor in self.variables_for_factors:
+            if len(variables_for_factor) != 2:
                 raise ValueError(
                     "All pairwise factors should connect to exactly 2 variables. Got a factor connecting to"
-                    f" {len(fac_list)} variables ({fac_list})."
+                    f" {len(variables_for_factor)} variables ({variables_for_factor})."
                 )
 
-            num_states0 = self.vars_to_num_states[fac_list[0]]
-            num_states1 = self.vars_to_num_states[fac_list[1]]
+            num_states0 = variables_for_factor[0][1]
+            num_states1 = variables_for_factor[1][1]
             if not log_potential_matrix.shape[-2:] == (num_states0, num_states1):
                 raise ValueError(
-                    f"The specified pairwise factor {fac_list} (with {(num_states0, num_states1)}"
+                    f"The specified pairwise factor {variables_for_factor} (with {(num_states0, num_states1)}"
                     f"configurations) does not match the specified log_potential_matrix "
                     f"(with {log_potential_matrix.shape[-2:]} configurations)."
                 )
@@ -264,7 +259,7 @@ def __post_init__(self):
         object.__setattr__(self, "factor_configs", factor_configs)
         log_potential_matrix = np.broadcast_to(
             log_potential_matrix,
-            (len(self.variable_names_for_factors),) + log_potential_matrix.shape[-2:],
+            (len(self.variables_for_factors),) + log_potential_matrix.shape[-2:],
         )
         log_potentials = np.empty(
             shape=(self.num_factors, self.factor_configs.shape[0])
@@ -286,19 +281,14 @@ def _get_variables_to_factors(
         variables_to_factors = collections.OrderedDict(
             [
                 (
-                    frozenset(variable_names_for_factor),
+                    frozenset(variable_for_factor),
                     enumeration.EnumerationFactor(
-                        vars_to_num_states=collections.OrderedDict(
-                            (var, self.vars_to_num_states[var])
-                            for var in variable_names_for_factor
-                        ),
+                        variables=variable_for_factor,
                         factor_configs=self.factor_configs,
                         log_potentials=self.log_potentials[ii],
                     ),
                 )
-                for ii, variable_names_for_factor in enumerate(
-                    self.variable_names_for_factors
-                )
+                for ii, variable_for_factor in enumerate(self.variables_for_factors)
             ]
         )
         return variables_to_factors
diff --git a/pgmax/groups/logical.py b/pgmax/groups/logical.py
index 325a4313..0f4714d5 100644
--- a/pgmax/groups/logical.py
+++ b/pgmax/groups/logical.py
@@ -34,15 +34,10 @@ def _get_variables_to_factors(
         variables_to_factors = collections.OrderedDict(
             [
                 (
-                    frozenset(variable_names_for_factor),
-                    self.factor_type(
-                        vars_to_num_states=collections.OrderedDict(
-                            (var, self.vars_to_num_states[var])
-                            for var in variable_names_for_factor
-                        ),
-                    ),
+                    frozenset(variables_for_factor),
+                    self.factor_type(variables=variables_for_factor),
                 )
-                for variable_names_for_factor in self.variable_names_for_factors
+                for variables_for_factor in self.variables_for_factors
             ]
         )
         return variables_to_factors
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index 4c6587f3..a6f3e8d5 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -1,9 +1,9 @@
 """A module containing the variables group classes inheriting from the base VariableGroup."""
 
-import itertools
 import random
 from dataclasses import dataclass
-from typing import Tuple, Union
+from functools import total_ordering
+from typing import List, Tuple, Union
 
 import jax
 import jax.numpy as jnp
@@ -12,6 +12,7 @@
 from pgmax.utils import cached_property
 
 
+@total_ordering
 @dataclass(frozen=True, eq=False)
 class NDVariableArray:
     """Subclass of VariableGroup for n-dimensional grids of variables.
@@ -22,6 +23,9 @@ class NDVariableArray:
             the notion of a NumPy ndarray shape)
     """
 
+    # TODO: Variables = (hash, num_states)
+    # TODO: VariableGroup can be deleted
+
     shape: Tuple[int, ...]
     num_states: Union[int, np.ndarray]
 
@@ -34,10 +38,27 @@ def __post_init__(self):
         elif isinstance(self.num_states, np.ndarray):
             if self.num_states.shape != self.shape:
                 raise ValueError("Should be same shape")
+        random_hash = random.randint(0, 2**63)
+        object.__setattr__(self, "random_hash", random_hash)
+
+    def __hash__(self):
+        return self.random_hash
+
+    def __eq__(self, other):
+        return hash(self) == hash(other)
+
+    def __lt__(self, other):
+        return hash(self) < hash(other)
 
     def __getitem__(self, val):
         # Numpy indexation will throw IndexError for us if out-of-bounds
-        return self.variable_names[val]
+        return (self.variable_names[val], self.num_states[val])
+
+    @cached_property
+    def variables(self) -> List[Tuple]:
+        vars_names = self.variable_names.flatten()
+        vars_num_states = self.num_states.flatten()
+        return list(zip(vars_names, vars_num_states))
 
     @cached_property
     def variable_names(self) -> np.ndarray:
@@ -47,9 +68,8 @@ def variable_names(self) -> np.ndarray:
             a dictionary mapping all possible names to different variables.
         """
         # Overwite default hash as it does not give enough spacing across consecutive objects
-        this_hash = random.randint(0, 2**63)
         indices = np.reshape(np.arange(np.product(self.shape)), self.shape)
-        return this_hash + indices
+        return self.__hash__() + indices
 
     def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
         """Function that turns meaningful structured data into a flat data array for internal use.
@@ -64,7 +84,7 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
         Raises:
             ValueError: If the data is not of the correct shape.
         """
-        # TODO: what should we do for different number of states
+        # TODO: what should we do for different number of states -> look at maask_array
         if data.shape != self.shape and data.shape != self.shape + (
             self.num_states.max(),
         ):
@@ -110,7 +130,9 @@ def unflatten(
         return data
 
 
-# TODO: delete?
+# TODO: delete? -- NO
+
+
 # @dataclass(frozen=True, eq=False)
 # class VariableDict():
 #     """A variable dictionary that contains a set of variables of the same size
diff --git a/tests/factors/test_and.py b/tests/factors/test_and.py
index 0af04db1..e95e22a8 100644
--- a/tests/factors/test_and.py
+++ b/tests/factors/test_and.py
@@ -8,14 +8,14 @@
 from pgmax.groups import variables as vgroup
 
 
-def test_run_bp_with_AND_factors():
+def test_run_bp_with_ANDFactors():
     """
     Simultaneously test
     (1) the support of ANDFactors in a FactorGraph and their specialized inference for different temperatures
     (2) the support of several factor types in a FactorGraph and during inference
 
     To do so, observe that an ANDFactor can be defined as an equivalent EnumerationFactor
-    (which list all the valid AND configurations) and define two equivalent FactorGraphs
+    (which list all the valid OR configurations) and define two equivalent FactorGraphs
     FG1: first half of factors are defined as EnumerationFactors, second half are defined as ANDFactors
     FG2: first half of factors are defined as ANDFactors, second half are defined as EnumerationFactors
 
@@ -25,6 +25,7 @@ def test_run_bp_with_AND_factors():
     Note: for the first seed, add all the EnumerationFactors to FG1 and all the ANDFactors to FG2
     """
     for idx in range(10):
+        print("it", idx)
         np.random.seed(idx)
 
         # Parameters
@@ -43,30 +44,41 @@ def test_run_bp_with_AND_factors():
         parents_variables1 = vgroup.NDVariableArray(
             num_states=2, shape=(num_parents.sum(),)
         )
-        children_variable1 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
-        fg1 = graph.FactorGraph(
-            variables=dict(parents=parents_variables1, children=children_variable1)
-        )
+        children_variables1 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
+        fg1 = graph.FactorGraph(variables=[parents_variables1, children_variables1])
 
         # Graph 2
         parents_variables2 = vgroup.NDVariableArray(
             num_states=2, shape=(num_parents.sum(),)
         )
-        children_variable2 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
-        fg2 = graph.FactorGraph(
-            variables=dict(parents=parents_variables2, children=children_variable2)
-        )
+        children_variables2 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
+        fg2 = graph.FactorGraph(variables=[parents_variables2, children_variables2])
 
-        # Option 1: Define EnumerationFactors equivalent to the ANDFactors
+        # Variable names for factors
+        variables_for_factors1 = []
+        variables_for_factors2 = []
         for factor_idx in range(num_factors):
-            this_num_parents = num_parents[factor_idx]
-            variable_names = [
-                ("parents", idx)
+            variable_names1 = [
+                parents_variables1[idx]
                 for idx in range(
                     num_parents_cumsum[factor_idx],
                     num_parents_cumsum[factor_idx + 1],
                 )
-            ] + [("children", factor_idx)]
+            ] + [children_variables1[factor_idx]]
+            variables_for_factors1.append(variable_names1)
+
+            variable_names2 = [
+                parents_variables2[idx]
+                for idx in range(
+                    num_parents_cumsum[factor_idx],
+                    num_parents_cumsum[factor_idx + 1],
+                )
+            ] + [children_variables2[factor_idx]]
+            variables_for_factors2.append(variable_names2)
+
+        # Option 1: Define EnumerationFactors equivalent to the ANDFactors
+        for factor_idx in range(num_factors):
+            this_num_parents = num_parents[factor_idx]
 
             configs = np.array(list(product([0, 1], repeat=this_num_parents + 1)))
             # Children state is last
@@ -84,7 +96,7 @@ def test_run_bp_with_AND_factors():
             if factor_idx < num_factors // 2:
                 # Add the first half of factors to FactorGraph1
                 fg1.add_factor(
-                    variable_names=variable_names,
+                    variable_names=variables_for_factors1[factor_idx],
                     factor_configs=valid_configs,
                     log_potentials=np.zeros(valid_configs.shape[0]),
                 )
@@ -92,73 +104,76 @@ def test_run_bp_with_AND_factors():
                 if idx != 0:
                     # Add the second half of factors to FactorGraph2
                     fg2.add_factor(
-                        variable_names=variable_names,
+                        variable_names=variables_for_factors2[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
                 else:
                     # Add all the EnumerationFactors to FactorGraph1 for the first iter
                     fg1.add_factor(
-                        variable_names=variable_names,
+                        variable_names=variables_for_factors1[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
 
         # Option 2: Define the ANDFactors
         num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0)
-        variable_names_for_ANDFactors_fg1 = []
-        variable_names_for_ANDFactors_fg2 = []
+        variables_for_ANDFactors_fg1 = []
+        variables_for_ANDFactors_fg2 = []
 
         for factor_idx in range(num_factors):
-            variables_names_for_ANDFactor = [
-                ("parents", idx)
-                for idx in range(
-                    num_parents_cumsum[factor_idx],
-                    num_parents_cumsum[factor_idx + 1],
-                )
-            ] + [("children", factor_idx)]
             if factor_idx < num_factors // 2:
                 # Add the first half of factors to FactorGraph2
-                variable_names_for_ANDFactors_fg2.append(variables_names_for_ANDFactor)
+                variables_for_ANDFactors_fg2.append(variables_for_factors2[factor_idx])
             else:
                 if idx != 0:
                     # Add the second half of factors to FactorGraph1
-                    variable_names_for_ANDFactors_fg1.append(
-                        variables_names_for_ANDFactor
+                    variables_for_ANDFactors_fg1.append(
+                        variables_for_factors1[factor_idx]
                     )
                 else:
                     # Add all the ANDFactors to FactorGraph2 for the first iter
-                    variable_names_for_ANDFactors_fg2.append(
-                        variables_names_for_ANDFactor
+                    variables_for_ANDFactors_fg2.append(
+                        variables_for_factors2[factor_idx]
                     )
-
         if idx != 0:
             fg1.add_factor_group(
                 factory=logical.ANDFactorGroup,
-                variable_names_for_factors=variable_names_for_ANDFactors_fg1,
+                variables_for_factors=variables_for_ANDFactors_fg1,
             )
         fg2.add_factor_group(
             factory=logical.ANDFactorGroup,
-            variable_names_for_factors=variable_names_for_ANDFactors_fg2,
+            variables_for_factors=variables_for_ANDFactors_fg2,
         )
 
         # Run inference
         bp1 = graph.BP(fg1.bp_state, temperature=temperature)
         bp2 = graph.BP(fg2.bp_state, temperature=temperature)
 
-        evidence_updates = {
-            "parents": jax.device_put(np.random.gumbel(size=(sum(num_parents), 2))),
-            "children": jax.device_put(np.random.gumbel(size=(num_factors, 2))),
+        evidence_parents = jax.device_put(np.random.gumbel(size=(sum(num_parents), 2)))
+        evidence_children = jax.device_put(np.random.gumbel(size=(num_factors, 2)))
+
+        evidence_updates1 = {
+            parents_variables1: evidence_parents,
+            children_variables1: evidence_children,
+        }
+        evidence_updates2 = {
+            parents_variables2: evidence_parents,
+            children_variables2: evidence_children,
         }
 
-        bp_arrays1 = bp1.init(evidence_updates=evidence_updates)
+        bp_arrays1 = bp1.init(evidence_updates=evidence_updates1)
         bp_arrays1 = bp1.run_bp(bp_arrays1, num_iters=5)
-        bp_arrays2 = bp2.init(evidence_updates=evidence_updates)
+        bp_arrays2 = bp2.init(evidence_updates=evidence_updates2)
         bp_arrays2 = bp2.run_bp(bp_arrays2, num_iters=5)
 
         # Get beliefs
         beliefs1 = bp1.get_beliefs(bp_arrays1)
         beliefs2 = bp2.get_beliefs(bp_arrays2)
 
-        assert np.allclose(beliefs1["children"], beliefs2["children"], atol=1e-4)
-        assert np.allclose(beliefs1["parents"], beliefs2["parents"], atol=1e-4)
+        assert np.allclose(
+            beliefs1[children_variables1], beliefs2[children_variables2], atol=1e-4
+        )
+        assert np.allclose(
+            beliefs1[parents_variables1], beliefs2[parents_variables2], atol=1e-4
+        )
diff --git a/tests/factors/test_or.py b/tests/factors/test_or.py
index e634fbab..7a6c0905 100644
--- a/tests/factors/test_or.py
+++ b/tests/factors/test_or.py
@@ -8,7 +8,7 @@
 from pgmax.groups import variables as vgroup
 
 
-def test_run_bp_with_OR_factors():
+def test_run_bp_with_ORFactors():
     """
     Simultaneously test
     (1) the support of ORFactors in a FactorGraph and their specialized inference for different temperatures
@@ -55,8 +55,8 @@ def test_run_bp_with_OR_factors():
         fg2 = graph.FactorGraph(variables=[parents_variables2, children_variables2])
 
         # Variable names for factors
-        variable_names_for_factors1 = []
-        variable_names_for_factors2 = []
+        variables_for_factors1 = []
+        variables_for_factors2 = []
         for factor_idx in range(num_factors):
             variable_names1 = [
                 parents_variables1[idx]
@@ -65,7 +65,7 @@ def test_run_bp_with_OR_factors():
                     num_parents_cumsum[factor_idx + 1],
                 )
             ] + [children_variables1[factor_idx]]
-            variable_names_for_factors1.append(variable_names1)
+            variables_for_factors1.append(variable_names1)
 
             variable_names2 = [
                 parents_variables2[idx]
@@ -74,7 +74,7 @@ def test_run_bp_with_OR_factors():
                     num_parents_cumsum[factor_idx + 1],
                 )
             ] + [children_variables2[factor_idx]]
-            variable_names_for_factors2.append(variable_names2)
+            variables_for_factors2.append(variable_names2)
 
         # Option 1: Define EnumerationFactors equivalent to the ORFactors
         for factor_idx in range(num_factors):
@@ -94,7 +94,7 @@ def test_run_bp_with_OR_factors():
             if factor_idx < num_factors // 2:
                 # Add the first half of factors to FactorGraph1
                 fg1.add_factor(
-                    variable_names=variable_names_for_factors1[factor_idx],
+                    variable_names=variables_for_factors1[factor_idx],
                     factor_configs=valid_configs,
                     log_potentials=np.zeros(valid_configs.shape[0]),
                 )
@@ -102,48 +102,46 @@ def test_run_bp_with_OR_factors():
                 if idx != 0:
                     # Add the second half of factors to FactorGraph2
                     fg2.add_factor(
-                        variable_names=variable_names_for_factors2[factor_idx],
+                        variable_names=variables_for_factors2[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
                 else:
                     # Add all the EnumerationFactors to FactorGraph1 for the first iter
                     fg1.add_factor(
-                        variable_names=variable_names_for_factors1[factor_idx],
+                        variable_names=variables_for_factors1[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
 
         # Option 2: Define the ORFactors
         num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0)
-        variable_names_for_ORFactors_fg1 = []
-        variable_names_for_ORFactors_fg2 = []
+        variables_for_ORFactors_fg1 = []
+        variables_for_ORFactors_fg2 = []
 
         for factor_idx in range(num_factors):
             if factor_idx < num_factors // 2:
                 # Add the first half of factors to FactorGraph2
-                variable_names_for_ORFactors_fg2.append(
-                    variable_names_for_factors2[factor_idx]
-                )
+                variables_for_ORFactors_fg2.append(variables_for_factors2[factor_idx])
             else:
                 if idx != 0:
                     # Add the second half of factors to FactorGraph1
-                    variable_names_for_ORFactors_fg1.append(
-                        variable_names_for_factors1[factor_idx]
+                    variables_for_ORFactors_fg1.append(
+                        variables_for_factors1[factor_idx]
                     )
                 else:
                     # Add all the ORFactors to FactorGraph2 for the first iter
-                    variable_names_for_ORFactors_fg2.append(
-                        variable_names_for_factors2[factor_idx]
+                    variables_for_ORFactors_fg2.append(
+                        variables_for_factors2[factor_idx]
                     )
         if idx != 0:
             fg1.add_factor_group(
                 factory=logical.ORFactorGroup,
-                variable_names_for_factors=variable_names_for_ORFactors_fg1,
+                variables_for_factors=variables_for_ORFactors_fg1,
             )
         fg2.add_factor_group(
             factory=logical.ORFactorGroup,
-            variable_names_for_factors=variable_names_for_ORFactors_fg2,
+            variables_for_factors=variables_for_ORFactors_fg2,
         )
 
         # Run inference
@@ -168,8 +166,8 @@ def test_run_bp_with_OR_factors():
         bp_arrays2 = bp2.run_bp(bp_arrays2, num_iters=5)
 
         # Get beliefs
-        beliefs1 = bp1.get_bp_output(bp_arrays1).beliefs
-        beliefs2 = bp2.get_bp_output(bp_arrays2).beliefs
+        beliefs1 = bp1.get_beliefs(bp_arrays1)
+        beliefs2 = bp2.get_beliefs(bp_arrays2)
 
         assert np.allclose(
             beliefs1[children_variables1], beliefs2[children_variables2], atol=1e-4
diff --git a/tests/fg/test_wiring.py b/tests/fg/test_wiring.py
index a3b6ae2b..4966a66e 100644
--- a/tests/fg/test_wiring.py
+++ b/tests/fg/test_wiring.py
@@ -22,17 +22,15 @@ def test_wiring_with_PairwiseFactorGroup():
     fg = graph.FactorGraph(variables=[A, B])
     fg.add_factor_group(
         factory=enumeration.PairwiseFactorGroup,
-        variable_names_for_factors=[[A[idx], B[idx]] for idx in range(10)],
+        variables_for_factors=[[A[idx], B[idx]] for idx in range(10)],
     )
     factor_group = fg.factor_groups[enumeration_factor.EnumerationFactor][0]
     object.__setattr__(
-        factor_group,
-        "factor_edges_num_states",
-        factor_group.factor_edges_num_states[:-1],
+        factor_group, "factor_configs", factor_group.factor_configs[:, :1]
     )
     with pytest.raises(
         ValueError,
-        match=re.escape("Expected factor_edges_num_states shape is (20,). Got (19,)."),
+        match=re.escape("Expected factor_edges_num_states shape is (10,). Got (20,)."),
     ):
         factor_group.compile_wiring(fg._vars_to_starts)
 
@@ -40,7 +38,7 @@ def test_wiring_with_PairwiseFactorGroup():
     fg1 = graph.FactorGraph(variables=[A, B])
     fg1.add_factor_group(
         factory=enumeration.PairwiseFactorGroup,
-        variable_names_for_factors=[[A[idx], B[idx]] for idx in range(10)],
+        variables_for_factors=[[A[idx], B[idx]] for idx in range(10)],
     )
     assert len(fg1.factor_groups[enumeration_factor.EnumerationFactor]) == 1
 
@@ -49,7 +47,7 @@ def test_wiring_with_PairwiseFactorGroup():
     for idx in range(10):
         fg2.add_factor_group(
             factory=enumeration.PairwiseFactorGroup,
-            variable_names_for_factors=[[A[idx], B[idx]]],
+            variables_for_factors=[[A[idx], B[idx]]],
         )
     assert len(fg2.factor_groups[enumeration_factor.EnumerationFactor]) == 10
 
@@ -57,7 +55,7 @@ def test_wiring_with_PairwiseFactorGroup():
     fg3 = graph.FactorGraph(variables=[A, B])
     for idx in range(10):
         fg3.add_factor_by_type(
-            variable_names=[A[idx], B[idx]],
+            variables=[A[idx], B[idx]],
             factor_type=enumeration_factor.EnumerationFactor,
             **{
                 "factor_configs": np.array([[0, 0], [0, 1], [1, 0], [1, 1]]),
@@ -101,7 +99,7 @@ def test_wiring_with_ORFactorGroup():
     fg1 = graph.FactorGraph(variables=[A, B, C])
     fg1.add_factor_group(
         factory=logical.ORFactorGroup,
-        variable_names_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)],
+        variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)],
     )
     assert len(fg1.factor_groups[logical_factor.ORFactor]) == 1
 
@@ -110,7 +108,7 @@ def test_wiring_with_ORFactorGroup():
     for idx in range(10):
         fg2.add_factor_group(
             factory=logical.ORFactorGroup,
-            variable_names_for_factors=[[A[idx], B[idx], C[idx]]],
+            variables_for_factors=[[A[idx], B[idx], C[idx]]],
         )
     assert len(fg2.factor_groups[logical_factor.ORFactor]) == 10
 
@@ -118,7 +116,7 @@ def test_wiring_with_ORFactorGroup():
     fg3 = graph.FactorGraph(variables=[A, B, C])
     for idx in range(10):
         fg3.add_factor_by_type(
-            variable_names=[A[idx], B[idx], C[idx]],
+            variables=[A[idx], B[idx], C[idx]],
             factor_type=logical_factor.ORFactor,
         )
     assert len(fg3.factor_groups[logical_factor.ORFactor]) == 10
@@ -156,7 +154,7 @@ def test_wiring_with_ANDFactorGroup():
     fg1 = graph.FactorGraph(variables=[A, B, C])
     fg1.add_factor_group(
         factory=logical.ANDFactorGroup,
-        variable_names_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)],
+        variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)],
     )
     assert len(fg1.factor_groups[logical_factor.ANDFactor]) == 1
 
@@ -165,7 +163,7 @@ def test_wiring_with_ANDFactorGroup():
     for idx in range(10):
         fg2.add_factor_group(
             factory=logical.ANDFactorGroup,
-            variable_names_for_factors=[[A[idx], B[idx], C[idx]]],
+            variables_for_factors=[[A[idx], B[idx], C[idx]]],
         )
     assert len(fg2.factor_groups[logical_factor.ANDFactor]) == 10
 
@@ -173,7 +171,7 @@ def test_wiring_with_ANDFactorGroup():
     fg3 = graph.FactorGraph(variables=[A, B, C])
     for idx in range(10):
         fg3.add_factor_by_type(
-            variable_names=[A[idx], B[idx], C[idx]],
+            variables=[A[idx], B[idx], C[idx]],
             factor_type=logical_factor.ANDFactor,
         )
     assert len(fg3.factor_groups[logical_factor.ANDFactor]) == 10

From ef92d1b8482168d76197cd4e8dca2f433fedfbe4 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Thu, 21 Apr 2022 00:39:47 +0000
Subject: [PATCH 06/35] Start tests + mypy

---
 examples/rbm.py              |   4 +-
 pgmax/factors/enumeration.py |   6 +-
 pgmax/factors/logical.py     |   6 +-
 pgmax/fg/graph.py            |  52 ++++----
 pgmax/fg/groups.py           |  15 ++-
 pgmax/fg/nodes.py            |  18 ++-
 pgmax/groups/enumeration.py  |   8 +-
 pgmax/groups/variables.py    | 225 +++++++++++++++++------------------
 tests/fg/test_graph.py       | 185 ++++++++++++++--------------
 tests/fg/test_nodes.py       |  31 ++---
 10 files changed, 268 insertions(+), 282 deletions(-)

diff --git a/examples/rbm.py b/examples/rbm.py
index 9a46c0dc..abd7435d 100644
--- a/examples/rbm.py
+++ b/examples/rbm.py
@@ -44,8 +44,8 @@
 
 # %% [markdown]
 # We can then initialize the factor graph for the RBM with
-
-import imp
+#
+# import imp
 
 # %%
 from pgmax.fg import graph
diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py
index a8078e78..49bd8fd3 100644
--- a/pgmax/factors/enumeration.py
+++ b/pgmax/factors/enumeration.py
@@ -2,7 +2,7 @@
 
 import functools
 from dataclasses import dataclass
-from typing import Mapping, Sequence, Tuple, Union
+from typing import List, Mapping, Sequence, Tuple, Union
 
 import jax
 import jax.numpy as jnp
@@ -21,7 +21,7 @@ class EnumerationWiring(nodes.Wiring):
     Args:
         factor_configs_edge_states: Array of shape (num_factor_configs, 2)
             factor_configs_edge_states[ii] contains a pair of global enumeration factor_config and global edge_state indices
-            factor_configs_edge_states[ii, 0] contains the global EnumerExpected factor_edges_num_statesationFactor config index,
+            factor_configs_edge_states[ii, 0] contains the global EnumerationFactor config index,
             factor_configs_edge_states[ii, 1] contains the corresponding global edge_state index.
             Both indices only take into account the EnumerationFactors of the FactorGraph
 
@@ -154,7 +154,7 @@ def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiri
 
     @staticmethod
     def compile_wiring(
-        variables_for_factors: Tuple[Tuple[int, int], ...],
+        variables_for_factors: Sequence[List],
         factor_configs: np.ndarray,
         vars_to_starts: Mapping[Tuple[int, int], int],
         num_factors: int,
diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py
index 81eab93d..a7554334 100644
--- a/pgmax/factors/logical.py
+++ b/pgmax/factors/logical.py
@@ -2,7 +2,7 @@
 
 import functools
 from dataclasses import dataclass, field
-from typing import Mapping, Optional, Sequence, Tuple, Union
+from typing import List, Mapping, Optional, Sequence, Union
 
 import jax
 import jax.numpy as jnp
@@ -88,7 +88,7 @@ class LogicalFactor(nodes.Factor):
     def __post_init__(self):
         if len(self.variables) < 2:
             raise ValueError(
-                "A LogicalFactor requires at least one parent variable and one child variable "
+                "A LogicalFactor requires at least one parent variable and one child variable"
             )
 
         if not np.all([variable[1] == 2 for variable in self.variables]):
@@ -140,7 +140,7 @@ def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring:
 
     @staticmethod
     def compile_wiring(
-        variables_for_factors: Tuple[Tuple[int, int], ...],
+        variables_for_factors: Sequence[List],
         vars_to_starts: Mapping[int, int],
         edge_states_offset: int,
     ) -> LogicalWiring:
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 1a8a9833..3e5cb570 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -60,8 +60,7 @@ def __post_init__(self):
         import time
 
         start = time.time()
-        # if isinstance(self.variables, groups.VariableGroup):
-        if isinstance(self.variables, vgroup.NDVariableArray):
+        if isinstance(self.variables, (vgroup.NDVariableArray, vgroup.VariableDict)):
             self.variable_groups = [self.variables]
         else:
             self.variable_groups = self.variables
@@ -103,9 +102,9 @@ def __post_init__(self):
         )
         # See FactorGraphState docstrings for documentation on the following fields
         self._num_var_states = vars_num_states_cumsum[-1]
-        self._vars_to_starts: OrderedDict[int, int] = collections.OrderedDict(
-            zip(self._variables, vars_num_states_cumsum[:-1])
-        )
+        self._vars_to_starts: OrderedDict[
+            Tuple[int, int], int
+        ] = collections.OrderedDict(zip(self._variables, vars_num_states_cumsum[:-1]))
         self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {}
         print("3", time.time() - start)
 
@@ -121,7 +120,7 @@ def __hash__(self) -> int:
 
     def add_factor(
         self,
-        variable_names: List,
+        variables: List,
         factor_configs: np.ndarray,
         log_potentials: Optional[np.ndarray] = None,
         name: Optional[str] = None,
@@ -129,7 +128,7 @@ def add_factor(
         """Function to add a single factor to the FactorGraph.
 
         Args:
-            variable_names: A list containing the connected variable names.
+            variables: A list containing the connected variable names.
                 Variable names are tuples of the type (variable_group_name, variable_name_within_variable_group)
             factor_configs: Array of shape (num_val_configs, num_variables)
                 An array containing explicit enumeration of all valid configurations.
@@ -143,7 +142,7 @@ def add_factor(
                 initialized.
         """
         factor_group = EnumerationFactorGroup(
-            variables_for_factors=[variable_names],
+            variables_for_factors=[variables],
             factor_configs=factor_configs,
             log_potentials=log_potentials,
         )
@@ -369,7 +368,7 @@ def fg_state(self) -> FactorGraphState:
         )
 
         return FactorGraphState(
-            variable_group=self._variable_group,
+            variable_groups=self._variable_group,
             vars_to_starts=self._vars_to_starts,
             num_var_states=self._num_var_states,
             total_factor_num_states=self._total_factor_num_states,
@@ -417,8 +416,8 @@ class FactorGraphState:
         wiring: Wiring derived for each factor type.
     """
 
-    variable_group: groups.VariableGroup
-    vars_to_starts: Mapping[int, int]
+    variable_groups: Sequence[groups.VariableGroup]
+    vars_to_starts: Mapping[Tuple[int, int], int]
     num_var_states: int
     total_factor_num_states: int
     named_factor_groups: Mapping[Hashable, groups.FactorGroup]
@@ -612,12 +611,11 @@ def update_ftov_msgs(
         (2) provided name is not valid for ftov_msgs updates.
     """
     for names, data in updates.items():
-        if names in fg_state.variable_group.names:
-            variable = fg_state.variable_group[names]
-            if data.shape != (variable.num_states,):
+        if names in fg_state.variable_groups:
+            if data.shape != (names.total_num_states,):
                 raise ValueError(
                     f"Given belief shape {data.shape} does not match expected "
-                    f"shape {(variable.num_states,)} for variable {names}."
+                    f"shape {(names.total_num_states,)} for variable."
                 )
 
             var_states_for_edges = np.concatenate(
@@ -627,13 +625,13 @@ def update_ftov_msgs(
                 ]
             )
 
-            starts = np.nonzero(
-                var_states_for_edges == fg_state.vars_to_starts[variable]
-            )[0]
-            for start in starts:
-                ftov_msgs = ftov_msgs.at[start : start + variable.num_states].set(
-                    data / starts.shape[0]
-                )
+            # starts = np.nonzero(
+            #     var_states_for_edges == fg_state.vars_to_starts[variable]
+            # )[0]
+            # for start in starts:
+            #     ftov_msgs = ftov_msgs.at[start : start + variable.num_states].set(
+            #         data / starts.shape[0]
+            #     )
         else:
             raise ValueError(
                 "Invalid names for setting messages. "
@@ -709,7 +707,7 @@ def __setitem__(self, names, data) -> None:
         if (
             isinstance(names, tuple)
             and len(names) == 2
-            and names[1] in self.fg_state.variable_group.names
+            and names[1] in self.fg_state.variable_groups
         ):
             names = (frozenset(names[0]), names[1])
 
@@ -742,7 +740,7 @@ def update_evidence(
     """
     for name, data in updates.items():
         # Name is a variable_group or a variable
-        if name in fg_state.variable_group:
+        if name in fg_state.variable_groups:
             first_variable = name.variables[0]
             start_index = fg_state.vars_to_starts[first_variable]
             flat_data = name.flatten(data)
@@ -811,7 +809,7 @@ def __setitem__(
                 If name is the name of a variable group, updates are derived by using the variable group to
                 flatten the data.
                 If name is the name of a variable, data should be of an array shape (num_states,)
-                If name is None, updates are derived by using self.fg_state.variable_group to flatten the data.
+                If name is None, updates are derived by using self.fg_state.variable_groups to flatten the data.
             data: Array containing the evidence updates.
         """
         object.__setattr__(
@@ -1097,7 +1095,7 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
             use_num_states = True
         else:
             raise ValueError(
-                f"flat_data should be either of shape (num_variables(={len(num_variables)}),), "
+                f"flat_data should be either of shape (num_variables(={num_variables}),), "
                 f"or (num_variable_states(={num_variable_states}),). "
                 f"Got {flat_beliefs.shape}"
             )
@@ -1137,7 +1135,7 @@ def compute_flat_beliefs(bp_arrays, var_states_for_edges):
 
         return unflatten_beliefs(
             compute_flat_beliefs(bp_arrays, var_states_for_edges),
-            bp_state.fg_state.variable_group,
+            bp_state.fg_state.variable_groups,
         )
 
     bp = BeliefPropagation(
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index 1752a002..06bc80cb 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -31,27 +31,26 @@ class VariableGroup:
     """
 
     def __getitem__(self, val):
-        """Given a name, retrieve the associated Variable.
+        """Given a variable name, retrieve the associated Variable.
 
         Args:
             val: a single name corresponding to a single variable, or a list of such names
 
         Returns:
             A single variable if the name is not a list. A list of variables if name is a list
-
-        Raises:
-            ValueError: if the name is not found in the group
         """
         raise NotImplementedError(
             "Please subclass the VariableGroup class and override this method"
         )
 
     @cached_property
-    def variables_names(self) -> Any:
-        """Function that generates a dictionary mapping names to variables.
+    def variables(self) -> Tuple[Any, int]:
+        """Function that returns the list of variables. Each variable is represented
+        by a tuple of the variable name (which can be a hash or a string) and its
+        number of states.
 
         Returns:
-            a dictionary mapping all possible names to different variables.
+           List of variables in the VariableGroup
         """
         raise NotImplementedError(
             "Please subclass the VariableGroup class and override this method"
@@ -205,7 +204,7 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any:
             "Please subclass the FactorGroup class and override this method"
         )
 
-    def compile_wiring(self, vars_to_starts: Mapping[int, int]) -> Any:
+    def compile_wiring(self, vars_to_starts: Mapping[Tuple[int, int], int]) -> Any:
         """Compile an efficient wiring for the FactorGroup.
 
         Args:
diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py
index 97e6bd3b..67a998ba 100644
--- a/pgmax/fg/nodes.py
+++ b/pgmax/fg/nodes.py
@@ -7,8 +7,6 @@
 import jax.numpy as jnp
 import numpy as np
 
-from pgmax import utils
-
 
 @jax.tree_util.register_pytree_node_class
 @dataclass(frozen=True, eq=False)
@@ -59,15 +57,15 @@ def __post_init__(self):
                 "Please implement compile_wiring in for your factor"
             )
 
-    @utils.cached_property
-    def edges_num_states(self) -> np.ndarray:
-        """Number of states for the variables connected to each edge
+    # @utils.cached_property
+    # def edges_num_states(self) -> np.ndarray:
+    #     """Number of states for the variables connected to each edge
 
-        Returns:
-            Array of shape (num_edges,)
-            Number of states for the variables connected to each edge
-        """
-        return self.vars_to_num_states.values()
+    #     Returns:
+    #         Array of shape (num_edges,)
+    #         Number of states for the variables connected to each edge
+    #     """
+    #     return self.variables.values()
 
     @staticmethod
     def concatenate_wirings(wirings: Sequence) -> Wiring:
diff --git a/pgmax/groups/enumeration.py b/pgmax/groups/enumeration.py
index 8c94ad46..982a8a57 100644
--- a/pgmax/groups/enumeration.py
+++ b/pgmax/groups/enumeration.py
@@ -109,16 +109,12 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
         num_factors = len(self.factors)
         if (
             data.shape != (num_factors, self.factor_configs.shape[0])
-            and data.shape
-            != (
-                num_factors,
-                np.sum(self.factors[0].edges_num_states),
-            )
+            and data.shape != (num_factors, np.prod(self.factor_configs.shape))
             and data.shape != (self.factor_configs.shape[0],)
         ):
             raise ValueError(
                 f"data should be of shape {(num_factors, self.factor_configs.shape[0])} or "
-                f"{(num_factors, np.sum(self.factors[0].edges_num_states))} or "
+                f"{(num_factors, np.prod(self.factor_configs.shape))} or "
                 f"{(self.factor_configs.shape[0],)}. Got {data.shape}."
             )
 
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index a6f3e8d5..b59e360c 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -3,18 +3,19 @@
 import random
 from dataclasses import dataclass
 from functools import total_ordering
-from typing import List, Tuple, Union
+from typing import Any, List, Tuple, Union
 
 import jax
 import jax.numpy as jnp
 import numpy as np
 
+from pgmax.fg import groups
 from pgmax.utils import cached_property
 
 
 @total_ordering
 @dataclass(frozen=True, eq=False)
-class NDVariableArray:
+class NDVariableArray(groups.VariableGroup):
     """Subclass of VariableGroup for n-dimensional grids of variables.
 
     Args:
@@ -23,11 +24,8 @@ class NDVariableArray:
             the notion of a NumPy ndarray shape)
     """
 
-    # TODO: Variables = (hash, num_states)
-    # TODO: VariableGroup can be deleted
-
     shape: Tuple[int, ...]
-    num_states: Union[int, np.ndarray]
+    num_states: np.ndarray
 
     def __post_init__(self):
         # super().__post_init__()
@@ -52,7 +50,10 @@ def __lt__(self, other):
 
     def __getitem__(self, val):
         # Numpy indexation will throw IndexError for us if out-of-bounds
-        return (self.variable_names[val], self.num_states[val])
+        result = (self.variable_names[val], self.num_states[val])
+        if isinstance(val, slice):
+            return tuple(zip(result))
+        return result
 
     @cached_property
     def variables(self) -> List[Tuple]:
@@ -84,7 +85,7 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
         Raises:
             ValueError: If the data is not of the correct shape.
         """
-        # TODO: what should we do for different number of states -> look at maask_array
+        # TODO: what should we do for different number of states -> look at mask_array
         if data.shape != self.shape and data.shape != self.shape + (
             self.num_states.max(),
         ):
@@ -130,109 +131,105 @@ def unflatten(
         return data
 
 
-# TODO: delete? -- NO
-
-
-# @dataclass(frozen=True, eq=False)
-# class VariableDict():
-#     """A variable dictionary that contains a set of variables of the same size
-
-#     Args:
-#         num_states: The size of the variables in this variable group
-#         num_variables: The number of variables
-
-#     """
-
-#     num_states: int
-#     num_variables: int
-
-#     @cached_property
-#     def variable_names(self) -> np.ndarray:
-#         """Function that generates a dictionary mapping names to variables.
-
-#         Returns:
-#             a dictionary mapping all possible names to different variables.
-#         """
-#         return self.__hash__() + np.arange(self.num_states)
-
-
-#     def flatten(
-#         self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]]
-#     ) -> jnp.ndarray:
-#         """Function that turns meaningful structured data into a flat data array for internal use.
-
-#         Args:
-#             data: Meaningful structured data. Should be a mapping with names from self.variable_names.
-#                 Each value should be an array of shape (1,) (for e.g. MAP decodings) or
-#                 (self.num_states,) (for e.g. evidence, beliefs).
-
-#         Returns:
-#             A flat jnp.array for internal use
-
-#         Raises:
-#             ValueError if:
-#                 (1) data is referring to a non-existing variable
-#                 (2) data is not of the correct shape
-#         """
-#         for name in data:
-#             if name not in self._names_to_variables:
-#                 raise ValueError(
-#                     f"data is referring to a non-existent variable {name}."
-#                 )
-
-#             if data[name].shape != (self.num_states,) and data[name].shape != (1,):
-#                 raise ValueError(
-#                     f"Variable {name} expects a data array of shape "
-#                     f"{(self.num_states,)} or (1,). Got {data[name].shape}."
-#                 )
-
-#         flat_data = jnp.concatenate([data[name].flatten() for name in self.names])
-#         return flat_data
-
-#     def unflatten(
-#         self, flat_data: Union[np.ndarray, jnp.ndarray]
-#     ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]:
-#         """Function that recovers meaningful structured data from internal flat data array
-
-#         Args:
-#             flat_data: Internal flat data array.
-
-#         Returns:
-#             Meaningful structured data. Should be a mapping with names from self.variable_names.
-#                 Each value should be an array of shape (1,) (for e.g. MAP decodings) or
-#                 (self.num_states,) (for e.g. evidence, beliefs).
-
-#         Raises:
-#             ValueError if:
-#                 (1) flat_data is not a 1D array
-#                 (2) flat_data is not of the right shape
-#         """
-#         if flat_data.ndim != 1:
-#             raise ValueError(
-#                 f"Can only unflatten 1D array. Got a {flat_data.ndim}D array."
-#             )
-
-#         num_variables = len(self.variable_names)
-#         num_variable_states = len(self.variable_names) * self.num_states
-#         if flat_data.shape[0] == num_variables:
-#             use_num_states = False
-#         elif flat_data.shape[0] == num_variable_states:
-#             use_num_states = True
-#         else:
-#             raise ValueError(
-#                 f"flat_data should be either of shape (num_variables(={len(self.variables)}),), "
-#                 f"or (num_variable_states(={num_variable_states}),). "
-#                 f"Got {flat_data.shape}"
-#             )
-
-#         start = 0
-#         data = {}
-#         for name in self.variable_names:
-#             if use_num_states:
-#                 data[name] = flat_data[start : start + self.num_states]
-#                 start += self.num_states
-#             else:
-#                 data[name] = flat_data[np.array([start])]
-#                 start += 1
-
-#         return data
+@dataclass(frozen=True, eq=False)
+class VariableDict(groups.VariableGroup):
+    """A variable dictionary that contains a set of variables of the same size
+
+    Args:
+        num_states: The size of the variables in this variable group
+        variable_names: A tuple of all names of the variables in this variable group
+    """
+
+    num_states: int
+    variable_names: Tuple[Any, ...]
+
+    @cached_property
+    def variables(self) -> List[Tuple]:
+        return list(
+            zip(self.variable_names, [self.num_states] * len(self.variable_names))
+        )
+
+    def __getitem__(self, val):
+        # Numpy indexation will throw IndexError for us if out-of-bounds
+        return (val, self.num_states)
+
+    # def flatten(
+    #     self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]]
+    # ) -> jnp.ndarray:
+    #     """Function that turns meaningful structured data into a flat data array for internal use.
+
+    #     Args:
+    #         data: Meaningful structured data. Should be a mapping with names from self.variable_names.
+    #             Each value should be an array of shape (1,) (for e.g. MAP decodings) or
+    #             (self.num_states,) (for e.g. evidence, beliefs).
+
+    #     Returns:
+    #         A flat jnp.array for internal use
+
+    #     Raises:
+    #         ValueError if:
+    #             (1) data is referring to a non-existing variable
+    #             (2) data is not of the correct shape
+    #     """
+    #     for name in data:
+    #         if name not in self.variable_names:
+    #             raise ValueError(
+    #                 f"data is referring to a non-existent variable {name}."
+    #             )
+
+    #         if data[name].shape != (self.num_states,) and data[name].shape != (1,):
+    #             raise ValueError(
+    #                 f"Variable {name} expects a data array of shape "
+    #                 f"{(self.num_states,)} or (1,). Got {data[name].shape}."
+    #             )
+
+    #     flat_data = jnp.concatenate([data[name].flatten() for name in self.variable_names])
+    #     return flat_data
+
+    # def unflatten(
+    #     self, flat_data: Union[np.ndarray, jnp.ndarray]
+    # ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]:
+    #     """Function that recovers meaningful structured data from internal flat data array
+
+    #     Args:
+    #         flat_data: Internal flat data array.
+
+    #     Returns:
+    #         Meaningful structured data. Should be a mapping with names from self.variable_names.
+    #             Each value should be an array of shape (1,) (for e.g. MAP decodings) or
+    #             (self.num_states,) (for e.g. evidence, beliefs).
+
+    #     Raises:
+    #         ValueError if:
+    #             (1) flat_data is not a 1D array
+    #             (2) flat_data is not of the right shape
+    #     """
+    #     if flat_data.ndim != 1:
+    #         raise ValueError(
+    #             f"Can only unflatten 1D array. Got a {flat_data.ndim}D array."
+    #         )
+
+    #     num_variables = len(self.variable_names)
+    #     num_variable_states = len(self.variable_names) * self.num_states
+    #     if flat_data.shape[0] == num_variables:
+    #         use_num_states = False
+    #     elif flat_data.shape[0] == num_variable_states:
+    #         use_num_states = True
+    #     else:
+    #         raise ValueError(
+    #             f"flat_data should be either of shape (num_variables(={len(self.variables)}),), "
+    #             f"or (num_variable_states(={num_variable_states}),). "
+    #             f"Got {flat_data.shape}"
+    #         )
+
+    #     start = 0
+    #     data = {}
+    #     for name in self.variable_names:
+    #         if use_num_states:
+    #             data[name] = flat_data[start : start + self.num_states]
+    #             start += self.num_states
+    #         else:
+    #             data[name] = flat_data[np.array([start])]
+    #             start += 1
+
+    #     return data
diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py
index 9c88ed2a..9185536c 100644
--- a/tests/fg/test_graph.py
+++ b/tests/fg/test_graph.py
@@ -13,11 +13,11 @@
 
 
 def test_factor_graph():
-    variable_group = vgroup.VariableDict(15, (0,))
-    fg = graph.FactorGraph(variable_group)
+    vg = vgroup.VariableDict(15, (0,))
+    fg = graph.FactorGraph(vg)
     fg.add_factor_by_type(
         factor_type=enumeration_factor.EnumerationFactor,
-        variable_names=[0],
+        variables=[vg[0]],
         factor_configs=np.arange(15)[:, None],
         log_potentials=np.zeros(15),
         name="test",
@@ -27,7 +27,7 @@ def test_factor_graph():
         match="A factor group with the name test already exists. Please choose a different name",
     ):
         fg.add_factor(
-            variable_names=[0],
+            variables=[vg[0]],
             factor_configs=np.arange(15)[:, None],
             name="test",
         )
@@ -35,11 +35,11 @@ def test_factor_graph():
     with pytest.raises(
         ValueError,
         match=re.escape(
-            f"A Factor of type {enumeration_factor.EnumerationFactor} involving variables {frozenset([0])} already exists."
+            f"A Factor of type {enumeration_factor.EnumerationFactor} involving variables {frozenset([(0, 15)])} already exists."
         ),
     ):
         fg.add_factor(
-            variable_names=[0],
+            variables=[vg[0]],
             factor_configs=np.arange(10)[:, None],
         )
 
@@ -49,46 +49,43 @@ def test_factor_graph():
             f"Type {groups.FactorGroup} is not one of the supported factor types {FAC_TO_VAR_UPDATES.keys()}"
         ),
     ):
-        fg.add_factor_by_type(variable_names=[0], factor_type=groups.FactorGroup)
+        fg.add_factor_by_type(variables=[vg[0]], factor_type=groups.FactorGroup)
 
 
 def test_factor_adding():
     A = vgroup.NDVariableArray(num_states=2, shape=(10,))
     B = vgroup.NDVariableArray(num_states=2, shape=(10,))
-    fg = graph.FactorGraph(variables=dict(A=A, B=B))
+    fg = graph.FactorGraph(variables=[A, B])
 
     with pytest.raises(ValueError, match="Do not add a factor group with no factors."):
         fg.add_factor_group(
             factory=logical.ORFactorGroup,
-            variable_names_for_factors=[],
+            variables_for_factors=[],
         )
 
-    variables0 = [("A", 0), ("B", 0)]
-    variables1 = [("A", 1), ("B", 1)]
-    ORFactor = logical.ORFactorGroup(
-        fg._variable_group, variable_names_for_factors=[variables0]
-    )
+    variables0 = (A[0], B[0])
+    variables1 = (A[1], B[1])
+    ORFactor = logical.ORFactorGroup(variables_for_factors=[variables0])
     with pytest.raises(
         ValueError, match="SingleFactorGroup should only contain one factor. Got 2"
     ):
         groups.SingleFactorGroup(
-            variable_group=fg._variable_group,
-            variable_names_for_factors=[variables0, variables1],
+            variables_for_factors=[variables0, variables1],
             factor=ORFactor,
         )
 
 
 def test_bp_state():
-    variable_group = vgroup.VariableDict(15, (0,))
-    fg0 = graph.FactorGraph(variable_group)
+    vg = vgroup.VariableDict(15, (0,))
+    fg0 = graph.FactorGraph(vg)
     fg0.add_factor(
-        variable_names=[0],
+        variables=[vg[0]],
         factor_configs=np.arange(10)[:, None],
         name="test",
     )
-    fg1 = graph.FactorGraph(variable_group)
+    fg1 = graph.FactorGraph(vg)
     fg1.add_factor(
-        variable_names=[0],
+        variables=[vg[0]],
         factor_configs=np.arange(15)[:, None],
         name="test",
     )
@@ -103,79 +100,79 @@ def test_bp_state():
         )
 
 
-def test_log_potentials():
-    variable_group = vgroup.VariableDict(15, (0,))
-    fg = graph.FactorGraph(variable_group)
-    fg.add_factor(
-        variable_names=[0],
-        factor_configs=np.arange(10)[:, None],
-        name="test",
-    )
-    with pytest.raises(
-        ValueError,
-        match=re.escape("Expected log potentials shape (10,) for factor group test."),
-    ):
-        fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15))
-
-    with pytest.raises(
-        ValueError,
-        match=re.escape(f"Invalid name {frozenset([0])} for log potentials updates."),
-    ):
-        fg.bp_state.log_potentials[[0]] = np.zeros(10)
-
-    with pytest.raises(
-        ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)")
-    ):
-        graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15))
-
-    log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10))
-    assert jnp.all(log_potentials["test"] == jnp.zeros(10))
-    with pytest.raises(
-        ValueError,
-        match=re.escape(f"Invalid name {frozenset([1])} for log potentials updates."),
-    ):
-        fg.bp_state.log_potentials[[1]]
-
-
-def test_ftov_msgs():
-    variable_group = vgroup.VariableDict(15, (0,))
-    fg = graph.FactorGraph(variable_group)
-    fg.add_factor(
-        variable_names=[0],
-        factor_configs=np.arange(10)[:, None],
-        name="test",
-    )
-    with pytest.raises(
-        ValueError,
-        match=re.escape("Invalid names for setting messages"),
-    ):
-        fg.bp_state.ftov_msgs[[0], 0] = np.ones(10)
-
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            "Given belief shape (10,) does not match expected shape (15,) for variable 0"
-        ),
-    ):
-        fg.bp_state.ftov_msgs[0] = np.ones(10)
-
-    with pytest.raises(
-        ValueError, match=re.escape("Expected messages shape (15,). Got (10,)")
-    ):
-        graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10))
-
-    ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15))
-    with pytest.raises(
-        TypeError, match=re.escape("'FToVMessages' object is not subscriptable")
-    ):
-        ftov_msgs[(10,)]
+# def test_log_potentials():
+#     vg = vgroup.VariableDict(15, (0,))
+#     fg = graph.FactorGraph(vg)
+#     fg.add_factor(
+#         variables=[vg[0]],
+#         factor_configs=np.arange(10)[:, None],
+#         name="test",
+#     )
+#     with pytest.raises(
+#         ValueError,
+#         match=re.escape("Expected log potentials shape (10,) for factor group test."),
+#     ):
+#         fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15))
+
+#     with pytest.raises(
+#         ValueError,
+#         match=re.escape(f"Invalid name {frozenset([0])} for log potentials updates."),
+#     ):
+#         fg.bp_state.log_potentials[vg[0]] = np.zeros(10)
+
+#     with pytest.raises(
+#         ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)")
+#     ):
+#         graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15))
+
+#     log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10))
+#     assert jnp.all(log_potentials["test"] == jnp.zeros(10))
+#     with pytest.raises(
+#         ValueError,
+#         match=re.escape(f"Invalid name {frozenset([1])} for log potentials updates."),
+#     ):
+#         fg.bp_state.log_potentials[[1]]
+
+
+# def test_ftov_msgs():
+#     variable_group = vgroup.VariableDict(15, (0,))
+#     fg = graph.FactorGraph(variable_group)
+#     fg.add_factor(
+#         variable_names=[0],
+#         factor_configs=np.arange(10)[:, None],
+#         name="test",
+#     )
+#     with pytest.raises(
+#         ValueError,
+#         match=re.escape("Invalid names for setting messages"),
+#     ):
+#         fg.bp_state.ftov_msgs[[0], 0] = np.ones(10)
+
+#     with pytest.raises(
+#         ValueError,
+#         match=re.escape(
+#             "Given belief shape (10,) does not match expected shape (15,) for variable 0"
+#         ),
+#     ):
+#         fg.bp_state.ftov_msgs[0] = np.ones(10)
+
+#     with pytest.raises(
+#         ValueError, match=re.escape("Expected messages shape (15,). Got (10,)")
+#     ):
+#         graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10))
+
+#     ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15))
+#     with pytest.raises(
+#         TypeError, match=re.escape("'FToVMessages' object is not subscriptable")
+#     ):
+#         ftov_msgs[(10,)]
 
 
 def test_evidence():
-    variable_group = vgroup.VariableDict(15, (0,))
-    fg = graph.FactorGraph(variable_group)
+    vg = vgroup.VariableDict(15, (0,))
+    fg = graph.FactorGraph(vg)
     fg.add_factor(
-        variable_names=[0],
+        variables=[vg[0]],
         factor_configs=np.arange(10)[:, None],
         name="test",
     )
@@ -189,10 +186,10 @@ def test_evidence():
 
 
 def test_bp():
-    variable_group = vgroup.VariableDict(15, (0,))
-    fg = graph.FactorGraph(variable_group)
+    vg = vgroup.VariableDict(15, (0,))
+    fg = graph.FactorGraph(vg)
     fg.add_factor(
-        variable_names=[0],
+        variables=[vg[0]],
         factor_configs=np.arange(10)[:, None],
         name="test",
     )
@@ -200,7 +197,7 @@ def test_bp():
     bp_arrays = bp.update()
     bp_arrays = bp.update(
         bp_arrays=bp_arrays,
-        ftov_msgs_updates={0: np.zeros(15)},
+        ftov_msgs_updates={vg[0]: np.zeros(15)},
     )
     bp_arrays = bp.run_bp(bp_arrays, num_iters=1)
     bp_arrays = replace(bp_arrays, log_potentials=jnp.zeros((10)))
diff --git a/tests/fg/test_nodes.py b/tests/fg/test_nodes.py
index e9abf9bc..27445726 100644
--- a/tests/fg/test_nodes.py
+++ b/tests/fg/test_nodes.py
@@ -5,36 +5,37 @@
 
 from pgmax.factors import enumeration, logical
 from pgmax.fg import nodes
+from pgmax.groups import variables as vgroup
 
 
 def test_enumeration_factor():
-    variable = nodes.Variable(3)
+    variables = vgroup.NDVariableArray(num_states=3, shape=(1,))
 
     with pytest.raises(
         NotImplementedError, match="Please implement compile_wiring in for your factor"
     ):
         nodes.Factor(
-            variables=(variable,),
+            variables=[variables[0]],
             log_potentials=np.array([0.0]),
         )
 
     with pytest.raises(ValueError, match="Configurations should be integers. Got"):
         enumeration.EnumerationFactor(
-            variables=(variable,),
+            variables=[variables[0]],
             factor_configs=np.array([[1.0]]),
             log_potentials=np.array([0.0]),
         )
 
     with pytest.raises(ValueError, match="Potential should be floats. Got"):
         enumeration.EnumerationFactor(
-            variables=(variable,),
+            variables=[variables[0]],
             factor_configs=np.array([[1]]),
             log_potentials=np.array([0]),
         )
 
     with pytest.raises(ValueError, match="factor_configs should be a 2D array"):
         enumeration.EnumerationFactor(
-            variables=(variable,),
+            variables=[variables[0]],
             factor_configs=np.array([1]),
             log_potentials=np.array([0.0]),
         )
@@ -46,7 +47,7 @@ def test_enumeration_factor():
         ),
     ):
         enumeration.EnumerationFactor(
-            variables=(variable,),
+            variables=[variables[0]],
             factor_configs=np.array([[1, 2]]),
             log_potentials=np.array([0.0]),
         )
@@ -55,27 +56,27 @@ def test_enumeration_factor():
         ValueError, match=re.escape("Expected log potentials of shape (1,)")
     ):
         enumeration.EnumerationFactor(
-            variables=(variable,),
+            variables=[variables[0]],
             factor_configs=np.array([[1]]),
             log_potentials=np.array([0.0, 1.0]),
         )
 
     with pytest.raises(ValueError, match="Invalid configurations for given variables"):
         enumeration.EnumerationFactor(
-            variables=(variable,),
+            variables=[variables[0]],
             factor_configs=np.array([[10]]),
             log_potentials=np.array([0.0]),
         )
 
 
 def test_logical_factor():
-    child = nodes.Variable(2)
-    wrong_parent = nodes.Variable(3)
-    parent = nodes.Variable(2)
+    child = vgroup.NDVariableArray(num_states=2, shape=(1,))[0]
+    wrong_parent = vgroup.NDVariableArray(num_states=3, shape=(1,))[0]
+    parent = vgroup.NDVariableArray(num_states=2, shape=(1,))[0]
 
     with pytest.raises(
         ValueError,
-        match="At least one parent variable and one child variable is required",
+        match="A LogicalFactor requires at least one parent variable and one child variable",
     ):
         logical.LogicalFactor(
             variables=(child,),
@@ -100,7 +101,7 @@ def test_logical_factor():
 
     with pytest.raises(ValueError, match="The highest LogicalFactor index must be 0"):
         logical.LogicalWiring(
-            edges_num_states=logical_factor.edges_num_states,
+            edges_num_states=[2, 2],
             var_states_for_edges=None,
             parents_edge_states=parents_edge_states + np.array([[1, 0]]),
             children_edge_states=child_edge_state,
@@ -112,7 +113,7 @@ def test_logical_factor():
         match="The LogicalWiring must have 1 different LogicalFactor indices",
     ):
         logical.LogicalWiring(
-            edges_num_states=logical_factor.edges_num_states,
+            edges_num_states=[2, 2],
             var_states_for_edges=None,
             parents_edge_states=parents_edge_states + np.array([[0], [1]]),
             children_edge_states=child_edge_state,
@@ -126,7 +127,7 @@ def test_logical_factor():
         ),
     ):
         logical.LogicalWiring(
-            edges_num_states=logical_factor.edges_num_states,
+            edges_num_states=[2, 2],
             var_states_for_edges=None,
             parents_edge_states=parents_edge_states,
             children_edge_states=child_edge_state,

From 7e55e5c425f94bd946f2142f33b20c8c6249905e Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Thu, 21 Apr 2022 19:47:14 +0000
Subject: [PATCH 07/35] Tests passing

---
 pgmax/fg/graph.py           |  49 ++++-------
 pgmax/groups/enumeration.py |  18 ++--
 pgmax/groups/variables.py   | 170 ++++++++++++++++++------------------
 tests/factors/test_and.py   |  14 +--
 tests/factors/test_or.py    |  14 +--
 tests/fg/test_graph.py      | 127 +++++++++++++--------------
 tests/fg/test_groups.py     | 123 +++++---------------------
 tests/test_pgmax.py         |  99 ++++++++++-----------
 8 files changed, 258 insertions(+), 356 deletions(-)

diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 3e5cb570..264a6ab0 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -61,21 +61,13 @@ def __post_init__(self):
 
         start = time.time()
         if isinstance(self.variables, (vgroup.NDVariableArray, vgroup.VariableDict)):
-            self.variable_groups = [self.variables]
+            self._variable_groups = [self.variables]
         else:
-            self.variable_groups = self.variables
-
-        # TODO: remove?
-        self._variable_group = self.variable_groups
-        # self._variable_group: Mapping[
-        #     int, groups.VariableGroup
-        # ] = collections.OrderedDict()
-        # for variable_group in self.variables:
-        #     self._variable_group[variable_group.__hash__()] = variable_group
+            self._variable_groups = self.variables
 
         self._variables = [
             variable
-            for variable_group in self.variable_groups
+            for variable_group in self._variable_groups
             for variable in variable_group.variables
         ]
         print("1", time.time() - start)
@@ -368,7 +360,7 @@ def fg_state(self) -> FactorGraphState:
         )
 
         return FactorGraphState(
-            variable_groups=self._variable_group,
+            variable_groups=self._variable_groups,
             vars_to_starts=self._vars_to_starts,
             num_var_states=self._num_var_states,
             total_factor_num_states=self._total_factor_num_states,
@@ -610,12 +602,12 @@ def update_ftov_msgs(
         (1) provided ftov_msgs shape does not match the expected ftov_msgs shape.
         (2) provided name is not valid for ftov_msgs updates.
     """
-    for names, data in updates.items():
-        if names in fg_state.variable_groups:
-            if data.shape != (names.total_num_states,):
+    for variable, data in updates.items():
+        if variable in fg_state.vars_to_starts:
+            if data.shape != (variable[1],):
                 raise ValueError(
                     f"Given belief shape {data.shape} does not match expected "
-                    f"shape {(names.total_num_states,)} for variable."
+                    f"shape {(variable[1],)} for variable {variable}."
                 )
 
             var_states_for_edges = np.concatenate(
@@ -625,13 +617,13 @@ def update_ftov_msgs(
                 ]
             )
 
-            # starts = np.nonzero(
-            #     var_states_for_edges == fg_state.vars_to_starts[variable]
-            # )[0]
-            # for start in starts:
-            #     ftov_msgs = ftov_msgs.at[start : start + variable.num_states].set(
-            #         data / starts.shape[0]
-            #     )
+            starts = np.nonzero(
+                var_states_for_edges == fg_state.vars_to_starts[variable]
+            )[0]
+            for start in starts:
+                ftov_msgs = ftov_msgs.at[start : start + variable[1]].set(
+                    data / starts.shape[0]
+                )
         else:
             raise ValueError(
                 "Invalid names for setting messages. "
@@ -691,26 +683,19 @@ def __setitem__(
     @typing.overload
     def __setitem__(
         self,
-        names: Any,
+        names: Tuple[int, int],
         data: Union[np.ndarray, jnp.ndarray],
     ) -> None:
         """Spreading beliefs at a variable to all connected factors
 
         Args:
-            names: The name of the variable
+            variable: A tuple representing a variable
             data: An array containing the beliefs to be spread uniformly
                 across all factor to variable messages involving this
                 variable.
         """
 
     def __setitem__(self, names, data) -> None:
-        if (
-            isinstance(names, tuple)
-            and len(names) == 2
-            and names[1] in self.fg_state.variable_groups
-        ):
-            names = (frozenset(names[0]), names[1])
-
         object.__setattr__(
             self,
             "value",
diff --git a/pgmax/groups/enumeration.py b/pgmax/groups/enumeration.py
index 982a8a57..03f17cb3 100644
--- a/pgmax/groups/enumeration.py
+++ b/pgmax/groups/enumeration.py
@@ -107,14 +107,17 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
             ValueError: if data is not of the right shape.
         """
         num_factors = len(self.factors)
+        factor_edges_num_states = sum(
+            [variable[1] for variable in self.variables_for_factors[0]]
+        )
         if (
             data.shape != (num_factors, self.factor_configs.shape[0])
-            and data.shape != (num_factors, np.prod(self.factor_configs.shape))
+            and data.shape != (num_factors, factor_edges_num_states)
             and data.shape != (self.factor_configs.shape[0],)
         ):
             raise ValueError(
                 f"data should be of shape {(num_factors, self.factor_configs.shape[0])} or "
-                f"{(num_factors, np.prod(self.factor_configs.shape))} or "
+                f"{(num_factors, factor_edges_num_states)} or "
                 f"{(self.factor_configs.shape[0],)}. Got {data.shape}."
             )
 
@@ -149,18 +152,19 @@ def unflatten(
             )
 
         num_factors = len(self.factors)
+        factor_edges_num_states = sum(
+            [variable[1] for variable in self.variables_for_factors[0]]
+        )
         if flat_data.size == num_factors * self.factor_configs.shape[0]:
             data = flat_data.reshape(
                 (num_factors, self.factor_configs.shape[0]),
             )
-        elif flat_data.size == num_factors * np.sum(self.factors[0].edges_num_states):
-            data = flat_data.reshape(
-                (num_factors, np.sum(self.factors[0].edges_num_states))
-            )
+        elif flat_data.size == num_factors * np.sum(factor_edges_num_states):
+            data = flat_data.reshape((num_factors, np.sum(factor_edges_num_states)))
         else:
             raise ValueError(
                 f"flat_data should be compatible with shape {(num_factors, self.factor_configs.shape[0])} "
-                f"or {(num_factors, np.sum(self.factors[0].edges_num_states))}. Got {flat_data.shape}."
+                f"or {(num_factors, np.sum(factor_edges_num_states))}. Got {flat_data.shape}."
             )
 
         return data
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index b59e360c..28572ae1 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -3,7 +3,7 @@
 import random
 from dataclasses import dataclass
 from functools import total_ordering
-from typing import Any, List, Tuple, Union
+from typing import Any, Dict, Hashable, List, Mapping, Tuple, Union
 
 import jax
 import jax.numpy as jnp
@@ -28,8 +28,6 @@ class NDVariableArray(groups.VariableGroup):
     num_states: np.ndarray
 
     def __post_init__(self):
-        # super().__post_init__()
-
         if isinstance(self.num_states, int):
             num_states = np.full(self.shape, fill_value=self.num_states)
             object.__setattr__(self, "num_states", num_states)
@@ -124,7 +122,7 @@ def unflatten(
             data = flat_data.reshape(self.shape + (self.num_states.max(),))
         else:
             raise ValueError(
-                f"flat_data should be compatible with shape {self.shape} or {self.shape + (self.num_states,)}. "
+                f"flat_data should be compatible with shape {self.shape} or {self.shape + (self.num_states.max(),)}. "
                 f"Got {flat_data.shape}."
             )
 
@@ -140,8 +138,8 @@ class VariableDict(groups.VariableGroup):
         variable_names: A tuple of all names of the variables in this variable group
     """
 
-    num_states: int
     variable_names: Tuple[Any, ...]
+    num_states: int
 
     @cached_property
     def variables(self) -> List[Tuple]:
@@ -153,83 +151,85 @@ def __getitem__(self, val):
         # Numpy indexation will throw IndexError for us if out-of-bounds
         return (val, self.num_states)
 
-    # def flatten(
-    #     self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]]
-    # ) -> jnp.ndarray:
-    #     """Function that turns meaningful structured data into a flat data array for internal use.
-
-    #     Args:
-    #         data: Meaningful structured data. Should be a mapping with names from self.variable_names.
-    #             Each value should be an array of shape (1,) (for e.g. MAP decodings) or
-    #             (self.num_states,) (for e.g. evidence, beliefs).
-
-    #     Returns:
-    #         A flat jnp.array for internal use
-
-    #     Raises:
-    #         ValueError if:
-    #             (1) data is referring to a non-existing variable
-    #             (2) data is not of the correct shape
-    #     """
-    #     for name in data:
-    #         if name not in self.variable_names:
-    #             raise ValueError(
-    #                 f"data is referring to a non-existent variable {name}."
-    #             )
-
-    #         if data[name].shape != (self.num_states,) and data[name].shape != (1,):
-    #             raise ValueError(
-    #                 f"Variable {name} expects a data array of shape "
-    #                 f"{(self.num_states,)} or (1,). Got {data[name].shape}."
-    #             )
-
-    #     flat_data = jnp.concatenate([data[name].flatten() for name in self.variable_names])
-    #     return flat_data
-
-    # def unflatten(
-    #     self, flat_data: Union[np.ndarray, jnp.ndarray]
-    # ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]:
-    #     """Function that recovers meaningful structured data from internal flat data array
-
-    #     Args:
-    #         flat_data: Internal flat data array.
-
-    #     Returns:
-    #         Meaningful structured data. Should be a mapping with names from self.variable_names.
-    #             Each value should be an array of shape (1,) (for e.g. MAP decodings) or
-    #             (self.num_states,) (for e.g. evidence, beliefs).
-
-    #     Raises:
-    #         ValueError if:
-    #             (1) flat_data is not a 1D array
-    #             (2) flat_data is not of the right shape
-    #     """
-    #     if flat_data.ndim != 1:
-    #         raise ValueError(
-    #             f"Can only unflatten 1D array. Got a {flat_data.ndim}D array."
-    #         )
-
-    #     num_variables = len(self.variable_names)
-    #     num_variable_states = len(self.variable_names) * self.num_states
-    #     if flat_data.shape[0] == num_variables:
-    #         use_num_states = False
-    #     elif flat_data.shape[0] == num_variable_states:
-    #         use_num_states = True
-    #     else:
-    #         raise ValueError(
-    #             f"flat_data should be either of shape (num_variables(={len(self.variables)}),), "
-    #             f"or (num_variable_states(={num_variable_states}),). "
-    #             f"Got {flat_data.shape}"
-    #         )
-
-    #     start = 0
-    #     data = {}
-    #     for name in self.variable_names:
-    #         if use_num_states:
-    #             data[name] = flat_data[start : start + self.num_states]
-    #             start += self.num_states
-    #         else:
-    #             data[name] = flat_data[np.array([start])]
-    #             start += 1
-
-    #     return data
+    def flatten(
+        self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]]
+    ) -> jnp.ndarray:
+        """Function that turns meaningful structured data into a flat data array for internal use.
+
+        Args:
+            data: Meaningful structured data. Should be a mapping with names from self.variable_names.
+                Each value should be an array of shape (1,) (for e.g. MAP decodings) or
+                (self.num_states,) (for e.g. evidence, beliefs).
+
+        Returns:
+            A flat jnp.array for internal use
+
+        Raises:
+            ValueError if:
+                (1) data is referring to a non-existing variable
+                (2) data is not of the correct shape
+        """
+        for name in data:
+            if name not in self.variable_names:
+                raise ValueError(
+                    f"data is referring to a non-existent variable {name}."
+                )
+
+            if data[name].shape != (self.num_states,) and data[name].shape != (1,):
+                raise ValueError(
+                    f"Variable {name} expects a data array of shape "
+                    f"{(self.num_states,)} or (1,). Got {data[name].shape}."
+                )
+
+        flat_data = jnp.concatenate(
+            [data[name].flatten() for name in self.variable_names]
+        )
+        return flat_data
+
+    def unflatten(
+        self, flat_data: Union[np.ndarray, jnp.ndarray]
+    ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]:
+        """Function that recovers meaningful structured data from internal flat data array
+
+        Args:
+            flat_data: Internal flat data array.
+
+        Returns:
+            Meaningful structured data. Should be a mapping with names from self.variable_names.
+                Each value should be an array of shape (1,) (for e.g. MAP decodings) or
+                (self.num_states,) (for e.g. evidence, beliefs).
+
+        Raises:
+            ValueError if:
+                (1) flat_data is not a 1D array
+                (2) flat_data is not of the right shape
+        """
+        if flat_data.ndim != 1:
+            raise ValueError(
+                f"Can only unflatten 1D array. Got a {flat_data.ndim}D array."
+            )
+
+        num_variables = len(self.variable_names)
+        num_variable_states = len(self.variable_names) * self.num_states
+        if flat_data.shape[0] == num_variables:
+            use_num_states = False
+        elif flat_data.shape[0] == num_variable_states:
+            use_num_states = True
+        else:
+            raise ValueError(
+                f"flat_data should be either of shape (num_variables(={len(self.variables)}),), "
+                f"or (num_variable_states(={num_variable_states}),). "
+                f"Got {flat_data.shape}"
+            )
+
+        start = 0
+        data = {}
+        for name in self.variable_names:
+            if use_num_states:
+                data[name] = flat_data[start : start + self.num_states]
+                start += self.num_states
+            else:
+                data[name] = flat_data[np.array([start])]
+                start += 1
+
+        return data
diff --git a/tests/factors/test_and.py b/tests/factors/test_and.py
index e95e22a8..19bc08d0 100644
--- a/tests/factors/test_and.py
+++ b/tests/factors/test_and.py
@@ -58,23 +58,23 @@ def test_run_bp_with_ANDFactors():
         variables_for_factors1 = []
         variables_for_factors2 = []
         for factor_idx in range(num_factors):
-            variable_names1 = [
+            variables1 = [
                 parents_variables1[idx]
                 for idx in range(
                     num_parents_cumsum[factor_idx],
                     num_parents_cumsum[factor_idx + 1],
                 )
             ] + [children_variables1[factor_idx]]
-            variables_for_factors1.append(variable_names1)
+            variables_for_factors1.append(variables1)
 
-            variable_names2 = [
+            variables2 = [
                 parents_variables2[idx]
                 for idx in range(
                     num_parents_cumsum[factor_idx],
                     num_parents_cumsum[factor_idx + 1],
                 )
             ] + [children_variables2[factor_idx]]
-            variables_for_factors2.append(variable_names2)
+            variables_for_factors2.append(variables2)
 
         # Option 1: Define EnumerationFactors equivalent to the ANDFactors
         for factor_idx in range(num_factors):
@@ -96,7 +96,7 @@ def test_run_bp_with_ANDFactors():
             if factor_idx < num_factors // 2:
                 # Add the first half of factors to FactorGraph1
                 fg1.add_factor(
-                    variable_names=variables_for_factors1[factor_idx],
+                    variables=variables_for_factors1[factor_idx],
                     factor_configs=valid_configs,
                     log_potentials=np.zeros(valid_configs.shape[0]),
                 )
@@ -104,14 +104,14 @@ def test_run_bp_with_ANDFactors():
                 if idx != 0:
                     # Add the second half of factors to FactorGraph2
                     fg2.add_factor(
-                        variable_names=variables_for_factors2[factor_idx],
+                        variables=variables_for_factors2[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
                 else:
                     # Add all the EnumerationFactors to FactorGraph1 for the first iter
                     fg1.add_factor(
-                        variable_names=variables_for_factors1[factor_idx],
+                        variables=variables_for_factors1[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
diff --git a/tests/factors/test_or.py b/tests/factors/test_or.py
index 7a6c0905..162a7338 100644
--- a/tests/factors/test_or.py
+++ b/tests/factors/test_or.py
@@ -58,23 +58,23 @@ def test_run_bp_with_ORFactors():
         variables_for_factors1 = []
         variables_for_factors2 = []
         for factor_idx in range(num_factors):
-            variable_names1 = [
+            variables1 = [
                 parents_variables1[idx]
                 for idx in range(
                     num_parents_cumsum[factor_idx],
                     num_parents_cumsum[factor_idx + 1],
                 )
             ] + [children_variables1[factor_idx]]
-            variables_for_factors1.append(variable_names1)
+            variables_for_factors1.append(variables1)
 
-            variable_names2 = [
+            variables2 = [
                 parents_variables2[idx]
                 for idx in range(
                     num_parents_cumsum[factor_idx],
                     num_parents_cumsum[factor_idx + 1],
                 )
             ] + [children_variables2[factor_idx]]
-            variables_for_factors2.append(variable_names2)
+            variables_for_factors2.append(variables2)
 
         # Option 1: Define EnumerationFactors equivalent to the ORFactors
         for factor_idx in range(num_factors):
@@ -94,7 +94,7 @@ def test_run_bp_with_ORFactors():
             if factor_idx < num_factors // 2:
                 # Add the first half of factors to FactorGraph1
                 fg1.add_factor(
-                    variable_names=variables_for_factors1[factor_idx],
+                    variables=variables_for_factors1[factor_idx],
                     factor_configs=valid_configs,
                     log_potentials=np.zeros(valid_configs.shape[0]),
                 )
@@ -102,14 +102,14 @@ def test_run_bp_with_ORFactors():
                 if idx != 0:
                     # Add the second half of factors to FactorGraph2
                     fg2.add_factor(
-                        variable_names=variables_for_factors2[factor_idx],
+                        variables=variables_for_factors2[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
                 else:
                     # Add all the EnumerationFactors to FactorGraph1 for the first iter
                     fg1.add_factor(
-                        variable_names=variables_for_factors1[factor_idx],
+                        variables=variables_for_factors1[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py
index 9185536c..ea2a7500 100644
--- a/tests/fg/test_graph.py
+++ b/tests/fg/test_graph.py
@@ -100,72 +100,67 @@ def test_bp_state():
         )
 
 
-# def test_log_potentials():
-#     vg = vgroup.VariableDict(15, (0,))
-#     fg = graph.FactorGraph(vg)
-#     fg.add_factor(
-#         variables=[vg[0]],
-#         factor_configs=np.arange(10)[:, None],
-#         name="test",
-#     )
-#     with pytest.raises(
-#         ValueError,
-#         match=re.escape("Expected log potentials shape (10,) for factor group test."),
-#     ):
-#         fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15))
-
-#     with pytest.raises(
-#         ValueError,
-#         match=re.escape(f"Invalid name {frozenset([0])} for log potentials updates."),
-#     ):
-#         fg.bp_state.log_potentials[vg[0]] = np.zeros(10)
-
-#     with pytest.raises(
-#         ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)")
-#     ):
-#         graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15))
-
-#     log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10))
-#     assert jnp.all(log_potentials["test"] == jnp.zeros(10))
-#     with pytest.raises(
-#         ValueError,
-#         match=re.escape(f"Invalid name {frozenset([1])} for log potentials updates."),
-#     ):
-#         fg.bp_state.log_potentials[[1]]
-
-
-# def test_ftov_msgs():
-#     variable_group = vgroup.VariableDict(15, (0,))
-#     fg = graph.FactorGraph(variable_group)
-#     fg.add_factor(
-#         variable_names=[0],
-#         factor_configs=np.arange(10)[:, None],
-#         name="test",
-#     )
-#     with pytest.raises(
-#         ValueError,
-#         match=re.escape("Invalid names for setting messages"),
-#     ):
-#         fg.bp_state.ftov_msgs[[0], 0] = np.ones(10)
-
-#     with pytest.raises(
-#         ValueError,
-#         match=re.escape(
-#             "Given belief shape (10,) does not match expected shape (15,) for variable 0"
-#         ),
-#     ):
-#         fg.bp_state.ftov_msgs[0] = np.ones(10)
-
-#     with pytest.raises(
-#         ValueError, match=re.escape("Expected messages shape (15,). Got (10,)")
-#     ):
-#         graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10))
-
-#     ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15))
-#     with pytest.raises(
-#         TypeError, match=re.escape("'FToVMessages' object is not subscriptable")
-#     ):
-#         ftov_msgs[(10,)]
+def test_log_potentials():
+    vg = vgroup.VariableDict(15, (0,))
+    fg = graph.FactorGraph(vg)
+    fg.add_factor(
+        variables=[vg[0]],
+        factor_configs=np.arange(10)[:, None],
+        name="test",
+    )
+    with pytest.raises(
+        ValueError,
+        match=re.escape("Expected log potentials shape (10,) for factor group test."),
+    ):
+        fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15))
+
+    with pytest.raises(
+        ValueError,
+        match=re.escape("Invalid name (0, 15) for log potentials updates."),
+    ):
+        fg.bp_state.log_potentials[vg[0]] = np.zeros(10)
+
+    with pytest.raises(
+        ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)")
+    ):
+        graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15))
+
+    log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10))
+    assert jnp.all(log_potentials["test"] == jnp.zeros(10))
+
+
+def test_ftov_msgs():
+    vg = vgroup.VariableDict(15, (0,))
+    fg = graph.FactorGraph(vg)
+    fg.add_factor(
+        variables=[vg[0]],
+        factor_configs=np.arange(10)[:, None],
+        name="test",
+    )
+    with pytest.raises(
+        ValueError,
+        match=re.escape("Invalid names for setting messages"),
+    ):
+        fg.bp_state.ftov_msgs[0] = np.ones(10)
+
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            "Given belief shape (10,) does not match expected shape (15,) for variable (0, 15)."
+        ),
+    ):
+        fg.bp_state.ftov_msgs[vg[0]] = np.ones(10)
+
+    with pytest.raises(
+        ValueError, match=re.escape("Expected messages shape (15,). Got (10,)")
+    ):
+        graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10))
+
+    ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15))
+    with pytest.raises(
+        TypeError, match=re.escape("'FToVMessages' object is not subscriptable")
+    ):
+        ftov_msgs[(10,)]
 
 
 def test_evidence():
diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py
index 40fe3351..9a0fd312 100644
--- a/tests/fg/test_groups.py
+++ b/tests/fg/test_groups.py
@@ -5,83 +5,10 @@
 import numpy as np
 import pytest
 
-from pgmax.fg import groups, nodes
 from pgmax.groups import enumeration
 from pgmax.groups import variables as vgroup
 
 
-def test_composite_variable_group():
-    variable_dict1 = vgroup.VariableDict(15, tuple([0, 1, 2]))
-    variable_dict2 = vgroup.VariableDict(15, tuple([0, 1, 2]))
-    composite_variable_sequence = groups.CompositeVariableGroup(
-        [variable_dict1, variable_dict2]
-    )
-    composite_variable_dict = groups.CompositeVariableGroup(
-        {(0, 1): variable_dict1, (2, 3): variable_dict2}
-    )
-    with pytest.raises(ValueError, match="The name needs to have at least 2 elements"):
-        composite_variable_sequence[(0,)]
-
-    assert composite_variable_sequence[0, 1] == variable_dict1[1]
-    assert (
-        composite_variable_sequence[[(0, 1), (1, 2)]]
-        == composite_variable_dict[[((0, 1), 1), ((2, 3), 2)]]
-    )
-    assert composite_variable_dict[(0, 1), 0] == variable_dict1[0]
-    assert composite_variable_dict[[((0, 1), 1), ((2, 3), 2)]] == [
-        variable_dict1[1],
-        variable_dict2[2],
-    ]
-    assert jnp.all(
-        composite_variable_sequence.flatten(
-            [{name: np.zeros(15) for name in range(3)} for _ in range(2)]
-        )
-        == composite_variable_dict.flatten(
-            {
-                (0, 1): {name: np.zeros(15) for name in range(3)},
-                (2, 3): {name: np.zeros(15) for name in range(3)},
-            }
-        )
-    )
-    assert jnp.all(
-        jnp.array(
-            jax.tree_util.tree_leaves(
-                jax.tree_util.tree_multimap(
-                    lambda x, y: jnp.all(x == y),
-                    composite_variable_sequence.unflatten(jnp.zeros(15 * 3 * 2)),
-                    [{name: jnp.zeros(15) for name in range(3)} for _ in range(2)],
-                )
-            )
-        )
-    )
-    assert jnp.all(
-        jnp.array(
-            jax.tree_util.tree_leaves(
-                jax.tree_util.tree_multimap(
-                    lambda x, y: jnp.all(x == y),
-                    composite_variable_dict.unflatten(jnp.zeros(3 * 2)),
-                    {
-                        (0, 1): {name: np.zeros(1) for name in range(3)},
-                        (2, 3): {name: np.zeros(1) for name in range(3)},
-                    },
-                )
-            )
-        )
-    )
-    with pytest.raises(
-        ValueError, match="Can only unflatten 1D array. Got a 2D array."
-    ):
-        composite_variable_dict.unflatten(jnp.zeros((10, 20)))
-
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            "flat_data should be either of shape (num_variables(=6),), or (num_variable_states(=90),)"
-        ),
-    ):
-        composite_variable_dict.unflatten(jnp.zeros((100)))
-
-
 def test_variable_dict():
     variable_dict = vgroup.VariableDict(15, tuple([0, 1, 2]))
     with pytest.raises(
@@ -123,9 +50,7 @@ def test_variable_dict():
 
 
 def test_nd_variable_array():
-    variable_group = vgroup.NDVariableArray(2, (1,))
-    assert isinstance(variable_group[0], nodes.Variable)
-    variable_group = vgroup.NDVariableArray(3, (2, 2))
+    variable_group = vgroup.NDVariableArray(shape=(2, 2), num_states=3)
     with pytest.raises(
         ValueError,
         match=re.escape("data should be of shape (2, 2) or (2, 2, 3). Got (3, 3)."),
@@ -153,16 +78,15 @@ def test_nd_variable_array():
 
 
 def test_enumeration_factor_group():
-    variable_group = vgroup.NDVariableArray(3, (2, 2))
+    vg = vgroup.NDVariableArray(shape=(2, 2), num_states=3)
     with pytest.raises(
         ValueError,
         match=re.escape("Expected log potentials shape: (1,) or (2, 1). Got (3, 2)"),
     ):
         enumeration_factor_group = enumeration.EnumerationFactorGroup(
-            variable_group=variable_group,
-            variable_names_for_factors=[
-                [(0, 0), (0, 1), (1, 1)],
-                [(0, 1), (1, 0), (1, 1)],
+            variables_for_factors=[
+                [vg[0, 0], vg[0, 1], vg[1, 1]],
+                [vg[0, 1], vg[1, 0], vg[1, 1]],
             ],
             factor_configs=np.zeros((1, 3), dtype=int),
             log_potentials=np.zeros((3, 2)),
@@ -170,21 +94,22 @@ def test_enumeration_factor_group():
 
     with pytest.raises(ValueError, match=re.escape("Potentials should be floats")):
         enumeration_factor_group = enumeration.EnumerationFactorGroup(
-            variable_group=variable_group,
-            variable_names_for_factors=[
-                [(0, 0), (0, 1), (1, 1)],
-                [(0, 1), (1, 0), (1, 1)],
+            variables_for_factors=[
+                [vg[0, 0], vg[0, 1], vg[1, 1]],
+                [vg[0, 1], vg[1, 0], vg[1, 1]],
             ],
             factor_configs=np.zeros((1, 3), dtype=int),
             log_potentials=np.zeros((2, 1), dtype=int),
         )
 
     enumeration_factor_group = enumeration.EnumerationFactorGroup(
-        variable_group=variable_group,
-        variable_names_for_factors=[[(0, 0), (0, 1), (1, 1)], [(0, 1), (1, 0), (1, 1)]],
+        variables_for_factors=[
+            [vg[0, 0], vg[0, 1], vg[1, 1]],
+            [vg[0, 1], vg[1, 0], vg[1, 1]],
+        ],
         factor_configs=np.zeros((1, 3), dtype=int),
     )
-    name = [(0, 0), (1, 1)]
+    name = [vg[0, 0], vg[1, 1]]
     with pytest.raises(
         ValueError,
         match=re.escape(
@@ -194,7 +119,7 @@ def test_enumeration_factor_group():
         enumeration_factor_group[name]
 
     assert (
-        enumeration_factor_group[[(0, 1), (1, 0), (1, 1)]]
+        enumeration_factor_group[[vg[0, 1], vg[1, 0], vg[1, 1]]]
         == enumeration_factor_group.factors[1]
     )
     with pytest.raises(
@@ -227,20 +152,20 @@ def test_enumeration_factor_group():
 
 
 def test_pairwise_factor_group():
-    variable_group = vgroup.NDVariableArray(3, (2, 2))
+    vg = vgroup.NDVariableArray(shape=(2, 2), num_states=3)
 
     with pytest.raises(
         ValueError, match=re.escape("log_potential_matrix should be either a 2D array")
     ):
         enumeration.PairwiseFactorGroup(
-            variable_group, [[(0, 0), (1, 1)]], np.zeros((1,), dtype=float)
+            [[vg[0, 0], vg[1, 1]]], np.zeros((1,), dtype=float)
         )
 
     with pytest.raises(
         ValueError, match=re.escape("Potential matrix should be floats")
     ):
         enumeration.PairwiseFactorGroup(
-            variable_group, [[(0, 0), (1, 1)]], np.zeros((3, 3), dtype=int)
+            [[vg[0, 0], vg[1, 1]]], np.zeros((3, 3), dtype=int)
         )
 
     with pytest.raises(
@@ -250,7 +175,7 @@ def test_pairwise_factor_group():
         ),
     ):
         enumeration.PairwiseFactorGroup(
-            variable_group, [[(0, 0), (1, 1)]], np.zeros((2, 3, 3), dtype=float)
+            [[vg[0, 0], vg[1, 1]]], np.zeros((2, 3, 3), dtype=float)
         )
 
     with pytest.raises(
@@ -260,20 +185,18 @@ def test_pairwise_factor_group():
         ),
     ):
         enumeration.PairwiseFactorGroup(
-            variable_group, [[(0, 0), (1, 1), (0, 1)]], np.zeros((3, 3), dtype=float)
+            [[vg[0, 0], vg[1, 1], vg[0, 1]]], np.zeros((3, 3), dtype=float)
         )
 
+    name = [vg[0, 0], vg[1, 1]]
     with pytest.raises(
         ValueError,
-        match=re.escape("The specified pairwise factor [(0, 0), (1, 1)]"),
+        match=re.escape(f"The specified pairwise factor {name}"),
     ):
-        enumeration.PairwiseFactorGroup(
-            variable_group, [[(0, 0), (1, 1)]], np.zeros((4, 4), dtype=float)
-        )
+        enumeration.PairwiseFactorGroup([name], np.zeros((4, 4), dtype=float))
 
     pairwise_factor_group = enumeration.PairwiseFactorGroup(
-        variable_group,
-        [[(0, 0), (1, 1)], [(1, 0), (0, 1)]],
+        [[vg[0, 0], vg[1, 1]], [vg[1, 0], vg[0, 1]]],
     )
     with pytest.raises(
         ValueError,
diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py
index 270eb7d9..a5e720a8 100644
--- a/tests/test_pgmax.py
+++ b/tests/test_pgmax.py
@@ -214,18 +214,13 @@ def create_valid_suppression_config_arr(suppression_diameter):
     # We create a NDVariableArray such that the [0,i,j] entry corresponds to the vertical cut variable (i.e, the one
     # attached horizontally to the factor) that's at that location in the image, and the [1,i,j] entry corresponds to
     # the horizontal cut variable (i.e, the one attached vertically to the factor) that's at that location
-    grid_vars_group = vgroup.NDVariableArray(3, (2, M - 1, N - 1))
+    grid_vars = vgroup.NDVariableArray(shape=(2, M - 1, N - 1), num_states=3)
 
     # Make a group of additional variables for the edges of the grid
     extra_row_names: List[Tuple[Any, ...]] = [(0, row, N - 1) for row in range(M - 1)]
     extra_col_names: List[Tuple[Any, ...]] = [(1, M - 1, col) for col in range(N - 1)]
     additional_names = tuple(extra_row_names + extra_col_names)
-    additional_names_group = vgroup.VariableDict(3, additional_names)
-
-    # Combine these two VariableGroups into one CompositeVariableGroup
-    composite_grid_group = groups.CompositeVariableGroup(
-        {"grid_vars": grid_vars_group, "additional_vars": additional_names_group}
-    )
+    additional_vars = vgroup.VariableDict(additional_names, num_states=3)
 
     gt_has_cuts = gt_has_cuts.astype(np.int32)
 
@@ -249,17 +244,19 @@ def create_valid_suppression_config_arr(suppression_diameter):
                     size=evidence_vals_arr[1:].shape
                 )  # This adds logistic noise for every evidence entry
                 try:
-                    _ = composite_grid_group["grid_vars", i, row, col]
+                    _ = grid_vars[i, row, col]
                     grid_evidence_arr[i, row, col] = evidence_vals_arr
-                except ValueError:
+                except IndexError:
                     try:
-                        _ = composite_grid_group["additional_vars", i, row, col]
-                        additional_vars_evidence_dict[(i, row, col)] = evidence_vals_arr
-                    except ValueError:
+                        _ = additional_vars[i, row, col]
+                        additional_vars_evidence_dict[
+                            additional_vars[i, row, col]
+                        ] = evidence_vals_arr
+                    except IndexError:
                         pass
 
     # Create the factor graph
-    fg = graph.FactorGraph(variables=composite_grid_group)
+    fg = graph.FactorGraph(variables=[grid_vars, additional_vars])
 
     # Imperatively add EnumerationFactorGroups (each consisting of just one EnumerationFactor) to
     # the graph!
@@ -267,37 +264,37 @@ def create_valid_suppression_config_arr(suppression_diameter):
         for col in range(N - 1):
             if row != M - 2 and col != N - 2:
                 curr_names = [
-                    ("grid_vars", 0, row, col),
-                    ("grid_vars", 1, row, col),
-                    ("grid_vars", 0, row, col + 1),
-                    ("grid_vars", 1, row + 1, col),
+                    grid_vars[0, row, col],
+                    grid_vars[1, row, col],
+                    grid_vars[0, row, col + 1],
+                    grid_vars[1, row + 1, col],
                 ]
             elif row != M - 2:
                 curr_names = [
-                    ("grid_vars", 0, row, col),
-                    ("grid_vars", 1, row, col),
-                    ("additional_vars", 0, row, col + 1),
-                    ("grid_vars", 1, row + 1, col),
+                    grid_vars[0, row, col],
+                    grid_vars[1, row, col],
+                    additional_vars[0, row, col + 1],
+                    grid_vars[1, row + 1, col],
                 ]
 
             elif col != N - 2:
                 curr_names = [
-                    ("grid_vars", 0, row, col),
-                    ("grid_vars", 1, row, col),
-                    ("grid_vars", 0, row, col + 1),
-                    ("additional_vars", 1, row + 1, col),
+                    grid_vars[0, row, col],
+                    grid_vars[1, row, col],
+                    grid_vars[0, row, col + 1],
+                    additional_vars[1, row + 1, col],
                 ]
 
             else:
                 curr_names = [
-                    ("grid_vars", 0, row, col),
-                    ("grid_vars", 1, row, col),
-                    ("additional_vars", 0, row, col + 1),
-                    ("additional_vars", 1, row + 1, col),
+                    grid_vars[0, row, col],
+                    grid_vars[1, row, col],
+                    additional_vars[0, row, col + 1],
+                    additional_vars[1, row + 1, col],
                 ]
             if row % 2 == 0:
                 fg.add_factor(
-                    variable_names=curr_names,
+                    variables=curr_names,
                     factor_configs=valid_configs_non_supp,
                     log_potentials=np.zeros(
                         valid_configs_non_supp.shape[0], dtype=float
@@ -306,7 +303,7 @@ def create_valid_suppression_config_arr(suppression_diameter):
                 )
             else:
                 fg.add_factor(
-                    variable_names=curr_names,
+                    variables=curr_names,
                     factor_configs=valid_configs_non_supp,
                     log_potentials=np.zeros(
                         valid_configs_non_supp.shape[0], dtype=float
@@ -321,14 +318,14 @@ def create_valid_suppression_config_arr(suppression_diameter):
             if col != N - 1:
                 vert_suppression_names.append(
                     [
-                        ("grid_vars", 0, r, col)
+                        grid_vars[0, r, col]
                         for r in range(start_row, start_row + SUPPRESSION_DIAMETER)
                     ]
                 )
             else:
                 vert_suppression_names.append(
                     [
-                        ("additional_vars", 0, r, col)
+                        additional_vars[0, r, col]
                         for r in range(start_row, start_row + SUPPRESSION_DIAMETER)
                     ]
                 )
@@ -339,14 +336,14 @@ def create_valid_suppression_config_arr(suppression_diameter):
             if row != M - 1:
                 horz_suppression_names.append(
                     [
-                        ("grid_vars", 1, row, c)
+                        grid_vars[1, row, c]
                         for c in range(start_col, start_col + SUPPRESSION_DIAMETER)
                     ]
                 )
             else:
                 horz_suppression_names.append(
                     [
-                        ("additional_vars", 1, row, c)
+                        additional_vars[1, row, c]
                         for c in range(start_col, start_col + SUPPRESSION_DIAMETER)
                     ]
                 )
@@ -354,12 +351,12 @@ def create_valid_suppression_config_arr(suppression_diameter):
     # Add the suppression factors to the graph via kwargs
     fg.add_factor_group(
         factory=enumeration.EnumerationFactorGroup,
-        variable_names_for_factors=vert_suppression_names,
+        variables_for_factors=vert_suppression_names,
         factor_configs=valid_configs_supp,
     )
     fg.add_factor_group(
         factory=enumeration.EnumerationFactorGroup,
-        variable_names_for_factors=horz_suppression_names,
+        variables_for_factors=horz_suppression_names,
         factor_configs=valid_configs_supp,
         log_potentials=np.zeros(valid_configs_supp.shape[0], dtype=float),
     )
@@ -373,8 +370,8 @@ def create_valid_suppression_config_arr(suppression_diameter):
             for this_wiring in fg.fg_state.wiring.values()
         ]
     )
-    bp_state.evidence["grid_vars"] = grid_evidence_arr
-    bp_state.evidence["additional_vars"] = additional_vars_evidence_dict
+    bp_state.evidence[grid_vars] = grid_evidence_arr
+    bp_state.evidence[additional_vars] = additional_vars_evidence_dict
     bp = graph.BP(bp_state)
     bp_arrays = bp.run_bp(bp.init(), num_iters=100)
     # Test that the output messages are close to the true messages
@@ -388,15 +385,15 @@ def test_e2e_heretic():
     # Define some global constants
     im_size = (30, 30)
     # Instantiate all the Variables in the factor graph via VariableGroups
-    pixel_vars = vgroup.NDVariableArray(3, im_size)
+    pixel_vars = vgroup.NDVariableArray(shape=im_size, num_states=3)
     hidden_vars = vgroup.NDVariableArray(
-        17, (im_size[0] - 2, im_size[1] - 2)
+        shape=(im_size[0] - 2, im_size[1] - 2), num_states=17
     )  # Each hidden var is connected to a 3x3 patch of pixel vars
 
     bXn = np.zeros((30, 30, 3))
 
     # Create the factor graph
-    fg = graph.FactorGraph((pixel_vars, hidden_vars))
+    fg = graph.FactorGraph([pixel_vars, hidden_vars])
 
     def binary_connected_variables(
         num_hidden_rows, num_hidden_cols, kernel_row, kernel_col
@@ -406,8 +403,8 @@ def binary_connected_variables(
             for h_col in range(num_hidden_cols):
                 ret_list.append(
                     [
-                        (1, h_row, h_col),
-                        (0, h_row + kernel_row, h_col + kernel_col),
+                        hidden_vars[h_row, h_col],
+                        pixel_vars[h_row + kernel_row, h_col + kernel_col],
                     ]
                 )
         return ret_list
@@ -417,22 +414,20 @@ def binary_connected_variables(
         for k_col in range(3):
             fg.add_factor_group(
                 factory=enumeration.PairwiseFactorGroup,
-                variable_names_for_factors=binary_connected_variables(
-                    28, 28, k_row, k_col
-                ),
+                variables_for_factors=binary_connected_variables(28, 28, k_row, k_col),
                 log_potential_matrix=W_pot[:, :, k_row, k_col],
                 name=(k_row, k_col),
             )
 
     # Assign evidence to pixel vars
     bp_state = fg.bp_state
-    bp_state.evidence[0] = np.array(bXn)
-    bp_state.evidence[0, 0, 0] = np.array([0.0, 0.0, 0.0])
-    bp_state.evidence[0, 0, 0]
-    bp_state.evidence[1, 0, 0]
+    bp_state.evidence[pixel_vars] = np.array(bXn)
+    bp_state.evidence[pixel_vars[0, 0]] = np.array([0.0, 0.0, 0.0])
+    bp_state.evidence[pixel_vars[0, 0]]
+    bp_state.evidence[hidden_vars[0, 0]]
     assert isinstance(bp_state.evidence.value, np.ndarray)
     assert len(sum(fg.factors.values(), ())) == 7056
     bp = graph.BP(bp_state, temperature=1.0)
     bp_arrays = bp.run_bp(bp.init(), num_iters=1)
     marginals = graph.get_marginals(bp.get_beliefs(bp_arrays))
-    assert jnp.allclose(jnp.sum(marginals[0], axis=-1), 1.0)
+    assert jnp.allclose(jnp.sum(marginals[pixel_vars], axis=-1), 1.0)

From 9a2dcd215f322029a6ed5173bb5c893b03d99bfd Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Thu, 21 Apr 2022 21:01:09 +0000
Subject: [PATCH 08/35] Variables

---
 pgmax/groups/variables.py | 9 ++++-----
 1 file changed, 4 insertions(+), 5 deletions(-)

diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index 28572ae1..8dac5c7c 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -148,7 +148,8 @@ def variables(self) -> List[Tuple]:
         )
 
     def __getitem__(self, val):
-        # Numpy indexation will throw IndexError for us if out-of-bounds
+        if val not in self.variable_names:
+            raise ValueError(f"Variable {val} is not in VariableDict")
         return (val, self.num_states)
 
     def flatten(
@@ -170,7 +171,7 @@ def flatten(
                 (2) data is not of the correct shape
         """
         for name in data:
-            if name not in self.variable_names:
+            if name not in self.variables:
                 raise ValueError(
                     f"data is referring to a non-existent variable {name}."
                 )
@@ -181,9 +182,7 @@ def flatten(
                     f"{(self.num_states,)} or (1,). Got {data[name].shape}."
                 )
 
-        flat_data = jnp.concatenate(
-            [data[name].flatten() for name in self.variable_names]
-        )
+        flat_data = jnp.concatenate([data[name].flatten() for name in self.variables])
         return flat_data
 
     def unflatten(

From c6ae8d80948476525c00546f866f6cb49389160e Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Thu, 21 Apr 2022 23:16:37 +0000
Subject: [PATCH 09/35] Tests + mypy

---
 .pre-commit-config.yaml              | 10 ++---
 examples/gmrf.py                     | 31 +++++++++------
 examples/pmp_binary_deconvolution.py |  8 +---
 examples/rbm.py                      | 56 ++--------------------------
 examples/rcn.py                      |  8 ++--
 pgmax/fg/graph.py                    | 38 ++++++++++++-------
 pgmax/fg/groups.py                   |  1 +
 pgmax/groups/variables.py            |  9 ++++-
 tests/fg/test_graph.py               | 12 +++---
 tests/fg/test_groups.py              |  6 +--
 tests/test_pgmax.py                  | 35 ++++++++---------
 11 files changed, 94 insertions(+), 120 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 7e553c33..cfec6112 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -27,10 +27,10 @@ repos:
         args: ['--config', '.flake8.config','--exit-zero']
         verbose: true
 
-# -   repo: https://github.com/pre-commit/mirrors-mypy
-#     rev: 'v0.942'  # Use the sha / tag you want to point at
-#     hooks:
-#     -   id: mypy
-#         additional_dependencies: [tokenize-rt==3.2.0]
+-   repo: https://github.com/pre-commit/mirrors-mypy
+    rev: 'v0.942'  # Use the sha / tag you want to point at
+    hooks:
+    -   id: mypy
+        additional_dependencies: [tokenize-rt==3.2.0]
 ci:
   autoupdate_schedule: 'quarterly'
diff --git a/examples/gmrf.py b/examples/gmrf.py
index 5e8dc86b..95ae4a7a 100644
--- a/examples/gmrf.py
+++ b/examples/gmrf.py
@@ -61,31 +61,39 @@
 # Add top-down factors
 fg.add_factor_group(
     factory=enumeration.PairwiseFactorGroup,
-    variable_names_for_factors=[
-        [(ii, jj), (ii + 1, jj)] for ii in range(M - 1) for jj in range(N)
+    variables_for_factors=[
+        [variables[ii, jj], variables[ii + 1, jj]]
+        for ii in range(M - 1)
+        for jj in range(N)
     ],
     name="top_down",
 )
 # Add left-right factors
 fg.add_factor_group(
     factory=enumeration.PairwiseFactorGroup,
-    variable_names_for_factors=[
-        [(ii, jj), (ii, jj + 1)] for ii in range(M) for jj in range(N - 1)
+    variables_for_factors=[
+        [variables[ii, jj], variables[ii, jj + 1]]
+        for ii in range(M)
+        for jj in range(N - 1)
     ],
     name="left_right",
 )
 # Add diagonal factors
 fg.add_factor_group(
     factory=enumeration.PairwiseFactorGroup,
-    variable_names_for_factors=[
-        [(ii, jj), (ii + 1, jj + 1)] for ii in range(M - 1) for jj in range(N - 1)
+    variables_for_factors=[
+        [variables[ii, jj], variables[ii + 1, jj + 1]]
+        for ii in range(M - 1)
+        for jj in range(N - 1)
     ],
     name="diagonal0",
 )
 fg.add_factor_group(
     factory=enumeration.PairwiseFactorGroup,
-    variable_names_for_factors=[
-        [(ii, jj), (ii - 1, jj + 1)] for ii in range(1, M) for jj in range(N - 1)
+    variables_for_factors=[
+        [variables[ii, jj], variables[ii - 1, jj + 1]]
+        for ii in range(1, M)
+        for jj in range(N - 1)
     ],
     name="diagonal1",
 )
@@ -106,7 +114,7 @@
         bp.get_beliefs(
             bp.run_bp(
                 bp.init(
-                    evidence_updates={None: evidence},
+                    evidence_updates={variables: evidence},
                     log_potentials_updates=log_potentials,
                 ),
                 num_iters=15,
@@ -114,6 +122,7 @@
             )
         )
     )
+    marginals = marginals[variables]
     pred_image = np.argmax(
         np.stack(
             [
@@ -153,7 +162,7 @@ def loss(noisy_image, target_image, log_potentials):
         bp.get_beliefs(
             bp.run_bp(
                 bp.init(
-                    evidence_updates={None: evidence},
+                    evidence_updates={variables: evidence},
                     log_potentials_updates=log_potentials,
                 ),
                 num_iters=15,
@@ -161,7 +170,7 @@ def loss(noisy_image, target_image, log_potentials):
             )
         )
     )
-    logp = jnp.mean(jnp.log(jnp.sum(target * marginals, axis=-1)))
+    logp = jnp.mean(jnp.log(jnp.sum(target * marginals[variables], axis=-1)))
     return -logp
 
 
diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py
index bd418a69..6c31a537 100644
--- a/examples/pmp_binary_deconvolution.py
+++ b/examples/pmp_binary_deconvolution.py
@@ -107,13 +107,7 @@ def plot_images(images, display=True, nr=None):
 #
 # See Section 5.6 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) for more details.
 
-import imp
-
 # %%
-from pgmax.fg import graph
-
-imp.reload(graph)
-
 # The dimensions of W used for the generation of X were (4, 5, 5) but we set them to (5, 6, 6)
 # to simulate a more realistic scenario in which we do not know their ground truth values
 n_feat, feat_height, feat_width = 5, 6, 6
@@ -183,7 +177,7 @@ def plot_images(images, display=True, nr=None):
 
                             X_var = X[idx_img, idx_chan, idx_img_height, idx_img_width]
                             variables_for_ORFactors_dict[X_var].append(SW_var)
-print(time.time() - start)
+print("After loop", time.time() - start)
 
 # Add ANDFactorGroup, which is computationally efficient
 fg.add_factor_group(
diff --git a/examples/rbm.py b/examples/rbm.py
index abd7435d..30f5ac2d 100644
--- a/examples/rbm.py
+++ b/examples/rbm.py
@@ -44,14 +44,8 @@
 
 # %% [markdown]
 # We can then initialize the factor graph for the RBM with
-#
-# import imp
 
 # %%
-from pgmax.fg import graph
-
-imp.reload(graph)
-
 import time
 
 start = time.time()
@@ -61,49 +55,6 @@
 fg = graph.FactorGraph(variables=[hidden_variables, visible_variables])
 print("Time", time.time() - start)
 
-# %%
-import itertools
-
-start = time.time()
-variable_names_for_factors = factors = list(
-    map(
-        lambda ij: (
-            hidden_variables.variable_names[ij[0]],
-            visible_variables.variable_names[ij[1]],
-        ),
-        list(itertools.product(range(bh.shape[0]), range(bv.shape[0]))),
-    )
-)
-print("Time", time.time() - start, len(variable_names_for_factors))
-
-import numba as nb
-
-
-@nb.jit(parallel=False, cache=True, fastmath=True, nopython=True)
-def run_numba(h, v, f):
-    for h_idx in nb.prange(h.shape[0]):
-        for v_idx in nb.prange(v.shape[0]):
-            f[h_idx * v.shape[0] + v_idx, 0] = h[h_idx]
-            f[h_idx * v.shape[0] + v_idx, 1] = v[v_idx]
-
-
-start = time.time()
-variable_names_for_factors = np.empty(shape=(bv.shape[0] * bh.shape[0], 2), dtype=int)
-run_numba(
-    hidden_variables.variable_names,
-    visible_variables.variable_names,
-    variable_names_for_factors,
-)
-print("Time", time.time() - start, len(variable_names_for_factors))
-
-start = time.time()
-variable_names_for_factors = [
-    [hidden_variables[ii], visible_variables[jj]]
-    for ii in range(bh.shape[0])
-    for jj in range(bv.shape[0])
-]
-print("Time", time.time() - start, len(variable_names_for_factors))
-
 # %% [markdown]
 # [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup) (e.g. an [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)), or a list/dictionary of [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
 #
@@ -111,6 +62,7 @@ def run_numba(h, v, f):
 
 # %%
 start = time.time()
+
 # Add unary factors
 fg.add_factor_group(
     factory=enumeration.EnumerationFactorGroup,
@@ -136,11 +88,10 @@ def run_numba(h, v, f):
         for ii in range(bh.shape[0])
         for jj in range(bv.shape[0])
     ],
-    #    variable_names_for_factors=variable_names_for_factors,
     log_potential_matrix=log_potential_matrix,
 )
 
-# # %snakeviz fg.add_factor_group(factory=enumeration.PairwiseFactorGroup, variable_names_for_factors=v, log_potential_matrix=log_potential_matrix,)
+# # %snakeviz fg.add_factor_group(factory=enumeration.PairwiseFactorGroup, variables_for_factors=v, log_potential_matrix=log_potential_matrix,)
 print("Time", time.time() - start)
 
 
@@ -227,7 +178,8 @@ def run_numba(h, v, f):
 # %%
 fig, ax = plt.subplots(1, 1, figsize=(10, 10))
 ax.imshow(
-    graph.map_states(beliefs)[visible_variables].copy().reshape((28, 28)), cmap="gray"
+    graph.decode_map_states(beliefs)[visible_variables].copy().reshape((28, 28)),
+    cmap="gray",
 )
 ax.axis("off")
 
diff --git a/examples/rcn.py b/examples/rcn.py
index 44fc1b85..97de9fff 100644
--- a/examples/rcn.py
+++ b/examples/rcn.py
@@ -211,11 +211,11 @@ def fetch_mnist_dataset(test_size: int, seed: int = 5) -> tuple[np.ndarray, np.n
     2 * vps + 1
 )  # The number of pool choices for the different variables of the PGM.
 
-variables_all_models = {}
+variables_all_models = []
 for idx in range(frcs.shape[0]):
     frc = frcs[idx]
-    variables_all_models[idx] = vgroup.NDVariableArray(
-        num_states=M, shape=(frc.shape[0],)
+    variables_all_models.append(
+        vgroup.NDVariableArray(num_states=M, shape=(frc.shape[0],))
     )
 
 end = time.time()
@@ -280,7 +280,7 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray:
     for e in edge:
         i1, i2, r = e
         fg.add_factor(
-            [(idx, i1), (idx, i2)],
+            [variables_all_models[idx][i1], variables_all_models[idx][i2]],
             valid_configs_list[r],
         )
 
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 264a6ab0..77fc489a 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -83,7 +83,6 @@ def __post_init__(self):
         ] = collections.OrderedDict(
             [(factor_type, set()) for factor_type in FAC_TO_VAR_UPDATES]
         )
-        print("2", time.time() - start)
 
         # Used to add FactorGroups
         vars_num_states = [variable[1] for variable in self._variables]
@@ -92,13 +91,14 @@ def __post_init__(self):
             0,
             0,
         )
+        print("1", time.time() - start)
         # See FactorGraphState docstrings for documentation on the following fields
         self._num_var_states = vars_num_states_cumsum[-1]
-        self._vars_to_starts: OrderedDict[
-            Tuple[int, int], int
-        ] = collections.OrderedDict(zip(self._variables, vars_num_states_cumsum[:-1]))
+        self._vars_to_starts: Dict[Tuple[int, int], int] = collections.OrderedDict(
+            zip(self._variables, vars_num_states_cumsum[:-1])
+        )
         self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {}
-        print("3", time.time() - start)
+        print("2", time.time() - start)
 
     def __hash__(self) -> int:
         all_factor_groups = tuple(
@@ -1062,6 +1062,7 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
         Raises:
             ValueError: if flat_data is not of the right shape
         """
+
         if flat_beliefs.ndim != 1:
             raise ValueError(
                 f"Can only unflatten 1D array. Got a {flat_beliefs.ndim}D array."
@@ -1070,9 +1071,18 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
         num_variables = 0
         num_variable_states = 0
         for variable_group in variable_groups:
-            if isinstance(variable_group, vgroup.NDVariableArray):
-                num_variables += variable_group.num_states.size
-                num_variable_states += variable_group.num_states.sum()
+            variables = variable_group.variables
+            num_variables += len(variables)
+            num_variable_states += sum([variable[1] for variable in variables])
+
+            # if isinstance(variable_group, vgroup.NDVariableArray):
+            #     num_variables += variable_group.num_states.size
+            #     num_variable_states += variable_group.num_states.sum()
+            # elif isinstance(variable_group, vgroup.VariableDict):
+            #     num_variables += len(variable_group.variables)
+            #     num_variable_states += (
+            #         len(variable_group.variables) * variable_group.variables[0].num_states
+            #     )
 
         if flat_beliefs.shape[0] == num_variables:
             use_num_states = False
@@ -1088,18 +1098,20 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
         beliefs = {}
         start = 0
         for variable_group in variable_groups:
-            if use_num_states:
-                length = variable_group.num_states.sum()
+            variables = variable_group.variables
+            if not use_num_states:
+                length = len(variables)
+                # length = variable_group.num_states.sum()
             else:
-                length = variable_group.num_states.size
-
+                length = sum([variable[1] for variable in variables])
+                # length = variable_group.num_states.size
             beliefs[variable_group] = variable_group.unflatten(
                 flat_beliefs[start : start + length]
             )
             start += length
         return beliefs
 
-    @jax.jit
+    # @jax.jit
     def get_beliefs(bp_arrays: BPArrays) -> Dict[Hashable, Any]:
         """Function to calculate beliefs from a BPArrays
 
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index 06bc80cb..96cc4d19 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -143,6 +143,7 @@ def _variables_to_factors(self) -> Mapping[FrozenSet, nodes.Factor]:
     @cached_property
     def total_num_states(self) -> int:
         """TODO"""
+        # TODO: this could be returned by the wiring
         return sum(
             [
                 variable[1]
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index 8dac5c7c..3907b9f8 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -28,12 +28,17 @@ class NDVariableArray(groups.VariableGroup):
     num_states: np.ndarray
 
     def __post_init__(self):
-        if isinstance(self.num_states, int):
+        if np.isscalar(self.num_states) and np.issubdtype(type(np.int64(10)), int):
             num_states = np.full(self.shape, fill_value=self.num_states)
             object.__setattr__(self, "num_states", num_states)
-        elif isinstance(self.num_states, np.ndarray):
+        elif isinstance(self.num_states, np.ndarray) and np.issubdtype(
+            self.num_states.dtype, int
+        ):
             if self.num_states.shape != self.shape:
                 raise ValueError("Should be same shape")
+        else:
+            raise ValueError("num_states entries should be of type np.int")
+
         random_hash = random.randint(0, 2**63)
         object.__setattr__(self, "random_hash", random_hash)
 
diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py
index ea2a7500..52090b1a 100644
--- a/tests/fg/test_graph.py
+++ b/tests/fg/test_graph.py
@@ -13,7 +13,7 @@
 
 
 def test_factor_graph():
-    vg = vgroup.VariableDict(15, (0,))
+    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
     fg = graph.FactorGraph(vg)
     fg.add_factor_by_type(
         factor_type=enumeration_factor.EnumerationFactor,
@@ -76,7 +76,7 @@ def test_factor_adding():
 
 
 def test_bp_state():
-    vg = vgroup.VariableDict(15, (0,))
+    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
     fg0 = graph.FactorGraph(vg)
     fg0.add_factor(
         variables=[vg[0]],
@@ -101,7 +101,7 @@ def test_bp_state():
 
 
 def test_log_potentials():
-    vg = vgroup.VariableDict(15, (0,))
+    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
     fg = graph.FactorGraph(vg)
     fg.add_factor(
         variables=[vg[0]],
@@ -130,7 +130,7 @@ def test_log_potentials():
 
 
 def test_ftov_msgs():
-    vg = vgroup.VariableDict(15, (0,))
+    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
     fg = graph.FactorGraph(vg)
     fg.add_factor(
         variables=[vg[0]],
@@ -164,7 +164,7 @@ def test_ftov_msgs():
 
 
 def test_evidence():
-    vg = vgroup.VariableDict(15, (0,))
+    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
     fg = graph.FactorGraph(vg)
     fg.add_factor(
         variables=[vg[0]],
@@ -181,7 +181,7 @@ def test_evidence():
 
 
 def test_bp():
-    vg = vgroup.VariableDict(15, (0,))
+    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
     fg = graph.FactorGraph(vg)
     fg.add_factor(
         variables=[vg[0]],
diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py
index 9a0fd312..7d4c7a72 100644
--- a/tests/fg/test_groups.py
+++ b/tests/fg/test_groups.py
@@ -10,7 +10,7 @@
 
 
 def test_variable_dict():
-    variable_dict = vgroup.VariableDict(15, tuple([0, 1, 2]))
+    variable_dict = vgroup.VariableDict(variable_names=tuple([0, 1, 2]), num_states=15)
     with pytest.raises(
         ValueError, match="data is referring to a non-existent variable 3"
     ):
@@ -19,10 +19,10 @@ def test_variable_dict():
     with pytest.raises(
         ValueError,
         match=re.escape(
-            "Variable 2 expects a data array of shape (15,) or (1,). Got (10,)"
+            "Variable (2, 15) expects a data array of shape (15,) or (1,). Got (10,)"
         ),
     ):
-        variable_dict.flatten({2: np.zeros(10)})
+        variable_dict.flatten({(2, 15): np.zeros(10)})
 
     with pytest.raises(
         ValueError, match="Can only unflatten 1D array. Got a 2D array."
diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py
index a5e720a8..f1404907 100644
--- a/tests/test_pgmax.py
+++ b/tests/test_pgmax.py
@@ -6,7 +6,7 @@
 from numpy.random import default_rng
 from scipy.ndimage import gaussian_filter
 
-from pgmax.fg import graph, groups, nodes
+from pgmax.fg import graph, nodes
 from pgmax.groups import enumeration
 from pgmax.groups import variables as vgroup
 
@@ -121,20 +121,6 @@ def create_valid_suppression_config_arr(suppression_diameter):
             ]
         )
     )
-    true_map_state_output = {
-        ("grid_vars", 0, 0, 0): 2,
-        ("grid_vars", 0, 0, 1): 0,
-        ("grid_vars", 0, 1, 0): 0,
-        ("grid_vars", 0, 1, 1): 2,
-        ("grid_vars", 1, 0, 0): 1,
-        ("grid_vars", 1, 0, 1): 0,
-        ("grid_vars", 1, 1, 0): 1,
-        ("grid_vars", 1, 1, 1): 0,
-        ("additional_vars", 0, 0, 2): 0,
-        ("additional_vars", 0, 1, 2): 2,
-        ("additional_vars", 1, 2, 0): 1,
-        ("additional_vars", 1, 2, 1): 0,
-    }
 
     # Create a synthetic depth image for testing purposes
     im_size = 3
@@ -222,6 +208,21 @@ def create_valid_suppression_config_arr(suppression_diameter):
     additional_names = tuple(extra_row_names + extra_col_names)
     additional_vars = vgroup.VariableDict(additional_names, num_states=3)
 
+    true_map_state_output = {
+        (grid_vars, (0, 0, 0)): 2,
+        (grid_vars, (0, 0, 1)): 0,
+        (grid_vars, (0, 1, 0)): 0,
+        (grid_vars, (0, 1, 1)): 2,
+        (grid_vars, (1, 0, 0)): 1,
+        (grid_vars, (1, 0, 1)): 0,
+        (grid_vars, (1, 1, 0)): 1,
+        (grid_vars, (1, 1, 1)): 0,
+        (additional_vars, (0, 0, 2)): 0,
+        (additional_vars, (0, 1, 2)): 2,
+        (additional_vars, (1, 2, 0)): 1,
+        (additional_vars, (1, 2, 1)): 0,
+    }
+
     gt_has_cuts = gt_has_cuts.astype(np.int32)
 
     # Now, we use this array along with the gt_has_cuts array computed earlier using the image in order to derive the evidence values
@@ -252,7 +253,7 @@ def create_valid_suppression_config_arr(suppression_diameter):
                         additional_vars_evidence_dict[
                             additional_vars[i, row, col]
                         ] = evidence_vals_arr
-                    except IndexError:
+                    except ValueError:
                         pass
 
     # Create the factor graph
@@ -378,7 +379,7 @@ def create_valid_suppression_config_arr(suppression_diameter):
     assert jnp.allclose(bp_arrays.ftov_msgs, true_final_msgs_output, atol=1e-06)
     decoded_map_states = graph.decode_map_states(bp.get_beliefs(bp_arrays))
     for name in true_map_state_output:
-        assert true_map_state_output[name] == decoded_map_states[name[0]][name[1:]]
+        assert true_map_state_output[name] == decoded_map_states[name[0]][name[1]]
 
 
 def test_e2e_heretic():

From 5c8b381b6ac74060b1988efa4ef057467ec5d05b Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Fri, 22 Apr 2022 02:12:06 +0000
Subject: [PATCH 10/35] Some docstrings

---
 examples/pmp_binary_deconvolution.py |  6 ++--
 pgmax/factors/enumeration.py         |  5 ++-
 pgmax/factors/logical.py             |  5 ++-
 pgmax/fg/graph.py                    | 46 +++++++++++++---------------
 pgmax/fg/groups.py                   | 35 +++++++--------------
 pgmax/fg/nodes.py                    | 14 ++-------
 pgmax/groups/variables.py            | 46 ++++++++++++++++++----------
 tests/factors/test_and.py            |  5 ++-
 tests/factors/test_or.py             |  1 -
 tests/fg/test_groups.py              |  2 +-
 tests/test_pgmax.py                  |  8 ++---
 11 files changed, 80 insertions(+), 93 deletions(-)

diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py
index 6c31a537..e6a37a0a 100644
--- a/examples/pmp_binary_deconvolution.py
+++ b/examples/pmp_binary_deconvolution.py
@@ -224,7 +224,7 @@ def plot_images(images, display=True, nr=None):
 
 # %%
 pW = 0.25
-pS = 1e-72
+pS = 1e-80
 pX = 1e-100
 
 # Sparsity inducing priors for W and S
@@ -242,7 +242,7 @@ def plot_images(images, display=True, nr=None):
 # We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap`
 
 # %%
-np.random.seed(seed=40)
+np.random.seed(seed=42)
 n_samples = 4
 
 start = time.time()
@@ -271,3 +271,5 @@ def plot_images(images, display=True, nr=None):
 
 # %%
 _ = plot_images(map_states[W].reshape(-1, feat_height, feat_width), nr=n_samples)
+
+# %%
diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py
index 49bd8fd3..d461d3ab 100644
--- a/pgmax/factors/enumeration.py
+++ b/pgmax/factors/enumeration.py
@@ -163,9 +163,8 @@ def compile_wiring(
         Internally calls _compile_var_states_numba and _compile_enumeration_wiring_numba for speed.
 
         Args:
-            variables_for_factors: A list of list of variables, where each innermost element is a
-                variable. Each list within the outer list is taken to contain the names of the
-                variables connected to a Factor.
+            variables_for_factors: A list of list of variables. Each list within the outer list contains the
+                variables connected to a Factor. The same variable can be connected to multiple Factors.
             factor_configs: Array of shape (num_val_configs, num_variables) containing an explicit enumeration
                 of all valid configurations.
             vars_to_starts: A dictionary that maps variables to their global starting indices
diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py
index a7554334..7f2d1ca6 100644
--- a/pgmax/factors/logical.py
+++ b/pgmax/factors/logical.py
@@ -148,9 +148,8 @@ def compile_wiring(
         Internally calls _compile_var_states_numba and _compile_logical_wiring_numba for speed.
 
         Args:
-            variables_for_factors: A list of list of variables, where each innermost element is a
-                variable. Each list within the outer list is taken to contain the names of the
-                variables connected to a Factor.
+            variables_for_factors: A list of list of variables. Each list within the outer list contains the
+                variables connected to a Factor. The same variable can be connected to multiple Factors.
             vars_to_starts: A dictionary that maps variables to their global starting indices
                 For an n-state variable, a global start index of m means the global indices
                 of its n variable states are m, m + 1, ..., m + n - 1
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 77fc489a..9b8f7cc7 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -65,13 +65,6 @@ def __post_init__(self):
         else:
             self._variable_groups = self.variables
 
-        self._variables = [
-            variable
-            for variable_group in self._variable_groups
-            for variable in variable_group.variables
-        ]
-        print("1", time.time() - start)
-
         # Useful objects to build the FactorGraph
         self._factor_types_to_groups: OrderedDict[
             Type, List[groups.FactorGroup]
@@ -84,19 +77,23 @@ def __post_init__(self):
             [(factor_type, set()) for factor_type in FAC_TO_VAR_UPDATES]
         )
 
-        # Used to add FactorGroups
-        vars_num_states = [variable[1] for variable in self._variables]
-        vars_num_states_cumsum = np.insert(
-            np.array(vars_num_states).cumsum(),
-            0,
-            0,
-        )
-        print("1", time.time() - start)
         # See FactorGraphState docstrings for documentation on the following fields
-        self._num_var_states = vars_num_states_cumsum[-1]
-        self._vars_to_starts: Dict[Tuple[int, int], int] = collections.OrderedDict(
-            zip(self._variables, vars_num_states_cumsum[:-1])
-        )
+        self._vars_to_starts: OrderedDict[
+            Tuple[int, int], int
+        ] = collections.OrderedDict()
+        vars_num_states_cumsum = 0
+        for variable_group in self._variable_groups:
+            vg_num_states = variable_group.num_states.flatten()
+            vg_num_states_cumsum = np.insert(np.cumsum(vg_num_states), 0, 0)
+            self._vars_to_starts.update(
+                zip(
+                    variable_group.variables,
+                    vars_num_states_cumsum + vg_num_states_cumsum[:-1],
+                )
+            )
+            vars_num_states_cumsum += vg_num_states_cumsum[-1]
+
+        self._num_var_states = vars_num_states_cumsum
         self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {}
         print("2", time.time() - start)
 
@@ -112,7 +109,7 @@ def __hash__(self) -> int:
 
     def add_factor(
         self,
-        variables: List,
+        variables: List[Tuple],
         factor_configs: np.ndarray,
         log_potentials: Optional[np.ndarray] = None,
         name: Optional[str] = None,
@@ -120,8 +117,8 @@ def add_factor(
         """Function to add a single factor to the FactorGraph.
 
         Args:
-            variables: A list containing the connected variable names.
-                Variable names are tuples of the type (variable_group_name, variable_name_within_variable_group)
+            variables: A list containing the connected variables.
+                Each variable is represented by a tuple of the form (variable hash/name, number of states)
             factor_configs: Array of shape (num_val_configs, num_variables)
                 An array containing explicit enumeration of all valid configurations.
                 If the connected variables have n1, n2, ... states, 1 <= num_val_configs <= n1 * n2 * ...
@@ -146,8 +143,8 @@ def add_factor_by_type(
         """Function to add a single factor to the FactorGraph.
 
         Args:
-            variable_names: A list containing the connected variable names.
-                Variable names are tuples of the type (variable_group_name, variable_name_within_variable_group)
+            variables: A list containing the connected variables.
+                Each variable is represented by a tuple of the form (variable hash/name, number of states)
             factor_type: Type of factor to be added
             args: Args to be passed to the factor_type.
             kwargs: kwargs to be passed to the factor_type, and an optional "name" argument
@@ -1105,6 +1102,7 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
             else:
                 length = sum([variable[1] for variable in variables])
                 # length = variable_group.num_states.size
+
             beliefs[variable_group] = variable_group.unflatten(
                 flat_beliefs[start : start + length]
             )
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index 96cc4d19..b0340cfa 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -45,12 +45,11 @@ def __getitem__(self, val):
 
     @cached_property
     def variables(self) -> Tuple[Any, int]:
-        """Function that returns the list of variables. Each variable is represented
-        by a tuple of the variable name (which can be a hash or a string) and its
-        number of states.
+        """Function that returns the list of all variables in the VariableGroup.
+        Each variable is represented by a tuple of the form (variable name, number of states)
 
         Returns:
-           List of variables in the VariableGroup
+            List of variables in the VariableGroup
         """
         raise NotImplementedError(
             "Please subclass the VariableGroup class and override this method"
@@ -88,9 +87,8 @@ class FactorGroup:
     """Class to represent a group of Factors.
 
     Args:
-        variables_for_factors: A list of list of variables, where each innermost element is a
-            variable. Each list within the outer list is taken to contain the names of the
-            variables connected to a Factor.
+        variables_for_factors: A list of list of variables. Each list within the outer list contains the
+            variables connected to a Factor. The same variable can be connected to multiple Factors.
         factor_configs: Optional array containing an explicit enumeration of all valid configurations
         log_potentials: Array of log potentials.
 
@@ -142,8 +140,12 @@ def _variables_to_factors(self) -> Mapping[FrozenSet, nodes.Factor]:
 
     @cached_property
     def total_num_states(self) -> int:
-        """TODO"""
-        # TODO: this could be returned by the wiring
+        """Function to return the total number of states for all the variables involved in all the Factors
+
+        Returns:
+            Total number of variable states in the FactorGroup
+        """
+        # TODO: this could be returned by the wiring to loop over variables_for_factors only once
         return sum(
             [
                 variable[1]
@@ -293,18 +295,3 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any:
         raise NotImplementedError(
             "SingleFactorGroup does not support vectorized factor operations."
         )
-
-
-# def get_ndvariable_names_for_factor_groups(
-#     arrays: List[Any]
-
-# ):
-#     "Util function"
-
-# import numba as nb
-# @nb.jit(parallel=False, cache=True, fastmath=True, nopython=True)
-# def run_numba(h, v, f):
-#     for h_idx in nb.prange(h.shape[0]):
-#         for v_idx in nb.prange(v.shape[0]):
-#             f[h_idx * v.shape[0] + v_idx, 0] = h[h_idx]
-#             f[h_idx * v.shape[0] + v_idx, 1] = v[v_idx]
diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py
index 67a998ba..06d7c5d6 100644
--- a/pgmax/fg/nodes.py
+++ b/pgmax/fg/nodes.py
@@ -41,8 +41,8 @@ class Factor:
     """A factor
 
     Args:
-        variables: List of variables in the factors. Each variable is represented
-            by a tuple containing the variable hash and number of states.
+        variables: List of variables connected by the Factor.
+            Each variable is represented by a tuple of the form (variable hash/name, number of states)
 
     Raises:
         NotImplementedError: If compile_wiring is not implemented
@@ -57,16 +57,6 @@ def __post_init__(self):
                 "Please implement compile_wiring in for your factor"
             )
 
-    # @utils.cached_property
-    # def edges_num_states(self) -> np.ndarray:
-    #     """Number of states for the variables connected to each edge
-
-    #     Returns:
-    #         Array of shape (num_edges,)
-    #         Number of states for the variables connected to each edge
-    #     """
-    #     return self.variables.values()
-
     @staticmethod
     def concatenate_wirings(wirings: Sequence) -> Wiring:
         """Concatenate a list of Wirings
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index 3907b9f8..36f7a4b9 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -52,7 +52,7 @@ def __lt__(self, other):
         return hash(self) < hash(other)
 
     def __getitem__(self, val):
-        # Numpy indexation will throw IndexError for us if out-of-bounds
+        # Relies on numpy indexation to throw IndexError if val is out-of-bounds
         result = (self.variable_names[val], self.num_states[val])
         if isinstance(val, slice):
             return tuple(zip(result))
@@ -60,16 +60,22 @@ def __getitem__(self, val):
 
     @cached_property
     def variables(self) -> List[Tuple]:
+        """Function that returns the list of all variables in the VariableGroup.
+        Each variable is represented by a tuple of the form (variable hash, number of states)
+
+        Returns:
+            List of variables in the VariableGroup
+        """
         vars_names = self.variable_names.flatten()
         vars_num_states = self.num_states.flatten()
         return list(zip(vars_names, vars_num_states))
 
     @cached_property
     def variable_names(self) -> np.ndarray:
-        """Function that generates a dictionary mapping names to variables.
+        """Function that generates all the variables names, in the form of hashes
 
         Returns:
-            a dictionary mapping all possible names to different variables.
+            Array of variables names.
         """
         # Overwite default hash as it does not give enough spacing across consecutive objects
         indices = np.reshape(np.arange(np.product(self.shape)), self.shape)
@@ -123,7 +129,7 @@ def unflatten(
         if flat_data.size == np.product(self.shape):
             data = flat_data.reshape(self.shape)
         elif flat_data.size == self.num_states.sum():
-            # TODO: what should we dot for different number of states
+            # TODO: what should we do for different number of states
             data = flat_data.reshape(self.shape + (self.num_states.max(),))
         else:
             raise ValueError(
@@ -144,21 +150,29 @@ class VariableDict(groups.VariableGroup):
     """
 
     variable_names: Tuple[Any, ...]
-    num_states: int
+    num_states: np.ndarray  # TODO: this should be an int converted to an array in __post_init__
+
+    def __post_init__(self):
+        num_states = np.full((len(self.variable_names),), fill_value=self.num_states)
+        object.__setattr__(self, "num_states", num_states)
 
     @cached_property
     def variables(self) -> List[Tuple]:
-        return list(
-            zip(self.variable_names, [self.num_states] * len(self.variable_names))
-        )
+        """Function that returns the list of all variables in the VariableGroup.
+        Each variable is represented by a tuple of the form (variable name, number of states)
+
+        Returns:
+            List of variables in the VariableGroup
+        """
+        return list(zip(self.variable_names, self.num_states))
 
     def __getitem__(self, val):
         if val not in self.variable_names:
             raise ValueError(f"Variable {val} is not in VariableDict")
-        return (val, self.num_states)
+        return (val, self.num_states[0])
 
     def flatten(
-        self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]]
+        self, data: Mapping[Tuple[int, int], Union[np.ndarray, jnp.ndarray]]
     ) -> jnp.ndarray:
         """Function that turns meaningful structured data into a flat data array for internal use.
 
@@ -181,10 +195,10 @@ def flatten(
                     f"data is referring to a non-existent variable {name}."
                 )
 
-            if data[name].shape != (self.num_states,) and data[name].shape != (1,):
+            if data[name].shape != (name[1],) and data[name].shape != (1,):
                 raise ValueError(
                     f"Variable {name} expects a data array of shape "
-                    f"{(self.num_states,)} or (1,). Got {data[name].shape}."
+                    f"{(name[1],)} or (1,). Got {data[name].shape}."
                 )
 
         flat_data = jnp.concatenate([data[name].flatten() for name in self.variables])
@@ -214,7 +228,7 @@ def unflatten(
             )
 
         num_variables = len(self.variable_names)
-        num_variable_states = len(self.variable_names) * self.num_states
+        num_variable_states = self.num_states.sum()
         if flat_data.shape[0] == num_variables:
             use_num_states = False
         elif flat_data.shape[0] == num_variable_states:
@@ -228,10 +242,10 @@ def unflatten(
 
         start = 0
         data = {}
-        for name in self.variable_names:
+        for name in self.variables:
             if use_num_states:
-                data[name] = flat_data[start : start + self.num_states]
-                start += self.num_states
+                data[name] = flat_data[start : start + name[1]]
+                start += name[1]
             else:
                 data[name] = flat_data[np.array([start])]
                 start += 1
diff --git a/tests/factors/test_and.py b/tests/factors/test_and.py
index 19bc08d0..84e37175 100644
--- a/tests/factors/test_and.py
+++ b/tests/factors/test_and.py
@@ -15,7 +15,7 @@ def test_run_bp_with_ANDFactors():
     (2) the support of several factor types in a FactorGraph and during inference
 
     To do so, observe that an ANDFactor can be defined as an equivalent EnumerationFactor
-    (which list all the valid OR configurations) and define two equivalent FactorGraphs
+    (which list all the valid AND configurations) and define two equivalent FactorGraphs
     FG1: first half of factors are defined as EnumerationFactors, second half are defined as ANDFactors
     FG2: first half of factors are defined as ANDFactors, second half are defined as EnumerationFactors
 
@@ -25,7 +25,6 @@ def test_run_bp_with_ANDFactors():
     Note: for the first seed, add all the EnumerationFactors to FG1 and all the ANDFactors to FG2
     """
     for idx in range(10):
-        print("it", idx)
         np.random.seed(idx)
 
         # Parameters
@@ -54,7 +53,7 @@ def test_run_bp_with_ANDFactors():
         children_variables2 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
         fg2 = graph.FactorGraph(variables=[parents_variables2, children_variables2])
 
-        # Variable names for factors
+        # Option 1: Define EnumerationFactors equivalent to the ANDFactors
         variables_for_factors1 = []
         variables_for_factors2 = []
         for factor_idx in range(num_factors):
diff --git a/tests/factors/test_or.py b/tests/factors/test_or.py
index 162a7338..c4bdc758 100644
--- a/tests/factors/test_or.py
+++ b/tests/factors/test_or.py
@@ -25,7 +25,6 @@ def test_run_bp_with_ORFactors():
     Note: for the first seed, add all the EnumerationFactors to FG1 and all the ORFactors to FG2
     """
     for idx in range(10):
-        print("it", idx)
         np.random.seed(idx)
 
         # Parameters
diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py
index 7d4c7a72..d6e18e73 100644
--- a/tests/fg/test_groups.py
+++ b/tests/fg/test_groups.py
@@ -35,7 +35,7 @@ def test_variable_dict():
                 jax.tree_util.tree_multimap(
                     lambda x, y: jnp.all(x == y),
                     variable_dict.unflatten(jnp.zeros(3)),
-                    {name: np.zeros(1) for name in range(3)},
+                    {(name, 15): np.zeros(1) for name in range(3)},
                 )
             )
         )
diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py
index f1404907..5c30e047 100644
--- a/tests/test_pgmax.py
+++ b/tests/test_pgmax.py
@@ -217,10 +217,10 @@ def create_valid_suppression_config_arr(suppression_diameter):
         (grid_vars, (1, 0, 1)): 0,
         (grid_vars, (1, 1, 0)): 1,
         (grid_vars, (1, 1, 1)): 0,
-        (additional_vars, (0, 0, 2)): 0,
-        (additional_vars, (0, 1, 2)): 2,
-        (additional_vars, (1, 2, 0)): 1,
-        (additional_vars, (1, 2, 1)): 0,
+        (additional_vars, ((0, 0, 2), 3)): 0,
+        (additional_vars, ((0, 1, 2), 3)): 2,
+        (additional_vars, ((1, 2, 0), 3)): 1,
+        (additional_vars, ((1, 2, 1), 3)): 0,
     }
 
     gt_has_cuts = gt_has_cuts.astype(np.int32)

From 40ec5190b4b915bbb33b48b5a0cbbef7818505d8 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Fri, 22 Apr 2022 19:24:00 +0000
Subject: [PATCH 11/35] Stannis first comments

---
 examples/gmrf.py          |  4 +--
 examples/rbm.py           |  6 ++---
 pgmax/fg/graph.py         | 47 +++++++++++----------------------
 pgmax/fg/groups.py        | 27 +++++++++++++++----
 pgmax/groups/variables.py | 55 +++++++++++++++++++++------------------
 5 files changed, 71 insertions(+), 68 deletions(-)

diff --git a/examples/gmrf.py b/examples/gmrf.py
index 95ae4a7a..3520c04d 100644
--- a/examples/gmrf.py
+++ b/examples/gmrf.py
@@ -121,8 +121,8 @@
                 damping=0.0,
             )
         )
-    )
-    marginals = marginals[variables]
+    )[variables]
+
     pred_image = np.argmax(
         np.stack(
             [
diff --git a/examples/rbm.py b/examples/rbm.py
index 30f5ac2d..6d7a1d53 100644
--- a/examples/rbm.py
+++ b/examples/rbm.py
@@ -110,14 +110,14 @@
 # # Add unary factors
 # for ii in range(bh.shape[0]):
 #     fg.add_factor(
-#         variable_names=[("hidden", ii)],
+#         variables=[hidden_variables[ii]],
 #         factor_configs=np.arange(2)[:, None],
 #         log_potentials=np.array([0, bh[ii]]),
 #     )
 #
 # for jj in range(bv.shape[0]):
 #     fg.add_factor(
-#         variable_names=[("visible", jj)],
+#         variables=[visible_variables[jj]],
 #         factor_configs=np.arange(2)[:, None],
 #         log_potentials=np.array([0, bv[jj]]),
 #     )
@@ -127,7 +127,7 @@
 # for ii in tqdm(range(bh.shape[0])):
 #     for jj in range(bv.shape[0]):
 #         fg.add_factor(
-#             variable_names=[("hidden", ii), ("visible", jj)],
+#             variables=[hidden_variables[ii], visible_variables[jj]],
 #             factor_configs=factor_configs,
 #             log_potentials=np.array([0, 0, 0, W[ii, jj]]),
 #         )
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 9b8f7cc7..28460c79 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -35,7 +35,6 @@
 from pgmax.bp import infer
 from pgmax.factors import FAC_TO_VAR_UPDATES
 from pgmax.fg import groups, nodes
-from pgmax.groups import variables as vgroup
 from pgmax.groups.enumeration import EnumerationFactorGroup
 from pgmax.utils import cached_property
 
@@ -60,10 +59,8 @@ def __post_init__(self):
         import time
 
         start = time.time()
-        if isinstance(self.variables, (vgroup.NDVariableArray, vgroup.VariableDict)):
-            self._variable_groups = [self.variables]
-        else:
-            self._variable_groups = self.variables
+        if isinstance(self.variables, groups.VariableGroup):
+            self.variables = [self.variables]
 
         # Useful objects to build the FactorGraph
         self._factor_types_to_groups: OrderedDict[
@@ -82,7 +79,7 @@ def __post_init__(self):
             Tuple[int, int], int
         ] = collections.OrderedDict()
         vars_num_states_cumsum = 0
-        for variable_group in self._variable_groups:
+        for variable_group in self.variables:
             vg_num_states = variable_group.num_states.flatten()
             vg_num_states_cumsum = np.insert(np.cumsum(vg_num_states), 0, 0)
             self._vars_to_starts.update(
@@ -355,9 +352,10 @@ def fg_state(self) -> FactorGraphState:
         log_potentials = np.concatenate(
             [self.log_potentials[factor_type] for factor_type in self.log_potentials]
         )
+        assert isinstance(self.variables, list)
 
         return FactorGraphState(
-            variable_groups=self._variable_groups,
+            variable_groups=self.variables,
             vars_to_starts=self._vars_to_starts,
             num_var_states=self._num_var_states,
             total_factor_num_states=self._total_factor_num_states,
@@ -391,7 +389,7 @@ class FactorGraphState:
     """FactorGraphState.
 
     Args:
-        variable_group: A variable group containing all the variables in the FactorGraph.
+        variable_groups: All the variable groups in the FactorGraph.
         vars_to_starts: Maps variables to their starting indices in the flat evidence array.
             flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states]
             contains evidence to the variable.
@@ -1109,7 +1107,7 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
             start += length
         return beliefs
 
-    # @jax.jit
+    @jax.jit
     def get_beliefs(bp_arrays: BPArrays) -> Dict[Hashable, Any]:
         """Function to calculate beliefs from a BPArrays
 
@@ -1143,7 +1141,8 @@ def compute_flat_beliefs(bp_arrays, var_states_for_edges):
     return bp
 
 
-def decode_map_states(beliefs: Dict[Hashable, Any]) -> Dict[Hashable, Any]:
+@jax.jit
+def decode_map_states(beliefs: Dict[Hashable, Any]) -> Any:
     """Function to decode MAP states given the calculated beliefs.
 
     Args:
@@ -1152,18 +1151,11 @@ def decode_map_states(beliefs: Dict[Hashable, Any]) -> Dict[Hashable, Any]:
     Returns:
         An array or a PyTree container containing the MAP states for different variables.
     """
-
-    @jax.jit
-    def _decode_map_states(beliefs) -> Any:
-        return jax.tree_util.tree_map(lambda x: jnp.argmax(x, axis=-1), beliefs)
-
-    map_states = {}
-    for variable_group, vgroup_beliefs in beliefs.items():
-        map_states[variable_group] = _decode_map_states(vgroup_beliefs)
-    return map_states
+    return jax.tree_util.tree_map(lambda x: jnp.argmax(x, axis=-1), beliefs)
 
 
-def get_marginals(beliefs: Dict[Hashable, Any]) -> Dict[Hashable, Any]:
+@jax.jit
+def get_marginals(beliefs: Dict[Hashable, Any]) -> Any:
     """Function to get marginal probabilities given the calculated beliefs.
 
     Args:
@@ -1172,15 +1164,6 @@ def get_marginals(beliefs: Dict[Hashable, Any]) -> Dict[Hashable, Any]:
     Returns:
         An array or a PyTree container containing the marginal probabilities different variables.
     """
-
-    @jax.jit
-    def _get_marginals(beliefs) -> Any:
-        return jax.tree_util.tree_map(
-            lambda x: jnp.exp(x - logsumexp(x, axis=-1, keepdims=True)),
-            beliefs,
-        )
-
-    marginals = {}
-    for variable_group, vgroup_beliefs in beliefs.items():
-        marginals[variable_group] = _get_marginals(vgroup_beliefs)
-    return marginals
+    return jax.tree_util.tree_map(
+        lambda x: jnp.exp(x - logsumexp(x, axis=-1, keepdims=True)), beliefs
+    )
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index b0340cfa..04b22ec4 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -1,7 +1,9 @@
 """A module containing the base classes for variable and factor groups in a Factor Graph."""
 
 import inspect
+import random
 from dataclasses import dataclass, field
+from functools import total_ordering
 from typing import (
     Any,
     FrozenSet,
@@ -21,6 +23,7 @@
 from pgmax.utils import cached_property
 
 
+@total_ordering
 @dataclass(frozen=True, eq=False)
 class VariableGroup:
     """Class to represent a group of variables.
@@ -30,14 +33,28 @@ class VariableGroup:
     a sequence of variable names) of the VariableGroup.
     """
 
+    def __post_init__(self):
+        random_hash = random.randint(0, 2**63)
+        object.__setattr__(self, "random_hash", random_hash)
+
+    def __hash__(self):
+        return self.random_hash
+
+    def __eq__(self, other):
+        return hash(self) == hash(other)
+
+    def __lt__(self, other):
+        return hash(self) < hash(other)
+
     def __getitem__(self, val):
-        """Given a variable name, retrieve the associated Variable.
+        """Given a variable name, index, or a group of variable indices, retrieve the associated variable(s).
+        Each variable is returned via a tuple of the form (variable hash/name, number of states)
 
         Args:
-            val: a single name corresponding to a single variable, or a list of such names
+            val: a variable index, slice, or name
 
         Returns:
-            A single variable if the name is not a list. A list of variables if name is a list
+            A single variable or a list of variables
         """
         raise NotImplementedError(
             "Please subclass the VariableGroup class and override this method"
@@ -46,7 +63,7 @@ def __getitem__(self, val):
     @cached_property
     def variables(self) -> Tuple[Any, int]:
         """Function that returns the list of all variables in the VariableGroup.
-        Each variable is represented by a tuple of the form (variable name, number of states)
+        Each variable is represented by a tuple of the form (variable hash/name, number of states)
 
         Returns:
             List of variables in the VariableGroup
@@ -108,7 +125,7 @@ def __post_init__(self):
         if len(self.variables_for_factors) == 0:
             raise ValueError("Do not add a factor group with no factors.")
 
-    def __getitem__(self, variables: Sequence[int]) -> Any:
+    def __getitem__(self, variables: Sequence[Tuple[int, int]]) -> Any:
         """Function to query individual factors in the factor group
 
         Args:
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index 36f7a4b9..b7e0d66b 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -1,8 +1,6 @@
 """A module containing the variables group classes inheriting from the base VariableGroup."""
 
-import random
 from dataclasses import dataclass
-from functools import total_ordering
 from typing import Any, Dict, Hashable, List, Mapping, Tuple, Union
 
 import jax
@@ -13,7 +11,6 @@
 from pgmax.utils import cached_property
 
 
-@total_ordering
 @dataclass(frozen=True, eq=False)
 class NDVariableArray(groups.VariableGroup):
     """Subclass of VariableGroup for n-dimensional grids of variables.
@@ -28,6 +25,8 @@ class NDVariableArray(groups.VariableGroup):
     num_states: np.ndarray
 
     def __post_init__(self):
+        super().__post_init__()
+
         if np.isscalar(self.num_states) and np.issubdtype(type(np.int64(10)), int):
             num_states = np.full(self.shape, fill_value=self.num_states)
             object.__setattr__(self, "num_states", num_states)
@@ -39,23 +38,23 @@ def __post_init__(self):
         else:
             raise ValueError("num_states entries should be of type np.int")
 
-        random_hash = random.randint(0, 2**63)
-        object.__setattr__(self, "random_hash", random_hash)
-
-    def __hash__(self):
-        return self.random_hash
+    def __getitem__(
+        self, val: Union[int, tuple, slice]
+    ) -> Union[Tuple[int, int], List[Tuple[int]]]:
+        """Given an index or a slice, retrieve the associated variable(s).
+        Each variable is returned via a tuple of the form (variable hash, number of states)
 
-    def __eq__(self, other):
-        return hash(self) == hash(other)
+        Note: Relies on numpy indexation to throw IndexError if val is out-of-bounds
 
-    def __lt__(self, other):
-        return hash(self) < hash(other)
+        Args:
+            val: a variable index or slice
 
-    def __getitem__(self, val):
-        # Relies on numpy indexation to throw IndexError if val is out-of-bounds
+        Returns:
+            A single variable or a list of variables
+        """
         result = (self.variable_names[val], self.num_states[val])
         if isinstance(val, slice):
-            return tuple(zip(result))
+            return list(zip(result))
         return result
 
     @cached_property
@@ -153,6 +152,8 @@ class VariableDict(groups.VariableGroup):
     num_states: np.ndarray  # TODO: this should be an int converted to an array in __post_init__
 
     def __post_init__(self):
+        super().__post_init__()
+
         num_states = np.full((len(self.variable_names),), fill_value=self.num_states)
         object.__setattr__(self, "num_states", num_states)
 
@@ -189,19 +190,21 @@ def flatten(
                 (1) data is referring to a non-existing variable
                 (2) data is not of the correct shape
         """
-        for name in data:
-            if name not in self.variables:
+        for variable in data:
+            if variable not in self.variables:
                 raise ValueError(
-                    f"data is referring to a non-existent variable {name}."
+                    f"data is referring to a non-existent variable {variable}."
                 )
 
-            if data[name].shape != (name[1],) and data[name].shape != (1,):
+            if data[variable].shape != (variable[1],) and data[variable].shape != (1,):
                 raise ValueError(
-                    f"Variable {name} expects a data array of shape "
-                    f"{(name[1],)} or (1,). Got {data[name].shape}."
+                    f"Variable {variable} expects a data array of shape "
+                    f"{(variable[1],)} or (1,). Got {data[variable].shape}."
                 )
 
-        flat_data = jnp.concatenate([data[name].flatten() for name in self.variables])
+        flat_data = jnp.concatenate(
+            [data[variable].flatten() for variable in self.variables]
+        )
         return flat_data
 
     def unflatten(
@@ -242,12 +245,12 @@ def unflatten(
 
         start = 0
         data = {}
-        for name in self.variables:
+        for variable in self.variables:
             if use_num_states:
-                data[name] = flat_data[start : start + name[1]]
-                start += name[1]
+                data[variable] = flat_data[start : start + variable[1]]
+                start += variable[1]
             else:
-                data[name] = flat_data[np.array([start])]
+                data[variable] = flat_data[np.array([start])]
                 start += 1
 
         return data

From 96c7fe380cd1feb141e4de6a4fcde9f6291086f1 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Fri, 22 Apr 2022 23:42:02 +0000
Subject: [PATCH 12/35] Remove add_factor

---
 examples/gmrf.py                     |  33 ++--
 examples/ising_model.py              |   5 +-
 examples/pmp_binary_deconvolution.py |  16 +-
 examples/rbm.py                      |  27 +--
 examples/rcn.py                      |   9 +-
 pgmax/factors/enumeration.py         |   3 +-
 pgmax/fg/graph.py                    |  90 ++--------
 pgmax/fg/groups.py                   |   6 +-
 pgmax/groups/logical.py              |   4 +
 tests/factors/test_and.py            |  23 +--
 tests/factors/test_or.py             |  23 +--
 tests/fg/test_graph.py               | 237 +++++++++++++--------------
 tests/fg/test_wiring.py              |  53 +++---
 tests/test_pgmax.py                  |  23 +--
 14 files changed, 246 insertions(+), 306 deletions(-)

diff --git a/examples/gmrf.py b/examples/gmrf.py
index 3520c04d..363aca85 100644
--- a/examples/gmrf.py
+++ b/examples/gmrf.py
@@ -59,44 +59,44 @@
 
 # %%
 # Add top-down factors
-fg.add_factor_group(
-    factory=enumeration.PairwiseFactorGroup,
+factor_group = enumeration.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii + 1, jj]]
         for ii in range(M - 1)
         for jj in range(N)
     ],
-    name="top_down",
 )
+fg.add_factor_group(factor_group, name="top_down")
+
 # Add left-right factors
-fg.add_factor_group(
-    factory=enumeration.PairwiseFactorGroup,
+factor_group = enumeration.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii, jj + 1]]
         for ii in range(M)
         for jj in range(N - 1)
     ],
-    name="left_right",
 )
+fg.add_factor_group(factor_group, name="left_right")
+
 # Add diagonal factors
-fg.add_factor_group(
-    factory=enumeration.PairwiseFactorGroup,
+factor_group = enumeration.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii + 1, jj + 1]]
         for ii in range(M - 1)
         for jj in range(N - 1)
     ],
-    name="diagonal0",
 )
-fg.add_factor_group(
-    factory=enumeration.PairwiseFactorGroup,
+fg.add_factor_group(factor_group, name="diagonal0")
+
+
+factor_group = enumeration.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii - 1, jj + 1]]
         for ii in range(1, M)
         for jj in range(N - 1)
     ],
-    name="diagonal1",
 )
+fg.add_factor_group(factor_group, name="diagonal1")
 
 # %%
 bp = graph.BP(fg.bp_state, temperature=1.0)
@@ -227,3 +227,12 @@ def update(step, batch_noisy_images, batch_target_images, opt_state):
             )
             pbar.update()
             pbar.set_postfix(loss=value)
+
+
+batch_indices = indices[idx * batch_size : (idx + 1) * batch_size]
+batch_noisy_images, batch_target_images = (
+    noisy_images_train[:10],
+    target_images_train[:10],
+)
+step = 0
+value, opt_state = update(step, batch_noisy_images, batch_target_images, opt_state)
diff --git a/examples/ising_model.py b/examples/ising_model.py
index 88535c23..b2f5aef7 100644
--- a/examples/ising_model.py
+++ b/examples/ising_model.py
@@ -39,12 +39,11 @@
         variables_for_factors.append([variables[ii, jj], variables[kk, jj]])
         variables_for_factors.append([variables[ii, jj], variables[kk, ll]])
 
-fg.add_factor_group(
-    factory=enumeration.PairwiseFactorGroup,
+factor_group = enumeration.PairwiseFactorGroup(
     variables_for_factors=variables_for_factors,
     log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
-    name="factors",
 )
+fg.add_factor_group(factor_group, name="factors")
 
 # %% [markdown]
 # ### Run inference and visualize results
diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py
index e6a37a0a..6cdb84f1 100644
--- a/examples/pmp_binary_deconvolution.py
+++ b/examples/pmp_binary_deconvolution.py
@@ -180,10 +180,8 @@ def plot_images(images, display=True, nr=None):
 print("After loop", time.time() - start)
 
 # Add ANDFactorGroup, which is computationally efficient
-fg.add_factor_group(
-    factory=logical.ANDFactorGroup,
-    variables_for_factors=variables_for_ANDFactors,
-)
+AND_factor_group = logical.ANDFactorGroup(variables_for_ANDFactors)
+fg.add_factor_group(AND_factor_group)
 print(time.time() - start)
 
 # Define the ORFactors
@@ -193,10 +191,8 @@ def plot_images(images, display=True, nr=None):
 ]
 
 # Add ORFactorGroup, which is computationally efficient
-fg.add_factor_group(
-    factory=logical.ORFactorGroup,
-    variables_for_factors=variables_for_ORFactors,
-)
+OR_factor_group = logical.ORFactorGroup(variables_for_ORFactors)
+fg.add_factor_group(OR_factor_group)
 print("Time", time.time() - start)
 
 for factor_type, factor_groups in fg.factor_groups.items():
@@ -224,7 +220,7 @@ def plot_images(images, display=True, nr=None):
 
 # %%
 pW = 0.25
-pS = 1e-80
+pS = 1e-70
 pX = 1e-100
 
 # Sparsity inducing priors for W and S
@@ -271,5 +267,3 @@ def plot_images(images, display=True, nr=None):
 
 # %%
 _ = plot_images(map_states[W].reshape(-1, feat_height, feat_width), nr=n_samples)
-
-# %%
diff --git a/examples/rbm.py b/examples/rbm.py
index 6d7a1d53..876f08ee 100644
--- a/examples/rbm.py
+++ b/examples/rbm.py
@@ -26,6 +26,7 @@
 import matplotlib.pyplot as plt
 import numpy as np
 
+from pgmax.factors import enumeration as enumeration_factor
 from pgmax.fg import graph
 from pgmax.groups import enumeration
 from pgmax.groups import variables as vgroup
@@ -64,25 +65,25 @@
 start = time.time()
 
 # Add unary factors
-fg.add_factor_group(
-    factory=enumeration.EnumerationFactorGroup,
+hidden_unaries = enumeration.EnumerationFactorGroup(
     variables_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])],
     factor_configs=np.arange(2)[:, None],
     log_potentials=np.stack([np.zeros_like(bh), bh], axis=1),
 )
+fg.add_factor_group(hidden_unaries)
 
-fg.add_factor_group(
-    factory=enumeration.EnumerationFactorGroup,
+visible_unaries = enumeration.EnumerationFactorGroup(
     variables_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])],
     factor_configs=np.arange(2)[:, None],
     log_potentials=np.stack([np.zeros_like(bv), bv], axis=1),
 )
+fg.add_factor_group(visible_unaries)
+
 # Add pairwise factors
 log_potential_matrix = np.zeros(W.shape + (2, 2)).reshape((-1, 2, 2))
 log_potential_matrix[:, 1, 1] = W.ravel()
 
-fg.add_factor_group(
-    factory=enumeration.PairwiseFactorGroup,
+pairwise_factors = enumeration.PairwiseFactorGroup(
     variables_for_factors=[
         [hidden_variables[ii], visible_variables[jj]]
         for ii in range(bh.shape[0])
@@ -90,6 +91,7 @@
     ],
     log_potential_matrix=log_potential_matrix,
 )
+fg.add_factor_group(pairwise_factors)
 
 # # %snakeviz fg.add_factor_group(factory=enumeration.PairwiseFactorGroup, variables_for_factors=v, log_potential_matrix=log_potential_matrix,)
 print("Time", time.time() - start)
@@ -109,28 +111,31 @@
 #
 # # Add unary factors
 # for ii in range(bh.shape[0]):
-#     fg.add_factor(
+#     factor = enumeration_factor.EnumerationFactor(
 #         variables=[hidden_variables[ii]],
 #         factor_configs=np.arange(2)[:, None],
 #         log_potentials=np.array([0, bh[ii]]),
 #     )
+#     fg.add_factor(factor)
 #
 # for jj in range(bv.shape[0]):
-#     fg.add_factor(
+#     factor = enumeration_factor.EnumerationFactor(
 #         variables=[visible_variables[jj]],
 #         factor_configs=np.arange(2)[:, None],
 #         log_potentials=np.array([0, bv[jj]]),
 #     )
+#     fg.add_factor(factor)
 #
 # # Add pairwise factors
 # factor_configs = np.array(list(itertools.product(np.arange(2), repeat=2)))
 # for ii in tqdm(range(bh.shape[0])):
 #     for jj in range(bv.shape[0]):
-#         fg.add_factor(
+#         factor = enumeration_factor.EnumerationFactor(
 #             variables=[hidden_variables[ii], visible_variables[jj]],
 #             factor_configs=factor_configs,
 #             log_potentials=np.array([0, 0, 0, W[ii, jj]]),
 #         )
+#         fg.add_factor(factor)
 # ~~~
 #
 # Once we have added the factors, we can run max-product LBP and get MAP decoding by
@@ -194,8 +199,8 @@
 # ~~~python
 # bp_arrays = run_bp(
 #     evidence_updates={
-#         "hidden": np.random.gumbel(size=(bh.shape[0], 2)),
-#         "visible": np.random.gumbel(size=(bv.shape[0], 2)),
+#         hidden_variables: np.random.gumbel(size=(bh.shape[0], 2)),
+#         visible_variables: np.random.gumbel(size=(bv.shape[0], 2)),
 #     },
 #     damping=0.5,
 # )
diff --git a/examples/rcn.py b/examples/rcn.py
index 97de9fff..d7add7d7 100644
--- a/examples/rcn.py
+++ b/examples/rcn.py
@@ -38,6 +38,7 @@
 from scipy.signal import fftconvolve
 from sklearn.datasets import fetch_openml
 
+from pgmax.factors.enumeration import EnumerationFactor
 from pgmax.fg import graph
 from pgmax.groups import variables as vgroup
 
@@ -279,10 +280,12 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray:
 
     for e in edge:
         i1, i2, r = e
-        fg.add_factor(
-            [variables_all_models[idx][i1], variables_all_models[idx][i2]],
-            valid_configs_list[r],
+        factor = EnumerationFactor(
+            variables=[variables_all_models[idx][i1], variables_all_models[idx][i2]],
+            factor_configs=valid_configs_list[r],
+            log_potentials=np.zeros(valid_configs_list[r].shape[0]),
         )
+        fg.add_factor(factor)
 
 end = time.time()
 print(f"Creating factors took {end-start:.3f} seconds.")
diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py
index d461d3ab..b0610c5a 100644
--- a/pgmax/factors/enumeration.py
+++ b/pgmax/factors/enumeration.py
@@ -1,6 +1,5 @@
 """Defines an enumeration factor"""
 
-import functools
 from dataclasses import dataclass
 from typing import List, Mapping, Sequence, Tuple, Union
 
@@ -263,7 +262,7 @@ def _compile_enumeration_wiring_numba(
                 )
 
 
-@functools.partial(jax.jit, static_argnames=("num_val_configs", "temperature"))
+# @functools.partial(jax.jit, static_argnames=("num_val_configs", "temperature"))
 def pass_enum_fac_to_var_messages(
     vtof_msgs: jnp.ndarray,
     factor_configs_edge_states: jnp.ndarray,
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 28460c79..aec8b82f 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -104,95 +104,31 @@ def __hash__(self) -> int:
         )
         return hash(all_factor_groups)
 
-    def add_factor(
-        self,
-        variables: List[Tuple],
-        factor_configs: np.ndarray,
-        log_potentials: Optional[np.ndarray] = None,
-        name: Optional[str] = None,
-    ) -> None:
-        """Function to add a single factor to the FactorGraph.
+    def add_factor(self, factor: nodes.Factor, name: Optional[str] = None) -> None:
+        """Function to add a single Factor to the FactorGraph.
 
         Args:
-            variables: A list containing the connected variables.
-                Each variable is represented by a tuple of the form (variable hash/name, number of states)
-            factor_configs: Array of shape (num_val_configs, num_variables)
-                An array containing explicit enumeration of all valid configurations.
-                If the connected variables have n1, n2, ... states, 1 <= num_val_configs <= n1 * n2 * ...
-                factor_configs[config_idx, variable_idx] represents the state of variable_names[variable_idx]
-                in the configuration factor_configs[config_idx].
-            log_potentials: Optional array of shape (num_val_configs,).
-                If specified, log_potentials[config_idx] contains the log of the potential value for
-                the valid configuration factor_configs[config_idx].
-                If None, it is assumed the log potential is uniform 0 and such an array is automatically
-                initialized.
+            factor: The factor to be added to the factor graph.
+            name: Optional name of the FactorGroup.
         """
-        factor_group = EnumerationFactorGroup(
-            variables_for_factors=[variables],
-            factor_configs=factor_configs,
-            log_potentials=log_potentials,
-        )
-        self._register_factor_group(factor_group, name)
-
-    def add_factor_by_type(
-        self, variables: List[int], factor_type: type, *args, **kwargs
-    ) -> None:
-        """Function to add a single factor to the FactorGraph.
-
-        Args:
-            variables: A list containing the connected variables.
-                Each variable is represented by a tuple of the form (variable hash/name, number of states)
-            factor_type: Type of factor to be added
-            args: Args to be passed to the factor_type.
-            kwargs: kwargs to be passed to the factor_type, and an optional "name" argument
-                for specifying the name of a named factor group.
-
-        Example:
-            To add an ORFactor to a FactorGraph fg, run::
-
-                fg.add_factor_by_type(
-                    variables=variables_for_OR_factor,
-                    factor_type=logical.ORFactor
-                )
-        """
-        if factor_type not in FAC_TO_VAR_UPDATES:
-            raise ValueError(
-                f"Type {factor_type} is not one of the supported factor types {FAC_TO_VAR_UPDATES.keys()}"
-            )
-
-        name = kwargs.pop("name", None)
-        factor = factor_type(variables, *args, **kwargs)
         factor_group = groups.SingleFactorGroup(
-            variables_for_factors=[variables],
+            variables_for_factors=[factor.variables],
             factor=factor,
         )
-        self._register_factor_group(factor_group, name)
-
-    def add_factor_group(self, factory: Callable, *args, **kwargs) -> None:
-        """Add a factor group to the factor graph
-
-        Args:
-            factory: Factory function that takes args and kwargs as input and outputs a factor group.
-            args: Args to be passed to the factory function.
-            kwargs: kwargs to be passed to the factory function, and an optional "name" argument
-                for specifying the name of a named factor group.
-        """
-        name = kwargs.pop("name", None)
-        factor_group = factory(*args, **kwargs)
-        self._register_factor_group(factor_group, name)
+        self.add_factor_group(factor_group, name)
 
-    def _register_factor_group(
+    def add_factor_group(
         self, factor_group: groups.FactorGroup, name: Optional[str] = None
     ) -> None:
-        """Register a factor group to the factor graph, by updating the factor graph state.
+        """Add a FactorGroup to the FactorGraph, by updating the FactorGraphState.
 
         Args:
-            factor_group: The factor group to be registered to the factor graph.
-            name: Optional name of the factor group.
+            factor_group: The FactorGroup to be added to the FactorGraph.
+            name: Optional name of the FactorGroup.
 
         Raises:
-            ValueError: If the factor group with the same name or a factor involving the same variables
-                already exists in the factor graph.
+            ValueError: If the factor group with the same name or a Factor involving the same variables
+                already exists in the FactorGraph.
         """
         if name in self._named_factor_groups:
             raise ValueError(
@@ -988,7 +924,7 @@ def run_bp(
             ftov_msgs, edges_num_states, max_msg_size
         )
 
-        @jax.checkpoint
+        # @jax.checkpoint
         def update(msgs: jnp.ndarray, _) -> Tuple[jnp.ndarray, None]:
             # Compute new variable to factor messages by message passing
             vtof_msgs = infer.pass_var_to_fac_messages(
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index 04b22ec4..b777f5a4 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -123,7 +123,7 @@ class FactorGroup:
 
     def __post_init__(self):
         if len(self.variables_for_factors) == 0:
-            raise ValueError("Do not add a factor group with no factors.")
+            raise ValueError("Cannot create a FactorGroup with no Factor.")
 
     def __getitem__(self, variables: Sequence[Tuple[int, int]]) -> Any:
         """Function to query individual factors in the factor group
@@ -277,6 +277,10 @@ def __post_init__(self):
             if not hasattr(self, key):
                 object.__setattr__(self, key, getattr(self.factor, key))
 
+        object.__setattr__(
+            self, "log_potentials", getattr(self.factor, "log_potentials")
+        )
+
     def _get_variables_to_factors(
         self,
     ) -> OrderedDict[FrozenSet, nodes.Factor]:
diff --git a/pgmax/groups/logical.py b/pgmax/groups/logical.py
index 0f4714d5..83115945 100644
--- a/pgmax/groups/logical.py
+++ b/pgmax/groups/logical.py
@@ -22,6 +22,10 @@ class LogicalFactorGroup(groups.FactorGroup):
 
     edge_states_offset: int = field(init=False)
 
+    def __post_init__(self):
+        super().__post_init__()
+        object.__setattr__(self, "factor_configs", None)
+
     def _get_variables_to_factors(
         self,
     ) -> OrderedDict[FrozenSet, logical.LogicalFactor]:
diff --git a/tests/factors/test_and.py b/tests/factors/test_and.py
index 84e37175..147d57a0 100644
--- a/tests/factors/test_and.py
+++ b/tests/factors/test_and.py
@@ -3,6 +3,7 @@
 import jax
 import numpy as np
 
+from pgmax.factors.enumeration import EnumerationFactor
 from pgmax.fg import graph
 from pgmax.groups import logical
 from pgmax.groups import variables as vgroup
@@ -94,26 +95,29 @@ def test_run_bp_with_ANDFactors():
 
             if factor_idx < num_factors // 2:
                 # Add the first half of factors to FactorGraph1
-                fg1.add_factor(
+                enum_factor = EnumerationFactor(
                     variables=variables_for_factors1[factor_idx],
                     factor_configs=valid_configs,
                     log_potentials=np.zeros(valid_configs.shape[0]),
                 )
+                fg1.add_factor(enum_factor)
             else:
                 if idx != 0:
                     # Add the second half of factors to FactorGraph2
-                    fg2.add_factor(
+                    enum_factor = EnumerationFactor(
                         variables=variables_for_factors2[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
+                    fg2.add_factor(enum_factor)
                 else:
                     # Add all the EnumerationFactors to FactorGraph1 for the first iter
-                    fg1.add_factor(
+                    enum_factor = EnumerationFactor(
                         variables=variables_for_factors1[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
+                    fg1.add_factor(enum_factor)
 
         # Option 2: Define the ANDFactors
         num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0)
@@ -136,14 +140,11 @@ def test_run_bp_with_ANDFactors():
                         variables_for_factors2[factor_idx]
                     )
         if idx != 0:
-            fg1.add_factor_group(
-                factory=logical.ANDFactorGroup,
-                variables_for_factors=variables_for_ANDFactors_fg1,
-            )
-        fg2.add_factor_group(
-            factory=logical.ANDFactorGroup,
-            variables_for_factors=variables_for_ANDFactors_fg2,
-        )
+            factor_group = logical.ANDFactorGroup(variables_for_ANDFactors_fg1)
+            fg1.add_factor_group(factor_group)
+
+        factor_group = logical.ANDFactorGroup(variables_for_ANDFactors_fg2)
+        fg2.add_factor_group(factor_group)
 
         # Run inference
         bp1 = graph.BP(fg1.bp_state, temperature=temperature)
diff --git a/tests/factors/test_or.py b/tests/factors/test_or.py
index c4bdc758..76cc5107 100644
--- a/tests/factors/test_or.py
+++ b/tests/factors/test_or.py
@@ -3,6 +3,7 @@
 import jax
 import numpy as np
 
+from pgmax.factors.enumeration import EnumerationFactor
 from pgmax.fg import graph
 from pgmax.groups import logical
 from pgmax.groups import variables as vgroup
@@ -92,26 +93,29 @@ def test_run_bp_with_ORFactors():
 
             if factor_idx < num_factors // 2:
                 # Add the first half of factors to FactorGraph1
-                fg1.add_factor(
+                enum_factor = EnumerationFactor(
                     variables=variables_for_factors1[factor_idx],
                     factor_configs=valid_configs,
                     log_potentials=np.zeros(valid_configs.shape[0]),
                 )
+                fg1.add_factor(enum_factor)
             else:
                 if idx != 0:
                     # Add the second half of factors to FactorGraph2
-                    fg2.add_factor(
+                    enum_factor = EnumerationFactor(
                         variables=variables_for_factors2[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
+                    fg2.add_factor(enum_factor)
                 else:
                     # Add all the EnumerationFactors to FactorGraph1 for the first iter
-                    fg1.add_factor(
+                    enum_factor = EnumerationFactor(
                         variables=variables_for_factors1[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
+                    fg1.add_factor(enum_factor)
 
         # Option 2: Define the ORFactors
         num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0)
@@ -134,14 +138,11 @@ def test_run_bp_with_ORFactors():
                         variables_for_factors2[factor_idx]
                     )
         if idx != 0:
-            fg1.add_factor_group(
-                factory=logical.ORFactorGroup,
-                variables_for_factors=variables_for_ORFactors_fg1,
-            )
-        fg2.add_factor_group(
-            factory=logical.ORFactorGroup,
-            variables_for_factors=variables_for_ORFactors_fg2,
-        )
+            factor_group = logical.ORFactorGroup(variables_for_ORFactors_fg1)
+            fg1.add_factor_group(factor_group)
+
+        factor_group = logical.ORFactorGroup(variables_for_ORFactors_fg2)
+        fg2.add_factor_group(factor_group)
 
         # Run inference
         bp1 = graph.BP(fg1.bp_state, temperature=temperature)
diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py
index 52090b1a..5fdf2815 100644
--- a/tests/fg/test_graph.py
+++ b/tests/fg/test_graph.py
@@ -13,24 +13,22 @@
 
 
 def test_factor_graph():
+    # TODO: remove factor graph name
+
     vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
     fg = graph.FactorGraph(vg)
-    fg.add_factor_by_type(
-        factor_type=enumeration_factor.EnumerationFactor,
+    factor = enumeration_factor.EnumerationFactor(
         variables=[vg[0]],
         factor_configs=np.arange(15)[:, None],
         log_potentials=np.zeros(15),
-        name="test",
     )
+    fg.add_factor(factor, name="test")
+
     with pytest.raises(
         ValueError,
         match="A factor group with the name test already exists. Please choose a different name",
     ):
-        fg.add_factor(
-            variables=[vg[0]],
-            factor_configs=np.arange(15)[:, None],
-            name="test",
-        )
+        fg.add_factor(factor, name="test")
 
     with pytest.raises(
         ValueError,
@@ -38,18 +36,7 @@ def test_factor_graph():
             f"A Factor of type {enumeration_factor.EnumerationFactor} involving variables {frozenset([(0, 15)])} already exists."
         ),
     ):
-        fg.add_factor(
-            variables=[vg[0]],
-            factor_configs=np.arange(10)[:, None],
-        )
-
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            f"Type {groups.FactorGroup} is not one of the supported factor types {FAC_TO_VAR_UPDATES.keys()}"
-        ),
-    ):
-        fg.add_factor_by_type(variables=[vg[0]], factor_type=groups.FactorGroup)
+        fg.add_factor(factor)
 
 
 def test_factor_adding():
@@ -57,11 +44,8 @@ def test_factor_adding():
     B = vgroup.NDVariableArray(num_states=2, shape=(10,))
     fg = graph.FactorGraph(variables=[A, B])
 
-    with pytest.raises(ValueError, match="Do not add a factor group with no factors."):
-        fg.add_factor_group(
-            factory=logical.ORFactorGroup,
-            variables_for_factors=[],
-        )
+    with pytest.raises(ValueError, match="Cannot create a FactorGroup with no Factor."):
+        factor_group = logical.ORFactorGroup(variables_for_factors=[])
 
     variables0 = (A[0], B[0])
     variables1 = (A[1], B[1])
@@ -78,17 +62,16 @@ def test_factor_adding():
 def test_bp_state():
     vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
     fg0 = graph.FactorGraph(vg)
-    fg0.add_factor(
-        variables=[vg[0]],
-        factor_configs=np.arange(10)[:, None],
-        name="test",
-    )
-    fg1 = graph.FactorGraph(vg)
-    fg1.add_factor(
+    factor = enumeration_factor.EnumerationFactor(
         variables=[vg[0]],
         factor_configs=np.arange(15)[:, None],
-        name="test",
+        log_potentials=np.zeros(15),
     )
+    fg0.add_factor(factor, name="test")
+
+    fg1 = graph.FactorGraph(vg)
+    fg1.add_factor(factor, name="test")
+
     with pytest.raises(
         ValueError,
         match="log_potentials, ftov_msgs and evidence should be derived from the same fg_state",
@@ -103,98 +86,100 @@ def test_bp_state():
 def test_log_potentials():
     vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
     fg = graph.FactorGraph(vg)
-    fg.add_factor(
-        variables=[vg[0]],
-        factor_configs=np.arange(10)[:, None],
-        name="test",
-    )
-    with pytest.raises(
-        ValueError,
-        match=re.escape("Expected log potentials shape (10,) for factor group test."),
-    ):
-        fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15))
-
-    with pytest.raises(
-        ValueError,
-        match=re.escape("Invalid name (0, 15) for log potentials updates."),
-    ):
-        fg.bp_state.log_potentials[vg[0]] = np.zeros(10)
-
-    with pytest.raises(
-        ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)")
-    ):
-        graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15))
-
-    log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10))
-    assert jnp.all(log_potentials["test"] == jnp.zeros(10))
-
-
-def test_ftov_msgs():
-    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
-    fg = graph.FactorGraph(vg)
-    fg.add_factor(
-        variables=[vg[0]],
-        factor_configs=np.arange(10)[:, None],
-        name="test",
-    )
-    with pytest.raises(
-        ValueError,
-        match=re.escape("Invalid names for setting messages"),
-    ):
-        fg.bp_state.ftov_msgs[0] = np.ones(10)
-
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            "Given belief shape (10,) does not match expected shape (15,) for variable (0, 15)."
-        ),
-    ):
-        fg.bp_state.ftov_msgs[vg[0]] = np.ones(10)
-
-    with pytest.raises(
-        ValueError, match=re.escape("Expected messages shape (15,). Got (10,)")
-    ):
-        graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10))
-
-    ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15))
-    with pytest.raises(
-        TypeError, match=re.escape("'FToVMessages' object is not subscriptable")
-    ):
-        ftov_msgs[(10,)]
-
-
-def test_evidence():
-    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
-    fg = graph.FactorGraph(vg)
-    fg.add_factor(
+    factor = enumeration_factor.EnumerationFactor(
         variables=[vg[0]],
-        factor_configs=np.arange(10)[:, None],
-        name="test",
-    )
-    with pytest.raises(
-        ValueError, match=re.escape("Expected evidence shape (15,). Got (10,).")
-    ):
-        graph.Evidence(fg_state=fg.fg_state, value=np.zeros(10))
-
-    evidence = graph.Evidence(fg_state=fg.fg_state, value=np.zeros(15))
-    assert jnp.all(evidence.value == jnp.zeros(15))
-
-
-def test_bp():
-    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
-    fg = graph.FactorGraph(vg)
-    fg.add_factor(
-        variables=[vg[0]],
-        factor_configs=np.arange(10)[:, None],
-        name="test",
-    )
-    bp = graph.BP(fg.bp_state, temperature=0)
-    bp_arrays = bp.update()
-    bp_arrays = bp.update(
-        bp_arrays=bp_arrays,
-        ftov_msgs_updates={vg[0]: np.zeros(15)},
+        factor_configs=np.arange(15)[:, None],
+        log_potentials=np.zeros(15),
     )
-    bp_arrays = bp.run_bp(bp_arrays, num_iters=1)
-    bp_arrays = replace(bp_arrays, log_potentials=jnp.zeros((10)))
-    bp_state = bp.to_bp_state(bp_arrays)
-    assert bp_state.fg_state == fg.fg_state
+    fg.add_factor(factor, name="test")
+
+    # with pytest.raises(
+    #     ValueError,
+    #     match=re.escape("Expected log potentials shape (10,) for factor group test."),
+    # ):
+    #     fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15))
+
+    # with pytest.raises(
+    #     ValueError,
+    #     match=re.escape("Invalid name (0, 15) for log potentials updates."),
+    # ):
+    #     fg.bp_state.log_potentials[vg[0]] = np.zeros(10)
+
+    # with pytest.raises(
+    #     ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)")
+    # ):
+    #     graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15))
+
+    # log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10))
+    # assert jnp.all(log_potentials["test"] == jnp.zeros(10))
+
+
+# def test_ftov_msgs():
+#     vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
+#     fg = graph.FactorGraph(vg)
+#     fg.add_factor(
+#         variables=[vg[0]],
+#         factor_configs=np.arange(10)[:, None],
+#         name="test",
+#     )
+#     with pytest.raises(
+#         ValueError,
+#         match=re.escape("Invalid names for setting messages"),
+#     ):
+#         fg.bp_state.ftov_msgs[0] = np.ones(10)
+
+#     with pytest.raises(
+#         ValueError,
+#         match=re.escape(
+#             "Given belief shape (10,) does not match expected shape (15,) for variable (0, 15)."
+#         ),
+#     ):
+#         fg.bp_state.ftov_msgs[vg[0]] = np.ones(10)
+
+#     with pytest.raises(
+#         ValueError, match=re.escape("Expected messages shape (15,). Got (10,)")
+#     ):
+#         graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10))
+
+#     ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15))
+#     with pytest.raises(
+#         TypeError, match=re.escape("'FToVMessages' object is not subscriptable")
+#     ):
+#         ftov_msgs[(10,)]
+
+
+# def test_evidence():
+#     vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
+#     fg = graph.FactorGraph(vg)
+#     fg.add_factor(
+#         variables=[vg[0]],
+#         factor_configs=np.arange(10)[:, None],
+#         name="test",
+#     )
+#     with pytest.raises(
+#         ValueError, match=re.escape("Expected evidence shape (15,). Got (10,).")
+#     ):
+#         graph.Evidence(fg_state=fg.fg_state, value=np.zeros(10))
+
+#     evidence = graph.Evidence(fg_state=fg.fg_state, value=np.zeros(15))
+#     assert jnp.all(evidence.value == jnp.zeros(15))
+
+
+# def test_bp():
+#     vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
+#     fg = graph.FactorGraph(vg)
+#     fg.add_factor(
+#         variables=[vg[0]],
+#         factor_configs=np.arange(10)[:, None],
+#         name="test",
+#     )
+#     bp = graph.BP(fg.bp_state, temperature=0)
+#     bp_arrays = bp.update()
+#     bp_arrays = bp.update(
+#         bp_arrays=bp_arrays,
+#         ftov_msgs_updates={vg[0]: np.zeros(15)},
+#     )
+#     bp_arrays = bp.run_bp(bp_arrays, num_iters=1)
+#     bp_arrays = replace(bp_arrays, log_potentials=jnp.zeros((10)))
+#     bp_state = bp.to_bp_state(bp_arrays)
+#     assert bp_state.fg_state == fg.fg_state
diff --git a/tests/fg/test_wiring.py b/tests/fg/test_wiring.py
index 4966a66e..5b68a974 100644
--- a/tests/fg/test_wiring.py
+++ b/tests/fg/test_wiring.py
@@ -20,10 +20,11 @@ def test_wiring_with_PairwiseFactorGroup():
 
     # First test that compile_wiring enforces the correct factor_edges_num_states shape
     fg = graph.FactorGraph(variables=[A, B])
-    fg.add_factor_group(
-        factory=enumeration.PairwiseFactorGroup,
-        variables_for_factors=[[A[idx], B[idx]] for idx in range(10)],
+    factor_group = enumeration.PairwiseFactorGroup(
+        variables_for_factors=[[A[idx], B[idx]] for idx in range(10)]
     )
+    fg.add_factor_group(factor_group)
+
     factor_group = fg.factor_groups[enumeration_factor.EnumerationFactor][0]
     object.__setattr__(
         factor_group, "factor_configs", factor_group.factor_configs[:, :1]
@@ -36,32 +37,30 @@ def test_wiring_with_PairwiseFactorGroup():
 
     # FactorGraph with a single PairwiseFactorGroup
     fg1 = graph.FactorGraph(variables=[A, B])
-    fg1.add_factor_group(
-        factory=enumeration.PairwiseFactorGroup,
-        variables_for_factors=[[A[idx], B[idx]] for idx in range(10)],
+    factor_group = enumeration.PairwiseFactorGroup(
+        variables_for_factors=[[A[idx], B[idx]] for idx in range(10)]
     )
+    fg1.add_factor_group(factor_group)
     assert len(fg1.factor_groups[enumeration_factor.EnumerationFactor]) == 1
 
     # FactorGraph with multiple PairwiseFactorGroup
     fg2 = graph.FactorGraph(variables=[A, B])
     for idx in range(10):
-        fg2.add_factor_group(
-            factory=enumeration.PairwiseFactorGroup,
-            variables_for_factors=[[A[idx], B[idx]]],
+        factor_group = enumeration.PairwiseFactorGroup(
+            variables_for_factors=[[A[idx], B[idx]]]
         )
+        fg2.add_factor_group(factor_group)
     assert len(fg2.factor_groups[enumeration_factor.EnumerationFactor]) == 10
 
     # FactorGraph with multiple SingleFactorGroup
     fg3 = graph.FactorGraph(variables=[A, B])
     for idx in range(10):
-        fg3.add_factor_by_type(
+        factor = enumeration_factor.EnumerationFactor(
             variables=[A[idx], B[idx]],
-            factor_type=enumeration_factor.EnumerationFactor,
-            **{
-                "factor_configs": np.array([[0, 0], [0, 1], [1, 0], [1, 1]]),
-                "log_potentials": np.zeros((4,)),
-            }
+            factor_configs=np.array([[0, 0], [0, 1], [1, 0], [1, 1]]),
+            log_potentials=np.zeros((4,)),
         )
+        fg3.add_factor(factor)
     assert len(fg3.factor_groups[enumeration_factor.EnumerationFactor]) == 10
 
     assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors)
@@ -97,28 +96,28 @@ def test_wiring_with_ORFactorGroup():
 
     # FactorGraph with a single ORFactorGroup
     fg1 = graph.FactorGraph(variables=[A, B, C])
-    fg1.add_factor_group(
-        factory=logical.ORFactorGroup,
+    factor_group = logical.ORFactorGroup(
         variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)],
     )
+    fg1.add_factor_group(factor_group)
     assert len(fg1.factor_groups[logical_factor.ORFactor]) == 1
 
     # FactorGraph with multiple ORFactorGroup
     fg2 = graph.FactorGraph(variables=[A, B, C])
     for idx in range(10):
-        fg2.add_factor_group(
-            factory=logical.ORFactorGroup,
+        factor_group = logical.ORFactorGroup(
             variables_for_factors=[[A[idx], B[idx], C[idx]]],
         )
+        fg2.add_factor_group(factor_group)
     assert len(fg2.factor_groups[logical_factor.ORFactor]) == 10
 
     # FactorGraph with multiple SingleFactorGroup
     fg3 = graph.FactorGraph(variables=[A, B, C])
     for idx in range(10):
-        fg3.add_factor_by_type(
+        factor = logical_factor.ORFactor(
             variables=[A[idx], B[idx], C[idx]],
-            factor_type=logical_factor.ORFactor,
         )
+        fg3.add_factor(factor)
     assert len(fg3.factor_groups[logical_factor.ORFactor]) == 10
 
     assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors)
@@ -152,28 +151,28 @@ def test_wiring_with_ANDFactorGroup():
 
     # FactorGraph with a single ANDFactorGroup
     fg1 = graph.FactorGraph(variables=[A, B, C])
-    fg1.add_factor_group(
-        factory=logical.ANDFactorGroup,
+    factor_group = logical.ANDFactorGroup(
         variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)],
     )
+    fg1.add_factor_group(factor_group)
     assert len(fg1.factor_groups[logical_factor.ANDFactor]) == 1
 
     # FactorGraph with multiple ANDFactorGroup
     fg2 = graph.FactorGraph(variables=[A, B, C])
     for idx in range(10):
-        fg2.add_factor_group(
-            factory=logical.ANDFactorGroup,
+        factor_group = logical.ANDFactorGroup(
             variables_for_factors=[[A[idx], B[idx], C[idx]]],
         )
+        fg2.add_factor_group(factor_group)
     assert len(fg2.factor_groups[logical_factor.ANDFactor]) == 10
 
     # FactorGraph with multiple SingleFactorGroup
     fg3 = graph.FactorGraph(variables=[A, B, C])
     for idx in range(10):
-        fg3.add_factor_by_type(
+        factor = logical_factor.ANDFactor(
             variables=[A[idx], B[idx], C[idx]],
-            factor_type=logical_factor.ANDFactor,
         )
+        fg3.add_factor(factor)
     assert len(fg3.factor_groups[logical_factor.ANDFactor]) == 10
 
     assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors)
diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py
index 5c30e047..e9067abe 100644
--- a/tests/test_pgmax.py
+++ b/tests/test_pgmax.py
@@ -6,6 +6,7 @@
 from numpy.random import default_rng
 from scipy.ndimage import gaussian_filter
 
+from pgmax.factors.enumeration import EnumerationFactor
 from pgmax.fg import graph, nodes
 from pgmax.groups import enumeration
 from pgmax.groups import variables as vgroup
@@ -294,23 +295,23 @@ def create_valid_suppression_config_arr(suppression_diameter):
                     additional_vars[1, row + 1, col],
                 ]
             if row % 2 == 0:
-                fg.add_factor(
+                factor = EnumerationFactor(
                     variables=curr_names,
                     factor_configs=valid_configs_non_supp,
                     log_potentials=np.zeros(
                         valid_configs_non_supp.shape[0], dtype=float
                     ),
-                    name=(row, col),
                 )
+                fg.add_factor(factor, name=(row, col))
             else:
-                fg.add_factor(
+                factor = EnumerationFactor(
                     variables=curr_names,
                     factor_configs=valid_configs_non_supp,
                     log_potentials=np.zeros(
                         valid_configs_non_supp.shape[0], dtype=float
                     ),
-                    name=(row, col),
                 )
+                fg.add_factor(factor, name=(row, col))
 
     # Create an EnumerationFactorGroup for vertical suppression factors
     vert_suppression_names: List[List[Tuple[Any, ...]]] = []
@@ -350,17 +351,18 @@ def create_valid_suppression_config_arr(suppression_diameter):
                 )
 
     # Add the suppression factors to the graph via kwargs
-    fg.add_factor_group(
-        factory=enumeration.EnumerationFactorGroup,
+    factor_group = enumeration.EnumerationFactorGroup(
         variables_for_factors=vert_suppression_names,
         factor_configs=valid_configs_supp,
     )
-    fg.add_factor_group(
-        factory=enumeration.EnumerationFactorGroup,
+    fg.add_factor_group(factor_group)
+
+    factor_group = enumeration.EnumerationFactorGroup(
         variables_for_factors=horz_suppression_names,
         factor_configs=valid_configs_supp,
         log_potentials=np.zeros(valid_configs_supp.shape[0], dtype=float),
     )
+    fg.add_factor_group(factor_group)
 
     # Run BP
     # Set the evidence
@@ -413,12 +415,11 @@ def binary_connected_variables(
     W_pot = np.zeros((17, 3, 3, 3), dtype=float)
     for k_row in range(3):
         for k_col in range(3):
-            fg.add_factor_group(
-                factory=enumeration.PairwiseFactorGroup,
+            factor_group = enumeration.PairwiseFactorGroup(
                 variables_for_factors=binary_connected_variables(28, 28, k_row, k_col),
                 log_potential_matrix=W_pot[:, :, k_row, k_col],
-                name=(k_row, k_col),
             )
+            fg.add_factor_group(factor_group, name=(k_row, k_col))
 
     # Assign evidence to pixel vars
     bp_state = fg.bp_state

From 88f8e230f601bef213a6106f3181ce7e6f8984df Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Sat, 23 Apr 2022 00:51:36 +0000
Subject: [PATCH 13/35] Test

---
 examples/rbm.py        |   1 -
 pgmax/fg/graph.py      |   1 -
 tests/fg/test_graph.py | 204 ++++++++++++++++++++---------------------
 3 files changed, 102 insertions(+), 104 deletions(-)

diff --git a/examples/rbm.py b/examples/rbm.py
index 876f08ee..8ea595ea 100644
--- a/examples/rbm.py
+++ b/examples/rbm.py
@@ -26,7 +26,6 @@
 import matplotlib.pyplot as plt
 import numpy as np
 
-from pgmax.factors import enumeration as enumeration_factor
 from pgmax.fg import graph
 from pgmax.groups import enumeration
 from pgmax.groups import variables as vgroup
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index aec8b82f..6f804be2 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -35,7 +35,6 @@
 from pgmax.bp import infer
 from pgmax.factors import FAC_TO_VAR_UPDATES
 from pgmax.fg import groups, nodes
-from pgmax.groups.enumeration import EnumerationFactorGroup
 from pgmax.utils import cached_property
 
 
diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py
index 5fdf2815..79ee0895 100644
--- a/tests/fg/test_graph.py
+++ b/tests/fg/test_graph.py
@@ -5,10 +5,9 @@
 import numpy as np
 import pytest
 
-from pgmax.factors import FAC_TO_VAR_UPDATES
 from pgmax.factors import enumeration as enumeration_factor
 from pgmax.fg import graph, groups
-from pgmax.groups import logical
+from pgmax.groups import enumeration, logical
 from pgmax.groups import variables as vgroup
 
 
@@ -39,13 +38,12 @@ def test_factor_graph():
         fg.add_factor(factor)
 
 
-def test_factor_adding():
+def test_single_factor():
+    with pytest.raises(ValueError, match="Cannot create a FactorGroup with no Factor."):
+        logical.ORFactorGroup(variables_for_factors=[])
+
     A = vgroup.NDVariableArray(num_states=2, shape=(10,))
     B = vgroup.NDVariableArray(num_states=2, shape=(10,))
-    fg = graph.FactorGraph(variables=[A, B])
-
-    with pytest.raises(ValueError, match="Cannot create a FactorGroup with no Factor."):
-        factor_group = logical.ORFactorGroup(variables_for_factors=[])
 
     variables0 = (A[0], B[0])
     variables1 = (A[1], B[1])
@@ -86,100 +84,102 @@ def test_bp_state():
 def test_log_potentials():
     vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
     fg = graph.FactorGraph(vg)
-    factor = enumeration_factor.EnumerationFactor(
-        variables=[vg[0]],
-        factor_configs=np.arange(15)[:, None],
-        log_potentials=np.zeros(15),
+    factor_group = enumeration.EnumerationFactorGroup(
+        variables_for_factors=[[vg[0]]],
+        factor_configs=np.arange(10)[:, None],
     )
-    fg.add_factor(factor, name="test")
+    fg.add_factor_group(factor_group, name="test")
 
-    # with pytest.raises(
-    #     ValueError,
-    #     match=re.escape("Expected log potentials shape (10,) for factor group test."),
-    # ):
-    #     fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15))
-
-    # with pytest.raises(
-    #     ValueError,
-    #     match=re.escape("Invalid name (0, 15) for log potentials updates."),
-    # ):
-    #     fg.bp_state.log_potentials[vg[0]] = np.zeros(10)
-
-    # with pytest.raises(
-    #     ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)")
-    # ):
-    #     graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15))
-
-    # log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10))
-    # assert jnp.all(log_potentials["test"] == jnp.zeros(10))
-
-
-# def test_ftov_msgs():
-#     vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
-#     fg = graph.FactorGraph(vg)
-#     fg.add_factor(
-#         variables=[vg[0]],
-#         factor_configs=np.arange(10)[:, None],
-#         name="test",
-#     )
-#     with pytest.raises(
-#         ValueError,
-#         match=re.escape("Invalid names for setting messages"),
-#     ):
-#         fg.bp_state.ftov_msgs[0] = np.ones(10)
-
-#     with pytest.raises(
-#         ValueError,
-#         match=re.escape(
-#             "Given belief shape (10,) does not match expected shape (15,) for variable (0, 15)."
-#         ),
-#     ):
-#         fg.bp_state.ftov_msgs[vg[0]] = np.ones(10)
-
-#     with pytest.raises(
-#         ValueError, match=re.escape("Expected messages shape (15,). Got (10,)")
-#     ):
-#         graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10))
-
-#     ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15))
-#     with pytest.raises(
-#         TypeError, match=re.escape("'FToVMessages' object is not subscriptable")
-#     ):
-#         ftov_msgs[(10,)]
-
-
-# def test_evidence():
-#     vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
-#     fg = graph.FactorGraph(vg)
-#     fg.add_factor(
-#         variables=[vg[0]],
-#         factor_configs=np.arange(10)[:, None],
-#         name="test",
-#     )
-#     with pytest.raises(
-#         ValueError, match=re.escape("Expected evidence shape (15,). Got (10,).")
-#     ):
-#         graph.Evidence(fg_state=fg.fg_state, value=np.zeros(10))
-
-#     evidence = graph.Evidence(fg_state=fg.fg_state, value=np.zeros(15))
-#     assert jnp.all(evidence.value == jnp.zeros(15))
-
-
-# def test_bp():
-#     vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
-#     fg = graph.FactorGraph(vg)
-#     fg.add_factor(
-#         variables=[vg[0]],
-#         factor_configs=np.arange(10)[:, None],
-#         name="test",
-#     )
-#     bp = graph.BP(fg.bp_state, temperature=0)
-#     bp_arrays = bp.update()
-#     bp_arrays = bp.update(
-#         bp_arrays=bp_arrays,
-#         ftov_msgs_updates={vg[0]: np.zeros(15)},
-#     )
-#     bp_arrays = bp.run_bp(bp_arrays, num_iters=1)
-#     bp_arrays = replace(bp_arrays, log_potentials=jnp.zeros((10)))
-#     bp_state = bp.to_bp_state(bp_arrays)
-#     assert bp_state.fg_state == fg.fg_state
+    with pytest.raises(
+        ValueError,
+        match=re.escape("Expected log potentials shape (10,) for factor group test."),
+    ):
+        fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15))
+
+    with pytest.raises(
+        ValueError,
+        match=re.escape("Invalid name (0, 15) for log potentials updates."),
+    ):
+        fg.bp_state.log_potentials[vg[0]] = np.zeros(10)
+
+    with pytest.raises(
+        ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)")
+    ):
+        graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15))
+
+    log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10))
+    assert jnp.all(log_potentials["test"] == jnp.zeros(10))
+
+
+def test_ftov_msgs():
+    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
+    fg = graph.FactorGraph(vg)
+    factor_group = enumeration.EnumerationFactorGroup(
+        variables_for_factors=[[vg[0]]],
+        factor_configs=np.arange(10)[:, None],
+    )
+    fg.add_factor_group(factor_group, name="test")
+
+    with pytest.raises(
+        ValueError,
+        match=re.escape("Invalid names for setting messages"),
+    ):
+        fg.bp_state.ftov_msgs[0] = np.ones(10)
+
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            "Given belief shape (10,) does not match expected shape (15,) for variable (0, 15)."
+        ),
+    ):
+        fg.bp_state.ftov_msgs[vg[0]] = np.ones(10)
+
+    with pytest.raises(
+        ValueError, match=re.escape("Expected messages shape (15,). Got (10,)")
+    ):
+        graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10))
+
+    ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15))
+    with pytest.raises(
+        TypeError, match=re.escape("'FToVMessages' object is not subscriptable")
+    ):
+        ftov_msgs[(10,)]
+
+
+def test_evidence():
+    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
+    fg = graph.FactorGraph(vg)
+    factor_group = enumeration.EnumerationFactorGroup(
+        variables_for_factors=[[vg[0]]],
+        factor_configs=np.arange(10)[:, None],
+    )
+    fg.add_factor_group(factor_group, name="test")
+
+    with pytest.raises(
+        ValueError, match=re.escape("Expected evidence shape (15,). Got (10,).")
+    ):
+        graph.Evidence(fg_state=fg.fg_state, value=np.zeros(10))
+
+    evidence = graph.Evidence(fg_state=fg.fg_state, value=np.zeros(15))
+    assert jnp.all(evidence.value == jnp.zeros(15))
+
+
+def test_bp():
+    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
+    fg = graph.FactorGraph(vg)
+    factor_group = enumeration.EnumerationFactorGroup(
+        variables_for_factors=[[vg[0]]],
+        factor_configs=np.arange(10)[:, None],
+    )
+    fg.add_factor_group(factor_group, name="test")
+
+    bp = graph.BP(fg.bp_state, temperature=0)
+    bp_arrays = bp.update()
+    bp_arrays = bp.update(
+        bp_arrays=bp_arrays,
+        ftov_msgs_updates={vg[0]: np.zeros(15)},
+    )
+    bp_arrays = bp.run_bp(bp_arrays, num_iters=1)
+    bp_arrays = replace(bp_arrays, log_potentials=jnp.zeros((10)))
+    bp_state = bp.to_bp_state(bp_arrays)
+    assert bp_state.fg_state == fg.fg_state

From 033d17661c32f8708910843e547a9f32ca11af49 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Sat, 23 Apr 2022 01:14:38 +0000
Subject: [PATCH 14/35] Docstring

---
 examples/gmrf.py             |  1 -
 pgmax/factors/enumeration.py |  3 ++-
 pgmax/fg/graph.py            | 35 +++++++++++++++++------------------
 pgmax/fg/groups.py           |  6 +++---
 pgmax/groups/variables.py    |  9 +++++++++
 5 files changed, 31 insertions(+), 23 deletions(-)

diff --git a/examples/gmrf.py b/examples/gmrf.py
index 363aca85..51977d38 100644
--- a/examples/gmrf.py
+++ b/examples/gmrf.py
@@ -88,7 +88,6 @@
 )
 fg.add_factor_group(factor_group, name="diagonal0")
 
-
 factor_group = enumeration.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii - 1, jj + 1]]
diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py
index b0610c5a..d461d3ab 100644
--- a/pgmax/factors/enumeration.py
+++ b/pgmax/factors/enumeration.py
@@ -1,5 +1,6 @@
 """Defines an enumeration factor"""
 
+import functools
 from dataclasses import dataclass
 from typing import List, Mapping, Sequence, Tuple, Union
 
@@ -262,7 +263,7 @@ def _compile_enumeration_wiring_numba(
                 )
 
 
-# @functools.partial(jax.jit, static_argnames=("num_val_configs", "temperature"))
+@functools.partial(jax.jit, static_argnames=("num_val_configs", "temperature"))
 def pass_enum_fac_to_var_messages(
     vtof_msgs: jnp.ndarray,
     factor_configs_edge_states: jnp.ndarray,
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 6f804be2..5295d3ae 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -108,7 +108,7 @@ def add_factor(self, factor: nodes.Factor, name: Optional[str] = None) -> None:
 
         Args:
             factor: The factor to be added to the factor graph.
-            name: Optional name of the FactorGroup.
+            name: Optional name of the SingleFactorGroup created.
         """
         factor_group = groups.SingleFactorGroup(
             variables_for_factors=[factor.variables],
@@ -350,14 +350,12 @@ class FactorGraphState:
     wiring: OrderedDict[type, nodes.Wiring]
 
     def __post_init__(self):
-        for this_field in self.__dataclass_fields__:
-            if isinstance(getattr(self, this_field), np.ndarray):
-                getattr(self, this_field).flags.writeable = False
+        for field in self.__dataclass_fields__:
+            if isinstance(getattr(self, field), np.ndarray):
+                getattr(self, field).flags.writeable = False
 
-            if isinstance(getattr(self, this_field), Mapping):
-                object.__setattr__(
-                    self, this_field, MappingProxyType(getattr(self, this_field))
-                )
+            if isinstance(getattr(self, field), Mapping):
+                object.__setattr__(self, field, MappingProxyType(getattr(self, field)))
 
 
 @dataclass(frozen=True, eq=False)
@@ -756,9 +754,9 @@ class BPArrays:
     evidence: Union[np.ndarray, jnp.ndarray]
 
     def __post_init__(self):
-        for this_field in self.__dataclass_fields__:
-            if isinstance(getattr(self, this_field), np.ndarray):
-                getattr(self, this_field).flags.writeable = False
+        for field in self.__dataclass_fields__:
+            if isinstance(getattr(self, field), np.ndarray):
+                getattr(self, field).flags.writeable = False
 
     def tree_flatten(self):
         return jax.tree_util.tree_flatten(asdict(self))
@@ -981,16 +979,16 @@ def to_bp_state(bp_arrays: BPArrays) -> BPState:
         )
 
     def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
-        """Function that recovers meaningful structured data from internal flat data array
+        """Function that returns unflattened beliefs from the flat beliefs
 
         Args:
-            variable_groups: TODO
-
-        Returns:
-            Meaningful structured data, with structure matching that of self.variable_group_container.
+            flat_beliefs: Flattened array of beliefs
+            variable_groups: All the variable groups in the FactorGraph.
 
         Raises:
-            ValueError: if flat_data is not of the right shape
+            ValueError: If
+                (1) flat_beliefs is not one dimensional
+                (2) flat_beliefs is not of the right shape
         """
 
         if flat_beliefs.ndim != 1:
@@ -998,6 +996,7 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
                 f"Can only unflatten 1D array. Got a {flat_beliefs.ndim}D array."
             )
 
+        # TODO: make sure this is not too slow
         num_variables = 0
         num_variable_states = 0
         for variable_group in variable_groups:
@@ -1020,7 +1019,7 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
             use_num_states = True
         else:
             raise ValueError(
-                f"flat_data should be either of shape (num_variables(={num_variables}),), "
+                f"flat_beliefs should be either of shape (num_variables(={num_variables}),), "
                 f"or (num_variable_states(={num_variable_states}),). "
                 f"Got {flat_beliefs.shape}"
             )
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index b777f5a4..62c23af1 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -27,10 +27,10 @@
 @dataclass(frozen=True, eq=False)
 class VariableGroup:
     """Class to represent a group of variables.
+    Each variable is represented via a tuple of the form (variable hash/name, number of states)
 
-    All variables in the group are assumed to have the same size. Additionally, the
-    variables are indexed by a variable name, and can be retrieved by direct indexing (even indexing
-    a sequence of variable names) of the VariableGroup.
+    Attributes:
+        random_hash: Hash of the VariableGroup
     """
 
     def __post_init__(self):
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index b7e0d66b..3813a364 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -168,6 +168,15 @@ def variables(self) -> List[Tuple]:
         return list(zip(self.variable_names, self.num_states))
 
     def __getitem__(self, val):
+        """Given a variable name retrieve the associated variable, returned via a tuple of the form
+        (variable name, number of states)
+
+        Args:
+            val: a variable index or slice
+
+        Returns:
+            The queried variable
+        """
         if val not in self.variable_names:
             raise ValueError(f"Variable {val} is not in VariableDict")
         return (val, self.num_states[0])

From 89767c81a39c0d6b4c241ab6ea4ac07f60ced1bd Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Mon, 25 Apr 2022 21:18:51 +0000
Subject: [PATCH 15/35] Coverage

---
 pgmax/fg/graph.py         | 46 ++++++++++++++++++---------------------
 pgmax/fg/groups.py        |  2 ++
 pgmax/groups/variables.py | 34 +++++++++++++++++++----------
 tests/fg/test_graph.py    | 28 ++++++++++++++++++++++--
 tests/fg/test_groups.py   | 15 +++++++++++++
 5 files changed, 86 insertions(+), 39 deletions(-)

diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 5295d3ae..3dd0ebb1 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -35,6 +35,7 @@
 from pgmax.bp import infer
 from pgmax.factors import FAC_TO_VAR_UPDATES
 from pgmax.fg import groups, nodes
+from pgmax.groups import variables
 from pgmax.utils import cached_property
 
 
@@ -61,6 +62,25 @@ def __post_init__(self):
         if isinstance(self.variables, groups.VariableGroup):
             self.variables = [self.variables]
 
+        # Check ids are unique
+        vg_names = []
+        vg_array_names = []
+        for variable_group in self.variables:
+            vg_name = variable_group.__hash__()
+            if vg_name in vg_names:
+                raise ValueError("Two objects have the same name")
+            vg_names.append(vg_name)
+            if isinstance(variable_group, variables.NDVariableArray):
+                start_name, end_name = (
+                    vg_name,
+                    variable_group.variable_names.flatten()[-1],
+                )
+                for var_array_name in vg_array_names:
+                    start_name2, end_name2 = var_array_name
+                    if max(start_name, start_name2) <= min(end_name, end_name2):
+                        raise ValueError("Two NDVariableArrays have overlapping names")
+                vg_array_names.append((start_name, end_name))
+
         # Useful objects to build the FactorGraph
         self._factor_types_to_groups: OrderedDict[
             Type, List[groups.FactorGroup]
@@ -91,7 +111,7 @@ def __post_init__(self):
 
         self._num_var_states = vars_num_states_cumsum
         self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {}
-        print("2", time.time() - start)
+        print("Init", time.time() - start)
 
     def __hash__(self) -> int:
         all_factor_groups = tuple(
@@ -469,9 +489,6 @@ def __getitem__(self, name: Any) -> np.ndarray:
             The queried log potentials.
         """
         value = cast(np.ndarray, self.value)
-        if not isinstance(name, Hashable):
-            name = frozenset(name)
-
         if name in self.fg_state.named_factor_groups:
             factor_group = self.fg_state.named_factor_groups[name]
             start = self.fg_state.factor_group_to_potentials_starts[factor_group]
@@ -496,9 +513,6 @@ def __setitem__(
             data: Array containing the log potentials for the named factor group
                 or the factor.
         """
-        if not isinstance(name, Hashable):
-            name = frozenset(name)
-
         object.__setattr__(
             self,
             "value",
@@ -996,7 +1010,6 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
                 f"Can only unflatten 1D array. Got a {flat_beliefs.ndim}D array."
             )
 
-        # TODO: make sure this is not too slow
         num_variables = 0
         num_variable_states = 0
         for variable_group in variable_groups:
@@ -1004,25 +1017,10 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
             num_variables += len(variables)
             num_variable_states += sum([variable[1] for variable in variables])
 
-            # if isinstance(variable_group, vgroup.NDVariableArray):
-            #     num_variables += variable_group.num_states.size
-            #     num_variable_states += variable_group.num_states.sum()
-            # elif isinstance(variable_group, vgroup.VariableDict):
-            #     num_variables += len(variable_group.variables)
-            #     num_variable_states += (
-            #         len(variable_group.variables) * variable_group.variables[0].num_states
-            #     )
-
         if flat_beliefs.shape[0] == num_variables:
             use_num_states = False
         elif flat_beliefs.shape[0] == num_variable_states:
             use_num_states = True
-        else:
-            raise ValueError(
-                f"flat_beliefs should be either of shape (num_variables(={num_variables}),), "
-                f"or (num_variable_states(={num_variable_states}),). "
-                f"Got {flat_beliefs.shape}"
-            )
 
         beliefs = {}
         start = 0
@@ -1030,10 +1028,8 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
             variables = variable_group.variables
             if not use_num_states:
                 length = len(variables)
-                # length = variable_group.num_states.sum()
             else:
                 length = sum([variable[1] for variable in variables])
-                # length = variable_group.num_states.size
 
             beliefs[variable_group] = variable_group.unflatten(
                 flat_beliefs[start : start + length]
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index 62c23af1..192eee84 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -34,6 +34,8 @@ class VariableGroup:
     """
 
     def __post_init__(self):
+        # Overwite default hash to have larger differences
+        random.seed(id(self))
         random_hash = random.randint(0, 2**63)
         object.__setattr__(self, "random_hash", random_hash)
 
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index 3813a364..feff8467 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -16,31 +16,34 @@ class NDVariableArray(groups.VariableGroup):
     """Subclass of VariableGroup for n-dimensional grids of variables.
 
     Args:
-        num_states: The size of the variables in this variable group
-        shape: a tuple specifying the size of each dimension of the grid (similar to
+        shape: Tuple specifying the size of each dimension of the grid (similar to
             the notion of a NumPy ndarray shape)
+        num_states: An integer or an array specifying the number of states of the
+            variables in this VariableGroup
     """
 
     shape: Tuple[int, ...]
-    num_states: np.ndarray
+    num_states: Union[int, np.ndarray]
 
     def __post_init__(self):
         super().__post_init__()
 
-        if np.isscalar(self.num_states) and np.issubdtype(type(np.int64(10)), int):
+        if np.isscalar(self.num_states):
             num_states = np.full(self.shape, fill_value=self.num_states)
             object.__setattr__(self, "num_states", num_states)
         elif isinstance(self.num_states, np.ndarray) and np.issubdtype(
             self.num_states.dtype, int
         ):
             if self.num_states.shape != self.shape:
-                raise ValueError("Should be same shape")
+                raise ValueError(
+                    f"Expected num_states shape {self.shape}. Got {self.num_states.shape}."
+                )
         else:
             raise ValueError("num_states entries should be of type np.int")
 
     def __getitem__(
-        self, val: Union[int, tuple, slice]
-    ) -> Union[Tuple[int, int], List[Tuple[int]]]:
+        self, val: Union[int, slice, Tuple]
+    ) -> Union[Tuple[int, int], List[Tuple]]:
         """Given an index or a slice, retrieve the associated variable(s).
         Each variable is returned via a tuple of the form (variable hash, number of states)
 
@@ -52,10 +55,13 @@ def __getitem__(
         Returns:
             A single variable or a list of variables
         """
-        result = (self.variable_names[val], self.num_states[val])
-        if isinstance(val, slice):
-            return list(zip(result))
-        return result
+        assert isinstance(self.num_states, np.ndarray)
+        if np.isscalar(self.variable_names[val]):
+            return (self.variable_names[val], self.num_states[val])
+        else:
+            vars_names = self.variable_names[val].flatten()
+            vars_num_states = self.num_states[val].flatten()
+            return list(zip(vars_names, vars_num_states))
 
     @cached_property
     def variables(self) -> List[Tuple]:
@@ -65,6 +71,7 @@ def variables(self) -> List[Tuple]:
         Returns:
             List of variables in the VariableGroup
         """
+        assert isinstance(self.num_states, np.ndarray)
         vars_names = self.variable_names.flatten()
         vars_num_states = self.num_states.flatten()
         return list(zip(vars_names, vars_num_states))
@@ -76,7 +83,6 @@ def variable_names(self) -> np.ndarray:
         Returns:
             Array of variables names.
         """
-        # Overwite default hash as it does not give enough spacing across consecutive objects
         indices = np.reshape(np.arange(np.product(self.shape)), self.shape)
         return self.__hash__() + indices
 
@@ -93,6 +99,8 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
         Raises:
             ValueError: If the data is not of the correct shape.
         """
+        assert isinstance(self.num_states, np.ndarray)
+
         # TODO: what should we do for different number of states -> look at mask_array
         if data.shape != self.shape and data.shape != self.shape + (
             self.num_states.max(),
@@ -120,6 +128,8 @@ def unflatten(
                 (1) flat_data is not a 1D array
                 (2) flat_data is not of the right shape
         """
+        assert isinstance(self.num_states, np.ndarray)
+
         if flat_data.ndim != 1:
             raise ValueError(
                 f"Can only unflatten 1D array. Got a {flat_data.ndim}D array."
diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py
index 79ee0895..09c766a0 100644
--- a/tests/fg/test_graph.py
+++ b/tests/fg/test_graph.py
@@ -1,6 +1,7 @@
 import re
 from dataclasses import replace
 
+import jax
 import jax.numpy as jnp
 import numpy as np
 import pytest
@@ -12,6 +13,16 @@
 
 
 def test_factor_graph():
+    vg = vgroup.NDVariableArray(num_states=2, shape=(10, 10))
+    with pytest.raises(ValueError, match="Two objects have the same name"):
+        fg = graph.FactorGraph(variables=[vg, vg])
+
+    vg = vgroup.NDVariableArray(num_states=2, shape=(10, 10))
+    vg2 = vgroup.NDVariableArray(num_states=2, shape=(10, 10))
+    object.__setattr__(vg2, "random_hash", vg.__hash__() + 10)
+    with pytest.raises(ValueError, match="Two NDVariableArrays have overlapping names"):
+        fg = graph.FactorGraph(variables=[vg, vg2])
+
     # TODO: remove factor graph name
 
     vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
@@ -147,10 +158,10 @@ def test_ftov_msgs():
 
 
 def test_evidence():
-    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
+    vg = vgroup.VariableDict(variable_names=("a",), num_states=15)
     fg = graph.FactorGraph(vg)
     factor_group = enumeration.EnumerationFactorGroup(
-        variables_for_factors=[[vg[0]]],
+        variables_for_factors=[[vg["a"]]],
         factor_configs=np.arange(10)[:, None],
     )
     fg.add_factor_group(factor_group, name="test")
@@ -163,6 +174,19 @@ def test_evidence():
     evidence = graph.Evidence(fg_state=fg.fg_state, value=np.zeros(15))
     assert jnp.all(evidence.value == jnp.zeros(15))
 
+    vg2 = vgroup.VariableDict(variable_names=("b",), num_states=15)
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            "Got evidence for a variable or a variable group not in the FactorGraph!"
+        ),
+    ):
+        graph.update_evidence(
+            jax.device_put(evidence.value),
+            {vg2["b"]: jax.device_put(np.zeros(15))},
+            fg.fg_state,
+        )
+
 
 def test_bp():
     vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py
index d6e18e73..b12b49d3 100644
--- a/tests/fg/test_groups.py
+++ b/tests/fg/test_groups.py
@@ -50,6 +50,21 @@ def test_variable_dict():
 
 
 def test_nd_variable_array():
+    num_states = np.full((2, 3), fill_value=2)
+    with pytest.raises(
+        ValueError, match=re.escape("Expected num_states shape (2, 2). Got (2, 3).")
+    ):
+        vgroup.NDVariableArray(shape=(2, 2), num_states=num_states)
+
+    num_states = np.full((2, 3), fill_value=2, dtype=np.float32)
+    with pytest.raises(
+        ValueError, match=re.escape("num_states entries should be of type np.int")
+    ):
+        vgroup.NDVariableArray(shape=(2, 2), num_states=num_states)
+
+    variable_group = vgroup.NDVariableArray(shape=(5, 5), num_states=2)
+    assert len(variable_group[:3, :3]) == 9
+
     variable_group = vgroup.NDVariableArray(shape=(2, 2), num_states=3)
     with pytest.raises(
         ValueError,

From 1ccfcf5de56644688a9cf90900020a575017ef3d Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Mon, 25 Apr 2022 21:20:15 +0000
Subject: [PATCH 16/35] Coverage

---
 pgmax/fg/graph.py | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 3dd0ebb1..dd38740a 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -935,7 +935,7 @@ def run_bp(
             ftov_msgs, edges_num_states, max_msg_size
         )
 
-        # @jax.checkpoint
+        @jax.checkpoint
         def update(msgs: jnp.ndarray, _) -> Tuple[jnp.ndarray, None]:
             # Compute new variable to factor messages by message passing
             vtof_msgs = infer.pass_var_to_fac_messages(
@@ -961,8 +961,7 @@ def update(msgs: jnp.ndarray, _) -> Tuple[jnp.ndarray, None]:
             # update the factor to variable messages
             delta_msgs = ftov_msgs - msgs
             msgs = msgs + (1 - damping) * delta_msgs
-            # Normalize and clip these damped, updated messages before
-            # returning them.
+            # Normalize and clip these damped, updated messages before returning them.
             msgs = infer.normalize_and_clip_msgs(msgs, edges_num_states, max_msg_size)
             return msgs, None
 

From ecaab6c8196a4b113f70e258d00d8bc5f553f5b1 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Mon, 25 Apr 2022 21:35:06 +0000
Subject: [PATCH 17/35] Coverage 100%

---
 pgmax/fg/graph.py      | 28 +---------------------------
 tests/fg/test_graph.py |  6 ++++++
 2 files changed, 7 insertions(+), 27 deletions(-)

diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index dd38740a..6cfbac10 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -997,38 +997,12 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
         Args:
             flat_beliefs: Flattened array of beliefs
             variable_groups: All the variable groups in the FactorGraph.
-
-        Raises:
-            ValueError: If
-                (1) flat_beliefs is not one dimensional
-                (2) flat_beliefs is not of the right shape
         """
-
-        if flat_beliefs.ndim != 1:
-            raise ValueError(
-                f"Can only unflatten 1D array. Got a {flat_beliefs.ndim}D array."
-            )
-
-        num_variables = 0
-        num_variable_states = 0
-        for variable_group in variable_groups:
-            variables = variable_group.variables
-            num_variables += len(variables)
-            num_variable_states += sum([variable[1] for variable in variables])
-
-        if flat_beliefs.shape[0] == num_variables:
-            use_num_states = False
-        elif flat_beliefs.shape[0] == num_variable_states:
-            use_num_states = True
-
         beliefs = {}
         start = 0
         for variable_group in variable_groups:
             variables = variable_group.variables
-            if not use_num_states:
-                length = len(variables)
-            else:
-                length = sum([variable[1] for variable in variables])
+            length = sum([variable[1] for variable in variables])
 
             beliefs[variable_group] = variable_group.unflatten(
                 flat_beliefs[start : start + length]
diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py
index 09c766a0..c3b3a853 100644
--- a/tests/fg/test_graph.py
+++ b/tests/fg/test_graph.py
@@ -107,6 +107,12 @@ def test_log_potentials():
     ):
         fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15))
 
+    with pytest.raises(
+        ValueError,
+        match=re.escape("Invalid name new_test for log potentials updates."),
+    ):
+        fg.bp_state.log_potentials["new_test"] = jnp.zeros((1, 15))
+
     with pytest.raises(
         ValueError,
         match=re.escape("Invalid name (0, 15) for log potentials updates."),

From cbd136bf8241430db734668bfa6bb87d96f86a25 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Mon, 25 Apr 2022 23:00:42 +0000
Subject: [PATCH 18/35] Remove factor group names

---
 examples/gmrf.py                     | 36 +++++++-----
 examples/ising_model.py              |  4 +-
 examples/pmp_binary_deconvolution.py |  2 +-
 pgmax/fg/graph.py                    | 83 +++++++++++-----------------
 pgmax/fg/groups.py                   | 10 ++++
 tests/fg/test_graph.py               | 47 +++++++---------
 tests/fg/test_wiring.py              | 20 +++----
 tests/test_pgmax.py                  | 34 ++++++------
 8 files changed, 114 insertions(+), 122 deletions(-)

diff --git a/examples/gmrf.py b/examples/gmrf.py
index 51977d38..6601e5c7 100644
--- a/examples/gmrf.py
+++ b/examples/gmrf.py
@@ -38,8 +38,8 @@
 
 # %%
 # Load saved log potentials
-log_potentials = dict(**np.load("example_data/gmrf_log_potentials.npz"))
-n_clones = log_potentials.pop("n_clones")
+grmf_log_potentials = dict(**np.load("example_data/gmrf_log_potentials.npz"))
+n_clones = grmf_log_potentials.pop("n_clones")
 p_contour = jax.device_put(np.repeat(data["p_contour"], n_clones))
 prototype_targets = jax.device_put(
     np.array(
@@ -59,48 +59,55 @@
 
 # %%
 # Add top-down factors
-factor_group = enumeration.PairwiseFactorGroup(
+top_down = enumeration.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii + 1, jj]]
         for ii in range(M - 1)
         for jj in range(N)
     ],
 )
-fg.add_factor_group(factor_group, name="top_down")
+fg.add_factor_group(top_down)
 
 # Add left-right factors
-factor_group = enumeration.PairwiseFactorGroup(
+left_right = enumeration.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii, jj + 1]]
         for ii in range(M)
         for jj in range(N - 1)
     ],
 )
-fg.add_factor_group(factor_group, name="left_right")
+fg.add_factor_group(left_right)
 
 # Add diagonal factors
-factor_group = enumeration.PairwiseFactorGroup(
+diagonal0 = enumeration.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii + 1, jj + 1]]
         for ii in range(M - 1)
         for jj in range(N - 1)
     ],
 )
-fg.add_factor_group(factor_group, name="diagonal0")
+fg.add_factor_group(diagonal0)
 
-factor_group = enumeration.PairwiseFactorGroup(
+diagonal1 = enumeration.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii - 1, jj + 1]]
         for ii in range(1, M)
         for jj in range(N - 1)
     ],
 )
-fg.add_factor_group(factor_group, name="diagonal1")
+fg.add_factor_group(diagonal1)
 
 # %%
 bp = graph.BP(fg.bp_state, temperature=1.0)
 
 # %%
+log_potentials = {
+    top_down: grmf_log_potentials["top_down"],
+    left_right: grmf_log_potentials["left_right"],
+    diagonal0: grmf_log_potentials["diagonal0"],
+    diagonal1: grmf_log_potentials["diagonal1"],
+}
+
 n_plots = 5
 indices = np.random.permutation(noisy_images.shape[0])[:n_plots]
 fig, ax = plt.subplots(n_plots, 3, figsize=(30, 10 * n_plots))
@@ -199,10 +206,10 @@ def update(step, batch_noisy_images, batch_target_images, opt_state):
 # %%
 opt_state = init_fun(
     {
-        "top_down": np.random.randn(num_states, num_states),
-        "left_right": np.random.randn(num_states, num_states),
-        "diagonal0": np.random.randn(num_states, num_states),
-        "diagonal1": np.random.randn(num_states, num_states),
+        top_down: np.random.randn(num_states, num_states),
+        left_right: np.random.randn(num_states, num_states),
+        diagonal0: np.random.randn(num_states, num_states),
+        diagonal1: np.random.randn(num_states, num_states),
     }
 )
 
@@ -227,7 +234,6 @@ def update(step, batch_noisy_images, batch_target_images, opt_state):
             pbar.update()
             pbar.set_postfix(loss=value)
 
-
 batch_indices = indices[idx * batch_size : (idx + 1) * batch_size]
 batch_noisy_images, batch_target_images = (
     noisy_images_train[:10],
diff --git a/examples/ising_model.py b/examples/ising_model.py
index b2f5aef7..52e3d02c 100644
--- a/examples/ising_model.py
+++ b/examples/ising_model.py
@@ -43,7 +43,7 @@
     variables_for_factors=variables_for_factors,
     log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
 )
-fg.add_factor_group(factor_group, name="factors")
+fg.add_factor_group(factor_group)
 
 # %% [markdown]
 # ### Run inference and visualize results
@@ -88,7 +88,7 @@ def loss(log_potentials_updates, evidence_updates):
 
 # %%
 grads = log_potentials_grads(
-    {"factors": jnp.eye(2)},
+    {factor_group: jnp.eye(2)},
     {variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))},
 )
 
diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py
index 6cdb84f1..d37a0c9f 100644
--- a/examples/pmp_binary_deconvolution.py
+++ b/examples/pmp_binary_deconvolution.py
@@ -195,7 +195,7 @@ def plot_images(images, display=True, nr=None):
 fg.add_factor_group(OR_factor_group)
 print("Time", time.time() - start)
 
-for factor_type, factor_groups in fg.factor_groups.items():
+for factor_type, factor_groups in fg._factor_types_to_groups.items():
     if len(factor_groups) > 0:
         assert len(factor_groups) == 1
         print(f"The factor graph contains {factor_groups[0].num_factors} {factor_type}")
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 6cfbac10..815d8ba2 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -62,7 +62,7 @@ def __post_init__(self):
         if isinstance(self.variables, groups.VariableGroup):
             self.variables = [self.variables]
 
-        # Check ids are unique
+        # Check variable groups are unique
         vg_names = []
         vg_array_names = []
         for variable_group in self.variables:
@@ -108,52 +108,34 @@ def __post_init__(self):
                 )
             )
             vars_num_states_cumsum += vg_num_states_cumsum[-1]
-
         self._num_var_states = vars_num_states_cumsum
-        self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {}
         print("Init", time.time() - start)
 
     def __hash__(self) -> int:
-        all_factor_groups = tuple(
-            [
-                factor_group
-                for factor_groups_per_type in self._factor_types_to_groups.values()
-                for factor_group in factor_groups_per_type
-            ]
-        )
-        return hash(all_factor_groups)
+        return hash(self.factor_groups)
 
-    def add_factor(self, factor: nodes.Factor, name: Optional[str] = None) -> None:
+    def add_factor(self, factor: nodes.Factor) -> None:
         """Function to add a single Factor to the FactorGraph.
 
         Args:
             factor: The factor to be added to the factor graph.
-            name: Optional name of the SingleFactorGroup created.
         """
         factor_group = groups.SingleFactorGroup(
             variables_for_factors=[factor.variables],
             factor=factor,
         )
-        self.add_factor_group(factor_group, name)
+        self.add_factor_group(factor_group)
 
-    def add_factor_group(
-        self, factor_group: groups.FactorGroup, name: Optional[str] = None
-    ) -> None:
+    def add_factor_group(self, factor_group: groups.FactorGroup) -> None:
         """Add a FactorGroup to the FactorGraph, by updating the FactorGraphState.
 
         Args:
             factor_group: The FactorGroup to be added to the FactorGraph.
-            name: Optional name of the FactorGroup.
 
         Raises:
             ValueError: If the factor group with the same name or a Factor involving the same variables
                 already exists in the FactorGraph.
         """
-        if name in self._named_factor_groups:
-            raise ValueError(
-                f"A factor group with the name {name} already exists. Please choose a different name!"
-            )
-
         factor_type = factor_group.factor_type
         for var_names_for_factor in factor_group.variables_for_factors:
             var_names = frozenset(var_names_for_factor)
@@ -165,9 +147,6 @@ def add_factor_group(
 
         self._factor_types_to_groups[factor_type].append(factor_group)
 
-        if name is not None:
-            self._named_factor_groups[name] = factor_group
-
     @functools.lru_cache(None)
     def compute_offsets(self) -> None:
         """Compute factor messages offsets for the factor types and factor groups
@@ -284,20 +263,28 @@ def factors(self) -> OrderedDict[Type, Tuple[nodes.Factor, ...]]:
                     tuple(
                         [
                             factor
-                            for factor_group in self.factor_groups[factor_type]
+                            for factor_group in self._factor_types_to_groups[
+                                factor_type
+                            ]
                             for factor in factor_group.factors
                         ]
                     ),
                 )
-                for factor_type in self.factor_groups
+                for factor_type in self._factor_types_to_groups
             ]
         )
         return factors
 
     @property
-    def factor_groups(self) -> OrderedDict[Type, List[groups.FactorGroup]]:
+    def factor_groups(self) -> Tuple[groups.FactorGroup, ...]:
         """Tuple of factor groups in the factor graph"""
-        return self._factor_types_to_groups
+        return tuple(
+            [
+                factor_group
+                for factor_groups_per_type in self._factor_types_to_groups.values()
+                for factor_group in factor_groups_per_type
+            ]
+        )
 
     @cached_property
     def fg_state(self) -> FactorGraphState:
@@ -314,7 +301,7 @@ def fg_state(self) -> FactorGraphState:
             vars_to_starts=self._vars_to_starts,
             num_var_states=self._num_var_states,
             total_factor_num_states=self._total_factor_num_states,
-            named_factor_groups=copy.copy(self._named_factor_groups),
+            factor_groups=self.factor_groups,
             factor_type_to_msgs_range=copy.copy(self._factor_type_to_msgs_range),
             factor_type_to_potentials_range=copy.copy(
                 self._factor_type_to_potentials_range
@@ -350,7 +337,7 @@ class FactorGraphState:
             contains evidence to the variable.
         num_var_states: Total number of variable states.
         total_factor_num_states: Size of the flat ftov messages array.
-        named_factor_groups: Maps the names of named factor groups to the corresponding factor groups.
+        factor_groups: Factor groups in the FactorGraph
         factor_type_to_msgs_range: Maps factors types to their start and end indices in the flat ftov messages.
         factor_type_to_potentials_range: Maps factor types to their start and end indices in the flat log potentials.
         factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials.
@@ -362,7 +349,7 @@ class FactorGraphState:
     vars_to_starts: Mapping[Tuple[int, int], int]
     num_var_states: int
     total_factor_num_states: int
-    named_factor_groups: Mapping[Hashable, groups.FactorGroup]
+    factor_groups: Tuple[groups.FactorGroup, ...]
     factor_type_to_msgs_range: OrderedDict[type, Tuple[int, int]]
     factor_type_to_potentials_range: OrderedDict[type, Tuple[int, int]]
     factor_group_to_potentials_starts: OrderedDict[groups.FactorGroup, int]
@@ -430,15 +417,13 @@ def update_log_potentials(
         (1) Provided log_potentials shape does not match the expected log_potentials shape.
         (2) Provided name is not valid for log_potentials updates.
     """
-    for name, data in updates.items():
-        if name in fg_state.named_factor_groups:
-            factor_group = fg_state.named_factor_groups[name]
-
+    for factor_group, data in updates.items():
+        if factor_group in fg_state.factor_groups:
             flat_data = factor_group.flatten(data)
             if flat_data.shape != factor_group.factor_group_log_potentials.shape:
                 raise ValueError(
                     f"Expected log potentials shape {factor_group.factor_group_log_potentials.shape} "
-                    f"for factor group {name}. Got incompatible data shape {data.shape}."
+                    f"for factor group. Got incompatible data shape {data.shape}."
                 )
 
             start = fg_state.factor_group_to_potentials_starts[factor_group]
@@ -446,7 +431,7 @@ def update_log_potentials(
                 flat_data
             )
         else:
-            raise ValueError(f"Invalid name {name} for log potentials updates.")
+            raise ValueError("Invalid FactorGroup for log potentials updates.")
 
     return log_potentials
 
@@ -478,38 +463,34 @@ def __post_init__(self):
 
             object.__setattr__(self, "value", self.value)
 
-    def __getitem__(self, name: Any) -> np.ndarray:
-        """Function to query log potentials for a named factor group or a factor.
+    def __getitem__(self, factor_group: groups.FactorGroup) -> np.ndarray:
+        """Function to query log potentials for a FactorGroup.
 
         Args:
-            name: Name of a named factor group, or a frozenset containing the set
-                of connected variables for the queried factor.
+            factor_group: Queried FactorGroup
 
         Returns:
             The queried log potentials.
         """
         value = cast(np.ndarray, self.value)
-        if name in self.fg_state.named_factor_groups:
-            factor_group = self.fg_state.named_factor_groups[name]
+        if factor_group in self.fg_state.factor_groups:
             start = self.fg_state.factor_group_to_potentials_starts[factor_group]
             log_potentials = value[
                 start : start + factor_group.factor_group_log_potentials.shape[0]
             ]
         else:
-            raise ValueError(f"Invalid name {name} for log potentials updates.")
-
+            raise ValueError("Invalid FactorGroup for log potentials updates.")
         return log_potentials
 
     def __setitem__(
         self,
-        name: Any,
+        factor_group: Any,
         data: Union[np.ndarray, jnp.ndarray],
     ):
         """Set the log potentials for a named factor group or a factor.
 
         Args:
-            name: Name of a named factor group, or a frozenset containing the set
-                of connected variables for the queried factor.
+            factor_group: FactorGroup
             data: Array containing the log potentials for the named factor group
                 or the factor.
         """
@@ -519,7 +500,7 @@ def __setitem__(
             np.asarray(
                 update_log_potentials(
                     jax.device_put(self.value),
-                    {name: jax.device_put(data)},
+                    {factor_group: jax.device_put(data)},
                     self.fg_state,
                 )
             ),
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index 192eee84..7c56deef 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -101,6 +101,7 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any:
         )
 
 
+@total_ordering
 @dataclass(frozen=True, eq=False)
 class FactorGroup:
     """Class to represent a group of Factors.
@@ -127,6 +128,15 @@ def __post_init__(self):
         if len(self.variables_for_factors) == 0:
             raise ValueError("Cannot create a FactorGroup with no Factor.")
 
+    def __hash__(self):
+        return id(self)
+
+    def __eq__(self, other):
+        return hash(self) == hash(other)
+
+    def __lt__(self, other):
+        return hash(self) < hash(other)
+
     def __getitem__(self, variables: Sequence[Tuple[int, int]]) -> Any:
         """Function to query individual factors in the factor group
 
diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py
index c3b3a853..e80b32a3 100644
--- a/tests/fg/test_graph.py
+++ b/tests/fg/test_graph.py
@@ -13,17 +13,14 @@
 
 
 def test_factor_graph():
-    vg = vgroup.NDVariableArray(num_states=2, shape=(10, 10))
+    vg1 = vgroup.NDVariableArray(num_states=2, shape=(10, 10))
     with pytest.raises(ValueError, match="Two objects have the same name"):
-        fg = graph.FactorGraph(variables=[vg, vg])
+        fg = graph.FactorGraph(variables=[vg1, vg1])
 
-    vg = vgroup.NDVariableArray(num_states=2, shape=(10, 10))
     vg2 = vgroup.NDVariableArray(num_states=2, shape=(10, 10))
-    object.__setattr__(vg2, "random_hash", vg.__hash__() + 10)
+    object.__setattr__(vg2, "random_hash", vg1.__hash__() + 10)
     with pytest.raises(ValueError, match="Two NDVariableArrays have overlapping names"):
-        fg = graph.FactorGraph(variables=[vg, vg2])
-
-    # TODO: remove factor graph name
+        fg = graph.FactorGraph(variables=[vg1, vg2])
 
     vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
     fg = graph.FactorGraph(vg)
@@ -32,13 +29,7 @@ def test_factor_graph():
         factor_configs=np.arange(15)[:, None],
         log_potentials=np.zeros(15),
     )
-    fg.add_factor(factor, name="test")
-
-    with pytest.raises(
-        ValueError,
-        match="A factor group with the name test already exists. Please choose a different name",
-    ):
-        fg.add_factor(factor, name="test")
+    fg.add_factor(factor)
 
     with pytest.raises(
         ValueError,
@@ -76,10 +67,10 @@ def test_bp_state():
         factor_configs=np.arange(15)[:, None],
         log_potentials=np.zeros(15),
     )
-    fg0.add_factor(factor, name="test")
+    fg0.add_factor(factor)
 
     fg1 = graph.FactorGraph(vg)
-    fg1.add_factor(factor, name="test")
+    fg1.add_factor(factor)
 
     with pytest.raises(
         ValueError,
@@ -99,23 +90,27 @@ def test_log_potentials():
         variables_for_factors=[[vg[0]]],
         factor_configs=np.arange(10)[:, None],
     )
-    fg.add_factor_group(factor_group, name="test")
+    fg.add_factor_group(factor_group)
 
     with pytest.raises(
         ValueError,
-        match=re.escape("Expected log potentials shape (10,) for factor group test."),
+        match=re.escape("Expected log potentials shape (10,) for factor group."),
     ):
-        fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15))
+        fg.bp_state.log_potentials[factor_group] = jnp.zeros((1, 15))
 
     with pytest.raises(
         ValueError,
-        match=re.escape("Invalid name new_test for log potentials updates."),
+        match=re.escape("Invalid FactorGroup for log potentials updates."),
     ):
-        fg.bp_state.log_potentials["new_test"] = jnp.zeros((1, 15))
+        factor_group2 = enumeration.EnumerationFactorGroup(
+            variables_for_factors=[[vg[0]]],
+            factor_configs=np.arange(10)[:, None],
+        )
+        fg.bp_state.log_potentials[factor_group2] = jnp.zeros((1, 15))
 
     with pytest.raises(
         ValueError,
-        match=re.escape("Invalid name (0, 15) for log potentials updates."),
+        match=re.escape("Invalid FactorGroup for log potentials updates."),
     ):
         fg.bp_state.log_potentials[vg[0]] = np.zeros(10)
 
@@ -125,7 +120,7 @@ def test_log_potentials():
         graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15))
 
     log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10))
-    assert jnp.all(log_potentials["test"] == jnp.zeros(10))
+    assert jnp.all(log_potentials[factor_group] == jnp.zeros(10))
 
 
 def test_ftov_msgs():
@@ -135,7 +130,7 @@ def test_ftov_msgs():
         variables_for_factors=[[vg[0]]],
         factor_configs=np.arange(10)[:, None],
     )
-    fg.add_factor_group(factor_group, name="test")
+    fg.add_factor_group(factor_group)
 
     with pytest.raises(
         ValueError,
@@ -170,7 +165,7 @@ def test_evidence():
         variables_for_factors=[[vg["a"]]],
         factor_configs=np.arange(10)[:, None],
     )
-    fg.add_factor_group(factor_group, name="test")
+    fg.add_factor_group(factor_group)
 
     with pytest.raises(
         ValueError, match=re.escape("Expected evidence shape (15,). Got (10,).")
@@ -201,7 +196,7 @@ def test_bp():
         variables_for_factors=[[vg[0]]],
         factor_configs=np.arange(10)[:, None],
     )
-    fg.add_factor_group(factor_group, name="test")
+    fg.add_factor_group(factor_group)
 
     bp = graph.BP(fg.bp_state, temperature=0)
     bp_arrays = bp.update()
diff --git a/tests/fg/test_wiring.py b/tests/fg/test_wiring.py
index 5b68a974..3df31cb5 100644
--- a/tests/fg/test_wiring.py
+++ b/tests/fg/test_wiring.py
@@ -25,7 +25,7 @@ def test_wiring_with_PairwiseFactorGroup():
     )
     fg.add_factor_group(factor_group)
 
-    factor_group = fg.factor_groups[enumeration_factor.EnumerationFactor][0]
+    factor_group = fg.factor_groups[0]
     object.__setattr__(
         factor_group, "factor_configs", factor_group.factor_configs[:, :1]
     )
@@ -41,7 +41,7 @@ def test_wiring_with_PairwiseFactorGroup():
         variables_for_factors=[[A[idx], B[idx]] for idx in range(10)]
     )
     fg1.add_factor_group(factor_group)
-    assert len(fg1.factor_groups[enumeration_factor.EnumerationFactor]) == 1
+    assert len(fg1.factor_groups) == 1
 
     # FactorGraph with multiple PairwiseFactorGroup
     fg2 = graph.FactorGraph(variables=[A, B])
@@ -50,7 +50,7 @@ def test_wiring_with_PairwiseFactorGroup():
             variables_for_factors=[[A[idx], B[idx]]]
         )
         fg2.add_factor_group(factor_group)
-    assert len(fg2.factor_groups[enumeration_factor.EnumerationFactor]) == 10
+    assert len(fg2.factor_groups) == 10
 
     # FactorGraph with multiple SingleFactorGroup
     fg3 = graph.FactorGraph(variables=[A, B])
@@ -61,7 +61,7 @@ def test_wiring_with_PairwiseFactorGroup():
             log_potentials=np.zeros((4,)),
         )
         fg3.add_factor(factor)
-    assert len(fg3.factor_groups[enumeration_factor.EnumerationFactor]) == 10
+    assert len(fg3.factor_groups) == 10
 
     assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors)
 
@@ -100,7 +100,7 @@ def test_wiring_with_ORFactorGroup():
         variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)],
     )
     fg1.add_factor_group(factor_group)
-    assert len(fg1.factor_groups[logical_factor.ORFactor]) == 1
+    assert len(fg1.factor_groups) == 1
 
     # FactorGraph with multiple ORFactorGroup
     fg2 = graph.FactorGraph(variables=[A, B, C])
@@ -109,7 +109,7 @@ def test_wiring_with_ORFactorGroup():
             variables_for_factors=[[A[idx], B[idx], C[idx]]],
         )
         fg2.add_factor_group(factor_group)
-    assert len(fg2.factor_groups[logical_factor.ORFactor]) == 10
+    assert len(fg2.factor_groups) == 10
 
     # FactorGraph with multiple SingleFactorGroup
     fg3 = graph.FactorGraph(variables=[A, B, C])
@@ -118,7 +118,7 @@ def test_wiring_with_ORFactorGroup():
             variables=[A[idx], B[idx], C[idx]],
         )
         fg3.add_factor(factor)
-    assert len(fg3.factor_groups[logical_factor.ORFactor]) == 10
+    assert len(fg3.factor_groups) == 10
 
     assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors)
 
@@ -155,7 +155,7 @@ def test_wiring_with_ANDFactorGroup():
         variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)],
     )
     fg1.add_factor_group(factor_group)
-    assert len(fg1.factor_groups[logical_factor.ANDFactor]) == 1
+    assert len(fg1.factor_groups) == 1
 
     # FactorGraph with multiple ANDFactorGroup
     fg2 = graph.FactorGraph(variables=[A, B, C])
@@ -164,7 +164,7 @@ def test_wiring_with_ANDFactorGroup():
             variables_for_factors=[[A[idx], B[idx], C[idx]]],
         )
         fg2.add_factor_group(factor_group)
-    assert len(fg2.factor_groups[logical_factor.ANDFactor]) == 10
+    assert len(fg2.factor_groups) == 10
 
     # FactorGraph with multiple SingleFactorGroup
     fg3 = graph.FactorGraph(variables=[A, B, C])
@@ -173,7 +173,7 @@ def test_wiring_with_ANDFactorGroup():
             variables=[A[idx], B[idx], C[idx]],
         )
         fg3.add_factor(factor)
-    assert len(fg3.factor_groups[logical_factor.ANDFactor]) == 10
+    assert len(fg3.factor_groups) == 10
 
     assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors)
 
diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py
index e9067abe..c6a030d6 100644
--- a/tests/test_pgmax.py
+++ b/tests/test_pgmax.py
@@ -265,14 +265,14 @@ def create_valid_suppression_config_arr(suppression_diameter):
     for row in range(M - 1):
         for col in range(N - 1):
             if row != M - 2 and col != N - 2:
-                curr_names = [
+                curr_vars = [
                     grid_vars[0, row, col],
                     grid_vars[1, row, col],
                     grid_vars[0, row, col + 1],
                     grid_vars[1, row + 1, col],
                 ]
             elif row != M - 2:
-                curr_names = [
+                curr_vars = [
                     grid_vars[0, row, col],
                     grid_vars[1, row, col],
                     additional_vars[0, row, col + 1],
@@ -280,7 +280,7 @@ def create_valid_suppression_config_arr(suppression_diameter):
                 ]
 
             elif col != N - 2:
-                curr_names = [
+                curr_vars = [
                     grid_vars[0, row, col],
                     grid_vars[1, row, col],
                     grid_vars[0, row, col + 1],
@@ -288,7 +288,7 @@ def create_valid_suppression_config_arr(suppression_diameter):
                 ]
 
             else:
-                curr_names = [
+                curr_vars = [
                     grid_vars[0, row, col],
                     grid_vars[1, row, col],
                     additional_vars[0, row, col + 1],
@@ -296,54 +296,54 @@ def create_valid_suppression_config_arr(suppression_diameter):
                 ]
             if row % 2 == 0:
                 factor = EnumerationFactor(
-                    variables=curr_names,
+                    variables=curr_vars,
                     factor_configs=valid_configs_non_supp,
                     log_potentials=np.zeros(
                         valid_configs_non_supp.shape[0], dtype=float
                     ),
                 )
-                fg.add_factor(factor, name=(row, col))
+                fg.add_factor(factor)
             else:
                 factor = EnumerationFactor(
-                    variables=curr_names,
+                    variables=curr_vars,
                     factor_configs=valid_configs_non_supp,
                     log_potentials=np.zeros(
                         valid_configs_non_supp.shape[0], dtype=float
                     ),
                 )
-                fg.add_factor(factor, name=(row, col))
+                fg.add_factor(factor)
 
     # Create an EnumerationFactorGroup for vertical suppression factors
-    vert_suppression_names: List[List[Tuple[Any, ...]]] = []
+    vert_suppression_vars: List[List[Tuple[Any, ...]]] = []
     for col in range(N):
         for start_row in range(M - SUPPRESSION_DIAMETER):
             if col != N - 1:
-                vert_suppression_names.append(
+                vert_suppression_vars.append(
                     [
                         grid_vars[0, r, col]
                         for r in range(start_row, start_row + SUPPRESSION_DIAMETER)
                     ]
                 )
             else:
-                vert_suppression_names.append(
+                vert_suppression_vars.append(
                     [
                         additional_vars[0, r, col]
                         for r in range(start_row, start_row + SUPPRESSION_DIAMETER)
                     ]
                 )
 
-    horz_suppression_names: List[List[Tuple[Any, ...]]] = []
+    horz_suppression_vars: List[List[Tuple[Any, ...]]] = []
     for row in range(M):
         for start_col in range(N - SUPPRESSION_DIAMETER):
             if row != M - 1:
-                horz_suppression_names.append(
+                horz_suppression_vars.append(
                     [
                         grid_vars[1, row, c]
                         for c in range(start_col, start_col + SUPPRESSION_DIAMETER)
                     ]
                 )
             else:
-                horz_suppression_names.append(
+                horz_suppression_vars.append(
                     [
                         additional_vars[1, row, c]
                         for c in range(start_col, start_col + SUPPRESSION_DIAMETER)
@@ -352,13 +352,13 @@ def create_valid_suppression_config_arr(suppression_diameter):
 
     # Add the suppression factors to the graph via kwargs
     factor_group = enumeration.EnumerationFactorGroup(
-        variables_for_factors=vert_suppression_names,
+        variables_for_factors=vert_suppression_vars,
         factor_configs=valid_configs_supp,
     )
     fg.add_factor_group(factor_group)
 
     factor_group = enumeration.EnumerationFactorGroup(
-        variables_for_factors=horz_suppression_names,
+        variables_for_factors=horz_suppression_vars,
         factor_configs=valid_configs_supp,
         log_potentials=np.zeros(valid_configs_supp.shape[0], dtype=float),
     )
@@ -419,7 +419,7 @@ def binary_connected_variables(
                 variables_for_factors=binary_connected_variables(28, 28, k_row, k_col),
                 log_potential_matrix=W_pot[:, :, k_row, k_col],
             )
-            fg.add_factor_group(factor_group, name=(k_row, k_col))
+            fg.add_factor_group(factor_group)
 
     # Assign evidence to pixel vars
     bp_state = fg.bp_state

From 22b604e7a8e78c622e6bbb8ca54ce142e872a40d Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Mon, 25 Apr 2022 23:05:11 +0000
Subject: [PATCH 19/35] Remove factor group names

---
 tests/fg/test_groups.py | 3 +++
 1 file changed, 3 insertions(+)

diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py
index b12b49d3..24d24400 100644
--- a/tests/fg/test_groups.py
+++ b/tests/fg/test_groups.py
@@ -94,6 +94,9 @@ def test_nd_variable_array():
 
 def test_enumeration_factor_group():
     vg = vgroup.NDVariableArray(shape=(2, 2), num_states=3)
+    vg_bis = vgroup.NDVariableArray(shape=(2, 2), num_states=3)
+    vg < vg_bis
+
     with pytest.raises(
         ValueError,
         match=re.escape("Expected log potentials shape: (1,) or (2, 1). Got (3, 2)"),

From 87fcfd9eb000f308ac00e28e75922385962c9c5b Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Tue, 26 Apr 2022 02:27:18 +0000
Subject: [PATCH 20/35] Modify hash + add_factors

---
 examples/gmrf.py                     |  8 +--
 examples/ising_model.py              |  4 +-
 examples/pmp_binary_deconvolution.py |  8 +--
 examples/rbm.py                      | 16 ++---
 examples/rcn.py                      |  4 +-
 pgmax/factors/logical.py             |  4 +-
 pgmax/fg/graph.py                    | 93 +++++++++++-----------------
 pgmax/fg/groups.py                   | 16 ++---
 pgmax/groups/variables.py            | 39 ++++++++----
 tests/factors/test_and.py            | 18 +++---
 tests/factors/test_or.py             | 14 +++--
 tests/fg/test_graph.py               | 80 ++++++++++++------------
 tests/fg/test_groups.py              | 31 ++++++++--
 tests/fg/test_wiring.py              | 40 ++++++------
 tests/test_pgmax.py                  | 12 ++--
 15 files changed, 202 insertions(+), 185 deletions(-)

diff --git a/examples/gmrf.py b/examples/gmrf.py
index 6601e5c7..eb86a5e9 100644
--- a/examples/gmrf.py
+++ b/examples/gmrf.py
@@ -66,7 +66,7 @@
         for jj in range(N)
     ],
 )
-fg.add_factor_group(top_down)
+fg.add_factors(top_down)
 
 # Add left-right factors
 left_right = enumeration.PairwiseFactorGroup(
@@ -76,7 +76,7 @@
         for jj in range(N - 1)
     ],
 )
-fg.add_factor_group(left_right)
+fg.add_factors(left_right)
 
 # Add diagonal factors
 diagonal0 = enumeration.PairwiseFactorGroup(
@@ -86,7 +86,7 @@
         for jj in range(N - 1)
     ],
 )
-fg.add_factor_group(diagonal0)
+fg.add_factors(diagonal0)
 
 diagonal1 = enumeration.PairwiseFactorGroup(
     variables_for_factors=[
@@ -95,7 +95,7 @@
         for jj in range(N - 1)
     ],
 )
-fg.add_factor_group(diagonal1)
+fg.add_factors(diagonal1)
 
 # %%
 bp = graph.BP(fg.bp_state, temperature=1.0)
diff --git a/examples/ising_model.py b/examples/ising_model.py
index 52e3d02c..3f530a33 100644
--- a/examples/ising_model.py
+++ b/examples/ising_model.py
@@ -29,7 +29,7 @@
 
 # %%
 variables = vgroup.NDVariableArray(num_states=2, shape=(50, 50))
-fg = graph.FactorGraph(variables=variables)
+fg = graph.FactorGraph(variable_groups=variables)
 
 variables_for_factors = []
 for ii in range(50):
@@ -43,7 +43,7 @@
     variables_for_factors=variables_for_factors,
     log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
 )
-fg.add_factor_group(factor_group)
+fg.add_factors(factor_group)
 
 # %% [markdown]
 # ### Run inference and visualize results
diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py
index d37a0c9f..a0f24406 100644
--- a/examples/pmp_binary_deconvolution.py
+++ b/examples/pmp_binary_deconvolution.py
@@ -138,12 +138,12 @@ def plot_images(images, display=True, nr=None):
 print("Time", time.time() - start)
 
 # %% [markdown]
-# For computation efficiency, we add large FactorGroups via `fg.add_factor_group` instead of adding individual Factors
+# For computation efficiency, we add large FactorGroups via `fg.add_factors` instead of adding individual Factors
 
 # %%
 start = time.time()
 # Factor graph
-fg = graph.FactorGraph(variables=[S, W, SW, X])
+fg = graph.FactorGraph(variable_groups=[S, W, SW, X])
 print(time.time() - start)
 
 # Define the ANDFactors
@@ -181,7 +181,7 @@ def plot_images(images, display=True, nr=None):
 
 # Add ANDFactorGroup, which is computationally efficient
 AND_factor_group = logical.ANDFactorGroup(variables_for_ANDFactors)
-fg.add_factor_group(AND_factor_group)
+fg.add_factors(AND_factor_group)
 print(time.time() - start)
 
 # Define the ORFactors
@@ -192,7 +192,7 @@ def plot_images(images, display=True, nr=None):
 
 # Add ORFactorGroup, which is computationally efficient
 OR_factor_group = logical.ORFactorGroup(variables_for_ORFactors)
-fg.add_factor_group(OR_factor_group)
+fg.add_factors(OR_factor_group)
 print("Time", time.time() - start)
 
 for factor_type, factor_groups in fg._factor_types_to_groups.items():
diff --git a/examples/rbm.py b/examples/rbm.py
index 8ea595ea..cdc51390 100644
--- a/examples/rbm.py
+++ b/examples/rbm.py
@@ -52,7 +52,7 @@
 # Initialize factor graph
 hidden_variables = vgroup.NDVariableArray(num_states=2, shape=bh.shape)
 visible_variables = vgroup.NDVariableArray(num_states=2, shape=bv.shape)
-fg = graph.FactorGraph(variables=[hidden_variables, visible_variables])
+fg = graph.FactorGraph(variable_groups=[hidden_variables, visible_variables])
 print("Time", time.time() - start)
 
 # %% [markdown]
@@ -69,14 +69,14 @@
     factor_configs=np.arange(2)[:, None],
     log_potentials=np.stack([np.zeros_like(bh), bh], axis=1),
 )
-fg.add_factor_group(hidden_unaries)
+fg.add_factors(hidden_unaries)
 
 visible_unaries = enumeration.EnumerationFactorGroup(
     variables_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])],
     factor_configs=np.arange(2)[:, None],
     log_potentials=np.stack([np.zeros_like(bv), bv], axis=1),
 )
-fg.add_factor_group(visible_unaries)
+fg.add_factors(visible_unaries)
 
 # Add pairwise factors
 log_potential_matrix = np.zeros(W.shape + (2, 2)).reshape((-1, 2, 2))
@@ -90,9 +90,9 @@
     ],
     log_potential_matrix=log_potential_matrix,
 )
-fg.add_factor_group(pairwise_factors)
+fg.add_factors(pairwise_factors)
 
-# # %snakeviz fg.add_factor_group(factory=enumeration.PairwiseFactorGroup, variables_for_factors=v, log_potential_matrix=log_potential_matrix,)
+# # %snakeviz fg.add_factors(factory=enumeration.PairwiseFactorGroup, variables_for_factors=v, log_potential_matrix=log_potential_matrix,)
 print("Time", time.time() - start)
 
 
@@ -115,7 +115,7 @@
 #         factor_configs=np.arange(2)[:, None],
 #         log_potentials=np.array([0, bh[ii]]),
 #     )
-#     fg.add_factor(factor)
+#     fg.add_factors(factor=factor)
 #
 # for jj in range(bv.shape[0]):
 #     factor = enumeration_factor.EnumerationFactor(
@@ -123,7 +123,7 @@
 #         factor_configs=np.arange(2)[:, None],
 #         log_potentials=np.array([0, bv[jj]]),
 #     )
-#     fg.add_factor(factor)
+#     fg.add_factors(factor=factor)
 #
 # # Add pairwise factors
 # factor_configs = np.array(list(itertools.product(np.arange(2), repeat=2)))
@@ -134,7 +134,7 @@
 #             factor_configs=factor_configs,
 #             log_potentials=np.array([0, 0, 0, W[ii, jj]]),
 #         )
-#         fg.add_factor(factor)
+#         fg.add_factors(factor=factor)
 # ~~~
 #
 # Once we have added the factors, we can run max-product LBP and get MAP decoding by
diff --git a/examples/rcn.py b/examples/rcn.py
index d7add7d7..2414c0e9 100644
--- a/examples/rcn.py
+++ b/examples/rcn.py
@@ -272,7 +272,7 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray:
 
 # %%
 start = time.time()
-fg = graph.FactorGraph(variables=variables_all_models)
+fg = graph.FactorGraph(variables_all_models)
 
 # Adding rcn model edges to the pgmax factor graph.
 for idx in range(edges.shape[0]):
@@ -285,7 +285,7 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray:
             factor_configs=valid_configs_list[r],
             log_potentials=np.zeros(valid_configs_list[r].shape[0]),
         )
-        fg.add_factor(factor)
+        fg.add_factors(factor=factor)
 
 end = time.time()
 print(f"Creating factors took {end-start:.3f} seconds.")
diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py
index 7f2d1ca6..bf8652cb 100644
--- a/pgmax/factors/logical.py
+++ b/pgmax/factors/logical.py
@@ -2,7 +2,7 @@
 
 import functools
 from dataclasses import dataclass, field
-from typing import List, Mapping, Optional, Sequence, Union
+from typing import List, Mapping, Optional, Sequence, Tuple, Union
 
 import jax
 import jax.numpy as jnp
@@ -141,7 +141,7 @@ def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring:
     @staticmethod
     def compile_wiring(
         variables_for_factors: Sequence[List],
-        vars_to_starts: Mapping[int, int],
+        vars_to_starts: Mapping[Tuple[int, int], int],
         edge_states_offset: int,
     ) -> LogicalWiring:
         """Compile a LogicalWiring for a LogicalFactor or a FactorGroup with LogicalFactors.
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 815d8ba2..9fbe794c 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -35,7 +35,6 @@
 from pgmax.bp import infer
 from pgmax.factors import FAC_TO_VAR_UPDATES
 from pgmax.fg import groups, nodes
-from pgmax.groups import variables
 from pgmax.utils import cached_property
 
 
@@ -53,33 +52,14 @@ class FactorGraph:
             this input, and the individual VariableGroups will need to be accessed by indexing.
     """
 
-    variables: Union[groups.VariableGroup, Sequence[groups.VariableGroup]]
+    variable_groups: Union[groups.VariableGroup, Sequence[groups.VariableGroup]]
 
     def __post_init__(self):
         import time
 
         start = time.time()
-        if isinstance(self.variables, groups.VariableGroup):
-            self.variables = [self.variables]
-
-        # Check variable groups are unique
-        vg_names = []
-        vg_array_names = []
-        for variable_group in self.variables:
-            vg_name = variable_group.__hash__()
-            if vg_name in vg_names:
-                raise ValueError("Two objects have the same name")
-            vg_names.append(vg_name)
-            if isinstance(variable_group, variables.NDVariableArray):
-                start_name, end_name = (
-                    vg_name,
-                    variable_group.variable_names.flatten()[-1],
-                )
-                for var_array_name in vg_array_names:
-                    start_name2, end_name2 = var_array_name
-                    if max(start_name, start_name2) <= min(end_name, end_name2):
-                        raise ValueError("Two NDVariableArrays have overlapping names")
-                vg_array_names.append((start_name, end_name))
+        if isinstance(self.variable_groups, groups.VariableGroup):
+            self.variable_groups = [self.variable_groups]
 
         # Useful objects to build the FactorGraph
         self._factor_types_to_groups: OrderedDict[
@@ -98,7 +78,7 @@ def __post_init__(self):
             Tuple[int, int], int
         ] = collections.OrderedDict()
         vars_num_states_cumsum = 0
-        for variable_group in self.variables:
+        for variable_group in self.variable_groups:
             vg_num_states = variable_group.num_states.flatten()
             vg_num_states_cumsum = np.insert(np.cumsum(vg_num_states), 0, 0)
             self._vars_to_starts.update(
@@ -114,28 +94,35 @@ def __post_init__(self):
     def __hash__(self) -> int:
         return hash(self.factor_groups)
 
-    def add_factor(self, factor: nodes.Factor) -> None:
-        """Function to add a single Factor to the FactorGraph.
-
-        Args:
-            factor: The factor to be added to the factor graph.
-        """
-        factor_group = groups.SingleFactorGroup(
-            variables_for_factors=[factor.variables],
-            factor=factor,
-        )
-        self.add_factor_group(factor_group)
-
-    def add_factor_group(self, factor_group: groups.FactorGroup) -> None:
-        """Add a FactorGroup to the FactorGraph, by updating the FactorGraphState.
+    def add_factors(
+        self,
+        factor_group: Optional[groups.FactorGroup] = None,
+        factor: Optional[nodes.Factor] = None,
+    ) -> None:
+        """Add a FactorGroup or a single Factor to the FactorGraph, by updating the FactorGraphState.
 
         Args:
             factor_group: The FactorGroup to be added to the FactorGraph.
+            factor: The Factor to be added to the factor graph.
 
         Raises:
-            ValueError: If the factor group with the same name or a Factor involving the same variables
-                already exists in the FactorGraph.
+            ValueError: If
+                (1) Both a Factor and a FactorGroup are added
+                (2) The FactorGroup involving the same variables already exists in the FactorGraph.
         """
+        if factor is None and factor_group is None:
+            raise ValueError("A Factor or a FactorGroup is required")
+
+        if factor is not None and factor_group is not None:
+            raise ValueError("Cannot simultaneously add a Factor and a FactorGroup")
+
+        if factor is not None:
+            factor_group = groups.SingleFactorGroup(
+                variables_for_factors=[factor.variables],
+                factor=factor,
+            )
+        assert factor_group is not None
+
         factor_type = factor_group.factor_type
         for var_names_for_factor in factor_group.variables_for_factors:
             var_names = frozenset(var_names_for_factor)
@@ -294,10 +281,10 @@ def fg_state(self) -> FactorGraphState:
         log_potentials = np.concatenate(
             [self.log_potentials[factor_type] for factor_type in self.log_potentials]
         )
-        assert isinstance(self.variables, list)
+        assert isinstance(self.variable_groups, list)
 
         return FactorGraphState(
-            variable_groups=self.variables,
+            variable_groups=self.variable_groups,
             vars_to_starts=self._vars_to_starts,
             num_var_states=self._num_var_states,
             total_factor_num_states=self._total_factor_num_states,
@@ -606,7 +593,7 @@ def __setitem__(
     @typing.overload
     def __setitem__(
         self,
-        names: Tuple[int, int],
+        variable: Tuple[int, int],
         data: Union[np.ndarray, jnp.ndarray],
     ) -> None:
         """Spreading beliefs at a variable to all connected factors
@@ -618,14 +605,14 @@ def __setitem__(
                 variable.
         """
 
-    def __setitem__(self, names, data) -> None:
+    def __setitem__(self, variable, data) -> None:
         object.__setattr__(
             self,
             "value",
             np.asarray(
                 update_ftov_msgs(
                     jax.device_put(self.value),
-                    {names: jax.device_put(data)},
+                    {variable: jax.device_put(data)},
                     self.fg_state,
                 )
             ),
@@ -1002,18 +989,12 @@ def get_beliefs(bp_arrays: BPArrays) -> Dict[Hashable, Any]:
             beliefs: Beliefs returned by belief propagation.
         """
 
-        def compute_flat_beliefs(bp_arrays, var_states_for_edges):
-            flat_beliefs = (
-                jax.device_put(bp_arrays.evidence)
-                .at[jax.device_put(var_states_for_edges)]
-                .add(bp_arrays.ftov_msgs)
-            )
-            return flat_beliefs
-
-        return unflatten_beliefs(
-            compute_flat_beliefs(bp_arrays, var_states_for_edges),
-            bp_state.fg_state.variable_groups,
+        flat_beliefs = (
+            jax.device_put(bp_arrays.evidence)
+            .at[jax.device_put(var_states_for_edges)]
+            .add(bp_arrays.ftov_msgs)
         )
+        return unflatten_beliefs(flat_beliefs, bp_state.fg_state.variable_groups)
 
     bp = BeliefPropagation(
         init=functools.partial(update, None),
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index 7c56deef..62692ffd 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -1,7 +1,6 @@
 """A module containing the base classes for variable and factor groups in a Factor Graph."""
 
 import inspect
-import random
 from dataclasses import dataclass, field
 from functools import total_ordering
 from typing import (
@@ -22,25 +21,18 @@
 import pgmax.fg.nodes as nodes
 from pgmax.utils import cached_property
 
+MAX_SIZE = 1e16
+
 
 @total_ordering
 @dataclass(frozen=True, eq=False)
 class VariableGroup:
     """Class to represent a group of variables.
     Each variable is represented via a tuple of the form (variable hash/name, number of states)
-
-    Attributes:
-        random_hash: Hash of the VariableGroup
     """
 
-    def __post_init__(self):
-        # Overwite default hash to have larger differences
-        random.seed(id(self))
-        random_hash = random.randint(0, 2**63)
-        object.__setattr__(self, "random_hash", random_hash)
-
     def __hash__(self):
-        return self.random_hash
+        return id(self) * int(MAX_SIZE)
 
     def __eq__(self, other):
         return hash(self) == hash(other)
@@ -63,7 +55,7 @@ def __getitem__(self, val):
         )
 
     @cached_property
-    def variables(self) -> Tuple[Any, int]:
+    def variables(self) -> List[Tuple]:
         """Function that returns the list of all variables in the VariableGroup.
         Each variable is represented by a tuple of the form (variable hash/name, number of states)
 
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index feff8467..7b50167e 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -26,10 +26,13 @@ class NDVariableArray(groups.VariableGroup):
     num_states: Union[int, np.ndarray]
 
     def __post_init__(self):
-        super().__post_init__()
+        if np.prod(self.shape) > groups.MAX_SIZE:
+            raise ValueError(
+                f"Currently only support NDVariableArray of size smaller than {groups.MAX_SIZE}. Got {np.prod(self.shape)}"
+            )
 
         if np.isscalar(self.num_states):
-            num_states = np.full(self.shape, fill_value=self.num_states)
+            num_states = np.full(self.shape, fill_value=self.num_states, dtype=np.int32)
             object.__setattr__(self, "num_states", num_states)
         elif isinstance(self.num_states, np.ndarray) and np.issubdtype(
             self.num_states.dtype, int
@@ -43,7 +46,7 @@ def __post_init__(self):
 
     def __getitem__(
         self, val: Union[int, slice, Tuple]
-    ) -> Union[Tuple[int, int], List[Tuple]]:
+    ) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
         """Given an index or a slice, retrieve the associated variable(s).
         Each variable is returned via a tuple of the form (variable hash, number of states)
 
@@ -159,14 +162,19 @@ class VariableDict(groups.VariableGroup):
     """
 
     variable_names: Tuple[Any, ...]
-    num_states: np.ndarray  # TODO: this should be an int converted to an array in __post_init__
+    num_states: int
 
     def __post_init__(self):
-        super().__post_init__()
-
-        num_states = np.full((len(self.variable_names),), fill_value=self.num_states)
+        num_states = np.full(
+            (len(self.variable_names),), fill_value=self.num_states, dtype=np.int32
+        )
         object.__setattr__(self, "num_states", num_states)
 
+        hash_and_names = tuple(
+            (self.__hash__(), var_name) for var_name in self.variable_names
+        )
+        object.__setattr__(self, "variable_names", hash_and_names)
+
     @cached_property
     def variables(self) -> List[Tuple]:
         """Function that returns the list of all variables in the VariableGroup.
@@ -175,9 +183,12 @@ def variables(self) -> List[Tuple]:
         Returns:
             List of variables in the VariableGroup
         """
-        return list(zip(self.variable_names, self.num_states))
+        assert isinstance(self.num_states, np.ndarray)
+        vars_names = list(self.variable_names)
+        vars_num_states = self.num_states.flatten()
+        return list(zip(vars_names, vars_num_states))
 
-    def __getitem__(self, val):
+    def __getitem__(self, val: Any) -> Tuple[Any, int]:
         """Given a variable name retrieve the associated variable, returned via a tuple of the form
         (variable name, number of states)
 
@@ -187,9 +198,11 @@ def __getitem__(self, val):
         Returns:
             The queried variable
         """
-        if val not in self.variable_names:
+        assert isinstance(self.num_states, np.ndarray)
+
+        if (self.__hash__(), val) not in self.variable_names:
             raise ValueError(f"Variable {val} is not in VariableDict")
-        return (val, self.num_states[0])
+        return ((self.__hash__(), val), self.num_states[0])
 
     def flatten(
         self, data: Mapping[Tuple[int, int], Union[np.ndarray, jnp.ndarray]]
@@ -209,6 +222,8 @@ def flatten(
                 (1) data is referring to a non-existing variable
                 (2) data is not of the correct shape
         """
+        assert isinstance(self.num_states, np.ndarray)
+
         for variable in data:
             if variable not in self.variables:
                 raise ValueError(
@@ -244,6 +259,8 @@ def unflatten(
                 (1) flat_data is not a 1D array
                 (2) flat_data is not of the right shape
         """
+        assert isinstance(self.num_states, np.ndarray)
+
         if flat_data.ndim != 1:
             raise ValueError(
                 f"Can only unflatten 1D array. Got a {flat_data.ndim}D array."
diff --git a/tests/factors/test_and.py b/tests/factors/test_and.py
index 147d57a0..fcbe6772 100644
--- a/tests/factors/test_and.py
+++ b/tests/factors/test_and.py
@@ -45,14 +45,18 @@ def test_run_bp_with_ANDFactors():
             num_states=2, shape=(num_parents.sum(),)
         )
         children_variables1 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
-        fg1 = graph.FactorGraph(variables=[parents_variables1, children_variables1])
+        fg1 = graph.FactorGraph(
+            variable_groups=[parents_variables1, children_variables1]
+        )
 
         # Graph 2
         parents_variables2 = vgroup.NDVariableArray(
             num_states=2, shape=(num_parents.sum(),)
         )
         children_variables2 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
-        fg2 = graph.FactorGraph(variables=[parents_variables2, children_variables2])
+        fg2 = graph.FactorGraph(
+            variable_groups=[parents_variables2, children_variables2]
+        )
 
         # Option 1: Define EnumerationFactors equivalent to the ANDFactors
         variables_for_factors1 = []
@@ -100,7 +104,7 @@ def test_run_bp_with_ANDFactors():
                     factor_configs=valid_configs,
                     log_potentials=np.zeros(valid_configs.shape[0]),
                 )
-                fg1.add_factor(enum_factor)
+                fg1.add_factors(factor=enum_factor)
             else:
                 if idx != 0:
                     # Add the second half of factors to FactorGraph2
@@ -109,7 +113,7 @@ def test_run_bp_with_ANDFactors():
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
-                    fg2.add_factor(enum_factor)
+                    fg2.add_factors(factor=enum_factor)
                 else:
                     # Add all the EnumerationFactors to FactorGraph1 for the first iter
                     enum_factor = EnumerationFactor(
@@ -117,7 +121,7 @@ def test_run_bp_with_ANDFactors():
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
-                    fg1.add_factor(enum_factor)
+                    fg1.add_factors(factor=enum_factor)
 
         # Option 2: Define the ANDFactors
         num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0)
@@ -141,10 +145,10 @@ def test_run_bp_with_ANDFactors():
                     )
         if idx != 0:
             factor_group = logical.ANDFactorGroup(variables_for_ANDFactors_fg1)
-            fg1.add_factor_group(factor_group)
+            fg1.add_factors(factor_group)
 
         factor_group = logical.ANDFactorGroup(variables_for_ANDFactors_fg2)
-        fg2.add_factor_group(factor_group)
+        fg2.add_factors(factor_group)
 
         # Run inference
         bp1 = graph.BP(fg1.bp_state, temperature=temperature)
diff --git a/tests/factors/test_or.py b/tests/factors/test_or.py
index 76cc5107..16236b35 100644
--- a/tests/factors/test_or.py
+++ b/tests/factors/test_or.py
@@ -45,14 +45,18 @@ def test_run_bp_with_ORFactors():
             num_states=2, shape=(num_parents.sum(),)
         )
         children_variables1 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
-        fg1 = graph.FactorGraph(variables=[parents_variables1, children_variables1])
+        fg1 = graph.FactorGraph(
+            variable_groups=[parents_variables1, children_variables1]
+        )
 
         # Graph 2
         parents_variables2 = vgroup.NDVariableArray(
             num_states=2, shape=(num_parents.sum(),)
         )
         children_variables2 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
-        fg2 = graph.FactorGraph(variables=[parents_variables2, children_variables2])
+        fg2 = graph.FactorGraph(
+            variable_groups=[parents_variables2, children_variables2]
+        )
 
         # Variable names for factors
         variables_for_factors1 = []
@@ -98,7 +102,7 @@ def test_run_bp_with_ORFactors():
                     factor_configs=valid_configs,
                     log_potentials=np.zeros(valid_configs.shape[0]),
                 )
-                fg1.add_factor(enum_factor)
+                fg1.add_factor(factor=enum_factor)
             else:
                 if idx != 0:
                     # Add the second half of factors to FactorGraph2
@@ -107,7 +111,7 @@ def test_run_bp_with_ORFactors():
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
-                    fg2.add_factor(enum_factor)
+                    fg2.add_factor(factor=enum_factor)
                 else:
                     # Add all the EnumerationFactors to FactorGraph1 for the first iter
                     enum_factor = EnumerationFactor(
@@ -115,7 +119,7 @@ def test_run_bp_with_ORFactors():
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
-                    fg1.add_factor(enum_factor)
+                    fg1.add_factor(factor=enum_factor)
 
         # Option 2: Define the ORFactors
         num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0)
diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py
index e80b32a3..bb26ce99 100644
--- a/tests/fg/test_graph.py
+++ b/tests/fg/test_graph.py
@@ -7,56 +7,52 @@
 import pytest
 
 from pgmax.factors import enumeration as enumeration_factor
-from pgmax.fg import graph, groups
-from pgmax.groups import enumeration, logical
+from pgmax.fg import graph
+from pgmax.groups import enumeration
 from pgmax.groups import variables as vgroup
 
 
 def test_factor_graph():
-    vg1 = vgroup.NDVariableArray(num_states=2, shape=(10, 10))
-    with pytest.raises(ValueError, match="Two objects have the same name"):
-        fg = graph.FactorGraph(variables=[vg1, vg1])
-
-    vg2 = vgroup.NDVariableArray(num_states=2, shape=(10, 10))
-    object.__setattr__(vg2, "random_hash", vg1.__hash__() + 10)
-    with pytest.raises(ValueError, match="Two NDVariableArrays have overlapping names"):
-        fg = graph.FactorGraph(variables=[vg1, vg2])
-
     vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
     fg = graph.FactorGraph(vg)
+
+    with pytest.raises(
+        ValueError,
+        match="A Factor or a FactorGroup is required",
+    ):
+        fg.add_factors(factor_group=None, factor=None)
+
     factor = enumeration_factor.EnumerationFactor(
         variables=[vg[0]],
         factor_configs=np.arange(15)[:, None],
         log_potentials=np.zeros(15),
     )
-    fg.add_factor(factor)
 
+    factor_group = enumeration.EnumerationFactorGroup(
+        variables_for_factors=[[vg[0]]],
+        factor_configs=np.arange(15)[:, None],
+        log_potentials=np.zeros(15),
+    )
     with pytest.raises(
         ValueError,
-        match=re.escape(
-            f"A Factor of type {enumeration_factor.EnumerationFactor} involving variables {frozenset([(0, 15)])} already exists."
-        ),
+        match="Cannot simultaneously add a Factor and a FactorGroup",
     ):
-        fg.add_factor(factor)
+        fg.add_factors(factor_group=factor_group, factor=factor)
 
+    fg.add_factors(factor=factor)
 
-def test_single_factor():
-    with pytest.raises(ValueError, match="Cannot create a FactorGroup with no Factor."):
-        logical.ORFactorGroup(variables_for_factors=[])
-
-    A = vgroup.NDVariableArray(num_states=2, shape=(10,))
-    B = vgroup.NDVariableArray(num_states=2, shape=(10,))
-
-    variables0 = (A[0], B[0])
-    variables1 = (A[1], B[1])
-    ORFactor = logical.ORFactorGroup(variables_for_factors=[variables0])
+    factor_group = enumeration.EnumerationFactorGroup(
+        variables_for_factors=[[vg[0]]],
+        factor_configs=np.arange(15)[:, None],
+        log_potentials=np.zeros(15),
+    )
     with pytest.raises(
-        ValueError, match="SingleFactorGroup should only contain one factor. Got 2"
+        ValueError,
+        match=re.escape(
+            f"A Factor of type {enumeration_factor.EnumerationFactor} involving variables {frozenset([((vg.__hash__(), 0), 15)])} already exists."
+        ),
     ):
-        groups.SingleFactorGroup(
-            variables_for_factors=[variables0, variables1],
-            factor=ORFactor,
-        )
+        fg.add_factors(factor_group)
 
 
 def test_bp_state():
@@ -67,10 +63,10 @@ def test_bp_state():
         factor_configs=np.arange(15)[:, None],
         log_potentials=np.zeros(15),
     )
-    fg0.add_factor(factor)
+    fg0.add_factors(factor=factor)
 
     fg1 = graph.FactorGraph(vg)
-    fg1.add_factor(factor)
+    fg1.add_factors(factor=factor)
 
     with pytest.raises(
         ValueError,
@@ -90,7 +86,7 @@ def test_log_potentials():
         variables_for_factors=[[vg[0]]],
         factor_configs=np.arange(10)[:, None],
     )
-    fg.add_factor_group(factor_group)
+    fg.add_factors(factor_group)
 
     with pytest.raises(
         ValueError,
@@ -130,7 +126,7 @@ def test_ftov_msgs():
         variables_for_factors=[[vg[0]]],
         factor_configs=np.arange(10)[:, None],
     )
-    fg.add_factor_group(factor_group)
+    fg.add_factors(factor_group)
 
     with pytest.raises(
         ValueError,
@@ -141,7 +137,7 @@ def test_ftov_msgs():
     with pytest.raises(
         ValueError,
         match=re.escape(
-            "Given belief shape (10,) does not match expected shape (15,) for variable (0, 15)."
+            f"Given belief shape (10,) does not match expected shape (15,) for variable (({vg.__hash__()}, 0), 15)."
         ),
     ):
         fg.bp_state.ftov_msgs[vg[0]] = np.ones(10)
@@ -159,13 +155,13 @@ def test_ftov_msgs():
 
 
 def test_evidence():
-    vg = vgroup.VariableDict(variable_names=("a",), num_states=15)
+    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
     fg = graph.FactorGraph(vg)
     factor_group = enumeration.EnumerationFactorGroup(
-        variables_for_factors=[[vg["a"]]],
+        variables_for_factors=[[vg[0]]],
         factor_configs=np.arange(10)[:, None],
     )
-    fg.add_factor_group(factor_group)
+    fg.add_factors(factor_group)
 
     with pytest.raises(
         ValueError, match=re.escape("Expected evidence shape (15,). Got (10,).")
@@ -175,7 +171,7 @@ def test_evidence():
     evidence = graph.Evidence(fg_state=fg.fg_state, value=np.zeros(15))
     assert jnp.all(evidence.value == jnp.zeros(15))
 
-    vg2 = vgroup.VariableDict(variable_names=("b",), num_states=15)
+    vg2 = vgroup.VariableDict(variable_names=(0,), num_states=15)
     with pytest.raises(
         ValueError,
         match=re.escape(
@@ -184,7 +180,7 @@ def test_evidence():
     ):
         graph.update_evidence(
             jax.device_put(evidence.value),
-            {vg2["b"]: jax.device_put(np.zeros(15))},
+            {vg2[0]: jax.device_put(np.zeros(15))},
             fg.fg_state,
         )
 
@@ -196,7 +192,7 @@ def test_bp():
         variables_for_factors=[[vg[0]]],
         factor_configs=np.arange(10)[:, None],
     )
-    fg.add_factor_group(factor_group)
+    fg.add_factors(factor_group)
 
     bp = graph.BP(fg.bp_state, temperature=0)
     bp_arrays = bp.update()
diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py
index 24d24400..82a2ea06 100644
--- a/tests/fg/test_groups.py
+++ b/tests/fg/test_groups.py
@@ -5,7 +5,8 @@
 import numpy as np
 import pytest
 
-from pgmax.groups import enumeration
+from pgmax.fg import groups
+from pgmax.groups import enumeration, logical
 from pgmax.groups import variables as vgroup
 
 
@@ -19,10 +20,10 @@ def test_variable_dict():
     with pytest.raises(
         ValueError,
         match=re.escape(
-            "Variable (2, 15) expects a data array of shape (15,) or (1,). Got (10,)"
+            f"Variable (({variable_dict.__hash__()}, 2), 15) expects a data array of shape (15,) or (1,). Got (10,)"
         ),
     ):
-        variable_dict.flatten({(2, 15): np.zeros(10)})
+        variable_dict.flatten({((variable_dict.__hash__(), 2), 15): np.zeros(10)})
 
     with pytest.raises(
         ValueError, match="Can only unflatten 1D array. Got a 2D array."
@@ -35,7 +36,10 @@ def test_variable_dict():
                 jax.tree_util.tree_multimap(
                     lambda x, y: jnp.all(x == y),
                     variable_dict.unflatten(jnp.zeros(3)),
-                    {(name, 15): np.zeros(1) for name in range(3)},
+                    {
+                        ((variable_dict.__hash__(), name), 15): np.zeros(1)
+                        for name in range(3)
+                    },
                 )
             )
         )
@@ -92,6 +96,25 @@ def test_nd_variable_array():
     assert jnp.all(variable_group.unflatten(np.zeros(12)) == jnp.zeros((2, 2, 3)))
 
 
+def test_single_factor():
+    with pytest.raises(ValueError, match="Cannot create a FactorGroup with no Factor."):
+        logical.ORFactorGroup(variables_for_factors=[])
+
+    A = vgroup.NDVariableArray(num_states=2, shape=(10,))
+    B = vgroup.NDVariableArray(num_states=2, shape=(10,))
+
+    variables0 = (A[0], B[0])
+    variables1 = (A[1], B[1])
+    ORFactor = logical.ORFactorGroup(variables_for_factors=[variables0])
+    with pytest.raises(
+        ValueError, match="SingleFactorGroup should only contain one factor. Got 2"
+    ):
+        groups.SingleFactorGroup(
+            variables_for_factors=[variables0, variables1],
+            factor=ORFactor,
+        )
+
+
 def test_enumeration_factor_group():
     vg = vgroup.NDVariableArray(shape=(2, 2), num_states=3)
     vg_bis = vgroup.NDVariableArray(shape=(2, 2), num_states=3)
diff --git a/tests/fg/test_wiring.py b/tests/fg/test_wiring.py
index 3df31cb5..53efeee3 100644
--- a/tests/fg/test_wiring.py
+++ b/tests/fg/test_wiring.py
@@ -19,11 +19,11 @@ def test_wiring_with_PairwiseFactorGroup():
     B = vgroup.NDVariableArray(num_states=2, shape=(10,))
 
     # First test that compile_wiring enforces the correct factor_edges_num_states shape
-    fg = graph.FactorGraph(variables=[A, B])
+    fg = graph.FactorGraph(variable_groups=[A, B])
     factor_group = enumeration.PairwiseFactorGroup(
         variables_for_factors=[[A[idx], B[idx]] for idx in range(10)]
     )
-    fg.add_factor_group(factor_group)
+    fg.add_factors(factor_group)
 
     factor_group = fg.factor_groups[0]
     object.__setattr__(
@@ -36,31 +36,31 @@ def test_wiring_with_PairwiseFactorGroup():
         factor_group.compile_wiring(fg._vars_to_starts)
 
     # FactorGraph with a single PairwiseFactorGroup
-    fg1 = graph.FactorGraph(variables=[A, B])
+    fg1 = graph.FactorGraph(variable_groups=[A, B])
     factor_group = enumeration.PairwiseFactorGroup(
         variables_for_factors=[[A[idx], B[idx]] for idx in range(10)]
     )
-    fg1.add_factor_group(factor_group)
+    fg1.add_factors(factor_group)
     assert len(fg1.factor_groups) == 1
 
     # FactorGraph with multiple PairwiseFactorGroup
-    fg2 = graph.FactorGraph(variables=[A, B])
+    fg2 = graph.FactorGraph(variable_groups=[A, B])
     for idx in range(10):
         factor_group = enumeration.PairwiseFactorGroup(
             variables_for_factors=[[A[idx], B[idx]]]
         )
-        fg2.add_factor_group(factor_group)
+        fg2.add_factors(factor_group)
     assert len(fg2.factor_groups) == 10
 
     # FactorGraph with multiple SingleFactorGroup
-    fg3 = graph.FactorGraph(variables=[A, B])
+    fg3 = graph.FactorGraph(variable_groups=[A, B])
     for idx in range(10):
         factor = enumeration_factor.EnumerationFactor(
             variables=[A[idx], B[idx]],
             factor_configs=np.array([[0, 0], [0, 1], [1, 0], [1, 1]]),
             log_potentials=np.zeros((4,)),
         )
-        fg3.add_factor(factor)
+        fg3.add_factors(factor=factor)
     assert len(fg3.factor_groups) == 10
 
     assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors)
@@ -95,29 +95,29 @@ def test_wiring_with_ORFactorGroup():
     C = vgroup.NDVariableArray(num_states=2, shape=(10,))
 
     # FactorGraph with a single ORFactorGroup
-    fg1 = graph.FactorGraph(variables=[A, B, C])
+    fg1 = graph.FactorGraph(variable_groups=[A, B, C])
     factor_group = logical.ORFactorGroup(
         variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)],
     )
-    fg1.add_factor_group(factor_group)
+    fg1.add_factors(factor_group)
     assert len(fg1.factor_groups) == 1
 
     # FactorGraph with multiple ORFactorGroup
-    fg2 = graph.FactorGraph(variables=[A, B, C])
+    fg2 = graph.FactorGraph(variable_groups=[A, B, C])
     for idx in range(10):
         factor_group = logical.ORFactorGroup(
             variables_for_factors=[[A[idx], B[idx], C[idx]]],
         )
-        fg2.add_factor_group(factor_group)
+        fg2.add_factors(factor_group)
     assert len(fg2.factor_groups) == 10
 
     # FactorGraph with multiple SingleFactorGroup
-    fg3 = graph.FactorGraph(variables=[A, B, C])
+    fg3 = graph.FactorGraph(variable_groups=[A, B, C])
     for idx in range(10):
         factor = logical_factor.ORFactor(
             variables=[A[idx], B[idx], C[idx]],
         )
-        fg3.add_factor(factor)
+        fg3.add_factors(factor=factor)
     assert len(fg3.factor_groups) == 10
 
     assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors)
@@ -150,29 +150,29 @@ def test_wiring_with_ANDFactorGroup():
     C = vgroup.NDVariableArray(num_states=2, shape=(10,))
 
     # FactorGraph with a single ANDFactorGroup
-    fg1 = graph.FactorGraph(variables=[A, B, C])
+    fg1 = graph.FactorGraph(variable_groups=[A, B, C])
     factor_group = logical.ANDFactorGroup(
         variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)],
     )
-    fg1.add_factor_group(factor_group)
+    fg1.add_factors(factor_group)
     assert len(fg1.factor_groups) == 1
 
     # FactorGraph with multiple ANDFactorGroup
-    fg2 = graph.FactorGraph(variables=[A, B, C])
+    fg2 = graph.FactorGraph(variable_groups=[A, B, C])
     for idx in range(10):
         factor_group = logical.ANDFactorGroup(
             variables_for_factors=[[A[idx], B[idx], C[idx]]],
         )
-        fg2.add_factor_group(factor_group)
+        fg2.add_factors(factor_group)
     assert len(fg2.factor_groups) == 10
 
     # FactorGraph with multiple SingleFactorGroup
-    fg3 = graph.FactorGraph(variables=[A, B, C])
+    fg3 = graph.FactorGraph(variable_groups=[A, B, C])
     for idx in range(10):
         factor = logical_factor.ANDFactor(
             variables=[A[idx], B[idx], C[idx]],
         )
-        fg3.add_factor(factor)
+        fg3.add_factors(factor=factor)
     assert len(fg3.factor_groups) == 10
 
     assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors)
diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py
index c6a030d6..643448a3 100644
--- a/tests/test_pgmax.py
+++ b/tests/test_pgmax.py
@@ -258,7 +258,7 @@ def create_valid_suppression_config_arr(suppression_diameter):
                         pass
 
     # Create the factor graph
-    fg = graph.FactorGraph(variables=[grid_vars, additional_vars])
+    fg = graph.FactorGraph(variable_groups=[grid_vars, additional_vars])
 
     # Imperatively add EnumerationFactorGroups (each consisting of just one EnumerationFactor) to
     # the graph!
@@ -302,7 +302,7 @@ def create_valid_suppression_config_arr(suppression_diameter):
                         valid_configs_non_supp.shape[0], dtype=float
                     ),
                 )
-                fg.add_factor(factor)
+                fg.add_factors(factor=factor)
             else:
                 factor = EnumerationFactor(
                     variables=curr_vars,
@@ -311,7 +311,7 @@ def create_valid_suppression_config_arr(suppression_diameter):
                         valid_configs_non_supp.shape[0], dtype=float
                     ),
                 )
-                fg.add_factor(factor)
+                fg.add_factors(factor=factor)
 
     # Create an EnumerationFactorGroup for vertical suppression factors
     vert_suppression_vars: List[List[Tuple[Any, ...]]] = []
@@ -355,14 +355,14 @@ def create_valid_suppression_config_arr(suppression_diameter):
         variables_for_factors=vert_suppression_vars,
         factor_configs=valid_configs_supp,
     )
-    fg.add_factor_group(factor_group)
+    fg.add_factors(factor_group)
 
     factor_group = enumeration.EnumerationFactorGroup(
         variables_for_factors=horz_suppression_vars,
         factor_configs=valid_configs_supp,
         log_potentials=np.zeros(valid_configs_supp.shape[0], dtype=float),
     )
-    fg.add_factor_group(factor_group)
+    fg.add_factors(factor_group)
 
     # Run BP
     # Set the evidence
@@ -419,7 +419,7 @@ def binary_connected_variables(
                 variables_for_factors=binary_connected_variables(28, 28, k_row, k_col),
                 log_potential_matrix=W_pot[:, :, k_row, k_col],
             )
-            fg.add_factor_group(factor_group)
+            fg.add_factors(factor_group)
 
     # Assign evidence to pixel vars
     bp_state = fg.bp_state

From 0f639fe6fcce3422c970e9aeea64a31a100b558c Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Tue, 26 Apr 2022 18:58:24 +0000
Subject: [PATCH 21/35] Stannis' comments

---
 examples/gmrf.py             |  8 -----
 pgmax/factors/enumeration.py |  4 +--
 pgmax/factors/logical.py     |  4 +--
 pgmax/fg/graph.py            | 66 ++++++++++--------------------------
 pgmax/fg/groups.py           |  9 ++---
 pgmax/fg/nodes.py            |  4 +--
 pgmax/groups/logical.py      |  7 ++--
 pgmax/groups/variables.py    | 22 ++++++------
 tests/factors/test_or.py     | 10 +++---
 tests/fg/test_graph.py       |  4 +--
 tests/fg/test_groups.py      | 20 +++++++++--
 tests/test_pgmax.py          |  8 ++---
 12 files changed, 71 insertions(+), 95 deletions(-)

diff --git a/examples/gmrf.py b/examples/gmrf.py
index eb86a5e9..1bf4fad9 100644
--- a/examples/gmrf.py
+++ b/examples/gmrf.py
@@ -233,11 +233,3 @@ def update(step, batch_noisy_images, batch_target_images, opt_state):
             )
             pbar.update()
             pbar.set_postfix(loss=value)
-
-batch_indices = indices[idx * batch_size : (idx + 1) * batch_size]
-batch_noisy_images, batch_target_images = (
-    noisy_images_train[:10],
-    target_images_train[:10],
-)
-step = 0
-value, opt_state = update(step, batch_noisy_images, batch_target_images, opt_state)
diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py
index d461d3ab..5f776ba0 100644
--- a/pgmax/factors/enumeration.py
+++ b/pgmax/factors/enumeration.py
@@ -2,7 +2,7 @@
 
 import functools
 from dataclasses import dataclass
-from typing import List, Mapping, Sequence, Tuple, Union
+from typing import Any, List, Mapping, Sequence, Tuple, Union
 
 import jax
 import jax.numpy as jnp
@@ -156,7 +156,7 @@ def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiri
     def compile_wiring(
         variables_for_factors: Sequence[List],
         factor_configs: np.ndarray,
-        vars_to_starts: Mapping[Tuple[int, int], int],
+        vars_to_starts: Mapping[Tuple[Any, int], int],
         num_factors: int,
     ) -> EnumerationWiring:
         """Compile an EnumerationWiring for an EnumerationFactor or a FactorGroup with EnumerationFactors.
diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py
index bf8652cb..b78ea206 100644
--- a/pgmax/factors/logical.py
+++ b/pgmax/factors/logical.py
@@ -2,7 +2,7 @@
 
 import functools
 from dataclasses import dataclass, field
-from typing import List, Mapping, Optional, Sequence, Tuple, Union
+from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
 
 import jax
 import jax.numpy as jnp
@@ -141,7 +141,7 @@ def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring:
     @staticmethod
     def compile_wiring(
         variables_for_factors: Sequence[List],
-        vars_to_starts: Mapping[Tuple[int, int], int],
+        vars_to_starts: Mapping[Tuple[Any, int], int],
         edge_states_offset: int,
     ) -> LogicalWiring:
         """Compile a LogicalWiring for a LogicalFactor or a FactorGroup with LogicalFactors.
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 9fbe794c..2f16f2eb 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -6,7 +6,6 @@
 import copy
 import functools
 import inspect
-import typing
 from dataclasses import asdict, dataclass
 from types import MappingProxyType
 from typing import (
@@ -44,12 +43,7 @@ class FactorGraph:
     Factors in a graph are clustered in factor groups, which are grouped according to their factor types.
 
     Args:
-        variables: A single VariableGroup or a container containing variable groups.
-            If not a single VariableGroup, supported containers include mapping and sequence.
-            For a mapping, the keys of the mapping are used to index the variable groups.
-            For a sequence, the indices of the sequence are used to index the variable groups.
-            Note that if not a single VariableGroup, a CompositeVariableGroup will be created from
-            this input, and the individual VariableGroups will need to be accessed by indexing.
+        variable_groups: A single VariableGroup or a list of VariableGroups.
     """
 
     variable_groups: Union[groups.VariableGroup, Sequence[groups.VariableGroup]]
@@ -103,7 +97,7 @@ def add_factors(
 
         Args:
             factor_group: The FactorGroup to be added to the FactorGraph.
-            factor: The Factor to be added to the factor graph.
+            factor: The Factor to be added to the FactorGraph.
 
         Raises:
             ValueError: If
@@ -318,13 +312,13 @@ class FactorGraphState:
     """FactorGraphState.
 
     Args:
-        variable_groups: All the variable groups in the FactorGraph.
+        variable_groups: VariableGroups in the FactorGraph.
         vars_to_starts: Maps variables to their starting indices in the flat evidence array.
             flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states]
             contains evidence to the variable.
         num_var_states: Total number of variable states.
         total_factor_num_states: Size of the flat ftov messages array.
-        factor_groups: Factor groups in the FactorGraph
+        factor_groups: FactorGroups in the FactorGraph
         factor_type_to_msgs_range: Maps factors types to their start and end indices in the flat ftov messages.
         factor_type_to_potentials_range: Maps factor types to their start and end indices in the flat log potentials.
         factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials.
@@ -333,7 +327,7 @@ class FactorGraphState:
     """
 
     variable_groups: Sequence[groups.VariableGroup]
-    vars_to_starts: Mapping[Tuple[int, int], int]
+    vars_to_starts: Mapping[Tuple[Any, int], int]
     num_var_states: int
     total_factor_num_states: int
     factor_groups: Tuple[groups.FactorGroup, ...]
@@ -474,12 +468,11 @@ def __setitem__(
         factor_group: Any,
         data: Union[np.ndarray, jnp.ndarray],
     ):
-        """Set the log potentials for a named factor group or a factor.
+        """Set the log potentials for a FactorGroup
 
         Args:
             factor_group: FactorGroup
-            data: Array containing the log potentials for the named factor group
-                or the factor.
+            data: Array containing the log potentials for the FactorGroup
         """
         object.__setattr__(
             self,
@@ -510,7 +503,7 @@ def update_ftov_msgs(
 
     Raises: ValueError if:
         (1) provided ftov_msgs shape does not match the expected ftov_msgs shape.
-        (2) provided name is not valid for ftov_msgs updates.
+        (2) provided variable is not in the FactorGraph.
     """
     for variable, data in updates.items():
         if variable in fg_state.vars_to_starts:
@@ -535,14 +528,7 @@ def update_ftov_msgs(
                     data / starts.shape[0]
                 )
         else:
-            raise ValueError(
-                "Invalid names for setting messages. "
-                "Supported names include a tuple of length 2 with factor "
-                "and variable names for directly setting factor to variable "
-                "messages, or a valid variable name for spreading expected "
-                "beliefs at a variable"
-            )
-
+            raise ValueError("Provided variable is not in the FactorGraph")
     return ftov_msgs
 
 
@@ -574,38 +560,19 @@ def __post_init__(self):
 
             object.__setattr__(self, "value", self.value)
 
-    @typing.overload
-    def __setitem__(
-        self,
-        names: Tuple[Any, Any],
-        data: Union[np.ndarray, jnp.ndarray],
-    ) -> None:
-        """Setting messages from a factor to a variable
-
-        Args:
-            names: A tuple of length 2
-                names[0] is the name of the factor
-                names[1] is the name of the variable
-            data: An array containing messages from factor names[0]
-                to variable names[1]
-        """
-
-    @typing.overload
     def __setitem__(
         self,
-        variable: Tuple[int, int],
+        variable: Tuple[Any, int],
         data: Union[np.ndarray, jnp.ndarray],
     ) -> None:
-        """Spreading beliefs at a variable to all connected factors
+        """Spreading beliefs at a variable to all connected Factors
 
         Args:
             variable: A tuple representing a variable
             data: An array containing the beliefs to be spread uniformly
-                across all factor to variable messages involving this
-                variable.
+                across all factors to variable messages involving this variable.
         """
 
-    def __setitem__(self, variable, data) -> None:
         object.__setattr__(
             self,
             "value",
@@ -647,7 +614,7 @@ def update_evidence(
             evidence = evidence.at[start_index : start_index + name[1]].set(data)
         else:
             raise ValueError(
-                "Got evidence for a variable or a variable group not in the FactorGraph!"
+                "Got evidence for a variable or a VariableGroup not in the FactorGraph!"
             )
     return evidence
 
@@ -678,7 +645,7 @@ def __post_init__(self):
 
             object.__setattr__(self, "value", self.value)
 
-    def __getitem__(self, variable: Tuple[int, int]) -> np.ndarray:
+    def __getitem__(self, variable: Tuple[Any, int]) -> np.ndarray:
         """Function to query evidence for a variable
 
         Args:
@@ -969,8 +936,9 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
         beliefs = {}
         start = 0
         for variable_group in variable_groups:
-            variables = variable_group.variables
-            length = sum([variable[1] for variable in variables])
+            num_states = variable_group.num_states
+            assert isinstance(num_states, np.ndarray)
+            length = num_states.sum()
 
             beliefs[variable_group] = variable_group.unflatten(
                 flat_beliefs[start : start + length]
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index 62692ffd..15f50603 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -5,6 +5,7 @@
 from functools import total_ordering
 from typing import (
     Any,
+    Collection,
     FrozenSet,
     List,
     Mapping,
@@ -21,7 +22,7 @@
 import pgmax.fg.nodes as nodes
 from pgmax.utils import cached_property
 
-MAX_SIZE = 1e16
+MAX_SIZE = 1e10
 
 
 @total_ordering
@@ -40,7 +41,7 @@ def __eq__(self, other):
     def __lt__(self, other):
         return hash(self) < hash(other)
 
-    def __getitem__(self, val):
+    def __getitem__(self, val: Any) -> Union[Tuple[Any, int], List[Tuple[Any, int]]]:
         """Given a variable name, index, or a group of variable indices, retrieve the associated variable(s).
         Each variable is returned via a tuple of the form (variable hash/name, number of states)
 
@@ -129,7 +130,7 @@ def __eq__(self, other):
     def __lt__(self, other):
         return hash(self) < hash(other)
 
-    def __getitem__(self, variables: Sequence[Tuple[int, int]]) -> Any:
+    def __getitem__(self, variables: Union[Sequence, Collection]) -> Any:
         """Function to query individual factors in the factor group
 
         Args:
@@ -228,7 +229,7 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any:
             "Please subclass the FactorGroup class and override this method"
         )
 
-    def compile_wiring(self, vars_to_starts: Mapping[Tuple[int, int], int]) -> Any:
+    def compile_wiring(self, vars_to_starts: Mapping[Tuple[Any, int], int]) -> Any:
         """Compile an efficient wiring for the FactorGroup.
 
         Args:
diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py
index 06d7c5d6..f987b8df 100644
--- a/pgmax/fg/nodes.py
+++ b/pgmax/fg/nodes.py
@@ -1,7 +1,7 @@
 """A module containing classes that specify the basic components of a Factor Graph."""
 
 from dataclasses import asdict, dataclass
-from typing import List, Sequence, Tuple, Union
+from typing import Any, List, Sequence, Tuple, Union
 
 import jax
 import jax.numpy as jnp
@@ -48,7 +48,7 @@ class Factor:
         NotImplementedError: If compile_wiring is not implemented
     """
 
-    variables: List[Tuple[int, int]]
+    variables: List[Tuple[Any, int]]
     log_potentials: np.ndarray
 
     def __post_init__(self):
diff --git a/pgmax/groups/logical.py b/pgmax/groups/logical.py
index 83115945..618a0526 100644
--- a/pgmax/groups/logical.py
+++ b/pgmax/groups/logical.py
@@ -4,6 +4,8 @@
 from dataclasses import dataclass, field
 from typing import FrozenSet, OrderedDict, Type
 
+import numpy as np
+
 from pgmax.factors import logical
 from pgmax.fg import groups
 
@@ -20,12 +22,9 @@ class LogicalFactorGroup(groups.FactorGroup):
             For ORFactors the edge_states_offset is 1, for ANDFactors the edge_states_offset is -1.
     """
 
+    factor_configs: np.ndarray = field(init=False, default=None)
     edge_states_offset: int = field(init=False)
 
-    def __post_init__(self):
-        super().__post_init__()
-        object.__setattr__(self, "factor_configs", None)
-
     def _get_variables_to_factors(
         self,
     ) -> OrderedDict[FrozenSet, logical.LogicalFactor]:
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index 7b50167e..57570a1a 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -26,9 +26,9 @@ class NDVariableArray(groups.VariableGroup):
     num_states: Union[int, np.ndarray]
 
     def __post_init__(self):
-        if np.prod(self.shape) > groups.MAX_SIZE:
+        if np.prod(self.shape) > int(groups.MAX_SIZE):
             raise ValueError(
-                f"Currently only support NDVariableArray of size smaller than {groups.MAX_SIZE}. Got {np.prod(self.shape)}"
+                f"Currently only support NDVariableArray of size smaller than {int(groups.MAX_SIZE)}. Got {np.prod(self.shape)}"
             )
 
         if np.isscalar(self.num_states):
@@ -42,7 +42,9 @@ def __post_init__(self):
                     f"Expected num_states shape {self.shape}. Got {self.num_states.shape}."
                 )
         else:
-            raise ValueError("num_states entries should be of type np.int")
+            raise ValueError(
+                "num_states should be an integer or a NumPy array of dtype int"
+            )
 
     def __getitem__(
         self, val: Union[int, slice, Tuple]
@@ -94,7 +96,7 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
 
         Args:
             data: Meaningful structured data. Should be an array of shape self.shape (for e.g. MAP decodings)
-                or self.shape + (self.num_states,) (for e.g. evidence, beliefs).
+                or self.shape + (self.num_states.max(),) (for e.g. evidence, beliefs).
 
         Returns:
             A flat jnp.array for internal use
@@ -158,7 +160,8 @@ class VariableDict(groups.VariableGroup):
 
     Args:
         num_states: The size of the variables in this variable group
-        variable_names: A tuple of all names of the variables in this variable group
+        variable_names: A tuple of all names of the variables in this variable group.
+            Note that we overwrite variable_names to add the hash of the VariableDict
     """
 
     variable_names: Tuple[Any, ...]
@@ -176,7 +179,7 @@ def __post_init__(self):
         object.__setattr__(self, "variable_names", hash_and_names)
 
     @cached_property
-    def variables(self) -> List[Tuple]:
+    def variables(self) -> List[Tuple[Tuple[Any, int], int]]:
         """Function that returns the list of all variables in the VariableGroup.
         Each variable is represented by a tuple of the form (variable name, number of states)
 
@@ -188,24 +191,23 @@ def variables(self) -> List[Tuple]:
         vars_num_states = self.num_states.flatten()
         return list(zip(vars_names, vars_num_states))
 
-    def __getitem__(self, val: Any) -> Tuple[Any, int]:
+    def __getitem__(self, val: Any) -> Tuple[Tuple[Any, int], int]:
         """Given a variable name retrieve the associated variable, returned via a tuple of the form
         (variable name, number of states)
 
         Args:
-            val: a variable index or slice
+            val: a variable name
 
         Returns:
             The queried variable
         """
         assert isinstance(self.num_states, np.ndarray)
-
         if (self.__hash__(), val) not in self.variable_names:
             raise ValueError(f"Variable {val} is not in VariableDict")
         return ((self.__hash__(), val), self.num_states[0])
 
     def flatten(
-        self, data: Mapping[Tuple[int, int], Union[np.ndarray, jnp.ndarray]]
+        self, data: Mapping[Tuple[Tuple[int, int], int], Union[np.ndarray, jnp.ndarray]]
     ) -> jnp.ndarray:
         """Function that turns meaningful structured data into a flat data array for internal use.
 
diff --git a/tests/factors/test_or.py b/tests/factors/test_or.py
index 16236b35..f29dccd0 100644
--- a/tests/factors/test_or.py
+++ b/tests/factors/test_or.py
@@ -102,7 +102,7 @@ def test_run_bp_with_ORFactors():
                     factor_configs=valid_configs,
                     log_potentials=np.zeros(valid_configs.shape[0]),
                 )
-                fg1.add_factor(factor=enum_factor)
+                fg1.add_factors(factor=enum_factor)
             else:
                 if idx != 0:
                     # Add the second half of factors to FactorGraph2
@@ -111,7 +111,7 @@ def test_run_bp_with_ORFactors():
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
-                    fg2.add_factor(factor=enum_factor)
+                    fg2.add_factors(factor=enum_factor)
                 else:
                     # Add all the EnumerationFactors to FactorGraph1 for the first iter
                     enum_factor = EnumerationFactor(
@@ -119,7 +119,7 @@ def test_run_bp_with_ORFactors():
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
-                    fg1.add_factor(factor=enum_factor)
+                    fg1.add_factors(factor=enum_factor)
 
         # Option 2: Define the ORFactors
         num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0)
@@ -143,10 +143,10 @@ def test_run_bp_with_ORFactors():
                     )
         if idx != 0:
             factor_group = logical.ORFactorGroup(variables_for_ORFactors_fg1)
-            fg1.add_factor_group(factor_group)
+            fg1.add_factors(factor_group)
 
         factor_group = logical.ORFactorGroup(variables_for_ORFactors_fg2)
-        fg2.add_factor_group(factor_group)
+        fg2.add_factors(factor_group)
 
         # Run inference
         bp1 = graph.BP(fg1.bp_state, temperature=temperature)
diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py
index bb26ce99..df734ac2 100644
--- a/tests/fg/test_graph.py
+++ b/tests/fg/test_graph.py
@@ -130,7 +130,7 @@ def test_ftov_msgs():
 
     with pytest.raises(
         ValueError,
-        match=re.escape("Invalid names for setting messages"),
+        match=re.escape("Provided variable is not in the FactorGraph"),
     ):
         fg.bp_state.ftov_msgs[0] = np.ones(10)
 
@@ -175,7 +175,7 @@ def test_evidence():
     with pytest.raises(
         ValueError,
         match=re.escape(
-            "Got evidence for a variable or a variable group not in the FactorGraph!"
+            "Got evidence for a variable or a VariableGroup not in the FactorGraph!"
         ),
     ):
         graph.update_evidence(
diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py
index 82a2ea06..322147c4 100644
--- a/tests/fg/test_groups.py
+++ b/tests/fg/test_groups.py
@@ -54,6 +54,15 @@ def test_variable_dict():
 
 
 def test_nd_variable_array():
+    max_size = int(groups.MAX_SIZE)
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            f"Currently only support NDVariableArray of size smaller than {max_size}. Got {max_size + 1}"
+        ),
+    ):
+        vgroup.NDVariableArray(shape=(max_size + 1,), num_states=2)
+
     num_states = np.full((2, 3), fill_value=2)
     with pytest.raises(
         ValueError, match=re.escape("Expected num_states shape (2, 2). Got (2, 3).")
@@ -62,14 +71,19 @@ def test_nd_variable_array():
 
     num_states = np.full((2, 3), fill_value=2, dtype=np.float32)
     with pytest.raises(
-        ValueError, match=re.escape("num_states entries should be of type np.int")
+        ValueError,
+        match=re.escape(
+            "num_states should be an integer or a NumPy array of dtype int"
+        ),
     ):
         vgroup.NDVariableArray(shape=(2, 2), num_states=num_states)
 
-    variable_group = vgroup.NDVariableArray(shape=(5, 5), num_states=2)
-    assert len(variable_group[:3, :3]) == 9
+    variable_group0 = vgroup.NDVariableArray(shape=(5, 5), num_states=2)
+    assert len(variable_group0[:3, :3]) == 9
 
     variable_group = vgroup.NDVariableArray(shape=(2, 2), num_states=3)
+    variable_group0 < variable_group
+
     with pytest.raises(
         ValueError,
         match=re.escape("data should be of shape (2, 2) or (2, 2, 3). Got (3, 3)."),
diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py
index 643448a3..9e6deb65 100644
--- a/tests/test_pgmax.py
+++ b/tests/test_pgmax.py
@@ -218,10 +218,10 @@ def create_valid_suppression_config_arr(suppression_diameter):
         (grid_vars, (1, 0, 1)): 0,
         (grid_vars, (1, 1, 0)): 1,
         (grid_vars, (1, 1, 1)): 0,
-        (additional_vars, ((0, 0, 2), 3)): 0,
-        (additional_vars, ((0, 1, 2), 3)): 2,
-        (additional_vars, ((1, 2, 0), 3)): 1,
-        (additional_vars, ((1, 2, 1), 3)): 0,
+        (additional_vars, ((additional_vars.__hash__(), (0, 0, 2)), 3)): 0,
+        (additional_vars, ((additional_vars.__hash__(), (0, 1, 2)), 3)): 2,
+        (additional_vars, ((additional_vars.__hash__(), (1, 2, 0)), 3)): 1,
+        (additional_vars, ((additional_vars.__hash__(), (1, 2, 1)), 3)): 0,
     }
 
     gt_has_cuts = gt_has_cuts.astype(np.int32)

From 546b7906e694a29677b065b49781f6251ac049ff Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Tue, 26 Apr 2022 19:39:36 +0000
Subject: [PATCH 22/35] Flattent / unflattent

---
 pgmax/fg/graph.py         |  2 +-
 pgmax/groups/variables.py | 22 ++++++++++++----------
 tests/fg/test_graph.py    |  4 ++--
 tests/fg/test_groups.py   | 23 +++++++++++++----------
 4 files changed, 28 insertions(+), 23 deletions(-)

diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 2f16f2eb..67ca4477 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -460,7 +460,7 @@ def __getitem__(self, factor_group: groups.FactorGroup) -> np.ndarray:
                 start : start + factor_group.factor_group_log_potentials.shape[0]
             ]
         else:
-            raise ValueError("Invalid FactorGroup for log potentials updates.")
+            raise ValueError("Invalid FactorGroup queried to access log potentials.")
         return log_potentials
 
     def __setitem__(
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index 57570a1a..ad03e966 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -106,19 +106,19 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
         """
         assert isinstance(self.num_states, np.ndarray)
 
-        # TODO: what should we do for different number of states -> look at mask_array
-        if data.shape != self.shape and data.shape != self.shape + (
-            self.num_states.max(),
-        ):
+        if data.shape == self.shape:
+            return jax.device_put(data).flatten()
+        elif data.shape == self.shape + (self.num_states.max(),):
+            return jax.device_put(
+                data[np.arange(data.shape[-1]) < self.num_states[..., None]]
+            )
+        else:
             raise ValueError(
                 f"data should be of shape {self.shape} or {self.shape + (self.num_states.max(),)}. "
                 f"Got {data.shape}."
             )
-        return jax.device_put(data).flatten()
 
-    def unflatten(
-        self, flat_data: Union[np.ndarray, jnp.ndarray]
-    ) -> Union[np.ndarray, jnp.ndarray]:
+    def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
         """Function that recovers meaningful structured data from internal flat data array
 
         Args:
@@ -143,8 +143,10 @@ def unflatten(
         if flat_data.size == np.product(self.shape):
             data = flat_data.reshape(self.shape)
         elif flat_data.size == self.num_states.sum():
-            # TODO: what should we do for different number of states
-            data = flat_data.reshape(self.shape + (self.num_states.max(),))
+            data = jnp.zeros(self.shape + (self.num_states.max(),))
+            data = data.at[np.arange(data.shape[-1]) < self.num_states[..., None]].set(
+                flat_data
+            )
         else:
             raise ValueError(
                 f"flat_data should be compatible with shape {self.shape} or {self.shape + (self.num_states.max(),)}. "
diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py
index df734ac2..0fdadbb2 100644
--- a/tests/fg/test_graph.py
+++ b/tests/fg/test_graph.py
@@ -106,9 +106,9 @@ def test_log_potentials():
 
     with pytest.raises(
         ValueError,
-        match=re.escape("Invalid FactorGroup for log potentials updates."),
+        match=re.escape("Invalid FactorGroup queried to access log potentials."),
     ):
-        fg.bp_state.log_potentials[vg[0]] = np.zeros(10)
+        fg.bp_state.log_potentials[vg[0]]
 
     with pytest.raises(
         ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)")
diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py
index 322147c4..8b5d7a44 100644
--- a/tests/fg/test_groups.py
+++ b/tests/fg/test_groups.py
@@ -81,18 +81,22 @@ def test_nd_variable_array():
     variable_group0 = vgroup.NDVariableArray(shape=(5, 5), num_states=2)
     assert len(variable_group0[:3, :3]) == 9
 
-    variable_group = vgroup.NDVariableArray(shape=(2, 2), num_states=3)
+    variable_group = vgroup.NDVariableArray(
+        shape=(2, 2), num_states=np.array([[1, 2], [3, 4]])
+    )
     variable_group0 < variable_group
 
     with pytest.raises(
         ValueError,
-        match=re.escape("data should be of shape (2, 2) or (2, 2, 3). Got (3, 3)."),
+        match=re.escape("data should be of shape (2, 2) or (2, 2, 4). Got (3, 3)."),
     ):
         variable_group.flatten(np.zeros((3, 3)))
 
     assert jnp.all(
         variable_group.flatten(np.array([[1, 2], [3, 4]])) == jnp.array([1, 2, 3, 4])
     )
+    assert jnp.all(variable_group.flatten(np.zeros((2, 2, 4))) == jnp.zeros((10,)))
+
     with pytest.raises(
         ValueError, match="Can only unflatten 1D array. Got a 2D array."
     ):
@@ -101,13 +105,13 @@ def test_nd_variable_array():
     with pytest.raises(
         ValueError,
         match=re.escape(
-            "flat_data should be compatible with shape (2, 2) or (2, 2, 3). Got (10,)."
+            "flat_data should be compatible with shape (2, 2) or (2, 2, 4). Got (12,)."
         ),
     ):
-        variable_group.unflatten(np.zeros((10,)))
+        variable_group.unflatten(np.zeros((12,)))
 
     assert jnp.all(variable_group.unflatten(np.zeros(4)) == jnp.zeros((2, 2)))
-    assert jnp.all(variable_group.unflatten(np.zeros(12)) == jnp.zeros((2, 2, 3)))
+    assert jnp.all(variable_group.unflatten(np.zeros(10)) == jnp.zeros((2, 2, 4)))
 
 
 def test_single_factor():
@@ -119,21 +123,20 @@ def test_single_factor():
 
     variables0 = (A[0], B[0])
     variables1 = (A[1], B[1])
-    ORFactor = logical.ORFactorGroup(variables_for_factors=[variables0])
+    ORFactor0 = logical.ORFactorGroup(variables_for_factors=[variables0])
     with pytest.raises(
         ValueError, match="SingleFactorGroup should only contain one factor. Got 2"
     ):
         groups.SingleFactorGroup(
             variables_for_factors=[variables0, variables1],
-            factor=ORFactor,
+            factor=ORFactor0,
         )
+    ORFactor1 = logical.ORFactorGroup(variables_for_factors=[variables1])
+    ORFactor0 < ORFactor1
 
 
 def test_enumeration_factor_group():
     vg = vgroup.NDVariableArray(shape=(2, 2), num_states=3)
-    vg_bis = vgroup.NDVariableArray(shape=(2, 2), num_states=3)
-    vg < vg_bis
-
     with pytest.raises(
         ValueError,
         match=re.escape("Expected log potentials shape: (1,) or (2, 1). Got (3, 2)"),

From 35ce6c053e8a6940aefc8ad545accfec138cdcd3 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Tue, 26 Apr 2022 19:48:05 +0000
Subject: [PATCH 23/35] Unflatten with nan

---
 pgmax/groups/variables.py | 4 +++-
 tests/fg/test_groups.py   | 8 +++++++-
 2 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index ad03e966..887a32a7 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -143,7 +143,9 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
         if flat_data.size == np.product(self.shape):
             data = flat_data.reshape(self.shape)
         elif flat_data.size == self.num_states.sum():
-            data = jnp.zeros(self.shape + (self.num_states.max(),))
+            data = jnp.full(
+                shape=self.shape + (self.num_states.max(),), fill_value=jnp.nan
+            )
             data = data.at[np.arange(data.shape[-1]) < self.num_states[..., None]].set(
                 flat_data
             )
diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py
index 8b5d7a44..b65a8ce5 100644
--- a/tests/fg/test_groups.py
+++ b/tests/fg/test_groups.py
@@ -111,7 +111,13 @@ def test_nd_variable_array():
         variable_group.unflatten(np.zeros((12,)))
 
     assert jnp.all(variable_group.unflatten(np.zeros(4)) == jnp.zeros((2, 2)))
-    assert jnp.all(variable_group.unflatten(np.zeros(10)) == jnp.zeros((2, 2, 4)))
+    unflattened = jnp.full((2, 2, 4), fill_value=jnp.nan)
+    unflattened = unflattened.at[0, 0, 0].set(0)
+    unflattened = unflattened.at[0, 1, :1].set(0)
+    unflattened = unflattened.at[1, 0, :2].set(0)
+    unflattened = unflattened.at[1, 1].set(0)
+    mask = ~jnp.isnan(unflattened)
+    assert jnp.all(variable_group.unflatten(np.zeros(10))[mask] == unflattened[mask])
 
 
 def test_single_factor():

From 704ee3a6c978f0cbbb8a8195a9db8d95f228db2a Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Wed, 27 Apr 2022 01:30:37 +0000
Subject: [PATCH 24/35] Speeding up

---
 examples/pmp_binary_deconvolution.py |  1 +
 examples/rbm.py                      |  5 +--
 pgmax/factors/enumeration.py         |  9 +++--
 pgmax/factors/logical.py             | 26 +++++++++-----
 pgmax/fg/graph.py                    |  7 ++--
 pgmax/fg/groups.py                   | 52 +++++++++++++++++++---------
 pgmax/groups/enumeration.py          | 22 ++++++++----
 pgmax/groups/variables.py            | 27 +++++++++++----
 8 files changed, 102 insertions(+), 47 deletions(-)

diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py
index a0f24406..fd396f3a 100644
--- a/examples/pmp_binary_deconvolution.py
+++ b/examples/pmp_binary_deconvolution.py
@@ -211,6 +211,7 @@ def plot_images(images, display=True, nr=None):
 # in the same manner does not change X, so this naturally results in multiple equivalent modes.
 
 # %%
+# %load_ext snakeviz
 start = time.time()
 bp = graph.BP(fg.bp_state, temperature=0.0)
 print("Time", time.time() - start)
diff --git a/examples/rbm.py b/examples/rbm.py
index cdc51390..cde32960 100644
--- a/examples/rbm.py
+++ b/examples/rbm.py
@@ -90,9 +90,10 @@
     ],
     log_potential_matrix=log_potential_matrix,
 )
-fg.add_factors(pairwise_factors)
+print("Time", time.time() - start)
+# #%snakeviz pairwise_factors = enumeration.PairwiseFactorGroup(variables_for_factors=[ [hidden_variables[ii], visible_variables[jj]] for ii in range(bh.shape[0]) for jj in range(bv.shape[0])],log_potential_matrix=log_potential_matrix,)
 
-# # %snakeviz fg.add_factors(factory=enumeration.PairwiseFactorGroup, variables_for_factors=v, log_potential_matrix=log_potential_matrix,)
+fg.add_factors(pairwise_factors)
 print("Time", time.time() - start)
 
 
diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py
index 5f776ba0..ef2a99a7 100644
--- a/pgmax/factors/enumeration.py
+++ b/pgmax/factors/enumeration.py
@@ -154,6 +154,7 @@ def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiri
 
     @staticmethod
     def compile_wiring(
+        factor_edges_num_states: np.ndarray,
         variables_for_factors: Sequence[List],
         factor_configs: np.ndarray,
         vars_to_starts: Mapping[Tuple[Any, int], int],
@@ -163,10 +164,14 @@ def compile_wiring(
         Internally calls _compile_var_states_numba and _compile_enumeration_wiring_numba for speed.
 
         Args:
+            factor_edges_num_states: An array concatenating the number of states for the variables connected to each
+                Factor of the FactorGroup. Each variable will appear once for each Factor it connects to.
             variables_for_factors: A list of list of variables. Each list within the outer list contains the
                 variables connected to a Factor. The same variable can be connected to multiple Factors.
             factor_configs: Array of shape (num_val_configs, num_variables) containing an explicit enumeration
                 of all valid configurations.
+            factor_edges_num_states: Array concatenating the number of states for the variables connected to each Factor of
+                the FactorGroup. Each variable will appear once for each Factor it connects to.
             vars_to_starts: A dictionary that maps variables to their global starting indices
                 For an n-state variable, a global start index of m means the global indices
                 of its n variable states are m, m + 1, ..., m + n - 1
@@ -178,14 +183,12 @@ def compile_wiring(
         Returns:
             The EnumerationWiring
         """
+        # TODO: Don't use vars_to_starts
         var_states = []
-        factor_edges_num_states = []
         for variables_for_factor in variables_for_factors:
             for variable in variables_for_factor:
                 var_states.append(vars_to_starts[variable])
-                factor_edges_num_states.append(variable[1])
         var_states = np.array(var_states)
-        factor_edges_num_states = np.array(factor_edges_num_states)
 
         num_states_cumsum = np.insert(np.cumsum(factor_edges_num_states), 0, 0)
         var_states_for_edges = np.empty(shape=(num_states_cumsum[-1],), dtype=int)
diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py
index b78ea206..c8f93f6b 100644
--- a/pgmax/factors/logical.py
+++ b/pgmax/factors/logical.py
@@ -140,7 +140,9 @@ def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring:
 
     @staticmethod
     def compile_wiring(
+        factor_edges_num_states: np.ndarray,
         variables_for_factors: Sequence[List],
+        factor_sizes: np.ndarray,
         vars_to_starts: Mapping[Tuple[Any, int], int],
         edge_states_offset: int,
     ) -> LogicalWiring:
@@ -148,8 +150,11 @@ def compile_wiring(
         Internally calls _compile_var_states_numba and _compile_logical_wiring_numba for speed.
 
         Args:
-            variables_for_factors: A list of list of variables. Each list within the outer list contains the
-                variables connected to a Factor. The same variable can be connected to multiple Factors.
+            factor_edges_num_states: An array concatenating the number of states for the variables connected to each
+                Factor of the FactorGroup. Each variable will appear once for each Factor it connects to.
+            variables_for_factors: A tuple of tuples containing variables connected to each Factor of the FactorGroup.
+                Each variable will appear once for each Factor it connects to.
+            factor_sizes: An array containing the different factor sizes.
             vars_to_starts: A dictionary that maps variables to their global starting indices
                 For an n-state variable, a global start index of m means the global indices
                 of its n variable states are m, m + 1, ..., m + n - 1
@@ -159,17 +164,12 @@ def compile_wiring(
         Returns:
             The LogicalWiring
         """
-        factor_sizes = []
+        # TODO: Don't use vars_to_starts
         var_states = []
-        factor_edges_num_states = []
         for variables_for_factor in variables_for_factors:
-            factor_sizes.append(len(variables_for_factor))
             for variable in variables_for_factor:
                 var_states.append(vars_to_starts[variable])
-                factor_edges_num_states.append(variable[1])
-        factor_sizes = np.array(factor_sizes)
         var_states = np.array(var_states)
-        factor_edges_num_states = np.array(factor_edges_num_states)
 
         # Relevant state differs for ANDFactors and ORFactors
         relevant_state = (-edge_states_offset + 1) // 2
@@ -241,6 +241,16 @@ class ANDFactor(LogicalFactor):
     edge_states_offset: int = field(init=False, default=-1)
 
 
+@nb.jit(parallel=False, cache=True, fastmath=True, nopython=True)
+def _compile_utils_numba(var_states, factor_edges_num_states, variables_for_factors):
+    idx = 0
+    for variables_for_factor in variables_for_factors:
+        for variable in variables_for_factor:
+            var_states[idx] = variable
+            factor_edges_num_states[idx] = variable[1]
+            idx += 1
+
+
 @nb.jit(parallel=False, cache=True, fastmath=True, nopython=True)
 def _compile_logical_wiring_numba(
     parents_edge_states: np.ndarray,
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 67ca4477..a13fc426 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -68,9 +68,8 @@ def __post_init__(self):
         )
 
         # See FactorGraphState docstrings for documentation on the following fields
-        self._vars_to_starts: OrderedDict[
-            Tuple[int, int], int
-        ] = collections.OrderedDict()
+        self._vars_to_starts: Dict[Tuple[int, int], int] = {}
+
         vars_num_states_cumsum = 0
         for variable_group in self.variable_groups:
             vg_num_states = variable_group.num_states.flatten()
@@ -159,7 +158,7 @@ def compute_offsets(self) -> None:
                     factor_group
                 ] = factor_num_configs_cumsum
 
-                factor_num_states_cumsum += factor_group.total_num_states
+                factor_num_states_cumsum += factor_group.factor_edges_num_states.sum()
                 factor_num_configs_cumsum += (
                     factor_group.factor_group_log_potentials.shape[0]
                 )
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index 15f50603..1ffc771d 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -17,6 +17,7 @@
 )
 
 import jax.numpy as jnp
+import numba as nb
 import numpy as np
 
 import pgmax.fg.nodes as nodes
@@ -30,10 +31,13 @@
 class VariableGroup:
     """Class to represent a group of variables.
     Each variable is represented via a tuple of the form (variable hash/name, number of states)
+
+    Attributes:
+        this_hash: Hash of the VariableGroup
     """
 
     def __hash__(self):
-        return id(self) * int(MAX_SIZE)
+        return self.this_hash
 
     def __eq__(self, other):
         return hash(self) == hash(other)
@@ -107,6 +111,9 @@ class FactorGroup:
 
     Attributes:
         factor_type: Factor type shared by all the Factors in the FactorGroup.
+        factor_sizes: Array of the different factor sizes.
+        factor_edges_num_states: Array concatenating the number of states for the variables connected to each Factor of
+            the FactorGroup. Each variable will appear once for each Factor it connects to.
 
     Raises:
         ValueError: if the FactorGroup does not contain a Factor
@@ -116,11 +123,27 @@ class FactorGroup:
     factor_configs: np.ndarray = field(init=False)
     log_potentials: np.ndarray = field(init=False, default=np.empty((0,)))
     factor_type: Type = field(init=False)
+    factor_sizes: np.ndarray = field(init=False)
+    factor_edges_num_states: np.ndarray = field(init=False)
 
     def __post_init__(self):
         if len(self.variables_for_factors) == 0:
             raise ValueError("Cannot create a FactorGroup with no Factor.")
 
+        factor_sizes = np.array(
+            [
+                len(variables_for_factor)
+                for variables_for_factor in self.variables_for_factors
+            ]
+        )
+        object.__setattr__(self, "factor_sizes", factor_sizes)
+
+        factor_edges_num_states = np.empty(shape=(self.factor_sizes.sum(),), dtype=int)
+        _compile_edges_num_states_numba(
+            factor_edges_num_states, self.variables_for_factors
+        )
+        object.__setattr__(self, "factor_edges_num_states", factor_edges_num_states)
+
     def __hash__(self):
         return id(self)
 
@@ -160,22 +183,6 @@ def _variables_to_factors(self) -> Mapping[FrozenSet, nodes.Factor]:
         """
         return self._get_variables_to_factors()
 
-    @cached_property
-    def total_num_states(self) -> int:
-        """Function to return the total number of states for all the variables involved in all the Factors
-
-        Returns:
-            Total number of variable states in the FactorGroup
-        """
-        # TODO: this could be returned by the wiring to loop over variables_for_factors only once
-        return sum(
-            [
-                variable[1]
-                for variables_for_factor in self.variables_for_factors
-                for variable in variables_for_factor
-            ]
-        )
-
     @cached_property
     def factor_group_log_potentials(self) -> np.ndarray:
         """Flattened array of log potentials"""
@@ -321,3 +328,14 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any:
         raise NotImplementedError(
             "SingleFactorGroup does not support vectorized factor operations."
         )
+
+
+nb.jit(parallel=False, cache=True, fastmath=True, nopython=False)
+
+
+def _compile_edges_num_states_numba(factor_edges_num_states, variables_for_factors):
+    idx = 0
+    for variables_for_factor in variables_for_factors:
+        for variable in variables_for_factor:
+            factor_edges_num_states[idx] = variable[1]
+            idx += 1
diff --git a/pgmax/groups/enumeration.py b/pgmax/groups/enumeration.py
index 03f17cb3..a09d9f24 100644
--- a/pgmax/groups/enumeration.py
+++ b/pgmax/groups/enumeration.py
@@ -65,7 +65,6 @@ def __post_init__(self):
             raise ValueError(
                 f"Potentials should be floats. Got {log_potentials.dtype}."
             )
-
         object.__setattr__(self, "log_potentials", log_potentials)
 
     def _get_variables_to_factors(
@@ -232,6 +231,10 @@ def __post_init__(self):
                 f"Got log_potential_matrix for {log_potential_matrix.shape[0]} factors."
             )
 
+        import time
+
+        start = time.time()
+        log_potential_shape = log_potential_matrix.shape[-2:]
         for variables_for_factor in self.variables_for_factors:
             if len(variables_for_factor) != 2:
                 raise ValueError(
@@ -239,15 +242,19 @@ def __post_init__(self):
                     f" {len(variables_for_factor)} variables ({variables_for_factor})."
                 )
 
-            num_states0 = variables_for_factor[0][1]
-            num_states1 = variables_for_factor[1][1]
-            if not log_potential_matrix.shape[-2:] == (num_states0, num_states1):
+            factor_num_configs = (
+                variables_for_factor[0][1],
+                variables_for_factor[1][1],
+            )
+            if log_potential_shape != factor_num_configs:
                 raise ValueError(
-                    f"The specified pairwise factor {variables_for_factor} (with {(num_states0, num_states1)}"
+                    f"The specified pairwise factor {variables_for_factor} (with {factor_num_configs}"
                     f"configurations) does not match the specified log_potential_matrix "
-                    f"(with {log_potential_matrix.shape[-2:]} configurations)."
+                    f"(with {log_potential_shape} configurations)."
                 )
-        object.__setattr__(self, "log_potential_matrix", log_potential_matrix)
+            object.__setattr__(self, "log_potential_matrix", log_potential_matrix)
+        print(time.time() - start)
+
         factor_configs = (
             np.mgrid[
                 : log_potential_matrix.shape[-2],
@@ -257,6 +264,7 @@ def __post_init__(self):
             .reshape((-1, 2))
         )
         object.__setattr__(self, "factor_configs", factor_configs)
+
         log_potential_matrix = np.broadcast_to(
             log_potential_matrix,
             (len(self.variables_for_factors),) + log_potential_matrix.shape[-2:],
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index 887a32a7..92551ffa 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -10,6 +10,8 @@
 from pgmax.fg import groups
 from pgmax.utils import cached_property
 
+MAX_SIZE = 1e9
+
 
 @dataclass(frozen=True, eq=False)
 class NDVariableArray(groups.VariableGroup):
@@ -26,13 +28,13 @@ class NDVariableArray(groups.VariableGroup):
     num_states: Union[int, np.ndarray]
 
     def __post_init__(self):
-        if np.prod(self.shape) > int(groups.MAX_SIZE):
+        if np.prod(self.shape) > MAX_SIZE:
             raise ValueError(
-                f"Currently only support NDVariableArray of size smaller than {int(groups.MAX_SIZE)}. Got {np.prod(self.shape)}"
+                f"Currently only support NDVariableArray of size smaller than {int(MAX_SIZE)}. Got {np.prod(self.shape)}"
             )
 
         if np.isscalar(self.num_states):
-            num_states = np.full(self.shape, fill_value=self.num_states, dtype=np.int32)
+            num_states = np.full(self.shape, fill_value=self.num_states, dtype=np.int64)
             object.__setattr__(self, "num_states", num_states)
         elif isinstance(self.num_states, np.ndarray) and np.issubdtype(
             self.num_states.dtype, int
@@ -46,6 +48,12 @@ def __post_init__(self):
                 "num_states should be an integer or a NumPy array of dtype int"
             )
 
+        # Only compute the hash once, which is guaranteed to be an int64
+        this_id = id(self) % 2**32
+        this_hash = this_id * int(MAX_SIZE)
+        assert this_hash < 2**63
+        object.__setattr__(self, "this_hash", this_hash)
+
     def __getitem__(
         self, val: Union[int, slice, Tuple]
     ) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
@@ -61,13 +69,17 @@ def __getitem__(
             A single variable or a list of variables
         """
         assert isinstance(self.num_states, np.ndarray)
-        if np.isscalar(self.variable_names[val]):
-            return (self.variable_names[val], self.num_states[val])
-        else:
+
+        if isinstance(val, slice) or (
+            isinstance(val, tuple) and isinstance(val[0], slice)
+        ):
+            assert isinstance(self.num_states, np.ndarray)
             vars_names = self.variable_names[val].flatten()
             vars_num_states = self.num_states[val].flatten()
             return list(zip(vars_names, vars_num_states))
 
+        return (self.variable_names[val], self.num_states[val])
+
     @cached_property
     def variables(self) -> List[Tuple]:
         """Function that returns the list of all variables in the VariableGroup.
@@ -177,6 +189,9 @@ def __post_init__(self):
         )
         object.__setattr__(self, "num_states", num_states)
 
+        # Only compute the hash once
+        object.__setattr__(self, "this_hash", id(self))
+
         hash_and_names = tuple(
             (self.__hash__(), var_name) for var_name in self.variable_names
         )

From aaa67ef5f828c89c3450daccd5ed2093c382c691 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Wed, 27 Apr 2022 01:33:44 +0000
Subject: [PATCH 25/35] max size

---
 pgmax/fg/groups.py      | 2 --
 tests/fg/test_groups.py | 2 +-
 2 files changed, 1 insertion(+), 3 deletions(-)

diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index 1ffc771d..bf1c43d1 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -23,8 +23,6 @@
 import pgmax.fg.nodes as nodes
 from pgmax.utils import cached_property
 
-MAX_SIZE = 1e10
-
 
 @total_ordering
 @dataclass(frozen=True, eq=False)
diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py
index b65a8ce5..b47e3e09 100644
--- a/tests/fg/test_groups.py
+++ b/tests/fg/test_groups.py
@@ -54,7 +54,7 @@ def test_variable_dict():
 
 
 def test_nd_variable_array():
-    max_size = int(groups.MAX_SIZE)
+    max_size = int(vgroup.MAX_SIZE)
     with pytest.raises(
         ValueError,
         match=re.escape(

From 04c4d89b79b7b8c353cdb26faed3eb8e35c70342 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Wed, 27 Apr 2022 01:54:43 +0000
Subject: [PATCH 26/35] Understand timings

---
 examples/pmp_binary_deconvolution.py |  1 -
 examples/rbm.py                      | 13 +++++++------
 pgmax/factors/logical.py             | 10 ----------
 3 files changed, 7 insertions(+), 17 deletions(-)

diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py
index fd396f3a..a0f24406 100644
--- a/examples/pmp_binary_deconvolution.py
+++ b/examples/pmp_binary_deconvolution.py
@@ -211,7 +211,6 @@ def plot_images(images, display=True, nr=None):
 # in the same manner does not change X, so this naturally results in multiple equivalent modes.
 
 # %%
-# %load_ext snakeviz
 start = time.time()
 bp = graph.BP(fg.bp_state, temperature=0.0)
 print("Time", time.time() - start)
diff --git a/examples/rbm.py b/examples/rbm.py
index cde32960..8565f086 100644
--- a/examples/rbm.py
+++ b/examples/rbm.py
@@ -82,16 +82,17 @@
 log_potential_matrix = np.zeros(W.shape + (2, 2)).reshape((-1, 2, 2))
 log_potential_matrix[:, 1, 1] = W.ravel()
 
+variables_for_factors = [
+    [hidden_variables[ii], visible_variables[jj]]
+    for ii in range(bh.shape[0])
+    for jj in range(bv.shape[0])
+]
+print("Time", time.time() - start)
 pairwise_factors = enumeration.PairwiseFactorGroup(
-    variables_for_factors=[
-        [hidden_variables[ii], visible_variables[jj]]
-        for ii in range(bh.shape[0])
-        for jj in range(bv.shape[0])
-    ],
+    variables_for_factors=variables_for_factors,
     log_potential_matrix=log_potential_matrix,
 )
 print("Time", time.time() - start)
-# #%snakeviz pairwise_factors = enumeration.PairwiseFactorGroup(variables_for_factors=[ [hidden_variables[ii], visible_variables[jj]] for ii in range(bh.shape[0]) for jj in range(bv.shape[0])],log_potential_matrix=log_potential_matrix,)
 
 fg.add_factors(pairwise_factors)
 print("Time", time.time() - start)
diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py
index c8f93f6b..8b931e47 100644
--- a/pgmax/factors/logical.py
+++ b/pgmax/factors/logical.py
@@ -241,16 +241,6 @@ class ANDFactor(LogicalFactor):
     edge_states_offset: int = field(init=False, default=-1)
 
 
-@nb.jit(parallel=False, cache=True, fastmath=True, nopython=True)
-def _compile_utils_numba(var_states, factor_edges_num_states, variables_for_factors):
-    idx = 0
-    for variables_for_factor in variables_for_factors:
-        for variable in variables_for_factor:
-            var_states[idx] = variable
-            factor_edges_num_states[idx] = variable[1]
-            idx += 1
-
-
 @nb.jit(parallel=False, cache=True, fastmath=True, nopython=True)
 def _compile_logical_wiring_numba(
     parents_edge_states: np.ndarray,

From a276ce5a45676872eb8471953ec3567f497ca479 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Wed, 27 Apr 2022 18:04:45 +0000
Subject: [PATCH 27/35] Some comments

---
 examples/ising_model.py              |  4 +--
 examples/pmp_binary_deconvolution.py |  2 +-
 pgmax/fg/graph.py                    |  5 +--
 pgmax/fg/groups.py                   | 13 ++++++--
 pgmax/groups/enumeration.py          |  6 +---
 pgmax/groups/variables.py            | 48 +++++++++++++++-------------
 6 files changed, 40 insertions(+), 38 deletions(-)

diff --git a/examples/ising_model.py b/examples/ising_model.py
index 3f530a33..bf30c7bd 100644
--- a/examples/ising_model.py
+++ b/examples/ising_model.py
@@ -37,7 +37,7 @@
         kk = (ii + 1) % 50
         ll = (jj + 1) % 50
         variables_for_factors.append([variables[ii, jj], variables[kk, jj]])
-        variables_for_factors.append([variables[ii, jj], variables[kk, ll]])
+        variables_for_factors.append([variables[ii, jj], variables[ii, ll]])
 
 factor_group = enumeration.PairwiseFactorGroup(
     variables_for_factors=variables_for_factors,
@@ -52,8 +52,6 @@
 bp = graph.BP(fg.bp_state, temperature=0)
 
 # %%
-# TODO: check\ time for before BP vs time for BP
-# TODO: time PGMAX vs PMP
 bp_arrays = bp.init(
     evidence_updates={variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}
 )
diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py
index a0f24406..131c79e3 100644
--- a/examples/pmp_binary_deconvolution.py
+++ b/examples/pmp_binary_deconvolution.py
@@ -138,7 +138,7 @@ def plot_images(images, display=True, nr=None):
 print("Time", time.time() - start)
 
 # %% [markdown]
-# For computation efficiency, we add large FactorGroups via `fg.add_factors` instead of adding individual Factors
+# For computation efficiency, we construct large FactorGroups instead of individual Factors
 
 # %%
 start = time.time()
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index a13fc426..29ecc26a 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -49,9 +49,6 @@ class FactorGraph:
     variable_groups: Union[groups.VariableGroup, Sequence[groups.VariableGroup]]
 
     def __post_init__(self):
-        import time
-
-        start = time.time()
         if isinstance(self.variable_groups, groups.VariableGroup):
             self.variable_groups = [self.variable_groups]
 
@@ -68,6 +65,7 @@ def __post_init__(self):
         )
 
         # See FactorGraphState docstrings for documentation on the following fields
+        # TODO: vars_to_starts does not have to be a dict
         self._vars_to_starts: Dict[Tuple[int, int], int] = {}
 
         vars_num_states_cumsum = 0
@@ -82,7 +80,6 @@ def __post_init__(self):
             )
             vars_num_states_cumsum += vg_num_states_cumsum[-1]
         self._num_var_states = vars_num_states_cumsum
-        print("Init", time.time() - start)
 
     def __hash__(self) -> int:
         return hash(self.factor_groups)
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index bf1c43d1..5f47929b 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -23,6 +23,8 @@
 import pgmax.fg.nodes as nodes
 from pgmax.utils import cached_property
 
+MAX_SIZE = 1e9
+
 
 @total_ordering
 @dataclass(frozen=True, eq=False)
@@ -34,6 +36,13 @@ class VariableGroup:
         this_hash: Hash of the VariableGroup
     """
 
+    def __post_init__(self):
+        # Only compute the hash once, which is guaranteed to be an int64
+        this_id = id(self) % 2**32
+        this_hash = this_id * int(MAX_SIZE)
+        assert this_hash < 2**63
+        object.__setattr__(self, "this_hash", this_hash)
+
     def __hash__(self):
         return self.this_hash
 
@@ -328,9 +337,7 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any:
         )
 
 
-nb.jit(parallel=False, cache=True, fastmath=True, nopython=False)
-
-
+# @nb.jit(parallel=False, cache=True, fastmath=True)
 def _compile_edges_num_states_numba(factor_edges_num_states, variables_for_factors):
     idx = 0
     for variables_for_factor in variables_for_factors:
diff --git a/pgmax/groups/enumeration.py b/pgmax/groups/enumeration.py
index a09d9f24..dab6e14b 100644
--- a/pgmax/groups/enumeration.py
+++ b/pgmax/groups/enumeration.py
@@ -231,9 +231,6 @@ def __post_init__(self):
                 f"Got log_potential_matrix for {log_potential_matrix.shape[0]} factors."
             )
 
-        import time
-
-        start = time.time()
         log_potential_shape = log_potential_matrix.shape[-2:]
         for variables_for_factor in self.variables_for_factors:
             if len(variables_for_factor) != 2:
@@ -252,8 +249,7 @@ def __post_init__(self):
                     f"configurations) does not match the specified log_potential_matrix "
                     f"(with {log_potential_shape} configurations)."
                 )
-            object.__setattr__(self, "log_potential_matrix", log_potential_matrix)
-        print(time.time() - start)
+        object.__setattr__(self, "log_potential_matrix", log_potential_matrix)
 
         factor_configs = (
             np.mgrid[
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index 92551ffa..7d61fb7a 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -10,8 +10,6 @@
 from pgmax.fg import groups
 from pgmax.utils import cached_property
 
-MAX_SIZE = 1e9
-
 
 @dataclass(frozen=True, eq=False)
 class NDVariableArray(groups.VariableGroup):
@@ -28,9 +26,12 @@ class NDVariableArray(groups.VariableGroup):
     num_states: Union[int, np.ndarray]
 
     def __post_init__(self):
-        if np.prod(self.shape) > MAX_SIZE:
+        super().__post_init__()
+
+        max_size = int(groups.MAX_SIZE)
+        if np.prod(self.shape) > max_size:
             raise ValueError(
-                f"Currently only support NDVariableArray of size smaller than {int(MAX_SIZE)}. Got {np.prod(self.shape)}"
+                f"Currently only support NDVariableArray of size smaller than {max_size}. Got {np.prod(self.shape)}"
             )
 
         if np.isscalar(self.num_states):
@@ -48,12 +49,6 @@ def __post_init__(self):
                 "num_states should be an integer or a NumPy array of dtype int"
             )
 
-        # Only compute the hash once, which is guaranteed to be an int64
-        this_id = id(self) % 2**32
-        this_hash = this_id * int(MAX_SIZE)
-        assert this_hash < 2**63
-        object.__setattr__(self, "this_hash", this_hash)
-
     def __getitem__(
         self, val: Union[int, slice, Tuple]
     ) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
@@ -175,28 +170,35 @@ class VariableDict(groups.VariableGroup):
     """A variable dictionary that contains a set of variables of the same size
 
     Args:
-        num_states: The size of the variables in this variable group
-        variable_names: A tuple of all names of the variables in this variable group.
+        num_states: The size of the variables in this VariableGroup
+        variable_names: A tuple of all names of the variables in this VariableGroup.
             Note that we overwrite variable_names to add the hash of the VariableDict
     """
 
     variable_names: Tuple[Any, ...]
-    num_states: int
+    num_states: Union[int, np.ndarray]
 
     def __post_init__(self):
-        num_states = np.full(
-            (len(self.variable_names),), fill_value=self.num_states, dtype=np.int32
-        )
-        object.__setattr__(self, "num_states", num_states)
-
-        # Only compute the hash once
-        object.__setattr__(self, "this_hash", id(self))
+        super().__post_init__()
 
         hash_and_names = tuple(
             (self.__hash__(), var_name) for var_name in self.variable_names
         )
         object.__setattr__(self, "variable_names", hash_and_names)
 
+        if np.isscalar(self.num_states):
+            num_states = np.full(
+                len(self.variable_names), fill_value=self.num_states, dtype=np.int64
+            )
+            object.__setattr__(self, "num_states", num_states)
+        elif isinstance(self.num_states, np.ndarray) and np.issubdtype(
+            self.num_states.dtype, int
+        ):
+            if self.num_states.shape != len(self.variable_names):
+                raise ValueError(
+                    f"Expected num_states shape {len(self.variable_names)}. Got {self.num_states.shape}."
+                )
+
     @cached_property
     def variables(self) -> List[Tuple[Tuple[Any, int], int]]:
         """Function that returns the list of all variables in the VariableGroup.
@@ -215,7 +217,7 @@ def __getitem__(self, val: Any) -> Tuple[Tuple[Any, int], int]:
         (variable name, number of states)
 
         Args:
-            val: a variable name
+            val: a variable name (without the object hash)
 
         Returns:
             The queried variable
@@ -223,7 +225,9 @@ def __getitem__(self, val: Any) -> Tuple[Tuple[Any, int], int]:
         assert isinstance(self.num_states, np.ndarray)
         if (self.__hash__(), val) not in self.variable_names:
             raise ValueError(f"Variable {val} is not in VariableDict")
-        return ((self.__hash__(), val), self.num_states[0])
+
+        idx = self.variable_names.index((self.__hash__(), val))
+        return ((self.__hash__(), val), self.num_states[idx])
 
     def flatten(
         self, data: Mapping[Tuple[Tuple[int, int], int], Union[np.ndarray, jnp.ndarray]]

From a839be6ee691e755ab81a920e9db1632a587b576 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Wed, 27 Apr 2022 22:56:11 +0000
Subject: [PATCH 28/35] Comments

---
 examples/pmp_binary_deconvolution.py |   2 +-
 examples/rbm.py                      |  11 ++-
 examples/rcn.py                      |  24 ++++---
 pgmax/fg/graph.py                    |  67 +++++++++--------
 pgmax/fg/groups.py                   |  93 ++++++++++++++----------
 pgmax/groups/variables.py            | 103 ++++++++++++---------------
 tests/factors/test_and.py            |   6 +-
 tests/factors/test_or.py             |   6 +-
 tests/fg/test_graph.py               |  28 ++------
 tests/fg/test_groups.py              |  26 +++++--
 tests/fg/test_wiring.py              |  28 ++++----
 tests/test_pgmax.py                  |  22 +++---
 12 files changed, 212 insertions(+), 204 deletions(-)

diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py
index 131c79e3..a8b992e5 100644
--- a/examples/pmp_binary_deconvolution.py
+++ b/examples/pmp_binary_deconvolution.py
@@ -195,7 +195,7 @@ def plot_images(images, display=True, nr=None):
 fg.add_factors(OR_factor_group)
 print("Time", time.time() - start)
 
-for factor_type, factor_groups in fg._factor_types_to_groups.items():
+for factor_type, factor_groups in fg.factor_groups.items():
     if len(factor_groups) > 0:
         assert len(factor_groups) == 1
         print(f"The factor graph contains {factor_groups[0].num_factors} {factor_type}")
diff --git a/examples/rbm.py b/examples/rbm.py
index 8565f086..4b391a62 100644
--- a/examples/rbm.py
+++ b/examples/rbm.py
@@ -69,14 +69,12 @@
     factor_configs=np.arange(2)[:, None],
     log_potentials=np.stack([np.zeros_like(bh), bh], axis=1),
 )
-fg.add_factors(hidden_unaries)
 
 visible_unaries = enumeration.EnumerationFactorGroup(
     variables_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])],
     factor_configs=np.arange(2)[:, None],
     log_potentials=np.stack([np.zeros_like(bv), bv], axis=1),
 )
-fg.add_factors(visible_unaries)
 
 # Add pairwise factors
 log_potential_matrix = np.zeros(W.shape + (2, 2)).reshape((-1, 2, 2))
@@ -94,7 +92,7 @@
 )
 print("Time", time.time() - start)
 
-fg.add_factors(pairwise_factors)
+fg.add_factors([hidden_unaries, visible_unaries, pairwise_factors])
 print("Time", time.time() - start)
 
 
@@ -107,6 +105,7 @@
 #
 # An alternative way of creating the above factors is to add them iteratively by calling [`fg.add_factor`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph.add_factor) as below. This approach is not recommended as it is not computationally efficient.
 # ~~~python
+# from pgmax.factors import enumeration as enumeration_factor
 # import itertools
 # from tqdm import tqdm
 #
@@ -117,7 +116,7 @@
 #         factor_configs=np.arange(2)[:, None],
 #         log_potentials=np.array([0, bh[ii]]),
 #     )
-#     fg.add_factors(factor=factor)
+#     fg.add_factors(factor)
 #
 # for jj in range(bv.shape[0]):
 #     factor = enumeration_factor.EnumerationFactor(
@@ -125,7 +124,7 @@
 #         factor_configs=np.arange(2)[:, None],
 #         log_potentials=np.array([0, bv[jj]]),
 #     )
-#     fg.add_factors(factor=factor)
+#     fg.add_factors(factor)
 #
 # # Add pairwise factors
 # factor_configs = np.array(list(itertools.product(np.arange(2), repeat=2)))
@@ -136,7 +135,7 @@
 #             factor_configs=factor_configs,
 #             log_potentials=np.array([0, 0, 0, W[ii, jj]]),
 #         )
-#         fg.add_factors(factor=factor)
+#         fg.add_factors(factor)
 # ~~~
 #
 # Once we have added the factors, we can run max-product LBP and get MAP decoding by
diff --git a/examples/rcn.py b/examples/rcn.py
index 2414c0e9..140c1712 100644
--- a/examples/rcn.py
+++ b/examples/rcn.py
@@ -38,9 +38,9 @@
 from scipy.signal import fftconvolve
 from sklearn.datasets import fetch_openml
 
-from pgmax.factors.enumeration import EnumerationFactor
 from pgmax.fg import graph
 from pgmax.groups import variables as vgroup
+from pgmax.groups.enumeration import EnumerationFactorGroup
 
 memory = Memory("./example_data/tmp")
 fetch_openml_cached = memory.cache(fetch_openml)
@@ -280,12 +280,13 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray:
 
     for e in edge:
         i1, i2, r = e
-        factor = EnumerationFactor(
-            variables=[variables_all_models[idx][i1], variables_all_models[idx][i2]],
+        factor_group = EnumerationFactorGroup(
+            variables_for_factors=[
+                [variables_all_models[idx][i1], variables_all_models[idx][i2]]
+            ],
             factor_configs=valid_configs_list[r],
-            log_potentials=np.zeros(valid_configs_list[r].shape[0]),
         )
-        fg.add_factors(factor=factor)
+        fg.add_factors(factor_group)
 
 end = time.time()
 print(f"Creating factors took {end-start:.3f} seconds.")
@@ -384,7 +385,10 @@ def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray:
 
 
 # %%
-frcs_dict = {model_idx: frcs[model_idx] for model_idx in range(frcs.shape[0])}
+frcs_dict = {
+    variables_all_models[model_idx]: frcs[model_idx]
+    for model_idx in range(frcs.shape[0])
+}
 bp = graph.BP(fg.bp_state, temperature=0.0)
 scores = np.zeros((len(test_set), frcs.shape[0]))
 map_states_dict = {}
@@ -413,8 +417,8 @@ def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray:
         evidence_updates,
         map_states,
     )
-    for ii in score:
-        scores[test_idx, ii] = score[ii]
+    for idx, score in enumerate(score.values()):
+        scores[test_idx, idx] = score
     end = time.time()
     print(f"Computing scores took {end-start:.3f} seconds for image {test_idx}.")
 
@@ -437,7 +441,9 @@ def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray:
 fig, ax = plt.subplots(5, 4, figsize=(16, 20))
 for test_idx in range(20):
     idx = np.unravel_index(test_idx, (5, 4))
-    map_state = map_states_dict[test_idx][best_model_idx[test_idx]]
+    map_state = map_states_dict[test_idx][
+        variables_all_models[best_model_idx[test_idx]]
+    ]
     offsets = np.array(
         np.unravel_index(map_state, (2 * hps + 1, 2 * vps + 1))
     ).T - np.array([hps, vps])
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 29ecc26a..21ea1f11 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -82,36 +82,44 @@ def __post_init__(self):
         self._num_var_states = vars_num_states_cumsum
 
     def __hash__(self) -> int:
-        return hash(self.factor_groups)
+        all_factor_groups = tuple(
+            [
+                factor_group
+                for factor_groups_per_type in self._factor_types_to_groups.values()
+                for factor_group in factor_groups_per_type
+            ]
+        )
+        return hash(all_factor_groups)
 
     def add_factors(
         self,
-        factor_group: Optional[groups.FactorGroup] = None,
-        factor: Optional[nodes.Factor] = None,
+        factors: Union[
+            nodes.Factor,
+            groups.FactorGroup,
+            Sequence[Union[nodes.Factor, groups.FactorGroup]],
+        ],
     ) -> None:
-        """Add a FactorGroup or a single Factor to the FactorGraph, by updating the FactorGraphState.
+        """Add a single Factor, a FactorGroup or a list of single Factors and FactorGroups to the FactorGraph,
+        by updating the FactorGraphState.
 
         Args:
-            factor_group: The FactorGroup to be added to the FactorGraph.
-            factor: The Factor to be added to the FactorGraph.
+            factors: The Factor, FactorGroup or list of Factors and FactorGroups to be added to the FactorGraph.
 
         Raises:
-            ValueError: If
-                (1) Both a Factor and a FactorGroup are added
-                (2) The FactorGroup involving the same variables already exists in the FactorGraph.
+            ValueError: A FactorGroup involving the same variables already exists in the FactorGraph.
         """
-        if factor is None and factor_group is None:
-            raise ValueError("A Factor or a FactorGroup is required")
-
-        if factor is not None and factor_group is not None:
-            raise ValueError("Cannot simultaneously add a Factor and a FactorGroup")
-
-        if factor is not None:
+        if isinstance(factors, list):
+            for factor in factors:
+                self.add_factors(factor)
+            return None
+
+        if isinstance(factors, groups.FactorGroup):
+            factor_group = factors
+        elif isinstance(factors, nodes.Factor):
             factor_group = groups.SingleFactorGroup(
-                variables_for_factors=[factor.variables],
-                factor=factor,
+                variables_for_factors=[factors.variables],
+                factor=factors,
             )
-        assert factor_group is not None
 
         factor_type = factor_group.factor_type
         for var_names_for_factor in factor_group.variables_for_factors:
@@ -253,15 +261,9 @@ def factors(self) -> OrderedDict[Type, Tuple[nodes.Factor, ...]]:
         return factors
 
     @property
-    def factor_groups(self) -> Tuple[groups.FactorGroup, ...]:
+    def factor_groups(self) -> OrderedDict[Type, List[groups.FactorGroup]]:
         """Tuple of factor groups in the factor graph"""
-        return tuple(
-            [
-                factor_group
-                for factor_groups_per_type in self._factor_types_to_groups.values()
-                for factor_group in factor_groups_per_type
-            ]
-        )
+        return self._factor_types_to_groups
 
     @cached_property
     def fg_state(self) -> FactorGraphState:
@@ -278,7 +280,6 @@ def fg_state(self) -> FactorGraphState:
             vars_to_starts=self._vars_to_starts,
             num_var_states=self._num_var_states,
             total_factor_num_states=self._total_factor_num_states,
-            factor_groups=self.factor_groups,
             factor_type_to_msgs_range=copy.copy(self._factor_type_to_msgs_range),
             factor_type_to_potentials_range=copy.copy(
                 self._factor_type_to_potentials_range
@@ -314,7 +315,6 @@ class FactorGraphState:
             contains evidence to the variable.
         num_var_states: Total number of variable states.
         total_factor_num_states: Size of the flat ftov messages array.
-        factor_groups: FactorGroups in the FactorGraph
         factor_type_to_msgs_range: Maps factors types to their start and end indices in the flat ftov messages.
         factor_type_to_potentials_range: Maps factor types to their start and end indices in the flat log potentials.
         factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials.
@@ -326,7 +326,6 @@ class FactorGraphState:
     vars_to_starts: Mapping[Tuple[Any, int], int]
     num_var_states: int
     total_factor_num_states: int
-    factor_groups: Tuple[groups.FactorGroup, ...]
     factor_type_to_msgs_range: OrderedDict[type, Tuple[int, int]]
     factor_type_to_potentials_range: OrderedDict[type, Tuple[int, int]]
     factor_group_to_potentials_starts: OrderedDict[groups.FactorGroup, int]
@@ -395,7 +394,7 @@ def update_log_potentials(
         (2) Provided name is not valid for log_potentials updates.
     """
     for factor_group, data in updates.items():
-        if factor_group in fg_state.factor_groups:
+        if factor_group in fg_state.factor_group_to_potentials_starts:
             flat_data = factor_group.flatten(data)
             if flat_data.shape != factor_group.factor_group_log_potentials.shape:
                 raise ValueError(
@@ -450,7 +449,7 @@ def __getitem__(self, factor_group: groups.FactorGroup) -> np.ndarray:
             The queried log potentials.
         """
         value = cast(np.ndarray, self.value)
-        if factor_group in self.fg_state.factor_groups:
+        if factor_group in self.fg_state.factor_group_to_potentials_starts:
             start = self.fg_state.factor_group_to_potentials_starts[factor_group]
             log_potentials = value[
                 start : start + factor_group.factor_group_log_potentials.shape[0]
@@ -558,7 +557,7 @@ def __post_init__(self):
 
     def __setitem__(
         self,
-        variable: Tuple[Any, int],
+        variable: Tuple[int, int],
         data: Union[np.ndarray, jnp.ndarray],
     ) -> None:
         """Spreading beliefs at a variable to all connected Factors
@@ -641,7 +640,7 @@ def __post_init__(self):
 
             object.__setattr__(self, "value", self.value)
 
-    def __getitem__(self, variable: Tuple[Any, int]) -> np.ndarray:
+    def __getitem__(self, variable: Tuple[int, int]) -> np.ndarray:
         """Function to query evidence for a variable
 
         Args:
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index 5f47929b..e8079a36 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -17,7 +17,6 @@
 )
 
 import jax.numpy as jnp
-import numba as nb
 import numpy as np
 
 import pgmax.fg.nodes as nodes
@@ -32,10 +31,13 @@ class VariableGroup:
     """Class to represent a group of variables.
     Each variable is represented via a tuple of the form (variable hash/name, number of states)
 
-    Attributes:
-        this_hash: Hash of the VariableGroup
+    Arguments:
+        num_states: An integer or an array specifying the number of states of the variables
+            in this VariableGroup
     """
 
+    num_states: Union[int, np.ndarray] = field(init=False)
+
     def __post_init__(self):
         # Only compute the hash once, which is guaranteed to be an int64
         this_id = id(self) % 2**32
@@ -52,9 +54,9 @@ def __eq__(self, other):
     def __lt__(self, other):
         return hash(self) < hash(other)
 
-    def __getitem__(self, val: Any) -> Union[Tuple[Any, int], List[Tuple[Any, int]]]:
+    def __getitem__(self, val: Any) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
         """Given a variable name, index, or a group of variable indices, retrieve the associated variable(s).
-        Each variable is returned via a tuple of the form (variable hash/name, number of states)
+        Each variable is returned via a tuple of the form (variable hash, number of states)
 
         Args:
             val: a variable index, slice, or name
@@ -67,17 +69,30 @@ def __getitem__(self, val: Any) -> Union[Tuple[Any, int], List[Tuple[Any, int]]]
         )
 
     @cached_property
-    def variables(self) -> List[Tuple]:
-        """Function that returns the list of all variables in the VariableGroup.
-        Each variable is represented by a tuple of the form (variable hash/name, number of states)
+    def variable_hashes(self) -> np.ndarray:
+        """Function that generates a variable hash for each variable
 
         Returns:
-            List of variables in the VariableGroup
+            Array of variables hashes.
         """
         raise NotImplementedError(
             "Please subclass the VariableGroup class and override this method"
         )
 
+    @cached_property
+    def variables(self) -> List[Tuple[int, int]]:
+        """Function that returns the list of all variables in the VariableGroup.
+        Each variable is represented by a tuple of the form (variable hash, number of states)
+
+        Returns:
+            List of variables in the VariableGroup
+        """
+        assert isinstance(self.variable_hashes, np.ndarray)
+        assert isinstance(self.num_states, np.ndarray)
+        vars_hashes = self.variable_hashes.flatten()
+        vars_num_states = self.num_states.flatten()
+        return list(zip(vars_hashes, vars_num_states))
+
     def flatten(self, data: Any) -> jnp.ndarray:
         """Function that turns meaningful structured data into a flat data array for internal use.
 
@@ -118,9 +133,6 @@ class FactorGroup:
 
     Attributes:
         factor_type: Factor type shared by all the Factors in the FactorGroup.
-        factor_sizes: Array of the different factor sizes.
-        factor_edges_num_states: Array concatenating the number of states for the variables connected to each Factor of
-            the FactorGroup. Each variable will appear once for each Factor it connects to.
 
     Raises:
         ValueError: if the FactorGroup does not contain a Factor
@@ -130,27 +142,11 @@ class FactorGroup:
     factor_configs: np.ndarray = field(init=False)
     log_potentials: np.ndarray = field(init=False, default=np.empty((0,)))
     factor_type: Type = field(init=False)
-    factor_sizes: np.ndarray = field(init=False)
-    factor_edges_num_states: np.ndarray = field(init=False)
 
     def __post_init__(self):
         if len(self.variables_for_factors) == 0:
             raise ValueError("Cannot create a FactorGroup with no Factor.")
 
-        factor_sizes = np.array(
-            [
-                len(variables_for_factor)
-                for variables_for_factor in self.variables_for_factors
-            ]
-        )
-        object.__setattr__(self, "factor_sizes", factor_sizes)
-
-        factor_edges_num_states = np.empty(shape=(self.factor_sizes.sum(),), dtype=int)
-        _compile_edges_num_states_numba(
-            factor_edges_num_states, self.variables_for_factors
-        )
-        object.__setattr__(self, "factor_edges_num_states", factor_edges_num_states)
-
     def __hash__(self):
         return id(self)
 
@@ -180,6 +176,38 @@ def __getitem__(self, variables: Union[Sequence, Collection]) -> Any:
             )
         return self._variables_to_factors[variables]
 
+    @cached_property
+    def factor_sizes(self) -> np.ndarray:
+        """Computes the factor sizes
+
+        Returns:
+            Array of the different factor sizes.
+        """
+        return np.array(
+            [
+                len(variables_for_factor)
+                for variables_for_factor in self.variables_for_factors
+            ]
+        )
+
+    @cached_property
+    def factor_edges_num_states(self) -> np.ndarray:
+        """Computes an array concatenating the number of states for the variables connected to each Factor of
+        the FactorGroup. Each variable will appear once for each Factor it connects to.
+
+        Returns:
+            Array with the the number of states for the variables connected to each Factor.
+        """
+        factor_edges_num_states = np.empty(shape=(self.factor_sizes.sum(),), dtype=int)
+
+        # TODO: create variables_for_factors as an array and move this to numba
+        idx = 0
+        for variables_for_factor in self.variables_for_factors:
+            for variable in variables_for_factor:
+                factor_edges_num_states[idx] = variable[1]
+                idx += 1
+        return factor_edges_num_states
+
     @cached_property
     def _variables_to_factors(self) -> Mapping[FrozenSet, nodes.Factor]:
         """Function to compile potential array for the factor group.
@@ -335,12 +363,3 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any:
         raise NotImplementedError(
             "SingleFactorGroup does not support vectorized factor operations."
         )
-
-
-# @nb.jit(parallel=False, cache=True, fastmath=True)
-def _compile_edges_num_states_numba(factor_edges_num_states, variables_for_factors):
-    idx = 0
-    for variables_for_factor in variables_for_factors:
-        for variable in variables_for_factor:
-            factor_edges_num_states[idx] = variable[1]
-            idx += 1
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index 7d61fb7a..22a7ef0e 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -16,14 +16,14 @@ class NDVariableArray(groups.VariableGroup):
     """Subclass of VariableGroup for n-dimensional grids of variables.
 
     Args:
-        shape: Tuple specifying the size of each dimension of the grid (similar to
-            the notion of a NumPy ndarray shape)
         num_states: An integer or an array specifying the number of states of the
             variables in this VariableGroup
+        shape: Tuple specifying the size of each dimension of the grid (similar to
+            the notion of a NumPy ndarray shape)
     """
 
-    shape: Tuple[int, ...]
     num_states: Union[int, np.ndarray]
+    shape: Tuple[int, ...]
 
     def __post_init__(self):
         super().__post_init__()
@@ -69,31 +69,18 @@ def __getitem__(
             isinstance(val, tuple) and isinstance(val[0], slice)
         ):
             assert isinstance(self.num_states, np.ndarray)
-            vars_names = self.variable_names[val].flatten()
+            vars_names = self.variable_hashes[val].flatten()
             vars_num_states = self.num_states[val].flatten()
             return list(zip(vars_names, vars_num_states))
 
-        return (self.variable_names[val], self.num_states[val])
-
-    @cached_property
-    def variables(self) -> List[Tuple]:
-        """Function that returns the list of all variables in the VariableGroup.
-        Each variable is represented by a tuple of the form (variable hash, number of states)
-
-        Returns:
-            List of variables in the VariableGroup
-        """
-        assert isinstance(self.num_states, np.ndarray)
-        vars_names = self.variable_names.flatten()
-        vars_num_states = self.num_states.flatten()
-        return list(zip(vars_names, vars_num_states))
+        return (self.variable_hashes[val], self.num_states[val])
 
     @cached_property
-    def variable_names(self) -> np.ndarray:
-        """Function that generates all the variables names, in the form of hashes
+    def variable_hashes(self) -> np.ndarray:
+        """Function that generates a variable hash for each variable
 
         Returns:
-            Array of variables names.
+            Array of variables hashes.
         """
         indices = np.reshape(np.arange(np.product(self.shape)), self.shape)
         return self.__hash__() + indices
@@ -167,28 +154,22 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
 
 @dataclass(frozen=True, eq=False)
 class VariableDict(groups.VariableGroup):
-    """A variable dictionary that contains a set of variables of the same size
+    """A variable dictionary that contains a set of variables
 
     Args:
         num_states: The size of the variables in this VariableGroup
-        variable_names: A tuple of all names of the variables in this VariableGroup.
-            Note that we overwrite variable_names to add the hash of the VariableDict
+        variable_names: A tuple of all the names of the variables in this VariableGroup.
     """
 
-    variable_names: Tuple[Any, ...]
     num_states: Union[int, np.ndarray]
+    variable_names: Tuple[Any, ...]
 
     def __post_init__(self):
         super().__post_init__()
 
-        hash_and_names = tuple(
-            (self.__hash__(), var_name) for var_name in self.variable_names
-        )
-        object.__setattr__(self, "variable_names", hash_and_names)
-
         if np.isscalar(self.num_states):
             num_states = np.full(
-                len(self.variable_names), fill_value=self.num_states, dtype=np.int64
+                (len(self.variable_names),), fill_value=self.num_states, dtype=np.int64
             )
             object.__setattr__(self, "num_states", num_states)
         elif isinstance(self.num_states, np.ndarray) and np.issubdtype(
@@ -196,38 +177,39 @@ def __post_init__(self):
         ):
             if self.num_states.shape != len(self.variable_names):
                 raise ValueError(
-                    f"Expected num_states shape {len(self.variable_names)}. Got {self.num_states.shape}."
+                    f"Expected num_states shape ({len(self.variable_names)},). Got {self.num_states.shape}."
                 )
+        else:
+            raise ValueError(
+                "num_states should be an integer or a NumPy array of dtype int"
+            )
 
     @cached_property
-    def variables(self) -> List[Tuple[Tuple[Any, int], int]]:
-        """Function that returns the list of all variables in the VariableGroup.
-        Each variable is represented by a tuple of the form (variable name, number of states)
+    def variable_hashes(self) -> np.ndarray:
+        """Function that generates a variable hash for each variable
 
         Returns:
-            List of variables in the VariableGroup
+            Array of variables hashes.
         """
-        assert isinstance(self.num_states, np.ndarray)
-        vars_names = list(self.variable_names)
-        vars_num_states = self.num_states.flatten()
-        return list(zip(vars_names, vars_num_states))
+        indices = np.arange(len(self.variable_names))
+        return self.__hash__() + indices
 
-    def __getitem__(self, val: Any) -> Tuple[Tuple[Any, int], int]:
+    def __getitem__(self, var_name: Any) -> Tuple[int, int]:
         """Given a variable name retrieve the associated variable, returned via a tuple of the form
-        (variable name, number of states)
+        (variable hash, number of states)
 
         Args:
-            val: a variable name (without the object hash)
+            val: a variable name
 
         Returns:
             The queried variable
         """
         assert isinstance(self.num_states, np.ndarray)
-        if (self.__hash__(), val) not in self.variable_names:
-            raise ValueError(f"Variable {val} is not in VariableDict")
+        if var_name not in self.variable_names:
+            raise ValueError(f"Variable {var_name} is not in VariableDict")
 
-        idx = self.variable_names.index((self.__hash__(), val))
-        return ((self.__hash__(), val), self.num_states[idx])
+        var_idx = self.variable_names.index(var_name)
+        return (self.variable_hashes[var_idx], self.num_states[var_idx])
 
     def flatten(
         self, data: Mapping[Tuple[Tuple[int, int], int], Union[np.ndarray, jnp.ndarray]]
@@ -249,20 +231,23 @@ def flatten(
         """
         assert isinstance(self.num_states, np.ndarray)
 
-        for variable in data:
-            if variable not in self.variables:
+        for var_name in data:
+            if var_name not in self.variable_names:
                 raise ValueError(
-                    f"data is referring to a non-existent variable {variable}."
+                    f"data is referring to a non-existent variable {var_name}."
                 )
 
-            if data[variable].shape != (variable[1],) and data[variable].shape != (1,):
+            var_idx = self.variable_names.index(var_name)
+            if data[var_name].shape != (self.num_states[var_idx],) and data[
+                var_name
+            ].shape != (1,):
                 raise ValueError(
-                    f"Variable {variable} expects a data array of shape "
-                    f"{(variable[1],)} or (1,). Got {data[variable].shape}."
+                    f"Variable {var_name} expects a data array of shape "
+                    f"{(self.num_states[var_idx],)} or (1,). Got {data[var_name].shape}."
                 )
 
         flat_data = jnp.concatenate(
-            [data[variable].flatten() for variable in self.variables]
+            [data[var_name].flatten() for var_name in self.variable_names]
         )
         return flat_data
 
@@ -306,12 +291,14 @@ def unflatten(
 
         start = 0
         data = {}
-        for variable in self.variables:
+        for var_name in self.variable_names:
             if use_num_states:
-                data[variable] = flat_data[start : start + variable[1]]
-                start += variable[1]
+                var_idx = self.variable_names.index(var_name)
+                var_num_states = self.num_states[var_idx]
+                data[var_name] = flat_data[start : start + var_num_states]
+                start += var_num_states
             else:
-                data[variable] = flat_data[np.array([start])]
+                data[var_name] = flat_data[np.array([start])]
                 start += 1
 
         return data
diff --git a/tests/factors/test_and.py b/tests/factors/test_and.py
index fcbe6772..4f2f7612 100644
--- a/tests/factors/test_and.py
+++ b/tests/factors/test_and.py
@@ -104,7 +104,7 @@ def test_run_bp_with_ANDFactors():
                     factor_configs=valid_configs,
                     log_potentials=np.zeros(valid_configs.shape[0]),
                 )
-                fg1.add_factors(factor=enum_factor)
+                fg1.add_factors(enum_factor)
             else:
                 if idx != 0:
                     # Add the second half of factors to FactorGraph2
@@ -113,7 +113,7 @@ def test_run_bp_with_ANDFactors():
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
-                    fg2.add_factors(factor=enum_factor)
+                    fg2.add_factors(enum_factor)
                 else:
                     # Add all the EnumerationFactors to FactorGraph1 for the first iter
                     enum_factor = EnumerationFactor(
@@ -121,7 +121,7 @@ def test_run_bp_with_ANDFactors():
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
-                    fg1.add_factors(factor=enum_factor)
+                    fg1.add_factors(enum_factor)
 
         # Option 2: Define the ANDFactors
         num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0)
diff --git a/tests/factors/test_or.py b/tests/factors/test_or.py
index f29dccd0..c615a800 100644
--- a/tests/factors/test_or.py
+++ b/tests/factors/test_or.py
@@ -102,7 +102,7 @@ def test_run_bp_with_ORFactors():
                     factor_configs=valid_configs,
                     log_potentials=np.zeros(valid_configs.shape[0]),
                 )
-                fg1.add_factors(factor=enum_factor)
+                fg1.add_factors(enum_factor)
             else:
                 if idx != 0:
                     # Add the second half of factors to FactorGraph2
@@ -111,7 +111,7 @@ def test_run_bp_with_ORFactors():
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
-                    fg2.add_factors(factor=enum_factor)
+                    fg2.add_factors(enum_factor)
                 else:
                     # Add all the EnumerationFactors to FactorGraph1 for the first iter
                     enum_factor = EnumerationFactor(
@@ -119,7 +119,7 @@ def test_run_bp_with_ORFactors():
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
-                    fg1.add_factors(factor=enum_factor)
+                    fg1.add_factors(enum_factor)
 
         # Option 2: Define the ORFactors
         num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0)
diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py
index 0fdadbb2..8b0d72b0 100644
--- a/tests/fg/test_graph.py
+++ b/tests/fg/test_graph.py
@@ -16,30 +16,12 @@ def test_factor_graph():
     vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
     fg = graph.FactorGraph(vg)
 
-    with pytest.raises(
-        ValueError,
-        match="A Factor or a FactorGroup is required",
-    ):
-        fg.add_factors(factor_group=None, factor=None)
-
     factor = enumeration_factor.EnumerationFactor(
         variables=[vg[0]],
         factor_configs=np.arange(15)[:, None],
         log_potentials=np.zeros(15),
     )
-
-    factor_group = enumeration.EnumerationFactorGroup(
-        variables_for_factors=[[vg[0]]],
-        factor_configs=np.arange(15)[:, None],
-        log_potentials=np.zeros(15),
-    )
-    with pytest.raises(
-        ValueError,
-        match="Cannot simultaneously add a Factor and a FactorGroup",
-    ):
-        fg.add_factors(factor_group=factor_group, factor=factor)
-
-    fg.add_factors(factor=factor)
+    fg.add_factors(factor)
 
     factor_group = enumeration.EnumerationFactorGroup(
         variables_for_factors=[[vg[0]]],
@@ -49,7 +31,7 @@ def test_factor_graph():
     with pytest.raises(
         ValueError,
         match=re.escape(
-            f"A Factor of type {enumeration_factor.EnumerationFactor} involving variables {frozenset([((vg.__hash__(), 0), 15)])} already exists."
+            f"A Factor of type {enumeration_factor.EnumerationFactor} involving variables {frozenset([(vg.__hash__(), 15)])} already exists."
         ),
     ):
         fg.add_factors(factor_group)
@@ -63,10 +45,10 @@ def test_bp_state():
         factor_configs=np.arange(15)[:, None],
         log_potentials=np.zeros(15),
     )
-    fg0.add_factors(factor=factor)
+    fg0.add_factors(factor)
 
     fg1 = graph.FactorGraph(vg)
-    fg1.add_factors(factor=factor)
+    fg1.add_factors(factor)
 
     with pytest.raises(
         ValueError,
@@ -137,7 +119,7 @@ def test_ftov_msgs():
     with pytest.raises(
         ValueError,
         match=re.escape(
-            f"Given belief shape (10,) does not match expected shape (15,) for variable (({vg.__hash__()}, 0), 15)."
+            f"Given belief shape (10,) does not match expected shape (15,) for variable ({vg.__hash__()}, 15)."
         ),
     ):
         fg.bp_state.ftov_msgs[vg[0]] = np.ones(10)
diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py
index b47e3e09..7799ffb0 100644
--- a/tests/fg/test_groups.py
+++ b/tests/fg/test_groups.py
@@ -11,6 +11,21 @@
 
 
 def test_variable_dict():
+    num_states = np.full((4,), fill_value=2)
+    with pytest.raises(
+        ValueError, match=re.escape("Expected num_states shape (3,). Got (4,).")
+    ):
+        vgroup.VariableDict(variable_names=tuple([0, 1, 2]), num_states=num_states)
+
+    num_states = np.full((3,), fill_value=2, dtype=np.float32)
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            "num_states should be an integer or a NumPy array of dtype int"
+        ),
+    ):
+        vgroup.NDVariableArray(shape=(2, 2), num_states=num_states)
+
     variable_dict = vgroup.VariableDict(variable_names=tuple([0, 1, 2]), num_states=15)
     with pytest.raises(
         ValueError, match="data is referring to a non-existent variable 3"
@@ -20,10 +35,10 @@ def test_variable_dict():
     with pytest.raises(
         ValueError,
         match=re.escape(
-            f"Variable (({variable_dict.__hash__()}, 2), 15) expects a data array of shape (15,) or (1,). Got (10,)"
+            "Variable 2 expects a data array of shape (15,) or (1,). Got (10,)."
         ),
     ):
-        variable_dict.flatten({((variable_dict.__hash__(), 2), 15): np.zeros(10)})
+        variable_dict.flatten({2: np.zeros(10)})
 
     with pytest.raises(
         ValueError, match="Can only unflatten 1D array. Got a 2D array."
@@ -36,10 +51,7 @@ def test_variable_dict():
                 jax.tree_util.tree_multimap(
                     lambda x, y: jnp.all(x == y),
                     variable_dict.unflatten(jnp.zeros(3)),
-                    {
-                        ((variable_dict.__hash__(), name), 15): np.zeros(1)
-                        for name in range(3)
-                    },
+                    {name: np.zeros(1) for name in range(3)},
                 )
             )
         )
@@ -54,7 +66,7 @@ def test_variable_dict():
 
 
 def test_nd_variable_array():
-    max_size = int(vgroup.MAX_SIZE)
+    max_size = int(groups.MAX_SIZE)
     with pytest.raises(
         ValueError,
         match=re.escape(
diff --git a/tests/fg/test_wiring.py b/tests/fg/test_wiring.py
index 53efeee3..af2fb9fc 100644
--- a/tests/fg/test_wiring.py
+++ b/tests/fg/test_wiring.py
@@ -25,7 +25,7 @@ def test_wiring_with_PairwiseFactorGroup():
     )
     fg.add_factors(factor_group)
 
-    factor_group = fg.factor_groups[0]
+    factor_group = fg.factor_groups[enumeration_factor.EnumerationFactor][0]
     object.__setattr__(
         factor_group, "factor_configs", factor_group.factor_configs[:, :1]
     )
@@ -41,7 +41,7 @@ def test_wiring_with_PairwiseFactorGroup():
         variables_for_factors=[[A[idx], B[idx]] for idx in range(10)]
     )
     fg1.add_factors(factor_group)
-    assert len(fg1.factor_groups) == 1
+    assert len(fg1.factor_groups[enumeration_factor.EnumerationFactor]) == 1
 
     # FactorGraph with multiple PairwiseFactorGroup
     fg2 = graph.FactorGraph(variable_groups=[A, B])
@@ -50,18 +50,20 @@ def test_wiring_with_PairwiseFactorGroup():
             variables_for_factors=[[A[idx], B[idx]]]
         )
         fg2.add_factors(factor_group)
-    assert len(fg2.factor_groups) == 10
+    assert len(fg2.factor_groups[enumeration_factor.EnumerationFactor]) == 10
 
     # FactorGraph with multiple SingleFactorGroup
     fg3 = graph.FactorGraph(variable_groups=[A, B])
+    factors = []
     for idx in range(10):
         factor = enumeration_factor.EnumerationFactor(
             variables=[A[idx], B[idx]],
             factor_configs=np.array([[0, 0], [0, 1], [1, 0], [1, 1]]),
             log_potentials=np.zeros((4,)),
         )
-        fg3.add_factors(factor=factor)
-    assert len(fg3.factor_groups) == 10
+        factors.append(factor)
+    fg3.add_factors(factors)
+    assert len(fg3.factor_groups[enumeration_factor.EnumerationFactor]) == 10
 
     assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors)
 
@@ -100,7 +102,7 @@ def test_wiring_with_ORFactorGroup():
         variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)],
     )
     fg1.add_factors(factor_group)
-    assert len(fg1.factor_groups) == 1
+    assert len(fg1.factor_groups[logical_factor.ORFactor]) == 1
 
     # FactorGraph with multiple ORFactorGroup
     fg2 = graph.FactorGraph(variable_groups=[A, B, C])
@@ -109,7 +111,7 @@ def test_wiring_with_ORFactorGroup():
             variables_for_factors=[[A[idx], B[idx], C[idx]]],
         )
         fg2.add_factors(factor_group)
-    assert len(fg2.factor_groups) == 10
+    assert len(fg2.factor_groups[logical_factor.ORFactor]) == 10
 
     # FactorGraph with multiple SingleFactorGroup
     fg3 = graph.FactorGraph(variable_groups=[A, B, C])
@@ -117,8 +119,8 @@ def test_wiring_with_ORFactorGroup():
         factor = logical_factor.ORFactor(
             variables=[A[idx], B[idx], C[idx]],
         )
-        fg3.add_factors(factor=factor)
-    assert len(fg3.factor_groups) == 10
+        fg3.add_factors(factor)
+    assert len(fg3.factor_groups[logical_factor.ORFactor]) == 10
 
     assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors)
 
@@ -155,7 +157,7 @@ def test_wiring_with_ANDFactorGroup():
         variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)],
     )
     fg1.add_factors(factor_group)
-    assert len(fg1.factor_groups) == 1
+    assert len(fg1.factor_groups[logical_factor.ANDFactor]) == 1
 
     # FactorGraph with multiple ANDFactorGroup
     fg2 = graph.FactorGraph(variable_groups=[A, B, C])
@@ -164,7 +166,7 @@ def test_wiring_with_ANDFactorGroup():
             variables_for_factors=[[A[idx], B[idx], C[idx]]],
         )
         fg2.add_factors(factor_group)
-    assert len(fg2.factor_groups) == 10
+    assert len(fg2.factor_groups[logical_factor.ANDFactor]) == 10
 
     # FactorGraph with multiple SingleFactorGroup
     fg3 = graph.FactorGraph(variable_groups=[A, B, C])
@@ -172,8 +174,8 @@ def test_wiring_with_ANDFactorGroup():
         factor = logical_factor.ANDFactor(
             variables=[A[idx], B[idx], C[idx]],
         )
-        fg3.add_factors(factor=factor)
-    assert len(fg3.factor_groups) == 10
+        fg3.add_factors(factor)
+    assert len(fg3.factor_groups[logical_factor.ANDFactor]) == 10
 
     assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors)
 
diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py
index 9e6deb65..f3ee7bfa 100644
--- a/tests/test_pgmax.py
+++ b/tests/test_pgmax.py
@@ -207,7 +207,7 @@ def create_valid_suppression_config_arr(suppression_diameter):
     extra_row_names: List[Tuple[Any, ...]] = [(0, row, N - 1) for row in range(M - 1)]
     extra_col_names: List[Tuple[Any, ...]] = [(1, M - 1, col) for col in range(N - 1)]
     additional_names = tuple(extra_row_names + extra_col_names)
-    additional_vars = vgroup.VariableDict(additional_names, num_states=3)
+    additional_vars = vgroup.VariableDict(variable_names=additional_names, num_states=3)
 
     true_map_state_output = {
         (grid_vars, (0, 0, 0)): 2,
@@ -218,10 +218,14 @@ def create_valid_suppression_config_arr(suppression_diameter):
         (grid_vars, (1, 0, 1)): 0,
         (grid_vars, (1, 1, 0)): 1,
         (grid_vars, (1, 1, 1)): 0,
-        (additional_vars, ((additional_vars.__hash__(), (0, 0, 2)), 3)): 0,
-        (additional_vars, ((additional_vars.__hash__(), (0, 1, 2)), 3)): 2,
-        (additional_vars, ((additional_vars.__hash__(), (1, 2, 0)), 3)): 1,
-        (additional_vars, ((additional_vars.__hash__(), (1, 2, 1)), 3)): 0,
+        # (additional_vars, ((additional_vars.__hash__(), (0, 0, 2)), 3)): 0,
+        # (additional_vars, ((additional_vars.__hash__(), (0, 1, 2)), 3)): 2,
+        # (additional_vars, ((additional_vars.__hash__(), (1, 2, 0)), 3)): 1,
+        # (additional_vars, ((additional_vars.__hash__(), (1, 2, 1)), 3)): 0,
+        (additional_vars, (0, 0, 2)): 0,
+        (additional_vars, (0, 1, 2)): 2,
+        (additional_vars, (1, 2, 0)): 1,
+        (additional_vars, (1, 2, 1)): 0,
     }
 
     gt_has_cuts = gt_has_cuts.astype(np.int32)
@@ -251,9 +255,7 @@ def create_valid_suppression_config_arr(suppression_diameter):
                 except IndexError:
                     try:
                         _ = additional_vars[i, row, col]
-                        additional_vars_evidence_dict[
-                            additional_vars[i, row, col]
-                        ] = evidence_vals_arr
+                        additional_vars_evidence_dict[i, row, col] = evidence_vals_arr
                     except ValueError:
                         pass
 
@@ -302,7 +304,7 @@ def create_valid_suppression_config_arr(suppression_diameter):
                         valid_configs_non_supp.shape[0], dtype=float
                     ),
                 )
-                fg.add_factors(factor=factor)
+                fg.add_factors(factor)
             else:
                 factor = EnumerationFactor(
                     variables=curr_vars,
@@ -311,7 +313,7 @@ def create_valid_suppression_config_arr(suppression_diameter):
                         valid_configs_non_supp.shape[0], dtype=float
                     ),
                 )
-                fg.add_factors(factor=factor)
+                fg.add_factors(factor)
 
     # Create an EnumerationFactorGroup for vertical suppression factors
     vert_suppression_vars: List[List[Tuple[Any, ...]]] = []

From 05de350e4d20010690f8a3a1fb760cdcccaee13d Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Wed, 27 Apr 2022 23:20:34 +0000
Subject: [PATCH 29/35] Minor

---
 pgmax/factors/enumeration.py | 5 ++---
 pgmax/factors/logical.py     | 5 ++---
 2 files changed, 4 insertions(+), 6 deletions(-)

diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py
index ef2a99a7..2340d0c1 100644
--- a/pgmax/factors/enumeration.py
+++ b/pgmax/factors/enumeration.py
@@ -2,7 +2,7 @@
 
 import functools
 from dataclasses import dataclass
-from typing import Any, List, Mapping, Sequence, Tuple, Union
+from typing import List, Mapping, Sequence, Tuple, Union
 
 import jax
 import jax.numpy as jnp
@@ -157,7 +157,7 @@ def compile_wiring(
         factor_edges_num_states: np.ndarray,
         variables_for_factors: Sequence[List],
         factor_configs: np.ndarray,
-        vars_to_starts: Mapping[Tuple[Any, int], int],
+        vars_to_starts: Mapping[Tuple[int, int], int],
         num_factors: int,
     ) -> EnumerationWiring:
         """Compile an EnumerationWiring for an EnumerationFactor or a FactorGroup with EnumerationFactors.
@@ -183,7 +183,6 @@ def compile_wiring(
         Returns:
             The EnumerationWiring
         """
-        # TODO: Don't use vars_to_starts
         var_states = []
         for variables_for_factor in variables_for_factors:
             for variable in variables_for_factor:
diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py
index 8b931e47..482865b2 100644
--- a/pgmax/factors/logical.py
+++ b/pgmax/factors/logical.py
@@ -2,7 +2,7 @@
 
 import functools
 from dataclasses import dataclass, field
-from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union
+from typing import List, Mapping, Optional, Sequence, Tuple, Union
 
 import jax
 import jax.numpy as jnp
@@ -143,7 +143,7 @@ def compile_wiring(
         factor_edges_num_states: np.ndarray,
         variables_for_factors: Sequence[List],
         factor_sizes: np.ndarray,
-        vars_to_starts: Mapping[Tuple[Any, int], int],
+        vars_to_starts: Mapping[Tuple[int, int], int],
         edge_states_offset: int,
     ) -> LogicalWiring:
         """Compile a LogicalWiring for a LogicalFactor or a FactorGroup with LogicalFactors.
@@ -164,7 +164,6 @@ def compile_wiring(
         Returns:
             The LogicalWiring
         """
-        # TODO: Don't use vars_to_starts
         var_states = []
         for variables_for_factor in variables_for_factors:
             for variable in variables_for_factor:

From 1e0c93d6675d6eae41fba7c03ff25b67ab0b8021 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Wed, 27 Apr 2022 23:50:35 +0000
Subject: [PATCH 30/35] Docstring

---
 pgmax/fg/groups.py        | 2 +-
 pgmax/fg/nodes.py         | 4 ++--
 pgmax/groups/variables.py | 2 +-
 tests/fg/test_groups.py   | 2 +-
 4 files changed, 5 insertions(+), 5 deletions(-)

diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index e8079a36..e705f477 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -271,7 +271,7 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any:
             "Please subclass the FactorGroup class and override this method"
         )
 
-    def compile_wiring(self, vars_to_starts: Mapping[Tuple[Any, int], int]) -> Any:
+    def compile_wiring(self, vars_to_starts: Mapping[Tuple[int, int], int]) -> Any:
         """Compile an efficient wiring for the FactorGroup.
 
         Args:
diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py
index f987b8df..06d7c5d6 100644
--- a/pgmax/fg/nodes.py
+++ b/pgmax/fg/nodes.py
@@ -1,7 +1,7 @@
 """A module containing classes that specify the basic components of a Factor Graph."""
 
 from dataclasses import asdict, dataclass
-from typing import Any, List, Sequence, Tuple, Union
+from typing import List, Sequence, Tuple, Union
 
 import jax
 import jax.numpy as jnp
@@ -48,7 +48,7 @@ class Factor:
         NotImplementedError: If compile_wiring is not implemented
     """
 
-    variables: List[Tuple[Any, int]]
+    variables: List[Tuple[int, int]]
     log_potentials: np.ndarray
 
     def __post_init__(self):
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index 22a7ef0e..87371c95 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -212,7 +212,7 @@ def __getitem__(self, var_name: Any) -> Tuple[int, int]:
         return (self.variable_hashes[var_idx], self.num_states[var_idx])
 
     def flatten(
-        self, data: Mapping[Tuple[Tuple[int, int], int], Union[np.ndarray, jnp.ndarray]]
+        self, data: Mapping[Any, Union[np.ndarray, jnp.ndarray]]
     ) -> jnp.ndarray:
         """Function that turns meaningful structured data into a flat data array for internal use.
 
diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py
index 7799ffb0..e39fbda3 100644
--- a/tests/fg/test_groups.py
+++ b/tests/fg/test_groups.py
@@ -24,7 +24,7 @@ def test_variable_dict():
             "num_states should be an integer or a NumPy array of dtype int"
         ),
     ):
-        vgroup.NDVariableArray(shape=(2, 2), num_states=num_states)
+        vgroup.VariableDict(variable_names=tuple([0, 1, 2]), num_states=num_states)
 
     variable_dict = vgroup.VariableDict(variable_names=tuple([0, 1, 2]), num_states=15)
     with pytest.raises(

From 63c6738ef19d1dfcb5ca54b852bd9a6e8f99e87f Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Thu, 28 Apr 2022 18:16:35 +0000
Subject: [PATCH 31/35] Minor changes

---
 examples/pmp_binary_deconvolution.py | 2 +-
 examples/rcn.py                      | 4 ++--
 pgmax/fg/graph.py                    | 6 +++---
 pgmax/fg/groups.py                   | 8 ++++----
 pgmax/fg/nodes.py                    | 2 +-
 pgmax/groups/variables.py            | 2 --
 6 files changed, 11 insertions(+), 13 deletions(-)

diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py
index a8b992e5..190abe36 100644
--- a/examples/pmp_binary_deconvolution.py
+++ b/examples/pmp_binary_deconvolution.py
@@ -220,7 +220,7 @@ def plot_images(images, display=True, nr=None):
 
 # %%
 pW = 0.25
-pS = 1e-70
+pS = 1e-100
 pX = 1e-100
 
 # Sparsity inducing priors for W and S
diff --git a/examples/rcn.py b/examples/rcn.py
index 140c1712..e83cf10c 100644
--- a/examples/rcn.py
+++ b/examples/rcn.py
@@ -417,8 +417,8 @@ def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray:
         evidence_updates,
         map_states,
     )
-    for idx, score in enumerate(score.values()):
-        scores[test_idx, idx] = score
+    for model_idx in range(frcs.shape[0]):
+        scores[test_idx, model_idx] = score[variables_all_models[model_idx]]
     end = time.time()
     print(f"Computing scores took {end-start:.3f} seconds for image {test_idx}.")
 
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 21ea1f11..77a51f71 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -323,7 +323,7 @@ class FactorGraphState:
     """
 
     variable_groups: Sequence[groups.VariableGroup]
-    vars_to_starts: Mapping[Tuple[Any, int], int]
+    vars_to_starts: Mapping[Tuple[int, int], int]
     num_var_states: int
     total_factor_num_states: int
     factor_type_to_msgs_range: OrderedDict[type, Tuple[int, int]]
@@ -460,7 +460,7 @@ def __getitem__(self, factor_group: groups.FactorGroup) -> np.ndarray:
 
     def __setitem__(
         self,
-        factor_group: Any,
+        factor_group: groups.FactorGroup,
         data: Union[np.ndarray, jnp.ndarray],
     ):
         """Set the log potentials for a FactorGroup
@@ -563,7 +563,7 @@ def __setitem__(
         """Spreading beliefs at a variable to all connected Factors
 
         Args:
-            variable: A tuple representing a variable
+            variable: Variable queried
             data: An array containing the beliefs to be spread uniformly
                 across all factors to variable messages involving this variable.
         """
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index e705f477..311d53b1 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -29,14 +29,14 @@
 @dataclass(frozen=True, eq=False)
 class VariableGroup:
     """Class to represent a group of variables.
-    Each variable is represented via a tuple of the form (variable hash/name, number of states)
+    Each variable is represented via a tuple of the form (variable hash, variable num_states)
 
     Arguments:
         num_states: An integer or an array specifying the number of states of the variables
             in this VariableGroup
     """
 
-    num_states: Union[int, np.ndarray] = field(init=False)
+    num_states: Union[int, np.ndarray]
 
     def __post_init__(self):
         # Only compute the hash once, which is guaranteed to be an int64
@@ -56,7 +56,7 @@ def __lt__(self, other):
 
     def __getitem__(self, val: Any) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
         """Given a variable name, index, or a group of variable indices, retrieve the associated variable(s).
-        Each variable is returned via a tuple of the form (variable hash, number of states)
+        Each variable is returned via a tuple of the form (variable hash, variable num_states)
 
         Args:
             val: a variable index, slice, or name
@@ -82,7 +82,7 @@ def variable_hashes(self) -> np.ndarray:
     @cached_property
     def variables(self) -> List[Tuple[int, int]]:
         """Function that returns the list of all variables in the VariableGroup.
-        Each variable is represented by a tuple of the form (variable hash, number of states)
+        Each variable is represented by a tuple of the form (variable hash, variable num_states)
 
         Returns:
             List of variables in the VariableGroup
diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py
index 06d7c5d6..37cfce13 100644
--- a/pgmax/fg/nodes.py
+++ b/pgmax/fg/nodes.py
@@ -42,7 +42,7 @@ class Factor:
 
     Args:
         variables: List of variables connected by the Factor.
-            Each variable is represented by a tuple of the form (variable hash/name, number of states)
+            Each variable is represented by a tuple (variable hash, variable num_ states)
 
     Raises:
         NotImplementedError: If compile_wiring is not implemented
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
index 87371c95..357d681a 100644
--- a/pgmax/groups/variables.py
+++ b/pgmax/groups/variables.py
@@ -22,7 +22,6 @@ class NDVariableArray(groups.VariableGroup):
             the notion of a NumPy ndarray shape)
     """
 
-    num_states: Union[int, np.ndarray]
     shape: Tuple[int, ...]
 
     def __post_init__(self):
@@ -161,7 +160,6 @@ class VariableDict(groups.VariableGroup):
         variable_names: A tuple of all the names of the variables in this VariableGroup.
     """
 
-    num_states: Union[int, np.ndarray]
     variable_names: Tuple[Any, ...]
 
     def __post_init__(self):

From 0397f9b788c95666bab3dad6680d380731ec59f7 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Thu, 28 Apr 2022 18:24:52 +0000
Subject: [PATCH 32/35] Doc

---
 pgmax/fg/nodes.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py
index 37cfce13..8388a61e 100644
--- a/pgmax/fg/nodes.py
+++ b/pgmax/fg/nodes.py
@@ -42,7 +42,7 @@ class Factor:
 
     Args:
         variables: List of variables connected by the Factor.
-            Each variable is represented by a tuple (variable hash, variable num_ states)
+            Each variable is represented by a tuple (variable hash, variable num_states)
 
     Raises:
         NotImplementedError: If compile_wiring is not implemented

From 8b3d60e131e9d31d86db9b1647b94f9b81244620 Mon Sep 17 00:00:00 2001
From: stannis <stannis@vicarious.com>
Date: Fri, 29 Apr 2022 22:51:02 -0700
Subject: [PATCH 33/35] Rename this_hash

---
 pgmax/fg/groups.py | 8 ++++----
 1 file changed, 4 insertions(+), 4 deletions(-)

diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index 311d53b1..2a3d691a 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -41,12 +41,12 @@ class VariableGroup:
     def __post_init__(self):
         # Only compute the hash once, which is guaranteed to be an int64
         this_id = id(self) % 2**32
-        this_hash = this_id * int(MAX_SIZE)
-        assert this_hash < 2**63
-        object.__setattr__(self, "this_hash", this_hash)
+        _hash = this_id * int(MAX_SIZE)
+        assert _hash < 2**63
+        object.__setattr__(self, "_hash", _hash)
 
     def __hash__(self):
-        return self.this_hash
+        return self._hash
 
     def __eq__(self, other):
         return hash(self) == hash(other)

From 8751e90c243564f4e43ec62e2f68cf4d1523b5b7 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Mon, 2 May 2022 18:37:53 +0000
Subject: [PATCH 34/35] Final comments

---
 examples/gmrf.py                     | 14 +++----
 examples/pmp_binary_deconvolution.py | 24 +++--------
 examples/rbm.py                      | 59 +++++++++-------------------
 pgmax/fg/graph.py                    |  1 -
 pgmax/fg/groups.py                   |  1 -
 5 files changed, 30 insertions(+), 69 deletions(-)

diff --git a/examples/gmrf.py b/examples/gmrf.py
index 1bf4fad9..50e95dc7 100644
--- a/examples/gmrf.py
+++ b/examples/gmrf.py
@@ -58,7 +58,7 @@
 fg = graph.FactorGraph(variables)
 
 # %%
-# Add top-down factors
+# Create top-down factors
 top_down = enumeration.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii + 1, jj]]
@@ -66,9 +66,8 @@
         for jj in range(N)
     ],
 )
-fg.add_factors(top_down)
 
-# Add left-right factors
+# Create left-right factors
 left_right = enumeration.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii, jj + 1]]
@@ -76,9 +75,8 @@
         for jj in range(N - 1)
     ],
 )
-fg.add_factors(left_right)
 
-# Add diagonal factors
+# Create diagonal factors
 diagonal0 = enumeration.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii + 1, jj + 1]]
@@ -86,8 +84,6 @@
         for jj in range(N - 1)
     ],
 )
-fg.add_factors(diagonal0)
-
 diagonal1 = enumeration.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii - 1, jj + 1]]
@@ -95,7 +91,9 @@
         for jj in range(N - 1)
     ],
 )
-fg.add_factors(diagonal1)
+
+# Add factors
+fg.add_factors([top_down, left_right, diagonal0, diagonal1])
 
 # %%
 bp = graph.BP(fg.bp_state, temperature=1.0)
diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py
index 190abe36..43e628af 100644
--- a/examples/pmp_binary_deconvolution.py
+++ b/examples/pmp_binary_deconvolution.py
@@ -116,9 +116,6 @@ def plot_images(images, display=True, nr=None):
 s_height = im_height - feat_height + 1
 s_width = im_width - feat_width + 1
 
-import time
-
-start = time.time()
 # Binary features
 W = vgroup.NDVariableArray(
     num_states=2, shape=(n_chan, n_feat, feat_height, feat_width)
@@ -135,16 +132,13 @@ def plot_images(images, display=True, nr=None):
 
 # Binary images obtained by convolution
 X = vgroup.NDVariableArray(num_states=2, shape=X_gt.shape)
-print("Time", time.time() - start)
 
 # %% [markdown]
-# For computation efficiency, we construct large FactorGroups instead of individual Factors
+# For computation efficiency, we construct large FactorGroups instead of individual factors
 
 # %%
-start = time.time()
 # Factor graph
 fg = graph.FactorGraph(variable_groups=[S, W, SW, X])
-print(time.time() - start)
 
 # Define the ANDFactors
 variables_for_ANDFactors = []
@@ -177,12 +171,10 @@ def plot_images(images, display=True, nr=None):
 
                             X_var = X[idx_img, idx_chan, idx_img_height, idx_img_width]
                             variables_for_ORFactors_dict[X_var].append(SW_var)
-print("After loop", time.time() - start)
 
 # Add ANDFactorGroup, which is computationally efficient
 AND_factor_group = logical.ANDFactorGroup(variables_for_ANDFactors)
 fg.add_factors(AND_factor_group)
-print(time.time() - start)
 
 # Define the ORFactors
 variables_for_ORFactors = [
@@ -193,7 +185,6 @@ def plot_images(images, display=True, nr=None):
 # Add ORFactorGroup, which is computationally efficient
 OR_factor_group = logical.ORFactorGroup(variables_for_ORFactors)
 fg.add_factors(OR_factor_group)
-print("Time", time.time() - start)
 
 for factor_type, factor_groups in fg.factor_groups.items():
     if len(factor_groups) > 0:
@@ -211,16 +202,14 @@ def plot_images(images, display=True, nr=None):
 # in the same manner does not change X, so this naturally results in multiple equivalent modes.
 
 # %%
-start = time.time()
 bp = graph.BP(fg.bp_state, temperature=0.0)
-print("Time", time.time() - start)
 
 # %% [markdown]
 # We first compute the evidence without perturbation, similar to the PMP paper.
 
 # %%
 pW = 0.25
-pS = 1e-100
+pS = 1e-75
 pX = 1e-100
 
 # Sparsity inducing priors for W and S
@@ -235,13 +224,12 @@ def plot_images(images, display=True, nr=None):
 uX[..., 0] = (2 * X_gt - 1) * logit(pX)
 
 # %% [markdown]
-# We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap`
+# We draw a batch of samples from the posterior in parallel by transforming `bp.init`/`bp.run_bp`/`bp.get_beliefs` with `jax.vmap`
 
 # %%
-np.random.seed(seed=42)
+np.random.seed(seed=0)
 n_samples = 4
 
-start = time.time()
 bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)(
     evidence_updates={
         S: uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape),
@@ -250,13 +238,13 @@ def plot_images(images, display=True, nr=None):
         X: uX[None] + np.zeros(shape=(n_samples,) + uX.shape),
     },
 )
-print("Time", time.time() - start)
+
 bp_arrays = jax.vmap(
     functools.partial(bp.run_bp, num_iters=100, damping=0.5),
     in_axes=0,
     out_axes=0,
 )(bp_arrays)
-print("Time", time.time() - start)
+
 beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
 map_states = graph.decode_map_states(beliefs)
 
diff --git a/examples/rbm.py b/examples/rbm.py
index 4b391a62..c29efbd7 100644
--- a/examples/rbm.py
+++ b/examples/rbm.py
@@ -46,37 +46,30 @@
 # We can then initialize the factor graph for the RBM with
 
 # %%
-import time
-
-start = time.time()
 # Initialize factor graph
 hidden_variables = vgroup.NDVariableArray(num_states=2, shape=bh.shape)
 visible_variables = vgroup.NDVariableArray(num_states=2, shape=bv.shape)
 fg = graph.FactorGraph(variable_groups=[hidden_variables, visible_variables])
-print("Time", time.time() - start)
 
 # %% [markdown]
 # [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup) (e.g. an [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)), or a list/dictionary of [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
 #
-# After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors by grouping them using
+# After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors by grouping them using [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)
 
 # %%
-start = time.time()
-
-# Add unary factors
+# Create unary factors
 hidden_unaries = enumeration.EnumerationFactorGroup(
     variables_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])],
     factor_configs=np.arange(2)[:, None],
     log_potentials=np.stack([np.zeros_like(bh), bh], axis=1),
 )
-
 visible_unaries = enumeration.EnumerationFactorGroup(
     variables_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])],
     factor_configs=np.arange(2)[:, None],
     log_potentials=np.stack([np.zeros_like(bv), bv], axis=1),
 )
 
-# Add pairwise factors
+# Create pairwise factors
 log_potential_matrix = np.zeros(W.shape + (2, 2)).reshape((-1, 2, 2))
 log_potential_matrix[:, 1, 1] = W.ravel()
 
@@ -85,25 +78,23 @@
     for ii in range(bh.shape[0])
     for jj in range(bv.shape[0])
 ]
-print("Time", time.time() - start)
 pairwise_factors = enumeration.PairwiseFactorGroup(
     variables_for_factors=variables_for_factors,
     log_potential_matrix=log_potential_matrix,
 )
-print("Time", time.time() - start)
 
+# Add factors to the FactorGraph
 fg.add_factors([hidden_unaries, visible_unaries, pairwise_factors])
-print("Time", time.time() - start)
 
 
 # %% [markdown]
-# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing Groups of similar factors. The code above makes use of [`EnumerationFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumerationFactorGroup.html#pgmax.fg.groups.EnumerationFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup), two [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)s implemented in the [`pgmax.fg.groups`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.html#module-pgmax.fg.graph) module.
+# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing groups of similar factors. The code above makes use of [`EnumerationFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumerationFactorGroup.html#pgmax.fg.groups.EnumerationFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup).
 #
-# A [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) is created by calling [`fg.add_factor_group`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph.add_factor_group), which takes 2 arguments: `factory` which specifies the [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) subclass, `variable_names_for_factors` which is a list of lists containing the name of the involved variables in the different factors, and additional arguments for the [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) (e.g. `factor_configs` or `log_potential_matrix` here).
+# A [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) takes as argument `variables_for_factors` which is a list of lists of the variables involved in the different factors, and additional arguments specific to each [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) (e.g. `factor_configs` or `log_potential_matrix` here).
 #
-# In this example, since we construct `fg` with variables `dict(hidden=hidden_variables, visible=visible_variables)`, where `hidden_variables` and `visible_variables` are [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)s, we can refer to the `ii`th hidden variable as `("hidden", ii)` and the `jj`th visible variable as `("visible", jj)`. In general, PGMax implements an intuitive scheme for automatically assigning names to the variables in a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph).
+# In this example, since we construct `fg` with variables `hidden_variables` and `visible_variables`, which are both [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)s, we can refer to the `ii`th hidden variable as `hidden_variables[ii]` and the `jj`th visible variable as `visible_variables[jj]`.
 #
-# An alternative way of creating the above factors is to add them iteratively by calling [`fg.add_factor`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph.add_factor) as below. This approach is not recommended as it is not computationally efficient.
+# An alternative way of creating the above factors is to add them iteratively without building the [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)s as below. This approach is not recommended as it is not computationally efficient.
 # ~~~python
 # from pgmax.factors import enumeration as enumeration_factor
 # import itertools
@@ -155,30 +146,28 @@
 # More generally, PGMax implements LBP with temperature, with `temperature=0.0` and `temperature=1.0` corresponding to the commonly used max/sum-product LBP respectively.
 #
 # Now we are ready to demonstrate PMP sampling from RBM. PMP perturbs the model with [Gumbel](https://numpy.org/doc/stable/reference/random/generated/numpy.random.gumbel.html) unary potentials, and draws a sample from the RBM as the MAP decoding from running max-product LBP on the perturbed model
+#
+# import itertools
+#
+# from tqdm import tqdm
 
 # %%
-start = time.time()
 bp = graph.BP(fg.bp_state, temperature=0.0)
-print("Time", time.time() - start)
 
 # %%
-start = time.time()
 bp_arrays = bp.init(
     evidence_updates={
         hidden_variables: np.random.gumbel(size=(bh.shape[0], 2)),
         visible_variables: np.random.gumbel(size=(bv.shape[0], 2)),
     }
 )
-print("Time", time.time() - start)
 bp_arrays = bp.run_bp(bp_arrays, num_iters=100, damping=0.5)
-print("Time", time.time() - start)
 beliefs = bp.get_beliefs(bp_arrays)
-print("Time", time.time() - start)
 
 # %% [markdown]
-# Here we use the `evidence_updates` argument of `run_bp` to perturb the model with Gumbel unary potentials. In general, `evidence_updates` can be used to incorporate evidence in the form of externally applied unary potentials in PGM inference.
+# Here we use the `evidence_updates` argument of `bp.init` to perturb the model with Gumbel unary potentials. In general, `evidence_updates` can be used to incorporate evidence in the form of externally applied unary potentials in PGM inference.
 #
-# Visualizing the MAP decoding (Figure [fig:rbm_single_digit]), we see that we have sampled an MNIST digit!
+# Visualizing the MAP decoding, we see that we have sampled an MNIST digit!
 
 # %%
 fig, ax = plt.subplots(1, 1, figsize=(10, 10))
@@ -191,23 +180,11 @@
 # %% [markdown]
 # PGMax adopts a functional interface for implementing LBP: running LBP in PGMax starts with
 # ~~~python
-# run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, num_iters=NUM_ITERS, temperature=T)
+# bp = graph.BP(fg.bp_state, temperature=T)
 # ~~~
-# where `run_bp` and `get_beliefs` are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`/`vmap`/`grad`, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see [here](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) and [here](https://github.com/n2cholas/awesome-jax)). In what follows we demonstrate an example on applying `jax.vmap`, a convenient transformation for automatically vectorizing functions.
+# where the arguments of the `bp` are several useful functions to run LBP. In particular, `bp.init`, `bp.run_bp`, `bp.get_beliefs` are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`/`vmap`/`grad`, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see [here](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) and [here](https://github.com/n2cholas/awesome-jax)).
 #
-# Since we implement `run_bp`/`get_beliefs` as a pure function, we can apply `jax.vmap` to `run_bp`/`get_beliefs` to process a batch of samples/models in parallel. As an example, consider the PGMax implementation of PMP sampling from the RBM trained on MNIST images in Section [Tutorial: implementing LBP inference for RBMs with PGMax]. Instead of drawing one sample at a time
-# ~~~python
-# bp_arrays = run_bp(
-#     evidence_updates={
-#         hidden_variables: np.random.gumbel(size=(bh.shape[0], 2)),
-#         visible_variables: np.random.gumbel(size=(bv.shape[0], 2)),
-#     },
-#     damping=0.5,
-# )
-# beliefs = get_beliefs(bp_arrays)
-# map_states = graph.decode_map_states(beliefs)
-# ~~~
-# we can draw a batch of samples in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap`
+# As an example of applying `jax.vmap` to `bp.init`/`bp.run_bp`/`bp.get_beliefs` to process a batch of samples/models in parallel, instead of drawing one sample at a time as above, we can draw a batch of samples in parallel as follows:
 
 # %%
 n_samples = 10
@@ -227,7 +204,7 @@
 map_states = graph.decode_map_states(beliefs)
 
 # %% [markdown]
-# Visualizing the MAP decodings (Figure [fig:rbm_multiple_digits]), we see that we have sampled 10 MNIST digits in parallel!
+# Visualizing the MAP decodings, we see that we have sampled 10 MNIST digits in parallel!
 
 # %%
 fig, ax = plt.subplots(2, 5, figsize=(20, 8))
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
index 77a51f71..16088e9f 100644
--- a/pgmax/fg/graph.py
+++ b/pgmax/fg/graph.py
@@ -65,7 +65,6 @@ def __post_init__(self):
         )
 
         # See FactorGraphState docstrings for documentation on the following fields
-        # TODO: vars_to_starts does not have to be a dict
         self._vars_to_starts: Dict[Tuple[int, int], int] = {}
 
         vars_num_states_cumsum = 0
diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py
index 2a3d691a..e772095a 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fg/groups.py
@@ -200,7 +200,6 @@ def factor_edges_num_states(self) -> np.ndarray:
         """
         factor_edges_num_states = np.empty(shape=(self.factor_sizes.sum(),), dtype=int)
 
-        # TODO: create variables_for_factors as an array and move this to numba
         idx = 0
         for variables_for_factor in self.variables_for_factors:
             for variable in variables_for_factor:

From b265f14c42b657c88103f35e97047f92426575e7 Mon Sep 17 00:00:00 2001
From: Antoine Dedieu <antoine@vicarious.com>
Date: Mon, 2 May 2022 18:39:20 +0000
Subject: [PATCH 35/35] Minor

---
 examples/rbm.py | 4 ----
 1 file changed, 4 deletions(-)

diff --git a/examples/rbm.py b/examples/rbm.py
index c29efbd7..5e751f44 100644
--- a/examples/rbm.py
+++ b/examples/rbm.py
@@ -146,10 +146,6 @@
 # More generally, PGMax implements LBP with temperature, with `temperature=0.0` and `temperature=1.0` corresponding to the commonly used max/sum-product LBP respectively.
 #
 # Now we are ready to demonstrate PMP sampling from RBM. PMP perturbs the model with [Gumbel](https://numpy.org/doc/stable/reference/random/generated/numpy.random.gumbel.html) unary potentials, and draws a sample from the RBM as the MAP decoding from running max-product LBP on the perturbed model
-#
-# import itertools
-#
-# from tqdm import tqdm
 
 # %%
 bp = graph.BP(fg.bp_state, temperature=0.0)