diff --git a/examples/gmrf.py b/examples/gmrf.py
index 50e95dc7..5228d0d7 100644
--- a/examples/gmrf.py
+++ b/examples/gmrf.py
@@ -22,9 +22,7 @@
 from jax.example_libraries import optimizers
 from tqdm.notebook import tqdm
 
-from pgmax.fg import graph
-from pgmax.groups import enumeration
-from pgmax.groups import variables as vgroup
+from pgmax import fgraph, fgroup, infer, vgroup
 
 # %% [markdown]
 # # Visualize a trained GMRF
@@ -54,12 +52,12 @@
 # %%
 M, N = target_images.shape[-2:]
 num_states = np.sum(n_clones)
-variables = vgroup.NDVariableArray(num_states=num_states, shape=(M, N))
-fg = graph.FactorGraph(variables)
+variables = vgroup.NDVarArray(num_states=num_states, shape=(M, N))
+fg = fgraph.FactorGraph(variables)
 
 # %%
 # Create top-down factors
-top_down = enumeration.PairwiseFactorGroup(
+top_down = fgroup.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii + 1, jj]]
         for ii in range(M - 1)
@@ -68,7 +66,7 @@
 )
 
 # Create left-right factors
-left_right = enumeration.PairwiseFactorGroup(
+left_right = fgroup.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii, jj + 1]]
         for ii in range(M)
@@ -77,14 +75,14 @@
 )
 
 # Create diagonal factors
-diagonal0 = enumeration.PairwiseFactorGroup(
+diagonal0 = fgroup.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii + 1, jj + 1]]
         for ii in range(M - 1)
         for jj in range(N - 1)
     ],
 )
-diagonal1 = enumeration.PairwiseFactorGroup(
+diagonal1 = fgroup.PairwiseFactorGroup(
     variables_for_factors=[
         [variables[ii, jj], variables[ii - 1, jj + 1]]
         for ii in range(1, M)
@@ -96,7 +94,7 @@
 fg.add_factors([top_down, left_right, diagonal0, diagonal1])
 
 # %%
-bp = graph.BP(fg.bp_state, temperature=1.0)
+bp = infer.BP(fg.bp_state, temperature=1.0)
 
 # %%
 log_potentials = {
@@ -114,7 +112,7 @@
     target_image = target_images[idx]
     evidence = jnp.log(jnp.where(noisy_image[..., None] == 0, p_contour, 1 - p_contour))
     target = prototype_targets[target_image]
-    marginals = graph.get_marginals(
+    marginals = infer.get_marginals(
         bp.get_beliefs(
             bp.run_bp(
                 bp.init(
@@ -162,7 +160,7 @@
 def loss(noisy_image, target_image, log_potentials):
     evidence = jnp.log(jnp.where(noisy_image[..., None] == 0, p_contour, 1 - p_contour))
     target = prototype_targets[target_image]
-    marginals = graph.get_marginals(
+    marginals = infer.get_marginals(
         bp.get_beliefs(
             bp.run_bp(
                 bp.init(
diff --git a/examples/ising_model.py b/examples/ising_model.py
index bf30c7bd..1befd272 100644
--- a/examples/ising_model.py
+++ b/examples/ising_model.py
@@ -20,16 +20,14 @@
 import matplotlib.pyplot as plt
 import numpy as np
 
-from pgmax.fg import graph
-from pgmax.groups import enumeration
-from pgmax.groups import variables as vgroup
+from pgmax import fgraph, fgroup, infer, vgroup
 
 # %% [markdown]
 # ### Construct variable grid, initialize factor graph, and add factors
 
 # %%
-variables = vgroup.NDVariableArray(num_states=2, shape=(50, 50))
-fg = graph.FactorGraph(variable_groups=variables)
+variables = vgroup.NDVarArray(num_states=2, shape=(50, 50))
+fg = fgraph.FactorGraph(variable_groups=variables)
 
 variables_for_factors = []
 for ii in range(50):
@@ -39,7 +37,7 @@
         variables_for_factors.append([variables[ii, jj], variables[kk, jj]])
         variables_for_factors.append([variables[ii, jj], variables[ii, ll]])
 
-factor_group = enumeration.PairwiseFactorGroup(
+factor_group = fgroup.PairwiseFactorGroup(
     variables_for_factors=variables_for_factors,
     log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]),
 )
@@ -49,7 +47,7 @@
 # ### Run inference and visualize results
 
 # %%
-bp = graph.BP(fg.bp_state, temperature=0)
+bp = infer.BP(fg.bp_state, temperature=0)
 
 # %%
 bp_arrays = bp.init(
@@ -59,7 +57,7 @@
 beliefs = bp.get_beliefs(bp_arrays)
 
 # %%
-img = graph.decode_map_states(beliefs)[variables]
+img = infer.decode_map_states(beliefs)[variables]
 fig, ax = plt.subplots(1, 1, figsize=(10, 10))
 ax.imshow(img)
 
diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py
index 43e628af..d8ec02a5 100644
--- a/examples/pmp_binary_deconvolution.py
+++ b/examples/pmp_binary_deconvolution.py
@@ -28,9 +28,7 @@
 from scipy.special import logit
 from tqdm.notebook import tqdm
 
-from pgmax.fg import graph
-from pgmax.groups import logical
-from pgmax.groups import variables as vgroup
+from pgmax import fgraph, fgroup, infer, vgroup
 
 
 # %%
@@ -117,28 +115,26 @@ def plot_images(images, display=True, nr=None):
 s_width = im_width - feat_width + 1
 
 # Binary features
-W = vgroup.NDVariableArray(
-    num_states=2, shape=(n_chan, n_feat, feat_height, feat_width)
-)
+W = vgroup.NDVarArray(num_states=2, shape=(n_chan, n_feat, feat_height, feat_width))
 
 # Binary indicators of features locations
-S = vgroup.NDVariableArray(num_states=2, shape=(n_images, n_feat, s_height, s_width))
+S = vgroup.NDVarArray(num_states=2, shape=(n_images, n_feat, s_height, s_width))
 
 # Auxiliary binary variables combining W and S
-SW = vgroup.NDVariableArray(
+SW = vgroup.NDVarArray(
     num_states=2,
     shape=(n_images, n_chan, im_height, im_width, n_feat, feat_height, feat_width),
 )
 
 # Binary images obtained by convolution
-X = vgroup.NDVariableArray(num_states=2, shape=X_gt.shape)
+X = vgroup.NDVarArray(num_states=2, shape=X_gt.shape)
 
 # %% [markdown]
 # For computation efficiency, we construct large FactorGroups instead of individual factors
 
 # %%
 # Factor graph
-fg = graph.FactorGraph(variable_groups=[S, W, SW, X])
+fg = fgraph.FactorGraph(variable_groups=[S, W, SW, X])
 
 # Define the ANDFactors
 variables_for_ANDFactors = []
@@ -173,7 +169,7 @@ def plot_images(images, display=True, nr=None):
                             variables_for_ORFactors_dict[X_var].append(SW_var)
 
 # Add ANDFactorGroup, which is computationally efficient
-AND_factor_group = logical.ANDFactorGroup(variables_for_ANDFactors)
+AND_factor_group = fgroup.ANDFactorGroup(variables_for_ANDFactors)
 fg.add_factors(AND_factor_group)
 
 # Define the ORFactors
@@ -183,7 +179,7 @@ def plot_images(images, display=True, nr=None):
 ]
 
 # Add ORFactorGroup, which is computationally efficient
-OR_factor_group = logical.ORFactorGroup(variables_for_ORFactors)
+OR_factor_group = fgroup.ORFactorGroup(variables_for_ORFactors)
 fg.add_factors(OR_factor_group)
 
 for factor_type, factor_groups in fg.factor_groups.items():
@@ -202,7 +198,7 @@ def plot_images(images, display=True, nr=None):
 # in the same manner does not change X, so this naturally results in multiple equivalent modes.
 
 # %%
-bp = graph.BP(fg.bp_state, temperature=0.0)
+bp = infer.BP(fg.bp_state, temperature=0.0)
 
 # %% [markdown]
 # We first compute the evidence without perturbation, similar to the PMP paper.
@@ -246,7 +242,7 @@ def plot_images(images, display=True, nr=None):
 )(bp_arrays)
 
 beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
-map_states = graph.decode_map_states(beliefs)
+map_states = infer.decode_map_states(beliefs)
 
 # %% [markdown]
 # Visualizing the MAP decoding, we see that we have 4 good random samples (one per row) from the posterior!
diff --git a/examples/rbm.py b/examples/rbm.py
index 5e751f44..a70a0486 100644
--- a/examples/rbm.py
+++ b/examples/rbm.py
@@ -26,12 +26,10 @@
 import matplotlib.pyplot as plt
 import numpy as np
 
-from pgmax.fg import graph
-from pgmax.groups import enumeration
-from pgmax.groups import variables as vgroup
+from pgmax import fgraph, fgroup, infer, vgroup
 
 # %% [markdown]
-# The [`pgmax.fg.graph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.html#module-pgmax.fg.graph) module contains core classes for specifying factor graphs and implementing LBP, while the [`pgmax.fg.groups`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.html#module-pgmax.fg.graph) module contains classes for specifying groups of variables/factors.
+# The [`pgmax.fgraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgraph.html#module-pgmax.fgraph) module contains classes for specifying factor graphs, the [`pgmax.fgroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.html#module-pgmax.vgroup) module contains classes for specifying groups of variables, the [`pgmax.vgroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fgroup.html#module-pgmax.fgroup) module contains classes for specifying groups of factors and the [`pgmax.infer`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.infer.html#module-pgmax.infer) module containing core functions to perform LBP.
 #
 # We next load the RBM trained in Sec. 5.5 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) on MNIST digits.
 
@@ -47,23 +45,23 @@
 
 # %%
 # Initialize factor graph
-hidden_variables = vgroup.NDVariableArray(num_states=2, shape=bh.shape)
-visible_variables = vgroup.NDVariableArray(num_states=2, shape=bv.shape)
-fg = graph.FactorGraph(variable_groups=[hidden_variables, visible_variables])
+hidden_variables = vgroup.NDVarArray(num_states=2, shape=bh.shape)
+visible_variables = vgroup.NDVarArray(num_states=2, shape=bv.shape)
+fg = fgraph.FactorGraph(variable_groups=[hidden_variables, visible_variables])
 
 # %% [markdown]
-# [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup) (e.g. an [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)), or a list/dictionary of [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
+# [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VarGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VarGroup.html#pgmax.fg.groups.VarGroup) (e.g. an [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray)), or a list of [`VarGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VarGroup.html#pgmax.fg.groups.VarGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
 #
 # After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors by grouping them using [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)
 
 # %%
 # Create unary factors
-hidden_unaries = enumeration.EnumerationFactorGroup(
+hidden_unaries = fgroup.EnumFactorGroup(
     variables_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])],
     factor_configs=np.arange(2)[:, None],
     log_potentials=np.stack([np.zeros_like(bh), bh], axis=1),
 )
-visible_unaries = enumeration.EnumerationFactorGroup(
+visible_unaries = fgroup.EnumFactorGroup(
     variables_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])],
     factor_configs=np.arange(2)[:, None],
     log_potentials=np.stack([np.zeros_like(bv), bv], axis=1),
@@ -78,7 +76,7 @@
     for ii in range(bh.shape[0])
     for jj in range(bv.shape[0])
 ]
-pairwise_factors = enumeration.PairwiseFactorGroup(
+pairwise_factors = fgroup.PairwiseFactorGroup(
     variables_for_factors=variables_for_factors,
     log_potential_matrix=log_potential_matrix,
 )
@@ -88,67 +86,67 @@
 
 
 # %% [markdown]
-# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing groups of similar factors. The code above makes use of [`EnumerationFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumerationFactorGroup.html#pgmax.fg.groups.EnumerationFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup).
+# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing groups of similar factors. The code above makes use of [`EnumFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumFactorGroup.html#pgmax.fg.groups.EnumFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup).
 #
 # A [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) takes as argument `variables_for_factors` which is a list of lists of the variables involved in the different factors, and additional arguments specific to each [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) (e.g. `factor_configs` or `log_potential_matrix` here).
 #
-# In this example, since we construct `fg` with variables `hidden_variables` and `visible_variables`, which are both [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)s, we can refer to the `ii`th hidden variable as `hidden_variables[ii]` and the `jj`th visible variable as `visible_variables[jj]`.
+# In this example, since we construct `fg` with variables `hidden_variables` and `visible_variables`, which are both [`NDVarArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVarArray.html#pgmax.fg.groups.NDVarArray)s, we can refer to the `ii`th hidden variable as `hidden_variables[ii]` and the `jj`th visible variable as `visible_variables[jj]`.
 #
 # An alternative way of creating the above factors is to add them iteratively without building the [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)s as below. This approach is not recommended as it is not computationally efficient.
 # ~~~python
-# from pgmax.factors import enumeration as enumeration_factor
+# from pgmax import factor
 # import itertools
 # from tqdm import tqdm
 #
 # # Add unary factors
 # for ii in range(bh.shape[0]):
-#     factor = enumeration_factor.EnumerationFactor(
+#     unary_factor = factor.EnumFactor(
 #         variables=[hidden_variables[ii]],
 #         factor_configs=np.arange(2)[:, None],
 #         log_potentials=np.array([0, bh[ii]]),
 #     )
-#     fg.add_factors(factor)
+#     fg.add_factors(unary_factor)
 #
 # for jj in range(bv.shape[0]):
-#     factor = enumeration_factor.EnumerationFactor(
+#     unary_factor = factor.EnumFactor(
 #         variables=[visible_variables[jj]],
 #         factor_configs=np.arange(2)[:, None],
 #         log_potentials=np.array([0, bv[jj]]),
 #     )
-#     fg.add_factors(factor)
+#     fg.add_factors(unary_factor)
 #
 # # Add pairwise factors
 # factor_configs = np.array(list(itertools.product(np.arange(2), repeat=2)))
 # for ii in tqdm(range(bh.shape[0])):
 #     for jj in range(bv.shape[0]):
-#         factor = enumeration_factor.EnumerationFactor(
+#         pairwise_factor = factor.EnumFactor(
 #             variables=[hidden_variables[ii], visible_variables[jj]],
 #             factor_configs=factor_configs,
 #             log_potentials=np.array([0, 0, 0, W[ii, jj]]),
 #         )
-#         fg.add_factors(factor)
+#         fg.add_factors(pairwise_factor)
 # ~~~
 #
 # Once we have added the factors, we can run max-product LBP and get MAP decoding by
 # ~~~python
-# bp = graph.BP(fg.bp_state, temperature=0.0)
+# bp = infer.BP(fg.bp_state, temperature=0.0)
 # bp_arrays = bp.run_bp(bp.init(), num_iters=100, damping=0.5)
 # beliefs = bp.get_beliefs(bp_arrays)
-# map_states = graph.decode_map_states(beliefs)
+# map_states = infer.decode_map_states(beliefs)
 # ~~~
 # and run sum-product LBP and get estimated marginals by
 # ~~~python
-# bp = graph.BP(fg.bp_state, temperature=1.0)
+# bp = infer.BP(fg.bp_state, temperature=1.0)
 # bp_arrays = bp.run_bp(bp.init(), num_iters=100, damping=0.5)
 # beliefs = bp.get_beliefs(bp_arrays)
-# marginals = graph.get_marginals(beliefs)
+# marginals = infer.get_marginals(beliefs)
 # ~~~
 # More generally, PGMax implements LBP with temperature, with `temperature=0.0` and `temperature=1.0` corresponding to the commonly used max/sum-product LBP respectively.
 #
 # Now we are ready to demonstrate PMP sampling from RBM. PMP perturbs the model with [Gumbel](https://numpy.org/doc/stable/reference/random/generated/numpy.random.gumbel.html) unary potentials, and draws a sample from the RBM as the MAP decoding from running max-product LBP on the perturbed model
 
 # %%
-bp = graph.BP(fg.bp_state, temperature=0.0)
+bp = infer.BP(fg.bp_state, temperature=0.0)
 
 # %%
 bp_arrays = bp.init(
@@ -168,7 +166,7 @@
 # %%
 fig, ax = plt.subplots(1, 1, figsize=(10, 10))
 ax.imshow(
-    graph.decode_map_states(beliefs)[visible_variables].copy().reshape((28, 28)),
+    infer.decode_map_states(beliefs)[visible_variables].copy().reshape((28, 28)),
     cmap="gray",
 )
 ax.axis("off")
@@ -176,9 +174,9 @@
 # %% [markdown]
 # PGMax adopts a functional interface for implementing LBP: running LBP in PGMax starts with
 # ~~~python
-# bp = graph.BP(fg.bp_state, temperature=T)
+# bp = infer.BP(fg.bp_state, temperature=T)
 # ~~~
-# where the arguments of the `bp` are several useful functions to run LBP. In particular, `bp.init`, `bp.run_bp`, `bp.get_beliefs` are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`/`vmap`/`grad`, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see [here](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) and [here](https://github.com/n2cholas/awesome-jax)).
+# where the arguments of the `this_bp` are several useful functions to run LBP. In particular, `bp.init`, `bp.run_bp`, `bp.get_beliefs` are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`/`vmap`/`grad`, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see [here](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) and [here](https://github.com/n2cholas/awesome-jax)).
 #
 # As an example of applying `jax.vmap` to `bp.init`/`bp.run_bp`/`bp.get_beliefs` to process a batch of samples/models in parallel, instead of drawing one sample at a time as above, we can draw a batch of samples in parallel as follows:
 
@@ -197,7 +195,7 @@
 )(bp_arrays)
 
 beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays)
-map_states = graph.decode_map_states(beliefs)
+map_states = infer.decode_map_states(beliefs)
 
 # %% [markdown]
 # Visualizing the MAP decodings, we see that we have sampled 10 MNIST digits in parallel!
diff --git a/examples/rcn.py b/examples/rcn.py
index e83cf10c..df5a30dc 100644
--- a/examples/rcn.py
+++ b/examples/rcn.py
@@ -38,9 +38,7 @@
 from scipy.signal import fftconvolve
 from sklearn.datasets import fetch_openml
 
-from pgmax.fg import graph
-from pgmax.groups import variables as vgroup
-from pgmax.groups.enumeration import EnumerationFactorGroup
+from pgmax import fgraph, fgroup, infer, vgroup
 
 memory = Memory("./example_data/tmp")
 fetch_openml_cached = memory.cache(fetch_openml)
@@ -215,9 +213,7 @@ def fetch_mnist_dataset(test_size: int, seed: int = 5) -> tuple[np.ndarray, np.n
 variables_all_models = []
 for idx in range(frcs.shape[0]):
     frc = frcs[idx]
-    variables_all_models.append(
-        vgroup.NDVariableArray(num_states=M, shape=(frc.shape[0],))
-    )
+    variables_all_models.append(vgroup.NDVarArray(num_states=M, shape=(frc.shape[0],)))
 
 end = time.time()
 print(f"Creating variables took {end-start:.3f} seconds.")
@@ -272,7 +268,7 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray:
 
 # %%
 start = time.time()
-fg = graph.FactorGraph(variables_all_models)
+fg = fgraph.FactorGraph(variables_all_models)
 
 # Adding rcn model edges to the pgmax factor graph.
 for idx in range(edges.shape[0]):
@@ -280,7 +276,7 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray:
 
     for e in edge:
         i1, i2, r = e
-        factor_group = EnumerationFactorGroup(
+        factor_group = fgroup.EnumFactorGroup(
             variables_for_factors=[
                 [variables_all_models[idx][i1], variables_all_models[idx][i2]]
             ],
@@ -389,7 +385,7 @@ def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray:
     variables_all_models[model_idx]: frcs[model_idx]
     for model_idx in range(frcs.shape[0])
 }
-bp = graph.BP(fg.bp_state, temperature=0.0)
+bp = infer.BP(fg.bp_state, temperature=0.0)
 scores = np.zeros((len(test_set), frcs.shape[0]))
 map_states_dict = {}
 
@@ -406,7 +402,7 @@ def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray:
 
     start = end
     bp_arrays = bp.run_bp(bp.init(evidence_updates=evidence_updates), num_iters=30)
-    map_states = graph.decode_map_states(bp.get_beliefs(bp_arrays))
+    map_states = infer.decode_map_states(bp.get_beliefs(bp_arrays))
     end = time.time()
     print(f"Max product inference took {end-start:.3f} seconds for image {test_idx}.")
 
diff --git a/pgmax/bp/__init__.py b/pgmax/bp/__init__.py
deleted file mode 100644
index 4286090d..00000000
--- a/pgmax/bp/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""A sub-package containing functions to perform belief propagation."""
diff --git a/pgmax/bp/infer.py b/pgmax/bp/infer.py
deleted file mode 100644
index 3f549812..00000000
--- a/pgmax/bp/infer.py
+++ /dev/null
@@ -1,64 +0,0 @@
-"""A module containing the core message-passing functions for belief propagation"""
-
-import functools
-
-import jax
-import jax.numpy as jnp
-
-from pgmax.bp import bp_utils
-
-
-@jax.jit
-def pass_var_to_fac_messages(
-    ftov_msgs: jnp.array,
-    evidence: jnp.array,
-    var_states_for_edges: jnp.array,
-) -> jnp.array:
-    """Passes messages from Variables to Factors.
-
-    The update works by first summing the evidence and neighboring factor to variable messages for
-    each variable. Next, it subtracts messages from the correct elements of this sum to yield the
-    correct updated messages.
-
-    Args:
-        ftov_msgs: Array of shape (num_edge_state,). This holds all the flattened factor to variable
-            messages.
-        evidence: Array of shape (num_var_states,) representing the flattened evidence for each variable
-        var_states_for_edges: Array of shape (num_edge_states,)
-            Global variable state indices for each edge state
-    Returns:
-        Array of shape (num_edge_state,). This holds all the flattened variable to factor messages.
-    """
-    var_sums_arr = evidence.at[var_states_for_edges].add(ftov_msgs)
-    vtof_msgs = var_sums_arr[var_states_for_edges] - ftov_msgs
-    return vtof_msgs
-
-
-@functools.partial(jax.jit, static_argnames=("max_msg_size"))
-def normalize_and_clip_msgs(
-    msgs: jnp.ndarray,
-    edges_num_states: jnp.ndarray,
-    max_msg_size: int,
-) -> jnp.ndarray:
-    """Performs normalization and clipping of flattened messages
-
-    Normalization is done by subtracting the maximum value of every message from every element of every message,
-    clipping is done to keep every message value in the range [-1000, 0].
-
-    Args:
-        msgs: Array of shape (num_edge_state,). This holds all the flattened factor to variable messages.
-        edges_num_states: Array of shape (num_edges,). Number of states for the variables connected to each edge
-        max_msg_size: the max of edges_num_states
-
-    Returns:
-        Array of shape (num_edge_state,). This holds all the flattened factor to variable messages
-            after normalization and clipping
-    """
-    msgs = msgs - jnp.repeat(
-        bp_utils.segment_max_opt(msgs, edges_num_states, max_msg_size),
-        edges_num_states,
-        total_repeat_length=msgs.shape[0],
-    )
-    # Clip message values to be always greater than -1000
-    msgs = jnp.clip(msgs, -1000, None)
-    return msgs
diff --git a/pgmax/factor/__init__.py b/pgmax/factor/__init__.py
new file mode 100644
index 00000000..12808516
--- /dev/null
+++ b/pgmax/factor/__init__.py
@@ -0,0 +1,21 @@
+"""A sub-package defining factors containing different types of factors."""
+
+import collections
+from typing import Callable, OrderedDict, Type
+
+import jax.numpy as jnp
+
+from . import enum, logical
+from .enum import EnumFactor
+from .factor import Factor, Wiring
+from .logical import ANDFactor, ORFactor
+
+FAC_TO_VAR_UPDATES: OrderedDict[
+    Type, Callable[..., jnp.ndarray]
+] = collections.OrderedDict(
+    [
+        (EnumFactor, enum.pass_enum_fac_to_var_messages),
+        (ORFactor, logical.pass_logical_fac_to_var_messages),
+        (ANDFactor, logical.pass_logical_fac_to_var_messages),
+    ]
+)
diff --git a/pgmax/factors/enumeration.py b/pgmax/factor/enum.py
similarity index 82%
rename from pgmax/factors/enumeration.py
rename to pgmax/factor/enum.py
index 2340d0c1..856965e1 100644
--- a/pgmax/factors/enumeration.py
+++ b/pgmax/factor/enum.py
@@ -9,21 +9,22 @@
 import numba as nb
 import numpy as np
 
-from pgmax.bp import bp_utils
-from pgmax.fg import nodes
+from pgmax.utils import NEG_INF
+
+from . import factor
 
 
 @jax.tree_util.register_pytree_node_class
 @dataclass(frozen=True, eq=False)
-class EnumerationWiring(nodes.Wiring):
-    """Wiring for EnumerationFactors.
+class EnumWiring(factor.Wiring):
+    """Wiring for EnumFactors.
 
     Args:
         factor_configs_edge_states: Array of shape (num_factor_configs, 2)
             factor_configs_edge_states[ii] contains a pair of global enumeration factor_config and global edge_state indices
-            factor_configs_edge_states[ii, 0] contains the global EnumerationFactor config index,
+            factor_configs_edge_states[ii, 0] contains the global EnumFactor config index,
             factor_configs_edge_states[ii, 1] contains the corresponding global edge_state index.
-            Both indices only take into account the EnumerationFactors of the FactorGraph
+            Both indices only take into account the EnumFactors of the FactorGraph
 
     Attributes:
         num_val_configs: Number of valid configurations for this wiring
@@ -42,7 +43,7 @@ def __post_init__(self):
 
 
 @dataclass(frozen=True, eq=False)
-class EnumerationFactor(nodes.Factor):
+class EnumFactor(factor.Factor):
     """An enumeration factor
 
     Args:
@@ -78,7 +79,7 @@ def __post_init__(self):
         if self.factor_configs.ndim != 2:
             raise ValueError(
                 "factor_configs should be a 2D array containing a list of valid configurations for "
-                f"EnumerationFactor. Got a factor_configs array of shape {self.factor_configs.shape}."
+                f"EnumFactor. Got a factor_configs array of shape {self.factor_configs.shape}."
             )
 
         if len(self.variables) != self.factor_configs.shape[1]:
@@ -100,17 +101,17 @@ def __post_init__(self):
             raise ValueError("Invalid configurations for given variables")
 
     @staticmethod
-    def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiring:
-        """Concatenate a list of EnumerationWirings
+    def concatenate_wirings(wirings: Sequence[EnumWiring]) -> EnumWiring:
+        """Concatenate a list of EnumWirings
 
         Args:
-            wirings: A list of EnumerationWirings
+            wirings: A list of EnumWirings
 
         Returns:
-            Concatenated EnumerationWiring
+            Concatenated EnumWiring
         """
         if len(wirings) == 0:
-            return EnumerationWiring(
+            return EnumWiring(
                 edges_num_states=np.empty((0,), dtype=int),
                 var_states_for_edges=np.empty((0,), dtype=int),
                 factor_configs_edge_states=np.empty((0, 2), dtype=int),
@@ -124,7 +125,7 @@ def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiri
             0,
         )[:-1]
 
-        # Note: this correspomds to all the factor_to_msgs_starts of the EnumerationFactors
+        # Note: this correspomds to all the factor_to_msgs_starts of the EnumFactors
         num_edge_states_cumsum = np.insert(
             np.array([wiring.edges_num_states.sum() for wiring in wirings]).cumsum(),
             0,
@@ -140,7 +141,7 @@ def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiri
                 )
             )
 
-        return EnumerationWiring(
+        return EnumWiring(
             edges_num_states=np.concatenate(
                 [wiring.edges_num_states for wiring in wirings]
             ),
@@ -159,8 +160,8 @@ def compile_wiring(
         factor_configs: np.ndarray,
         vars_to_starts: Mapping[Tuple[int, int], int],
         num_factors: int,
-    ) -> EnumerationWiring:
-        """Compile an EnumerationWiring for an EnumerationFactor or a FactorGroup with EnumerationFactors.
+    ) -> EnumWiring:
+        """Compile an EnumWiring for an EnumFactor or a FactorGroup with EnumFactors.
         Internally calls _compile_var_states_numba and _compile_enumeration_wiring_numba for speed.
 
         Args:
@@ -181,7 +182,7 @@ def compile_wiring(
             ValueError: if factor_edges_num_states is not of shape (num_factors * num_variables, )
 
         Returns:
-            The EnumerationWiring
+            The EnumWiring
         """
         var_states = []
         for variables_for_factor in variables_for_factors:
@@ -191,7 +192,9 @@ def compile_wiring(
 
         num_states_cumsum = np.insert(np.cumsum(factor_edges_num_states), 0, 0)
         var_states_for_edges = np.empty(shape=(num_states_cumsum[-1],), dtype=int)
-        _compile_var_states_numba(var_states_for_edges, num_states_cumsum, var_states)
+        factor._compile_var_states_numba(
+            var_states_for_edges, num_states_cumsum, var_states
+        )
 
         num_configs, num_variables = factor_configs.shape
         if not factor_edges_num_states.shape == (num_factors * num_variables,):
@@ -206,33 +209,13 @@ def compile_wiring(
             factor_configs_edge_states, factor_configs, factor_edges_starts, num_factors
         )
 
-        return EnumerationWiring(
+        return EnumWiring(
             edges_num_states=factor_edges_num_states,
             var_states_for_edges=var_states_for_edges,
             factor_configs_edge_states=factor_configs_edge_states,
         )
 
 
-@nb.jit(parallel=False, cache=True, fastmath=True, nopython=True)
-def _compile_var_states_numba(
-    var_states_for_edges: np.ndarray,
-    num_states_cumsum: np.ndarray,
-    var_states: np.ndarray,
-) -> np.ndarray:
-    """Fast numba computation of the var_states_for_edges of a Wiring.
-    var_states_for_edges is updated in-place.
-    """
-
-    for variable_idx in nb.prange(num_states_cumsum.shape[0] - 1):
-        start_variable, end_variable = (
-            num_states_cumsum[variable_idx],
-            num_states_cumsum[variable_idx + 1],
-        )
-        var_states_for_edges[start_variable:end_variable] = var_states[
-            variable_idx
-        ] + np.arange(end_variable - start_variable)
-
-
 @nb.jit(parallel=False, cache=True, fastmath=True, nopython=True)
 def _compile_enumeration_wiring_numba(
     factor_configs_edge_states: np.ndarray,
@@ -240,7 +223,7 @@ def _compile_enumeration_wiring_numba(
     factor_edges_starts: np.ndarray,
     num_factors: int,
 ) -> np.ndarray:
-    """Fast numba computation of the factor_configs_edge_states of an EnumerationWiring.
+    """Fast numba computation of the factor_configs_edge_states of an EnumWiring.
     factor_edges_starts is updated in-place.
     """
 
@@ -274,30 +257,30 @@ def pass_enum_fac_to_var_messages(
     temperature: float,
 ) -> jnp.ndarray:
 
-    """Passes messages from EnumerationFactors to Variables.
+    """Passes messages from EnumFactors to Variables.
 
     The update is performed in two steps. First, a "summary" array is generated that has an entry for every valid
-    configuration for every EnumerationFactor. The elements of this array are simply the sums of messages across
+    configuration for every EnumFactor. The elements of this array are simply the sums of messages across
     each valid config. Then, the info from factor_configs_edge_states is used to apply the scattering operation and
     generate a flat set of output messages.
 
     Args:
         vtof_msgs: Array of shape (num_edge_state,). This holds all the flattened variable
-            to all the EnumerationFactors messages
+            to all the EnumFactors messages
         factor_configs_edge_states: Array of shape (num_factor_configs, 2)
             factor_configs_edge_states[ii] contains a pair of global factor_config and edge_state indices
-            factor_configs_edge_states[ii, 0] contains the global EnumerationFactor config index,
+            factor_configs_edge_states[ii, 0] contains the global EnumFactor config index,
             factor_configs_edge_states[ii, 1] contains the corresponding global edge_state index.
-            Both indices only take into account the EnumerationFactors of the FactorGraph
+            Both indices only take into account the EnumFactors of the FactorGraph
         log_potentials: Array of shape (num_val_configs, ). An entry at index i is the log potential
-            function value for the configuration with global EnumerationFactor config index i.
-        num_val_configs: the total number of valid configurations for all the EnumerationFactors
+            function value for the configuration with global EnumFactor config index i.
+        num_val_configs: the total number of valid configurations for all the EnumFactors
             in the factor graph.
         temperature: Temperature for loopy belief propagation.
             1.0 corresponds to sum-product, 0.0 corresponds to max-product.
 
     Returns:
-        Array of shape (num_edge_state,). This holds all the flattened EnumerationFactors to variable messages.
+        Array of shape (num_edge_state,). This holds all the flattened EnumFactors to variable messages.
     """
     fac_config_summary_sum = (
         jnp.zeros(shape=(num_val_configs,))
@@ -305,7 +288,7 @@ def pass_enum_fac_to_var_messages(
         .add(vtof_msgs[factor_configs_edge_states[..., 1]])
     ) + log_potentials
     max_factor_config_summary_for_edge_states = (
-        jnp.full(shape=(vtof_msgs.shape[0],), fill_value=bp_utils.NEG_INF)
+        jnp.full(shape=(vtof_msgs.shape[0],), fill_value=NEG_INF)
         .at[factor_configs_edge_states[..., 1]]
         .max(fac_config_summary_sum[factor_configs_edge_states[..., 0]])
     )
@@ -314,9 +297,7 @@ def pass_enum_fac_to_var_messages(
         ftov_msgs = ftov_msgs + (
             temperature
             * jnp.log(
-                jnp.full(
-                    shape=(vtof_msgs.shape[0],), fill_value=jnp.exp(bp_utils.NEG_INF)
-                )
+                jnp.full(shape=(vtof_msgs.shape[0],), fill_value=jnp.exp(NEG_INF))
                 .at[factor_configs_edge_states[..., 1]]
                 .add(
                     jnp.exp(
diff --git a/pgmax/fg/nodes.py b/pgmax/factor/factor.py
similarity index 73%
rename from pgmax/fg/nodes.py
rename to pgmax/factor/factor.py
index 8388a61e..090333b1 100644
--- a/pgmax/fg/nodes.py
+++ b/pgmax/factor/factor.py
@@ -1,10 +1,11 @@
-"""A module containing classes that specify the basic components of a Factor Graph."""
+"""A module containing classes that specify the basic components of a factor."""
 
 from dataclasses import asdict, dataclass
 from typing import List, Sequence, Tuple, Union
 
 import jax
 import jax.numpy as jnp
+import numba as nb
 import numpy as np
 
 
@@ -70,3 +71,23 @@ def concatenate_wirings(wirings: Sequence) -> Wiring:
         raise NotImplementedError(
             "Please subclass the Wiring class and override this method."
         )
+
+
+@nb.jit(parallel=False, cache=True, fastmath=True, nopython=True)
+def _compile_var_states_numba(
+    var_states_for_edges: np.ndarray,
+    num_states_cumsum: np.ndarray,
+    var_states: np.ndarray,
+) -> np.ndarray:
+    """Fast numba computation of the var_states_for_edges of a Wiring.
+    var_states_for_edges is updated in-place.
+    """
+
+    for variable_idx in nb.prange(num_states_cumsum.shape[0] - 1):
+        start_variable, end_variable = (
+            num_states_cumsum[variable_idx],
+            num_states_cumsum[variable_idx + 1],
+        )
+        var_states_for_edges[start_variable:end_variable] = var_states[
+            variable_idx
+        ] + np.arange(end_variable - start_variable)
diff --git a/pgmax/factors/logical.py b/pgmax/factor/logical.py
similarity index 92%
rename from pgmax/factors/logical.py
rename to pgmax/factor/logical.py
index 482865b2..54b64392 100644
--- a/pgmax/factors/logical.py
+++ b/pgmax/factor/logical.py
@@ -10,14 +10,14 @@
 import numpy as np
 from jax.nn import log_sigmoid, sigmoid
 
-from pgmax.bp import bp_utils
-from pgmax.factors import enumeration
-from pgmax.fg import nodes
+from pgmax.utils import NEG_INF
+
+from . import factor
 
 
 @jax.tree_util.register_pytree_node_class
 @dataclass(frozen=True, eq=False)
-class LogicalWiring(nodes.Wiring):
+class LogicalWiring(factor.Wiring):
     """Wiring for LogicalFactors.
 
     Args:
@@ -68,7 +68,7 @@ def __post_init__(self):
 
 
 @dataclass(frozen=True, eq=False)
-class LogicalFactor(nodes.Factor):
+class LogicalFactor(factor.Factor):
     """A logical OR/AND factor of the form (p1,...,pn, c)
     where p1,...,pn are the parents variables and c is the child variable.
 
@@ -176,7 +176,7 @@ def compile_wiring(
         # Note: all the variables in a LogicalFactorGroup are binary
         num_states_cumsum = np.arange(0, 2 * var_states.shape[0] + 2, 2)
         var_states_for_edges = np.empty(shape=(2 * var_states.shape[0],), dtype=int)
-        enumeration._compile_var_states_numba(
+        factor._compile_var_states_numba(
             var_states_for_edges, num_states_cumsum, var_states
         )
 
@@ -271,6 +271,38 @@ def _compile_logical_wiring_numba(
         )
 
 
+@functools.partial(jax.jit, static_argnames="num_labels")
+def get_maxes_and_argmaxes(
+    data: jnp.array, labels: jnp.array, num_labels: int
+) -> Tuple[jnp.ndarray, jnp.ndarray]:
+    """
+    Given a flattened sequence of elements and their corresponding labels,
+    returns the maxes and argmaxes of each label.
+
+    Args:
+        data: Array of shape (a_len,) where a_len is an arbitrary integer.
+        labels: Label array of shape (a_len,), assigning a label to each entry.
+            Labels must be 0,..., num_labels - 1.
+        num_labels: Number of different labels.
+
+    Returns:
+        Maxes and argmaxes arrays
+    """
+    num_obs = data.shape[0]
+
+    maxes = jnp.full(shape=(num_labels,), fill_value=NEG_INF).at[labels].max(data)
+    only_maxes_pos = jnp.arange(num_obs) - num_obs * jnp.where(
+        data != maxes[labels], 1, 0
+    )
+
+    argmaxes = (
+        jnp.full(shape=(num_labels,), fill_value=NEG_INF, dtype=jnp.int32)
+        .at[labels]
+        .max(only_maxes_pos)
+    )
+    return maxes, argmaxes
+
+
 @functools.partial(jax.jit, static_argnames=("temperature"))
 def pass_logical_fac_to_var_messages(
     vtof_msgs: jnp.ndarray,
@@ -319,11 +351,11 @@ def pass_logical_fac_to_var_messages(
     # See https://arxiv.org/pdf/2111.02458.pdf, Appendix C.3
     if temperature == 0.0:
         # Get the first and second argmaxes for the incoming parents messages of each factor
-        _, first_parents_argmaxes = bp_utils.get_maxes_and_argmaxes(
+        _, first_parents_argmaxes = get_maxes_and_argmaxes(
             parents_tof_msgs, factor_indices, num_factors
         )
-        _, second_parents_argmaxes = bp_utils.get_maxes_and_argmaxes(
-            parents_tof_msgs.at[first_parents_argmaxes].set(bp_utils.NEG_INF),
+        _, second_parents_argmaxes = get_maxes_and_argmaxes(
+            parents_tof_msgs.at[first_parents_argmaxes].set(NEG_INF),
             factor_indices,
             num_factors,
         )
diff --git a/pgmax/factors/__init__.py b/pgmax/factors/__init__.py
deleted file mode 100644
index acccf3bd..00000000
--- a/pgmax/factors/__init__.py
+++ /dev/null
@@ -1,18 +0,0 @@
-"""A sub-package containing different types of factors."""
-
-import collections
-from typing import Callable, OrderedDict, Type
-
-import jax.numpy as jnp
-
-from pgmax.factors import enumeration, logical
-
-FAC_TO_VAR_UPDATES: OrderedDict[
-    Type, Callable[..., jnp.ndarray]
-] = collections.OrderedDict(
-    [
-        (enumeration.EnumerationFactor, enumeration.pass_enum_fac_to_var_messages),
-        (logical.ORFactor, logical.pass_logical_fac_to_var_messages),
-        (logical.ANDFactor, logical.pass_logical_fac_to_var_messages),
-    ]
-)
diff --git a/pgmax/fg/__init__.py b/pgmax/fg/__init__.py
deleted file mode 100644
index 947adda2..00000000
--- a/pgmax/fg/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""A sub-package containing classes and functions used to specify Factor Graphs."""
diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py
deleted file mode 100644
index 16088e9f..00000000
--- a/pgmax/fg/graph.py
+++ /dev/null
@@ -1,996 +0,0 @@
-from __future__ import annotations
-
-"""A module containing the core class to specify a Factor Graph."""
-
-import collections
-import copy
-import functools
-import inspect
-from dataclasses import asdict, dataclass
-from types import MappingProxyType
-from typing import (
-    Any,
-    Callable,
-    Dict,
-    FrozenSet,
-    Hashable,
-    List,
-    Mapping,
-    Optional,
-    OrderedDict,
-    Sequence,
-    Set,
-    Tuple,
-    Type,
-    Union,
-    cast,
-)
-
-import jax
-import jax.numpy as jnp
-import numpy as np
-from jax.scipy.special import logsumexp
-
-from pgmax.bp import infer
-from pgmax.factors import FAC_TO_VAR_UPDATES
-from pgmax.fg import groups, nodes
-from pgmax.utils import cached_property
-
-
-@dataclass
-class FactorGraph:
-    """Class for representing a factor graph.
-    Factors in a graph are clustered in factor groups, which are grouped according to their factor types.
-
-    Args:
-        variable_groups: A single VariableGroup or a list of VariableGroups.
-    """
-
-    variable_groups: Union[groups.VariableGroup, Sequence[groups.VariableGroup]]
-
-    def __post_init__(self):
-        if isinstance(self.variable_groups, groups.VariableGroup):
-            self.variable_groups = [self.variable_groups]
-
-        # Useful objects to build the FactorGraph
-        self._factor_types_to_groups: OrderedDict[
-            Type, List[groups.FactorGroup]
-        ] = collections.OrderedDict(
-            [(factor_type, []) for factor_type in FAC_TO_VAR_UPDATES]
-        )
-        self._factor_types_to_variables_for_factors: OrderedDict[
-            Type, Set[FrozenSet]
-        ] = collections.OrderedDict(
-            [(factor_type, set()) for factor_type in FAC_TO_VAR_UPDATES]
-        )
-
-        # See FactorGraphState docstrings for documentation on the following fields
-        self._vars_to_starts: Dict[Tuple[int, int], int] = {}
-
-        vars_num_states_cumsum = 0
-        for variable_group in self.variable_groups:
-            vg_num_states = variable_group.num_states.flatten()
-            vg_num_states_cumsum = np.insert(np.cumsum(vg_num_states), 0, 0)
-            self._vars_to_starts.update(
-                zip(
-                    variable_group.variables,
-                    vars_num_states_cumsum + vg_num_states_cumsum[:-1],
-                )
-            )
-            vars_num_states_cumsum += vg_num_states_cumsum[-1]
-        self._num_var_states = vars_num_states_cumsum
-
-    def __hash__(self) -> int:
-        all_factor_groups = tuple(
-            [
-                factor_group
-                for factor_groups_per_type in self._factor_types_to_groups.values()
-                for factor_group in factor_groups_per_type
-            ]
-        )
-        return hash(all_factor_groups)
-
-    def add_factors(
-        self,
-        factors: Union[
-            nodes.Factor,
-            groups.FactorGroup,
-            Sequence[Union[nodes.Factor, groups.FactorGroup]],
-        ],
-    ) -> None:
-        """Add a single Factor, a FactorGroup or a list of single Factors and FactorGroups to the FactorGraph,
-        by updating the FactorGraphState.
-
-        Args:
-            factors: The Factor, FactorGroup or list of Factors and FactorGroups to be added to the FactorGraph.
-
-        Raises:
-            ValueError: A FactorGroup involving the same variables already exists in the FactorGraph.
-        """
-        if isinstance(factors, list):
-            for factor in factors:
-                self.add_factors(factor)
-            return None
-
-        if isinstance(factors, groups.FactorGroup):
-            factor_group = factors
-        elif isinstance(factors, nodes.Factor):
-            factor_group = groups.SingleFactorGroup(
-                variables_for_factors=[factors.variables],
-                factor=factors,
-            )
-
-        factor_type = factor_group.factor_type
-        for var_names_for_factor in factor_group.variables_for_factors:
-            var_names = frozenset(var_names_for_factor)
-            if var_names in self._factor_types_to_variables_for_factors[factor_type]:
-                raise ValueError(
-                    f"A Factor of type {factor_type} involving variables {var_names} already exists. Please merge the corresponding factors."
-                )
-            self._factor_types_to_variables_for_factors[factor_type].add(var_names)
-
-        self._factor_types_to_groups[factor_type].append(factor_group)
-
-    @functools.lru_cache(None)
-    def compute_offsets(self) -> None:
-        """Compute factor messages offsets for the factor types and factor groups
-        in the flattened array of message.
-        Also compute log potentials offsets for factor groups.
-
-        See FactorGraphState for documentation on the following fields
-
-        If offsets have already beeen compiled, do nothing.
-        """
-        # Message offsets for ftov messages
-        self._factor_type_to_msgs_range = collections.OrderedDict()
-        self._factor_group_to_msgs_starts = collections.OrderedDict()
-        factor_num_states_cumsum = 0
-
-        # Log potentials offsets
-        self._factor_type_to_potentials_range = collections.OrderedDict()
-        self._factor_group_to_potentials_starts = collections.OrderedDict()
-        factor_num_configs_cumsum = 0
-
-        for factor_type, factors_groups_by_type in self._factor_types_to_groups.items():
-            factor_type_num_states_start = factor_num_states_cumsum
-            factor_type_num_configs_start = factor_num_configs_cumsum
-            for factor_group in factors_groups_by_type:
-                self._factor_group_to_msgs_starts[
-                    factor_group
-                ] = factor_num_states_cumsum
-                self._factor_group_to_potentials_starts[
-                    factor_group
-                ] = factor_num_configs_cumsum
-
-                factor_num_states_cumsum += factor_group.factor_edges_num_states.sum()
-                factor_num_configs_cumsum += (
-                    factor_group.factor_group_log_potentials.shape[0]
-                )
-
-            self._factor_type_to_msgs_range[factor_type] = (
-                factor_type_num_states_start,
-                factor_num_states_cumsum,
-            )
-            self._factor_type_to_potentials_range[factor_type] = (
-                factor_type_num_configs_start,
-                factor_num_configs_cumsum,
-            )
-
-        self._total_factor_num_states = factor_num_states_cumsum
-        self._total_factor_num_configs = factor_num_configs_cumsum
-
-    @cached_property
-    def wiring(self) -> OrderedDict[Type, nodes.Wiring]:
-        """Function to compile wiring for belief propagation.
-
-        If wiring has already beeen compiled, do nothing.
-
-        Returns:
-            A dictionnary mapping each factor type to its wiring.
-        """
-        wiring = collections.OrderedDict(
-            [
-                (
-                    factor_type,
-                    [
-                        factor_group.compile_wiring(self._vars_to_starts)
-                        for factor_group in self._factor_types_to_groups[factor_type]
-                    ],
-                )
-                for factor_type in self._factor_types_to_groups
-            ]
-        )
-        wiring = collections.OrderedDict(
-            [
-                (factor_type, factor_type.concatenate_wirings(wiring[factor_type]))
-                for factor_type in wiring
-            ]
-        )
-        return wiring
-
-    @cached_property
-    def log_potentials(self) -> OrderedDict[Type, np.ndarray]:
-        """Function to compile potential array for belief propagation.
-
-        If potential array has already been compiled, do nothing.
-
-        Returns:
-            A dictionnary mapping each factor type to the array of the log of the potential
-                function for each valid configuration
-        """
-        log_potentials = collections.OrderedDict()
-        for factor_type, factors_groups_by_type in self._factor_types_to_groups.items():
-            if len(factors_groups_by_type) == 0:
-                log_potentials[factor_type] = np.empty((0,))
-            else:
-                log_potentials[factor_type] = np.concatenate(
-                    [
-                        factor_group.factor_group_log_potentials
-                        for factor_group in factors_groups_by_type
-                    ]
-                )
-
-        return log_potentials
-
-    @cached_property
-    def factors(self) -> OrderedDict[Type, Tuple[nodes.Factor, ...]]:
-        """Mapping factor type to individual factors in the factor graph.
-        This function is only called on demand when the user requires it."""
-        print(
-            "Factors have not been added to the factor graph yet, this may take a while..."
-        )
-
-        factors: OrderedDict[Type, Tuple[nodes.Factor, ...]] = collections.OrderedDict(
-            [
-                (
-                    factor_type,
-                    tuple(
-                        [
-                            factor
-                            for factor_group in self._factor_types_to_groups[
-                                factor_type
-                            ]
-                            for factor in factor_group.factors
-                        ]
-                    ),
-                )
-                for factor_type in self._factor_types_to_groups
-            ]
-        )
-        return factors
-
-    @property
-    def factor_groups(self) -> OrderedDict[Type, List[groups.FactorGroup]]:
-        """Tuple of factor groups in the factor graph"""
-        return self._factor_types_to_groups
-
-    @cached_property
-    def fg_state(self) -> FactorGraphState:
-        """Current factor graph state given the added factors."""
-        # Preliminary computations
-        self.compute_offsets()
-        log_potentials = np.concatenate(
-            [self.log_potentials[factor_type] for factor_type in self.log_potentials]
-        )
-        assert isinstance(self.variable_groups, list)
-
-        return FactorGraphState(
-            variable_groups=self.variable_groups,
-            vars_to_starts=self._vars_to_starts,
-            num_var_states=self._num_var_states,
-            total_factor_num_states=self._total_factor_num_states,
-            factor_type_to_msgs_range=copy.copy(self._factor_type_to_msgs_range),
-            factor_type_to_potentials_range=copy.copy(
-                self._factor_type_to_potentials_range
-            ),
-            factor_group_to_potentials_starts=copy.copy(
-                self._factor_group_to_potentials_starts
-            ),
-            log_potentials=log_potentials,
-            wiring=self.wiring,
-        )
-
-    @property
-    def bp_state(self) -> BPState:
-        """Relevant information for doing belief propagation."""
-        # Preliminary computations
-        self.compute_offsets()
-
-        return BPState(
-            log_potentials=LogPotentials(fg_state=self.fg_state),
-            ftov_msgs=FToVMessages(fg_state=self.fg_state),
-            evidence=Evidence(fg_state=self.fg_state),
-        )
-
-
-@dataclass(frozen=True, eq=False)
-class FactorGraphState:
-    """FactorGraphState.
-
-    Args:
-        variable_groups: VariableGroups in the FactorGraph.
-        vars_to_starts: Maps variables to their starting indices in the flat evidence array.
-            flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states]
-            contains evidence to the variable.
-        num_var_states: Total number of variable states.
-        total_factor_num_states: Size of the flat ftov messages array.
-        factor_type_to_msgs_range: Maps factors types to their start and end indices in the flat ftov messages.
-        factor_type_to_potentials_range: Maps factor types to their start and end indices in the flat log potentials.
-        factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials.
-        log_potentials: Flat log potentials array concatenated for each factor type.
-        wiring: Wiring derived for each factor type.
-    """
-
-    variable_groups: Sequence[groups.VariableGroup]
-    vars_to_starts: Mapping[Tuple[int, int], int]
-    num_var_states: int
-    total_factor_num_states: int
-    factor_type_to_msgs_range: OrderedDict[type, Tuple[int, int]]
-    factor_type_to_potentials_range: OrderedDict[type, Tuple[int, int]]
-    factor_group_to_potentials_starts: OrderedDict[groups.FactorGroup, int]
-    log_potentials: OrderedDict[type, None | np.ndarray]
-    wiring: OrderedDict[type, nodes.Wiring]
-
-    def __post_init__(self):
-        for field in self.__dataclass_fields__:
-            if isinstance(getattr(self, field), np.ndarray):
-                getattr(self, field).flags.writeable = False
-
-            if isinstance(getattr(self, field), Mapping):
-                object.__setattr__(self, field, MappingProxyType(getattr(self, field)))
-
-
-@dataclass(frozen=True, eq=False)
-class BPState:
-    """Container class for belief propagation states, including log potentials,
-    ftov messages and evidence (unary log potentials).
-
-    Args:
-        log_potentials: log potentials of the model
-        ftov_msgs: factor to variable messages
-        evidence: evidence (unary log potentials) for variables.
-
-    Raises:
-        ValueError: If log_potentials, ftov_msgs or evidence are not derived from the same
-            FactorGraphState.
-    """
-
-    log_potentials: LogPotentials
-    ftov_msgs: FToVMessages
-    evidence: Evidence
-
-    def __post_init__(self):
-        if (self.log_potentials.fg_state != self.ftov_msgs.fg_state) or (
-            self.ftov_msgs.fg_state != self.evidence.fg_state
-        ):
-            raise ValueError(
-                "log_potentials, ftov_msgs and evidence should be derived from the same fg_state."
-            )
-
-    @property
-    def fg_state(self) -> FactorGraphState:
-        return self.log_potentials.fg_state
-
-
-@functools.partial(jax.jit, static_argnames="fg_state")
-def update_log_potentials(
-    log_potentials: jnp.ndarray,
-    updates: Dict[Any, jnp.ndarray],
-    fg_state: FactorGraphState,
-) -> jnp.ndarray:
-    """Function to update log_potentials.
-
-    Args:
-        log_potentials: A flat jnp array containing log_potentials.
-        updates: A dictionary containing updates for log_potentials
-        fg_state: Factor graph state
-
-    Returns:
-        A flat jnp array containing updated log_potentials.
-
-    Raises: ValueError if
-        (1) Provided log_potentials shape does not match the expected log_potentials shape.
-        (2) Provided name is not valid for log_potentials updates.
-    """
-    for factor_group, data in updates.items():
-        if factor_group in fg_state.factor_group_to_potentials_starts:
-            flat_data = factor_group.flatten(data)
-            if flat_data.shape != factor_group.factor_group_log_potentials.shape:
-                raise ValueError(
-                    f"Expected log potentials shape {factor_group.factor_group_log_potentials.shape} "
-                    f"for factor group. Got incompatible data shape {data.shape}."
-                )
-
-            start = fg_state.factor_group_to_potentials_starts[factor_group]
-            log_potentials = log_potentials.at[start : start + flat_data.shape[0]].set(
-                flat_data
-            )
-        else:
-            raise ValueError("Invalid FactorGroup for log potentials updates.")
-
-    return log_potentials
-
-
-@dataclass(frozen=True, eq=False)
-class LogPotentials:
-    """Class for storing and manipulating log potentials.
-
-    Args:
-        fg_state: Factor graph state
-        value: Optionally specify an initial value
-
-    Raises:
-        ValueError: If provided value shape does not match the expected log_potentials shape.
-    """
-
-    fg_state: FactorGraphState
-    value: Optional[np.ndarray] = None
-
-    def __post_init__(self):
-        if self.value is None:
-            object.__setattr__(self, "value", self.fg_state.log_potentials)
-        else:
-            if not self.value.shape == self.fg_state.log_potentials.shape:
-                raise ValueError(
-                    f"Expected log potentials shape {self.fg_state.log_potentials.shape}. "
-                    f"Got {self.value.shape}."
-                )
-
-            object.__setattr__(self, "value", self.value)
-
-    def __getitem__(self, factor_group: groups.FactorGroup) -> np.ndarray:
-        """Function to query log potentials for a FactorGroup.
-
-        Args:
-            factor_group: Queried FactorGroup
-
-        Returns:
-            The queried log potentials.
-        """
-        value = cast(np.ndarray, self.value)
-        if factor_group in self.fg_state.factor_group_to_potentials_starts:
-            start = self.fg_state.factor_group_to_potentials_starts[factor_group]
-            log_potentials = value[
-                start : start + factor_group.factor_group_log_potentials.shape[0]
-            ]
-        else:
-            raise ValueError("Invalid FactorGroup queried to access log potentials.")
-        return log_potentials
-
-    def __setitem__(
-        self,
-        factor_group: groups.FactorGroup,
-        data: Union[np.ndarray, jnp.ndarray],
-    ):
-        """Set the log potentials for a FactorGroup
-
-        Args:
-            factor_group: FactorGroup
-            data: Array containing the log potentials for the FactorGroup
-        """
-        object.__setattr__(
-            self,
-            "value",
-            np.asarray(
-                update_log_potentials(
-                    jax.device_put(self.value),
-                    {factor_group: jax.device_put(data)},
-                    self.fg_state,
-                )
-            ),
-        )
-
-
-@functools.partial(jax.jit, static_argnames="fg_state")
-def update_ftov_msgs(
-    ftov_msgs: jnp.ndarray, updates: Dict[Any, jnp.ndarray], fg_state: FactorGraphState
-) -> jnp.ndarray:
-    """Function to update ftov_msgs.
-
-    Args:
-        ftov_msgs: A flat jnp array containing ftov_msgs.
-        updates: A dictionary containing updates for ftov_msgs
-        fg_state: Factor graph state
-
-    Returns:
-        A flat jnp array containing updated ftov_msgs.
-
-    Raises: ValueError if:
-        (1) provided ftov_msgs shape does not match the expected ftov_msgs shape.
-        (2) provided variable is not in the FactorGraph.
-    """
-    for variable, data in updates.items():
-        if variable in fg_state.vars_to_starts:
-            if data.shape != (variable[1],):
-                raise ValueError(
-                    f"Given belief shape {data.shape} does not match expected "
-                    f"shape {(variable[1],)} for variable {variable}."
-                )
-
-            var_states_for_edges = np.concatenate(
-                [
-                    wiring_by_type.var_states_for_edges
-                    for wiring_by_type in fg_state.wiring.values()
-                ]
-            )
-
-            starts = np.nonzero(
-                var_states_for_edges == fg_state.vars_to_starts[variable]
-            )[0]
-            for start in starts:
-                ftov_msgs = ftov_msgs.at[start : start + variable[1]].set(
-                    data / starts.shape[0]
-                )
-        else:
-            raise ValueError("Provided variable is not in the FactorGraph")
-    return ftov_msgs
-
-
-@dataclass(frozen=True, eq=False)
-class FToVMessages:
-    """Class for storing and manipulating factor to variable messages.
-
-    Args:
-        fg_state: Factor graph state
-        value: Optionally specify initial value for ftov messages
-
-    Raises: ValueError if provided value does not match expected ftov messages shape.
-    """
-
-    fg_state: FactorGraphState
-    value: Optional[np.ndarray] = None
-
-    def __post_init__(self):
-        if self.value is None:
-            object.__setattr__(
-                self, "value", np.zeros(self.fg_state.total_factor_num_states)
-            )
-        else:
-            if not self.value.shape == (self.fg_state.total_factor_num_states,):
-                raise ValueError(
-                    f"Expected messages shape {(self.fg_state.total_factor_num_states,)}. "
-                    f"Got {self.value.shape}."
-                )
-
-            object.__setattr__(self, "value", self.value)
-
-    def __setitem__(
-        self,
-        variable: Tuple[int, int],
-        data: Union[np.ndarray, jnp.ndarray],
-    ) -> None:
-        """Spreading beliefs at a variable to all connected Factors
-
-        Args:
-            variable: Variable queried
-            data: An array containing the beliefs to be spread uniformly
-                across all factors to variable messages involving this variable.
-        """
-
-        object.__setattr__(
-            self,
-            "value",
-            np.asarray(
-                update_ftov_msgs(
-                    jax.device_put(self.value),
-                    {variable: jax.device_put(data)},
-                    self.fg_state,
-                )
-            ),
-        )
-
-
-@functools.partial(jax.jit, static_argnames="fg_state")
-def update_evidence(
-    evidence: jnp.ndarray, updates: Dict[Any, jnp.ndarray], fg_state: FactorGraphState
-) -> jnp.ndarray:
-    """Function to update evidence.
-
-    Args:
-        evidence: A flat jnp array containing evidence.
-        updates: A dictionary containing updates for evidence
-        fg_state: Factor graph state
-
-    Returns:
-        A flat jnp array containing updated evidence.
-    """
-    for name, data in updates.items():
-        # Name is a variable_group or a variable
-        if name in fg_state.variable_groups:
-            first_variable = name.variables[0]
-            start_index = fg_state.vars_to_starts[first_variable]
-            flat_data = name.flatten(data)
-            evidence = evidence.at[start_index : start_index + flat_data.shape[0]].set(
-                flat_data
-            )
-        elif name in fg_state.vars_to_starts:
-            start_index = fg_state.vars_to_starts[name]
-            evidence = evidence.at[start_index : start_index + name[1]].set(data)
-        else:
-            raise ValueError(
-                "Got evidence for a variable or a VariableGroup not in the FactorGraph!"
-            )
-    return evidence
-
-
-@dataclass(frozen=True, eq=False)
-class Evidence:
-    """Class for storing and manipulating evidence
-
-    Args:
-        fg_state: Factor graph state
-        value: Optionally specify initial value for evidence
-
-    Raises: ValueError if provided value does not match expected evidence shape.
-    """
-
-    fg_state: FactorGraphState
-    value: Optional[np.ndarray] = None
-
-    def __post_init__(self):
-        if self.value is None:
-            object.__setattr__(self, "value", np.zeros(self.fg_state.num_var_states))
-        else:
-            if self.value.shape != (self.fg_state.num_var_states,):
-                raise ValueError(
-                    f"Expected evidence shape {(self.fg_state.num_var_states,)}. "
-                    f"Got {self.value.shape}."
-                )
-
-            object.__setattr__(self, "value", self.value)
-
-    def __getitem__(self, variable: Tuple[int, int]) -> np.ndarray:
-        """Function to query evidence for a variable
-
-        Args:
-            variable: Variable queried
-
-        Returns:
-            evidence for the queried variable
-        """
-        value = cast(np.ndarray, self.value)
-        start = self.fg_state.vars_to_starts[variable]
-        evidence = value[start : start + variable[1]]
-        return evidence
-
-    def __setitem__(
-        self,
-        name: Any,
-        data: np.ndarray,
-    ) -> None:
-        """Function to update the evidence for variables
-
-        Args:
-            name: The name of a variable group or a single variable.
-                If name is the name of a variable group, updates are derived by using the variable group to
-                flatten the data.
-                If name is the name of a variable, data should be of an array shape (num_states,)
-                If name is None, updates are derived by using self.fg_state.variable_groups to flatten the data.
-            data: Array containing the evidence updates.
-        """
-        object.__setattr__(
-            self,
-            "value",
-            np.asarray(
-                update_evidence(
-                    jax.device_put(self.value),
-                    {name: jax.device_put(data)},
-                    self.fg_state,
-                ),
-            ),
-        )
-
-
-@jax.tree_util.register_pytree_node_class
-@dataclass(frozen=True, eq=False)
-class BPArrays:
-    """Container for the relevant flat arrays used in belief propagation.
-
-    Args:
-        log_potentials: Flat log potentials array.
-        ftov_msgs: Flat factor to variable messages array.
-        evidence: Flat evidence array.
-    """
-
-    log_potentials: Union[np.ndarray, jnp.ndarray]
-    ftov_msgs: Union[np.ndarray, jnp.ndarray]
-    evidence: Union[np.ndarray, jnp.ndarray]
-
-    def __post_init__(self):
-        for field in self.__dataclass_fields__:
-            if isinstance(getattr(self, field), np.ndarray):
-                getattr(self, field).flags.writeable = False
-
-    def tree_flatten(self):
-        return jax.tree_util.tree_flatten(asdict(self))
-
-    @classmethod
-    def tree_unflatten(cls, aux_data, children):
-        return cls(**aux_data.unflatten(children))
-
-
-@dataclass(frozen=True, eq=False)
-class BeliefPropagation:
-    """Belief propagation functions.
-
-    Arguments:
-        init: Function to create log_potentials, ftov_msgs and evidence.
-            Args:
-                log_potentials_updates: Optional dictionary containing log_potentials updates.
-                ftov_msgs_updates: Optional dictionary containing ftov_msgs updates.
-                evidence_updates: Optional dictionary containing evidence updates.
-            Returns:
-                A BPArrays with the log_potentials, ftov_msgs and evidence.
-
-        update: Function to update log_potentials, ftov_msgs and evidence.
-            Args:
-                bp_arrays: Optional arrays of log_potentials, ftov_msgs, evidence.
-                log_potentials_updates: Optional dictionary containing log_potentials updates.
-                ftov_msgs_updates: Optional dictionary containing ftov_msgs updates.
-                evidence_updates: Optional dictionary containing evidence updates.
-            Returns:
-                A BPArrays with the updated log_potentials, ftov_msgs and evidence.
-
-        run_bp: Function to run belief propagation for num_iters with a damping_factor.
-            Args:
-                bp_arrays: Initial arrays of log_potentials, ftov_msgs, evidence.
-                num_iters: Number of belief propagation iterations.
-                damping: The damping factor to use for message updates between one timestep and the next.
-            Returns:
-                A BPArrays containing the updated ftov_msgs.
-
-        get_bp_state: Function to reconstruct the BPState from a BPArrays.
-            Args:
-                bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence.
-            Returns:
-                The reconstructed BPState
-
-        get_beliefs: Function to calculate beliefs from a BPArrays.
-            Args:
-                bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence.
-            Returns:
-                beliefs: Beliefs returned by belief propagation.
-    """
-
-    init: Callable
-    update: Callable
-    run_bp: Callable
-    to_bp_state: Callable
-    get_beliefs: Callable
-
-
-def BP(bp_state: BPState, temperature: float = 0.0) -> BeliefPropagation:
-    """Function for generating belief propagation functions.
-
-    Args:
-        bp_state: Belief propagation state.
-        temperature: Temperature for loopy belief propagation.
-            1.0 corresponds to sum-product, 0.0 corresponds to max-product.
-
-    Returns:
-        Belief propagation functions.
-    """
-    wiring = bp_state.fg_state.wiring
-    edges_num_states = np.concatenate(
-        [wiring[factor_type].edges_num_states for factor_type in FAC_TO_VAR_UPDATES]
-    )
-    max_msg_size = int(np.max(edges_num_states))
-
-    var_states_for_edges = np.concatenate(
-        [wiring[factor_type].var_states_for_edges for factor_type in FAC_TO_VAR_UPDATES]
-    )
-
-    # Inference argumnets per factor type
-    inference_arguments: Dict[type, Mapping] = {}
-    for factor_type in FAC_TO_VAR_UPDATES:
-        this_inference_arguments = inspect.getfullargspec(
-            FAC_TO_VAR_UPDATES[factor_type]
-        ).args
-        this_inference_arguments.remove("vtof_msgs")
-        this_inference_arguments.remove("log_potentials")
-        this_inference_arguments.remove("temperature")
-        this_inference_arguments = {
-            key: getattr(wiring[factor_type], key) for key in this_inference_arguments
-        }
-        inference_arguments[factor_type] = this_inference_arguments
-
-    factor_type_to_msgs_range = bp_state.fg_state.factor_type_to_msgs_range
-    factor_type_to_potentials_range = bp_state.fg_state.factor_type_to_potentials_range
-
-    def update(
-        bp_arrays: Optional[BPArrays] = None,
-        log_potentials_updates: Optional[Dict[Any, jnp.ndarray]] = None,
-        ftov_msgs_updates: Optional[Dict[Any, jnp.ndarray]] = None,
-        evidence_updates: Optional[Dict[Any, jnp.ndarray]] = None,
-    ) -> BPArrays:
-        """Function to update belief propagation log_potentials, ftov_msgs, evidence.
-
-        Args:
-            bp_arrays: Optional arrays of log_potentials, ftov_msgs, evidence.
-            log_potentials_updates: Optional dictionary containing log_potentials updates.
-            ftov_msgs_updates: Optional dictionary containing ftov_msgs updates.
-            evidence_updates: Optional dictionary containing evidence updates.
-
-        Returns:
-            A BPArrays with the updated log_potentials, ftov_msgs and evidence.
-        """
-        if bp_arrays is not None:
-            log_potentials = bp_arrays.log_potentials
-            evidence = bp_arrays.evidence
-            ftov_msgs = bp_arrays.ftov_msgs
-        else:
-            log_potentials = jax.device_put(bp_state.log_potentials.value)
-            ftov_msgs = bp_state.ftov_msgs.value
-            evidence = bp_state.evidence.value
-
-        if log_potentials_updates is not None:
-            log_potentials = update_log_potentials(
-                log_potentials, log_potentials_updates, bp_state.fg_state
-            )
-
-        if ftov_msgs_updates is not None:
-            ftov_msgs = update_ftov_msgs(
-                ftov_msgs, ftov_msgs_updates, bp_state.fg_state
-            )
-
-        if evidence_updates is not None:
-            evidence = update_evidence(evidence, evidence_updates, bp_state.fg_state)
-
-        return BPArrays(
-            log_potentials=log_potentials, ftov_msgs=ftov_msgs, evidence=evidence
-        )
-
-    def run_bp(
-        bp_arrays: BPArrays,
-        num_iters: int,
-        damping: float = 0.5,
-    ) -> BPArrays:
-        """Function to run belief propagation for num_iters with a damping_factor.
-
-        Args:
-            bp_arrays: Initial arrays of log_potentials, ftov_msgs, evidence.
-            num_iters: Number of belief propagation iterations.
-            damping: The damping factor to use for message updates between one timestep and the next.
-
-        Returns:
-            A BPArrays containing the updated ftov_msgs.
-        """
-        log_potentials = bp_arrays.log_potentials
-        evidence = bp_arrays.evidence
-        ftov_msgs = bp_arrays.ftov_msgs
-
-        # Normalize the messages to ensure the maximum value is 0.
-        ftov_msgs = infer.normalize_and_clip_msgs(
-            ftov_msgs, edges_num_states, max_msg_size
-        )
-
-        @jax.checkpoint
-        def update(msgs: jnp.ndarray, _) -> Tuple[jnp.ndarray, None]:
-            # Compute new variable to factor messages by message passing
-            vtof_msgs = infer.pass_var_to_fac_messages(
-                msgs,
-                evidence,
-                var_states_for_edges,
-            )
-            ftov_msgs = jnp.zeros_like(vtof_msgs)
-            for factor_type in FAC_TO_VAR_UPDATES:
-                msgs_start, msgs_end = factor_type_to_msgs_range[factor_type]
-                potentials_start, potentials_end = factor_type_to_potentials_range[
-                    factor_type
-                ]
-                ftov_msgs_type = FAC_TO_VAR_UPDATES[factor_type](
-                    vtof_msgs=vtof_msgs[msgs_start:msgs_end],
-                    log_potentials=log_potentials[potentials_start:potentials_end],
-                    temperature=temperature,
-                    **inference_arguments[factor_type],
-                )
-                ftov_msgs = ftov_msgs.at[msgs_start:msgs_end].set(ftov_msgs_type)
-
-            # Use the results of message passing to perform damping and
-            # update the factor to variable messages
-            delta_msgs = ftov_msgs - msgs
-            msgs = msgs + (1 - damping) * delta_msgs
-            # Normalize and clip these damped, updated messages before returning them.
-            msgs = infer.normalize_and_clip_msgs(msgs, edges_num_states, max_msg_size)
-            return msgs, None
-
-        ftov_msgs, _ = jax.lax.scan(update, ftov_msgs, None, num_iters)
-
-        return BPArrays(
-            log_potentials=log_potentials, ftov_msgs=ftov_msgs, evidence=evidence
-        )
-
-    def to_bp_state(bp_arrays: BPArrays) -> BPState:
-        """Function to reconstruct the BPState from a BPArrays
-
-        Args:
-            bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence.
-
-        Returns:
-            The reconstructed BPState
-        """
-        return BPState(
-            log_potentials=LogPotentials(
-                fg_state=bp_state.fg_state, value=bp_arrays.log_potentials
-            ),
-            ftov_msgs=FToVMessages(
-                fg_state=bp_state.fg_state,
-                value=bp_arrays.ftov_msgs,
-            ),
-            evidence=Evidence(fg_state=bp_state.fg_state, value=bp_arrays.evidence),
-        )
-
-    def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
-        """Function that returns unflattened beliefs from the flat beliefs
-
-        Args:
-            flat_beliefs: Flattened array of beliefs
-            variable_groups: All the variable groups in the FactorGraph.
-        """
-        beliefs = {}
-        start = 0
-        for variable_group in variable_groups:
-            num_states = variable_group.num_states
-            assert isinstance(num_states, np.ndarray)
-            length = num_states.sum()
-
-            beliefs[variable_group] = variable_group.unflatten(
-                flat_beliefs[start : start + length]
-            )
-            start += length
-        return beliefs
-
-    @jax.jit
-    def get_beliefs(bp_arrays: BPArrays) -> Dict[Hashable, Any]:
-        """Function to calculate beliefs from a BPArrays
-
-        Args:
-            bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence.
-
-        Returns:
-            beliefs: Beliefs returned by belief propagation.
-        """
-
-        flat_beliefs = (
-            jax.device_put(bp_arrays.evidence)
-            .at[jax.device_put(var_states_for_edges)]
-            .add(bp_arrays.ftov_msgs)
-        )
-        return unflatten_beliefs(flat_beliefs, bp_state.fg_state.variable_groups)
-
-    bp = BeliefPropagation(
-        init=functools.partial(update, None),
-        update=update,
-        run_bp=run_bp,
-        to_bp_state=to_bp_state,
-        get_beliefs=get_beliefs,
-    )
-    return bp
-
-
-@jax.jit
-def decode_map_states(beliefs: Dict[Hashable, Any]) -> Any:
-    """Function to decode MAP states given the calculated beliefs.
-
-    Args:
-        beliefs: An array or a PyTree container containing beliefs for different variables.
-
-    Returns:
-        An array or a PyTree container containing the MAP states for different variables.
-    """
-    return jax.tree_util.tree_map(lambda x: jnp.argmax(x, axis=-1), beliefs)
-
-
-@jax.jit
-def get_marginals(beliefs: Dict[Hashable, Any]) -> Any:
-    """Function to get marginal probabilities given the calculated beliefs.
-
-    Args:
-        beliefs: An array or a PyTree container containing beliefs for different variables.
-
-    Returns:
-        An array or a PyTree container containing the marginal probabilities different variables.
-    """
-    return jax.tree_util.tree_map(
-        lambda x: jnp.exp(x - logsumexp(x, axis=-1, keepdims=True)), beliefs
-    )
diff --git a/pgmax/fgraph/__init__.py b/pgmax/fgraph/__init__.py
new file mode 100644
index 00000000..d85e1834
--- /dev/null
+++ b/pgmax/fgraph/__init__.py
@@ -0,0 +1,3 @@
+"""A sub-package containing functions to represent a factor graph."""
+
+from .fgraph import FactorGraph, FactorGraphState
diff --git a/pgmax/fgraph/fgraph.py b/pgmax/fgraph/fgraph.py
new file mode 100644
index 00000000..2d4e7ffe
--- /dev/null
+++ b/pgmax/fgraph/fgraph.py
@@ -0,0 +1,333 @@
+from __future__ import annotations
+
+"""A module containing the core class to specify a Factor Graph."""
+
+import collections
+import copy
+import functools
+from dataclasses import dataclass
+from types import MappingProxyType
+from typing import (
+    Any,
+    Dict,
+    FrozenSet,
+    List,
+    Mapping,
+    OrderedDict,
+    Sequence,
+    Set,
+    Tuple,
+    Type,
+    Union,
+)
+
+import numpy as np
+
+from pgmax import factor, fgroup, vgroup
+from pgmax.factor import FAC_TO_VAR_UPDATES
+from pgmax.utils import cached_property
+
+
+@dataclass
+class FactorGraph:
+    """Class for representing a factor graph.
+    Factors in a graph are clustered in factor groups, which are grouped according to their factor types.
+
+    Args:
+        variable_groups: A single VarGroup or a list of VarGroups.
+    """
+
+    variable_groups: Union[vgroup.VarGroup, Sequence[vgroup.VarGroup]]
+
+    def __post_init__(self):
+        if isinstance(self.variable_groups, vgroup.VarGroup):
+            self.variable_groups = [self.variable_groups]
+
+        # Useful objects to build the FactorGraph
+        self._factor_types_to_groups: OrderedDict[
+            Type, List[fgroup.FactorGroup]
+        ] = collections.OrderedDict(
+            [(factor_type, []) for factor_type in FAC_TO_VAR_UPDATES]
+        )
+        self._factor_types_to_variables_for_factors: OrderedDict[
+            Type, Set[FrozenSet]
+        ] = collections.OrderedDict(
+            [(factor_type, set()) for factor_type in FAC_TO_VAR_UPDATES]
+        )
+
+        # See FactorGraphState docstrings for documentation on the following fields
+        self._vars_to_starts: Dict[Tuple[int, int], int] = {}
+
+        vars_num_states_cumsum = 0
+        for variable_group in self.variable_groups:
+            vg_num_states = variable_group.num_states.flatten()
+            vg_num_states_cumsum = np.insert(np.cumsum(vg_num_states), 0, 0)
+            self._vars_to_starts.update(
+                zip(
+                    variable_group.variables,
+                    vars_num_states_cumsum + vg_num_states_cumsum[:-1],
+                )
+            )
+            vars_num_states_cumsum += vg_num_states_cumsum[-1]
+        self._num_var_states = vars_num_states_cumsum
+
+    def __hash__(self) -> int:
+        all_factor_groups = tuple(
+            [
+                factor_group
+                for factor_groups_per_type in self._factor_types_to_groups.values()
+                for factor_group in factor_groups_per_type
+            ]
+        )
+        return hash(all_factor_groups)
+
+    def add_factors(
+        self,
+        factors: Union[
+            factor.Factor,
+            fgroup.FactorGroup,
+            Sequence[Union[factor.Factor, fgroup.FactorGroup]],
+        ],
+    ) -> None:
+        """Add a single Factor, a FactorGroup or a list of single Factors and FactorGroups to the FactorGraph,
+        by updating the FactorGraphState.
+
+        Args:
+            factors: The Factor, FactorGroup or list of Factors and FactorGroups to be added to the FactorGraph.
+
+        Raises:
+            ValueError: A FactorGroup involving the same variables already exists in the FactorGraph.
+        """
+        if isinstance(factors, list):
+            for this_factor in factors:
+                self.add_factors(this_factor)
+            return None
+
+        if isinstance(factors, fgroup.FactorGroup):
+            factor_group = factors
+        elif isinstance(factors, factor.Factor):
+            factor_group = fgroup.SingleFactorGroup(
+                variables_for_factors=[factors.variables],
+                factor=factors,
+            )
+
+        factor_type = factor_group.factor_type
+        for var_names_for_factor in factor_group.variables_for_factors:
+            var_names = frozenset(var_names_for_factor)
+            if var_names in self._factor_types_to_variables_for_factors[factor_type]:
+                raise ValueError(
+                    f"A Factor of type {factor_type} involving variables {var_names} already exists. Please merge the corresponding factors."
+                )
+            self._factor_types_to_variables_for_factors[factor_type].add(var_names)
+
+        self._factor_types_to_groups[factor_type].append(factor_group)
+
+    @functools.lru_cache(None)
+    def compute_offsets(self) -> None:
+        """Compute factor messages offsets for the factor types and factor groups
+        in the flattened array of message.
+        Also compute log potentials offsets for factor groups.
+
+        See FactorGraphState for documentation on the following fields
+
+        If offsets have already beeen compiled, do nothing.
+        """
+        # Message offsets for ftov messages
+        self._factor_type_to_msgs_range = collections.OrderedDict()
+        self._factor_group_to_msgs_starts = collections.OrderedDict()
+        factor_num_states_cumsum = 0
+
+        # Log potentials offsets
+        self._factor_type_to_potentials_range = collections.OrderedDict()
+        self._factor_group_to_potentials_starts = collections.OrderedDict()
+        factor_num_configs_cumsum = 0
+
+        for factor_type, factors_groups_by_type in self._factor_types_to_groups.items():
+            factor_type_num_states_start = factor_num_states_cumsum
+            factor_type_num_configs_start = factor_num_configs_cumsum
+            for factor_group in factors_groups_by_type:
+                self._factor_group_to_msgs_starts[
+                    factor_group
+                ] = factor_num_states_cumsum
+                self._factor_group_to_potentials_starts[
+                    factor_group
+                ] = factor_num_configs_cumsum
+
+                factor_num_states_cumsum += factor_group.factor_edges_num_states.sum()
+                factor_num_configs_cumsum += (
+                    factor_group.factor_group_log_potentials.shape[0]
+                )
+
+            self._factor_type_to_msgs_range[factor_type] = (
+                factor_type_num_states_start,
+                factor_num_states_cumsum,
+            )
+            self._factor_type_to_potentials_range[factor_type] = (
+                factor_type_num_configs_start,
+                factor_num_configs_cumsum,
+            )
+
+        self._total_factor_num_states = factor_num_states_cumsum
+        self._total_factor_num_configs = factor_num_configs_cumsum
+
+    @cached_property
+    def wiring(self) -> OrderedDict[Type, factor.Wiring]:
+        """Function to compile wiring for belief propagation.
+
+        If wiring has already beeen compiled, do nothing.
+
+        Returns:
+            A dictionnary mapping each factor type to its wiring.
+        """
+        wiring = collections.OrderedDict(
+            [
+                (
+                    factor_type,
+                    [
+                        factor_group.compile_wiring(self._vars_to_starts)
+                        for factor_group in self._factor_types_to_groups[factor_type]
+                    ],
+                )
+                for factor_type in self._factor_types_to_groups
+            ]
+        )
+        wiring = collections.OrderedDict(
+            [
+                (factor_type, factor_type.concatenate_wirings(wiring[factor_type]))
+                for factor_type in wiring
+            ]
+        )
+        return wiring
+
+    @cached_property
+    def log_potentials(self) -> OrderedDict[Type, np.ndarray]:
+        """Function to compile potential array for belief propagation.
+
+        If potential array has already been compiled, do nothing.
+
+        Returns:
+            A dictionnary mapping each factor type to the array of the log of the potential
+                function for each valid configuration
+        """
+        log_potentials = collections.OrderedDict()
+        for factor_type, factors_groups_by_type in self._factor_types_to_groups.items():
+            if len(factors_groups_by_type) == 0:
+                log_potentials[factor_type] = np.empty((0,))
+            else:
+                log_potentials[factor_type] = np.concatenate(
+                    [
+                        factor_group.factor_group_log_potentials
+                        for factor_group in factors_groups_by_type
+                    ]
+                )
+
+        return log_potentials
+
+    @cached_property
+    def factors(self) -> OrderedDict[Type, Tuple[factor.Factor, ...]]:
+        """Mapping factor type to individual factors in the factor graph.
+        This function is only called on demand when the user requires it."""
+        print(
+            "Factors have not been added to the factor graph yet, this may take a while..."
+        )
+
+        factors: OrderedDict[Type, Tuple[factor.Factor, ...]] = collections.OrderedDict(
+            [
+                (
+                    factor_type,
+                    tuple(
+                        [
+                            factor
+                            for factor_group in self._factor_types_to_groups[
+                                factor_type
+                            ]
+                            for factor in factor_group.factors
+                        ]
+                    ),
+                )
+                for factor_type in self._factor_types_to_groups
+            ]
+        )
+        return factors
+
+    @property
+    def factor_groups(self) -> OrderedDict[Type, List[fgroup.FactorGroup]]:
+        """Tuple of factor groups in the factor graph"""
+        return self._factor_types_to_groups
+
+    @cached_property
+    def fg_state(self) -> FactorGraphState:
+        """Current factor graph state given the added factors."""
+        # Preliminary computations
+        self.compute_offsets()
+        log_potentials = np.concatenate(
+            [self.log_potentials[factor_type] for factor_type in self.log_potentials]
+        )
+        assert isinstance(self.variable_groups, list)
+
+        return FactorGraphState(
+            variable_groups=self.variable_groups,
+            vars_to_starts=self._vars_to_starts,
+            num_var_states=self._num_var_states,
+            total_factor_num_states=self._total_factor_num_states,
+            factor_type_to_msgs_range=copy.copy(self._factor_type_to_msgs_range),
+            factor_type_to_potentials_range=copy.copy(
+                self._factor_type_to_potentials_range
+            ),
+            factor_group_to_potentials_starts=copy.copy(
+                self._factor_group_to_potentials_starts
+            ),
+            log_potentials=log_potentials,
+            wiring=self.wiring,
+        )
+
+    @property
+    def bp_state(self) -> Any:
+        """Relevant information for doing belief propagation."""
+        # Preliminary computations
+        self.compute_offsets()
+
+        from pgmax.infer import bp_state
+
+        return bp_state.BPState(
+            log_potentials=bp_state.LogPotentials(fg_state=self.fg_state),
+            ftov_msgs=bp_state.FToVMessages(fg_state=self.fg_state),
+            evidence=bp_state.Evidence(fg_state=self.fg_state),
+        )
+
+
+@dataclass(frozen=True, eq=False)
+class FactorGraphState:
+    """FactorGraphState.
+
+    Args:
+        variable_groups: VarGroups in the FactorGraph.
+        vars_to_starts: Maps variables to their starting indices in the flat evidence array.
+            flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states]
+            contains evidence to the variable.
+        num_var_states: Total number of variable states.
+        total_factor_num_states: Size of the flat ftov messages array.
+        factor_type_to_msgs_range: Maps factors types to their start and end indices in the flat ftov messages.
+        factor_type_to_potentials_range: Maps factor types to their start and end indices in the flat log potentials.
+        factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials.
+        log_potentials: Flat log potentials array concatenated for each factor type.
+        wiring: Wiring derived for each factor type.
+    """
+
+    variable_groups: Sequence[vgroup.VarGroup]
+    vars_to_starts: Mapping[Tuple[int, int], int]
+    num_var_states: int
+    total_factor_num_states: int
+    factor_type_to_msgs_range: OrderedDict[type, Tuple[int, int]]
+    factor_type_to_potentials_range: OrderedDict[type, Tuple[int, int]]
+    factor_group_to_potentials_starts: OrderedDict[fgroup.FactorGroup, int]
+    log_potentials: OrderedDict[type, None | np.ndarray]
+    wiring: OrderedDict[type, factor.Wiring]
+
+    def __post_init__(self):
+        for field in self.__dataclass_fields__:
+            if isinstance(getattr(self, field), np.ndarray):
+                getattr(self, field).flags.writeable = False
+
+            if isinstance(getattr(self, field), Mapping):
+                object.__setattr__(self, field, MappingProxyType(getattr(self, field)))
diff --git a/pgmax/fgroup/__init__.py b/pgmax/fgroup/__init__.py
new file mode 100644
index 00000000..d9750c0d
--- /dev/null
+++ b/pgmax/fgroup/__init__.py
@@ -0,0 +1,5 @@
+"""A sub-package defining factor groups and containing different types of factor groups."""
+
+from .enum import EnumFactorGroup, PairwiseFactorGroup
+from .fgroup import FactorGroup, SingleFactorGroup
+from .logical import ANDFactorGroup, ORFactorGroup
diff --git a/pgmax/groups/enumeration.py b/pgmax/fgroup/enum.py
similarity index 93%
rename from pgmax/groups/enumeration.py
rename to pgmax/fgroup/enum.py
index dab6e14b..8a6e64c8 100644
--- a/pgmax/groups/enumeration.py
+++ b/pgmax/fgroup/enum.py
@@ -1,4 +1,4 @@
-"""Defines EnumerationFactorGroup and PairwiseFactorGroup."""
+"""Defines EnumFactorGroup and PairwiseFactorGroup."""
 
 import collections
 from dataclasses import dataclass, field
@@ -9,13 +9,14 @@
 import numba as nb
 import numpy as np
 
-from pgmax.factors import enumeration
-from pgmax.fg import groups
+from pgmax.factor import enum
+
+from .fgroup import FactorGroup
 
 
 @dataclass(frozen=True, eq=False)
-class EnumerationFactorGroup(groups.FactorGroup):
-    """Class to represent a group of EnumerationFactors.
+class EnumFactorGroup(FactorGroup):
+    """Class to represent a group of EnumFactors.
 
     All factors in the group are assumed to have the same set of valid configurations and
     the same potential function. Note that the log potential function is assumed to be
@@ -38,7 +39,7 @@ class EnumerationFactorGroup(groups.FactorGroup):
 
     factor_configs: np.ndarray
     log_potentials: Optional[np.ndarray] = None
-    factor_type: Type = field(init=False, default=enumeration.EnumerationFactor)
+    factor_type: Type = field(init=False, default=enum.EnumFactor)
 
     def __post_init__(self):
         super().__post_init__()
@@ -69,7 +70,7 @@ def __post_init__(self):
 
     def _get_variables_to_factors(
         self,
-    ) -> OrderedDict[FrozenSet, enumeration.EnumerationFactor]:
+    ) -> OrderedDict[FrozenSet, enum.EnumFactor]:
         """Function that generates a dictionary mapping set of connected variables to factors.
         This function is only called on demand when the user requires it.
 
@@ -80,7 +81,7 @@ def _get_variables_to_factors(
             [
                 (
                     frozenset(variables_for_factor),
-                    enumeration.EnumerationFactor(
+                    enum.EnumFactor(
                         variables=variables_for_factor,
                         factor_configs=self.factor_configs,
                         log_potentials=np.array(self.log_potentials)[ii],
@@ -170,19 +171,19 @@ def unflatten(
 
 
 @dataclass(frozen=True, eq=False)
-class PairwiseFactorGroup(groups.FactorGroup):
-    """Class to represent a group of EnumerationFactors where each factor connects to
+class PairwiseFactorGroup(FactorGroup):
+    """Class to represent a group of EnumFactors where each factor connects to
     two different variables.
 
     All factors in the group are assumed to be such that all possible configuration of the two
     variable's states are valid. Additionally, all factors in the group are assumed to share
-    the same potential function and to be connected to variables from VariableGroups within
-    one CompositeVariableGroup.
+    the same potential function and to be connected to variables from VarGroups within
+    one CompositeVarGroup.
 
     Args:
         log_potential_matrix: array of shape (var1.num_states, var2.num_states),
-            where var1 and var2 are the 2 VariableGroups (that may refer to the same
-            VariableGroup) whose names are present in each sub-list from self.variables_for_factors.
+            where var1 and var2 are the 2 VarGroups (that may refer to the same
+            VarGroup) whose names are present in each sub-list from self.variables_for_factors.
         factor_type: Factor type shared by all the Factors in the FactorGroup.
 
     Raises:
@@ -196,7 +197,7 @@ class PairwiseFactorGroup(groups.FactorGroup):
     """
 
     log_potential_matrix: Optional[np.ndarray] = None
-    factor_type: Type = field(init=False, default=enumeration.EnumerationFactor)
+    factor_type: Type = field(init=False, default=enum.EnumFactor)
 
     def __post_init__(self):
         super().__post_init__()
@@ -275,7 +276,7 @@ def __post_init__(self):
 
     def _get_variables_to_factors(
         self,
-    ) -> OrderedDict[FrozenSet, enumeration.EnumerationFactor]:
+    ) -> OrderedDict[FrozenSet, enum.EnumFactor]:
         """Function that generates a dictionary mapping set of connected variables to factors.
         This function is only called on demand when the user requires it.
 
@@ -286,7 +287,7 @@ def _get_variables_to_factors(
             [
                 (
                     frozenset(variable_for_factor),
-                    enumeration.EnumerationFactor(
+                    enum.EnumFactor(
                         variables=variable_for_factor,
                         factor_configs=self.factor_configs,
                         log_potentials=self.log_potentials[ii],
diff --git a/pgmax/fg/groups.py b/pgmax/fgroup/fgroup.py
similarity index 71%
rename from pgmax/fg/groups.py
rename to pgmax/fgroup/fgroup.py
index e772095a..cbb240e8 100644
--- a/pgmax/fg/groups.py
+++ b/pgmax/fgroup/fgroup.py
@@ -1,4 +1,4 @@
-"""A module containing the base classes for variable and factor groups in a Factor Graph."""
+"""A module containing the base classes for factor groups in a Factor Graph."""
 
 import inspect
 from dataclasses import dataclass, field
@@ -19,106 +19,9 @@
 import jax.numpy as jnp
 import numpy as np
 
-import pgmax.fg.nodes as nodes
+from pgmax import factor
 from pgmax.utils import cached_property
 
-MAX_SIZE = 1e9
-
-
-@total_ordering
-@dataclass(frozen=True, eq=False)
-class VariableGroup:
-    """Class to represent a group of variables.
-    Each variable is represented via a tuple of the form (variable hash, variable num_states)
-
-    Arguments:
-        num_states: An integer or an array specifying the number of states of the variables
-            in this VariableGroup
-    """
-
-    num_states: Union[int, np.ndarray]
-
-    def __post_init__(self):
-        # Only compute the hash once, which is guaranteed to be an int64
-        this_id = id(self) % 2**32
-        _hash = this_id * int(MAX_SIZE)
-        assert _hash < 2**63
-        object.__setattr__(self, "_hash", _hash)
-
-    def __hash__(self):
-        return self._hash
-
-    def __eq__(self, other):
-        return hash(self) == hash(other)
-
-    def __lt__(self, other):
-        return hash(self) < hash(other)
-
-    def __getitem__(self, val: Any) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
-        """Given a variable name, index, or a group of variable indices, retrieve the associated variable(s).
-        Each variable is returned via a tuple of the form (variable hash, variable num_states)
-
-        Args:
-            val: a variable index, slice, or name
-
-        Returns:
-            A single variable or a list of variables
-        """
-        raise NotImplementedError(
-            "Please subclass the VariableGroup class and override this method"
-        )
-
-    @cached_property
-    def variable_hashes(self) -> np.ndarray:
-        """Function that generates a variable hash for each variable
-
-        Returns:
-            Array of variables hashes.
-        """
-        raise NotImplementedError(
-            "Please subclass the VariableGroup class and override this method"
-        )
-
-    @cached_property
-    def variables(self) -> List[Tuple[int, int]]:
-        """Function that returns the list of all variables in the VariableGroup.
-        Each variable is represented by a tuple of the form (variable hash, variable num_states)
-
-        Returns:
-            List of variables in the VariableGroup
-        """
-        assert isinstance(self.variable_hashes, np.ndarray)
-        assert isinstance(self.num_states, np.ndarray)
-        vars_hashes = self.variable_hashes.flatten()
-        vars_num_states = self.num_states.flatten()
-        return list(zip(vars_hashes, vars_num_states))
-
-    def flatten(self, data: Any) -> jnp.ndarray:
-        """Function that turns meaningful structured data into a flat data array for internal use.
-
-        Args:
-            data: Meaningful structured data
-
-        Returns:
-            A flat jnp.array for internal use
-        """
-        raise NotImplementedError(
-            "Please subclass the VariableGroup class and override this method"
-        )
-
-    def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any:
-        """Function that recovers meaningful structured data from internal flat data array
-
-        Args:
-            flat_data: Internal flat data array.
-
-        Returns:
-            Meaningful structured data
-        """
-        raise NotImplementedError(
-            "Please subclass the VariableGroup class and override this method"
-        )
-
 
 @total_ordering
 @dataclass(frozen=True, eq=False)
@@ -208,7 +111,7 @@ def factor_edges_num_states(self) -> np.ndarray:
         return factor_edges_num_states
 
     @cached_property
-    def _variables_to_factors(self) -> Mapping[FrozenSet, nodes.Factor]:
+    def _variables_to_factors(self) -> Mapping[FrozenSet, factor.Factor]:
         """Function to compile potential array for the factor group.
         This function is only called on demand when the user requires it.
 
@@ -223,7 +126,7 @@ def factor_group_log_potentials(self) -> np.ndarray:
         return self.log_potentials.flatten()
 
     @cached_property
-    def factors(self) -> Tuple[nodes.Factor, ...]:
+    def factors(self) -> Tuple[factor.Factor, ...]:
         """Returns all factors in the factor group.
         This function is only called on demand when the user requires it."""
         return tuple(self._variables_to_factors.values())
@@ -304,7 +207,7 @@ class SingleFactorGroup(FactorGroup):
         factor: the single factor in the SingleFactorGroup
     """
 
-    factor: nodes.Factor
+    factor: factor.Factor
 
     def __post_init__(self):
         super().__post_init__()
@@ -329,7 +232,7 @@ def __post_init__(self):
 
     def _get_variables_to_factors(
         self,
-    ) -> OrderedDict[FrozenSet, nodes.Factor]:
+    ) -> OrderedDict[FrozenSet, factor.Factor]:
         """Function that generates a dictionary mapping names to factors.
 
         Returns:
diff --git a/pgmax/groups/logical.py b/pgmax/fgroup/logical.py
similarity index 95%
rename from pgmax/groups/logical.py
rename to pgmax/fgroup/logical.py
index 618a0526..afb895ff 100644
--- a/pgmax/groups/logical.py
+++ b/pgmax/fgroup/logical.py
@@ -6,12 +6,13 @@
 
 import numpy as np
 
-from pgmax.factors import logical
-from pgmax.fg import groups
+from pgmax.factor import logical
+
+from .fgroup import FactorGroup
 
 
 @dataclass(frozen=True, eq=False)
-class LogicalFactorGroup(groups.FactorGroup):
+class LogicalFactorGroup(FactorGroup):
     """Class to represent a group of LogicalFactors.
 
     All factors in the group are assumed to have the same edge_states_offset.
diff --git a/pgmax/groups/__init__.py b/pgmax/groups/__init__.py
deleted file mode 100644
index e69de29b..00000000
diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py
deleted file mode 100644
index 357d681a..00000000
--- a/pgmax/groups/variables.py
+++ /dev/null
@@ -1,302 +0,0 @@
-"""A module containing the variables group classes inheriting from the base VariableGroup."""
-
-from dataclasses import dataclass
-from typing import Any, Dict, Hashable, List, Mapping, Tuple, Union
-
-import jax
-import jax.numpy as jnp
-import numpy as np
-
-from pgmax.fg import groups
-from pgmax.utils import cached_property
-
-
-@dataclass(frozen=True, eq=False)
-class NDVariableArray(groups.VariableGroup):
-    """Subclass of VariableGroup for n-dimensional grids of variables.
-
-    Args:
-        num_states: An integer or an array specifying the number of states of the
-            variables in this VariableGroup
-        shape: Tuple specifying the size of each dimension of the grid (similar to
-            the notion of a NumPy ndarray shape)
-    """
-
-    shape: Tuple[int, ...]
-
-    def __post_init__(self):
-        super().__post_init__()
-
-        max_size = int(groups.MAX_SIZE)
-        if np.prod(self.shape) > max_size:
-            raise ValueError(
-                f"Currently only support NDVariableArray of size smaller than {max_size}. Got {np.prod(self.shape)}"
-            )
-
-        if np.isscalar(self.num_states):
-            num_states = np.full(self.shape, fill_value=self.num_states, dtype=np.int64)
-            object.__setattr__(self, "num_states", num_states)
-        elif isinstance(self.num_states, np.ndarray) and np.issubdtype(
-            self.num_states.dtype, int
-        ):
-            if self.num_states.shape != self.shape:
-                raise ValueError(
-                    f"Expected num_states shape {self.shape}. Got {self.num_states.shape}."
-                )
-        else:
-            raise ValueError(
-                "num_states should be an integer or a NumPy array of dtype int"
-            )
-
-    def __getitem__(
-        self, val: Union[int, slice, Tuple]
-    ) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
-        """Given an index or a slice, retrieve the associated variable(s).
-        Each variable is returned via a tuple of the form (variable hash, number of states)
-
-        Note: Relies on numpy indexation to throw IndexError if val is out-of-bounds
-
-        Args:
-            val: a variable index or slice
-
-        Returns:
-            A single variable or a list of variables
-        """
-        assert isinstance(self.num_states, np.ndarray)
-
-        if isinstance(val, slice) or (
-            isinstance(val, tuple) and isinstance(val[0], slice)
-        ):
-            assert isinstance(self.num_states, np.ndarray)
-            vars_names = self.variable_hashes[val].flatten()
-            vars_num_states = self.num_states[val].flatten()
-            return list(zip(vars_names, vars_num_states))
-
-        return (self.variable_hashes[val], self.num_states[val])
-
-    @cached_property
-    def variable_hashes(self) -> np.ndarray:
-        """Function that generates a variable hash for each variable
-
-        Returns:
-            Array of variables hashes.
-        """
-        indices = np.reshape(np.arange(np.product(self.shape)), self.shape)
-        return self.__hash__() + indices
-
-    def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
-        """Function that turns meaningful structured data into a flat data array for internal use.
-
-        Args:
-            data: Meaningful structured data. Should be an array of shape self.shape (for e.g. MAP decodings)
-                or self.shape + (self.num_states.max(),) (for e.g. evidence, beliefs).
-
-        Returns:
-            A flat jnp.array for internal use
-
-        Raises:
-            ValueError: If the data is not of the correct shape.
-        """
-        assert isinstance(self.num_states, np.ndarray)
-
-        if data.shape == self.shape:
-            return jax.device_put(data).flatten()
-        elif data.shape == self.shape + (self.num_states.max(),):
-            return jax.device_put(
-                data[np.arange(data.shape[-1]) < self.num_states[..., None]]
-            )
-        else:
-            raise ValueError(
-                f"data should be of shape {self.shape} or {self.shape + (self.num_states.max(),)}. "
-                f"Got {data.shape}."
-            )
-
-    def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
-        """Function that recovers meaningful structured data from internal flat data array
-
-        Args:
-            flat_data: Internal flat data array.
-
-        Returns:
-            Meaningful structured data. An array of shape self.shape (for e.g. MAP decodings)
-                or an array of shape self.shape + (self.num_states,) (for e.g. evidence, beliefs).
-
-        Raises:
-            ValueError if:
-                (1) flat_data is not a 1D array
-                (2) flat_data is not of the right shape
-        """
-        assert isinstance(self.num_states, np.ndarray)
-
-        if flat_data.ndim != 1:
-            raise ValueError(
-                f"Can only unflatten 1D array. Got a {flat_data.ndim}D array."
-            )
-
-        if flat_data.size == np.product(self.shape):
-            data = flat_data.reshape(self.shape)
-        elif flat_data.size == self.num_states.sum():
-            data = jnp.full(
-                shape=self.shape + (self.num_states.max(),), fill_value=jnp.nan
-            )
-            data = data.at[np.arange(data.shape[-1]) < self.num_states[..., None]].set(
-                flat_data
-            )
-        else:
-            raise ValueError(
-                f"flat_data should be compatible with shape {self.shape} or {self.shape + (self.num_states.max(),)}. "
-                f"Got {flat_data.shape}."
-            )
-
-        return data
-
-
-@dataclass(frozen=True, eq=False)
-class VariableDict(groups.VariableGroup):
-    """A variable dictionary that contains a set of variables
-
-    Args:
-        num_states: The size of the variables in this VariableGroup
-        variable_names: A tuple of all the names of the variables in this VariableGroup.
-    """
-
-    variable_names: Tuple[Any, ...]
-
-    def __post_init__(self):
-        super().__post_init__()
-
-        if np.isscalar(self.num_states):
-            num_states = np.full(
-                (len(self.variable_names),), fill_value=self.num_states, dtype=np.int64
-            )
-            object.__setattr__(self, "num_states", num_states)
-        elif isinstance(self.num_states, np.ndarray) and np.issubdtype(
-            self.num_states.dtype, int
-        ):
-            if self.num_states.shape != len(self.variable_names):
-                raise ValueError(
-                    f"Expected num_states shape ({len(self.variable_names)},). Got {self.num_states.shape}."
-                )
-        else:
-            raise ValueError(
-                "num_states should be an integer or a NumPy array of dtype int"
-            )
-
-    @cached_property
-    def variable_hashes(self) -> np.ndarray:
-        """Function that generates a variable hash for each variable
-
-        Returns:
-            Array of variables hashes.
-        """
-        indices = np.arange(len(self.variable_names))
-        return self.__hash__() + indices
-
-    def __getitem__(self, var_name: Any) -> Tuple[int, int]:
-        """Given a variable name retrieve the associated variable, returned via a tuple of the form
-        (variable hash, number of states)
-
-        Args:
-            val: a variable name
-
-        Returns:
-            The queried variable
-        """
-        assert isinstance(self.num_states, np.ndarray)
-        if var_name not in self.variable_names:
-            raise ValueError(f"Variable {var_name} is not in VariableDict")
-
-        var_idx = self.variable_names.index(var_name)
-        return (self.variable_hashes[var_idx], self.num_states[var_idx])
-
-    def flatten(
-        self, data: Mapping[Any, Union[np.ndarray, jnp.ndarray]]
-    ) -> jnp.ndarray:
-        """Function that turns meaningful structured data into a flat data array for internal use.
-
-        Args:
-            data: Meaningful structured data. Should be a mapping with names from self.variable_names.
-                Each value should be an array of shape (1,) (for e.g. MAP decodings) or
-                (self.num_states,) (for e.g. evidence, beliefs).
-
-        Returns:
-            A flat jnp.array for internal use
-
-        Raises:
-            ValueError if:
-                (1) data is referring to a non-existing variable
-                (2) data is not of the correct shape
-        """
-        assert isinstance(self.num_states, np.ndarray)
-
-        for var_name in data:
-            if var_name not in self.variable_names:
-                raise ValueError(
-                    f"data is referring to a non-existent variable {var_name}."
-                )
-
-            var_idx = self.variable_names.index(var_name)
-            if data[var_name].shape != (self.num_states[var_idx],) and data[
-                var_name
-            ].shape != (1,):
-                raise ValueError(
-                    f"Variable {var_name} expects a data array of shape "
-                    f"{(self.num_states[var_idx],)} or (1,). Got {data[var_name].shape}."
-                )
-
-        flat_data = jnp.concatenate(
-            [data[var_name].flatten() for var_name in self.variable_names]
-        )
-        return flat_data
-
-    def unflatten(
-        self, flat_data: Union[np.ndarray, jnp.ndarray]
-    ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]:
-        """Function that recovers meaningful structured data from internal flat data array
-
-        Args:
-            flat_data: Internal flat data array.
-
-        Returns:
-            Meaningful structured data. Should be a mapping with names from self.variable_names.
-                Each value should be an array of shape (1,) (for e.g. MAP decodings) or
-                (self.num_states,) (for e.g. evidence, beliefs).
-
-        Raises:
-            ValueError if:
-                (1) flat_data is not a 1D array
-                (2) flat_data is not of the right shape
-        """
-        assert isinstance(self.num_states, np.ndarray)
-
-        if flat_data.ndim != 1:
-            raise ValueError(
-                f"Can only unflatten 1D array. Got a {flat_data.ndim}D array."
-            )
-
-        num_variables = len(self.variable_names)
-        num_variable_states = self.num_states.sum()
-        if flat_data.shape[0] == num_variables:
-            use_num_states = False
-        elif flat_data.shape[0] == num_variable_states:
-            use_num_states = True
-        else:
-            raise ValueError(
-                f"flat_data should be either of shape (num_variables(={len(self.variables)}),), "
-                f"or (num_variable_states(={num_variable_states}),). "
-                f"Got {flat_data.shape}"
-            )
-
-        start = 0
-        data = {}
-        for var_name in self.variable_names:
-            if use_num_states:
-                var_idx = self.variable_names.index(var_name)
-                var_num_states = self.num_states[var_idx]
-                data[var_name] = flat_data[start : start + var_num_states]
-                start += var_num_states
-            else:
-                data[var_name] = flat_data[np.array([start])]
-                start += 1
-
-        return data
diff --git a/pgmax/infer/__init__.py b/pgmax/infer/__init__.py
new file mode 100644
index 00000000..fb663754
--- /dev/null
+++ b/pgmax/infer/__init__.py
@@ -0,0 +1,4 @@
+"""A sub-package containing functions to perform belief propagation."""
+
+from .bp import BP, decode_map_states, get_marginals
+from .bp_state import BPArrays, BPState, Evidence, FToVMessages, LogPotentials
diff --git a/pgmax/infer/bp.py b/pgmax/infer/bp.py
new file mode 100644
index 00000000..67aceb3a
--- /dev/null
+++ b/pgmax/infer/bp.py
@@ -0,0 +1,360 @@
+"""A module containing the core message-passing functions for belief propagation"""
+
+import functools
+import inspect
+from dataclasses import dataclass
+from typing import Any, Callable, Dict, Hashable, Mapping, Optional, Tuple
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+from jax.scipy.special import logsumexp
+
+from pgmax.factor import FAC_TO_VAR_UPDATES
+
+from . import bp_state as bpstate
+from . import bp_utils
+from .bp_state import BPArrays, BPState, Evidence, FToVMessages, LogPotentials
+
+
+@dataclass(frozen=True, eq=False)
+class BeliefPropagation:
+    """Belief propagation functions.
+
+    Arguments:
+        init: Function to create log_potentials, ftov_msgs and evidence.
+            Args:
+                log_potentials_updates: Optional dictionary containing log_potentials updates.
+                ftov_msgs_updates: Optional dictionary containing ftov_msgs updates.
+                evidence_updates: Optional dictionary containing evidence updates.
+            Returns:
+                A BPArrays with the log_potentials, ftov_msgs and evidence.
+
+        update: Function to update log_potentials, ftov_msgs and evidence.
+            Args:
+                bp_arrays: Optional arrays of log_potentials, ftov_msgs, evidence.
+                log_potentials_updates: Optional dictionary containing log_potentials updates.
+                ftov_msgs_updates: Optional dictionary containing ftov_msgs updates.
+                evidence_updates: Optional dictionary containing evidence updates.
+            Returns:
+                A BPArrays with the updated log_potentials, ftov_msgs and evidence.
+
+        run_bp: Function to run belief propagation for num_iters with a damping_factor.
+            Args:
+                bp_arrays: Initial arrays of log_potentials, ftov_msgs, evidence.
+                num_iters: Number of belief propagation iterations.
+                damping: The damping factor to use for message updates between one timestep and the next.
+            Returns:
+                A BPArrays containing the updated ftov_msgs.
+
+        get_bp_state: Function to reconstruct the BPState from a BPArrays.
+            Args:
+                bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence.
+            Returns:
+                The reconstructed BPState
+
+        get_beliefs: Function to calculate beliefs from a BPArrays.
+            Args:
+                bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence.
+            Returns:
+                beliefs: Beliefs returned by belief propagation.
+    """
+
+    init: Callable
+    update: Callable
+    run_bp: Callable
+    to_bp_state: Callable
+    get_beliefs: Callable
+
+
+def BP(bp_state: BPState, temperature: float = 0.0) -> BeliefPropagation:
+    """Function for generating belief propagation functions.
+
+    Args:
+        bp_state: Belief propagation state.
+        temperature: Temperature for loopy belief propagation.
+            1.0 corresponds to sum-product, 0.0 corresponds to max-product.
+
+    Returns:
+        Belief propagation functions.
+    """
+    wiring = bp_state.fg_state.wiring
+    edges_num_states = np.concatenate(
+        [wiring[factor_type].edges_num_states for factor_type in FAC_TO_VAR_UPDATES]
+    )
+    max_msg_size = int(np.max(edges_num_states))
+
+    var_states_for_edges = np.concatenate(
+        [wiring[factor_type].var_states_for_edges for factor_type in FAC_TO_VAR_UPDATES]
+    )
+
+    # Inference argumnets per factor type
+    inference_arguments: Dict[type, Mapping] = {}
+    for factor_type in FAC_TO_VAR_UPDATES:
+        this_inference_arguments = inspect.getfullargspec(
+            FAC_TO_VAR_UPDATES[factor_type]
+        ).args
+        this_inference_arguments.remove("vtof_msgs")
+        this_inference_arguments.remove("log_potentials")
+        this_inference_arguments.remove("temperature")
+        this_inference_arguments = {
+            key: getattr(wiring[factor_type], key) for key in this_inference_arguments
+        }
+        inference_arguments[factor_type] = this_inference_arguments
+
+    factor_type_to_msgs_range = bp_state.fg_state.factor_type_to_msgs_range
+    factor_type_to_potentials_range = bp_state.fg_state.factor_type_to_potentials_range
+
+    def update(
+        bp_arrays: Optional[BPArrays] = None,
+        log_potentials_updates: Optional[Dict[Any, jnp.ndarray]] = None,
+        ftov_msgs_updates: Optional[Dict[Any, jnp.ndarray]] = None,
+        evidence_updates: Optional[Dict[Any, jnp.ndarray]] = None,
+    ) -> BPArrays:
+        """Function to update belief propagation log_potentials, ftov_msgs, evidence.
+
+        Args:
+            bp_arrays: Optional arrays of log_potentials, ftov_msgs, evidence.
+            log_potentials_updates: Optional dictionary containing log_potentials updates.
+            ftov_msgs_updates: Optional dictionary containing ftov_msgs updates.
+            evidence_updates: Optional dictionary containing evidence updates.
+
+        Returns:
+            A BPArrays with the updated log_potentials, ftov_msgs and evidence.
+        """
+        if bp_arrays is not None:
+            log_potentials = bp_arrays.log_potentials
+            evidence = bp_arrays.evidence
+            ftov_msgs = bp_arrays.ftov_msgs
+        else:
+            log_potentials = jax.device_put(bp_state.log_potentials.value)
+            ftov_msgs = bp_state.ftov_msgs.value
+            evidence = bp_state.evidence.value
+
+        if log_potentials_updates is not None:
+            log_potentials = bpstate.update_log_potentials(
+                log_potentials, log_potentials_updates, bp_state.fg_state
+            )
+
+        if ftov_msgs_updates is not None:
+            ftov_msgs = bpstate.update_ftov_msgs(
+                ftov_msgs, ftov_msgs_updates, bp_state.fg_state
+            )
+
+        if evidence_updates is not None:
+            evidence = bpstate.update_evidence(
+                evidence, evidence_updates, bp_state.fg_state
+            )
+
+        return BPArrays(
+            log_potentials=log_potentials, ftov_msgs=ftov_msgs, evidence=evidence
+        )
+
+    def run_bp(
+        bp_arrays: BPArrays,
+        num_iters: int,
+        damping: float = 0.5,
+    ) -> BPArrays:
+        """Function to run belief propagation for num_iters with a damping_factor.
+
+        Args:
+            bp_arrays: Initial arrays of log_potentials, ftov_msgs, evidence.
+            num_iters: Number of belief propagation iterations.
+            damping: The damping factor to use for message updates between one timestep and the next.
+
+        Returns:
+            A BPArrays containing the updated ftov_msgs.
+        """
+        log_potentials = bp_arrays.log_potentials
+        evidence = bp_arrays.evidence
+        ftov_msgs = bp_arrays.ftov_msgs
+
+        # Normalize the messages to ensure the maximum value is 0.
+        ftov_msgs = normalize_and_clip_msgs(ftov_msgs, edges_num_states, max_msg_size)
+
+        @jax.checkpoint
+        def update(msgs: jnp.ndarray, _) -> Tuple[jnp.ndarray, None]:
+            # Compute new variable to factor messages by message passing
+            vtof_msgs = pass_var_to_fac_messages(
+                msgs,
+                evidence,
+                var_states_for_edges,
+            )
+            ftov_msgs = jnp.zeros_like(vtof_msgs)
+            for factor_type in FAC_TO_VAR_UPDATES:
+                msgs_start, msgs_end = factor_type_to_msgs_range[factor_type]
+                potentials_start, potentials_end = factor_type_to_potentials_range[
+                    factor_type
+                ]
+                ftov_msgs_type = FAC_TO_VAR_UPDATES[factor_type](
+                    vtof_msgs=vtof_msgs[msgs_start:msgs_end],
+                    log_potentials=log_potentials[potentials_start:potentials_end],
+                    temperature=temperature,
+                    **inference_arguments[factor_type],
+                )
+                ftov_msgs = ftov_msgs.at[msgs_start:msgs_end].set(ftov_msgs_type)
+
+            # Use the results of message passing to perform damping and
+            # update the factor to variable messages
+            delta_msgs = ftov_msgs - msgs
+            msgs = msgs + (1 - damping) * delta_msgs
+            # Normalize and clip these damped, updated messages before returning them.
+            msgs = normalize_and_clip_msgs(msgs, edges_num_states, max_msg_size)
+            return msgs, None
+
+        ftov_msgs, _ = jax.lax.scan(update, ftov_msgs, None, num_iters)
+
+        return BPArrays(
+            log_potentials=log_potentials, ftov_msgs=ftov_msgs, evidence=evidence
+        )
+
+    def to_bp_state(bp_arrays: BPArrays) -> BPState:
+        """Function to reconstruct the BPState from a BPArrays
+
+        Args:
+            bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence.
+
+        Returns:
+            The reconstructed BPState
+        """
+        return BPState(
+            log_potentials=LogPotentials(
+                fg_state=bp_state.fg_state, value=bp_arrays.log_potentials
+            ),
+            ftov_msgs=FToVMessages(
+                fg_state=bp_state.fg_state,
+                value=bp_arrays.ftov_msgs,
+            ),
+            evidence=Evidence(fg_state=bp_state.fg_state, value=bp_arrays.evidence),
+        )
+
+    def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]:
+        """Function that returns unflattened beliefs from the flat beliefs
+
+        Args:
+            flat_beliefs: Flattened array of beliefs
+            variable_groups: All the variable groups in the FactorGraph.
+        """
+        beliefs = {}
+        start = 0
+        for variable_group in variable_groups:
+            num_states = variable_group.num_states
+            assert isinstance(num_states, np.ndarray)
+            length = num_states.sum()
+
+            beliefs[variable_group] = variable_group.unflatten(
+                flat_beliefs[start : start + length]
+            )
+            start += length
+        return beliefs
+
+    @jax.jit
+    def get_beliefs(bp_arrays: BPArrays) -> Dict[Hashable, Any]:
+        """Function to calculate beliefs from a BPArrays
+
+        Args:
+            bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence.
+
+        Returns:
+            beliefs: Beliefs returned by belief propagation.
+        """
+
+        flat_beliefs = (
+            jax.device_put(bp_arrays.evidence)
+            .at[jax.device_put(var_states_for_edges)]
+            .add(bp_arrays.ftov_msgs)
+        )
+        return unflatten_beliefs(flat_beliefs, bp_state.fg_state.variable_groups)
+
+    bp = BeliefPropagation(
+        init=functools.partial(update, None),
+        update=update,
+        run_bp=run_bp,
+        to_bp_state=to_bp_state,
+        get_beliefs=get_beliefs,
+    )
+    return bp
+
+
+@jax.jit
+def pass_var_to_fac_messages(
+    ftov_msgs: jnp.array,
+    evidence: jnp.array,
+    var_states_for_edges: jnp.array,
+) -> jnp.array:
+    """Passes messages from Variables to Factors.
+
+    The update works by first summing the evidence and neighboring factor to variable messages for
+    each variable. Next, it subtracts messages from the correct elements of this sum to yield the
+    correct updated messages.
+
+    Args:
+        ftov_msgs: Array of shape (num_edge_state,). This holds all the flattened factor to variable
+            messages.
+        evidence: Array of shape (num_var_states,) representing the flattened evidence for each variable
+        var_states_for_edges: Array of shape (num_edge_states,)
+            Global variable state indices for each edge state
+    Returns:
+        Array of shape (num_edge_state,). This holds all the flattened variable to factor messages.
+    """
+    var_sums_arr = evidence.at[var_states_for_edges].add(ftov_msgs)
+    vtof_msgs = var_sums_arr[var_states_for_edges] - ftov_msgs
+    return vtof_msgs
+
+
+@functools.partial(jax.jit, static_argnames=("max_msg_size"))
+def normalize_and_clip_msgs(
+    msgs: jnp.ndarray,
+    edges_num_states: jnp.ndarray,
+    max_msg_size: int,
+) -> jnp.ndarray:
+    """Performs normalization and clipping of flattened messages
+
+    Normalization is done by subtracting the maximum value of every message from every element of every message,
+    clipping is done to keep every message value in the range [-1000, 0].
+
+    Args:
+        msgs: Array of shape (num_edge_state,). This holds all the flattened factor to variable messages.
+        edges_num_states: Array of shape (num_edges,). Number of states for the variables connected to each edge
+        max_msg_size: the max of edges_num_states
+
+    Returns:
+        Array of shape (num_edge_state,). This holds all the flattened factor to variable messages
+            after normalization and clipping
+    """
+    msgs = msgs - jnp.repeat(
+        bp_utils.segment_max_opt(msgs, edges_num_states, max_msg_size),
+        edges_num_states,
+        total_repeat_length=msgs.shape[0],
+    )
+    # Clip message values to be always greater than -1000
+    msgs = jnp.clip(msgs, -1000, None)
+    return msgs
+
+
+@jax.jit
+def decode_map_states(beliefs: Dict[Hashable, Any]) -> Any:
+    """Function to decode MAP states given the calculated beliefs.
+
+    Args:
+        beliefs: An array or a PyTree container containing beliefs for different variables.
+
+    Returns:
+        An array or a PyTree container containing the MAP states for different variables.
+    """
+    return jax.tree_util.tree_map(lambda x: jnp.argmax(x, axis=-1), beliefs)
+
+
+@jax.jit
+def get_marginals(beliefs: Dict[Hashable, Any]) -> Any:
+    """Function to get marginal probabilities given the calculated beliefs.
+
+    Args:
+        beliefs: An array or a PyTree container containing beliefs for different variables.
+
+    Returns:
+        An array or a PyTree container containing the marginal probabilities different variables.
+    """
+    return jax.tree_util.tree_map(
+        lambda x: jnp.exp(x - logsumexp(x, axis=-1, keepdims=True)), beliefs
+    )
diff --git a/pgmax/infer/bp_state.py b/pgmax/infer/bp_state.py
new file mode 100644
index 00000000..3d3b4d79
--- /dev/null
+++ b/pgmax/infer/bp_state.py
@@ -0,0 +1,384 @@
+"Defines container classes for belief propagation states, and for the relevant flat arrays used in belief propagation."
+
+import functools
+from dataclasses import asdict, dataclass
+from typing import Any, Dict, Optional, Tuple, Union, cast
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+
+from pgmax import fgraph, fgroup
+
+
+@jax.tree_util.register_pytree_node_class
+@dataclass(frozen=True, eq=False)
+class BPArrays:
+    """Container for the relevant flat arrays used in belief propagation.
+
+    Args:
+        log_potentials: Flat log potentials array.
+        ftov_msgs: Flat factor to variable messages array.
+        evidence: Flat evidence array.
+    """
+
+    log_potentials: Union[np.ndarray, jnp.ndarray]
+    ftov_msgs: Union[np.ndarray, jnp.ndarray]
+    evidence: Union[np.ndarray, jnp.ndarray]
+
+    def __post_init__(self):
+        for field in self.__dataclass_fields__:
+            if isinstance(getattr(self, field), np.ndarray):
+                getattr(self, field).flags.writeable = False
+
+    def tree_flatten(self):
+        return jax.tree_util.tree_flatten(asdict(self))
+
+    @classmethod
+    def tree_unflatten(cls, aux_data, children):
+        return cls(**aux_data.unflatten(children))
+
+
+@functools.partial(jax.jit, static_argnames="fg_state")
+def update_log_potentials(
+    log_potentials: jnp.ndarray,
+    updates: Dict[Any, jnp.ndarray],
+    fg_state: fgraph.FactorGraphState,
+) -> jnp.ndarray:
+    """Function to update log_potentials.
+
+    Args:
+        log_potentials: A flat jnp array containing log_potentials.
+        updates: A dictionary containing updates for log_potentials
+        fg_state: Factor graph state
+
+    Returns:
+        A flat jnp array containing updated log_potentials.
+
+    Raises: ValueError if
+        (1) Provided log_potentials shape does not match the expected log_potentials shape.
+        (2) Provided name is not valid for log_potentials updates.
+    """
+    for factor_group, data in updates.items():
+        if factor_group in fg_state.factor_group_to_potentials_starts:
+            flat_data = factor_group.flatten(data)
+            if flat_data.shape != factor_group.factor_group_log_potentials.shape:
+                raise ValueError(
+                    f"Expected log potentials shape {factor_group.factor_group_log_potentials.shape} "
+                    f"for factor group. Got incompatible data shape {data.shape}."
+                )
+
+            start = fg_state.factor_group_to_potentials_starts[factor_group]
+            log_potentials = log_potentials.at[start : start + flat_data.shape[0]].set(
+                flat_data
+            )
+        else:
+            raise ValueError("Invalid FactorGroup for log potentials updates.")
+
+    return log_potentials
+
+
+@dataclass(frozen=True, eq=False)
+class LogPotentials:
+    """Class for storing and manipulating log potentials.
+
+    Args:
+        fg_state: Factor graph state
+        value: Optionally specify an initial value
+
+    Raises:
+        ValueError: If provided value shape does not match the expected log_potentials shape.
+    """
+
+    fg_state: fgraph.FactorGraphState
+    value: Optional[np.ndarray] = None
+
+    def __post_init__(self):
+        if self.value is None:
+            object.__setattr__(self, "value", self.fg_state.log_potentials)
+        else:
+            if not self.value.shape == self.fg_state.log_potentials.shape:
+                raise ValueError(
+                    f"Expected log potentials shape {self.fg_state.log_potentials.shape}. "
+                    f"Got {self.value.shape}."
+                )
+
+            object.__setattr__(self, "value", self.value)
+
+    def __getitem__(self, factor_group: fgroup.FactorGroup) -> np.ndarray:
+        """Function to query log potentials for a FactorGroup.
+
+        Args:
+            factor_group: Queried FactorGroup
+
+        Returns:
+            The queried log potentials.
+        """
+        value = cast(np.ndarray, self.value)
+        if factor_group in self.fg_state.factor_group_to_potentials_starts:
+            start = self.fg_state.factor_group_to_potentials_starts[factor_group]
+            log_potentials = value[
+                start : start + factor_group.factor_group_log_potentials.shape[0]
+            ]
+        else:
+            raise ValueError("Invalid FactorGroup queried to access log potentials.")
+        return log_potentials
+
+    def __setitem__(
+        self,
+        factor_group: fgroup.FactorGroup,
+        data: Union[np.ndarray, jnp.ndarray],
+    ):
+        """Set the log potentials for a FactorGroup
+
+        Args:
+            factor_group: FactorGroup
+            data: Array containing the log potentials for the FactorGroup
+        """
+        object.__setattr__(
+            self,
+            "value",
+            np.asarray(
+                update_log_potentials(
+                    jax.device_put(self.value),
+                    {factor_group: jax.device_put(data)},
+                    self.fg_state,
+                )
+            ),
+        )
+
+
+@functools.partial(jax.jit, static_argnames="fg_state")
+def update_ftov_msgs(
+    ftov_msgs: jnp.ndarray,
+    updates: Dict[Any, jnp.ndarray],
+    fg_state: fgraph.FactorGraphState,
+) -> jnp.ndarray:
+    """Function to update ftov_msgs.
+
+    Args:
+        ftov_msgs: A flat jnp array containing ftov_msgs.
+        updates: A dictionary containing updates for ftov_msgs
+        fg_state: Factor graph state
+
+    Returns:
+        A flat jnp array containing updated ftov_msgs.
+
+    Raises: ValueError if:
+        (1) provided ftov_msgs shape does not match the expected ftov_msgs shape.
+        (2) provided variable is not in the FactorGraph.
+    """
+    for variable, data in updates.items():
+        if variable in fg_state.vars_to_starts:
+            if data.shape != (variable[1],):
+                raise ValueError(
+                    f"Given belief shape {data.shape} does not match expected "
+                    f"shape {(variable[1],)} for variable {variable}."
+                )
+
+            var_states_for_edges = np.concatenate(
+                [
+                    wiring_by_type.var_states_for_edges
+                    for wiring_by_type in fg_state.wiring.values()
+                ]
+            )
+
+            starts = np.nonzero(
+                var_states_for_edges == fg_state.vars_to_starts[variable]
+            )[0]
+            for start in starts:
+                ftov_msgs = ftov_msgs.at[start : start + variable[1]].set(
+                    data / starts.shape[0]
+                )
+        else:
+            raise ValueError("Provided variable is not in the FactorGraph")
+    return ftov_msgs
+
+
+@dataclass(frozen=True, eq=False)
+class FToVMessages:
+    """Class for storing and manipulating factor to variable messages.
+
+    Args:
+        fg_state: Factor graph state
+        value: Optionally specify initial value for ftov messages
+
+    Raises: ValueError if provided value does not match expected ftov messages shape.
+    """
+
+    fg_state: fgraph.FactorGraphState
+    value: Optional[np.ndarray] = None
+
+    def __post_init__(self):
+        if self.value is None:
+            object.__setattr__(
+                self, "value", np.zeros(self.fg_state.total_factor_num_states)
+            )
+        else:
+            if not self.value.shape == (self.fg_state.total_factor_num_states,):
+                raise ValueError(
+                    f"Expected messages shape {(self.fg_state.total_factor_num_states,)}. "
+                    f"Got {self.value.shape}."
+                )
+
+            object.__setattr__(self, "value", self.value)
+
+    def __setitem__(
+        self,
+        variable: Tuple[int, int],
+        data: Union[np.ndarray, jnp.ndarray],
+    ) -> None:
+        """Spreading beliefs at a variable to all connected Factors
+
+        Args:
+            variable: Variable queried
+            data: An array containing the beliefs to be spread uniformly
+                across all factors to variable messages involving this variable.
+        """
+
+        object.__setattr__(
+            self,
+            "value",
+            np.asarray(
+                update_ftov_msgs(
+                    jax.device_put(self.value),
+                    {variable: jax.device_put(data)},
+                    self.fg_state,
+                )
+            ),
+        )
+
+
+@functools.partial(jax.jit, static_argnames="fg_state")
+def update_evidence(
+    evidence: jnp.ndarray,
+    updates: Dict[Any, jnp.ndarray],
+    fg_state: fgraph.FactorGraphState,
+) -> jnp.ndarray:
+    """Function to update evidence.
+
+    Args:
+        evidence: A flat jnp array containing evidence.
+        updates: A dictionary containing updates for evidence
+        fg_state: Factor graph state
+
+    Returns:
+        A flat jnp array containing updated evidence.
+    """
+    for name, data in updates.items():
+        # Name is a variable_group or a variable
+        if name in fg_state.variable_groups:
+            first_variable = name.variables[0]
+            start_index = fg_state.vars_to_starts[first_variable]
+            flat_data = name.flatten(data)
+            evidence = evidence.at[start_index : start_index + flat_data.shape[0]].set(
+                flat_data
+            )
+        elif name in fg_state.vars_to_starts:
+            start_index = fg_state.vars_to_starts[name]
+            evidence = evidence.at[start_index : start_index + name[1]].set(data)
+        else:
+            raise ValueError(
+                "Got evidence for a variable or a VarGroup not in the FactorGraph!"
+            )
+    return evidence
+
+
+@dataclass(frozen=True, eq=False)
+class Evidence:
+    """Class for storing and manipulating evidence
+
+    Args:
+        fg_state: Factor graph state
+        value: Optionally specify initial value for evidence
+
+    Raises: ValueError if provided value does not match expected evidence shape.
+    """
+
+    fg_state: fgraph.FactorGraphState
+    value: Optional[np.ndarray] = None
+
+    def __post_init__(self):
+        if self.value is None:
+            object.__setattr__(self, "value", np.zeros(self.fg_state.num_var_states))
+        else:
+            if self.value.shape != (self.fg_state.num_var_states,):
+                raise ValueError(
+                    f"Expected evidence shape {(self.fg_state.num_var_states,)}. "
+                    f"Got {self.value.shape}."
+                )
+
+            object.__setattr__(self, "value", self.value)
+
+    def __getitem__(self, variable: Tuple[int, int]) -> np.ndarray:
+        """Function to query evidence for a variable
+
+        Args:
+            variable: Variable queried
+
+        Returns:
+            evidence for the queried variable
+        """
+        value = cast(np.ndarray, self.value)
+        start = self.fg_state.vars_to_starts[variable]
+        evidence = value[start : start + variable[1]]
+        return evidence
+
+    def __setitem__(
+        self,
+        name: Any,
+        data: np.ndarray,
+    ) -> None:
+        """Function to update the evidence for variables
+
+        Args:
+            name: The name of a variable group or a single variable.
+                If name is the name of a variable group, updates are derived by using the variable group to
+                flatten the data.
+                If name is the name of a variable, data should be of an array shape (num_states,)
+                If name is None, updates are derived by using self.fg_state.variable_groups to flatten the data.
+            data: Array containing the evidence updates.
+        """
+        object.__setattr__(
+            self,
+            "value",
+            np.asarray(
+                update_evidence(
+                    jax.device_put(self.value),
+                    {name: jax.device_put(data)},
+                    self.fg_state,
+                ),
+            ),
+        )
+
+
+@dataclass(frozen=True, eq=False)
+class BPState:
+    """Container class for belief propagation states, including log potentials,
+    ftov messages and evidence (unary log potentials).
+
+    Args:
+        log_potentials: log potentials of the model
+        ftov_msgs: factor to variable messages
+        evidence: evidence (unary log potentials) for variables.
+
+    Raises:
+        ValueError: If log_potentials, ftov_msgs or evidence are not derived from the same
+            FactorGraphState.
+    """
+
+    log_potentials: LogPotentials
+    ftov_msgs: FToVMessages
+    evidence: Evidence
+
+    def __post_init__(self):
+        if (self.log_potentials.fg_state != self.ftov_msgs.fg_state) or (
+            self.ftov_msgs.fg_state != self.evidence.fg_state
+        ):
+            raise ValueError(
+                "log_potentials, ftov_msgs and evidence should be derived from the same fg_state."
+            )
+
+    @property
+    def fg_state(self) -> fgraph.FactorGraphState:
+        return self.log_potentials.fg_state
diff --git a/pgmax/bp/bp_utils.py b/pgmax/infer/bp_utils.py
similarity index 58%
rename from pgmax/bp/bp_utils.py
rename to pgmax/infer/bp_utils.py
index 0bdebc5a..8173ffa7 100644
--- a/pgmax/bp/bp_utils.py
+++ b/pgmax/infer/bp_utils.py
@@ -1,14 +1,11 @@
 """A module containing helper functions used for belief propagation."""
 
 import functools
-from typing import Tuple
 
 import jax
 import jax.numpy as jnp
 
-NEG_INF = (
-    -100000.0
-)  # A large negative value to use as -inf for numerical stability reasons
+from pgmax.utils import NEG_INF
 
 
 @functools.partial(jax.jit, static_argnames="max_segment_length")
@@ -50,35 +47,3 @@ def get_max(data, start_index, segment_length):
     )[:-1]
     expanded_data = jnp.concatenate([data, jnp.zeros(max_segment_length)])
     return get_max(expanded_data, start_indices, segments_lengths)
-
-
-@functools.partial(jax.jit, static_argnames="num_labels")
-def get_maxes_and_argmaxes(
-    data: jnp.array, labels: jnp.array, num_labels: int
-) -> Tuple[jnp.ndarray, jnp.ndarray]:
-    """
-    Given a flattened sequence of elements and their corresponding labels,
-    returns the maxes and argmaxes of each label.
-
-    Args:
-        data: Array of shape (a_len,) where a_len is an arbitrary integer.
-        labels: Label array of shape (a_len,), assigning a label to each entry.
-            Labels must be 0,..., num_labels - 1.
-        num_labels: Number of different labels.
-
-    Returns:
-        Maxes and argmaxes arrays
-    """
-    num_obs = data.shape[0]
-
-    maxes = jnp.full(shape=(num_labels,), fill_value=NEG_INF).at[labels].max(data)
-    only_maxes_pos = jnp.arange(num_obs) - num_obs * jnp.where(
-        data != maxes[labels], 1, 0
-    )
-
-    argmaxes = (
-        jnp.full(shape=(num_labels,), fill_value=NEG_INF, dtype=jnp.int32)
-        .at[labels]
-        .max(only_maxes_pos)
-    )
-    return maxes, argmaxes
diff --git a/pgmax/utils.py b/pgmax/utils.py
index cba3f375..15a6114d 100644
--- a/pgmax/utils.py
+++ b/pgmax/utils.py
@@ -3,6 +3,10 @@
 import functools
 from typing import Callable
 
+NEG_INF = (
+    -100000.0
+)  # A large negative value to use as -inf for numerical stability reasons
+
 
 def cached_property(func: Callable) -> property:
     """Customized cached property decorator
diff --git a/pgmax/vgroup/__init__.py b/pgmax/vgroup/__init__.py
new file mode 100644
index 00000000..4ead3270
--- /dev/null
+++ b/pgmax/vgroup/__init__.py
@@ -0,0 +1,5 @@
+"""A sub-package defining variable groups and containing different types of variable groups."""
+
+from .varray import NDVarArray
+from .vdict import VarDict
+from .vgroup import VarGroup
diff --git a/pgmax/vgroup/varray.py b/pgmax/vgroup/varray.py
new file mode 100644
index 00000000..ddc825fd
--- /dev/null
+++ b/pgmax/vgroup/varray.py
@@ -0,0 +1,152 @@
+"""A module containing a subclass of VarGroup for n-dimensional grids of variables."""
+
+from dataclasses import dataclass
+from typing import List, Tuple, Union
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+
+from pgmax.utils import cached_property
+
+from . import vgroup
+
+
+@dataclass(frozen=True, eq=False)
+class NDVarArray(vgroup.VarGroup):
+    """Subclass of VarGroup for n-dimensional grids of variables.
+
+    Args:
+        num_states: An integer or an array specifying the number of states of the
+            variables in this VarGroup
+        shape: Tuple specifying the size of each dimension of the grid (similar to
+            the notion of a NumPy ndarray shape)
+    """
+
+    shape: Tuple[int, ...]
+
+    def __post_init__(self):
+        super().__post_init__()
+
+        max_size = int(vgroup.MAX_SIZE)
+        if np.prod(self.shape) > max_size:
+            raise ValueError(
+                f"Currently only support NDVarArray of size smaller than {max_size}. Got {np.prod(self.shape)}"
+            )
+
+        if np.isscalar(self.num_states):
+            num_states = np.full(self.shape, fill_value=self.num_states, dtype=np.int64)
+            object.__setattr__(self, "num_states", num_states)
+        elif isinstance(self.num_states, np.ndarray) and np.issubdtype(
+            self.num_states.dtype, int
+        ):
+            if self.num_states.shape != self.shape:
+                raise ValueError(
+                    f"Expected num_states shape {self.shape}. Got {self.num_states.shape}."
+                )
+        else:
+            raise ValueError(
+                "num_states should be an integer or a NumPy array of dtype int"
+            )
+
+    def __getitem__(
+        self, val: Union[int, slice, Tuple]
+    ) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
+        """Given an index or a slice, retrieve the associated variable(s).
+        Each variable is returned via a tuple of the form (variable hash, number of states)
+
+        Note: Relies on numpy indexation to throw IndexError if val is out-of-bounds
+
+        Args:
+            val: a variable index or slice
+
+        Returns:
+            A single variable or a list of variables
+        """
+        assert isinstance(self.num_states, np.ndarray)
+
+        if isinstance(val, slice) or (
+            isinstance(val, tuple) and isinstance(val[0], slice)
+        ):
+            assert isinstance(self.num_states, np.ndarray)
+            vars_names = self.variable_hashes[val].flatten()
+            vars_num_states = self.num_states[val].flatten()
+            return list(zip(vars_names, vars_num_states))
+
+        return (self.variable_hashes[val], self.num_states[val])
+
+    @cached_property
+    def variable_hashes(self) -> np.ndarray:
+        """Function that generates a variable hash for each variable
+
+        Returns:
+            Array of variables hashes.
+        """
+        indices = np.reshape(np.arange(np.product(self.shape)), self.shape)
+        return self.__hash__() + indices
+
+    def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
+        """Function that turns meaningful structured data into a flat data array for internal use.
+
+        Args:
+            data: Meaningful structured data. Should be an array of shape self.shape (for e.g. MAP decodings)
+                or self.shape + (self.num_states.max(),) (for e.g. evidence, beliefs).
+
+        Returns:
+            A flat jnp.array for internal use
+
+        Raises:
+            ValueError: If the data is not of the correct shape.
+        """
+        assert isinstance(self.num_states, np.ndarray)
+
+        if data.shape == self.shape:
+            return jax.device_put(data).flatten()
+        elif data.shape == self.shape + (self.num_states.max(),):
+            return jax.device_put(
+                data[np.arange(data.shape[-1]) < self.num_states[..., None]]
+            )
+        else:
+            raise ValueError(
+                f"data should be of shape {self.shape} or {self.shape + (self.num_states.max(),)}. "
+                f"Got {data.shape}."
+            )
+
+    def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray:
+        """Function that recovers meaningful structured data from internal flat data array
+
+        Args:
+            flat_data: Internal flat data array.
+
+        Returns:
+            Meaningful structured data. An array of shape self.shape (for e.g. MAP decodings)
+                or an array of shape self.shape + (self.num_states,) (for e.g. evidence, beliefs).
+
+        Raises:
+            ValueError if:
+                (1) flat_data is not a 1D array
+                (2) flat_data is not of the right shape
+        """
+        assert isinstance(self.num_states, np.ndarray)
+
+        if flat_data.ndim != 1:
+            raise ValueError(
+                f"Can only unflatten 1D array. Got a {flat_data.ndim}D array."
+            )
+
+        if flat_data.size == np.product(self.shape):
+            data = flat_data.reshape(self.shape)
+        elif flat_data.size == self.num_states.sum():
+            data = jnp.full(
+                shape=self.shape + (self.num_states.max(),), fill_value=jnp.nan
+            )
+            data = data.at[np.arange(data.shape[-1]) < self.num_states[..., None]].set(
+                flat_data
+            )
+        else:
+            raise ValueError(
+                f"flat_data should be compatible with shape {self.shape} or {self.shape + (self.num_states.max(),)}. "
+                f"Got {flat_data.shape}."
+            )
+
+        return data
diff --git a/pgmax/vgroup/vdict.py b/pgmax/vgroup/vdict.py
new file mode 100644
index 00000000..7541d971
--- /dev/null
+++ b/pgmax/vgroup/vdict.py
@@ -0,0 +1,162 @@
+"""A module containing a variable dictionnary class inheriting from the base VarGroup."""
+
+from dataclasses import dataclass
+from typing import Any, Dict, Hashable, Mapping, Tuple, Union
+
+import jax.numpy as jnp
+import numpy as np
+
+from pgmax.utils import cached_property
+
+from . import vgroup
+
+
+@dataclass(frozen=True, eq=False)
+class VarDict(vgroup.VarGroup):
+    """A variable dictionary that contains a set of variables
+
+    Args:
+        num_states: The size of the variables in this VarGroup
+        variable_names: A tuple of all the names of the variables in this VarGroup.
+    """
+
+    variable_names: Tuple[Any, ...]
+
+    def __post_init__(self):
+        super().__post_init__()
+
+        if np.isscalar(self.num_states):
+            num_states = np.full(
+                (len(self.variable_names),), fill_value=self.num_states, dtype=np.int64
+            )
+            object.__setattr__(self, "num_states", num_states)
+        elif isinstance(self.num_states, np.ndarray) and np.issubdtype(
+            self.num_states.dtype, int
+        ):
+            if self.num_states.shape != len(self.variable_names):
+                raise ValueError(
+                    f"Expected num_states shape ({len(self.variable_names)},). Got {self.num_states.shape}."
+                )
+        else:
+            raise ValueError(
+                "num_states should be an integer or a NumPy array of dtype int"
+            )
+
+    @cached_property
+    def variable_hashes(self) -> np.ndarray:
+        """Function that generates a variable hash for each variable
+
+        Returns:
+            Array of variables hashes.
+        """
+        indices = np.arange(len(self.variable_names))
+        return self.__hash__() + indices
+
+    def __getitem__(self, var_name: Any) -> Tuple[int, int]:
+        """Given a variable name retrieve the associated variable, returned via a tuple of the form
+        (variable hash, number of states)
+
+        Args:
+            val: a variable name
+
+        Returns:
+            The queried variable
+        """
+        assert isinstance(self.num_states, np.ndarray)
+        if var_name not in self.variable_names:
+            raise ValueError(f"Variable {var_name} is not in VarDict")
+
+        var_idx = self.variable_names.index(var_name)
+        return (self.variable_hashes[var_idx], self.num_states[var_idx])
+
+    def flatten(
+        self, data: Mapping[Any, Union[np.ndarray, jnp.ndarray]]
+    ) -> jnp.ndarray:
+        """Function that turns meaningful structured data into a flat data array for internal use.
+
+        Args:
+            data: Meaningful structured data. Should be a mapping with names from self.variable_names.
+                Each value should be an array of shape (1,) (for e.g. MAP decodings) or
+                (self.num_states,) (for e.g. evidence, beliefs).
+
+        Returns:
+            A flat jnp.array for internal use
+
+        Raises:
+            ValueError if:
+                (1) data is referring to a non-existing variable
+                (2) data is not of the correct shape
+        """
+        assert isinstance(self.num_states, np.ndarray)
+
+        for var_name in data:
+            if var_name not in self.variable_names:
+                raise ValueError(
+                    f"data is referring to a non-existent variable {var_name}."
+                )
+
+            var_idx = self.variable_names.index(var_name)
+            if data[var_name].shape != (self.num_states[var_idx],) and data[
+                var_name
+            ].shape != (1,):
+                raise ValueError(
+                    f"Variable {var_name} expects a data array of shape "
+                    f"{(self.num_states[var_idx],)} or (1,). Got {data[var_name].shape}."
+                )
+
+        flat_data = jnp.concatenate(
+            [data[var_name].flatten() for var_name in self.variable_names]
+        )
+        return flat_data
+
+    def unflatten(
+        self, flat_data: Union[np.ndarray, jnp.ndarray]
+    ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]:
+        """Function that recovers meaningful structured data from internal flat data array
+
+        Args:
+            flat_data: Internal flat data array.
+
+        Returns:
+            Meaningful structured data. Should be a mapping with names from self.variable_names.
+                Each value should be an array of shape (1,) (for e.g. MAP decodings) or
+                (self.num_states,) (for e.g. evidence, beliefs).
+
+        Raises:
+            ValueError if:
+                (1) flat_data is not a 1D array
+                (2) flat_data is not of the right shape
+        """
+        assert isinstance(self.num_states, np.ndarray)
+
+        if flat_data.ndim != 1:
+            raise ValueError(
+                f"Can only unflatten 1D array. Got a {flat_data.ndim}D array."
+            )
+
+        num_variables = len(self.variable_names)
+        num_variable_states = self.num_states.sum()
+        if flat_data.shape[0] == num_variables:
+            use_num_states = False
+        elif flat_data.shape[0] == num_variable_states:
+            use_num_states = True
+        else:
+            raise ValueError(
+                f"flat_data should be either of shape (num_variables(={len(self.variables)}),), "
+                f"or (num_variable_states(={num_variable_states}),). "
+                f"Got {flat_data.shape}"
+            )
+
+        start = 0
+        data = {}
+        for var_name in self.variable_names:
+            if use_num_states:
+                var_idx = self.variable_names.index(var_name)
+                var_num_states = self.num_states[var_idx]
+                data[var_name] = flat_data[start : start + var_num_states]
+                start += var_num_states
+            else:
+                data[var_name] = flat_data[np.array([start])]
+                start += 1
+
+        return data
diff --git a/pgmax/vgroup/vgroup.py b/pgmax/vgroup/vgroup.py
new file mode 100644
index 00000000..80471980
--- /dev/null
+++ b/pgmax/vgroup/vgroup.py
@@ -0,0 +1,107 @@
+"""A module containing the base class for variable groups in a Factor Graph."""
+
+from dataclasses import dataclass
+from functools import total_ordering
+from typing import Any, List, Tuple, Union
+
+import jax.numpy as jnp
+import numpy as np
+
+from pgmax.utils import cached_property
+
+MAX_SIZE = 1e9
+
+
+@total_ordering
+@dataclass(frozen=True, eq=False)
+class VarGroup:
+    """Class to represent a group of variables.
+    Each variable is represented via a tuple of the form (variable hash, variable num_states)
+
+    Arguments:
+        num_states: An integer or an array specifying the number of states of the variables
+            in this VarGroup
+    """
+
+    num_states: Union[int, np.ndarray]
+
+    def __post_init__(self):
+        # Only compute the hash once, which is guaranteed to be an int64
+        this_id = id(self) % 2**32
+        _hash = this_id * int(MAX_SIZE)
+        assert _hash < 2**63
+        object.__setattr__(self, "_hash", _hash)
+
+    def __hash__(self):
+        return self._hash
+
+    def __eq__(self, other):
+        return hash(self) == hash(other)
+
+    def __lt__(self, other):
+        return hash(self) < hash(other)
+
+    def __getitem__(self, val: Any) -> Union[Tuple[int, int], List[Tuple[int, int]]]:
+        """Given a variable name, index, or a group of variable indices, retrieve the associated variable(s).
+        Each variable is returned via a tuple of the form (variable hash, variable num_states)
+
+        Args:
+            val: a variable index, slice, or name
+
+        Returns:
+            A single variable or a list of variables
+        """
+        raise NotImplementedError(
+            "Please subclass the VarGroup class and override this method"
+        )
+
+    @cached_property
+    def variable_hashes(self) -> np.ndarray:
+        """Function that generates a variable hash for each variable
+
+        Returns:
+            Array of variables hashes.
+        """
+        raise NotImplementedError(
+            "Please subclass the VarGroup class and override this method"
+        )
+
+    @cached_property
+    def variables(self) -> List[Tuple[int, int]]:
+        """Function that returns the list of all variables in the VarGroup.
+        Each variable is represented by a tuple of the form (variable hash, variable num_states)
+
+        Returns:
+            List of variables in the VarGroup
+        """
+        assert isinstance(self.variable_hashes, np.ndarray)
+        assert isinstance(self.num_states, np.ndarray)
+        vars_hashes = self.variable_hashes.flatten()
+        vars_num_states = self.num_states.flatten()
+        return list(zip(vars_hashes, vars_num_states))
+
+    def flatten(self, data: Any) -> jnp.ndarray:
+        """Function that turns meaningful structured data into a flat data array for internal use.
+
+        Args:
+            data: Meaningful structured data
+
+        Returns:
+            A flat jnp.array for internal use
+        """
+        raise NotImplementedError(
+            "Please subclass the VarGroup class and override this method"
+        )
+
+    def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any:
+        """Function that recovers meaningful structured data from internal flat data array
+
+        Args:
+            flat_data: Internal flat data array.
+
+        Returns:
+            Meaningful structured data
+        """
+        raise NotImplementedError(
+            "Please subclass the VarGroup class and override this method"
+        )
diff --git a/tests/factors/test_and.py b/tests/factor/test_and.py
similarity index 77%
rename from tests/factors/test_and.py
rename to tests/factor/test_and.py
index 4f2f7612..f36b5989 100644
--- a/tests/factors/test_and.py
+++ b/tests/factor/test_and.py
@@ -3,10 +3,7 @@
 import jax
 import numpy as np
 
-from pgmax.factors.enumeration import EnumerationFactor
-from pgmax.fg import graph
-from pgmax.groups import logical
-from pgmax.groups import variables as vgroup
+from pgmax import factor, fgraph, fgroup, infer, vgroup
 
 
 def test_run_bp_with_ANDFactors():
@@ -15,15 +12,15 @@ def test_run_bp_with_ANDFactors():
     (1) the support of ANDFactors in a FactorGraph and their specialized inference for different temperatures
     (2) the support of several factor types in a FactorGraph and during inference
 
-    To do so, observe that an ANDFactor can be defined as an equivalent EnumerationFactor
+    To do so, observe that an ANDFactor can be defined as an equivalent EnumFactor
     (which list all the valid AND configurations) and define two equivalent FactorGraphs
-    FG1: first half of factors are defined as EnumerationFactors, second half are defined as ANDFactors
-    FG2: first half of factors are defined as ANDFactors, second half are defined as EnumerationFactors
+    FG1: first half of factors are defined as EnumFactors, second half are defined as ANDFactors
+    FG2: first half of factors are defined as ANDFactors, second half are defined as EnumFactors
 
-    Inference for the EnumerationFactors will be run with pass_enum_fac_to_var_messages while
+    Inference for the EnumFactors will be run with pass_enum_fac_to_var_messages while
     inference for the ANDFactors will be run with pass_logical_fac_to_var_messages.
 
-    Note: for the first seed, add all the EnumerationFactors to FG1 and all the ANDFactors to FG2
+    Note: for the first seed, add all the EnumFactors to FG1 and all the ANDFactors to FG2
     """
     for idx in range(10):
         np.random.seed(idx)
@@ -41,24 +38,20 @@ def test_run_bp_with_ANDFactors():
             temperature = np.random.uniform(low=0.5, high=1.0)
 
         # Graph 1
-        parents_variables1 = vgroup.NDVariableArray(
-            num_states=2, shape=(num_parents.sum(),)
-        )
-        children_variables1 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
-        fg1 = graph.FactorGraph(
+        parents_variables1 = vgroup.NDVarArray(num_states=2, shape=(num_parents.sum(),))
+        children_variables1 = vgroup.NDVarArray(num_states=2, shape=(num_factors,))
+        fg1 = fgraph.FactorGraph(
             variable_groups=[parents_variables1, children_variables1]
         )
 
         # Graph 2
-        parents_variables2 = vgroup.NDVariableArray(
-            num_states=2, shape=(num_parents.sum(),)
-        )
-        children_variables2 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
-        fg2 = graph.FactorGraph(
+        parents_variables2 = vgroup.NDVarArray(num_states=2, shape=(num_parents.sum(),))
+        children_variables2 = vgroup.NDVarArray(num_states=2, shape=(num_factors,))
+        fg2 = fgraph.FactorGraph(
             variable_groups=[parents_variables2, children_variables2]
         )
 
-        # Option 1: Define EnumerationFactors equivalent to the ANDFactors
+        # Option 1: Define EnumFactors equivalent to the ANDFactors
         variables_for_factors1 = []
         variables_for_factors2 = []
         for factor_idx in range(num_factors):
@@ -80,7 +73,7 @@ def test_run_bp_with_ANDFactors():
             ] + [children_variables2[factor_idx]]
             variables_for_factors2.append(variables2)
 
-        # Option 1: Define EnumerationFactors equivalent to the ANDFactors
+        # Option 1: Define EnumFactors equivalent to the ANDFactors
         for factor_idx in range(num_factors):
             this_num_parents = num_parents[factor_idx]
 
@@ -99,7 +92,7 @@ def test_run_bp_with_ANDFactors():
 
             if factor_idx < num_factors // 2:
                 # Add the first half of factors to FactorGraph1
-                enum_factor = EnumerationFactor(
+                enum_factor = factor.EnumFactor(
                     variables=variables_for_factors1[factor_idx],
                     factor_configs=valid_configs,
                     log_potentials=np.zeros(valid_configs.shape[0]),
@@ -108,15 +101,15 @@ def test_run_bp_with_ANDFactors():
             else:
                 if idx != 0:
                     # Add the second half of factors to FactorGraph2
-                    enum_factor = EnumerationFactor(
+                    enum_factor = factor.EnumFactor(
                         variables=variables_for_factors2[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
                     fg2.add_factors(enum_factor)
                 else:
-                    # Add all the EnumerationFactors to FactorGraph1 for the first iter
-                    enum_factor = EnumerationFactor(
+                    # Add all the EnumFactors to FactorGraph1 for the first iter
+                    enum_factor = factor.EnumFactor(
                         variables=variables_for_factors1[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
@@ -144,15 +137,15 @@ def test_run_bp_with_ANDFactors():
                         variables_for_factors2[factor_idx]
                     )
         if idx != 0:
-            factor_group = logical.ANDFactorGroup(variables_for_ANDFactors_fg1)
+            factor_group = fgroup.ANDFactorGroup(variables_for_ANDFactors_fg1)
             fg1.add_factors(factor_group)
 
-        factor_group = logical.ANDFactorGroup(variables_for_ANDFactors_fg2)
+        factor_group = fgroup.ANDFactorGroup(variables_for_ANDFactors_fg2)
         fg2.add_factors(factor_group)
 
         # Run inference
-        bp1 = graph.BP(fg1.bp_state, temperature=temperature)
-        bp2 = graph.BP(fg2.bp_state, temperature=temperature)
+        bp1 = infer.BP(fg1.bp_state, temperature=temperature)
+        bp2 = infer.BP(fg2.bp_state, temperature=temperature)
 
         evidence_parents = jax.device_put(np.random.gumbel(size=(sum(num_parents), 2)))
         evidence_children = jax.device_put(np.random.gumbel(size=(num_factors, 2)))
diff --git a/tests/fg/test_nodes.py b/tests/factor/test_factor.py
similarity index 80%
rename from tests/fg/test_nodes.py
rename to tests/factor/test_factor.py
index 27445726..79f5141f 100644
--- a/tests/fg/test_nodes.py
+++ b/tests/factor/test_factor.py
@@ -3,38 +3,36 @@
 import numpy as np
 import pytest
 
-from pgmax.factors import enumeration, logical
-from pgmax.fg import nodes
-from pgmax.groups import variables as vgroup
+from pgmax import factor, vgroup
 
 
 def test_enumeration_factor():
-    variables = vgroup.NDVariableArray(num_states=3, shape=(1,))
+    variables = vgroup.NDVarArray(num_states=3, shape=(1,))
 
     with pytest.raises(
         NotImplementedError, match="Please implement compile_wiring in for your factor"
     ):
-        nodes.Factor(
+        factor.Factor(
             variables=[variables[0]],
             log_potentials=np.array([0.0]),
         )
 
     with pytest.raises(ValueError, match="Configurations should be integers. Got"):
-        enumeration.EnumerationFactor(
+        factor.EnumFactor(
             variables=[variables[0]],
             factor_configs=np.array([[1.0]]),
             log_potentials=np.array([0.0]),
         )
 
     with pytest.raises(ValueError, match="Potential should be floats. Got"):
-        enumeration.EnumerationFactor(
+        factor.EnumFactor(
             variables=[variables[0]],
             factor_configs=np.array([[1]]),
             log_potentials=np.array([0]),
         )
 
     with pytest.raises(ValueError, match="factor_configs should be a 2D array"):
-        enumeration.EnumerationFactor(
+        factor.EnumFactor(
             variables=[variables[0]],
             factor_configs=np.array([1]),
             log_potentials=np.array([0.0]),
@@ -46,7 +44,7 @@ def test_enumeration_factor():
             "Number of variables 1 doesn't match given configurations (1, 2)"
         ),
     ):
-        enumeration.EnumerationFactor(
+        factor.EnumFactor(
             variables=[variables[0]],
             factor_configs=np.array([[1, 2]]),
             log_potentials=np.array([0.0]),
@@ -55,14 +53,14 @@ def test_enumeration_factor():
     with pytest.raises(
         ValueError, match=re.escape("Expected log potentials of shape (1,)")
     ):
-        enumeration.EnumerationFactor(
+        factor.EnumFactor(
             variables=[variables[0]],
             factor_configs=np.array([[1]]),
             log_potentials=np.array([0.0, 1.0]),
         )
 
     with pytest.raises(ValueError, match="Invalid configurations for given variables"):
-        enumeration.EnumerationFactor(
+        factor.EnumFactor(
             variables=[variables[0]],
             factor_configs=np.array([[10]]),
             log_potentials=np.array([0.0]),
@@ -70,24 +68,24 @@ def test_enumeration_factor():
 
 
 def test_logical_factor():
-    child = vgroup.NDVariableArray(num_states=2, shape=(1,))[0]
-    wrong_parent = vgroup.NDVariableArray(num_states=3, shape=(1,))[0]
-    parent = vgroup.NDVariableArray(num_states=2, shape=(1,))[0]
+    child = vgroup.NDVarArray(num_states=2, shape=(1,))[0]
+    wrong_parent = vgroup.NDVarArray(num_states=3, shape=(1,))[0]
+    parent = vgroup.NDVarArray(num_states=2, shape=(1,))[0]
 
     with pytest.raises(
         ValueError,
         match="A LogicalFactor requires at least one parent variable and one child variable",
     ):
-        logical.LogicalFactor(
+        factor.logical.LogicalFactor(
             variables=(child,),
         )
 
     with pytest.raises(ValueError, match="All variables should all be binary"):
-        logical.LogicalFactor(
+        factor.logical.LogicalFactor(
             variables=(wrong_parent, child),
         )
 
-    logical_factor = logical.LogicalFactor(
+    logical_factor = factor.logical.LogicalFactor(
         variables=(parent, child),
     )
     num_parents = len(logical_factor.variables) - 1
@@ -100,7 +98,7 @@ def test_logical_factor():
     child_edge_state = np.array([2 * num_parents], dtype=int)
 
     with pytest.raises(ValueError, match="The highest LogicalFactor index must be 0"):
-        logical.LogicalWiring(
+        factor.logical.LogicalWiring(
             edges_num_states=[2, 2],
             var_states_for_edges=None,
             parents_edge_states=parents_edge_states + np.array([[1, 0]]),
@@ -112,7 +110,7 @@ def test_logical_factor():
         ValueError,
         match="The LogicalWiring must have 1 different LogicalFactor indices",
     ):
-        logical.LogicalWiring(
+        factor.logical.LogicalWiring(
             edges_num_states=[2, 2],
             var_states_for_edges=None,
             parents_edge_states=parents_edge_states + np.array([[0], [1]]),
@@ -126,7 +124,7 @@ def test_logical_factor():
             "The LogicalWiring's edge_states_offset must be 1 (for OR) and -1 (for AND), but is 0"
         ),
     ):
-        logical.LogicalWiring(
+        factor.logical.LogicalWiring(
             edges_num_states=[2, 2],
             var_states_for_edges=None,
             parents_edge_states=parents_edge_states,
diff --git a/tests/factors/test_or.py b/tests/factor/test_or.py
similarity index 78%
rename from tests/factors/test_or.py
rename to tests/factor/test_or.py
index c615a800..7e2995b8 100644
--- a/tests/factors/test_or.py
+++ b/tests/factor/test_or.py
@@ -3,10 +3,7 @@
 import jax
 import numpy as np
 
-from pgmax.factors.enumeration import EnumerationFactor
-from pgmax.fg import graph
-from pgmax.groups import logical
-from pgmax.groups import variables as vgroup
+from pgmax import factor, fgraph, fgroup, infer, vgroup
 
 
 def test_run_bp_with_ORFactors():
@@ -15,15 +12,15 @@ def test_run_bp_with_ORFactors():
     (1) the support of ORFactors in a FactorGraph and their specialized inference for different temperatures
     (2) the support of several factor types in a FactorGraph and during inference
 
-    To do so, observe that an ORFactor can be defined as an equivalent EnumerationFactor
+    To do so, observe that an ORFactor can be defined as an equivalent EnumFactor
     (which list all the valid OR configurations) and define two equivalent FactorGraphs
-    FG1: first half of factors are defined as EnumerationFactors, second half are defined as ORFactors
-    FG2: first half of factors are defined as ORFactors, second half are defined as EnumerationFactors
+    FG1: first half of factors are defined as EnumFactors, second half are defined as ORFactors
+    FG2: first half of factors are defined as ORFactors, second half are defined as EnumFactors
 
-    Inference for the EnumerationFactors will be run with pass_enum_fac_to_var_messages while
+    Inference for the EnumFactors will be run with pass_enum_fac_to_var_messages while
     inference for the ORFactors will be run with pass_logical_fac_to_var_messages.
 
-    Note: for the first seed, add all the EnumerationFactors to FG1 and all the ORFactors to FG2
+    Note: for the first seed, add all the EnumFactors to FG1 and all the ORFactors to FG2
     """
     for idx in range(10):
         np.random.seed(idx)
@@ -41,20 +38,16 @@ def test_run_bp_with_ORFactors():
             temperature = np.random.uniform(low=0.5, high=1.0)
 
         # Graph 1
-        parents_variables1 = vgroup.NDVariableArray(
-            num_states=2, shape=(num_parents.sum(),)
-        )
-        children_variables1 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
-        fg1 = graph.FactorGraph(
+        parents_variables1 = vgroup.NDVarArray(num_states=2, shape=(num_parents.sum(),))
+        children_variables1 = vgroup.NDVarArray(num_states=2, shape=(num_factors,))
+        fg1 = fgraph.FactorGraph(
             variable_groups=[parents_variables1, children_variables1]
         )
 
         # Graph 2
-        parents_variables2 = vgroup.NDVariableArray(
-            num_states=2, shape=(num_parents.sum(),)
-        )
-        children_variables2 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,))
-        fg2 = graph.FactorGraph(
+        parents_variables2 = vgroup.NDVarArray(num_states=2, shape=(num_parents.sum(),))
+        children_variables2 = vgroup.NDVarArray(num_states=2, shape=(num_factors,))
+        fg2 = fgraph.FactorGraph(
             variable_groups=[parents_variables2, children_variables2]
         )
 
@@ -80,7 +73,7 @@ def test_run_bp_with_ORFactors():
             ] + [children_variables2[factor_idx]]
             variables_for_factors2.append(variables2)
 
-        # Option 1: Define EnumerationFactors equivalent to the ORFactors
+        # Option 1: Define EnumFactors equivalent to the ORFactors
         for factor_idx in range(num_factors):
             this_num_parents = num_parents[factor_idx]
 
@@ -97,7 +90,7 @@ def test_run_bp_with_ORFactors():
 
             if factor_idx < num_factors // 2:
                 # Add the first half of factors to FactorGraph1
-                enum_factor = EnumerationFactor(
+                enum_factor = factor.EnumFactor(
                     variables=variables_for_factors1[factor_idx],
                     factor_configs=valid_configs,
                     log_potentials=np.zeros(valid_configs.shape[0]),
@@ -106,15 +99,15 @@ def test_run_bp_with_ORFactors():
             else:
                 if idx != 0:
                     # Add the second half of factors to FactorGraph2
-                    enum_factor = EnumerationFactor(
+                    enum_factor = factor.EnumFactor(
                         variables=variables_for_factors2[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
                     )
                     fg2.add_factors(enum_factor)
                 else:
-                    # Add all the EnumerationFactors to FactorGraph1 for the first iter
-                    enum_factor = EnumerationFactor(
+                    # Add all the EnumFactors to FactorGraph1 for the first iter
+                    enum_factor = factor.EnumFactor(
                         variables=variables_for_factors1[factor_idx],
                         factor_configs=valid_configs,
                         log_potentials=np.zeros(valid_configs.shape[0]),
@@ -142,15 +135,15 @@ def test_run_bp_with_ORFactors():
                         variables_for_factors2[factor_idx]
                     )
         if idx != 0:
-            factor_group = logical.ORFactorGroup(variables_for_ORFactors_fg1)
+            factor_group = fgroup.ORFactorGroup(variables_for_ORFactors_fg1)
             fg1.add_factors(factor_group)
 
-        factor_group = logical.ORFactorGroup(variables_for_ORFactors_fg2)
+        factor_group = fgroup.ORFactorGroup(variables_for_ORFactors_fg2)
         fg2.add_factors(factor_group)
 
         # Run inference
-        bp1 = graph.BP(fg1.bp_state, temperature=temperature)
-        bp2 = graph.BP(fg2.bp_state, temperature=temperature)
+        bp1 = infer.BP(fg1.bp_state, temperature=temperature)
+        bp2 = infer.BP(fg2.bp_state, temperature=temperature)
 
         evidence_parents = jax.device_put(np.random.gumbel(size=(sum(num_parents), 2)))
         evidence_children = jax.device_put(np.random.gumbel(size=(num_factors, 2)))
diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py
deleted file mode 100644
index e39fbda3..00000000
--- a/tests/fg/test_groups.py
+++ /dev/null
@@ -1,304 +0,0 @@
-import re
-
-import jax
-import jax.numpy as jnp
-import numpy as np
-import pytest
-
-from pgmax.fg import groups
-from pgmax.groups import enumeration, logical
-from pgmax.groups import variables as vgroup
-
-
-def test_variable_dict():
-    num_states = np.full((4,), fill_value=2)
-    with pytest.raises(
-        ValueError, match=re.escape("Expected num_states shape (3,). Got (4,).")
-    ):
-        vgroup.VariableDict(variable_names=tuple([0, 1, 2]), num_states=num_states)
-
-    num_states = np.full((3,), fill_value=2, dtype=np.float32)
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            "num_states should be an integer or a NumPy array of dtype int"
-        ),
-    ):
-        vgroup.VariableDict(variable_names=tuple([0, 1, 2]), num_states=num_states)
-
-    variable_dict = vgroup.VariableDict(variable_names=tuple([0, 1, 2]), num_states=15)
-    with pytest.raises(
-        ValueError, match="data is referring to a non-existent variable 3"
-    ):
-        variable_dict.flatten({3: np.zeros(10)})
-
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            "Variable 2 expects a data array of shape (15,) or (1,). Got (10,)."
-        ),
-    ):
-        variable_dict.flatten({2: np.zeros(10)})
-
-    with pytest.raises(
-        ValueError, match="Can only unflatten 1D array. Got a 2D array."
-    ):
-        variable_dict.unflatten(jnp.zeros((10, 20)))
-
-    assert jnp.all(
-        jnp.array(
-            jax.tree_util.tree_leaves(
-                jax.tree_util.tree_multimap(
-                    lambda x, y: jnp.all(x == y),
-                    variable_dict.unflatten(jnp.zeros(3)),
-                    {name: np.zeros(1) for name in range(3)},
-                )
-            )
-        )
-    )
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            "flat_data should be either of shape (num_variables(=3),), or (num_variable_states(=45),)"
-        ),
-    ):
-        variable_dict.unflatten(jnp.zeros((100)))
-
-
-def test_nd_variable_array():
-    max_size = int(groups.MAX_SIZE)
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            f"Currently only support NDVariableArray of size smaller than {max_size}. Got {max_size + 1}"
-        ),
-    ):
-        vgroup.NDVariableArray(shape=(max_size + 1,), num_states=2)
-
-    num_states = np.full((2, 3), fill_value=2)
-    with pytest.raises(
-        ValueError, match=re.escape("Expected num_states shape (2, 2). Got (2, 3).")
-    ):
-        vgroup.NDVariableArray(shape=(2, 2), num_states=num_states)
-
-    num_states = np.full((2, 3), fill_value=2, dtype=np.float32)
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            "num_states should be an integer or a NumPy array of dtype int"
-        ),
-    ):
-        vgroup.NDVariableArray(shape=(2, 2), num_states=num_states)
-
-    variable_group0 = vgroup.NDVariableArray(shape=(5, 5), num_states=2)
-    assert len(variable_group0[:3, :3]) == 9
-
-    variable_group = vgroup.NDVariableArray(
-        shape=(2, 2), num_states=np.array([[1, 2], [3, 4]])
-    )
-    variable_group0 < variable_group
-
-    with pytest.raises(
-        ValueError,
-        match=re.escape("data should be of shape (2, 2) or (2, 2, 4). Got (3, 3)."),
-    ):
-        variable_group.flatten(np.zeros((3, 3)))
-
-    assert jnp.all(
-        variable_group.flatten(np.array([[1, 2], [3, 4]])) == jnp.array([1, 2, 3, 4])
-    )
-    assert jnp.all(variable_group.flatten(np.zeros((2, 2, 4))) == jnp.zeros((10,)))
-
-    with pytest.raises(
-        ValueError, match="Can only unflatten 1D array. Got a 2D array."
-    ):
-        variable_group.unflatten(np.zeros((10, 20)))
-
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            "flat_data should be compatible with shape (2, 2) or (2, 2, 4). Got (12,)."
-        ),
-    ):
-        variable_group.unflatten(np.zeros((12,)))
-
-    assert jnp.all(variable_group.unflatten(np.zeros(4)) == jnp.zeros((2, 2)))
-    unflattened = jnp.full((2, 2, 4), fill_value=jnp.nan)
-    unflattened = unflattened.at[0, 0, 0].set(0)
-    unflattened = unflattened.at[0, 1, :1].set(0)
-    unflattened = unflattened.at[1, 0, :2].set(0)
-    unflattened = unflattened.at[1, 1].set(0)
-    mask = ~jnp.isnan(unflattened)
-    assert jnp.all(variable_group.unflatten(np.zeros(10))[mask] == unflattened[mask])
-
-
-def test_single_factor():
-    with pytest.raises(ValueError, match="Cannot create a FactorGroup with no Factor."):
-        logical.ORFactorGroup(variables_for_factors=[])
-
-    A = vgroup.NDVariableArray(num_states=2, shape=(10,))
-    B = vgroup.NDVariableArray(num_states=2, shape=(10,))
-
-    variables0 = (A[0], B[0])
-    variables1 = (A[1], B[1])
-    ORFactor0 = logical.ORFactorGroup(variables_for_factors=[variables0])
-    with pytest.raises(
-        ValueError, match="SingleFactorGroup should only contain one factor. Got 2"
-    ):
-        groups.SingleFactorGroup(
-            variables_for_factors=[variables0, variables1],
-            factor=ORFactor0,
-        )
-    ORFactor1 = logical.ORFactorGroup(variables_for_factors=[variables1])
-    ORFactor0 < ORFactor1
-
-
-def test_enumeration_factor_group():
-    vg = vgroup.NDVariableArray(shape=(2, 2), num_states=3)
-    with pytest.raises(
-        ValueError,
-        match=re.escape("Expected log potentials shape: (1,) or (2, 1). Got (3, 2)"),
-    ):
-        enumeration_factor_group = enumeration.EnumerationFactorGroup(
-            variables_for_factors=[
-                [vg[0, 0], vg[0, 1], vg[1, 1]],
-                [vg[0, 1], vg[1, 0], vg[1, 1]],
-            ],
-            factor_configs=np.zeros((1, 3), dtype=int),
-            log_potentials=np.zeros((3, 2)),
-        )
-
-    with pytest.raises(ValueError, match=re.escape("Potentials should be floats")):
-        enumeration_factor_group = enumeration.EnumerationFactorGroup(
-            variables_for_factors=[
-                [vg[0, 0], vg[0, 1], vg[1, 1]],
-                [vg[0, 1], vg[1, 0], vg[1, 1]],
-            ],
-            factor_configs=np.zeros((1, 3), dtype=int),
-            log_potentials=np.zeros((2, 1), dtype=int),
-        )
-
-    enumeration_factor_group = enumeration.EnumerationFactorGroup(
-        variables_for_factors=[
-            [vg[0, 0], vg[0, 1], vg[1, 1]],
-            [vg[0, 1], vg[1, 0], vg[1, 1]],
-        ],
-        factor_configs=np.zeros((1, 3), dtype=int),
-    )
-    name = [vg[0, 0], vg[1, 1]]
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            f"The queried factor connected to the set of variables {frozenset(name)} is not present in the factor group."
-        ),
-    ):
-        enumeration_factor_group[name]
-
-    assert (
-        enumeration_factor_group[[vg[0, 1], vg[1, 0], vg[1, 1]]]
-        == enumeration_factor_group.factors[1]
-    )
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            "data should be of shape (2, 1) or (2, 9) or (1,). Got (4, 5)."
-        ),
-    ):
-        enumeration_factor_group.flatten(np.zeros((4, 5)))
-
-    assert jnp.all(enumeration_factor_group.flatten(np.ones(1)) == jnp.ones(2))
-    assert jnp.all(enumeration_factor_group.flatten(np.ones((2, 9))) == jnp.ones(18))
-    with pytest.raises(
-        ValueError, match=re.escape("Can only unflatten 1D array. Got a 3D array.")
-    ):
-        enumeration_factor_group.unflatten(jnp.ones((1, 2, 3)))
-
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            "flat_data should be compatible with shape (2, 1) or (2, 9). Got (30,)"
-        ),
-    ):
-        enumeration_factor_group.unflatten(jnp.zeros(30))
-
-    assert jnp.all(
-        enumeration_factor_group.unflatten(jnp.arange(2)) == jnp.array([[0], [1]])
-    )
-    assert jnp.all(enumeration_factor_group.unflatten(jnp.ones(18)) == jnp.ones((2, 9)))
-
-
-def test_pairwise_factor_group():
-    vg = vgroup.NDVariableArray(shape=(2, 2), num_states=3)
-
-    with pytest.raises(
-        ValueError, match=re.escape("log_potential_matrix should be either a 2D array")
-    ):
-        enumeration.PairwiseFactorGroup(
-            [[vg[0, 0], vg[1, 1]]], np.zeros((1,), dtype=float)
-        )
-
-    with pytest.raises(
-        ValueError, match=re.escape("Potential matrix should be floats")
-    ):
-        enumeration.PairwiseFactorGroup(
-            [[vg[0, 0], vg[1, 1]]], np.zeros((3, 3), dtype=int)
-        )
-
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            "Expected log_potential_matrix for 1 factors. Got log_potential_matrix for 2 factors."
-        ),
-    ):
-        enumeration.PairwiseFactorGroup(
-            [[vg[0, 0], vg[1, 1]]], np.zeros((2, 3, 3), dtype=float)
-        )
-
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            "All pairwise factors should connect to exactly 2 variables. Got a factor connecting to 3 variables"
-        ),
-    ):
-        enumeration.PairwiseFactorGroup(
-            [[vg[0, 0], vg[1, 1], vg[0, 1]]], np.zeros((3, 3), dtype=float)
-        )
-
-    name = [vg[0, 0], vg[1, 1]]
-    with pytest.raises(
-        ValueError,
-        match=re.escape(f"The specified pairwise factor {name}"),
-    ):
-        enumeration.PairwiseFactorGroup([name], np.zeros((4, 4), dtype=float))
-
-    pairwise_factor_group = enumeration.PairwiseFactorGroup(
-        [[vg[0, 0], vg[1, 1]], [vg[1, 0], vg[0, 1]]],
-    )
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            "data should be of shape (2, 3, 3) or (2, 6) or (3, 3). Got (4, 4)."
-        ),
-    ):
-        pairwise_factor_group.flatten(np.zeros((4, 4)))
-
-    assert jnp.all(
-        pairwise_factor_group.flatten(np.zeros((3, 3))) == jnp.zeros(2 * 3 * 3)
-    )
-    assert jnp.all(pairwise_factor_group.flatten(np.zeros((2, 6))) == jnp.zeros(12))
-    with pytest.raises(ValueError, match="Can only unflatten 1D array. Got a 2D array"):
-        pairwise_factor_group.unflatten(np.zeros((10, 20)))
-
-    assert jnp.all(
-        pairwise_factor_group.unflatten(np.zeros(2 * 3 * 3)) == jnp.zeros((2, 3, 3))
-    )
-    assert jnp.all(
-        pairwise_factor_group.unflatten(np.zeros(2 * 6)) == jnp.zeros((2, 6))
-    )
-    with pytest.raises(
-        ValueError,
-        match=re.escape(
-            "flat_data should be compatible with shape (2, 3, 3) or (2, 6). Got (10,)."
-        ),
-    ):
-        pairwise_factor_group.unflatten(np.zeros(10))
diff --git a/tests/fg/test_graph.py b/tests/fgraph/test_fgraph.py
similarity index 65%
rename from tests/fg/test_graph.py
rename to tests/fgraph/test_fgraph.py
index 8b0d72b0..0ad7dbe4 100644
--- a/tests/fg/test_graph.py
+++ b/tests/fgraph/test_fgraph.py
@@ -6,24 +6,21 @@
 import numpy as np
 import pytest
 
-from pgmax.factors import enumeration as enumeration_factor
-from pgmax.fg import graph
-from pgmax.groups import enumeration
-from pgmax.groups import variables as vgroup
+from pgmax import factor, fgraph, fgroup, infer, vgroup
 
 
 def test_factor_graph():
-    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
-    fg = graph.FactorGraph(vg)
+    vg = vgroup.VarDict(variable_names=(0,), num_states=15)
+    fg = fgraph.FactorGraph(vg)
 
-    factor = enumeration_factor.EnumerationFactor(
+    enum_factor = factor.EnumFactor(
         variables=[vg[0]],
         factor_configs=np.arange(15)[:, None],
         log_potentials=np.zeros(15),
     )
-    fg.add_factors(factor)
+    fg.add_factors(enum_factor)
 
-    factor_group = enumeration.EnumerationFactorGroup(
+    factor_group = fgroup.EnumFactorGroup(
         variables_for_factors=[[vg[0]]],
         factor_configs=np.arange(15)[:, None],
         log_potentials=np.zeros(15),
@@ -31,30 +28,30 @@ def test_factor_graph():
     with pytest.raises(
         ValueError,
         match=re.escape(
-            f"A Factor of type {enumeration_factor.EnumerationFactor} involving variables {frozenset([(vg.__hash__(), 15)])} already exists."
+            f"A Factor of type {factor.EnumFactor} involving variables {frozenset([(vg.__hash__(), 15)])} already exists."
         ),
     ):
         fg.add_factors(factor_group)
 
 
 def test_bp_state():
-    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
-    fg0 = graph.FactorGraph(vg)
-    factor = enumeration_factor.EnumerationFactor(
+    vg = vgroup.VarDict(variable_names=(0,), num_states=15)
+    fg0 = fgraph.FactorGraph(vg)
+    enum_factor = factor.EnumFactor(
         variables=[vg[0]],
         factor_configs=np.arange(15)[:, None],
         log_potentials=np.zeros(15),
     )
-    fg0.add_factors(factor)
+    fg0.add_factors(enum_factor)
 
-    fg1 = graph.FactorGraph(vg)
-    fg1.add_factors(factor)
+    fg1 = fgraph.FactorGraph(vg)
+    fg1.add_factors(enum_factor)
 
     with pytest.raises(
         ValueError,
         match="log_potentials, ftov_msgs and evidence should be derived from the same fg_state",
     ):
-        graph.BPState(
+        infer.BPState(
             log_potentials=fg0.bp_state.log_potentials,
             ftov_msgs=fg1.bp_state.ftov_msgs,
             evidence=fg1.bp_state.evidence,
@@ -62,9 +59,9 @@ def test_bp_state():
 
 
 def test_log_potentials():
-    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
-    fg = graph.FactorGraph(vg)
-    factor_group = enumeration.EnumerationFactorGroup(
+    vg = vgroup.VarDict(variable_names=(0,), num_states=15)
+    fg = fgraph.FactorGraph(vg)
+    factor_group = fgroup.EnumFactorGroup(
         variables_for_factors=[[vg[0]]],
         factor_configs=np.arange(10)[:, None],
     )
@@ -80,7 +77,7 @@ def test_log_potentials():
         ValueError,
         match=re.escape("Invalid FactorGroup for log potentials updates."),
     ):
-        factor_group2 = enumeration.EnumerationFactorGroup(
+        factor_group2 = fgroup.EnumFactorGroup(
             variables_for_factors=[[vg[0]]],
             factor_configs=np.arange(10)[:, None],
         )
@@ -95,16 +92,16 @@ def test_log_potentials():
     with pytest.raises(
         ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)")
     ):
-        graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15))
+        infer.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15))
 
-    log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10))
+    log_potentials = infer.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10))
     assert jnp.all(log_potentials[factor_group] == jnp.zeros(10))
 
 
 def test_ftov_msgs():
-    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
-    fg = graph.FactorGraph(vg)
-    factor_group = enumeration.EnumerationFactorGroup(
+    vg = vgroup.VarDict(variable_names=(0,), num_states=15)
+    fg = fgraph.FactorGraph(vg)
+    factor_group = fgroup.EnumFactorGroup(
         variables_for_factors=[[vg[0]]],
         factor_configs=np.arange(10)[:, None],
     )
@@ -127,9 +124,9 @@ def test_ftov_msgs():
     with pytest.raises(
         ValueError, match=re.escape("Expected messages shape (15,). Got (10,)")
     ):
-        graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10))
+        infer.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10))
 
-    ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15))
+    ftov_msgs = infer.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15))
     with pytest.raises(
         TypeError, match=re.escape("'FToVMessages' object is not subscriptable")
     ):
@@ -137,9 +134,9 @@ def test_ftov_msgs():
 
 
 def test_evidence():
-    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
-    fg = graph.FactorGraph(vg)
-    factor_group = enumeration.EnumerationFactorGroup(
+    vg = vgroup.VarDict(variable_names=(0,), num_states=15)
+    fg = fgraph.FactorGraph(vg)
+    factor_group = fgroup.EnumFactorGroup(
         variables_for_factors=[[vg[0]]],
         factor_configs=np.arange(10)[:, None],
     )
@@ -148,19 +145,19 @@ def test_evidence():
     with pytest.raises(
         ValueError, match=re.escape("Expected evidence shape (15,). Got (10,).")
     ):
-        graph.Evidence(fg_state=fg.fg_state, value=np.zeros(10))
+        infer.Evidence(fg_state=fg.fg_state, value=np.zeros(10))
 
-    evidence = graph.Evidence(fg_state=fg.fg_state, value=np.zeros(15))
+    evidence = infer.Evidence(fg_state=fg.fg_state, value=np.zeros(15))
     assert jnp.all(evidence.value == jnp.zeros(15))
 
-    vg2 = vgroup.VariableDict(variable_names=(0,), num_states=15)
+    vg2 = vgroup.VarDict(variable_names=(0,), num_states=15)
     with pytest.raises(
         ValueError,
         match=re.escape(
-            "Got evidence for a variable or a VariableGroup not in the FactorGraph!"
+            "Got evidence for a variable or a VarGroup not in the FactorGraph!"
         ),
     ):
-        graph.update_evidence(
+        infer.bp_state.update_evidence(
             jax.device_put(evidence.value),
             {vg2[0]: jax.device_put(np.zeros(15))},
             fg.fg_state,
@@ -168,15 +165,15 @@ def test_evidence():
 
 
 def test_bp():
-    vg = vgroup.VariableDict(variable_names=(0,), num_states=15)
-    fg = graph.FactorGraph(vg)
-    factor_group = enumeration.EnumerationFactorGroup(
+    vg = vgroup.VarDict(variable_names=(0,), num_states=15)
+    fg = fgraph.FactorGraph(vg)
+    factor_group = fgroup.EnumFactorGroup(
         variables_for_factors=[[vg[0]]],
         factor_configs=np.arange(10)[:, None],
     )
     fg.add_factors(factor_group)
 
-    bp = graph.BP(fg.bp_state, temperature=0)
+    bp = infer.BP(fg.bp_state, temperature=0)
     bp_arrays = bp.update()
     bp_arrays = bp.update(
         bp_arrays=bp_arrays,
diff --git a/tests/fgroup/test_fgroup.py b/tests/fgroup/test_fgroup.py
new file mode 100644
index 00000000..39609bb7
--- /dev/null
+++ b/tests/fgroup/test_fgroup.py
@@ -0,0 +1,175 @@
+import re
+
+import jax.numpy as jnp
+import numpy as np
+import pytest
+
+from pgmax import fgroup, vgroup
+
+
+def test_single_factor():
+    with pytest.raises(ValueError, match="Cannot create a FactorGroup with no Factor."):
+        fgroup.ORFactorGroup(variables_for_factors=[])
+
+    A = vgroup.NDVarArray(num_states=2, shape=(10,))
+    B = vgroup.NDVarArray(num_states=2, shape=(10,))
+
+    variables0 = (A[0], B[0])
+    variables1 = (A[1], B[1])
+    ORFactor0 = fgroup.ORFactorGroup(variables_for_factors=[variables0])
+    with pytest.raises(
+        ValueError, match="SingleFactorGroup should only contain one factor. Got 2"
+    ):
+        fgroup.SingleFactorGroup(
+            variables_for_factors=[variables0, variables1],
+            factor=ORFactor0,
+        )
+    ORFactor1 = fgroup.ORFactorGroup(variables_for_factors=[variables1])
+    ORFactor0 < ORFactor1
+
+
+def test_enumeration_factor_group():
+    vg = vgroup.NDVarArray(shape=(2, 2), num_states=3)
+    with pytest.raises(
+        ValueError,
+        match=re.escape("Expected log potentials shape: (1,) or (2, 1). Got (3, 2)"),
+    ):
+        enumeration_factor_group = fgroup.EnumFactorGroup(
+            variables_for_factors=[
+                [vg[0, 0], vg[0, 1], vg[1, 1]],
+                [vg[0, 1], vg[1, 0], vg[1, 1]],
+            ],
+            factor_configs=np.zeros((1, 3), dtype=int),
+            log_potentials=np.zeros((3, 2)),
+        )
+
+    with pytest.raises(ValueError, match=re.escape("Potentials should be floats")):
+        enumeration_factor_group = fgroup.EnumFactorGroup(
+            variables_for_factors=[
+                [vg[0, 0], vg[0, 1], vg[1, 1]],
+                [vg[0, 1], vg[1, 0], vg[1, 1]],
+            ],
+            factor_configs=np.zeros((1, 3), dtype=int),
+            log_potentials=np.zeros((2, 1), dtype=int),
+        )
+
+    enumeration_factor_group = fgroup.EnumFactorGroup(
+        variables_for_factors=[
+            [vg[0, 0], vg[0, 1], vg[1, 1]],
+            [vg[0, 1], vg[1, 0], vg[1, 1]],
+        ],
+        factor_configs=np.zeros((1, 3), dtype=int),
+    )
+    name = [vg[0, 0], vg[1, 1]]
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            f"The queried factor connected to the set of variables {frozenset(name)} is not present in the factor group."
+        ),
+    ):
+        enumeration_factor_group[name]
+
+    assert (
+        enumeration_factor_group[[vg[0, 1], vg[1, 0], vg[1, 1]]]
+        == enumeration_factor_group.factors[1]
+    )
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            "data should be of shape (2, 1) or (2, 9) or (1,). Got (4, 5)."
+        ),
+    ):
+        enumeration_factor_group.flatten(np.zeros((4, 5)))
+
+    assert jnp.all(enumeration_factor_group.flatten(np.ones(1)) == jnp.ones(2))
+    assert jnp.all(enumeration_factor_group.flatten(np.ones((2, 9))) == jnp.ones(18))
+    with pytest.raises(
+        ValueError, match=re.escape("Can only unflatten 1D array. Got a 3D array.")
+    ):
+        enumeration_factor_group.unflatten(jnp.ones((1, 2, 3)))
+
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            "flat_data should be compatible with shape (2, 1) or (2, 9). Got (30,)"
+        ),
+    ):
+        enumeration_factor_group.unflatten(jnp.zeros(30))
+
+    assert jnp.all(
+        enumeration_factor_group.unflatten(jnp.arange(2)) == jnp.array([[0], [1]])
+    )
+    assert jnp.all(enumeration_factor_group.unflatten(jnp.ones(18)) == jnp.ones((2, 9)))
+
+
+def test_pairwise_factor_group():
+    vg = vgroup.NDVarArray(shape=(2, 2), num_states=3)
+
+    with pytest.raises(
+        ValueError, match=re.escape("log_potential_matrix should be either a 2D array")
+    ):
+        fgroup.PairwiseFactorGroup([[vg[0, 0], vg[1, 1]]], np.zeros((1,), dtype=float))
+
+    with pytest.raises(
+        ValueError, match=re.escape("Potential matrix should be floats")
+    ):
+        fgroup.PairwiseFactorGroup([[vg[0, 0], vg[1, 1]]], np.zeros((3, 3), dtype=int))
+
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            "Expected log_potential_matrix for 1 factors. Got log_potential_matrix for 2 factors."
+        ),
+    ):
+        fgroup.PairwiseFactorGroup(
+            [[vg[0, 0], vg[1, 1]]], np.zeros((2, 3, 3), dtype=float)
+        )
+
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            "All pairwise factors should connect to exactly 2 variables. Got a factor connecting to 3 variables"
+        ),
+    ):
+        fgroup.PairwiseFactorGroup(
+            [[vg[0, 0], vg[1, 1], vg[0, 1]]], np.zeros((3, 3), dtype=float)
+        )
+
+    name = [vg[0, 0], vg[1, 1]]
+    with pytest.raises(
+        ValueError,
+        match=re.escape(f"The specified pairwise factor {name}"),
+    ):
+        fgroup.PairwiseFactorGroup([name], np.zeros((4, 4), dtype=float))
+
+    pairwise_factor_group = fgroup.PairwiseFactorGroup(
+        [[vg[0, 0], vg[1, 1]], [vg[1, 0], vg[0, 1]]],
+    )
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            "data should be of shape (2, 3, 3) or (2, 6) or (3, 3). Got (4, 4)."
+        ),
+    ):
+        pairwise_factor_group.flatten(np.zeros((4, 4)))
+
+    assert jnp.all(
+        pairwise_factor_group.flatten(np.zeros((3, 3))) == jnp.zeros(2 * 3 * 3)
+    )
+    assert jnp.all(pairwise_factor_group.flatten(np.zeros((2, 6))) == jnp.zeros(12))
+    with pytest.raises(ValueError, match="Can only unflatten 1D array. Got a 2D array"):
+        pairwise_factor_group.unflatten(np.zeros((10, 20)))
+
+    assert jnp.all(
+        pairwise_factor_group.unflatten(np.zeros(2 * 3 * 3)) == jnp.zeros((2, 3, 3))
+    )
+    assert jnp.all(
+        pairwise_factor_group.unflatten(np.zeros(2 * 6)) == jnp.zeros((2, 6))
+    )
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            "flat_data should be compatible with shape (2, 3, 3) or (2, 6). Got (10,)."
+        ),
+    ):
+        pairwise_factor_group.unflatten(np.zeros(10))
diff --git a/tests/fg/test_wiring.py b/tests/fgroup/test_wiring.py
similarity index 61%
rename from tests/fg/test_wiring.py
rename to tests/fgroup/test_wiring.py
index af2fb9fc..f10b1ebd 100644
--- a/tests/fg/test_wiring.py
+++ b/tests/fgroup/test_wiring.py
@@ -3,29 +3,25 @@
 import numpy as np
 import pytest
 
-from pgmax.factors import enumeration as enumeration_factor
-from pgmax.factors import logical as logical_factor
-from pgmax.fg import graph
-from pgmax.groups import enumeration, logical
-from pgmax.groups import variables as vgroup
+from pgmax import factor, fgraph, fgroup, vgroup
 
 
 def test_wiring_with_PairwiseFactorGroup():
     """
     Test the equivalence of the wiring compiled at the PairwiseFactorGroup level
-    vs at the individual EnumerationFactor level (which is called from SingleFactorGroup)
+    vs at the individual EnumFactor level (which is called from SingleFactorGroup)
     """
-    A = vgroup.NDVariableArray(num_states=2, shape=(10,))
-    B = vgroup.NDVariableArray(num_states=2, shape=(10,))
+    A = vgroup.NDVarArray(num_states=2, shape=(10,))
+    B = vgroup.NDVarArray(num_states=2, shape=(10,))
 
     # First test that compile_wiring enforces the correct factor_edges_num_states shape
-    fg = graph.FactorGraph(variable_groups=[A, B])
-    factor_group = enumeration.PairwiseFactorGroup(
+    fg = fgraph.FactorGraph(variable_groups=[A, B])
+    factor_group = fgroup.PairwiseFactorGroup(
         variables_for_factors=[[A[idx], B[idx]] for idx in range(10)]
     )
     fg.add_factors(factor_group)
 
-    factor_group = fg.factor_groups[enumeration_factor.EnumerationFactor][0]
+    factor_group = fg.factor_groups[factor.EnumFactor][0]
     object.__setattr__(
         factor_group, "factor_configs", factor_group.factor_configs[:, :1]
     )
@@ -36,43 +32,43 @@ def test_wiring_with_PairwiseFactorGroup():
         factor_group.compile_wiring(fg._vars_to_starts)
 
     # FactorGraph with a single PairwiseFactorGroup
-    fg1 = graph.FactorGraph(variable_groups=[A, B])
-    factor_group = enumeration.PairwiseFactorGroup(
+    fg1 = fgraph.FactorGraph(variable_groups=[A, B])
+    factor_group = fgroup.PairwiseFactorGroup(
         variables_for_factors=[[A[idx], B[idx]] for idx in range(10)]
     )
     fg1.add_factors(factor_group)
-    assert len(fg1.factor_groups[enumeration_factor.EnumerationFactor]) == 1
+    assert len(fg1.factor_groups[factor.EnumFactor]) == 1
 
     # FactorGraph with multiple PairwiseFactorGroup
-    fg2 = graph.FactorGraph(variable_groups=[A, B])
+    fg2 = fgraph.FactorGraph(variable_groups=[A, B])
     for idx in range(10):
-        factor_group = enumeration.PairwiseFactorGroup(
+        factor_group = fgroup.PairwiseFactorGroup(
             variables_for_factors=[[A[idx], B[idx]]]
         )
         fg2.add_factors(factor_group)
-    assert len(fg2.factor_groups[enumeration_factor.EnumerationFactor]) == 10
+    assert len(fg2.factor_groups[factor.EnumFactor]) == 10
 
     # FactorGraph with multiple SingleFactorGroup
-    fg3 = graph.FactorGraph(variable_groups=[A, B])
+    fg3 = fgraph.FactorGraph(variable_groups=[A, B])
     factors = []
     for idx in range(10):
-        factor = enumeration_factor.EnumerationFactor(
+        enum_factor = factor.EnumFactor(
             variables=[A[idx], B[idx]],
             factor_configs=np.array([[0, 0], [0, 1], [1, 0], [1, 1]]),
             log_potentials=np.zeros((4,)),
         )
-        factors.append(factor)
+        factors.append(enum_factor)
     fg3.add_factors(factors)
-    assert len(fg3.factor_groups[enumeration_factor.EnumerationFactor]) == 10
+    assert len(fg3.factor_groups[factor.EnumFactor]) == 10
 
     assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors)
 
     # Compile wiring via factor_group.compile_wiring
-    wiring1 = fg1.wiring[enumeration_factor.EnumerationFactor]
-    wiring2 = fg2.wiring[enumeration_factor.EnumerationFactor]
+    wiring1 = fg1.wiring[factor.EnumFactor]
+    wiring2 = fg2.wiring[factor.EnumFactor]
 
     # Compile wiring via factor.compile_wiring
-    wiring3 = fg3.wiring[enumeration_factor.EnumerationFactor]
+    wiring3 = fg3.wiring[factor.EnumFactor]
 
     assert np.all(wiring1.edges_num_states == wiring2.edges_num_states)
     assert np.all(wiring1.var_states_for_edges == wiring2.var_states_for_edges)
@@ -92,44 +88,44 @@ def test_wiring_with_ORFactorGroup():
     Test the equivalence of the wiring compiled at the ORFactorGroup level
     vs at the individual ORFactor level (which is called from SingleFactorGroup)
     """
-    A = vgroup.NDVariableArray(num_states=2, shape=(10,))
-    B = vgroup.NDVariableArray(num_states=2, shape=(10,))
-    C = vgroup.NDVariableArray(num_states=2, shape=(10,))
+    A = vgroup.NDVarArray(num_states=2, shape=(10,))
+    B = vgroup.NDVarArray(num_states=2, shape=(10,))
+    C = vgroup.NDVarArray(num_states=2, shape=(10,))
 
     # FactorGraph with a single ORFactorGroup
-    fg1 = graph.FactorGraph(variable_groups=[A, B, C])
-    factor_group = logical.ORFactorGroup(
+    fg1 = fgraph.FactorGraph(variable_groups=[A, B, C])
+    factor_group = fgroup.ORFactorGroup(
         variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)],
     )
     fg1.add_factors(factor_group)
-    assert len(fg1.factor_groups[logical_factor.ORFactor]) == 1
+    assert len(fg1.factor_groups[factor.ORFactor]) == 1
 
     # FactorGraph with multiple ORFactorGroup
-    fg2 = graph.FactorGraph(variable_groups=[A, B, C])
+    fg2 = fgraph.FactorGraph(variable_groups=[A, B, C])
     for idx in range(10):
-        factor_group = logical.ORFactorGroup(
+        factor_group = fgroup.ORFactorGroup(
             variables_for_factors=[[A[idx], B[idx], C[idx]]],
         )
         fg2.add_factors(factor_group)
-    assert len(fg2.factor_groups[logical_factor.ORFactor]) == 10
+    assert len(fg2.factor_groups[factor.ORFactor]) == 10
 
     # FactorGraph with multiple SingleFactorGroup
-    fg3 = graph.FactorGraph(variable_groups=[A, B, C])
+    fg3 = fgraph.FactorGraph(variable_groups=[A, B, C])
     for idx in range(10):
-        factor = logical_factor.ORFactor(
+        or_factor = factor.ORFactor(
             variables=[A[idx], B[idx], C[idx]],
         )
-        fg3.add_factors(factor)
-    assert len(fg3.factor_groups[logical_factor.ORFactor]) == 10
+        fg3.add_factors(or_factor)
+    assert len(fg3.factor_groups[factor.ORFactor]) == 10
 
     assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors)
 
     # Compile wiring via factor_group.compile_wiring
-    wiring1 = fg1.wiring[logical_factor.ORFactor]
-    wiring2 = fg2.wiring[logical_factor.ORFactor]
+    wiring1 = fg1.wiring[factor.ORFactor]
+    wiring2 = fg2.wiring[factor.ORFactor]
 
     # Compile wiring via factor.compile_wiring
-    wiring3 = fg3.wiring[logical_factor.ORFactor]
+    wiring3 = fg3.wiring[factor.ORFactor]
 
     assert np.all(wiring1.edges_num_states == wiring2.edges_num_states)
     assert np.all(wiring1.var_states_for_edges == wiring2.var_states_for_edges)
@@ -147,44 +143,44 @@ def test_wiring_with_ANDFactorGroup():
     Test the equivalence of the wiring compiled at the ANDFactorGroup level
     vs at the individual ANDFactor level (which is called from SingleFactorGroup)
     """
-    A = vgroup.NDVariableArray(num_states=2, shape=(10,))
-    B = vgroup.NDVariableArray(num_states=2, shape=(10,))
-    C = vgroup.NDVariableArray(num_states=2, shape=(10,))
+    A = vgroup.NDVarArray(num_states=2, shape=(10,))
+    B = vgroup.NDVarArray(num_states=2, shape=(10,))
+    C = vgroup.NDVarArray(num_states=2, shape=(10,))
 
     # FactorGraph with a single ANDFactorGroup
-    fg1 = graph.FactorGraph(variable_groups=[A, B, C])
-    factor_group = logical.ANDFactorGroup(
+    fg1 = fgraph.FactorGraph(variable_groups=[A, B, C])
+    factor_group = fgroup.ANDFactorGroup(
         variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)],
     )
     fg1.add_factors(factor_group)
-    assert len(fg1.factor_groups[logical_factor.ANDFactor]) == 1
+    assert len(fg1.factor_groups[factor.ANDFactor]) == 1
 
     # FactorGraph with multiple ANDFactorGroup
-    fg2 = graph.FactorGraph(variable_groups=[A, B, C])
+    fg2 = fgraph.FactorGraph(variable_groups=[A, B, C])
     for idx in range(10):
-        factor_group = logical.ANDFactorGroup(
+        factor_group = fgroup.ANDFactorGroup(
             variables_for_factors=[[A[idx], B[idx], C[idx]]],
         )
         fg2.add_factors(factor_group)
-    assert len(fg2.factor_groups[logical_factor.ANDFactor]) == 10
+    assert len(fg2.factor_groups[factor.ANDFactor]) == 10
 
     # FactorGraph with multiple SingleFactorGroup
-    fg3 = graph.FactorGraph(variable_groups=[A, B, C])
+    fg3 = fgraph.FactorGraph(variable_groups=[A, B, C])
     for idx in range(10):
-        factor = logical_factor.ANDFactor(
+        and_factor = factor.ANDFactor(
             variables=[A[idx], B[idx], C[idx]],
         )
-        fg3.add_factors(factor)
-    assert len(fg3.factor_groups[logical_factor.ANDFactor]) == 10
+        fg3.add_factors(and_factor)
+    assert len(fg3.factor_groups[factor.ANDFactor]) == 10
 
     assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors)
 
     # Compile wiring via factor_group.compile_wiring
-    wiring1 = fg1.wiring[logical_factor.ANDFactor]
-    wiring2 = fg2.wiring[logical_factor.ANDFactor]
+    wiring1 = fg1.wiring[factor.ANDFactor]
+    wiring2 = fg2.wiring[factor.ANDFactor]
 
     # Compile wiring via factor.compile_wiring
-    wiring3 = fg3.wiring[logical_factor.ANDFactor]
+    wiring3 = fg3.wiring[factor.ANDFactor]
 
     assert np.all(wiring1.edges_num_states == wiring2.edges_num_states)
     assert np.all(wiring1.var_states_for_edges == wiring2.var_states_for_edges)
diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py
index f3ee7bfa..211f1482 100644
--- a/tests/test_pgmax.py
+++ b/tests/test_pgmax.py
@@ -6,10 +6,7 @@
 from numpy.random import default_rng
 from scipy.ndimage import gaussian_filter
 
-from pgmax.factors.enumeration import EnumerationFactor
-from pgmax.fg import graph, nodes
-from pgmax.groups import enumeration
-from pgmax.groups import variables as vgroup
+from pgmax import factor, fgraph, fgroup, infer, vgroup
 
 # Set random seed for rng
 rng = default_rng(23)
@@ -195,19 +192,19 @@ def create_valid_suppression_config_arr(suppression_diameter):
     # Now, we specify the valid configurations for all the suppression factors
     SUPPRESSION_DIAMETER = 2
     valid_configs_supp = create_valid_suppression_config_arr(SUPPRESSION_DIAMETER)
-    # We create a NDVariableArray such that the [0,i,j] entry corresponds to the vertical cut variable (i.e, the one
+    # We create a NDVarArray such that the [0,i,j] entry corresponds to the vertical cut variable (i.e, the one
     # attached horizontally to the factor) that's at that location in the image, and the [1,i,j] entry corresponds to
     # the horizontal cut variable (i.e, the one attached vertically to the factor) that's at that location
-    # We create a NDVariableArray such that the [0,i,j] entry corresponds to the vertical cut variable (i.e, the one
+    # We create a NDVarArray such that the [0,i,j] entry corresponds to the vertical cut variable (i.e, the one
     # attached horizontally to the factor) that's at that location in the image, and the [1,i,j] entry corresponds to
     # the horizontal cut variable (i.e, the one attached vertically to the factor) that's at that location
-    grid_vars = vgroup.NDVariableArray(shape=(2, M - 1, N - 1), num_states=3)
+    grid_vars = vgroup.NDVarArray(shape=(2, M - 1, N - 1), num_states=3)
 
     # Make a group of additional variables for the edges of the grid
     extra_row_names: List[Tuple[Any, ...]] = [(0, row, N - 1) for row in range(M - 1)]
     extra_col_names: List[Tuple[Any, ...]] = [(1, M - 1, col) for col in range(N - 1)]
     additional_names = tuple(extra_row_names + extra_col_names)
-    additional_vars = vgroup.VariableDict(variable_names=additional_names, num_states=3)
+    additional_vars = vgroup.VarDict(variable_names=additional_names, num_states=3)
 
     true_map_state_output = {
         (grid_vars, (0, 0, 0)): 2,
@@ -260,9 +257,9 @@ def create_valid_suppression_config_arr(suppression_diameter):
                         pass
 
     # Create the factor graph
-    fg = graph.FactorGraph(variable_groups=[grid_vars, additional_vars])
+    fg = fgraph.FactorGraph(variable_groups=[grid_vars, additional_vars])
 
-    # Imperatively add EnumerationFactorGroups (each consisting of just one EnumerationFactor) to
+    # Imperatively add EnumFactorGroups (each consisting of just one EnumFactor) to
     # the graph!
     for row in range(M - 1):
         for col in range(N - 1):
@@ -297,25 +294,25 @@ def create_valid_suppression_config_arr(suppression_diameter):
                     additional_vars[1, row + 1, col],
                 ]
             if row % 2 == 0:
-                factor = EnumerationFactor(
+                enum_factor = factor.EnumFactor(
                     variables=curr_vars,
                     factor_configs=valid_configs_non_supp,
                     log_potentials=np.zeros(
                         valid_configs_non_supp.shape[0], dtype=float
                     ),
                 )
-                fg.add_factors(factor)
+                fg.add_factors(enum_factor)
             else:
-                factor = EnumerationFactor(
+                enum_factor = factor.EnumFactor(
                     variables=curr_vars,
                     factor_configs=valid_configs_non_supp,
                     log_potentials=np.zeros(
                         valid_configs_non_supp.shape[0], dtype=float
                     ),
                 )
-                fg.add_factors(factor)
+                fg.add_factors(enum_factor)
 
-    # Create an EnumerationFactorGroup for vertical suppression factors
+    # Create an EnumFactorGroup for vertical suppression factors
     vert_suppression_vars: List[List[Tuple[Any, ...]]] = []
     for col in range(N):
         for start_row in range(M - SUPPRESSION_DIAMETER):
@@ -353,13 +350,13 @@ def create_valid_suppression_config_arr(suppression_diameter):
                 )
 
     # Add the suppression factors to the graph via kwargs
-    factor_group = enumeration.EnumerationFactorGroup(
+    factor_group = fgroup.EnumFactorGroup(
         variables_for_factors=vert_suppression_vars,
         factor_configs=valid_configs_supp,
     )
     fg.add_factors(factor_group)
 
-    factor_group = enumeration.EnumerationFactorGroup(
+    factor_group = fgroup.EnumFactorGroup(
         variables_for_factors=horz_suppression_vars,
         factor_configs=valid_configs_supp,
         log_potentials=np.zeros(valid_configs_supp.shape[0], dtype=float),
@@ -371,17 +368,17 @@ def create_valid_suppression_config_arr(suppression_diameter):
     bp_state = fg.bp_state
     assert np.all(
         [
-            isinstance(jax.device_put(this_wiring), nodes.Wiring)
+            isinstance(jax.device_put(this_wiring), factor.Wiring)
             for this_wiring in fg.fg_state.wiring.values()
         ]
     )
     bp_state.evidence[grid_vars] = grid_evidence_arr
     bp_state.evidence[additional_vars] = additional_vars_evidence_dict
-    bp = graph.BP(bp_state)
+    bp = infer.BP(bp_state)
     bp_arrays = bp.run_bp(bp.init(), num_iters=100)
     # Test that the output messages are close to the true messages
     assert jnp.allclose(bp_arrays.ftov_msgs, true_final_msgs_output, atol=1e-06)
-    decoded_map_states = graph.decode_map_states(bp.get_beliefs(bp_arrays))
+    decoded_map_states = infer.decode_map_states(bp.get_beliefs(bp_arrays))
     for name in true_map_state_output:
         assert true_map_state_output[name] == decoded_map_states[name[0]][name[1]]
 
@@ -389,16 +386,16 @@ def create_valid_suppression_config_arr(suppression_diameter):
 def test_e2e_heretic():
     # Define some global constants
     im_size = (30, 30)
-    # Instantiate all the Variables in the factor graph via VariableGroups
-    pixel_vars = vgroup.NDVariableArray(shape=im_size, num_states=3)
-    hidden_vars = vgroup.NDVariableArray(
+    # Instantiate all the Variables in the factor graph via VarGroups
+    pixel_vars = vgroup.NDVarArray(shape=im_size, num_states=3)
+    hidden_vars = vgroup.NDVarArray(
         shape=(im_size[0] - 2, im_size[1] - 2), num_states=17
     )  # Each hidden var is connected to a 3x3 patch of pixel vars
 
     bXn = np.zeros((30, 30, 3))
 
     # Create the factor graph
-    fg = graph.FactorGraph([pixel_vars, hidden_vars])
+    fg = fgraph.FactorGraph([pixel_vars, hidden_vars])
 
     def binary_connected_variables(
         num_hidden_rows, num_hidden_cols, kernel_row, kernel_col
@@ -417,7 +414,7 @@ def binary_connected_variables(
     W_pot = np.zeros((17, 3, 3, 3), dtype=float)
     for k_row in range(3):
         for k_col in range(3):
-            factor_group = enumeration.PairwiseFactorGroup(
+            factor_group = fgroup.PairwiseFactorGroup(
                 variables_for_factors=binary_connected_variables(28, 28, k_row, k_col),
                 log_potential_matrix=W_pot[:, :, k_row, k_col],
             )
@@ -431,7 +428,7 @@ def binary_connected_variables(
     bp_state.evidence[hidden_vars[0, 0]]
     assert isinstance(bp_state.evidence.value, np.ndarray)
     assert len(sum(fg.factors.values(), ())) == 7056
-    bp = graph.BP(bp_state, temperature=1.0)
+    bp = infer.BP(bp_state, temperature=1.0)
     bp_arrays = bp.run_bp(bp.init(), num_iters=1)
-    marginals = graph.get_marginals(bp.get_beliefs(bp_arrays))
+    marginals = infer.get_marginals(bp.get_beliefs(bp_arrays))
     assert jnp.allclose(jnp.sum(marginals[pixel_vars], axis=-1), 1.0)
diff --git a/tests/vgroup/test_vgroup.py b/tests/vgroup/test_vgroup.py
new file mode 100644
index 00000000..11d0db47
--- /dev/null
+++ b/tests/vgroup/test_vgroup.py
@@ -0,0 +1,130 @@
+import re
+
+import jax
+import jax.numpy as jnp
+import numpy as np
+import pytest
+
+from pgmax import vgroup
+
+
+def test_variable_dict():
+    num_states = np.full((4,), fill_value=2)
+    with pytest.raises(
+        ValueError, match=re.escape("Expected num_states shape (3,). Got (4,).")
+    ):
+        vgroup.VarDict(variable_names=tuple([0, 1, 2]), num_states=num_states)
+
+    num_states = np.full((3,), fill_value=2, dtype=np.float32)
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            "num_states should be an integer or a NumPy array of dtype int"
+        ),
+    ):
+        vgroup.VarDict(variable_names=tuple([0, 1, 2]), num_states=num_states)
+
+    variable_dict = vgroup.VarDict(variable_names=tuple([0, 1, 2]), num_states=15)
+    with pytest.raises(
+        ValueError, match="data is referring to a non-existent variable 3"
+    ):
+        variable_dict.flatten({3: np.zeros(10)})
+
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            "Variable 2 expects a data array of shape (15,) or (1,). Got (10,)."
+        ),
+    ):
+        variable_dict.flatten({2: np.zeros(10)})
+
+    with pytest.raises(
+        ValueError, match="Can only unflatten 1D array. Got a 2D array."
+    ):
+        variable_dict.unflatten(jnp.zeros((10, 20)))
+
+    assert jnp.all(
+        jnp.array(
+            jax.tree_util.tree_leaves(
+                jax.tree_util.tree_multimap(
+                    lambda x, y: jnp.all(x == y),
+                    variable_dict.unflatten(jnp.zeros(3)),
+                    {name: np.zeros(1) for name in range(3)},
+                )
+            )
+        )
+    )
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            "flat_data should be either of shape (num_variables(=3),), or (num_variable_states(=45),)"
+        ),
+    ):
+        variable_dict.unflatten(jnp.zeros((100)))
+
+
+def test_nd_variable_array():
+    max_size = int(vgroup.vgroup.MAX_SIZE)
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            f"Currently only support NDVarArray of size smaller than {max_size}. Got {max_size + 1}"
+        ),
+    ):
+        vgroup.NDVarArray(shape=(max_size + 1,), num_states=2)
+
+    num_states = np.full((2, 3), fill_value=2)
+    with pytest.raises(
+        ValueError, match=re.escape("Expected num_states shape (2, 2). Got (2, 3).")
+    ):
+        vgroup.NDVarArray(shape=(2, 2), num_states=num_states)
+
+    num_states = np.full((2, 3), fill_value=2, dtype=np.float32)
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            "num_states should be an integer or a NumPy array of dtype int"
+        ),
+    ):
+        vgroup.NDVarArray(shape=(2, 2), num_states=num_states)
+
+    variable_group0 = vgroup.NDVarArray(shape=(5, 5), num_states=2)
+    assert len(variable_group0[:3, :3]) == 9
+
+    variable_group = vgroup.NDVarArray(
+        shape=(2, 2), num_states=np.array([[1, 2], [3, 4]])
+    )
+    variable_group0 < variable_group
+
+    with pytest.raises(
+        ValueError,
+        match=re.escape("data should be of shape (2, 2) or (2, 2, 4). Got (3, 3)."),
+    ):
+        variable_group.flatten(np.zeros((3, 3)))
+
+    assert jnp.all(
+        variable_group.flatten(np.array([[1, 2], [3, 4]])) == jnp.array([1, 2, 3, 4])
+    )
+    assert jnp.all(variable_group.flatten(np.zeros((2, 2, 4))) == jnp.zeros((10,)))
+
+    with pytest.raises(
+        ValueError, match="Can only unflatten 1D array. Got a 2D array."
+    ):
+        variable_group.unflatten(np.zeros((10, 20)))
+
+    with pytest.raises(
+        ValueError,
+        match=re.escape(
+            "flat_data should be compatible with shape (2, 2) or (2, 2, 4). Got (12,)."
+        ),
+    ):
+        variable_group.unflatten(np.zeros((12,)))
+
+    assert jnp.all(variable_group.unflatten(np.zeros(4)) == jnp.zeros((2, 2)))
+    unflattened = jnp.full((2, 2, 4), fill_value=jnp.nan)
+    unflattened = unflattened.at[0, 0, 0].set(0)
+    unflattened = unflattened.at[0, 1, :1].set(0)
+    unflattened = unflattened.at[1, 0, :2].set(0)
+    unflattened = unflattened.at[1, 1].set(0)
+    mask = ~jnp.isnan(unflattened)
+    assert jnp.all(variable_group.unflatten(np.zeros(10))[mask] == unflattened[mask])