From 66bbcbcd47d83d34278866ce2deae870838c7f39 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Tue, 12 Apr 2022 18:01:04 +0000 Subject: [PATCH 01/35] Rewrite NDVariableArray --- .pre-commit-config.yaml | 10 +- examples/ising_model.py | 25 ++- examples/pmp_binary_deconvolution.py | 61 ++++-- examples/rbm.py | 47 +++-- pgmax/factors/enumeration.py | 20 +- pgmax/factors/logical.py | 10 +- pgmax/fg/graph.py | 290 ++++++++++++++++--------- pgmax/fg/groups.py | 305 ++------------------------- pgmax/fg/nodes.py | 25 +-- pgmax/groups/enumeration.py | 43 ++-- pgmax/groups/logical.py | 9 +- pgmax/groups/variables.py | 281 +++++++++++++----------- tests/fg/test_wiring.py | 44 ++-- 13 files changed, 539 insertions(+), 631 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cfec6112..7e553c33 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,10 +27,10 @@ repos: args: ['--config', '.flake8.config','--exit-zero'] verbose: true -- repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v0.942' # Use the sha / tag you want to point at - hooks: - - id: mypy - additional_dependencies: [tokenize-rt==3.2.0] +# - repo: https://github.com/pre-commit/mirrors-mypy +# rev: 'v0.942' # Use the sha / tag you want to point at +# hooks: +# - id: mypy +# additional_dependencies: [tokenize-rt==3.2.0] ci: autoupdate_schedule: 'quarterly' diff --git a/examples/ising_model.py b/examples/ising_model.py index 5b805c1a..ef5a267b 100644 --- a/examples/ising_model.py +++ b/examples/ising_model.py @@ -30,35 +30,50 @@ # %% variables = vgroup.NDVariableArray(num_states=2, shape=(50, 50)) fg = graph.FactorGraph(variables=variables) + +# TODO: rename variable_for_factors? variable_names_for_factors = [] for ii in range(50): for jj in range(50): kk = (ii + 1) % 50 ll = (jj + 1) % 50 - variable_names_for_factors.append([(ii, jj), (kk, jj)]) - variable_names_for_factors.append([(ii, jj), (ii, ll)]) + variable_names_for_factors.append([variables[ii, jj], variables[kk, jj]]) + variable_names_for_factors.append([variables[ii, jj], variables[kk, ll]]) fg.add_factor_group( factory=enumeration.PairwiseFactorGroup, variable_names_for_factors=variable_names_for_factors, log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]), - name="factors", ) +# %% +from pgmax.factors import enumeration as enumeration_factor + +factors = fg.factor_groups[enumeration_factor.EnumerationFactor][0].factors + # %% [markdown] # ### Run inference and visualize results +import imp + # %% +from pgmax.fg import graph + +imp.reload(graph) bp = graph.BP(fg.bp_state, temperature=0) # %% +# TODO: check\ time for before BP vs time for BP +# TODO: why bug when done twice? +# TODO: time PGMAX vs PMP bp_arrays = bp.init( - evidence_updates={None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))} + evidence_updates={variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))} ) bp_arrays = bp.run_bp(bp_arrays, num_iters=3000) +output = bp.get_bp_output(bp_arrays) # %% -img = graph.decode_map_states(bp.get_beliefs(bp_arrays)) +img = output.map_states[variables] fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(img) diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py index c00aac38..b10f5688 100644 --- a/examples/pmp_binary_deconvolution.py +++ b/examples/pmp_binary_deconvolution.py @@ -107,7 +107,13 @@ def plot_images(images, display=True, nr=None): # # See Section 5.6 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) for more details. +import imp + # %% +from pgmax.fg import graph + +imp.reload(graph) + # The dimensions of W used for the generation of X were (4, 5, 5) but we set them to (5, 6, 6) # to simulate a more realistic scenario in which we do not know their ground truth values n_feat, feat_height, feat_width = 5, 6, 6 @@ -116,6 +122,9 @@ def plot_images(images, display=True, nr=None): s_height = im_height - feat_height + 1 s_width = im_width - feat_width + 1 +import time + +start = time.time() # Binary features W = vgroup.NDVariableArray( num_states=2, shape=(n_chan, n_feat, feat_height, feat_width) @@ -132,13 +141,16 @@ def plot_images(images, display=True, nr=None): # Binary images obtained by convolution X = vgroup.NDVariableArray(num_states=2, shape=X_gt.shape) +print("Time", time.time() - start) # %% [markdown] # For computation efficiency, we add large FactorGroups via `fg.add_factor_group` instead of adding individual Factors # %% +start = time.time() # Factor graph -fg = graph.FactorGraph(variables=dict(S=S, W=W, SW=SW, X=X)) +fg = graph.FactorGraph(variables=[S, W, SW, X]) +print("x", time.time() - start) # Define the ANDFactors variable_names_for_ANDFactors = [] @@ -152,8 +164,7 @@ def plot_images(images, display=True, nr=None): for idx_feat_width in range(feat_width): idx_img_height = idx_feat_height + idx_s_height idx_img_width = idx_feat_width + idx_s_width - SW_var = ( - "SW", + SW_var = SW[ idx_img, idx_chan, idx_img_height, @@ -161,35 +172,31 @@ def plot_images(images, display=True, nr=None): idx_feat, idx_feat_height, idx_feat_width, - ) + ] variable_names_for_ANDFactor = [ - ("S", idx_img, idx_feat, idx_s_height, idx_s_width), - ( - "W", - idx_chan, - idx_feat, - idx_feat_height, - idx_feat_width, - ), + S[idx_img, idx_feat, idx_s_height, idx_s_width], + W[idx_chan, idx_feat, idx_feat_height, idx_feat_width], SW_var, ] variable_names_for_ANDFactors.append( variable_names_for_ANDFactor ) - X_var = (idx_img, idx_chan, idx_img_height, idx_img_width) + X_var = X[idx_img, idx_chan, idx_img_height, idx_img_width] variable_names_for_ORFactors_dict[X_var].append(SW_var) +print(time.time() - start) # Add ANDFactorGroup, which is computationally efficient fg.add_factor_group( factory=logical.ANDFactorGroup, variable_names_for_factors=variable_names_for_ANDFactors, ) +print(time.time() - start) # Define the ORFactors variable_names_for_ORFactors = [ - list(tuple(variable_names_for_ORFactors_dict[X_var]) + (("X",) + X_var,)) + list(tuple(variable_names_for_ORFactors_dict[X_var]) + (X_var,)) for X_var in variable_names_for_ORFactors_dict ] @@ -198,6 +205,7 @@ def plot_images(images, display=True, nr=None): factory=logical.ORFactorGroup, variable_names_for_factors=variable_names_for_ORFactors, ) +print("Time", time.time() - start) for factor_type, factor_groups in fg.factor_groups.items(): if len(factor_groups) > 0: @@ -215,7 +223,9 @@ def plot_images(images, display=True, nr=None): # in the same manner does not change X, so this naturally results in multiple equivalent modes. # %% +start = time.time() bp = graph.BP(fg.bp_state, temperature=0.0) +print("Time", time.time() - start) # %% [markdown] # We first compute the evidence without perturbation, similar to the PMP paper. @@ -236,6 +246,29 @@ def plot_images(images, display=True, nr=None): uX = np.zeros((X_gt.shape) + (2,)) uX[..., 0] = (2 * X_gt - 1) * logit(pX) +# %% +np.random.seed(seed=40) +n_samples = 1 + +start = time.time() +bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)( + evidence_updates={ + "S": uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape), + "W": uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape), + "SW": np.zeros(shape=(n_samples,) + SW.shape), + "X": uX[None] + np.zeros(shape=(n_samples,) + uX.shape), + }, +) +print("Time", time.time() - start) +bp_arrays = jax.vmap( + functools.partial(bp.run_bp, num_iters=100, damping=0.5), + in_axes=0, + out_axes=0, +)(bp_arrays) +print("Time", time.time() - start) +# beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays) +# map_states = graph.decode_map_states(beliefs) + # %% [markdown] # We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap` diff --git a/examples/rbm.py b/examples/rbm.py index d0faa91a..e4d26a83 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -45,13 +45,21 @@ # %% [markdown] # We can then initialize the factor graph for the RBM with +import imp + # %% +from pgmax.fg import graph + +imp.reload(graph) + +import time + +start = time.time() # Initialize factor graph hidden_variables = vgroup.NDVariableArray(num_states=2, shape=bh.shape) visible_variables = vgroup.NDVariableArray(num_states=2, shape=bv.shape) -fg = graph.FactorGraph( - variables=dict(hidden=hidden_variables, visible=visible_variables), -) +fg = graph.FactorGraph(variables=[hidden_variables, visible_variables]) +print("Time", time.time() - start) # %% [markdown] # [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup) (e.g. an [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)), or a list/dictionary of [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed. @@ -59,17 +67,18 @@ # After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors by grouping them using # %% +start = time.time() # Add unary factors fg.add_factor_group( factory=enumeration.EnumerationFactorGroup, - variable_names_for_factors=[[("hidden", ii)] for ii in range(bh.shape[0])], + variable_names_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])], factor_configs=np.arange(2)[:, None], log_potentials=np.stack([np.zeros_like(bh), bh], axis=1), ) fg.add_factor_group( factory=enumeration.EnumerationFactorGroup, - variable_names_for_factors=[[("visible", jj)] for jj in range(bv.shape[0])], + variable_names_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])], factor_configs=np.arange(2)[:, None], log_potentials=np.stack([np.zeros_like(bv), bv], axis=1), ) @@ -81,13 +90,16 @@ fg.add_factor_group( factory=enumeration.PairwiseFactorGroup, variable_names_for_factors=[ - [("hidden", ii), ("visible", jj)] + [hidden_variables[ii], visible_variables[jj]] for ii in range(bh.shape[0]) for jj in range(bv.shape[0]) ], log_potential_matrix=log_potential_matrix, ) +# fg.add_factor_group(factory=enumeration.PairwiseFactorGroup, variable_names_for_factors=[[hidden_variables[ii], visible_variables[jj]]for ii in range(bh.shape[0])for jj in range(bv.shape[0])], log_potential_matrix=log_potential_matrix,) +print("Time", time.time() - start) + # %% [markdown] # PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing Groups of similar factors. The code above makes use of [`EnumerationFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumerationFactorGroup.html#pgmax.fg.groups.EnumerationFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup), two [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)s implemented in the [`pgmax.fg.groups`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.html#module-pgmax.fg.graph) module. @@ -146,18 +158,23 @@ # Now we are ready to demonstrate PMP sampling from RBM. PMP perturbs the model with [Gumbel](https://numpy.org/doc/stable/reference/random/generated/numpy.random.gumbel.html) unary potentials, and draws a sample from the RBM as the MAP decoding from running max-product LBP on the perturbed model # %% +start = time.time() bp = graph.BP(fg.bp_state, temperature=0.0) +print("Time", time.time() - start) # %% +start = time.time() bp_arrays = bp.init( evidence_updates={ - "hidden": np.random.gumbel(size=(bh.shape[0], 2)), - "visible": np.random.gumbel(size=(bv.shape[0], 2)), + hidden_variables: np.random.gumbel(size=(bh.shape[0], 2)), + visible_variables: np.random.gumbel(size=(bv.shape[0], 2)), }, ) +print("Time", time.time() - start) bp_arrays = bp.run_bp(bp_arrays, num_iters=100, damping=0.5) -beliefs = bp.get_beliefs(bp_arrays) -map_states = graph.decode_map_states(beliefs) +print("Time", time.time() - start) +output = bp.get_bp_output(bp_arrays) +print("Time", time.time() - start) # %% [markdown] # Here we use the `evidence_updates` argument of `run_bp` to perturb the model with Gumbel unary potentials. In general, `evidence_updates` can be used to incorporate evidence in the form of externally applied unary potentials in PGM inference. @@ -166,7 +183,7 @@ # %% fig, ax = plt.subplots(1, 1, figsize=(10, 10)) -ax.imshow(map_states["visible"].copy().reshape((28, 28)), cmap="gray") +ax.imshow(output.map_states[visible_variables].copy().reshape((28, 28)), cmap="gray") ax.axis("off") # %% [markdown] @@ -203,8 +220,12 @@ in_axes=0, out_axes=0, )(bp_arrays) -beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays) -map_states = graph.decode_map_states(beliefs) + +outputs = jax.vmap(bp.get_bp_output, in_axes=0, out_axes=0)(bp_arrays) +# map_states = graph.decode_map_states(beliefs) + +# %% +O # %% [markdown] # Visualizing the MAP decodings (Figure [fig:rbm_multiple_digits]), we see that we have sampled 10 MNIST digits in parallel! diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py index 61d349d5..e2521597 100644 --- a/pgmax/factors/enumeration.py +++ b/pgmax/factors/enumeration.py @@ -81,9 +81,9 @@ def __post_init__(self): f"EnumerationFactor. Got a factor_configs array of shape {self.factor_configs.shape}." ) - if len(self.variables) != self.factor_configs.shape[1]: + if len(self.vars_to_num_states.keys()) != self.factor_configs.shape[1]: raise ValueError( - f"Number of variables {len(self.variables)} doesn't match given configurations {self.factor_configs.shape}" + f"Number of variables {len(self.vars_to_num_states.keys())} doesn't match given configurations {self.factor_configs.shape}" ) if self.log_potentials.shape != (self.factor_configs.shape[0],): @@ -93,7 +93,7 @@ def __post_init__(self): f"shape {self.log_potentials.shape}." ) - vars_num_states = np.array([variable.num_states for variable in self.variables]) + vars_num_states = np.array([list(self.vars_to_num_states.values())]) if not np.logical_and( self.factor_configs >= 0, self.factor_configs < vars_num_states[None] ).all(): @@ -155,9 +155,9 @@ def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiri @staticmethod def compile_wiring( factor_edges_num_states: np.ndarray, - variables_for_factors: Tuple[nodes.Variable, ...], + variables_for_factors: Tuple[int, ...], # TODO: rename factor_configs: np.ndarray, - vars_to_starts: Mapping[nodes.Variable, int], + vars_to_starts: Mapping[int, int], num_factors: int, ) -> EnumerationWiring: """Compile an EnumerationWiring for an EnumerationFactor or a FactorGroup with EnumerationFactors. @@ -166,7 +166,7 @@ def compile_wiring( Args: factor_edges_num_states: An array concatenating the number of states for the variables connected to each Factor of the FactorGroup. Each variable will appear once for each Factor it connects to. - variables_for_factors: A tuple of tuples containing variables connected to each Factor of the FactorGroup. + variables_for_factors: A tuple containing variables connected to each Factor of the FactorGroup. Each variable will appear once for each Factor it connects to. factor_configs: Array of shape (num_val_configs, num_variables) containing an explicit enumeration of all valid configurations. @@ -184,10 +184,10 @@ def compile_wiring( var_states = np.array( [vars_to_starts[variable] for variable in variables_for_factors] ) - num_states = np.array( - [variable.num_states for variable in variables_for_factors] - ) - num_states_cumsum = np.insert(np.cumsum(num_states), 0, 0) + # num_states = np.array( + # [variable.num_states for variable in variables_for_factors] + # ) + num_states_cumsum = np.insert(np.cumsum(factor_edges_num_states), 0, 0) var_states_for_edges = np.empty(shape=(num_states_cumsum[-1],), dtype=int) _compile_var_states_numba(var_states_for_edges, num_states_cumsum, var_states) diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py index 8a27d0be..c53a5fac 100644 --- a/pgmax/factors/logical.py +++ b/pgmax/factors/logical.py @@ -86,12 +86,14 @@ class LogicalFactor(nodes.Factor): edge_states_offset: int = field(init=False) def __post_init__(self): - if len(self.variables) < 2: + if len(self.vars_to_num_states.keys()) < 2: raise ValueError( "At least one parent variable and one child variable is required" ) - if not np.all([variable.num_states == 2 for variable in self.variables]): + if not np.all( + [num_states == 2 for num_states in self.vars_to_num_states.values()] + ): raise ValueError("All variables should all be binary") @staticmethod @@ -141,9 +143,9 @@ def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring: @staticmethod def compile_wiring( factor_edges_num_states: np.ndarray, - variables_for_factors: Tuple[nodes.Variable, ...], + variables_for_factors: Tuple[int, ...], # notsure factor_sizes: np.ndarray, - vars_to_starts: Mapping[nodes.Variable, int], + vars_to_starts: Mapping[int, int], edge_states_offset: int, ) -> LogicalWiring: """Compile a LogicalWiring for a LogicalFactor or a FactorGroup with LogicalFactors. diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index c0b4f096..f0a0aefe 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -1,5 +1,7 @@ from __future__ import annotations +from this import d + """A module containing the core class to specify a Factor Graph.""" import collections @@ -7,7 +9,7 @@ import functools import inspect import typing -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from types import MappingProxyType from typing import ( Any, @@ -35,6 +37,7 @@ from pgmax.bp import infer from pgmax.factors import FAC_TO_VAR_UPDATES from pgmax.fg import groups, nodes +from pgmax.groups import variables as vgroup from pgmax.groups.enumeration import EnumerationFactorGroup from pgmax.utils import cached_property @@ -53,23 +56,34 @@ class FactorGraph: this input, and the individual VariableGroups will need to be accessed by indexing. """ - variables: Union[ - Mapping[Any, groups.VariableGroup], - Sequence[groups.VariableGroup], - groups.VariableGroup, - ] + variables: Union[groups.VariableGroup, Sequence[groups.VariableGroup]] def __post_init__(self): - if isinstance(self.variables, groups.VariableGroup): - self._variable_group = self.variables - else: - self._variable_group = groups.CompositeVariableGroup(self.variables) - + import time + + start = time.time() + # if isinstance(self.variables, groups.VariableGroup): + if isinstance(self.variables, vgroup.NDVariableArray): + self.variables = [self.variables] + + self._variable_group: Mapping[ + int, groups.VariableGroup + ] = collections.OrderedDict() + for variable_group in self.variables: + self._variable_group[variable_group.__hash__()] = variable_group + + vars_names = [] + vars_num_states = [] + for variable_group in self.variables: + if isinstance(variable_group, vgroup.NDVariableArray): + vars_names.append(variable_group.variable_names.flatten()) + vars_num_states.append(variable_group.num_states.flatten()) + print("1", time.time() - start) + + vars_names = np.concatenate(vars_names) + vars_num_states = np.concatenate(vars_num_states) vars_num_states_cumsum = np.insert( - np.array( - [variable.num_states for variable in self._variable_group.variables], - dtype=int, - ).cumsum(), + np.array(vars_num_states).cumsum(), 0, 0, ) @@ -85,16 +99,22 @@ def __post_init__(self): ] = collections.OrderedDict( [(factor_type, set()) for factor_type in FAC_TO_VAR_UPDATES] ) - + print("2", time.time() - start) + + # Used to add FactorGroups + # TODO: dict is faster + aa = zip(vars_names, vars_num_states) + print("3", time.time() - start) + self._vars_to_num_states: OrderedDict[int, int] = collections.OrderedDict( + zip(vars_names, vars_num_states) + ) # See FactorGraphState docstrings for documentation on the following fields self._num_var_states = vars_num_states_cumsum[-1] - self._vars_to_starts = MappingProxyType( - { - variable: vars_num_states_cumsum[vv] - for vv, variable in enumerate(self._variable_group.variables) - } + self._vars_to_starts: OrderedDict[int, int] = collections.OrderedDict( + zip(vars_names, vars_num_states_cumsum[:-1]) ) self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {} + print("3", time.time() - start) def __hash__(self) -> int: all_factor_groups = tuple( @@ -130,7 +150,7 @@ def add_factor( initialized. """ factor_group = EnumerationFactorGroup( - self._variable_group, + self._vars_to_num_states, variable_names_for_factors=[variable_names], factor_configs=factor_configs, log_potentials=log_potentials, @@ -138,7 +158,7 @@ def add_factor( self._register_factor_group(factor_group, name) def add_factor_by_type( - self, variable_names: List, factor_type: type, *args, **kwargs + self, variable_names: List[int], factor_type: type, *args, **kwargs ) -> None: """Function to add a single factor to the FactorGraph. @@ -163,15 +183,16 @@ def add_factor_by_type( f"Type {factor_type} is not one of the supported factor types {FAC_TO_VAR_UPDATES.keys()}" ) - name = kwargs.pop("name", None) - variables = tuple(self._variable_group[variable_names]) - factor = factor_type(variables, *args, **kwargs) + vars_to_num_states = collections.OrderedDict( + (var, self._vars_to_num_states[var]) for var in variable_names + ) + factor = factor_type(vars_to_num_states, *args, **kwargs) factor_group = groups.SingleFactorGroup( - variable_group=self._variable_group, + vars_to_num_states=self._vars_to_num_states, variable_names_for_factors=[variable_names], factor=factor, ) - self._register_factor_group(factor_group, name) + self._register_factor_group(factor_group) def add_factor_group(self, factory: Callable, *args, **kwargs) -> None: """Add a factor group to the factor graph @@ -182,13 +203,10 @@ def add_factor_group(self, factory: Callable, *args, **kwargs) -> None: kwargs: kwargs to be passed to the factory function, and an optional "name" argument for specifying the name of a named factor group. """ - name = kwargs.pop("name", None) - factor_group = factory(self._variable_group, *args, **kwargs) - self._register_factor_group(factor_group, name) + factor_group = factory(self._vars_to_num_states, *args, **kwargs) + self._register_factor_group(factor_group) - def _register_factor_group( - self, factor_group: groups.FactorGroup, name: Optional[str] = None - ) -> None: + def _register_factor_group(self, factor_group: groups.FactorGroup) -> None: """Register a factor group to the factor graph, by updating the factor graph state. Args: @@ -199,10 +217,6 @@ def _register_factor_group( ValueError: If the factor group with the same name or a factor involving the same variables already exists in the factor graph. """ - if name in self._named_factor_groups: - raise ValueError( - f"A factor group with the name {name} already exists. Please choose a different name!" - ) factor_type = factor_group.factor_type for var_names_for_factor in factor_group.variable_names_for_factors: @@ -211,15 +225,13 @@ def _register_factor_group( var_names in self._factor_types_to_variable_names_for_factors[factor_type] ): + print(len(var_names_for_factor)) raise ValueError( f"A Factor of type {factor_type} involving variables {var_names} already exists. Please merge the corresponding factors." ) - self._factor_types_to_variable_names_for_factors[factor_type].add(var_names) self._factor_types_to_groups[factor_type].append(factor_group) - if name is not None: - self._named_factor_groups[name] = factor_group @functools.lru_cache(None) def compute_offsets(self) -> None: @@ -411,7 +423,7 @@ class FactorGraphState: """ variable_group: groups.VariableGroup - vars_to_starts: Mapping[nodes.Variable, int] + vars_to_starts: Mapping[int, int] num_var_states: int total_factor_num_states: int named_factor_groups: Mapping[Hashable, groups.FactorGroup] @@ -719,7 +731,7 @@ def __setitem__(self, names, data) -> None: ) -@functools.partial(jax.jit, static_argnames="fg_state") +# @functools.partial(jax.jit, static_argnames="fg_state") def update_evidence( evidence: jnp.ndarray, updates: Dict[Any, jnp.ndarray], fg_state: FactorGraphState ) -> jnp.ndarray: @@ -733,26 +745,21 @@ def update_evidence( Returns: A flat jnp array containing updated evidence. """ - for name in updates: - data = updates[name] - if name in fg_state.variable_group.container_names: - if name is None: - variable_group = fg_state.variable_group - else: - assert isinstance( - fg_state.variable_group, groups.CompositeVariableGroup - ) - variable_group = fg_state.variable_group.variable_group_container[name] - - start_index = fg_state.vars_to_starts[variable_group.variables[0]] + for var_group_name in updates: + data = updates[var_group_name] + print(data.shape) + if var_group_name.__hash__() in fg_state.variable_group: + variable_group = fg_state.variable_group[var_group_name.__hash__()] + first_variable = variable_group.variable_names.flatten()[0] + start_index = fg_state.vars_to_starts[first_variable] flat_data = variable_group.flatten(data) evidence = evidence.at[start_index : start_index + flat_data.shape[0]].set( flat_data ) - else: - var = fg_state.variable_group[name] - start_index = fg_state.vars_to_starts[var] - evidence = evidence.at[start_index : start_index + var.num_states].set(data) + # else: + # var = fg_state.variable_group[name] + # start_index = fg_state.vars_to_starts[var] + # evidence = evidence.at[start_index : start_index + var.num_states].set(data) return evidence @@ -889,18 +896,18 @@ class BeliefPropagation: Returns: The reconstructed BPState - get_beliefs: Function to calculate beliefs from a BPArrays. + get_bp_output: Function to calculate beliefs from a BPArrays. Args: bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence. Returns: - beliefs: An array or a PyTree container containing the beliefs for the variables. + bp_output: Belief propagation output. """ init: Callable update: Callable run_bp: Callable to_bp_state: Callable - get_beliefs: Callable + get_bp_output: Callable def BP(bp_state: BPState, temperature: float = 0.0) -> BeliefPropagation: @@ -978,6 +985,7 @@ def update( ) if evidence_updates is not None: + print(type(evidence), type(evidence_updates)) evidence = update_evidence(evidence, evidence_updates, bp_state.fg_state) return BPArrays( @@ -1065,62 +1073,150 @@ def to_bp_state(bp_arrays: BPArrays) -> BPState: evidence=Evidence(fg_state=bp_state.fg_state, value=bp_arrays.evidence), ) - @jax.jit - def get_beliefs(bp_arrays: BPArrays) -> Any: + def get_bp_output(bp_arrays: BPArrays) -> Any: """Function to calculate beliefs from a BPArrays Args: bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence. Returns: - beliefs: An array or a PyTree container containing the beliefs for the variables. + bp_output: Belief propagation output. """ - beliefs = bp_state.fg_state.variable_group.unflatten( - jax.device_put(bp_arrays.evidence) - .at[jax.device_put(var_states_for_edges)] - .add(bp_arrays.ftov_msgs) + + @jax.jit + def compute_flat_beliefs(bp_arrays): + flat_beliefs = ( + jax.device_put(bp_arrays.evidence) + .at[jax.device_put(var_states_for_edges)] + .add(bp_arrays.ftov_msgs) + ) + return flat_beliefs + + return BeliefPropagationOutputs( + compute_flat_beliefs(bp_arrays), bp_state.fg_state.variable_group ) - return beliefs bp = BeliefPropagation( init=functools.partial(update, None), update=update, run_bp=run_bp, to_bp_state=to_bp_state, - get_beliefs=get_beliefs, + get_bp_output=get_bp_output, ) return bp -@jax.jit -def decode_map_states(beliefs: Any) -> Any: - """Function to decode MAP states given the calculated beliefs. +@dataclass(frozen=True, eq=False) +class HashableDict: + d: Dict = field(default_factory=dict) - Args: - beliefs: An array or a PyTree container containing beliefs for different variables. + def __setitem__(self, key, value): + self.d[key] = value - Returns: - An array or a PyTree container containing the MAP states for different variables. - """ - map_states = jax.tree_util.tree_map( - lambda x: jnp.argmax(x, axis=-1), - beliefs, - ) - return map_states + def __getitem__(self, value): + return self.d[value.__hash__()] -@jax.jit -def get_marginals(beliefs: Any) -> Any: - """Function to get marginal probabilities given the calculated beliefs. +@dataclass(frozen=True, eq=False) +class BeliefPropagationOutputs: + # beliefs: An array or a PyTree container containing the beliefs for the variables. + flat_beliefs: jnp.ndarray + variable_groups: Mapping[int, groups.VariableGroup] + beliefs: HashableDict = field(init=False, default_factory=dict) - Args: - beliefs: An array or a PyTree container containing beliefs for different variables. + def __post_init__(self): + if self.flat_beliefs.ndim != 1: + raise ValueError( + f"Can only unflatten 1D array. Got a {self.flat_beliefs.ndim}D array." + ) + beliefs = self.unflatten() + object.__setattr__(self, "beliefs", beliefs) - Returns: - An array or a PyTree container containing the marginal probabilities different variables. - """ - marginals = jax.tree_util.tree_map( - lambda x: jnp.exp(x - logsumexp(x, axis=-1, keepdims=True)), - beliefs, - ) - return marginals + @functools.lru_cache + def unflatten(self) -> None: + """Function that recovers meaningful structured data from internal flat data array + + Args: + variable_groups: TODO + + Returns: + Meaningful structured data, with structure matching that of self.variable_group_container. + + Raises: + ValueError: if flat_data is not of the right shape + """ + # Note: this is a reimplementation of CompositeVariableGroup.unflatten + num_variables = 0 + num_variable_states = 0 + for variable_group in self.variable_groups.values(): + if isinstance(variable_group, vgroup.NDVariableArray): + num_variables += variable_group.num_states.size + num_variable_states += variable_group.num_states.sum() + + if self.flat_beliefs.shape[0] == num_variables: + use_num_states = False + elif self.flat_beliefs.shape[0] == num_variable_states: + use_num_states = True + else: + raise ValueError( + f"flat_data should be either of shape (num_variables(={len(self.variables)}),), " + f"or (num_variable_states(={num_variable_states}),). " + f"Got {self.flat_beliefs.shape}" + ) + + beliefs = {} + start = 0 + for name, variable_group in self.variable_groups.items(): + if use_num_states: + length = variable_group.num_states.sum() + else: + length = variable_group.num_states.size + + beliefs[name] = variable_group.unflatten( + self.flat_beliefs[start : start + length] + ) + start += length + return beliefs + + @cached_property + def map_states(self) -> Any: + """Function to decode MAP states given the calculated beliefs. + + Args: + beliefs: An array or a PyTree container containing beliefs for different variables. + + Returns: + An array or a PyTree container containing the MAP states for different variables. + """ + + @jax.jit + def _decode_map_states(beliefs) -> Any: + return jax.tree_util.tree_map(lambda x: jnp.argmax(x, axis=-1), beliefs) + + map_states = HashableDict() + for name, beliefs in self.beliefs.items(): + map_states[name] = _decode_map_states(beliefs) + return map_states + + @cached_property + def get_marginals(self) -> Any: + """Function to get marginal probabilities given the calculated beliefs. + + Args: + beliefs: An array or a PyTree container containing beliefs for different variables. + + Returns: + An array or a PyTree container containing the marginal probabilities different variables. + """ + + @jax.jit + def _get_marginals(beliefs) -> Any: + return jax.tree_util.tree_map( + lambda x: jnp.exp(x - logsumexp(x, axis=-1, keepdims=True)), + beliefs, + ) + + marginals = HashableDict() + for name, beliefs in self.beliefs.items(): + marginals[name] = _get_marginals(beliefs) + return marginals diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 7b0f767d..c90fe106 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -33,34 +33,13 @@ class VariableGroup: All variables in the group are assumed to have the same size. Additionally, the variables are indexed by a variable name, and can be retrieved by direct indexing (even indexing a sequence of variable names) of the VariableGroup. - - Attributes: - _names_to_variables: A private, immutable mapping from variable names to variables """ - _names_to_variables: Mapping[Hashable, nodes.Variable] = field(init=False) - - def __post_init__(self) -> None: - """Initialize a private, immutable mapping from variable names to variables.""" - object.__setattr__( - self, - "_names_to_variables", - MappingProxyType(self._get_names_to_variables()), - ) - - @typing.overload - def __getitem__(self, name: Hashable) -> nodes.Variable: - """This function is a typing overload and is overwritten by the implemented __getitem__""" - - @typing.overload - def __getitem__(self, name: List) -> List[nodes.Variable]: - """This function is a typing overload and is overwritten by the implemented __getitem__""" - - def __getitem__(self, name): + def __getitem__(self, val): """Given a name, retrieve the associated Variable. Args: - name: a single name corresponding to a single variable, or a list of such names + val: a single name corresponding to a single variable, or a list of such names Returns: A single variable if the name is not a list. A list of variables if name is a list @@ -68,61 +47,20 @@ def __getitem__(self, name): Raises: ValueError: if the name is not found in the group """ - if isinstance(name, List): - names_list = name - else: - names_list = [name] - - vars_list = [] - for curr_name in names_list: - var = self._names_to_variables.get(curr_name) - if var is None: - raise ValueError( - f"The name {curr_name} is not present in the VariableGroup {type(self)}; please ensure " - "it's been added to the VariableGroup before trying to query it." - ) - - vars_list.append(var) - - if isinstance(name, List): - return vars_list - else: - return vars_list[0] - - def _get_names_to_variables(self) -> OrderedDict[Any, nodes.Variable]: - """Function that generates a dictionary mapping names to variables. - - Returns: - a dictionary mapping all possible names to different variables. - """ raise NotImplementedError( "Please subclass the VariableGroup class and override this method" ) @cached_property - def names(self) -> Tuple[Any, ...]: - """Function to return a tuple of all names in the group. - - Returns: - tuple of all names that are part of this VariableGroup - """ - return tuple(self._names_to_variables.keys()) - - @cached_property - def variables(self) -> Tuple[nodes.Variable, ...]: - """Function to return a tuple of all variables in the group. + def variables_names(self) -> Any: + """Function that generates a dictionary mapping names to variables. Returns: - tuple of all variable that are part of this VariableGroup - """ - return tuple(self._names_to_variables.values()) - - @cached_property - def container_names(self) -> Tuple: - """Placeholder function. Returns a tuple containing None for all variable groups - other than a composite variable group + a dictionary mapping all possible names to different variables. """ - return (None,) + raise NotImplementedError( + "Please subclass the VariableGroup class and override this method" + ) def flatten(self, data: Any) -> jnp.ndarray: """Function that turns meaningful structured data into a flat data array for internal use. @@ -151,220 +89,21 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any: ) -@dataclass(frozen=True, eq=False) -class CompositeVariableGroup(VariableGroup): - """A class to encapsulate a collection of instantiated VariableGroups. - - This class enables users to wrap various different VariableGroups and then index - them in a straightforward manner. To index into a CompositeVariableGroup, simply - provide the name of the VariableGroup within this CompositeVariableGroup followed - by the name to be indexed within the VariableGroup. - - Args: - variable_group_container: A container containing multiple variable groups. - Supported containers include mapping and sequence. - For a mapping, the names of the mapping are used to index the variable groups. - For a sequence, the indices of the sequence are used to index the variable groups. - - Attributes: - _names_to_variables: A private, immutable mapping from names to variables - """ - - variable_group_container: Union[ - Mapping[Hashable, VariableGroup], Sequence[VariableGroup] - ] - - def __post_init__(self): - object.__setattr__( - self, - "_names_to_variables", - MappingProxyType(self._get_names_to_variables()), - ) - - @typing.overload - def __getitem__(self, name: Hashable) -> nodes.Variable: - """This function is a typing overload and is overwritten by the implemented __getitem__""" - - @typing.overload - def __getitem__(self, name: List) -> List[nodes.Variable]: - """This function is a typing overload and is overwritten by the implemented __getitem__""" - - def __getitem__(self, name): - """Given a name, retrieve the associated Variable from the associated VariableGroup. - - Args: - name: a single name corresponding to a single Variable within a VariableGroup, or a list - of such names - - Returns: - A single variable if the name is not a list. A list of variables if name is a list - - Raises: - ValueError: if the name does not have the right format (tuples with at least two elements). - """ - if isinstance(name, List): - names_list = name - else: - names_list = [name] - - vars_list = [] - for curr_name in names_list: - if len(curr_name) < 2: - raise ValueError( - "The name needs to have at least 2 elements to index from a composite variable group." - ) - - variable_group = self.variable_group_container[curr_name[0]] - if len(curr_name) == 2: - vars_list.append(variable_group[curr_name[1]]) - else: - vars_list.append(variable_group[curr_name[1:]]) - - if isinstance(name, List): - return vars_list - else: - return vars_list[0] - - def _get_names_to_variables(self) -> OrderedDict[Hashable, nodes.Variable]: - """Function that generates a dictionary mapping names to variables. - - Returns: - a dictionary mapping all possible names to different variables. - """ - names_to_variables: OrderedDict[ - Hashable, nodes.Variable - ] = collections.OrderedDict() - for container_name in self.container_names: - for variable_group_name in self.variable_group_container[ - container_name - ].names: - if isinstance(variable_group_name, tuple): - names_to_variables[ - (container_name,) + variable_group_name - ] = self.variable_group_container[container_name][ - variable_group_name - ] - else: - names_to_variables[ - (container_name, variable_group_name) - ] = self.variable_group_container[container_name][ - variable_group_name - ] - - return names_to_variables - - def flatten(self, data: Union[Mapping, Sequence]) -> jnp.ndarray: - """Function that turns meaningful structured data into a flat data array for internal use. - - Args: - data: Meaningful structured data. - The structure of data should match self.variable_group_container. - - Returns: - A flat jnp.array for internal use - """ - flat_data = jnp.concatenate( - [ - self.variable_group_container[name].flatten(data[name]) - for name in self.container_names - ] - ) - return flat_data - - def unflatten( - self, flat_data: Union[np.ndarray, jnp.ndarray] - ) -> Union[Mapping, Sequence]: - """Function that recovers meaningful structured data from internal flat data array - - Args: - flat_data: Internal flat data array. - - Returns: - Meaningful structured data, with structure matching that of self.variable_group_container. - - Raises: - ValueError if: - (1) flat_data is not a 1D array - (2) flat_data is not of the right shape - """ - if flat_data.ndim != 1: - raise ValueError( - f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." - ) - - num_variables = 0 - num_variable_states = 0 - for name in self.container_names: - variable_group = self.variable_group_container[name] - num_variables += len(variable_group.variables) - num_variable_states += ( - len(variable_group.variables) * variable_group.variables[0].num_states - ) - - if flat_data.shape[0] == num_variables: - use_num_states = False - elif flat_data.shape[0] == num_variable_states: - use_num_states = True - else: - raise ValueError( - f"flat_data should be either of shape (num_variables(={len(self.variables)}),), " - f"or (num_variable_states(={num_variable_states}),). " - f"Got {flat_data.shape}" - ) - - data: List[np.ndarray] = [] - start = 0 - for name in self.container_names: - variable_group = self.variable_group_container[name] - length = len(variable_group.variables) - if use_num_states: - length *= variable_group.variables[0].num_states - - data.append(variable_group.unflatten(flat_data[start : start + length])) - start += length - if isinstance(self.variable_group_container, Mapping): - return dict( - [(name, data[kk]) for kk, name in enumerate(self.container_names)] - ) - else: - return data - - @cached_property - def container_names(self) -> Tuple: - """Function to get names referring to the variable groups within this - CompositeVariableGroup. - - Returns: - a tuple of the names referring to the variable groups within this - CompositeVariableGroup. - """ - if isinstance(self.variable_group_container, Mapping): - container_names = tuple(self.variable_group_container.keys()) - else: - container_names = tuple(range(len(self.variable_group_container))) - - return container_names - - @dataclass(frozen=True, eq=False) class FactorGroup: """Class to represent a group of Factors. Args: - variable_group: either a VariableGroup or - if the elements of more than one VariableGroup - are connected to this FactorGroup - then a CompositeVariableGroup. This holds - all the variables that are connected to this FactorGroup + vars_to_num_states: TODO variable_names_for_factors: A list of list of variable names, where each innermost element is the - name of a variable in variable_group. Each list within the outer list is taken to contain - the names of the variables connected to a Factor. + name of a variable. Each list within the outer list is taken to contain the names of the + variables connected to a Factor. factor_configs: Optional array containing an explicit enumeration of all valid configurations log_potentials: Array of log potentials. Attributes: factor_type: Factor type shared by all the Factors in the FactorGroup. factor_sizes: Array of the different factor sizes. - variables_for_factors: Tuple concatenating the variables connected to each factor in the FactorGroup. - Each variable will appear once for each Factor it connects to. factor_edges_num_states: Array concatenating the number of states for the variables connected to each Factor of the FactorGroup. Each variable will appear once for each Factor it connects to. @@ -372,36 +111,39 @@ class FactorGroup: ValueError: if the FactorGroup does not contain a Factor """ - variable_group: Union[CompositeVariableGroup, VariableGroup] + vars_to_num_states: Mapping[int, int] variable_names_for_factors: Sequence[List] factor_configs: np.ndarray = field(init=False) log_potentials: np.ndarray = field(init=False, default=np.empty((0,))) factor_type: Type = field(init=False) factor_sizes: np.ndarray = field(init=False) - variables_for_factors: Tuple[nodes.Variable, ...] = field(init=False) + variables_for_factors: Tuple[Tuple[int], ...] = field(init=False) factor_edges_num_states: np.ndarray = field(init=False) def __post_init__(self): if len(self.variable_names_for_factors) == 0: raise ValueError("Do not add a factor group with no factors.") + # Note: variable_names_for_factors contains the HASHes + # Note: this can probably be sped up by numba factor_sizes = [] - variables_for_factors = [] + flat_var_names_for_factors = [] factor_edges_num_states = [] for variable_names_for_factor in self.variable_names_for_factors: for variable_name in variable_names_for_factor: - variable = self.variable_group._names_to_variables[variable_name] - variables_for_factors.append(variable) - factor_edges_num_states.append(variable.num_states) + factor_edges_num_states.append(self.vars_to_num_states[variable_name]) + flat_var_names_for_factors.append(variable_name) factor_sizes.append(len(variable_names_for_factor)) object.__setattr__(self, "factor_sizes", np.array(factor_sizes)) - object.__setattr__(self, "variables_for_factors", tuple(variables_for_factors)) + object.__setattr__( + self, "variables_for_factors", np.array(flat_var_names_for_factors) + ) object.__setattr__( self, "factor_edges_num_states", np.array(factor_edges_num_states) ) - def __getitem__(self, variables: Union[Sequence, Collection]) -> Any: + def __getitem__(self, variables: Sequence[int]) -> Any: """Function to query individual factors in the factor group Args: @@ -419,7 +161,6 @@ def __getitem__(self, variables: Union[Sequence, Collection]) -> Any: raise ValueError( f"The queried factor connected to the set of variables {variables} is not present in the factor group." ) - return self._variables_to_factors[variables] @cached_property @@ -485,7 +226,7 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any: "Please subclass the FactorGroup class and override this method" ) - def compile_wiring(self, vars_to_starts: Mapping[nodes.Variable, int]) -> Any: + def compile_wiring(self, vars_to_starts: Mapping[int, int]) -> Any: """Compile an efficient wiring for the FactorGroup. Args: diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py index 32fa29ad..97fe7c5e 100644 --- a/pgmax/fg/nodes.py +++ b/pgmax/fg/nodes.py @@ -1,7 +1,7 @@ """A module containing classes that specify the basic components of a Factor Graph.""" from dataclasses import asdict, dataclass -from typing import Sequence, Tuple, Union +from typing import OrderedDict, Sequence, Union import jax import jax.numpy as jnp @@ -10,19 +10,6 @@ from pgmax import utils -@dataclass(frozen=True, eq=False) -class Variable: - """Base class for variables. - If desired, this can be sub-classed to add additional concrete - meta-information - - Args: - num_states: an int representing the number of states this variable has. - """ - - num_states: int - - @jax.tree_util.register_pytree_node_class @dataclass(frozen=True, eq=False) class Wiring: @@ -56,13 +43,14 @@ class Factor: """A factor Args: - variables: List of connected variables + vars_to_num_states: Dictionnary mapping the variables names, represented + in the form of a hash, to the variables number of states. Raises: NotImplementedError: If compile_wiring is not implemented """ - variables: Tuple[Variable, ...] + vars_to_num_states: OrderedDict[int, int] log_potentials: np.ndarray def __post_init__(self): @@ -79,10 +67,7 @@ def edges_num_states(self) -> np.ndarray: Array of shape (num_edges,) Number of states for the variables connected to each edge """ - edges_num_states = np.array( - [variable.num_states for variable in self.variables], dtype=int - ) - return edges_num_states + return self.vars_to_num_states.values() @staticmethod def concatenate_wirings(wirings: Sequence) -> Wiring: diff --git a/pgmax/groups/enumeration.py b/pgmax/groups/enumeration.py index 84a9c033..f091c53d 100644 --- a/pgmax/groups/enumeration.py +++ b/pgmax/groups/enumeration.py @@ -80,16 +80,19 @@ def _get_variables_to_factors( variables_to_factors = collections.OrderedDict( [ ( - frozenset(self.variable_names_for_factors[ii]), + frozenset(variable_names_for_factor), enumeration.EnumerationFactor( - variables=tuple( - self.variable_group[self.variable_names_for_factors[ii]] + vars_to_num_states=collections.OrderedDict( + (var, self.vars_to_num_states[var]) + for var in variable_names_for_factor ), factor_configs=self.factor_configs, log_potentials=np.array(self.log_potentials)[ii], ), ) - for ii in range(len(self.variable_names_for_factors)) + for ii, variable_names_for_factor in enumerate( + self.variable_names_for_factors + ) ] ) return variables_to_factors @@ -207,12 +210,8 @@ def __post_init__(self): if self.log_potential_matrix is None: log_potential_matrix = np.zeros( ( - self.variable_group[ - self.variable_names_for_factors[0][0] - ].num_states, - self.variable_group[ - self.variable_names_for_factors[0][1] - ].num_states, + self.vars_to_num_states[self.variable_names_for_factors[0][0]], + self.vars_to_num_states[self.variable_names_for_factors[0][1]], ) ) else: @@ -245,18 +244,11 @@ def __post_init__(self): f" {len(fac_list)} variables ({fac_list})." ) - # Note: num_states0 = self.variable_group[fac_list[0]] is 2x slower - num_states0 = self.variable_group._names_to_variables[ - fac_list[0] - ].num_states - num_states1 = self.variable_group._names_to_variables[ - fac_list[1] - ].num_states - + num_states0 = self.vars_to_num_states[fac_list[0]] + num_states1 = self.vars_to_num_states[fac_list[1]] if not log_potential_matrix.shape[-2:] == (num_states0, num_states1): raise ValueError( - f"The specified pairwise factor {fac_list} (with " - f"{(self.variable_group[fac_list[0]].num_states, self.variable_group[fac_list[1]].num_states)} " + f"The specified pairwise factor {fac_list} (with {(num_states0, num_states1)}" f"configurations) does not match the specified log_potential_matrix " f"(with {log_potential_matrix.shape[-2:]} configurations)." ) @@ -294,16 +286,19 @@ def _get_variables_to_factors( variables_to_factors = collections.OrderedDict( [ ( - frozenset(self.variable_names_for_factors[ii]), + frozenset(variable_names_for_factor), enumeration.EnumerationFactor( - variables=tuple( - self.variable_group[self.variable_names_for_factors[ii]] + vars_to_num_states=collections.OrderedDict( + (var, self.vars_to_num_states[var]) + for var in variable_names_for_factor ), factor_configs=self.factor_configs, log_potentials=self.log_potentials[ii], ), ) - for ii in range(len(self.variable_names_for_factors)) + for ii, variable_names_for_factor in enumerate( + self.variable_names_for_factors + ) ] ) return variables_to_factors diff --git a/pgmax/groups/logical.py b/pgmax/groups/logical.py index 5003df02..325a4313 100644 --- a/pgmax/groups/logical.py +++ b/pgmax/groups/logical.py @@ -34,14 +34,15 @@ def _get_variables_to_factors( variables_to_factors = collections.OrderedDict( [ ( - frozenset(self.variable_names_for_factors[ii]), + frozenset(variable_names_for_factor), self.factor_type( - variables=tuple( - self.variable_group[self.variable_names_for_factors[ii]] + vars_to_num_states=collections.OrderedDict( + (var, self.vars_to_num_states[var]) + for var in variable_names_for_factor ), ), ) - for ii in range(len(self.variable_names_for_factors)) + for variable_names_for_factor in self.variable_names_for_factors ] ) return variables_to_factors diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index 959aabbc..dd381233 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -3,17 +3,27 @@ import collections import itertools from dataclasses import dataclass -from typing import Any, Dict, Hashable, Mapping, OrderedDict, Tuple, Union +from typing import ( + Any, + Dict, + Hashable, + Mapping, + Optional, + OrderedDict, + Set, + Tuple, + Union, +) import jax import jax.numpy as jnp import numpy as np -from pgmax.fg import groups, nodes +from pgmax.utils import cached_property @dataclass(frozen=True, eq=False) -class NDVariableArray(groups.VariableGroup): +class NDVariableArray: """Subclass of VariableGroup for n-dimensional grids of variables. Args: @@ -22,27 +32,37 @@ class NDVariableArray(groups.VariableGroup): the notion of a NumPy ndarray shape) """ - num_states: int shape: Tuple[int, ...] + num_states: Union[int, np.ndarray] - def _get_names_to_variables( - self, - ) -> OrderedDict[Union[int, Tuple[int, ...]], nodes.Variable]: + def __post_init__(self): + # super().__post_init__() + + if isinstance(self.num_states, int): + num_states = np.full(self.shape, fill_value=self.num_states) + object.__setattr__(self, "num_states", num_states) + elif isinstance(self.num_states, np.ndarray): + if self.num_states.shape != self.shape: + raise ValueError("Should be same shape") + + @cached_property + def variable_names(self) -> np.ndarray: """Function that generates a dictionary mapping names to variables. Returns: a dictionary mapping all possible names to different variables. """ - names_to_variables: OrderedDict[ - Union[int, Tuple[int, ...]], nodes.Variable - ] = collections.OrderedDict() - for name in itertools.product(*[list(range(k)) for k in self.shape]): - if len(name) == 1: - names_to_variables[name[0]] = nodes.Variable(self.num_states) - else: - names_to_variables[name] = nodes.Variable(self.num_states) - - return names_to_variables + variable_names = np.empty(self.shape, dtype=int) + self_hash = self.__hash__() + for index in itertools.product(*[list(range(k)) for k in self.shape]): + name = hash((self_hash, index)) + variable_names[index] = name + return variable_names + + def __getitem__(self, val): + # Numpy indexation will throw IndexError is out-of-bounds + # This will be used to add FactorGroups + return self.variable_names[val] def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: """Function that turns meaningful structured data into a flat data array for internal use. @@ -57,12 +77,14 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: Raises: ValueError: If the data is not of the correct shape. """ - if data.shape != self.shape and data.shape != self.shape + (self.num_states,): + # TODO: what should we do for different number of states + if data.shape != self.shape and data.shape != self.shape + ( + self.num_states.max(), + ): raise ValueError( - f"data should be of shape {self.shape} or {self.shape + (self.num_states,)}. " + f"data should be of shape {self.shape} or {self.shape + (self.num_states.max(),)}. " f"Got {data.shape}." ) - return jax.device_put(data).flatten() def unflatten( @@ -89,8 +111,9 @@ def unflatten( if flat_data.size == np.product(self.shape): data = flat_data.reshape(self.shape) - elif flat_data.size == np.product(self.shape) * self.num_states: - data = flat_data.reshape(self.shape + (self.num_states,)) + elif flat_data.size == self.num_states.sum(): + # TODO: what should we dot for different number of states + data = flat_data.reshape(self.shape + (self.num_states.max(),)) else: raise ValueError( f"flat_data should be compatible with shape {self.shape} or {self.shape + (self.num_states,)}. " @@ -100,110 +123,110 @@ def unflatten( return data -@dataclass(frozen=True, eq=False) -class VariableDict(groups.VariableGroup): - """A variable dictionary that contains a set of variables of the same size - - Args: - num_states: The size of the variables in this variable group - variable_names: A tuple of all names of the variables in this variable group - - """ - - num_states: int - variable_names: Tuple[Any, ...] - - def _get_names_to_variables(self) -> OrderedDict[Tuple[int, ...], nodes.Variable]: - """Function that generates a dictionary mapping names to variables. - - Returns: - a dictionary mapping all possible names to different variables. - """ - names_to_variables: OrderedDict[ - Tuple[Any, ...], nodes.Variable - ] = collections.OrderedDict() - for name in self.variable_names: - names_to_variables[name] = nodes.Variable(self.num_states) - - return names_to_variables - - def flatten( - self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]] - ) -> jnp.ndarray: - """Function that turns meaningful structured data into a flat data array for internal use. - - Args: - data: Meaningful structured data. Should be a mapping with names from self.variable_names. - Each value should be an array of shape (1,) (for e.g. MAP decodings) or - (self.num_states,) (for e.g. evidence, beliefs). - - Returns: - A flat jnp.array for internal use - - Raises: - ValueError if: - (1) data is referring to a non-existing variable - (2) data is not of the correct shape - """ - for name in data: - if name not in self._names_to_variables: - raise ValueError( - f"data is referring to a non-existent variable {name}." - ) - - if data[name].shape != (self.num_states,) and data[name].shape != (1,): - raise ValueError( - f"Variable {name} expects a data array of shape " - f"{(self.num_states,)} or (1,). Got {data[name].shape}." - ) - - flat_data = jnp.concatenate([data[name].flatten() for name in self.names]) - return flat_data - - def unflatten( - self, flat_data: Union[np.ndarray, jnp.ndarray] - ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]: - """Function that recovers meaningful structured data from internal flat data array - - Args: - flat_data: Internal flat data array. - - Returns: - Meaningful structured data. Should be a mapping with names from self.variable_names. - Each value should be an array of shape (1,) (for e.g. MAP decodings) or - (self.num_states,) (for e.g. evidence, beliefs). - - Raises: - ValueError if: - (1) flat_data is not a 1D array - (2) flat_data is not of the right shape - """ - if flat_data.ndim != 1: - raise ValueError( - f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." - ) - - num_variables = len(self.variable_names) - num_variable_states = len(self.variable_names) * self.num_states - if flat_data.shape[0] == num_variables: - use_num_states = False - elif flat_data.shape[0] == num_variable_states: - use_num_states = True - else: - raise ValueError( - f"flat_data should be either of shape (num_variables(={len(self.variables)}),), " - f"or (num_variable_states(={num_variable_states}),). " - f"Got {flat_data.shape}" - ) - - start = 0 - data = {} - for name in self.variable_names: - if use_num_states: - data[name] = flat_data[start : start + self.num_states] - start += self.num_states - else: - data[name] = flat_data[np.array([start])] - start += 1 - - return data +# @dataclass(frozen=True, eq=False) +# class VariableDict(groups.VariableGroup): +# """A variable dictionary that contains a set of variables of the same size + +# Args: +# num_states: The size of the variables in this variable group +# variable_names: A tuple of all names of the variables in this variable group + +# """ + +# num_states: int +# variable_names: Tuple[Any, ...] + +# def _get_names_to_variables(self) -> OrderedDict[Tuple[int, ...], nodes.Variable]: +# """Function that generates a dictionary mapping names to variables. + +# Returns: +# a dictionary mapping all possible names to different variables. +# """ +# names_to_variables: OrderedDict[ +# Tuple[Any, ...], nodes.Variable +# ] = collections.OrderedDict() +# for name in self.variable_names: +# names_to_variables[name] = nodes.Variable(self.num_states) + +# return names_to_variables + +# def flatten( +# self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]] +# ) -> jnp.ndarray: +# """Function that turns meaningful structured data into a flat data array for internal use. + +# Args: +# data: Meaningful structured data. Should be a mapping with names from self.variable_names. +# Each value should be an array of shape (1,) (for e.g. MAP decodings) or +# (self.num_states,) (for e.g. evidence, beliefs). + +# Returns: +# A flat jnp.array for internal use + +# Raises: +# ValueError if: +# (1) data is referring to a non-existing variable +# (2) data is not of the correct shape +# """ +# for name in data: +# if name not in self._names_to_variables: +# raise ValueError( +# f"data is referring to a non-existent variable {name}." +# ) + +# if data[name].shape != (self.num_states,) and data[name].shape != (1,): +# raise ValueError( +# f"Variable {name} expects a data array of shape " +# f"{(self.num_states,)} or (1,). Got {data[name].shape}." +# ) + +# flat_data = jnp.concatenate([data[name].flatten() for name in self.names]) +# return flat_data + +# def unflatten( +# self, flat_data: Union[np.ndarray, jnp.ndarray] +# ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]: +# """Function that recovers meaningful structured data from internal flat data array + +# Args: +# flat_data: Internal flat data array. + +# Returns: +# Meaningful structured data. Should be a mapping with names from self.variable_names. +# Each value should be an array of shape (1,) (for e.g. MAP decodings) or +# (self.num_states,) (for e.g. evidence, beliefs). + +# Raises: +# ValueError if: +# (1) flat_data is not a 1D array +# (2) flat_data is not of the right shape +# """ +# if flat_data.ndim != 1: +# raise ValueError( +# f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." +# ) + +# num_variables = len(self.variable_names) +# num_variable_states = len(self.variable_names) * self.num_states +# if flat_data.shape[0] == num_variables: +# use_num_states = False +# elif flat_data.shape[0] == num_variable_states: +# use_num_states = True +# else: +# raise ValueError( +# f"flat_data should be either of shape (num_variables(={len(self.variables)}),), " +# f"or (num_variable_states(={num_variable_states}),). " +# f"Got {flat_data.shape}" +# ) + +# start = 0 +# data = {} +# for name in self.variable_names: +# if use_num_states: +# data[name] = flat_data[start : start + self.num_states] +# start += self.num_states +# else: +# data[name] = flat_data[np.array([start])] +# start += 1 + +# return data diff --git a/tests/fg/test_wiring.py b/tests/fg/test_wiring.py index 232309ee..a3b6ae2b 100644 --- a/tests/fg/test_wiring.py +++ b/tests/fg/test_wiring.py @@ -19,10 +19,10 @@ def test_wiring_with_PairwiseFactorGroup(): B = vgroup.NDVariableArray(num_states=2, shape=(10,)) # First test that compile_wiring enforces the correct factor_edges_num_states shape - fg = graph.FactorGraph(variables=dict(A=A, B=B)) + fg = graph.FactorGraph(variables=[A, B]) fg.add_factor_group( factory=enumeration.PairwiseFactorGroup, - variable_names_for_factors=[[("A", idx), ("B", idx)] for idx in range(10)], + variable_names_for_factors=[[A[idx], B[idx]] for idx in range(10)], ) factor_group = fg.factor_groups[enumeration_factor.EnumerationFactor][0] object.__setattr__( @@ -37,27 +37,27 @@ def test_wiring_with_PairwiseFactorGroup(): factor_group.compile_wiring(fg._vars_to_starts) # FactorGraph with a single PairwiseFactorGroup - fg1 = graph.FactorGraph(variables=dict(A=A, B=B)) + fg1 = graph.FactorGraph(variables=[A, B]) fg1.add_factor_group( factory=enumeration.PairwiseFactorGroup, - variable_names_for_factors=[[("A", idx), ("B", idx)] for idx in range(10)], + variable_names_for_factors=[[A[idx], B[idx]] for idx in range(10)], ) assert len(fg1.factor_groups[enumeration_factor.EnumerationFactor]) == 1 # FactorGraph with multiple PairwiseFactorGroup - fg2 = graph.FactorGraph(variables=dict(A=A, B=B)) + fg2 = graph.FactorGraph(variables=[A, B]) for idx in range(10): fg2.add_factor_group( factory=enumeration.PairwiseFactorGroup, - variable_names_for_factors=[[("A", idx), ("B", idx)]], + variable_names_for_factors=[[A[idx], B[idx]]], ) assert len(fg2.factor_groups[enumeration_factor.EnumerationFactor]) == 10 # FactorGraph with multiple SingleFactorGroup - fg3 = graph.FactorGraph(variables=dict(A=A, B=B)) + fg3 = graph.FactorGraph(variables=[A, B]) for idx in range(10): fg3.add_factor_by_type( - variable_names=[("A", idx), ("B", idx)], + variable_names=[A[idx], B[idx]], factor_type=enumeration_factor.EnumerationFactor, **{ "factor_configs": np.array([[0, 0], [0, 1], [1, 0], [1, 1]]), @@ -98,29 +98,27 @@ def test_wiring_with_ORFactorGroup(): C = vgroup.NDVariableArray(num_states=2, shape=(10,)) # FactorGraph with a single ORFactorGroup - fg1 = graph.FactorGraph(variables=dict(A=A, B=B, C=C)) + fg1 = graph.FactorGraph(variables=[A, B, C]) fg1.add_factor_group( factory=logical.ORFactorGroup, - variable_names_for_factors=[ - [("A", idx), ("B", idx), ("C", idx)] for idx in range(10) - ], + variable_names_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)], ) assert len(fg1.factor_groups[logical_factor.ORFactor]) == 1 # FactorGraph with multiple ORFactorGroup - fg2 = graph.FactorGraph(variables=dict(A=A, B=B, C=C)) + fg2 = graph.FactorGraph(variables=[A, B, C]) for idx in range(10): fg2.add_factor_group( factory=logical.ORFactorGroup, - variable_names_for_factors=[[("A", idx), ("B", idx), ("C", idx)]], + variable_names_for_factors=[[A[idx], B[idx], C[idx]]], ) assert len(fg2.factor_groups[logical_factor.ORFactor]) == 10 # FactorGraph with multiple SingleFactorGroup - fg3 = graph.FactorGraph(variables=dict(A=A, B=B, C=C)) + fg3 = graph.FactorGraph(variables=[A, B, C]) for idx in range(10): fg3.add_factor_by_type( - variable_names=[("A", idx), ("B", idx), ("C", idx)], + variable_names=[A[idx], B[idx], C[idx]], factor_type=logical_factor.ORFactor, ) assert len(fg3.factor_groups[logical_factor.ORFactor]) == 10 @@ -155,29 +153,27 @@ def test_wiring_with_ANDFactorGroup(): C = vgroup.NDVariableArray(num_states=2, shape=(10,)) # FactorGraph with a single ANDFactorGroup - fg1 = graph.FactorGraph(variables=dict(A=A, B=B, C=C)) + fg1 = graph.FactorGraph(variables=[A, B, C]) fg1.add_factor_group( factory=logical.ANDFactorGroup, - variable_names_for_factors=[ - [("A", idx), ("B", idx), ("C", idx)] for idx in range(10) - ], + variable_names_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)], ) assert len(fg1.factor_groups[logical_factor.ANDFactor]) == 1 # FactorGraph with multiple ANDFactorGroup - fg2 = graph.FactorGraph(variables=dict(A=A, B=B, C=C)) + fg2 = graph.FactorGraph(variables=[A, B, C]) for idx in range(10): fg2.add_factor_group( factory=logical.ANDFactorGroup, - variable_names_for_factors=[[("A", idx), ("B", idx), ("C", idx)]], + variable_names_for_factors=[[A[idx], B[idx], C[idx]]], ) assert len(fg2.factor_groups[logical_factor.ANDFactor]) == 10 # FactorGraph with multiple SingleFactorGroup - fg3 = graph.FactorGraph(variables=dict(A=A, B=B, C=C)) + fg3 = graph.FactorGraph(variables=[A, B, C]) for idx in range(10): fg3.add_factor_by_type( - variable_names=[("A", idx), ("B", idx), ("C", idx)], + variable_names=[A[idx], B[idx], C[idx]], factor_type=logical_factor.ANDFactor, ) assert len(fg3.factor_groups[logical_factor.ANDFactor]) == 10 From d816f63a1f74ddf018d1d9fa30f83d094a7a2c57 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Tue, 12 Apr 2022 18:10:02 +0000 Subject: [PATCH 02/35] Falke8 --- pgmax/fg/graph.py | 24 +++++++++++------------- pgmax/fg/groups.py | 5 ----- pgmax/groups/variables.py | 13 +------------ 3 files changed, 12 insertions(+), 30 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index f0a0aefe..ec03714c 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -1,7 +1,5 @@ from __future__ import annotations -from this import d - """A module containing the core class to specify a Factor Graph.""" import collections @@ -102,9 +100,7 @@ def __post_init__(self): print("2", time.time() - start) # Used to add FactorGroups - # TODO: dict is faster - aa = zip(vars_names, vars_num_states) - print("3", time.time() - start) + # TODO: move to dict, which is faster self._vars_to_num_states: OrderedDict[int, int] = collections.OrderedDict( zip(vars_names, vars_num_states) ) @@ -434,12 +430,14 @@ class FactorGraphState: wiring: OrderedDict[type, nodes.Wiring] def __post_init__(self): - for field in self.__dataclass_fields__: - if isinstance(getattr(self, field), np.ndarray): - getattr(self, field).flags.writeable = False + for this_field in self.__dataclass_fields__: + if isinstance(getattr(self, this_field), np.ndarray): + getattr(self, this_field).flags.writeable = False - if isinstance(getattr(self, field), Mapping): - object.__setattr__(self, field, MappingProxyType(getattr(self, field))) + if isinstance(getattr(self, this_field), Mapping): + object.__setattr__( + self, this_field, MappingProxyType(getattr(self, this_field)) + ) @dataclass(frozen=True, eq=False) @@ -848,9 +846,9 @@ class BPArrays: evidence: Union[np.ndarray, jnp.ndarray] def __post_init__(self): - for field in self.__dataclass_fields__: - if isinstance(getattr(self, field), np.ndarray): - getattr(self, field).flags.writeable = False + for this_field in self.__dataclass_fields__: + if isinstance(getattr(self, this_field), np.ndarray): + getattr(self, this_field).flags.writeable = False def tree_flatten(self): return jax.tree_util.tree_flatten(asdict(self)) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index c90fe106..a08541b0 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -1,15 +1,10 @@ """A module containing the base classes for variable and factor groups in a Factor Graph.""" -import collections import inspect -import typing from dataclasses import dataclass, field -from types import MappingProxyType from typing import ( Any, - Collection, FrozenSet, - Hashable, List, Mapping, OrderedDict, diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index dd381233..1ca5d4bf 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -1,19 +1,8 @@ """A module containing the variables group classes inheriting from the base VariableGroup.""" -import collections import itertools from dataclasses import dataclass -from typing import ( - Any, - Dict, - Hashable, - Mapping, - Optional, - OrderedDict, - Set, - Tuple, - Union, -) +from typing import Tuple, Union import jax import jax.numpy as jnp From 58b011535233f905e91c84a2251d57e65a49335b Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Tue, 12 Apr 2022 23:54:05 +0000 Subject: [PATCH 03/35] Test + HashableDict --- examples/ising_model.py | 16 ++--- examples/pmp_binary_deconvolution.py | 41 +++++++------ examples/rbm.py | 51 ++++++++++++++-- pgmax/factors/enumeration.py | 3 - pgmax/fg/graph.py | 71 +++++++++++++++------- pgmax/fg/groups.py | 15 +++++ pgmax/groups/variables.py | 37 +++++------- tests/factors/test_or.py | 89 +++++++++++++++++----------- 8 files changed, 210 insertions(+), 113 deletions(-) diff --git a/examples/ising_model.py b/examples/ising_model.py index ef5a267b..837c764e 100644 --- a/examples/ising_model.py +++ b/examples/ising_model.py @@ -46,25 +46,19 @@ log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]), ) -# %% -from pgmax.factors import enumeration as enumeration_factor - -factors = fg.factor_groups[enumeration_factor.EnumerationFactor][0].factors - # %% [markdown] # ### Run inference and visualize results -import imp - # %% -from pgmax.fg import graph - -imp.reload(graph) bp = graph.BP(fg.bp_state, temperature=0) +# %% +d = {variables: 1, variables.__hash__(): 2} +hd = graph.HashableDict(d) +hd[variables] + # %% # TODO: check\ time for before BP vs time for BP -# TODO: why bug when done twice? # TODO: time PGMAX vs PMP bp_arrays = bp.init( evidence_updates={variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))} diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py index b10f5688..a5c91c22 100644 --- a/examples/pmp_binary_deconvolution.py +++ b/examples/pmp_binary_deconvolution.py @@ -106,6 +106,8 @@ def plot_images(images, display=True, nr=None): # - a second set of ORFactors, which maps SW to X and model (binary) features overlapping. # # See Section 5.6 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) for more details. +# +# import imp import imp @@ -150,7 +152,7 @@ def plot_images(images, display=True, nr=None): start = time.time() # Factor graph fg = graph.FactorGraph(variables=[S, W, SW, X]) -print("x", time.time() - start) +print(time.time() - start) # Define the ANDFactors variable_names_for_ANDFactors = [] @@ -232,7 +234,7 @@ def plot_images(images, display=True, nr=None): # %% pW = 0.25 -pS = 1e-100 +pS = 1e-70 pX = 1e-100 # Sparsity inducing priors for W and S @@ -248,26 +250,28 @@ def plot_images(images, display=True, nr=None): # %% np.random.seed(seed=40) -n_samples = 1 start = time.time() -bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)( +bp_arrays = bp.init( evidence_updates={ - "S": uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape), - "W": uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape), - "SW": np.zeros(shape=(n_samples,) + SW.shape), - "X": uX[None] + np.zeros(shape=(n_samples,) + uX.shape), + S: uS + np.random.gumbel(size=uS.shape), + W: uW + np.random.gumbel(size=uW.shape), + SW: np.zeros(SW.shape), + X: uX + np.zeros(shape=uX.shape), }, ) print("Time", time.time() - start) -bp_arrays = jax.vmap( - functools.partial(bp.run_bp, num_iters=100, damping=0.5), - in_axes=0, - out_axes=0, -)(bp_arrays) +bp_arrays = bp.run_bp(bp_arrays, num_iters=100, damping=0.5) print("Time", time.time() - start) -# beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays) -# map_states = graph.decode_map_states(beliefs) +output = bp.get_bp_output(bp_arrays) +print("Time", time.time() - start) + +# %% +_ = plot_images(output.map_states[W].reshape(-1, feat_height, feat_width), nr=1) + +# %% + +# %% # %% [markdown] # We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap` @@ -276,6 +280,7 @@ def plot_images(images, display=True, nr=None): np.random.seed(seed=40) n_samples = 4 +start = time.time() bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)( evidence_updates={ "S": uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape), @@ -284,13 +289,15 @@ def plot_images(images, display=True, nr=None): "X": uX[None] + np.zeros(shape=(n_samples,) + uX.shape), }, ) +print("Time", time.time() - start) bp_arrays = jax.vmap( functools.partial(bp.run_bp, num_iters=100, damping=0.5), in_axes=0, out_axes=0, )(bp_arrays) -beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays) -map_states = graph.decode_map_states(beliefs) +print("Time", time.time() - start) +# beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays) +# map_states = graph.decode_map_states(beliefs) # %% [markdown] # Visualizing the MAP decoding, we see that we have 4 good random samples (one per row) from the posterior! diff --git a/examples/rbm.py b/examples/rbm.py index e4d26a83..0f820851 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -61,6 +61,49 @@ fg = graph.FactorGraph(variables=[hidden_variables, visible_variables]) print("Time", time.time() - start) +# %% +import itertools + +start = time.time() +variable_names_for_factors = factors = list( + map( + lambda ij: ( + hidden_variables.variable_names[ij[0]], + visible_variables.variable_names[ij[1]], + ), + list(itertools.product(range(bh.shape[0]), range(bv.shape[0]))), + ) +) +print("Time", time.time() - start, len(variable_names_for_factors)) + +import numba as nb + + +@nb.jit(parallel=False, cache=True, fastmath=True, nopython=True) +def run_numba(h, v, f): + for h_idx in nb.prange(h.shape[0]): + for v_idx in nb.prange(v.shape[0]): + f[h_idx * v.shape[0] + v_idx, 0] = h[h_idx] + f[h_idx * v.shape[0] + v_idx, 1] = v[v_idx] + + +start = time.time() +variable_names_for_factors = np.empty(shape=(bv.shape[0] * bh.shape[0], 2), dtype=int) +run_numba( + hidden_variables.variable_names, + visible_variables.variable_names, + variable_names_for_factors, +) +print("Time", time.time() - start, len(variable_names_for_factors)) + +start = time.time() +variable_names_for_factors = [ + [hidden_variables[ii], visible_variables[jj]] + for ii in range(bh.shape[0]) + for jj in range(bv.shape[0]) +] +print("Time", time.time() - start, len(variable_names_for_factors)) + # %% [markdown] # [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup) (e.g. an [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)), or a list/dictionary of [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed. # @@ -82,7 +125,6 @@ factor_configs=np.arange(2)[:, None], log_potentials=np.stack([np.zeros_like(bv), bv], axis=1), ) - # Add pairwise factors log_potential_matrix = np.zeros(W.shape + (2, 2)).reshape((-1, 2, 2)) log_potential_matrix[:, 1, 1] = W.ravel() @@ -94,10 +136,11 @@ for ii in range(bh.shape[0]) for jj in range(bv.shape[0]) ], + # variable_names_for_factors=variable_names_for_factors, log_potential_matrix=log_potential_matrix, ) -# fg.add_factor_group(factory=enumeration.PairwiseFactorGroup, variable_names_for_factors=[[hidden_variables[ii], visible_variables[jj]]for ii in range(bh.shape[0])for jj in range(bv.shape[0])], log_potential_matrix=log_potential_matrix,) +# # %snakeviz fg.add_factor_group(factory=enumeration.PairwiseFactorGroup, variable_names_for_factors=v, log_potential_matrix=log_potential_matrix,) print("Time", time.time() - start) @@ -168,7 +211,7 @@ evidence_updates={ hidden_variables: np.random.gumbel(size=(bh.shape[0], 2)), visible_variables: np.random.gumbel(size=(bv.shape[0], 2)), - }, + } ) print("Time", time.time() - start) bp_arrays = bp.run_bp(bp_arrays, num_iters=100, damping=0.5) @@ -221,11 +264,11 @@ out_axes=0, )(bp_arrays) +# TODO: problem outputs = jax.vmap(bp.get_bp_output, in_axes=0, out_axes=0)(bp_arrays) # map_states = graph.decode_map_states(beliefs) # %% -O # %% [markdown] # Visualizing the MAP decodings (Figure [fig:rbm_multiple_digits]), we see that we have sampled 10 MNIST digits in parallel! diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py index e2521597..450cc5ab 100644 --- a/pgmax/factors/enumeration.py +++ b/pgmax/factors/enumeration.py @@ -184,9 +184,6 @@ def compile_wiring( var_states = np.array( [vars_to_starts[variable] for variable in variables_for_factors] ) - # num_states = np.array( - # [variable.num_states for variable in variables_for_factors] - # ) num_states_cumsum = np.insert(np.cumsum(factor_edges_num_states), 0, 0) var_states_for_edges = np.empty(shape=(num_states_cumsum[-1],), dtype=int) _compile_var_states_numba(var_states_for_edges, num_states_cumsum, var_states) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index ec03714c..c144660f 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -127,7 +127,6 @@ def add_factor( variable_names: List, factor_configs: np.ndarray, log_potentials: Optional[np.ndarray] = None, - name: Optional[str] = None, ) -> None: """Function to add a single factor to the FactorGraph. @@ -151,7 +150,7 @@ def add_factor( factor_configs=factor_configs, log_potentials=log_potentials, ) - self._register_factor_group(factor_group, name) + self._register_factor_group(factor_group) def add_factor_by_type( self, variable_names: List[int], factor_type: type, *args, **kwargs @@ -221,7 +220,6 @@ def _register_factor_group(self, factor_group: groups.FactorGroup) -> None: var_names in self._factor_types_to_variable_names_for_factors[factor_type] ): - print(len(var_names_for_factor)) raise ValueError( f"A Factor of type {factor_type} involving variables {var_names} already exists. Please merge the corresponding factors." ) @@ -492,8 +490,7 @@ def update_log_potentials( (1) Provided log_potentials shape does not match the expected log_potentials shape. (2) Provided name is not valid for log_potentials updates. """ - for name in updates: - data = updates[name] + for name, data in updates.items(): if name in fg_state.named_factor_groups: factor_group = fg_state.named_factor_groups[name] @@ -613,8 +610,7 @@ def update_ftov_msgs( (1) provided ftov_msgs shape does not match the expected ftov_msgs shape. (2) provided name is not valid for ftov_msgs updates. """ - for names in updates: - data = updates[names] + for names, data in updates.items(): if names in fg_state.variable_group.names: variable = fg_state.variable_group[names] if data.shape != (variable.num_states,): @@ -729,7 +725,7 @@ def __setitem__(self, names, data) -> None: ) -# @functools.partial(jax.jit, static_argnames="fg_state") +@functools.partial(jax.jit, static_argnames="fg_state") def update_evidence( evidence: jnp.ndarray, updates: Dict[Any, jnp.ndarray], fg_state: FactorGraphState ) -> jnp.ndarray: @@ -743,11 +739,9 @@ def update_evidence( Returns: A flat jnp array containing updated evidence. """ - for var_group_name in updates: - data = updates[var_group_name] - print(data.shape) - if var_group_name.__hash__() in fg_state.variable_group: - variable_group = fg_state.variable_group[var_group_name.__hash__()] + for var_group_name, data in updates.items(): + if var_group_name in fg_state.variable_group: + variable_group = fg_state.variable_group[var_group_name] first_variable = variable_group.variable_names.flatten()[0] start_index = fg_state.vars_to_starts[first_variable] flat_data = variable_group.flatten(data) @@ -974,17 +968,19 @@ def update( if log_potentials_updates is not None: log_potentials = update_log_potentials( - log_potentials, log_potentials_updates, bp_state.fg_state + log_potentials, HashableDict(log_potentials_updates), bp_state.fg_state ) if ftov_msgs_updates is not None: ftov_msgs = update_ftov_msgs( - ftov_msgs, ftov_msgs_updates, bp_state.fg_state + ftov_msgs, HashableDict(ftov_msgs_updates), bp_state.fg_state ) if evidence_updates is not None: - print(type(evidence), type(evidence_updates)) - evidence = update_evidence(evidence, evidence_updates, bp_state.fg_state) + # Note: if we overwrite variables.__hash__ then hash may change when we call jit + evidence = update_evidence( + evidence, HashableDict(evidence_updates), bp_state.fg_state + ) return BPArrays( log_potentials=log_potentials, ftov_msgs=ftov_msgs, evidence=evidence @@ -1082,7 +1078,7 @@ def get_bp_output(bp_arrays: BPArrays) -> Any: """ @jax.jit - def compute_flat_beliefs(bp_arrays): + def compute_flat_beliefs(bp_arrays, var_states_for_edges): flat_beliefs = ( jax.device_put(bp_arrays.evidence) .at[jax.device_put(var_states_for_edges)] @@ -1091,7 +1087,8 @@ def compute_flat_beliefs(bp_arrays): return flat_beliefs return BeliefPropagationOutputs( - compute_flat_beliefs(bp_arrays), bp_state.fg_state.variable_group + compute_flat_beliefs(bp_arrays, var_states_for_edges), + bp_state.fg_state.variable_group, ) bp = BeliefPropagation( @@ -1104,23 +1101,47 @@ def compute_flat_beliefs(bp_arrays): return bp +@jax.tree_util.register_pytree_node_class @dataclass(frozen=True, eq=False) class HashableDict: + "A convenient class" d: Dict = field(default_factory=dict) + def __post_init__(self): + # Allows to initialize from dict without knowing hash + new_d = {k.__hash__(): v for k, v in self.d.items()} + object.__setattr__(self, "d", new_d) + + @functools.lru_cache() + def keys(self): + return self.d.keys() + + def items(self): + return self.d.items() + def __setitem__(self, key, value): + # Allows to copy keys from another HashableDict self.d[key] = value def __getitem__(self, value): + # Allows to retrieve from dict without knowing hash return self.d[value.__hash__()] + def tree_flatten(self): + return jax.tree_util.tree_flatten(asdict(self)) + @classmethod + def tree_unflatten(cls, aux_data, children): + return cls(**aux_data.unflatten(children)) + + +@jax.tree_util.register_pytree_node_class @dataclass(frozen=True, eq=False) class BeliefPropagationOutputs: # beliefs: An array or a PyTree container containing the beliefs for the variables. flat_beliefs: jnp.ndarray variable_groups: Mapping[int, groups.VariableGroup] - beliefs: HashableDict = field(init=False, default_factory=dict) + beliefs: HashableDict = field(init=False) def __post_init__(self): if self.flat_beliefs.ndim != 1: @@ -1162,7 +1183,7 @@ def unflatten(self) -> None: f"Got {self.flat_beliefs.shape}" ) - beliefs = {} + beliefs = HashableDict() start = 0 for name, variable_group in self.variable_groups.items(): if use_num_states: @@ -1218,3 +1239,11 @@ def _get_marginals(beliefs) -> Any: for name, beliefs in self.beliefs.items(): marginals[name] = _get_marginals(beliefs) return marginals + + def tree_flatten(self): + return jax.tree_util.tree_flatten(asdict(self)) + + @classmethod + def tree_unflatten(cls, aux_data, children): + # TODO: fix this + return cls(**aux_data.unflatten(children)) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index a08541b0..f9f3dd41 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -311,3 +311,18 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any: raise NotImplementedError( "SingleFactorGroup does not support vectorized factor operations." ) + + +# def get_ndvariable_names_for_factor_groups( +# arrays: List[Any] + +# ): +# "Util function" + +# import numba as nb +# @nb.jit(parallel=False, cache=True, fastmath=True, nopython=True) +# def run_numba(h, v, f): +# for h_idx in nb.prange(h.shape[0]): +# for v_idx in nb.prange(v.shape[0]): +# f[h_idx * v.shape[0] + v_idx, 0] = h[h_idx] +# f[h_idx * v.shape[0] + v_idx, 1] = v[v_idx] diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index 1ca5d4bf..4c6587f3 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -1,6 +1,7 @@ """A module containing the variables group classes inheriting from the base VariableGroup.""" import itertools +import random from dataclasses import dataclass from typing import Tuple, Union @@ -34,6 +35,10 @@ def __post_init__(self): if self.num_states.shape != self.shape: raise ValueError("Should be same shape") + def __getitem__(self, val): + # Numpy indexation will throw IndexError for us if out-of-bounds + return self.variable_names[val] + @cached_property def variable_names(self) -> np.ndarray: """Function that generates a dictionary mapping names to variables. @@ -41,17 +46,10 @@ def variable_names(self) -> np.ndarray: Returns: a dictionary mapping all possible names to different variables. """ - variable_names = np.empty(self.shape, dtype=int) - self_hash = self.__hash__() - for index in itertools.product(*[list(range(k)) for k in self.shape]): - name = hash((self_hash, index)) - variable_names[index] = name - return variable_names - - def __getitem__(self, val): - # Numpy indexation will throw IndexError is out-of-bounds - # This will be used to add FactorGroups - return self.variable_names[val] + # Overwite default hash as it does not give enough spacing across consecutive objects + this_hash = random.randint(0, 2**63) + indices = np.reshape(np.arange(np.product(self.shape)), self.shape) + return this_hash + indices def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: """Function that turns meaningful structured data into a flat data array for internal use. @@ -112,32 +110,29 @@ def unflatten( return data +# TODO: delete? # @dataclass(frozen=True, eq=False) -# class VariableDict(groups.VariableGroup): +# class VariableDict(): # """A variable dictionary that contains a set of variables of the same size # Args: # num_states: The size of the variables in this variable group -# variable_names: A tuple of all names of the variables in this variable group +# num_variables: The number of variables # """ # num_states: int -# variable_names: Tuple[Any, ...] +# num_variables: int -# def _get_names_to_variables(self) -> OrderedDict[Tuple[int, ...], nodes.Variable]: +# @cached_property +# def variable_names(self) -> np.ndarray: # """Function that generates a dictionary mapping names to variables. # Returns: # a dictionary mapping all possible names to different variables. # """ -# names_to_variables: OrderedDict[ -# Tuple[Any, ...], nodes.Variable -# ] = collections.OrderedDict() -# for name in self.variable_names: -# names_to_variables[name] = nodes.Variable(self.num_states) +# return self.__hash__() + np.arange(self.num_states) -# return names_to_variables # def flatten( # self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]] diff --git a/tests/factors/test_or.py b/tests/factors/test_or.py index 1e538a8d..e634fbab 100644 --- a/tests/factors/test_or.py +++ b/tests/factors/test_or.py @@ -25,6 +25,7 @@ def test_run_bp_with_OR_factors(): Note: for the first seed, add all the EnumerationFactors to FG1 and all the ORFactors to FG2 """ for idx in range(10): + print("it", idx) np.random.seed(idx) # Parameters @@ -43,30 +44,41 @@ def test_run_bp_with_OR_factors(): parents_variables1 = vgroup.NDVariableArray( num_states=2, shape=(num_parents.sum(),) ) - children_variable1 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,)) - fg1 = graph.FactorGraph( - variables=dict(parents=parents_variables1, children=children_variable1) - ) + children_variables1 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,)) + fg1 = graph.FactorGraph(variables=[parents_variables1, children_variables1]) # Graph 2 parents_variables2 = vgroup.NDVariableArray( num_states=2, shape=(num_parents.sum(),) ) - children_variable2 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,)) - fg2 = graph.FactorGraph( - variables=dict(parents=parents_variables2, children=children_variable2) - ) + children_variables2 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,)) + fg2 = graph.FactorGraph(variables=[parents_variables2, children_variables2]) - # Option 1: Define EnumerationFactors equivalent to the ORFactors + # Variable names for factors + variable_names_for_factors1 = [] + variable_names_for_factors2 = [] for factor_idx in range(num_factors): - this_num_parents = num_parents[factor_idx] - variable_names = [ - ("parents", idx) + variable_names1 = [ + parents_variables1[idx] + for idx in range( + num_parents_cumsum[factor_idx], + num_parents_cumsum[factor_idx + 1], + ) + ] + [children_variables1[factor_idx]] + variable_names_for_factors1.append(variable_names1) + + variable_names2 = [ + parents_variables2[idx] for idx in range( num_parents_cumsum[factor_idx], num_parents_cumsum[factor_idx + 1], ) - ] + [("children", factor_idx)] + ] + [children_variables2[factor_idx]] + variable_names_for_factors2.append(variable_names2) + + # Option 1: Define EnumerationFactors equivalent to the ORFactors + for factor_idx in range(num_factors): + this_num_parents = num_parents[factor_idx] configs = np.array(list(product([0, 1], repeat=this_num_parents + 1))) # Children state is last @@ -82,7 +94,7 @@ def test_run_bp_with_OR_factors(): if factor_idx < num_factors // 2: # Add the first half of factors to FactorGraph1 fg1.add_factor( - variable_names=variable_names, + variable_names=variable_names_for_factors1[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) @@ -90,14 +102,14 @@ def test_run_bp_with_OR_factors(): if idx != 0: # Add the second half of factors to FactorGraph2 fg2.add_factor( - variable_names=variable_names, + variable_names=variable_names_for_factors2[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) else: # Add all the EnumerationFactors to FactorGraph1 for the first iter fg1.add_factor( - variable_names=variable_names, + variable_names=variable_names_for_factors1[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) @@ -108,28 +120,22 @@ def test_run_bp_with_OR_factors(): variable_names_for_ORFactors_fg2 = [] for factor_idx in range(num_factors): - variables_names_for_ORFactor = [ - ("parents", idx) - for idx in range( - num_parents_cumsum[factor_idx], - num_parents_cumsum[factor_idx + 1], - ) - ] + [("children", factor_idx)] if factor_idx < num_factors // 2: # Add the first half of factors to FactorGraph2 - variable_names_for_ORFactors_fg2.append(variables_names_for_ORFactor) + variable_names_for_ORFactors_fg2.append( + variable_names_for_factors2[factor_idx] + ) else: if idx != 0: # Add the second half of factors to FactorGraph1 variable_names_for_ORFactors_fg1.append( - variables_names_for_ORFactor + variable_names_for_factors1[factor_idx] ) else: # Add all the ORFactors to FactorGraph2 for the first iter variable_names_for_ORFactors_fg2.append( - variables_names_for_ORFactor + variable_names_for_factors2[factor_idx] ) - if idx != 0: fg1.add_factor_group( factory=logical.ORFactorGroup, @@ -144,19 +150,30 @@ def test_run_bp_with_OR_factors(): bp1 = graph.BP(fg1.bp_state, temperature=temperature) bp2 = graph.BP(fg2.bp_state, temperature=temperature) - evidence_updates = { - "parents": jax.device_put(np.random.gumbel(size=(sum(num_parents), 2))), - "children": jax.device_put(np.random.gumbel(size=(num_factors, 2))), + evidence_parents = jax.device_put(np.random.gumbel(size=(sum(num_parents), 2))) + evidence_children = jax.device_put(np.random.gumbel(size=(num_factors, 2))) + + evidence_updates1 = { + parents_variables1: evidence_parents, + children_variables1: evidence_children, + } + evidence_updates2 = { + parents_variables2: evidence_parents, + children_variables2: evidence_children, } - bp_arrays1 = bp1.init(evidence_updates=evidence_updates) + bp_arrays1 = bp1.init(evidence_updates=evidence_updates1) bp_arrays1 = bp1.run_bp(bp_arrays1, num_iters=5) - bp_arrays2 = bp2.init(evidence_updates=evidence_updates) + bp_arrays2 = bp2.init(evidence_updates=evidence_updates2) bp_arrays2 = bp2.run_bp(bp_arrays2, num_iters=5) # Get beliefs - beliefs1 = bp1.get_beliefs(bp_arrays1) - beliefs2 = bp2.get_beliefs(bp_arrays2) + beliefs1 = bp1.get_bp_output(bp_arrays1).beliefs + beliefs2 = bp2.get_bp_output(bp_arrays2).beliefs - assert np.allclose(beliefs1["children"], beliefs2["children"], atol=1e-4) - assert np.allclose(beliefs1["parents"], beliefs2["parents"], atol=1e-4) + assert np.allclose( + beliefs1[children_variables1], beliefs2[children_variables2], atol=1e-4 + ) + assert np.allclose( + beliefs1[parents_variables1], beliefs2[parents_variables2], atol=1e-4 + ) From c8f2c5e22f2327a81c386440cd2901d2497445fd Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Wed, 13 Apr 2022 01:00:51 +0000 Subject: [PATCH 04/35] Minbor --- examples/pmp_binary_deconvolution.py | 2 ++ examples/rbm.py | 2 ++ pgmax/fg/graph.py | 2 +- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py index a5c91c22..c61fefcf 100644 --- a/examples/pmp_binary_deconvolution.py +++ b/examples/pmp_binary_deconvolution.py @@ -108,6 +108,8 @@ def plot_images(images, display=True, nr=None): # See Section 5.6 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) for more details. # # import imp +# +# import imp import imp diff --git a/examples/rbm.py b/examples/rbm.py index 0f820851..3dc6a61f 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -44,6 +44,8 @@ # %% [markdown] # We can then initialize the factor graph for the RBM with +# +# import imp import imp diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index c144660f..be4c026a 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -1104,7 +1104,7 @@ def compute_flat_beliefs(bp_arrays, var_states_for_edges): @jax.tree_util.register_pytree_node_class @dataclass(frozen=True, eq=False) class HashableDict: - "A convenient class" + "Represents a dictionnary where stored keys are the hash of the elements" d: Dict = field(default_factory=dict) def __post_init__(self): From 39b546a9fa88d8779e8b01ceea2858ec1c71761b Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Wed, 20 Apr 2022 22:31:42 +0000 Subject: [PATCH 05/35] Variables as tuple + Remove BPOuputs/HashableDict --- examples/ising_model.py | 38 ++- examples/pmp_binary_deconvolution.py | 67 ++--- examples/rbm.py | 29 +-- pgmax/factors/enumeration.py | 32 +-- pgmax/factors/logical.py | 36 +-- pgmax/fg/graph.py | 352 +++++++++++---------------- pgmax/fg/groups.py | 56 ++--- pgmax/fg/nodes.py | 8 +- pgmax/groups/enumeration.py | 46 ++-- pgmax/groups/logical.py | 11 +- pgmax/groups/variables.py | 36 ++- tests/factors/test_and.py | 99 ++++---- tests/factors/test_or.py | 40 ++- tests/fg/test_wiring.py | 26 +- 14 files changed, 394 insertions(+), 482 deletions(-) diff --git a/examples/ising_model.py b/examples/ising_model.py index 837c764e..88535c23 100644 --- a/examples/ising_model.py +++ b/examples/ising_model.py @@ -31,19 +31,19 @@ variables = vgroup.NDVariableArray(num_states=2, shape=(50, 50)) fg = graph.FactorGraph(variables=variables) -# TODO: rename variable_for_factors? -variable_names_for_factors = [] +variables_for_factors = [] for ii in range(50): for jj in range(50): kk = (ii + 1) % 50 ll = (jj + 1) % 50 - variable_names_for_factors.append([variables[ii, jj], variables[kk, jj]]) - variable_names_for_factors.append([variables[ii, jj], variables[kk, ll]]) + variables_for_factors.append([variables[ii, jj], variables[kk, jj]]) + variables_for_factors.append([variables[ii, jj], variables[kk, ll]]) fg.add_factor_group( factory=enumeration.PairwiseFactorGroup, - variable_names_for_factors=variable_names_for_factors, + variables_for_factors=variables_for_factors, log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]), + name="factors", ) # %% [markdown] @@ -52,11 +52,6 @@ # %% bp = graph.BP(fg.bp_state, temperature=0) -# %% -d = {variables: 1, variables.__hash__(): 2} -hd = graph.HashableDict(d) -hd[variables] - # %% # TODO: check\ time for before BP vs time for BP # TODO: time PGMAX vs PMP @@ -64,10 +59,10 @@ evidence_updates={variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))} ) bp_arrays = bp.run_bp(bp_arrays, num_iters=3000) -output = bp.get_bp_output(bp_arrays) +beliefs = bp.get_beliefs(bp_arrays) # %% -img = output.map_states[variables] +img = graph.decode_map_states(beliefs)[variables] fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(img) @@ -82,19 +77,20 @@ def loss(log_potentials_updates, evidence_updates): ) bp_arrays = bp.run_bp(bp_arrays, num_iters=3000) beliefs = bp.get_beliefs(bp_arrays) - loss = -jnp.sum(beliefs) + loss = -jnp.sum(beliefs[variables]) return loss -batch_loss = jax.jit(jax.vmap(loss, in_axes=(None, {None: 0}), out_axes=0)) +batch_loss = jax.jit(jax.vmap(loss, in_axes=(None, {variables: 0}), out_axes=0)) log_potentials_grads = jax.jit(jax.grad(loss, argnums=0)) # %% -batch_loss(None, {None: jax.device_put(np.random.gumbel(size=(10, 50, 50, 2)))}) +batch_loss(None, {variables: jax.device_put(np.random.gumbel(size=(10, 50, 50, 2)))}) # %% grads = log_potentials_grads( - {"factors": jnp.eye(2)}, {None: jax.device_put(np.random.gumbel(size=(50, 50, 2)))} + {"factors": jnp.eye(2)}, + {variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}, ) # %% [markdown] @@ -104,15 +100,15 @@ def loss(log_potentials_updates, evidence_updates): bp_state = bp.to_bp_state(bp_arrays) # Query evidence for variable (0, 0) -bp_state.evidence[0, 0] +bp_state.evidence[variables[0, 0]] # %% # Set evidence for variable (0, 0) -bp_state.evidence[0, 0] = np.array([1.0, 1.0]) -bp_state.evidence[0, 0] +bp_state.evidence[variables[0, 0]] = np.array([1.0, 1.0]) +bp_state.evidence[variables[0, 0]] # %% # Set evidence for all variables using an array evidence = np.random.randn(50, 50, 2) -bp_state.evidence[None] = evidence -bp_state.evidence[10, 10] == evidence[10, 10] +bp_state.evidence[variables] = evidence +np.allclose(bp_state.evidence[variables[10, 10]], evidence[10, 10]) diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py index c61fefcf..bd418a69 100644 --- a/examples/pmp_binary_deconvolution.py +++ b/examples/pmp_binary_deconvolution.py @@ -106,10 +106,6 @@ def plot_images(images, display=True, nr=None): # - a second set of ORFactors, which maps SW to X and model (binary) features overlapping. # # See Section 5.6 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) for more details. -# -# import imp -# -# import imp import imp @@ -157,8 +153,8 @@ def plot_images(images, display=True, nr=None): print(time.time() - start) # Define the ANDFactors -variable_names_for_ANDFactors = [] -variable_names_for_ORFactors_dict = defaultdict(list) +variables_for_ANDFactors = [] +variables_for_ORFactors_dict = defaultdict(list) for idx_img in tqdm(range(n_images)): for idx_chan in range(n_chan): for idx_s_height in range(s_height): @@ -178,36 +174,34 @@ def plot_images(images, display=True, nr=None): idx_feat_width, ] - variable_names_for_ANDFactor = [ + variables_for_ANDFactor = [ S[idx_img, idx_feat, idx_s_height, idx_s_width], W[idx_chan, idx_feat, idx_feat_height, idx_feat_width], SW_var, ] - variable_names_for_ANDFactors.append( - variable_names_for_ANDFactor - ) + variables_for_ANDFactors.append(variables_for_ANDFactor) X_var = X[idx_img, idx_chan, idx_img_height, idx_img_width] - variable_names_for_ORFactors_dict[X_var].append(SW_var) + variables_for_ORFactors_dict[X_var].append(SW_var) print(time.time() - start) # Add ANDFactorGroup, which is computationally efficient fg.add_factor_group( factory=logical.ANDFactorGroup, - variable_names_for_factors=variable_names_for_ANDFactors, + variables_for_factors=variables_for_ANDFactors, ) print(time.time() - start) # Define the ORFactors -variable_names_for_ORFactors = [ - list(tuple(variable_names_for_ORFactors_dict[X_var]) + (X_var,)) - for X_var in variable_names_for_ORFactors_dict +variables_for_ORFactors = [ + list(tuple(variables_for_ORFactors_dict[X_var]) + (X_var,)) + for X_var in variables_for_ORFactors_dict ] # Add ORFactorGroup, which is computationally efficient fg.add_factor_group( factory=logical.ORFactorGroup, - variable_names_for_factors=variable_names_for_ORFactors, + variables_for_factors=variables_for_ORFactors, ) print("Time", time.time() - start) @@ -236,7 +230,7 @@ def plot_images(images, display=True, nr=None): # %% pW = 0.25 -pS = 1e-70 +pS = 1e-72 pX = 1e-100 # Sparsity inducing priors for W and S @@ -250,31 +244,6 @@ def plot_images(images, display=True, nr=None): uX = np.zeros((X_gt.shape) + (2,)) uX[..., 0] = (2 * X_gt - 1) * logit(pX) -# %% -np.random.seed(seed=40) - -start = time.time() -bp_arrays = bp.init( - evidence_updates={ - S: uS + np.random.gumbel(size=uS.shape), - W: uW + np.random.gumbel(size=uW.shape), - SW: np.zeros(SW.shape), - X: uX + np.zeros(shape=uX.shape), - }, -) -print("Time", time.time() - start) -bp_arrays = bp.run_bp(bp_arrays, num_iters=100, damping=0.5) -print("Time", time.time() - start) -output = bp.get_bp_output(bp_arrays) -print("Time", time.time() - start) - -# %% -_ = plot_images(output.map_states[W].reshape(-1, feat_height, feat_width), nr=1) - -# %% - -# %% - # %% [markdown] # We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap` @@ -285,10 +254,10 @@ def plot_images(images, display=True, nr=None): start = time.time() bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)( evidence_updates={ - "S": uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape), - "W": uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape), - "SW": np.zeros(shape=(n_samples,) + SW.shape), - "X": uX[None] + np.zeros(shape=(n_samples,) + uX.shape), + S: uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape), + W: uW[None] + np.random.gumbel(size=(n_samples,) + uW.shape), + SW: np.zeros(shape=(n_samples,) + SW.shape), + X: uX[None] + np.zeros(shape=(n_samples,) + uX.shape), }, ) print("Time", time.time() - start) @@ -298,8 +267,8 @@ def plot_images(images, display=True, nr=None): out_axes=0, )(bp_arrays) print("Time", time.time() - start) -# beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays) -# map_states = graph.decode_map_states(beliefs) +beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays) +map_states = graph.decode_map_states(beliefs) # %% [markdown] # Visualizing the MAP decoding, we see that we have 4 good random samples (one per row) from the posterior! @@ -307,4 +276,4 @@ def plot_images(images, display=True, nr=None): # Because we have used one extra feature for inference, each posterior sample recovers the 4 basic features used to generate the images, and includes an extra symbol. # %% -_ = plot_images(map_states["W"].reshape(-1, feat_height, feat_width), nr=n_samples) +_ = plot_images(map_states[W].reshape(-1, feat_height, feat_width), nr=n_samples) diff --git a/examples/rbm.py b/examples/rbm.py index 3dc6a61f..9a46c0dc 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -44,8 +44,6 @@ # %% [markdown] # We can then initialize the factor graph for the RBM with -# -# import imp import imp @@ -116,14 +114,14 @@ def run_numba(h, v, f): # Add unary factors fg.add_factor_group( factory=enumeration.EnumerationFactorGroup, - variable_names_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])], + variables_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])], factor_configs=np.arange(2)[:, None], log_potentials=np.stack([np.zeros_like(bh), bh], axis=1), ) fg.add_factor_group( factory=enumeration.EnumerationFactorGroup, - variable_names_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])], + variables_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])], factor_configs=np.arange(2)[:, None], log_potentials=np.stack([np.zeros_like(bv), bv], axis=1), ) @@ -133,7 +131,7 @@ def run_numba(h, v, f): fg.add_factor_group( factory=enumeration.PairwiseFactorGroup, - variable_names_for_factors=[ + variables_for_factors=[ [hidden_variables[ii], visible_variables[jj]] for ii in range(bh.shape[0]) for jj in range(bv.shape[0]) @@ -218,7 +216,7 @@ def run_numba(h, v, f): print("Time", time.time() - start) bp_arrays = bp.run_bp(bp_arrays, num_iters=100, damping=0.5) print("Time", time.time() - start) -output = bp.get_bp_output(bp_arrays) +beliefs = bp.get_beliefs(bp_arrays) print("Time", time.time() - start) # %% [markdown] @@ -228,7 +226,9 @@ def run_numba(h, v, f): # %% fig, ax = plt.subplots(1, 1, figsize=(10, 10)) -ax.imshow(output.map_states[visible_variables].copy().reshape((28, 28)), cmap="gray") +ax.imshow( + graph.map_states(beliefs)[visible_variables].copy().reshape((28, 28)), cmap="gray" +) ax.axis("off") # %% [markdown] @@ -256,8 +256,8 @@ def run_numba(h, v, f): n_samples = 10 bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)( evidence_updates={ - "hidden": np.random.gumbel(size=(n_samples, bh.shape[0], 2)), - "visible": np.random.gumbel(size=(n_samples, bv.shape[0], 2)), + hidden_variables: np.random.gumbel(size=(n_samples, bh.shape[0], 2)), + visible_variables: np.random.gumbel(size=(n_samples, bv.shape[0], 2)), }, ) bp_arrays = jax.vmap( @@ -266,11 +266,8 @@ def run_numba(h, v, f): out_axes=0, )(bp_arrays) -# TODO: problem -outputs = jax.vmap(bp.get_bp_output, in_axes=0, out_axes=0)(bp_arrays) -# map_states = graph.decode_map_states(beliefs) - -# %% +beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays) +map_states = graph.decode_map_states(beliefs) # %% [markdown] # Visualizing the MAP decodings (Figure [fig:rbm_multiple_digits]), we see that we have sampled 10 MNIST digits in parallel! @@ -279,10 +276,8 @@ def run_numba(h, v, f): fig, ax = plt.subplots(2, 5, figsize=(20, 8)) for ii in range(10): ax[np.unravel_index(ii, (2, 5))].imshow( - map_states["visible"][ii].copy().reshape((28, 28)), cmap="gray" + map_states[visible_variables][ii].copy().reshape((28, 28)), cmap="gray" ) ax[np.unravel_index(ii, (2, 5))].axis("off") fig.tight_layout() - -# %% diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py index 450cc5ab..a8078e78 100644 --- a/pgmax/factors/enumeration.py +++ b/pgmax/factors/enumeration.py @@ -21,7 +21,7 @@ class EnumerationWiring(nodes.Wiring): Args: factor_configs_edge_states: Array of shape (num_factor_configs, 2) factor_configs_edge_states[ii] contains a pair of global enumeration factor_config and global edge_state indices - factor_configs_edge_states[ii, 0] contains the global EnumerationFactor config index, + factor_configs_edge_states[ii, 0] contains the global EnumerExpected factor_edges_num_statesationFactor config index, factor_configs_edge_states[ii, 1] contains the corresponding global edge_state index. Both indices only take into account the EnumerationFactors of the FactorGraph @@ -81,9 +81,9 @@ def __post_init__(self): f"EnumerationFactor. Got a factor_configs array of shape {self.factor_configs.shape}." ) - if len(self.vars_to_num_states.keys()) != self.factor_configs.shape[1]: + if len(self.variables) != self.factor_configs.shape[1]: raise ValueError( - f"Number of variables {len(self.vars_to_num_states.keys())} doesn't match given configurations {self.factor_configs.shape}" + f"Number of variables {len(self.variables)} doesn't match given configurations {self.factor_configs.shape}" ) if self.log_potentials.shape != (self.factor_configs.shape[0],): @@ -93,7 +93,7 @@ def __post_init__(self): f"shape {self.log_potentials.shape}." ) - vars_num_states = np.array([list(self.vars_to_num_states.values())]) + vars_num_states = np.array([variable[1] for variable in self.variables]) if not np.logical_and( self.factor_configs >= 0, self.factor_configs < vars_num_states[None] ).all(): @@ -154,20 +154,18 @@ def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiri @staticmethod def compile_wiring( - factor_edges_num_states: np.ndarray, - variables_for_factors: Tuple[int, ...], # TODO: rename + variables_for_factors: Tuple[Tuple[int, int], ...], factor_configs: np.ndarray, - vars_to_starts: Mapping[int, int], + vars_to_starts: Mapping[Tuple[int, int], int], num_factors: int, ) -> EnumerationWiring: """Compile an EnumerationWiring for an EnumerationFactor or a FactorGroup with EnumerationFactors. Internally calls _compile_var_states_numba and _compile_enumeration_wiring_numba for speed. Args: - factor_edges_num_states: An array concatenating the number of states for the variables connected to each - Factor of the FactorGroup. Each variable will appear once for each Factor it connects to. - variables_for_factors: A tuple containing variables connected to each Factor of the FactorGroup. - Each variable will appear once for each Factor it connects to. + variables_for_factors: A list of list of variables, where each innermost element is a + variable. Each list within the outer list is taken to contain the names of the + variables connected to a Factor. factor_configs: Array of shape (num_val_configs, num_variables) containing an explicit enumeration of all valid configurations. vars_to_starts: A dictionary that maps variables to their global starting indices @@ -181,9 +179,15 @@ def compile_wiring( Returns: The EnumerationWiring """ - var_states = np.array( - [vars_to_starts[variable] for variable in variables_for_factors] - ) + var_states = [] + factor_edges_num_states = [] + for variables_for_factor in variables_for_factors: + for variable in variables_for_factor: + var_states.append(vars_to_starts[variable]) + factor_edges_num_states.append(variable[1]) + var_states = np.array(var_states) + factor_edges_num_states = np.array(factor_edges_num_states) + num_states_cumsum = np.insert(np.cumsum(factor_edges_num_states), 0, 0) var_states_for_edges = np.empty(shape=(num_states_cumsum[-1],), dtype=int) _compile_var_states_numba(var_states_for_edges, num_states_cumsum, var_states) diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py index c53a5fac..81eab93d 100644 --- a/pgmax/factors/logical.py +++ b/pgmax/factors/logical.py @@ -86,14 +86,12 @@ class LogicalFactor(nodes.Factor): edge_states_offset: int = field(init=False) def __post_init__(self): - if len(self.vars_to_num_states.keys()) < 2: + if len(self.variables) < 2: raise ValueError( - "At least one parent variable and one child variable is required" + "A LogicalFactor requires at least one parent variable and one child variable " ) - if not np.all( - [num_states == 2 for num_states in self.vars_to_num_states.values()] - ): + if not np.all([variable[1] == 2 for variable in self.variables]): raise ValueError("All variables should all be binary") @staticmethod @@ -142,9 +140,7 @@ def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring: @staticmethod def compile_wiring( - factor_edges_num_states: np.ndarray, - variables_for_factors: Tuple[int, ...], # notsure - factor_sizes: np.ndarray, + variables_for_factors: Tuple[Tuple[int, int], ...], vars_to_starts: Mapping[int, int], edge_states_offset: int, ) -> LogicalWiring: @@ -152,11 +148,9 @@ def compile_wiring( Internally calls _compile_var_states_numba and _compile_logical_wiring_numba for speed. Args: - factor_edges_num_states: An array concatenating the number of states for the variables connected to each - Factor of the FactorGroup. Each variable will appear once for each Factor it connects to. - variables_for_factors: A tuple of tuples containing variables connected to each Factor of the FactorGroup. - Each variable will appear once for each Factor it connects to. - factor_sizes: An array containing the different factor sizes. + variables_for_factors: A list of list of variables, where each innermost element is a + variable. Each list within the outer list is taken to contain the names of the + variables connected to a Factor. vars_to_starts: A dictionary that maps variables to their global starting indices For an n-state variable, a global start index of m means the global indices of its n variable states are m, m + 1, ..., m + n - 1 @@ -166,11 +160,21 @@ def compile_wiring( Returns: The LogicalWiring """ + factor_sizes = [] + var_states = [] + factor_edges_num_states = [] + for variables_for_factor in variables_for_factors: + factor_sizes.append(len(variables_for_factor)) + for variable in variables_for_factor: + var_states.append(vars_to_starts[variable]) + factor_edges_num_states.append(variable[1]) + factor_sizes = np.array(factor_sizes) + var_states = np.array(var_states) + factor_edges_num_states = np.array(factor_edges_num_states) + + # Relevant state differs for ANDFactors and ORFactors relevant_state = (-edge_states_offset + 1) // 2 - var_states = np.array( - [vars_to_starts[variable] for variable in variables_for_factors] - ) # Note: all the variables in a LogicalFactorGroup are binary num_states_cumsum = np.arange(0, 2 * var_states.shape[0] + 2, 2) var_states_for_edges = np.empty(shape=(2 * var_states.shape[0],), dtype=int) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index be4c026a..1a8a9833 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -7,7 +7,7 @@ import functools import inspect import typing -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, dataclass from types import MappingProxyType from typing import ( Any, @@ -62,37 +62,32 @@ def __post_init__(self): start = time.time() # if isinstance(self.variables, groups.VariableGroup): if isinstance(self.variables, vgroup.NDVariableArray): - self.variables = [self.variables] - - self._variable_group: Mapping[ - int, groups.VariableGroup - ] = collections.OrderedDict() - for variable_group in self.variables: - self._variable_group[variable_group.__hash__()] = variable_group - - vars_names = [] - vars_num_states = [] - for variable_group in self.variables: - if isinstance(variable_group, vgroup.NDVariableArray): - vars_names.append(variable_group.variable_names.flatten()) - vars_num_states.append(variable_group.num_states.flatten()) + self.variable_groups = [self.variables] + else: + self.variable_groups = self.variables + + # TODO: remove? + self._variable_group = self.variable_groups + # self._variable_group: Mapping[ + # int, groups.VariableGroup + # ] = collections.OrderedDict() + # for variable_group in self.variables: + # self._variable_group[variable_group.__hash__()] = variable_group + + self._variables = [ + variable + for variable_group in self.variable_groups + for variable in variable_group.variables + ] print("1", time.time() - start) - vars_names = np.concatenate(vars_names) - vars_num_states = np.concatenate(vars_num_states) - vars_num_states_cumsum = np.insert( - np.array(vars_num_states).cumsum(), - 0, - 0, - ) - # Useful objects to build the FactorGraph self._factor_types_to_groups: OrderedDict[ Type, List[groups.FactorGroup] ] = collections.OrderedDict( [(factor_type, []) for factor_type in FAC_TO_VAR_UPDATES] ) - self._factor_types_to_variable_names_for_factors: OrderedDict[ + self._factor_types_to_variables_for_factors: OrderedDict[ Type, Set[FrozenSet] ] = collections.OrderedDict( [(factor_type, set()) for factor_type in FAC_TO_VAR_UPDATES] @@ -100,14 +95,16 @@ def __post_init__(self): print("2", time.time() - start) # Used to add FactorGroups - # TODO: move to dict, which is faster - self._vars_to_num_states: OrderedDict[int, int] = collections.OrderedDict( - zip(vars_names, vars_num_states) + vars_num_states = [variable[1] for variable in self._variables] + vars_num_states_cumsum = np.insert( + np.array(vars_num_states).cumsum(), + 0, + 0, ) # See FactorGraphState docstrings for documentation on the following fields self._num_var_states = vars_num_states_cumsum[-1] self._vars_to_starts: OrderedDict[int, int] = collections.OrderedDict( - zip(vars_names, vars_num_states_cumsum[:-1]) + zip(self._variables, vars_num_states_cumsum[:-1]) ) self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {} print("3", time.time() - start) @@ -127,6 +124,7 @@ def add_factor( variable_names: List, factor_configs: np.ndarray, log_potentials: Optional[np.ndarray] = None, + name: Optional[str] = None, ) -> None: """Function to add a single factor to the FactorGraph. @@ -145,15 +143,14 @@ def add_factor( initialized. """ factor_group = EnumerationFactorGroup( - self._vars_to_num_states, - variable_names_for_factors=[variable_names], + variables_for_factors=[variable_names], factor_configs=factor_configs, log_potentials=log_potentials, ) - self._register_factor_group(factor_group) + self._register_factor_group(factor_group, name) def add_factor_by_type( - self, variable_names: List[int], factor_type: type, *args, **kwargs + self, variables: List[int], factor_type: type, *args, **kwargs ) -> None: """Function to add a single factor to the FactorGraph. @@ -169,7 +166,7 @@ def add_factor_by_type( To add an ORFactor to a FactorGraph fg, run:: fg.add_factor_by_type( - variable_names=variables_names_for_OR_factor, + variables=variables_for_OR_factor, factor_type=logical.ORFactor ) """ @@ -178,16 +175,13 @@ def add_factor_by_type( f"Type {factor_type} is not one of the supported factor types {FAC_TO_VAR_UPDATES.keys()}" ) - vars_to_num_states = collections.OrderedDict( - (var, self._vars_to_num_states[var]) for var in variable_names - ) - factor = factor_type(vars_to_num_states, *args, **kwargs) + name = kwargs.pop("name", None) + factor = factor_type(variables, *args, **kwargs) factor_group = groups.SingleFactorGroup( - vars_to_num_states=self._vars_to_num_states, - variable_names_for_factors=[variable_names], + variables_for_factors=[variables], factor=factor, ) - self._register_factor_group(factor_group) + self._register_factor_group(factor_group, name) def add_factor_group(self, factory: Callable, *args, **kwargs) -> None: """Add a factor group to the factor graph @@ -198,10 +192,13 @@ def add_factor_group(self, factory: Callable, *args, **kwargs) -> None: kwargs: kwargs to be passed to the factory function, and an optional "name" argument for specifying the name of a named factor group. """ - factor_group = factory(self._vars_to_num_states, *args, **kwargs) - self._register_factor_group(factor_group) + name = kwargs.pop("name", None) + factor_group = factory(*args, **kwargs) + self._register_factor_group(factor_group, name) - def _register_factor_group(self, factor_group: groups.FactorGroup) -> None: + def _register_factor_group( + self, factor_group: groups.FactorGroup, name: Optional[str] = None + ) -> None: """Register a factor group to the factor graph, by updating the factor graph state. Args: @@ -212,21 +209,25 @@ def _register_factor_group(self, factor_group: groups.FactorGroup) -> None: ValueError: If the factor group with the same name or a factor involving the same variables already exists in the factor graph. """ + if name in self._named_factor_groups: + raise ValueError( + f"A factor group with the name {name} already exists. Please choose a different name!" + ) factor_type = factor_group.factor_type - for var_names_for_factor in factor_group.variable_names_for_factors: + for var_names_for_factor in factor_group.variables_for_factors: var_names = frozenset(var_names_for_factor) - if ( - var_names - in self._factor_types_to_variable_names_for_factors[factor_type] - ): + if var_names in self._factor_types_to_variables_for_factors[factor_type]: raise ValueError( f"A Factor of type {factor_type} involving variables {var_names} already exists. Please merge the corresponding factors." ) - self._factor_types_to_variable_names_for_factors[factor_type].add(var_names) + self._factor_types_to_variables_for_factors[factor_type].add(var_names) self._factor_types_to_groups[factor_type].append(factor_group) + if name is not None: + self._named_factor_groups[name] = factor_group + @functools.lru_cache(None) def compute_offsets(self) -> None: """Compute factor messages offsets for the factor types and factor groups @@ -258,7 +259,7 @@ def compute_offsets(self) -> None: factor_group ] = factor_num_configs_cumsum - factor_num_states_cumsum += sum(factor_group.factor_edges_num_states) + factor_num_states_cumsum += factor_group.total_num_states factor_num_configs_cumsum += ( factor_group.factor_group_log_potentials.shape[0] ) @@ -739,19 +740,22 @@ def update_evidence( Returns: A flat jnp array containing updated evidence. """ - for var_group_name, data in updates.items(): - if var_group_name in fg_state.variable_group: - variable_group = fg_state.variable_group[var_group_name] - first_variable = variable_group.variable_names.flatten()[0] + for name, data in updates.items(): + # Name is a variable_group or a variable + if name in fg_state.variable_group: + first_variable = name.variables[0] start_index = fg_state.vars_to_starts[first_variable] - flat_data = variable_group.flatten(data) + flat_data = name.flatten(data) evidence = evidence.at[start_index : start_index + flat_data.shape[0]].set( flat_data ) - # else: - # var = fg_state.variable_group[name] - # start_index = fg_state.vars_to_starts[var] - # evidence = evidence.at[start_index : start_index + var.num_states].set(data) + elif name in fg_state.vars_to_starts: + start_index = fg_state.vars_to_starts[name] + evidence = evidence.at[start_index : start_index + name[1]].set(data) + else: + raise ValueError( + "Got evidence for a variable or a variable group not in the FactorGraph!" + ) return evidence @@ -781,19 +785,18 @@ def __post_init__(self): object.__setattr__(self, "value", self.value) - def __getitem__(self, name: Any) -> np.ndarray: + def __getitem__(self, variable: Tuple[int, int]) -> np.ndarray: """Function to query evidence for a variable Args: - name: name for the variable + variable: Variable queried Returns: evidence for the queried variable """ value = cast(np.ndarray, self.value) - variable = self.fg_state.variable_group[name] start = self.fg_state.vars_to_starts[variable] - evidence = value[start : start + variable.num_states] + evidence = value[start : start + variable[1]] return evidence def __setitem__( @@ -888,18 +891,18 @@ class BeliefPropagation: Returns: The reconstructed BPState - get_bp_output: Function to calculate beliefs from a BPArrays. + get_beliefs: Function to calculate beliefs from a BPArrays. Args: bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence. Returns: - bp_output: Belief propagation output. + beliefs: Beliefs returned by belief propagation. """ init: Callable update: Callable run_bp: Callable to_bp_state: Callable - get_bp_output: Callable + get_beliefs: Callable def BP(bp_state: BPState, temperature: float = 0.0) -> BeliefPropagation: @@ -968,19 +971,16 @@ def update( if log_potentials_updates is not None: log_potentials = update_log_potentials( - log_potentials, HashableDict(log_potentials_updates), bp_state.fg_state + log_potentials, log_potentials_updates, bp_state.fg_state ) if ftov_msgs_updates is not None: ftov_msgs = update_ftov_msgs( - ftov_msgs, HashableDict(ftov_msgs_updates), bp_state.fg_state + ftov_msgs, ftov_msgs_updates, bp_state.fg_state ) if evidence_updates is not None: - # Note: if we overwrite variables.__hash__ then hash may change when we call jit - evidence = update_evidence( - evidence, HashableDict(evidence_updates), bp_state.fg_state - ) + evidence = update_evidence(evidence, evidence_updates, bp_state.fg_state) return BPArrays( log_potentials=log_potentials, ftov_msgs=ftov_msgs, evidence=evidence @@ -1067,92 +1067,7 @@ def to_bp_state(bp_arrays: BPArrays) -> BPState: evidence=Evidence(fg_state=bp_state.fg_state, value=bp_arrays.evidence), ) - def get_bp_output(bp_arrays: BPArrays) -> Any: - """Function to calculate beliefs from a BPArrays - - Args: - bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence. - - Returns: - bp_output: Belief propagation output. - """ - - @jax.jit - def compute_flat_beliefs(bp_arrays, var_states_for_edges): - flat_beliefs = ( - jax.device_put(bp_arrays.evidence) - .at[jax.device_put(var_states_for_edges)] - .add(bp_arrays.ftov_msgs) - ) - return flat_beliefs - - return BeliefPropagationOutputs( - compute_flat_beliefs(bp_arrays, var_states_for_edges), - bp_state.fg_state.variable_group, - ) - - bp = BeliefPropagation( - init=functools.partial(update, None), - update=update, - run_bp=run_bp, - to_bp_state=to_bp_state, - get_bp_output=get_bp_output, - ) - return bp - - -@jax.tree_util.register_pytree_node_class -@dataclass(frozen=True, eq=False) -class HashableDict: - "Represents a dictionnary where stored keys are the hash of the elements" - d: Dict = field(default_factory=dict) - - def __post_init__(self): - # Allows to initialize from dict without knowing hash - new_d = {k.__hash__(): v for k, v in self.d.items()} - object.__setattr__(self, "d", new_d) - - @functools.lru_cache() - def keys(self): - return self.d.keys() - - def items(self): - return self.d.items() - - def __setitem__(self, key, value): - # Allows to copy keys from another HashableDict - self.d[key] = value - - def __getitem__(self, value): - # Allows to retrieve from dict without knowing hash - return self.d[value.__hash__()] - - def tree_flatten(self): - return jax.tree_util.tree_flatten(asdict(self)) - - @classmethod - def tree_unflatten(cls, aux_data, children): - return cls(**aux_data.unflatten(children)) - - -@jax.tree_util.register_pytree_node_class -@dataclass(frozen=True, eq=False) -class BeliefPropagationOutputs: - # beliefs: An array or a PyTree container containing the beliefs for the variables. - flat_beliefs: jnp.ndarray - variable_groups: Mapping[int, groups.VariableGroup] - beliefs: HashableDict = field(init=False) - - def __post_init__(self): - if self.flat_beliefs.ndim != 1: - raise ValueError( - f"Can only unflatten 1D array. Got a {self.flat_beliefs.ndim}D array." - ) - beliefs = self.unflatten() - object.__setattr__(self, "beliefs", beliefs) - - @functools.lru_cache - def unflatten(self) -> None: + def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]: """Function that recovers meaningful structured data from internal flat data array Args: @@ -1164,86 +1079,115 @@ def unflatten(self) -> None: Raises: ValueError: if flat_data is not of the right shape """ - # Note: this is a reimplementation of CompositeVariableGroup.unflatten + if flat_beliefs.ndim != 1: + raise ValueError( + f"Can only unflatten 1D array. Got a {flat_beliefs.ndim}D array." + ) + num_variables = 0 num_variable_states = 0 - for variable_group in self.variable_groups.values(): + for variable_group in variable_groups: if isinstance(variable_group, vgroup.NDVariableArray): num_variables += variable_group.num_states.size num_variable_states += variable_group.num_states.sum() - if self.flat_beliefs.shape[0] == num_variables: + if flat_beliefs.shape[0] == num_variables: use_num_states = False - elif self.flat_beliefs.shape[0] == num_variable_states: + elif flat_beliefs.shape[0] == num_variable_states: use_num_states = True else: raise ValueError( - f"flat_data should be either of shape (num_variables(={len(self.variables)}),), " + f"flat_data should be either of shape (num_variables(={len(num_variables)}),), " f"or (num_variable_states(={num_variable_states}),). " - f"Got {self.flat_beliefs.shape}" + f"Got {flat_beliefs.shape}" ) - beliefs = HashableDict() + beliefs = {} start = 0 - for name, variable_group in self.variable_groups.items(): + for variable_group in variable_groups: if use_num_states: length = variable_group.num_states.sum() else: length = variable_group.num_states.size - beliefs[name] = variable_group.unflatten( - self.flat_beliefs[start : start + length] + beliefs[variable_group] = variable_group.unflatten( + flat_beliefs[start : start + length] ) start += length return beliefs - @cached_property - def map_states(self) -> Any: - """Function to decode MAP states given the calculated beliefs. + @jax.jit + def get_beliefs(bp_arrays: BPArrays) -> Dict[Hashable, Any]: + """Function to calculate beliefs from a BPArrays Args: - beliefs: An array or a PyTree container containing beliefs for different variables. + bp_arrays: A BPArrays containing log_potentials, ftov_msgs, evidence. Returns: - An array or a PyTree container containing the MAP states for different variables. + beliefs: Beliefs returned by belief propagation. """ - @jax.jit - def _decode_map_states(beliefs) -> Any: - return jax.tree_util.tree_map(lambda x: jnp.argmax(x, axis=-1), beliefs) + def compute_flat_beliefs(bp_arrays, var_states_for_edges): + flat_beliefs = ( + jax.device_put(bp_arrays.evidence) + .at[jax.device_put(var_states_for_edges)] + .add(bp_arrays.ftov_msgs) + ) + return flat_beliefs + + return unflatten_beliefs( + compute_flat_beliefs(bp_arrays, var_states_for_edges), + bp_state.fg_state.variable_group, + ) + + bp = BeliefPropagation( + init=functools.partial(update, None), + update=update, + run_bp=run_bp, + to_bp_state=to_bp_state, + get_beliefs=get_beliefs, + ) + return bp - map_states = HashableDict() - for name, beliefs in self.beliefs.items(): - map_states[name] = _decode_map_states(beliefs) - return map_states - @cached_property - def get_marginals(self) -> Any: - """Function to get marginal probabilities given the calculated beliefs. +def decode_map_states(beliefs: Dict[Hashable, Any]) -> Dict[Hashable, Any]: + """Function to decode MAP states given the calculated beliefs. - Args: - beliefs: An array or a PyTree container containing beliefs for different variables. + Args: + beliefs: An array or a PyTree container containing beliefs for different variables. - Returns: - An array or a PyTree container containing the marginal probabilities different variables. - """ + Returns: + An array or a PyTree container containing the MAP states for different variables. + """ - @jax.jit - def _get_marginals(beliefs) -> Any: - return jax.tree_util.tree_map( - lambda x: jnp.exp(x - logsumexp(x, axis=-1, keepdims=True)), - beliefs, - ) + @jax.jit + def _decode_map_states(beliefs) -> Any: + return jax.tree_util.tree_map(lambda x: jnp.argmax(x, axis=-1), beliefs) - marginals = HashableDict() - for name, beliefs in self.beliefs.items(): - marginals[name] = _get_marginals(beliefs) - return marginals + map_states = {} + for variable_group, vgroup_beliefs in beliefs.items(): + map_states[variable_group] = _decode_map_states(vgroup_beliefs) + return map_states - def tree_flatten(self): - return jax.tree_util.tree_flatten(asdict(self)) - @classmethod - def tree_unflatten(cls, aux_data, children): - # TODO: fix this - return cls(**aux_data.unflatten(children)) +def get_marginals(beliefs: Dict[Hashable, Any]) -> Dict[Hashable, Any]: + """Function to get marginal probabilities given the calculated beliefs. + + Args: + beliefs: An array or a PyTree container containing beliefs for different variables. + + Returns: + An array or a PyTree container containing the marginal probabilities different variables. + """ + + @jax.jit + def _get_marginals(beliefs) -> Any: + return jax.tree_util.tree_map( + lambda x: jnp.exp(x - logsumexp(x, axis=-1, keepdims=True)), + beliefs, + ) + + marginals = {} + for variable_group, vgroup_beliefs in beliefs.items(): + marginals[variable_group] = _get_marginals(vgroup_beliefs) + return marginals diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index f9f3dd41..1752a002 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -89,55 +89,28 @@ class FactorGroup: """Class to represent a group of Factors. Args: - vars_to_num_states: TODO - variable_names_for_factors: A list of list of variable names, where each innermost element is the - name of a variable. Each list within the outer list is taken to contain the names of the + variables_for_factors: A list of list of variables, where each innermost element is a + variable. Each list within the outer list is taken to contain the names of the variables connected to a Factor. factor_configs: Optional array containing an explicit enumeration of all valid configurations log_potentials: Array of log potentials. Attributes: factor_type: Factor type shared by all the Factors in the FactorGroup. - factor_sizes: Array of the different factor sizes. - factor_edges_num_states: Array concatenating the number of states for the variables connected to each Factor of - the FactorGroup. Each variable will appear once for each Factor it connects to. Raises: ValueError: if the FactorGroup does not contain a Factor """ - vars_to_num_states: Mapping[int, int] - variable_names_for_factors: Sequence[List] + variables_for_factors: Sequence[List] factor_configs: np.ndarray = field(init=False) log_potentials: np.ndarray = field(init=False, default=np.empty((0,))) factor_type: Type = field(init=False) - factor_sizes: np.ndarray = field(init=False) - variables_for_factors: Tuple[Tuple[int], ...] = field(init=False) - factor_edges_num_states: np.ndarray = field(init=False) def __post_init__(self): - if len(self.variable_names_for_factors) == 0: + if len(self.variables_for_factors) == 0: raise ValueError("Do not add a factor group with no factors.") - # Note: variable_names_for_factors contains the HASHes - # Note: this can probably be sped up by numba - factor_sizes = [] - flat_var_names_for_factors = [] - factor_edges_num_states = [] - for variable_names_for_factor in self.variable_names_for_factors: - for variable_name in variable_names_for_factor: - factor_edges_num_states.append(self.vars_to_num_states[variable_name]) - flat_var_names_for_factors.append(variable_name) - factor_sizes.append(len(variable_names_for_factor)) - - object.__setattr__(self, "factor_sizes", np.array(factor_sizes)) - object.__setattr__( - self, "variables_for_factors", np.array(flat_var_names_for_factors) - ) - object.__setattr__( - self, "factor_edges_num_states", np.array(factor_edges_num_states) - ) - def __getitem__(self, variables: Sequence[int]) -> Any: """Function to query individual factors in the factor group @@ -168,6 +141,17 @@ def _variables_to_factors(self) -> Mapping[FrozenSet, nodes.Factor]: """ return self._get_variables_to_factors() + @cached_property + def total_num_states(self) -> int: + """TODO""" + return sum( + [ + variable[1] + for variables_for_factor in self.variables_for_factors + for variable in variables_for_factor + ] + ) + @cached_property def factor_group_log_potentials(self) -> np.ndarray: """Flattened array of log potentials""" @@ -182,7 +166,7 @@ def factors(self) -> Tuple[nodes.Factor, ...]: @cached_property def num_factors(self) -> int: """Returns the number of factors in the FactorGroup.""" - return len(self.variable_names_for_factors) + return len(self.variables_for_factors) def _get_variables_to_factors(self) -> OrderedDict[FrozenSet, Any]: """Function that generates a dictionary mapping names to factors. @@ -260,9 +244,9 @@ class SingleFactorGroup(FactorGroup): def __post_init__(self): super().__post_init__() - if not len(self.variable_names_for_factors) == 1: + if not len(self.variables_for_factors) == 1: raise ValueError( - f"SingleFactorGroup should only contain one factor. Got {len(self.variable_names_for_factors)}" + f"SingleFactorGroup should only contain one factor. Got {len(self.variables_for_factors)}" ) object.__setattr__(self, "factor_type", type(self.factor)) @@ -282,9 +266,7 @@ def _get_variables_to_factors( Returns: A dictionary mapping all possible names to different factors. """ - return OrderedDict( - [(frozenset(self.variable_names_for_factors[0]), self.factor)] - ) + return OrderedDict([(frozenset(self.variables_for_factors[0]), self.factor)]) def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: """Function that turns meaningful structured data into a flat data array for internal use. diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py index 97fe7c5e..97e6bd3b 100644 --- a/pgmax/fg/nodes.py +++ b/pgmax/fg/nodes.py @@ -1,7 +1,7 @@ """A module containing classes that specify the basic components of a Factor Graph.""" from dataclasses import asdict, dataclass -from typing import OrderedDict, Sequence, Union +from typing import List, Sequence, Tuple, Union import jax import jax.numpy as jnp @@ -43,14 +43,14 @@ class Factor: """A factor Args: - vars_to_num_states: Dictionnary mapping the variables names, represented - in the form of a hash, to the variables number of states. + variables: List of variables in the factors. Each variable is represented + by a tuple containing the variable hash and number of states. Raises: NotImplementedError: If compile_wiring is not implemented """ - vars_to_num_states: OrderedDict[int, int] + variables: List[Tuple[int, int]] log_potentials: np.ndarray def __post_init__(self): diff --git a/pgmax/groups/enumeration.py b/pgmax/groups/enumeration.py index f091c53d..8c94ad46 100644 --- a/pgmax/groups/enumeration.py +++ b/pgmax/groups/enumeration.py @@ -80,19 +80,14 @@ def _get_variables_to_factors( variables_to_factors = collections.OrderedDict( [ ( - frozenset(variable_names_for_factor), + frozenset(variables_for_factor), enumeration.EnumerationFactor( - vars_to_num_states=collections.OrderedDict( - (var, self.vars_to_num_states[var]) - for var in variable_names_for_factor - ), + variables=variables_for_factor, factor_configs=self.factor_configs, log_potentials=np.array(self.log_potentials)[ii], ), ) - for ii, variable_names_for_factor in enumerate( - self.variable_names_for_factors - ) + for ii, variables_for_factor in enumerate(self.variables_for_factors) ] ) return variables_to_factors @@ -188,7 +183,7 @@ class PairwiseFactorGroup(groups.FactorGroup): Args: log_potential_matrix: array of shape (var1.num_states, var2.num_states), where var1 and var2 are the 2 VariableGroups (that may refer to the same - VariableGroup) whose names are present in each sub-list from self.variable_names_for_factors. + VariableGroup) whose names are present in each sub-list from self.variables_for_factors. factor_type: Factor type shared by all the Factors in the FactorGroup. Raises: @@ -210,8 +205,8 @@ def __post_init__(self): if self.log_potential_matrix is None: log_potential_matrix = np.zeros( ( - self.vars_to_num_states[self.variable_names_for_factors[0][0]], - self.vars_to_num_states[self.variable_names_for_factors[0][1]], + self.variables_for_factors[0][0][1], + self.variables_for_factors[0][1][1], ) ) else: @@ -230,25 +225,25 @@ def __post_init__(self): ) if log_potential_matrix.ndim == 3 and log_potential_matrix.shape[0] != len( - self.variable_names_for_factors + self.variables_for_factors ): raise ValueError( - f"Expected log_potential_matrix for {len(self.variable_names_for_factors)} factors. " + f"Expected log_potential_matrix for {len(self.variables_for_factors)} factors. " f"Got log_potential_matrix for {log_potential_matrix.shape[0]} factors." ) - for fac_list in self.variable_names_for_factors: - if len(fac_list) != 2: + for variables_for_factor in self.variables_for_factors: + if len(variables_for_factor) != 2: raise ValueError( "All pairwise factors should connect to exactly 2 variables. Got a factor connecting to" - f" {len(fac_list)} variables ({fac_list})." + f" {len(variables_for_factor)} variables ({variables_for_factor})." ) - num_states0 = self.vars_to_num_states[fac_list[0]] - num_states1 = self.vars_to_num_states[fac_list[1]] + num_states0 = variables_for_factor[0][1] + num_states1 = variables_for_factor[1][1] if not log_potential_matrix.shape[-2:] == (num_states0, num_states1): raise ValueError( - f"The specified pairwise factor {fac_list} (with {(num_states0, num_states1)}" + f"The specified pairwise factor {variables_for_factor} (with {(num_states0, num_states1)}" f"configurations) does not match the specified log_potential_matrix " f"(with {log_potential_matrix.shape[-2:]} configurations)." ) @@ -264,7 +259,7 @@ def __post_init__(self): object.__setattr__(self, "factor_configs", factor_configs) log_potential_matrix = np.broadcast_to( log_potential_matrix, - (len(self.variable_names_for_factors),) + log_potential_matrix.shape[-2:], + (len(self.variables_for_factors),) + log_potential_matrix.shape[-2:], ) log_potentials = np.empty( shape=(self.num_factors, self.factor_configs.shape[0]) @@ -286,19 +281,14 @@ def _get_variables_to_factors( variables_to_factors = collections.OrderedDict( [ ( - frozenset(variable_names_for_factor), + frozenset(variable_for_factor), enumeration.EnumerationFactor( - vars_to_num_states=collections.OrderedDict( - (var, self.vars_to_num_states[var]) - for var in variable_names_for_factor - ), + variables=variable_for_factor, factor_configs=self.factor_configs, log_potentials=self.log_potentials[ii], ), ) - for ii, variable_names_for_factor in enumerate( - self.variable_names_for_factors - ) + for ii, variable_for_factor in enumerate(self.variables_for_factors) ] ) return variables_to_factors diff --git a/pgmax/groups/logical.py b/pgmax/groups/logical.py index 325a4313..0f4714d5 100644 --- a/pgmax/groups/logical.py +++ b/pgmax/groups/logical.py @@ -34,15 +34,10 @@ def _get_variables_to_factors( variables_to_factors = collections.OrderedDict( [ ( - frozenset(variable_names_for_factor), - self.factor_type( - vars_to_num_states=collections.OrderedDict( - (var, self.vars_to_num_states[var]) - for var in variable_names_for_factor - ), - ), + frozenset(variables_for_factor), + self.factor_type(variables=variables_for_factor), ) - for variable_names_for_factor in self.variable_names_for_factors + for variables_for_factor in self.variables_for_factors ] ) return variables_to_factors diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index 4c6587f3..a6f3e8d5 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -1,9 +1,9 @@ """A module containing the variables group classes inheriting from the base VariableGroup.""" -import itertools import random from dataclasses import dataclass -from typing import Tuple, Union +from functools import total_ordering +from typing import List, Tuple, Union import jax import jax.numpy as jnp @@ -12,6 +12,7 @@ from pgmax.utils import cached_property +@total_ordering @dataclass(frozen=True, eq=False) class NDVariableArray: """Subclass of VariableGroup for n-dimensional grids of variables. @@ -22,6 +23,9 @@ class NDVariableArray: the notion of a NumPy ndarray shape) """ + # TODO: Variables = (hash, num_states) + # TODO: VariableGroup can be deleted + shape: Tuple[int, ...] num_states: Union[int, np.ndarray] @@ -34,10 +38,27 @@ def __post_init__(self): elif isinstance(self.num_states, np.ndarray): if self.num_states.shape != self.shape: raise ValueError("Should be same shape") + random_hash = random.randint(0, 2**63) + object.__setattr__(self, "random_hash", random_hash) + + def __hash__(self): + return self.random_hash + + def __eq__(self, other): + return hash(self) == hash(other) + + def __lt__(self, other): + return hash(self) < hash(other) def __getitem__(self, val): # Numpy indexation will throw IndexError for us if out-of-bounds - return self.variable_names[val] + return (self.variable_names[val], self.num_states[val]) + + @cached_property + def variables(self) -> List[Tuple]: + vars_names = self.variable_names.flatten() + vars_num_states = self.num_states.flatten() + return list(zip(vars_names, vars_num_states)) @cached_property def variable_names(self) -> np.ndarray: @@ -47,9 +68,8 @@ def variable_names(self) -> np.ndarray: a dictionary mapping all possible names to different variables. """ # Overwite default hash as it does not give enough spacing across consecutive objects - this_hash = random.randint(0, 2**63) indices = np.reshape(np.arange(np.product(self.shape)), self.shape) - return this_hash + indices + return self.__hash__() + indices def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: """Function that turns meaningful structured data into a flat data array for internal use. @@ -64,7 +84,7 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: Raises: ValueError: If the data is not of the correct shape. """ - # TODO: what should we do for different number of states + # TODO: what should we do for different number of states -> look at maask_array if data.shape != self.shape and data.shape != self.shape + ( self.num_states.max(), ): @@ -110,7 +130,9 @@ def unflatten( return data -# TODO: delete? +# TODO: delete? -- NO + + # @dataclass(frozen=True, eq=False) # class VariableDict(): # """A variable dictionary that contains a set of variables of the same size diff --git a/tests/factors/test_and.py b/tests/factors/test_and.py index 0af04db1..e95e22a8 100644 --- a/tests/factors/test_and.py +++ b/tests/factors/test_and.py @@ -8,14 +8,14 @@ from pgmax.groups import variables as vgroup -def test_run_bp_with_AND_factors(): +def test_run_bp_with_ANDFactors(): """ Simultaneously test (1) the support of ANDFactors in a FactorGraph and their specialized inference for different temperatures (2) the support of several factor types in a FactorGraph and during inference To do so, observe that an ANDFactor can be defined as an equivalent EnumerationFactor - (which list all the valid AND configurations) and define two equivalent FactorGraphs + (which list all the valid OR configurations) and define two equivalent FactorGraphs FG1: first half of factors are defined as EnumerationFactors, second half are defined as ANDFactors FG2: first half of factors are defined as ANDFactors, second half are defined as EnumerationFactors @@ -25,6 +25,7 @@ def test_run_bp_with_AND_factors(): Note: for the first seed, add all the EnumerationFactors to FG1 and all the ANDFactors to FG2 """ for idx in range(10): + print("it", idx) np.random.seed(idx) # Parameters @@ -43,30 +44,41 @@ def test_run_bp_with_AND_factors(): parents_variables1 = vgroup.NDVariableArray( num_states=2, shape=(num_parents.sum(),) ) - children_variable1 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,)) - fg1 = graph.FactorGraph( - variables=dict(parents=parents_variables1, children=children_variable1) - ) + children_variables1 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,)) + fg1 = graph.FactorGraph(variables=[parents_variables1, children_variables1]) # Graph 2 parents_variables2 = vgroup.NDVariableArray( num_states=2, shape=(num_parents.sum(),) ) - children_variable2 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,)) - fg2 = graph.FactorGraph( - variables=dict(parents=parents_variables2, children=children_variable2) - ) + children_variables2 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,)) + fg2 = graph.FactorGraph(variables=[parents_variables2, children_variables2]) - # Option 1: Define EnumerationFactors equivalent to the ANDFactors + # Variable names for factors + variables_for_factors1 = [] + variables_for_factors2 = [] for factor_idx in range(num_factors): - this_num_parents = num_parents[factor_idx] - variable_names = [ - ("parents", idx) + variable_names1 = [ + parents_variables1[idx] for idx in range( num_parents_cumsum[factor_idx], num_parents_cumsum[factor_idx + 1], ) - ] + [("children", factor_idx)] + ] + [children_variables1[factor_idx]] + variables_for_factors1.append(variable_names1) + + variable_names2 = [ + parents_variables2[idx] + for idx in range( + num_parents_cumsum[factor_idx], + num_parents_cumsum[factor_idx + 1], + ) + ] + [children_variables2[factor_idx]] + variables_for_factors2.append(variable_names2) + + # Option 1: Define EnumerationFactors equivalent to the ANDFactors + for factor_idx in range(num_factors): + this_num_parents = num_parents[factor_idx] configs = np.array(list(product([0, 1], repeat=this_num_parents + 1))) # Children state is last @@ -84,7 +96,7 @@ def test_run_bp_with_AND_factors(): if factor_idx < num_factors // 2: # Add the first half of factors to FactorGraph1 fg1.add_factor( - variable_names=variable_names, + variable_names=variables_for_factors1[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) @@ -92,73 +104,76 @@ def test_run_bp_with_AND_factors(): if idx != 0: # Add the second half of factors to FactorGraph2 fg2.add_factor( - variable_names=variable_names, + variable_names=variables_for_factors2[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) else: # Add all the EnumerationFactors to FactorGraph1 for the first iter fg1.add_factor( - variable_names=variable_names, + variable_names=variables_for_factors1[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) # Option 2: Define the ANDFactors num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0) - variable_names_for_ANDFactors_fg1 = [] - variable_names_for_ANDFactors_fg2 = [] + variables_for_ANDFactors_fg1 = [] + variables_for_ANDFactors_fg2 = [] for factor_idx in range(num_factors): - variables_names_for_ANDFactor = [ - ("parents", idx) - for idx in range( - num_parents_cumsum[factor_idx], - num_parents_cumsum[factor_idx + 1], - ) - ] + [("children", factor_idx)] if factor_idx < num_factors // 2: # Add the first half of factors to FactorGraph2 - variable_names_for_ANDFactors_fg2.append(variables_names_for_ANDFactor) + variables_for_ANDFactors_fg2.append(variables_for_factors2[factor_idx]) else: if idx != 0: # Add the second half of factors to FactorGraph1 - variable_names_for_ANDFactors_fg1.append( - variables_names_for_ANDFactor + variables_for_ANDFactors_fg1.append( + variables_for_factors1[factor_idx] ) else: # Add all the ANDFactors to FactorGraph2 for the first iter - variable_names_for_ANDFactors_fg2.append( - variables_names_for_ANDFactor + variables_for_ANDFactors_fg2.append( + variables_for_factors2[factor_idx] ) - if idx != 0: fg1.add_factor_group( factory=logical.ANDFactorGroup, - variable_names_for_factors=variable_names_for_ANDFactors_fg1, + variables_for_factors=variables_for_ANDFactors_fg1, ) fg2.add_factor_group( factory=logical.ANDFactorGroup, - variable_names_for_factors=variable_names_for_ANDFactors_fg2, + variables_for_factors=variables_for_ANDFactors_fg2, ) # Run inference bp1 = graph.BP(fg1.bp_state, temperature=temperature) bp2 = graph.BP(fg2.bp_state, temperature=temperature) - evidence_updates = { - "parents": jax.device_put(np.random.gumbel(size=(sum(num_parents), 2))), - "children": jax.device_put(np.random.gumbel(size=(num_factors, 2))), + evidence_parents = jax.device_put(np.random.gumbel(size=(sum(num_parents), 2))) + evidence_children = jax.device_put(np.random.gumbel(size=(num_factors, 2))) + + evidence_updates1 = { + parents_variables1: evidence_parents, + children_variables1: evidence_children, + } + evidence_updates2 = { + parents_variables2: evidence_parents, + children_variables2: evidence_children, } - bp_arrays1 = bp1.init(evidence_updates=evidence_updates) + bp_arrays1 = bp1.init(evidence_updates=evidence_updates1) bp_arrays1 = bp1.run_bp(bp_arrays1, num_iters=5) - bp_arrays2 = bp2.init(evidence_updates=evidence_updates) + bp_arrays2 = bp2.init(evidence_updates=evidence_updates2) bp_arrays2 = bp2.run_bp(bp_arrays2, num_iters=5) # Get beliefs beliefs1 = bp1.get_beliefs(bp_arrays1) beliefs2 = bp2.get_beliefs(bp_arrays2) - assert np.allclose(beliefs1["children"], beliefs2["children"], atol=1e-4) - assert np.allclose(beliefs1["parents"], beliefs2["parents"], atol=1e-4) + assert np.allclose( + beliefs1[children_variables1], beliefs2[children_variables2], atol=1e-4 + ) + assert np.allclose( + beliefs1[parents_variables1], beliefs2[parents_variables2], atol=1e-4 + ) diff --git a/tests/factors/test_or.py b/tests/factors/test_or.py index e634fbab..7a6c0905 100644 --- a/tests/factors/test_or.py +++ b/tests/factors/test_or.py @@ -8,7 +8,7 @@ from pgmax.groups import variables as vgroup -def test_run_bp_with_OR_factors(): +def test_run_bp_with_ORFactors(): """ Simultaneously test (1) the support of ORFactors in a FactorGraph and their specialized inference for different temperatures @@ -55,8 +55,8 @@ def test_run_bp_with_OR_factors(): fg2 = graph.FactorGraph(variables=[parents_variables2, children_variables2]) # Variable names for factors - variable_names_for_factors1 = [] - variable_names_for_factors2 = [] + variables_for_factors1 = [] + variables_for_factors2 = [] for factor_idx in range(num_factors): variable_names1 = [ parents_variables1[idx] @@ -65,7 +65,7 @@ def test_run_bp_with_OR_factors(): num_parents_cumsum[factor_idx + 1], ) ] + [children_variables1[factor_idx]] - variable_names_for_factors1.append(variable_names1) + variables_for_factors1.append(variable_names1) variable_names2 = [ parents_variables2[idx] @@ -74,7 +74,7 @@ def test_run_bp_with_OR_factors(): num_parents_cumsum[factor_idx + 1], ) ] + [children_variables2[factor_idx]] - variable_names_for_factors2.append(variable_names2) + variables_for_factors2.append(variable_names2) # Option 1: Define EnumerationFactors equivalent to the ORFactors for factor_idx in range(num_factors): @@ -94,7 +94,7 @@ def test_run_bp_with_OR_factors(): if factor_idx < num_factors // 2: # Add the first half of factors to FactorGraph1 fg1.add_factor( - variable_names=variable_names_for_factors1[factor_idx], + variable_names=variables_for_factors1[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) @@ -102,48 +102,46 @@ def test_run_bp_with_OR_factors(): if idx != 0: # Add the second half of factors to FactorGraph2 fg2.add_factor( - variable_names=variable_names_for_factors2[factor_idx], + variable_names=variables_for_factors2[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) else: # Add all the EnumerationFactors to FactorGraph1 for the first iter fg1.add_factor( - variable_names=variable_names_for_factors1[factor_idx], + variable_names=variables_for_factors1[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) # Option 2: Define the ORFactors num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0) - variable_names_for_ORFactors_fg1 = [] - variable_names_for_ORFactors_fg2 = [] + variables_for_ORFactors_fg1 = [] + variables_for_ORFactors_fg2 = [] for factor_idx in range(num_factors): if factor_idx < num_factors // 2: # Add the first half of factors to FactorGraph2 - variable_names_for_ORFactors_fg2.append( - variable_names_for_factors2[factor_idx] - ) + variables_for_ORFactors_fg2.append(variables_for_factors2[factor_idx]) else: if idx != 0: # Add the second half of factors to FactorGraph1 - variable_names_for_ORFactors_fg1.append( - variable_names_for_factors1[factor_idx] + variables_for_ORFactors_fg1.append( + variables_for_factors1[factor_idx] ) else: # Add all the ORFactors to FactorGraph2 for the first iter - variable_names_for_ORFactors_fg2.append( - variable_names_for_factors2[factor_idx] + variables_for_ORFactors_fg2.append( + variables_for_factors2[factor_idx] ) if idx != 0: fg1.add_factor_group( factory=logical.ORFactorGroup, - variable_names_for_factors=variable_names_for_ORFactors_fg1, + variables_for_factors=variables_for_ORFactors_fg1, ) fg2.add_factor_group( factory=logical.ORFactorGroup, - variable_names_for_factors=variable_names_for_ORFactors_fg2, + variables_for_factors=variables_for_ORFactors_fg2, ) # Run inference @@ -168,8 +166,8 @@ def test_run_bp_with_OR_factors(): bp_arrays2 = bp2.run_bp(bp_arrays2, num_iters=5) # Get beliefs - beliefs1 = bp1.get_bp_output(bp_arrays1).beliefs - beliefs2 = bp2.get_bp_output(bp_arrays2).beliefs + beliefs1 = bp1.get_beliefs(bp_arrays1) + beliefs2 = bp2.get_beliefs(bp_arrays2) assert np.allclose( beliefs1[children_variables1], beliefs2[children_variables2], atol=1e-4 diff --git a/tests/fg/test_wiring.py b/tests/fg/test_wiring.py index a3b6ae2b..4966a66e 100644 --- a/tests/fg/test_wiring.py +++ b/tests/fg/test_wiring.py @@ -22,17 +22,15 @@ def test_wiring_with_PairwiseFactorGroup(): fg = graph.FactorGraph(variables=[A, B]) fg.add_factor_group( factory=enumeration.PairwiseFactorGroup, - variable_names_for_factors=[[A[idx], B[idx]] for idx in range(10)], + variables_for_factors=[[A[idx], B[idx]] for idx in range(10)], ) factor_group = fg.factor_groups[enumeration_factor.EnumerationFactor][0] object.__setattr__( - factor_group, - "factor_edges_num_states", - factor_group.factor_edges_num_states[:-1], + factor_group, "factor_configs", factor_group.factor_configs[:, :1] ) with pytest.raises( ValueError, - match=re.escape("Expected factor_edges_num_states shape is (20,). Got (19,)."), + match=re.escape("Expected factor_edges_num_states shape is (10,). Got (20,)."), ): factor_group.compile_wiring(fg._vars_to_starts) @@ -40,7 +38,7 @@ def test_wiring_with_PairwiseFactorGroup(): fg1 = graph.FactorGraph(variables=[A, B]) fg1.add_factor_group( factory=enumeration.PairwiseFactorGroup, - variable_names_for_factors=[[A[idx], B[idx]] for idx in range(10)], + variables_for_factors=[[A[idx], B[idx]] for idx in range(10)], ) assert len(fg1.factor_groups[enumeration_factor.EnumerationFactor]) == 1 @@ -49,7 +47,7 @@ def test_wiring_with_PairwiseFactorGroup(): for idx in range(10): fg2.add_factor_group( factory=enumeration.PairwiseFactorGroup, - variable_names_for_factors=[[A[idx], B[idx]]], + variables_for_factors=[[A[idx], B[idx]]], ) assert len(fg2.factor_groups[enumeration_factor.EnumerationFactor]) == 10 @@ -57,7 +55,7 @@ def test_wiring_with_PairwiseFactorGroup(): fg3 = graph.FactorGraph(variables=[A, B]) for idx in range(10): fg3.add_factor_by_type( - variable_names=[A[idx], B[idx]], + variables=[A[idx], B[idx]], factor_type=enumeration_factor.EnumerationFactor, **{ "factor_configs": np.array([[0, 0], [0, 1], [1, 0], [1, 1]]), @@ -101,7 +99,7 @@ def test_wiring_with_ORFactorGroup(): fg1 = graph.FactorGraph(variables=[A, B, C]) fg1.add_factor_group( factory=logical.ORFactorGroup, - variable_names_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)], + variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)], ) assert len(fg1.factor_groups[logical_factor.ORFactor]) == 1 @@ -110,7 +108,7 @@ def test_wiring_with_ORFactorGroup(): for idx in range(10): fg2.add_factor_group( factory=logical.ORFactorGroup, - variable_names_for_factors=[[A[idx], B[idx], C[idx]]], + variables_for_factors=[[A[idx], B[idx], C[idx]]], ) assert len(fg2.factor_groups[logical_factor.ORFactor]) == 10 @@ -118,7 +116,7 @@ def test_wiring_with_ORFactorGroup(): fg3 = graph.FactorGraph(variables=[A, B, C]) for idx in range(10): fg3.add_factor_by_type( - variable_names=[A[idx], B[idx], C[idx]], + variables=[A[idx], B[idx], C[idx]], factor_type=logical_factor.ORFactor, ) assert len(fg3.factor_groups[logical_factor.ORFactor]) == 10 @@ -156,7 +154,7 @@ def test_wiring_with_ANDFactorGroup(): fg1 = graph.FactorGraph(variables=[A, B, C]) fg1.add_factor_group( factory=logical.ANDFactorGroup, - variable_names_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)], + variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)], ) assert len(fg1.factor_groups[logical_factor.ANDFactor]) == 1 @@ -165,7 +163,7 @@ def test_wiring_with_ANDFactorGroup(): for idx in range(10): fg2.add_factor_group( factory=logical.ANDFactorGroup, - variable_names_for_factors=[[A[idx], B[idx], C[idx]]], + variables_for_factors=[[A[idx], B[idx], C[idx]]], ) assert len(fg2.factor_groups[logical_factor.ANDFactor]) == 10 @@ -173,7 +171,7 @@ def test_wiring_with_ANDFactorGroup(): fg3 = graph.FactorGraph(variables=[A, B, C]) for idx in range(10): fg3.add_factor_by_type( - variable_names=[A[idx], B[idx], C[idx]], + variables=[A[idx], B[idx], C[idx]], factor_type=logical_factor.ANDFactor, ) assert len(fg3.factor_groups[logical_factor.ANDFactor]) == 10 From ef92d1b8482168d76197cd4e8dca2f433fedfbe4 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Thu, 21 Apr 2022 00:39:47 +0000 Subject: [PATCH 06/35] Start tests + mypy --- examples/rbm.py | 4 +- pgmax/factors/enumeration.py | 6 +- pgmax/factors/logical.py | 6 +- pgmax/fg/graph.py | 52 ++++---- pgmax/fg/groups.py | 15 ++- pgmax/fg/nodes.py | 18 ++- pgmax/groups/enumeration.py | 8 +- pgmax/groups/variables.py | 225 +++++++++++++++++------------------ tests/fg/test_graph.py | 185 ++++++++++++++-------------- tests/fg/test_nodes.py | 31 ++--- 10 files changed, 268 insertions(+), 282 deletions(-) diff --git a/examples/rbm.py b/examples/rbm.py index 9a46c0dc..abd7435d 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -44,8 +44,8 @@ # %% [markdown] # We can then initialize the factor graph for the RBM with - -import imp +# +# import imp # %% from pgmax.fg import graph diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py index a8078e78..49bd8fd3 100644 --- a/pgmax/factors/enumeration.py +++ b/pgmax/factors/enumeration.py @@ -2,7 +2,7 @@ import functools from dataclasses import dataclass -from typing import Mapping, Sequence, Tuple, Union +from typing import List, Mapping, Sequence, Tuple, Union import jax import jax.numpy as jnp @@ -21,7 +21,7 @@ class EnumerationWiring(nodes.Wiring): Args: factor_configs_edge_states: Array of shape (num_factor_configs, 2) factor_configs_edge_states[ii] contains a pair of global enumeration factor_config and global edge_state indices - factor_configs_edge_states[ii, 0] contains the global EnumerExpected factor_edges_num_statesationFactor config index, + factor_configs_edge_states[ii, 0] contains the global EnumerationFactor config index, factor_configs_edge_states[ii, 1] contains the corresponding global edge_state index. Both indices only take into account the EnumerationFactors of the FactorGraph @@ -154,7 +154,7 @@ def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiri @staticmethod def compile_wiring( - variables_for_factors: Tuple[Tuple[int, int], ...], + variables_for_factors: Sequence[List], factor_configs: np.ndarray, vars_to_starts: Mapping[Tuple[int, int], int], num_factors: int, diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py index 81eab93d..a7554334 100644 --- a/pgmax/factors/logical.py +++ b/pgmax/factors/logical.py @@ -2,7 +2,7 @@ import functools from dataclasses import dataclass, field -from typing import Mapping, Optional, Sequence, Tuple, Union +from typing import List, Mapping, Optional, Sequence, Union import jax import jax.numpy as jnp @@ -88,7 +88,7 @@ class LogicalFactor(nodes.Factor): def __post_init__(self): if len(self.variables) < 2: raise ValueError( - "A LogicalFactor requires at least one parent variable and one child variable " + "A LogicalFactor requires at least one parent variable and one child variable" ) if not np.all([variable[1] == 2 for variable in self.variables]): @@ -140,7 +140,7 @@ def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring: @staticmethod def compile_wiring( - variables_for_factors: Tuple[Tuple[int, int], ...], + variables_for_factors: Sequence[List], vars_to_starts: Mapping[int, int], edge_states_offset: int, ) -> LogicalWiring: diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 1a8a9833..3e5cb570 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -60,8 +60,7 @@ def __post_init__(self): import time start = time.time() - # if isinstance(self.variables, groups.VariableGroup): - if isinstance(self.variables, vgroup.NDVariableArray): + if isinstance(self.variables, (vgroup.NDVariableArray, vgroup.VariableDict)): self.variable_groups = [self.variables] else: self.variable_groups = self.variables @@ -103,9 +102,9 @@ def __post_init__(self): ) # See FactorGraphState docstrings for documentation on the following fields self._num_var_states = vars_num_states_cumsum[-1] - self._vars_to_starts: OrderedDict[int, int] = collections.OrderedDict( - zip(self._variables, vars_num_states_cumsum[:-1]) - ) + self._vars_to_starts: OrderedDict[ + Tuple[int, int], int + ] = collections.OrderedDict(zip(self._variables, vars_num_states_cumsum[:-1])) self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {} print("3", time.time() - start) @@ -121,7 +120,7 @@ def __hash__(self) -> int: def add_factor( self, - variable_names: List, + variables: List, factor_configs: np.ndarray, log_potentials: Optional[np.ndarray] = None, name: Optional[str] = None, @@ -129,7 +128,7 @@ def add_factor( """Function to add a single factor to the FactorGraph. Args: - variable_names: A list containing the connected variable names. + variables: A list containing the connected variable names. Variable names are tuples of the type (variable_group_name, variable_name_within_variable_group) factor_configs: Array of shape (num_val_configs, num_variables) An array containing explicit enumeration of all valid configurations. @@ -143,7 +142,7 @@ def add_factor( initialized. """ factor_group = EnumerationFactorGroup( - variables_for_factors=[variable_names], + variables_for_factors=[variables], factor_configs=factor_configs, log_potentials=log_potentials, ) @@ -369,7 +368,7 @@ def fg_state(self) -> FactorGraphState: ) return FactorGraphState( - variable_group=self._variable_group, + variable_groups=self._variable_group, vars_to_starts=self._vars_to_starts, num_var_states=self._num_var_states, total_factor_num_states=self._total_factor_num_states, @@ -417,8 +416,8 @@ class FactorGraphState: wiring: Wiring derived for each factor type. """ - variable_group: groups.VariableGroup - vars_to_starts: Mapping[int, int] + variable_groups: Sequence[groups.VariableGroup] + vars_to_starts: Mapping[Tuple[int, int], int] num_var_states: int total_factor_num_states: int named_factor_groups: Mapping[Hashable, groups.FactorGroup] @@ -612,12 +611,11 @@ def update_ftov_msgs( (2) provided name is not valid for ftov_msgs updates. """ for names, data in updates.items(): - if names in fg_state.variable_group.names: - variable = fg_state.variable_group[names] - if data.shape != (variable.num_states,): + if names in fg_state.variable_groups: + if data.shape != (names.total_num_states,): raise ValueError( f"Given belief shape {data.shape} does not match expected " - f"shape {(variable.num_states,)} for variable {names}." + f"shape {(names.total_num_states,)} for variable." ) var_states_for_edges = np.concatenate( @@ -627,13 +625,13 @@ def update_ftov_msgs( ] ) - starts = np.nonzero( - var_states_for_edges == fg_state.vars_to_starts[variable] - )[0] - for start in starts: - ftov_msgs = ftov_msgs.at[start : start + variable.num_states].set( - data / starts.shape[0] - ) + # starts = np.nonzero( + # var_states_for_edges == fg_state.vars_to_starts[variable] + # )[0] + # for start in starts: + # ftov_msgs = ftov_msgs.at[start : start + variable.num_states].set( + # data / starts.shape[0] + # ) else: raise ValueError( "Invalid names for setting messages. " @@ -709,7 +707,7 @@ def __setitem__(self, names, data) -> None: if ( isinstance(names, tuple) and len(names) == 2 - and names[1] in self.fg_state.variable_group.names + and names[1] in self.fg_state.variable_groups ): names = (frozenset(names[0]), names[1]) @@ -742,7 +740,7 @@ def update_evidence( """ for name, data in updates.items(): # Name is a variable_group or a variable - if name in fg_state.variable_group: + if name in fg_state.variable_groups: first_variable = name.variables[0] start_index = fg_state.vars_to_starts[first_variable] flat_data = name.flatten(data) @@ -811,7 +809,7 @@ def __setitem__( If name is the name of a variable group, updates are derived by using the variable group to flatten the data. If name is the name of a variable, data should be of an array shape (num_states,) - If name is None, updates are derived by using self.fg_state.variable_group to flatten the data. + If name is None, updates are derived by using self.fg_state.variable_groups to flatten the data. data: Array containing the evidence updates. """ object.__setattr__( @@ -1097,7 +1095,7 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]: use_num_states = True else: raise ValueError( - f"flat_data should be either of shape (num_variables(={len(num_variables)}),), " + f"flat_data should be either of shape (num_variables(={num_variables}),), " f"or (num_variable_states(={num_variable_states}),). " f"Got {flat_beliefs.shape}" ) @@ -1137,7 +1135,7 @@ def compute_flat_beliefs(bp_arrays, var_states_for_edges): return unflatten_beliefs( compute_flat_beliefs(bp_arrays, var_states_for_edges), - bp_state.fg_state.variable_group, + bp_state.fg_state.variable_groups, ) bp = BeliefPropagation( diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 1752a002..06bc80cb 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -31,27 +31,26 @@ class VariableGroup: """ def __getitem__(self, val): - """Given a name, retrieve the associated Variable. + """Given a variable name, retrieve the associated Variable. Args: val: a single name corresponding to a single variable, or a list of such names Returns: A single variable if the name is not a list. A list of variables if name is a list - - Raises: - ValueError: if the name is not found in the group """ raise NotImplementedError( "Please subclass the VariableGroup class and override this method" ) @cached_property - def variables_names(self) -> Any: - """Function that generates a dictionary mapping names to variables. + def variables(self) -> Tuple[Any, int]: + """Function that returns the list of variables. Each variable is represented + by a tuple of the variable name (which can be a hash or a string) and its + number of states. Returns: - a dictionary mapping all possible names to different variables. + List of variables in the VariableGroup """ raise NotImplementedError( "Please subclass the VariableGroup class and override this method" @@ -205,7 +204,7 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any: "Please subclass the FactorGroup class and override this method" ) - def compile_wiring(self, vars_to_starts: Mapping[int, int]) -> Any: + def compile_wiring(self, vars_to_starts: Mapping[Tuple[int, int], int]) -> Any: """Compile an efficient wiring for the FactorGroup. Args: diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py index 97e6bd3b..67a998ba 100644 --- a/pgmax/fg/nodes.py +++ b/pgmax/fg/nodes.py @@ -7,8 +7,6 @@ import jax.numpy as jnp import numpy as np -from pgmax import utils - @jax.tree_util.register_pytree_node_class @dataclass(frozen=True, eq=False) @@ -59,15 +57,15 @@ def __post_init__(self): "Please implement compile_wiring in for your factor" ) - @utils.cached_property - def edges_num_states(self) -> np.ndarray: - """Number of states for the variables connected to each edge + # @utils.cached_property + # def edges_num_states(self) -> np.ndarray: + # """Number of states for the variables connected to each edge - Returns: - Array of shape (num_edges,) - Number of states for the variables connected to each edge - """ - return self.vars_to_num_states.values() + # Returns: + # Array of shape (num_edges,) + # Number of states for the variables connected to each edge + # """ + # return self.variables.values() @staticmethod def concatenate_wirings(wirings: Sequence) -> Wiring: diff --git a/pgmax/groups/enumeration.py b/pgmax/groups/enumeration.py index 8c94ad46..982a8a57 100644 --- a/pgmax/groups/enumeration.py +++ b/pgmax/groups/enumeration.py @@ -109,16 +109,12 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: num_factors = len(self.factors) if ( data.shape != (num_factors, self.factor_configs.shape[0]) - and data.shape - != ( - num_factors, - np.sum(self.factors[0].edges_num_states), - ) + and data.shape != (num_factors, np.prod(self.factor_configs.shape)) and data.shape != (self.factor_configs.shape[0],) ): raise ValueError( f"data should be of shape {(num_factors, self.factor_configs.shape[0])} or " - f"{(num_factors, np.sum(self.factors[0].edges_num_states))} or " + f"{(num_factors, np.prod(self.factor_configs.shape))} or " f"{(self.factor_configs.shape[0],)}. Got {data.shape}." ) diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index a6f3e8d5..b59e360c 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -3,18 +3,19 @@ import random from dataclasses import dataclass from functools import total_ordering -from typing import List, Tuple, Union +from typing import Any, List, Tuple, Union import jax import jax.numpy as jnp import numpy as np +from pgmax.fg import groups from pgmax.utils import cached_property @total_ordering @dataclass(frozen=True, eq=False) -class NDVariableArray: +class NDVariableArray(groups.VariableGroup): """Subclass of VariableGroup for n-dimensional grids of variables. Args: @@ -23,11 +24,8 @@ class NDVariableArray: the notion of a NumPy ndarray shape) """ - # TODO: Variables = (hash, num_states) - # TODO: VariableGroup can be deleted - shape: Tuple[int, ...] - num_states: Union[int, np.ndarray] + num_states: np.ndarray def __post_init__(self): # super().__post_init__() @@ -52,7 +50,10 @@ def __lt__(self, other): def __getitem__(self, val): # Numpy indexation will throw IndexError for us if out-of-bounds - return (self.variable_names[val], self.num_states[val]) + result = (self.variable_names[val], self.num_states[val]) + if isinstance(val, slice): + return tuple(zip(result)) + return result @cached_property def variables(self) -> List[Tuple]: @@ -84,7 +85,7 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: Raises: ValueError: If the data is not of the correct shape. """ - # TODO: what should we do for different number of states -> look at maask_array + # TODO: what should we do for different number of states -> look at mask_array if data.shape != self.shape and data.shape != self.shape + ( self.num_states.max(), ): @@ -130,109 +131,105 @@ def unflatten( return data -# TODO: delete? -- NO - - -# @dataclass(frozen=True, eq=False) -# class VariableDict(): -# """A variable dictionary that contains a set of variables of the same size - -# Args: -# num_states: The size of the variables in this variable group -# num_variables: The number of variables - -# """ - -# num_states: int -# num_variables: int - -# @cached_property -# def variable_names(self) -> np.ndarray: -# """Function that generates a dictionary mapping names to variables. - -# Returns: -# a dictionary mapping all possible names to different variables. -# """ -# return self.__hash__() + np.arange(self.num_states) - - -# def flatten( -# self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]] -# ) -> jnp.ndarray: -# """Function that turns meaningful structured data into a flat data array for internal use. - -# Args: -# data: Meaningful structured data. Should be a mapping with names from self.variable_names. -# Each value should be an array of shape (1,) (for e.g. MAP decodings) or -# (self.num_states,) (for e.g. evidence, beliefs). - -# Returns: -# A flat jnp.array for internal use - -# Raises: -# ValueError if: -# (1) data is referring to a non-existing variable -# (2) data is not of the correct shape -# """ -# for name in data: -# if name not in self._names_to_variables: -# raise ValueError( -# f"data is referring to a non-existent variable {name}." -# ) - -# if data[name].shape != (self.num_states,) and data[name].shape != (1,): -# raise ValueError( -# f"Variable {name} expects a data array of shape " -# f"{(self.num_states,)} or (1,). Got {data[name].shape}." -# ) - -# flat_data = jnp.concatenate([data[name].flatten() for name in self.names]) -# return flat_data - -# def unflatten( -# self, flat_data: Union[np.ndarray, jnp.ndarray] -# ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]: -# """Function that recovers meaningful structured data from internal flat data array - -# Args: -# flat_data: Internal flat data array. - -# Returns: -# Meaningful structured data. Should be a mapping with names from self.variable_names. -# Each value should be an array of shape (1,) (for e.g. MAP decodings) or -# (self.num_states,) (for e.g. evidence, beliefs). - -# Raises: -# ValueError if: -# (1) flat_data is not a 1D array -# (2) flat_data is not of the right shape -# """ -# if flat_data.ndim != 1: -# raise ValueError( -# f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." -# ) - -# num_variables = len(self.variable_names) -# num_variable_states = len(self.variable_names) * self.num_states -# if flat_data.shape[0] == num_variables: -# use_num_states = False -# elif flat_data.shape[0] == num_variable_states: -# use_num_states = True -# else: -# raise ValueError( -# f"flat_data should be either of shape (num_variables(={len(self.variables)}),), " -# f"or (num_variable_states(={num_variable_states}),). " -# f"Got {flat_data.shape}" -# ) - -# start = 0 -# data = {} -# for name in self.variable_names: -# if use_num_states: -# data[name] = flat_data[start : start + self.num_states] -# start += self.num_states -# else: -# data[name] = flat_data[np.array([start])] -# start += 1 - -# return data +@dataclass(frozen=True, eq=False) +class VariableDict(groups.VariableGroup): + """A variable dictionary that contains a set of variables of the same size + + Args: + num_states: The size of the variables in this variable group + variable_names: A tuple of all names of the variables in this variable group + """ + + num_states: int + variable_names: Tuple[Any, ...] + + @cached_property + def variables(self) -> List[Tuple]: + return list( + zip(self.variable_names, [self.num_states] * len(self.variable_names)) + ) + + def __getitem__(self, val): + # Numpy indexation will throw IndexError for us if out-of-bounds + return (val, self.num_states) + + # def flatten( + # self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]] + # ) -> jnp.ndarray: + # """Function that turns meaningful structured data into a flat data array for internal use. + + # Args: + # data: Meaningful structured data. Should be a mapping with names from self.variable_names. + # Each value should be an array of shape (1,) (for e.g. MAP decodings) or + # (self.num_states,) (for e.g. evidence, beliefs). + + # Returns: + # A flat jnp.array for internal use + + # Raises: + # ValueError if: + # (1) data is referring to a non-existing variable + # (2) data is not of the correct shape + # """ + # for name in data: + # if name not in self.variable_names: + # raise ValueError( + # f"data is referring to a non-existent variable {name}." + # ) + + # if data[name].shape != (self.num_states,) and data[name].shape != (1,): + # raise ValueError( + # f"Variable {name} expects a data array of shape " + # f"{(self.num_states,)} or (1,). Got {data[name].shape}." + # ) + + # flat_data = jnp.concatenate([data[name].flatten() for name in self.variable_names]) + # return flat_data + + # def unflatten( + # self, flat_data: Union[np.ndarray, jnp.ndarray] + # ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]: + # """Function that recovers meaningful structured data from internal flat data array + + # Args: + # flat_data: Internal flat data array. + + # Returns: + # Meaningful structured data. Should be a mapping with names from self.variable_names. + # Each value should be an array of shape (1,) (for e.g. MAP decodings) or + # (self.num_states,) (for e.g. evidence, beliefs). + + # Raises: + # ValueError if: + # (1) flat_data is not a 1D array + # (2) flat_data is not of the right shape + # """ + # if flat_data.ndim != 1: + # raise ValueError( + # f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." + # ) + + # num_variables = len(self.variable_names) + # num_variable_states = len(self.variable_names) * self.num_states + # if flat_data.shape[0] == num_variables: + # use_num_states = False + # elif flat_data.shape[0] == num_variable_states: + # use_num_states = True + # else: + # raise ValueError( + # f"flat_data should be either of shape (num_variables(={len(self.variables)}),), " + # f"or (num_variable_states(={num_variable_states}),). " + # f"Got {flat_data.shape}" + # ) + + # start = 0 + # data = {} + # for name in self.variable_names: + # if use_num_states: + # data[name] = flat_data[start : start + self.num_states] + # start += self.num_states + # else: + # data[name] = flat_data[np.array([start])] + # start += 1 + + # return data diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py index 9c88ed2a..9185536c 100644 --- a/tests/fg/test_graph.py +++ b/tests/fg/test_graph.py @@ -13,11 +13,11 @@ def test_factor_graph(): - variable_group = vgroup.VariableDict(15, (0,)) - fg = graph.FactorGraph(variable_group) + vg = vgroup.VariableDict(15, (0,)) + fg = graph.FactorGraph(vg) fg.add_factor_by_type( factor_type=enumeration_factor.EnumerationFactor, - variable_names=[0], + variables=[vg[0]], factor_configs=np.arange(15)[:, None], log_potentials=np.zeros(15), name="test", @@ -27,7 +27,7 @@ def test_factor_graph(): match="A factor group with the name test already exists. Please choose a different name", ): fg.add_factor( - variable_names=[0], + variables=[vg[0]], factor_configs=np.arange(15)[:, None], name="test", ) @@ -35,11 +35,11 @@ def test_factor_graph(): with pytest.raises( ValueError, match=re.escape( - f"A Factor of type {enumeration_factor.EnumerationFactor} involving variables {frozenset([0])} already exists." + f"A Factor of type {enumeration_factor.EnumerationFactor} involving variables {frozenset([(0, 15)])} already exists." ), ): fg.add_factor( - variable_names=[0], + variables=[vg[0]], factor_configs=np.arange(10)[:, None], ) @@ -49,46 +49,43 @@ def test_factor_graph(): f"Type {groups.FactorGroup} is not one of the supported factor types {FAC_TO_VAR_UPDATES.keys()}" ), ): - fg.add_factor_by_type(variable_names=[0], factor_type=groups.FactorGroup) + fg.add_factor_by_type(variables=[vg[0]], factor_type=groups.FactorGroup) def test_factor_adding(): A = vgroup.NDVariableArray(num_states=2, shape=(10,)) B = vgroup.NDVariableArray(num_states=2, shape=(10,)) - fg = graph.FactorGraph(variables=dict(A=A, B=B)) + fg = graph.FactorGraph(variables=[A, B]) with pytest.raises(ValueError, match="Do not add a factor group with no factors."): fg.add_factor_group( factory=logical.ORFactorGroup, - variable_names_for_factors=[], + variables_for_factors=[], ) - variables0 = [("A", 0), ("B", 0)] - variables1 = [("A", 1), ("B", 1)] - ORFactor = logical.ORFactorGroup( - fg._variable_group, variable_names_for_factors=[variables0] - ) + variables0 = (A[0], B[0]) + variables1 = (A[1], B[1]) + ORFactor = logical.ORFactorGroup(variables_for_factors=[variables0]) with pytest.raises( ValueError, match="SingleFactorGroup should only contain one factor. Got 2" ): groups.SingleFactorGroup( - variable_group=fg._variable_group, - variable_names_for_factors=[variables0, variables1], + variables_for_factors=[variables0, variables1], factor=ORFactor, ) def test_bp_state(): - variable_group = vgroup.VariableDict(15, (0,)) - fg0 = graph.FactorGraph(variable_group) + vg = vgroup.VariableDict(15, (0,)) + fg0 = graph.FactorGraph(vg) fg0.add_factor( - variable_names=[0], + variables=[vg[0]], factor_configs=np.arange(10)[:, None], name="test", ) - fg1 = graph.FactorGraph(variable_group) + fg1 = graph.FactorGraph(vg) fg1.add_factor( - variable_names=[0], + variables=[vg[0]], factor_configs=np.arange(15)[:, None], name="test", ) @@ -103,79 +100,79 @@ def test_bp_state(): ) -def test_log_potentials(): - variable_group = vgroup.VariableDict(15, (0,)) - fg = graph.FactorGraph(variable_group) - fg.add_factor( - variable_names=[0], - factor_configs=np.arange(10)[:, None], - name="test", - ) - with pytest.raises( - ValueError, - match=re.escape("Expected log potentials shape (10,) for factor group test."), - ): - fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15)) - - with pytest.raises( - ValueError, - match=re.escape(f"Invalid name {frozenset([0])} for log potentials updates."), - ): - fg.bp_state.log_potentials[[0]] = np.zeros(10) - - with pytest.raises( - ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)") - ): - graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15)) - - log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10)) - assert jnp.all(log_potentials["test"] == jnp.zeros(10)) - with pytest.raises( - ValueError, - match=re.escape(f"Invalid name {frozenset([1])} for log potentials updates."), - ): - fg.bp_state.log_potentials[[1]] - - -def test_ftov_msgs(): - variable_group = vgroup.VariableDict(15, (0,)) - fg = graph.FactorGraph(variable_group) - fg.add_factor( - variable_names=[0], - factor_configs=np.arange(10)[:, None], - name="test", - ) - with pytest.raises( - ValueError, - match=re.escape("Invalid names for setting messages"), - ): - fg.bp_state.ftov_msgs[[0], 0] = np.ones(10) - - with pytest.raises( - ValueError, - match=re.escape( - "Given belief shape (10,) does not match expected shape (15,) for variable 0" - ), - ): - fg.bp_state.ftov_msgs[0] = np.ones(10) - - with pytest.raises( - ValueError, match=re.escape("Expected messages shape (15,). Got (10,)") - ): - graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10)) - - ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15)) - with pytest.raises( - TypeError, match=re.escape("'FToVMessages' object is not subscriptable") - ): - ftov_msgs[(10,)] +# def test_log_potentials(): +# vg = vgroup.VariableDict(15, (0,)) +# fg = graph.FactorGraph(vg) +# fg.add_factor( +# variables=[vg[0]], +# factor_configs=np.arange(10)[:, None], +# name="test", +# ) +# with pytest.raises( +# ValueError, +# match=re.escape("Expected log potentials shape (10,) for factor group test."), +# ): +# fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15)) + +# with pytest.raises( +# ValueError, +# match=re.escape(f"Invalid name {frozenset([0])} for log potentials updates."), +# ): +# fg.bp_state.log_potentials[vg[0]] = np.zeros(10) + +# with pytest.raises( +# ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)") +# ): +# graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15)) + +# log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10)) +# assert jnp.all(log_potentials["test"] == jnp.zeros(10)) +# with pytest.raises( +# ValueError, +# match=re.escape(f"Invalid name {frozenset([1])} for log potentials updates."), +# ): +# fg.bp_state.log_potentials[[1]] + + +# def test_ftov_msgs(): +# variable_group = vgroup.VariableDict(15, (0,)) +# fg = graph.FactorGraph(variable_group) +# fg.add_factor( +# variable_names=[0], +# factor_configs=np.arange(10)[:, None], +# name="test", +# ) +# with pytest.raises( +# ValueError, +# match=re.escape("Invalid names for setting messages"), +# ): +# fg.bp_state.ftov_msgs[[0], 0] = np.ones(10) + +# with pytest.raises( +# ValueError, +# match=re.escape( +# "Given belief shape (10,) does not match expected shape (15,) for variable 0" +# ), +# ): +# fg.bp_state.ftov_msgs[0] = np.ones(10) + +# with pytest.raises( +# ValueError, match=re.escape("Expected messages shape (15,). Got (10,)") +# ): +# graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10)) + +# ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15)) +# with pytest.raises( +# TypeError, match=re.escape("'FToVMessages' object is not subscriptable") +# ): +# ftov_msgs[(10,)] def test_evidence(): - variable_group = vgroup.VariableDict(15, (0,)) - fg = graph.FactorGraph(variable_group) + vg = vgroup.VariableDict(15, (0,)) + fg = graph.FactorGraph(vg) fg.add_factor( - variable_names=[0], + variables=[vg[0]], factor_configs=np.arange(10)[:, None], name="test", ) @@ -189,10 +186,10 @@ def test_evidence(): def test_bp(): - variable_group = vgroup.VariableDict(15, (0,)) - fg = graph.FactorGraph(variable_group) + vg = vgroup.VariableDict(15, (0,)) + fg = graph.FactorGraph(vg) fg.add_factor( - variable_names=[0], + variables=[vg[0]], factor_configs=np.arange(10)[:, None], name="test", ) @@ -200,7 +197,7 @@ def test_bp(): bp_arrays = bp.update() bp_arrays = bp.update( bp_arrays=bp_arrays, - ftov_msgs_updates={0: np.zeros(15)}, + ftov_msgs_updates={vg[0]: np.zeros(15)}, ) bp_arrays = bp.run_bp(bp_arrays, num_iters=1) bp_arrays = replace(bp_arrays, log_potentials=jnp.zeros((10))) diff --git a/tests/fg/test_nodes.py b/tests/fg/test_nodes.py index e9abf9bc..27445726 100644 --- a/tests/fg/test_nodes.py +++ b/tests/fg/test_nodes.py @@ -5,36 +5,37 @@ from pgmax.factors import enumeration, logical from pgmax.fg import nodes +from pgmax.groups import variables as vgroup def test_enumeration_factor(): - variable = nodes.Variable(3) + variables = vgroup.NDVariableArray(num_states=3, shape=(1,)) with pytest.raises( NotImplementedError, match="Please implement compile_wiring in for your factor" ): nodes.Factor( - variables=(variable,), + variables=[variables[0]], log_potentials=np.array([0.0]), ) with pytest.raises(ValueError, match="Configurations should be integers. Got"): enumeration.EnumerationFactor( - variables=(variable,), + variables=[variables[0]], factor_configs=np.array([[1.0]]), log_potentials=np.array([0.0]), ) with pytest.raises(ValueError, match="Potential should be floats. Got"): enumeration.EnumerationFactor( - variables=(variable,), + variables=[variables[0]], factor_configs=np.array([[1]]), log_potentials=np.array([0]), ) with pytest.raises(ValueError, match="factor_configs should be a 2D array"): enumeration.EnumerationFactor( - variables=(variable,), + variables=[variables[0]], factor_configs=np.array([1]), log_potentials=np.array([0.0]), ) @@ -46,7 +47,7 @@ def test_enumeration_factor(): ), ): enumeration.EnumerationFactor( - variables=(variable,), + variables=[variables[0]], factor_configs=np.array([[1, 2]]), log_potentials=np.array([0.0]), ) @@ -55,27 +56,27 @@ def test_enumeration_factor(): ValueError, match=re.escape("Expected log potentials of shape (1,)") ): enumeration.EnumerationFactor( - variables=(variable,), + variables=[variables[0]], factor_configs=np.array([[1]]), log_potentials=np.array([0.0, 1.0]), ) with pytest.raises(ValueError, match="Invalid configurations for given variables"): enumeration.EnumerationFactor( - variables=(variable,), + variables=[variables[0]], factor_configs=np.array([[10]]), log_potentials=np.array([0.0]), ) def test_logical_factor(): - child = nodes.Variable(2) - wrong_parent = nodes.Variable(3) - parent = nodes.Variable(2) + child = vgroup.NDVariableArray(num_states=2, shape=(1,))[0] + wrong_parent = vgroup.NDVariableArray(num_states=3, shape=(1,))[0] + parent = vgroup.NDVariableArray(num_states=2, shape=(1,))[0] with pytest.raises( ValueError, - match="At least one parent variable and one child variable is required", + match="A LogicalFactor requires at least one parent variable and one child variable", ): logical.LogicalFactor( variables=(child,), @@ -100,7 +101,7 @@ def test_logical_factor(): with pytest.raises(ValueError, match="The highest LogicalFactor index must be 0"): logical.LogicalWiring( - edges_num_states=logical_factor.edges_num_states, + edges_num_states=[2, 2], var_states_for_edges=None, parents_edge_states=parents_edge_states + np.array([[1, 0]]), children_edge_states=child_edge_state, @@ -112,7 +113,7 @@ def test_logical_factor(): match="The LogicalWiring must have 1 different LogicalFactor indices", ): logical.LogicalWiring( - edges_num_states=logical_factor.edges_num_states, + edges_num_states=[2, 2], var_states_for_edges=None, parents_edge_states=parents_edge_states + np.array([[0], [1]]), children_edge_states=child_edge_state, @@ -126,7 +127,7 @@ def test_logical_factor(): ), ): logical.LogicalWiring( - edges_num_states=logical_factor.edges_num_states, + edges_num_states=[2, 2], var_states_for_edges=None, parents_edge_states=parents_edge_states, children_edge_states=child_edge_state, From 7e55e5c425f94bd946f2142f33b20c8c6249905e Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Thu, 21 Apr 2022 19:47:14 +0000 Subject: [PATCH 07/35] Tests passing --- pgmax/fg/graph.py | 49 ++++------- pgmax/groups/enumeration.py | 18 ++-- pgmax/groups/variables.py | 170 ++++++++++++++++++------------------ tests/factors/test_and.py | 14 +-- tests/factors/test_or.py | 14 +-- tests/fg/test_graph.py | 127 +++++++++++++-------------- tests/fg/test_groups.py | 123 +++++--------------------- tests/test_pgmax.py | 99 ++++++++++----------- 8 files changed, 258 insertions(+), 356 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 3e5cb570..264a6ab0 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -61,21 +61,13 @@ def __post_init__(self): start = time.time() if isinstance(self.variables, (vgroup.NDVariableArray, vgroup.VariableDict)): - self.variable_groups = [self.variables] + self._variable_groups = [self.variables] else: - self.variable_groups = self.variables - - # TODO: remove? - self._variable_group = self.variable_groups - # self._variable_group: Mapping[ - # int, groups.VariableGroup - # ] = collections.OrderedDict() - # for variable_group in self.variables: - # self._variable_group[variable_group.__hash__()] = variable_group + self._variable_groups = self.variables self._variables = [ variable - for variable_group in self.variable_groups + for variable_group in self._variable_groups for variable in variable_group.variables ] print("1", time.time() - start) @@ -368,7 +360,7 @@ def fg_state(self) -> FactorGraphState: ) return FactorGraphState( - variable_groups=self._variable_group, + variable_groups=self._variable_groups, vars_to_starts=self._vars_to_starts, num_var_states=self._num_var_states, total_factor_num_states=self._total_factor_num_states, @@ -610,12 +602,12 @@ def update_ftov_msgs( (1) provided ftov_msgs shape does not match the expected ftov_msgs shape. (2) provided name is not valid for ftov_msgs updates. """ - for names, data in updates.items(): - if names in fg_state.variable_groups: - if data.shape != (names.total_num_states,): + for variable, data in updates.items(): + if variable in fg_state.vars_to_starts: + if data.shape != (variable[1],): raise ValueError( f"Given belief shape {data.shape} does not match expected " - f"shape {(names.total_num_states,)} for variable." + f"shape {(variable[1],)} for variable {variable}." ) var_states_for_edges = np.concatenate( @@ -625,13 +617,13 @@ def update_ftov_msgs( ] ) - # starts = np.nonzero( - # var_states_for_edges == fg_state.vars_to_starts[variable] - # )[0] - # for start in starts: - # ftov_msgs = ftov_msgs.at[start : start + variable.num_states].set( - # data / starts.shape[0] - # ) + starts = np.nonzero( + var_states_for_edges == fg_state.vars_to_starts[variable] + )[0] + for start in starts: + ftov_msgs = ftov_msgs.at[start : start + variable[1]].set( + data / starts.shape[0] + ) else: raise ValueError( "Invalid names for setting messages. " @@ -691,26 +683,19 @@ def __setitem__( @typing.overload def __setitem__( self, - names: Any, + names: Tuple[int, int], data: Union[np.ndarray, jnp.ndarray], ) -> None: """Spreading beliefs at a variable to all connected factors Args: - names: The name of the variable + variable: A tuple representing a variable data: An array containing the beliefs to be spread uniformly across all factor to variable messages involving this variable. """ def __setitem__(self, names, data) -> None: - if ( - isinstance(names, tuple) - and len(names) == 2 - and names[1] in self.fg_state.variable_groups - ): - names = (frozenset(names[0]), names[1]) - object.__setattr__( self, "value", diff --git a/pgmax/groups/enumeration.py b/pgmax/groups/enumeration.py index 982a8a57..03f17cb3 100644 --- a/pgmax/groups/enumeration.py +++ b/pgmax/groups/enumeration.py @@ -107,14 +107,17 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: ValueError: if data is not of the right shape. """ num_factors = len(self.factors) + factor_edges_num_states = sum( + [variable[1] for variable in self.variables_for_factors[0]] + ) if ( data.shape != (num_factors, self.factor_configs.shape[0]) - and data.shape != (num_factors, np.prod(self.factor_configs.shape)) + and data.shape != (num_factors, factor_edges_num_states) and data.shape != (self.factor_configs.shape[0],) ): raise ValueError( f"data should be of shape {(num_factors, self.factor_configs.shape[0])} or " - f"{(num_factors, np.prod(self.factor_configs.shape))} or " + f"{(num_factors, factor_edges_num_states)} or " f"{(self.factor_configs.shape[0],)}. Got {data.shape}." ) @@ -149,18 +152,19 @@ def unflatten( ) num_factors = len(self.factors) + factor_edges_num_states = sum( + [variable[1] for variable in self.variables_for_factors[0]] + ) if flat_data.size == num_factors * self.factor_configs.shape[0]: data = flat_data.reshape( (num_factors, self.factor_configs.shape[0]), ) - elif flat_data.size == num_factors * np.sum(self.factors[0].edges_num_states): - data = flat_data.reshape( - (num_factors, np.sum(self.factors[0].edges_num_states)) - ) + elif flat_data.size == num_factors * np.sum(factor_edges_num_states): + data = flat_data.reshape((num_factors, np.sum(factor_edges_num_states))) else: raise ValueError( f"flat_data should be compatible with shape {(num_factors, self.factor_configs.shape[0])} " - f"or {(num_factors, np.sum(self.factors[0].edges_num_states))}. Got {flat_data.shape}." + f"or {(num_factors, np.sum(factor_edges_num_states))}. Got {flat_data.shape}." ) return data diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index b59e360c..28572ae1 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -3,7 +3,7 @@ import random from dataclasses import dataclass from functools import total_ordering -from typing import Any, List, Tuple, Union +from typing import Any, Dict, Hashable, List, Mapping, Tuple, Union import jax import jax.numpy as jnp @@ -28,8 +28,6 @@ class NDVariableArray(groups.VariableGroup): num_states: np.ndarray def __post_init__(self): - # super().__post_init__() - if isinstance(self.num_states, int): num_states = np.full(self.shape, fill_value=self.num_states) object.__setattr__(self, "num_states", num_states) @@ -124,7 +122,7 @@ def unflatten( data = flat_data.reshape(self.shape + (self.num_states.max(),)) else: raise ValueError( - f"flat_data should be compatible with shape {self.shape} or {self.shape + (self.num_states,)}. " + f"flat_data should be compatible with shape {self.shape} or {self.shape + (self.num_states.max(),)}. " f"Got {flat_data.shape}." ) @@ -140,8 +138,8 @@ class VariableDict(groups.VariableGroup): variable_names: A tuple of all names of the variables in this variable group """ - num_states: int variable_names: Tuple[Any, ...] + num_states: int @cached_property def variables(self) -> List[Tuple]: @@ -153,83 +151,85 @@ def __getitem__(self, val): # Numpy indexation will throw IndexError for us if out-of-bounds return (val, self.num_states) - # def flatten( - # self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]] - # ) -> jnp.ndarray: - # """Function that turns meaningful structured data into a flat data array for internal use. - - # Args: - # data: Meaningful structured data. Should be a mapping with names from self.variable_names. - # Each value should be an array of shape (1,) (for e.g. MAP decodings) or - # (self.num_states,) (for e.g. evidence, beliefs). - - # Returns: - # A flat jnp.array for internal use - - # Raises: - # ValueError if: - # (1) data is referring to a non-existing variable - # (2) data is not of the correct shape - # """ - # for name in data: - # if name not in self.variable_names: - # raise ValueError( - # f"data is referring to a non-existent variable {name}." - # ) - - # if data[name].shape != (self.num_states,) and data[name].shape != (1,): - # raise ValueError( - # f"Variable {name} expects a data array of shape " - # f"{(self.num_states,)} or (1,). Got {data[name].shape}." - # ) - - # flat_data = jnp.concatenate([data[name].flatten() for name in self.variable_names]) - # return flat_data - - # def unflatten( - # self, flat_data: Union[np.ndarray, jnp.ndarray] - # ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]: - # """Function that recovers meaningful structured data from internal flat data array - - # Args: - # flat_data: Internal flat data array. - - # Returns: - # Meaningful structured data. Should be a mapping with names from self.variable_names. - # Each value should be an array of shape (1,) (for e.g. MAP decodings) or - # (self.num_states,) (for e.g. evidence, beliefs). - - # Raises: - # ValueError if: - # (1) flat_data is not a 1D array - # (2) flat_data is not of the right shape - # """ - # if flat_data.ndim != 1: - # raise ValueError( - # f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." - # ) - - # num_variables = len(self.variable_names) - # num_variable_states = len(self.variable_names) * self.num_states - # if flat_data.shape[0] == num_variables: - # use_num_states = False - # elif flat_data.shape[0] == num_variable_states: - # use_num_states = True - # else: - # raise ValueError( - # f"flat_data should be either of shape (num_variables(={len(self.variables)}),), " - # f"or (num_variable_states(={num_variable_states}),). " - # f"Got {flat_data.shape}" - # ) - - # start = 0 - # data = {} - # for name in self.variable_names: - # if use_num_states: - # data[name] = flat_data[start : start + self.num_states] - # start += self.num_states - # else: - # data[name] = flat_data[np.array([start])] - # start += 1 - - # return data + def flatten( + self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]] + ) -> jnp.ndarray: + """Function that turns meaningful structured data into a flat data array for internal use. + + Args: + data: Meaningful structured data. Should be a mapping with names from self.variable_names. + Each value should be an array of shape (1,) (for e.g. MAP decodings) or + (self.num_states,) (for e.g. evidence, beliefs). + + Returns: + A flat jnp.array for internal use + + Raises: + ValueError if: + (1) data is referring to a non-existing variable + (2) data is not of the correct shape + """ + for name in data: + if name not in self.variable_names: + raise ValueError( + f"data is referring to a non-existent variable {name}." + ) + + if data[name].shape != (self.num_states,) and data[name].shape != (1,): + raise ValueError( + f"Variable {name} expects a data array of shape " + f"{(self.num_states,)} or (1,). Got {data[name].shape}." + ) + + flat_data = jnp.concatenate( + [data[name].flatten() for name in self.variable_names] + ) + return flat_data + + def unflatten( + self, flat_data: Union[np.ndarray, jnp.ndarray] + ) -> Dict[Hashable, Union[np.ndarray, jnp.ndarray]]: + """Function that recovers meaningful structured data from internal flat data array + + Args: + flat_data: Internal flat data array. + + Returns: + Meaningful structured data. Should be a mapping with names from self.variable_names. + Each value should be an array of shape (1,) (for e.g. MAP decodings) or + (self.num_states,) (for e.g. evidence, beliefs). + + Raises: + ValueError if: + (1) flat_data is not a 1D array + (2) flat_data is not of the right shape + """ + if flat_data.ndim != 1: + raise ValueError( + f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." + ) + + num_variables = len(self.variable_names) + num_variable_states = len(self.variable_names) * self.num_states + if flat_data.shape[0] == num_variables: + use_num_states = False + elif flat_data.shape[0] == num_variable_states: + use_num_states = True + else: + raise ValueError( + f"flat_data should be either of shape (num_variables(={len(self.variables)}),), " + f"or (num_variable_states(={num_variable_states}),). " + f"Got {flat_data.shape}" + ) + + start = 0 + data = {} + for name in self.variable_names: + if use_num_states: + data[name] = flat_data[start : start + self.num_states] + start += self.num_states + else: + data[name] = flat_data[np.array([start])] + start += 1 + + return data diff --git a/tests/factors/test_and.py b/tests/factors/test_and.py index e95e22a8..19bc08d0 100644 --- a/tests/factors/test_and.py +++ b/tests/factors/test_and.py @@ -58,23 +58,23 @@ def test_run_bp_with_ANDFactors(): variables_for_factors1 = [] variables_for_factors2 = [] for factor_idx in range(num_factors): - variable_names1 = [ + variables1 = [ parents_variables1[idx] for idx in range( num_parents_cumsum[factor_idx], num_parents_cumsum[factor_idx + 1], ) ] + [children_variables1[factor_idx]] - variables_for_factors1.append(variable_names1) + variables_for_factors1.append(variables1) - variable_names2 = [ + variables2 = [ parents_variables2[idx] for idx in range( num_parents_cumsum[factor_idx], num_parents_cumsum[factor_idx + 1], ) ] + [children_variables2[factor_idx]] - variables_for_factors2.append(variable_names2) + variables_for_factors2.append(variables2) # Option 1: Define EnumerationFactors equivalent to the ANDFactors for factor_idx in range(num_factors): @@ -96,7 +96,7 @@ def test_run_bp_with_ANDFactors(): if factor_idx < num_factors // 2: # Add the first half of factors to FactorGraph1 fg1.add_factor( - variable_names=variables_for_factors1[factor_idx], + variables=variables_for_factors1[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) @@ -104,14 +104,14 @@ def test_run_bp_with_ANDFactors(): if idx != 0: # Add the second half of factors to FactorGraph2 fg2.add_factor( - variable_names=variables_for_factors2[factor_idx], + variables=variables_for_factors2[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) else: # Add all the EnumerationFactors to FactorGraph1 for the first iter fg1.add_factor( - variable_names=variables_for_factors1[factor_idx], + variables=variables_for_factors1[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) diff --git a/tests/factors/test_or.py b/tests/factors/test_or.py index 7a6c0905..162a7338 100644 --- a/tests/factors/test_or.py +++ b/tests/factors/test_or.py @@ -58,23 +58,23 @@ def test_run_bp_with_ORFactors(): variables_for_factors1 = [] variables_for_factors2 = [] for factor_idx in range(num_factors): - variable_names1 = [ + variables1 = [ parents_variables1[idx] for idx in range( num_parents_cumsum[factor_idx], num_parents_cumsum[factor_idx + 1], ) ] + [children_variables1[factor_idx]] - variables_for_factors1.append(variable_names1) + variables_for_factors1.append(variables1) - variable_names2 = [ + variables2 = [ parents_variables2[idx] for idx in range( num_parents_cumsum[factor_idx], num_parents_cumsum[factor_idx + 1], ) ] + [children_variables2[factor_idx]] - variables_for_factors2.append(variable_names2) + variables_for_factors2.append(variables2) # Option 1: Define EnumerationFactors equivalent to the ORFactors for factor_idx in range(num_factors): @@ -94,7 +94,7 @@ def test_run_bp_with_ORFactors(): if factor_idx < num_factors // 2: # Add the first half of factors to FactorGraph1 fg1.add_factor( - variable_names=variables_for_factors1[factor_idx], + variables=variables_for_factors1[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) @@ -102,14 +102,14 @@ def test_run_bp_with_ORFactors(): if idx != 0: # Add the second half of factors to FactorGraph2 fg2.add_factor( - variable_names=variables_for_factors2[factor_idx], + variables=variables_for_factors2[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) else: # Add all the EnumerationFactors to FactorGraph1 for the first iter fg1.add_factor( - variable_names=variables_for_factors1[factor_idx], + variables=variables_for_factors1[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py index 9185536c..ea2a7500 100644 --- a/tests/fg/test_graph.py +++ b/tests/fg/test_graph.py @@ -100,72 +100,67 @@ def test_bp_state(): ) -# def test_log_potentials(): -# vg = vgroup.VariableDict(15, (0,)) -# fg = graph.FactorGraph(vg) -# fg.add_factor( -# variables=[vg[0]], -# factor_configs=np.arange(10)[:, None], -# name="test", -# ) -# with pytest.raises( -# ValueError, -# match=re.escape("Expected log potentials shape (10,) for factor group test."), -# ): -# fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15)) - -# with pytest.raises( -# ValueError, -# match=re.escape(f"Invalid name {frozenset([0])} for log potentials updates."), -# ): -# fg.bp_state.log_potentials[vg[0]] = np.zeros(10) - -# with pytest.raises( -# ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)") -# ): -# graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15)) - -# log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10)) -# assert jnp.all(log_potentials["test"] == jnp.zeros(10)) -# with pytest.raises( -# ValueError, -# match=re.escape(f"Invalid name {frozenset([1])} for log potentials updates."), -# ): -# fg.bp_state.log_potentials[[1]] - - -# def test_ftov_msgs(): -# variable_group = vgroup.VariableDict(15, (0,)) -# fg = graph.FactorGraph(variable_group) -# fg.add_factor( -# variable_names=[0], -# factor_configs=np.arange(10)[:, None], -# name="test", -# ) -# with pytest.raises( -# ValueError, -# match=re.escape("Invalid names for setting messages"), -# ): -# fg.bp_state.ftov_msgs[[0], 0] = np.ones(10) - -# with pytest.raises( -# ValueError, -# match=re.escape( -# "Given belief shape (10,) does not match expected shape (15,) for variable 0" -# ), -# ): -# fg.bp_state.ftov_msgs[0] = np.ones(10) - -# with pytest.raises( -# ValueError, match=re.escape("Expected messages shape (15,). Got (10,)") -# ): -# graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10)) - -# ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15)) -# with pytest.raises( -# TypeError, match=re.escape("'FToVMessages' object is not subscriptable") -# ): -# ftov_msgs[(10,)] +def test_log_potentials(): + vg = vgroup.VariableDict(15, (0,)) + fg = graph.FactorGraph(vg) + fg.add_factor( + variables=[vg[0]], + factor_configs=np.arange(10)[:, None], + name="test", + ) + with pytest.raises( + ValueError, + match=re.escape("Expected log potentials shape (10,) for factor group test."), + ): + fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15)) + + with pytest.raises( + ValueError, + match=re.escape("Invalid name (0, 15) for log potentials updates."), + ): + fg.bp_state.log_potentials[vg[0]] = np.zeros(10) + + with pytest.raises( + ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)") + ): + graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15)) + + log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10)) + assert jnp.all(log_potentials["test"] == jnp.zeros(10)) + + +def test_ftov_msgs(): + vg = vgroup.VariableDict(15, (0,)) + fg = graph.FactorGraph(vg) + fg.add_factor( + variables=[vg[0]], + factor_configs=np.arange(10)[:, None], + name="test", + ) + with pytest.raises( + ValueError, + match=re.escape("Invalid names for setting messages"), + ): + fg.bp_state.ftov_msgs[0] = np.ones(10) + + with pytest.raises( + ValueError, + match=re.escape( + "Given belief shape (10,) does not match expected shape (15,) for variable (0, 15)." + ), + ): + fg.bp_state.ftov_msgs[vg[0]] = np.ones(10) + + with pytest.raises( + ValueError, match=re.escape("Expected messages shape (15,). Got (10,)") + ): + graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10)) + + ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15)) + with pytest.raises( + TypeError, match=re.escape("'FToVMessages' object is not subscriptable") + ): + ftov_msgs[(10,)] def test_evidence(): diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index 40fe3351..9a0fd312 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -5,83 +5,10 @@ import numpy as np import pytest -from pgmax.fg import groups, nodes from pgmax.groups import enumeration from pgmax.groups import variables as vgroup -def test_composite_variable_group(): - variable_dict1 = vgroup.VariableDict(15, tuple([0, 1, 2])) - variable_dict2 = vgroup.VariableDict(15, tuple([0, 1, 2])) - composite_variable_sequence = groups.CompositeVariableGroup( - [variable_dict1, variable_dict2] - ) - composite_variable_dict = groups.CompositeVariableGroup( - {(0, 1): variable_dict1, (2, 3): variable_dict2} - ) - with pytest.raises(ValueError, match="The name needs to have at least 2 elements"): - composite_variable_sequence[(0,)] - - assert composite_variable_sequence[0, 1] == variable_dict1[1] - assert ( - composite_variable_sequence[[(0, 1), (1, 2)]] - == composite_variable_dict[[((0, 1), 1), ((2, 3), 2)]] - ) - assert composite_variable_dict[(0, 1), 0] == variable_dict1[0] - assert composite_variable_dict[[((0, 1), 1), ((2, 3), 2)]] == [ - variable_dict1[1], - variable_dict2[2], - ] - assert jnp.all( - composite_variable_sequence.flatten( - [{name: np.zeros(15) for name in range(3)} for _ in range(2)] - ) - == composite_variable_dict.flatten( - { - (0, 1): {name: np.zeros(15) for name in range(3)}, - (2, 3): {name: np.zeros(15) for name in range(3)}, - } - ) - ) - assert jnp.all( - jnp.array( - jax.tree_util.tree_leaves( - jax.tree_util.tree_multimap( - lambda x, y: jnp.all(x == y), - composite_variable_sequence.unflatten(jnp.zeros(15 * 3 * 2)), - [{name: jnp.zeros(15) for name in range(3)} for _ in range(2)], - ) - ) - ) - ) - assert jnp.all( - jnp.array( - jax.tree_util.tree_leaves( - jax.tree_util.tree_multimap( - lambda x, y: jnp.all(x == y), - composite_variable_dict.unflatten(jnp.zeros(3 * 2)), - { - (0, 1): {name: np.zeros(1) for name in range(3)}, - (2, 3): {name: np.zeros(1) for name in range(3)}, - }, - ) - ) - ) - ) - with pytest.raises( - ValueError, match="Can only unflatten 1D array. Got a 2D array." - ): - composite_variable_dict.unflatten(jnp.zeros((10, 20))) - - with pytest.raises( - ValueError, - match=re.escape( - "flat_data should be either of shape (num_variables(=6),), or (num_variable_states(=90),)" - ), - ): - composite_variable_dict.unflatten(jnp.zeros((100))) - - def test_variable_dict(): variable_dict = vgroup.VariableDict(15, tuple([0, 1, 2])) with pytest.raises( @@ -123,9 +50,7 @@ def test_variable_dict(): def test_nd_variable_array(): - variable_group = vgroup.NDVariableArray(2, (1,)) - assert isinstance(variable_group[0], nodes.Variable) - variable_group = vgroup.NDVariableArray(3, (2, 2)) + variable_group = vgroup.NDVariableArray(shape=(2, 2), num_states=3) with pytest.raises( ValueError, match=re.escape("data should be of shape (2, 2) or (2, 2, 3). Got (3, 3)."), @@ -153,16 +78,15 @@ def test_nd_variable_array(): def test_enumeration_factor_group(): - variable_group = vgroup.NDVariableArray(3, (2, 2)) + vg = vgroup.NDVariableArray(shape=(2, 2), num_states=3) with pytest.raises( ValueError, match=re.escape("Expected log potentials shape: (1,) or (2, 1). Got (3, 2)"), ): enumeration_factor_group = enumeration.EnumerationFactorGroup( - variable_group=variable_group, - variable_names_for_factors=[ - [(0, 0), (0, 1), (1, 1)], - [(0, 1), (1, 0), (1, 1)], + variables_for_factors=[ + [vg[0, 0], vg[0, 1], vg[1, 1]], + [vg[0, 1], vg[1, 0], vg[1, 1]], ], factor_configs=np.zeros((1, 3), dtype=int), log_potentials=np.zeros((3, 2)), @@ -170,21 +94,22 @@ def test_enumeration_factor_group(): with pytest.raises(ValueError, match=re.escape("Potentials should be floats")): enumeration_factor_group = enumeration.EnumerationFactorGroup( - variable_group=variable_group, - variable_names_for_factors=[ - [(0, 0), (0, 1), (1, 1)], - [(0, 1), (1, 0), (1, 1)], + variables_for_factors=[ + [vg[0, 0], vg[0, 1], vg[1, 1]], + [vg[0, 1], vg[1, 0], vg[1, 1]], ], factor_configs=np.zeros((1, 3), dtype=int), log_potentials=np.zeros((2, 1), dtype=int), ) enumeration_factor_group = enumeration.EnumerationFactorGroup( - variable_group=variable_group, - variable_names_for_factors=[[(0, 0), (0, 1), (1, 1)], [(0, 1), (1, 0), (1, 1)]], + variables_for_factors=[ + [vg[0, 0], vg[0, 1], vg[1, 1]], + [vg[0, 1], vg[1, 0], vg[1, 1]], + ], factor_configs=np.zeros((1, 3), dtype=int), ) - name = [(0, 0), (1, 1)] + name = [vg[0, 0], vg[1, 1]] with pytest.raises( ValueError, match=re.escape( @@ -194,7 +119,7 @@ def test_enumeration_factor_group(): enumeration_factor_group[name] assert ( - enumeration_factor_group[[(0, 1), (1, 0), (1, 1)]] + enumeration_factor_group[[vg[0, 1], vg[1, 0], vg[1, 1]]] == enumeration_factor_group.factors[1] ) with pytest.raises( @@ -227,20 +152,20 @@ def test_enumeration_factor_group(): def test_pairwise_factor_group(): - variable_group = vgroup.NDVariableArray(3, (2, 2)) + vg = vgroup.NDVariableArray(shape=(2, 2), num_states=3) with pytest.raises( ValueError, match=re.escape("log_potential_matrix should be either a 2D array") ): enumeration.PairwiseFactorGroup( - variable_group, [[(0, 0), (1, 1)]], np.zeros((1,), dtype=float) + [[vg[0, 0], vg[1, 1]]], np.zeros((1,), dtype=float) ) with pytest.raises( ValueError, match=re.escape("Potential matrix should be floats") ): enumeration.PairwiseFactorGroup( - variable_group, [[(0, 0), (1, 1)]], np.zeros((3, 3), dtype=int) + [[vg[0, 0], vg[1, 1]]], np.zeros((3, 3), dtype=int) ) with pytest.raises( @@ -250,7 +175,7 @@ def test_pairwise_factor_group(): ), ): enumeration.PairwiseFactorGroup( - variable_group, [[(0, 0), (1, 1)]], np.zeros((2, 3, 3), dtype=float) + [[vg[0, 0], vg[1, 1]]], np.zeros((2, 3, 3), dtype=float) ) with pytest.raises( @@ -260,20 +185,18 @@ def test_pairwise_factor_group(): ), ): enumeration.PairwiseFactorGroup( - variable_group, [[(0, 0), (1, 1), (0, 1)]], np.zeros((3, 3), dtype=float) + [[vg[0, 0], vg[1, 1], vg[0, 1]]], np.zeros((3, 3), dtype=float) ) + name = [vg[0, 0], vg[1, 1]] with pytest.raises( ValueError, - match=re.escape("The specified pairwise factor [(0, 0), (1, 1)]"), + match=re.escape(f"The specified pairwise factor {name}"), ): - enumeration.PairwiseFactorGroup( - variable_group, [[(0, 0), (1, 1)]], np.zeros((4, 4), dtype=float) - ) + enumeration.PairwiseFactorGroup([name], np.zeros((4, 4), dtype=float)) pairwise_factor_group = enumeration.PairwiseFactorGroup( - variable_group, - [[(0, 0), (1, 1)], [(1, 0), (0, 1)]], + [[vg[0, 0], vg[1, 1]], [vg[1, 0], vg[0, 1]]], ) with pytest.raises( ValueError, diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py index 270eb7d9..a5e720a8 100644 --- a/tests/test_pgmax.py +++ b/tests/test_pgmax.py @@ -214,18 +214,13 @@ def create_valid_suppression_config_arr(suppression_diameter): # We create a NDVariableArray such that the [0,i,j] entry corresponds to the vertical cut variable (i.e, the one # attached horizontally to the factor) that's at that location in the image, and the [1,i,j] entry corresponds to # the horizontal cut variable (i.e, the one attached vertically to the factor) that's at that location - grid_vars_group = vgroup.NDVariableArray(3, (2, M - 1, N - 1)) + grid_vars = vgroup.NDVariableArray(shape=(2, M - 1, N - 1), num_states=3) # Make a group of additional variables for the edges of the grid extra_row_names: List[Tuple[Any, ...]] = [(0, row, N - 1) for row in range(M - 1)] extra_col_names: List[Tuple[Any, ...]] = [(1, M - 1, col) for col in range(N - 1)] additional_names = tuple(extra_row_names + extra_col_names) - additional_names_group = vgroup.VariableDict(3, additional_names) - - # Combine these two VariableGroups into one CompositeVariableGroup - composite_grid_group = groups.CompositeVariableGroup( - {"grid_vars": grid_vars_group, "additional_vars": additional_names_group} - ) + additional_vars = vgroup.VariableDict(additional_names, num_states=3) gt_has_cuts = gt_has_cuts.astype(np.int32) @@ -249,17 +244,19 @@ def create_valid_suppression_config_arr(suppression_diameter): size=evidence_vals_arr[1:].shape ) # This adds logistic noise for every evidence entry try: - _ = composite_grid_group["grid_vars", i, row, col] + _ = grid_vars[i, row, col] grid_evidence_arr[i, row, col] = evidence_vals_arr - except ValueError: + except IndexError: try: - _ = composite_grid_group["additional_vars", i, row, col] - additional_vars_evidence_dict[(i, row, col)] = evidence_vals_arr - except ValueError: + _ = additional_vars[i, row, col] + additional_vars_evidence_dict[ + additional_vars[i, row, col] + ] = evidence_vals_arr + except IndexError: pass # Create the factor graph - fg = graph.FactorGraph(variables=composite_grid_group) + fg = graph.FactorGraph(variables=[grid_vars, additional_vars]) # Imperatively add EnumerationFactorGroups (each consisting of just one EnumerationFactor) to # the graph! @@ -267,37 +264,37 @@ def create_valid_suppression_config_arr(suppression_diameter): for col in range(N - 1): if row != M - 2 and col != N - 2: curr_names = [ - ("grid_vars", 0, row, col), - ("grid_vars", 1, row, col), - ("grid_vars", 0, row, col + 1), - ("grid_vars", 1, row + 1, col), + grid_vars[0, row, col], + grid_vars[1, row, col], + grid_vars[0, row, col + 1], + grid_vars[1, row + 1, col], ] elif row != M - 2: curr_names = [ - ("grid_vars", 0, row, col), - ("grid_vars", 1, row, col), - ("additional_vars", 0, row, col + 1), - ("grid_vars", 1, row + 1, col), + grid_vars[0, row, col], + grid_vars[1, row, col], + additional_vars[0, row, col + 1], + grid_vars[1, row + 1, col], ] elif col != N - 2: curr_names = [ - ("grid_vars", 0, row, col), - ("grid_vars", 1, row, col), - ("grid_vars", 0, row, col + 1), - ("additional_vars", 1, row + 1, col), + grid_vars[0, row, col], + grid_vars[1, row, col], + grid_vars[0, row, col + 1], + additional_vars[1, row + 1, col], ] else: curr_names = [ - ("grid_vars", 0, row, col), - ("grid_vars", 1, row, col), - ("additional_vars", 0, row, col + 1), - ("additional_vars", 1, row + 1, col), + grid_vars[0, row, col], + grid_vars[1, row, col], + additional_vars[0, row, col + 1], + additional_vars[1, row + 1, col], ] if row % 2 == 0: fg.add_factor( - variable_names=curr_names, + variables=curr_names, factor_configs=valid_configs_non_supp, log_potentials=np.zeros( valid_configs_non_supp.shape[0], dtype=float @@ -306,7 +303,7 @@ def create_valid_suppression_config_arr(suppression_diameter): ) else: fg.add_factor( - variable_names=curr_names, + variables=curr_names, factor_configs=valid_configs_non_supp, log_potentials=np.zeros( valid_configs_non_supp.shape[0], dtype=float @@ -321,14 +318,14 @@ def create_valid_suppression_config_arr(suppression_diameter): if col != N - 1: vert_suppression_names.append( [ - ("grid_vars", 0, r, col) + grid_vars[0, r, col] for r in range(start_row, start_row + SUPPRESSION_DIAMETER) ] ) else: vert_suppression_names.append( [ - ("additional_vars", 0, r, col) + additional_vars[0, r, col] for r in range(start_row, start_row + SUPPRESSION_DIAMETER) ] ) @@ -339,14 +336,14 @@ def create_valid_suppression_config_arr(suppression_diameter): if row != M - 1: horz_suppression_names.append( [ - ("grid_vars", 1, row, c) + grid_vars[1, row, c] for c in range(start_col, start_col + SUPPRESSION_DIAMETER) ] ) else: horz_suppression_names.append( [ - ("additional_vars", 1, row, c) + additional_vars[1, row, c] for c in range(start_col, start_col + SUPPRESSION_DIAMETER) ] ) @@ -354,12 +351,12 @@ def create_valid_suppression_config_arr(suppression_diameter): # Add the suppression factors to the graph via kwargs fg.add_factor_group( factory=enumeration.EnumerationFactorGroup, - variable_names_for_factors=vert_suppression_names, + variables_for_factors=vert_suppression_names, factor_configs=valid_configs_supp, ) fg.add_factor_group( factory=enumeration.EnumerationFactorGroup, - variable_names_for_factors=horz_suppression_names, + variables_for_factors=horz_suppression_names, factor_configs=valid_configs_supp, log_potentials=np.zeros(valid_configs_supp.shape[0], dtype=float), ) @@ -373,8 +370,8 @@ def create_valid_suppression_config_arr(suppression_diameter): for this_wiring in fg.fg_state.wiring.values() ] ) - bp_state.evidence["grid_vars"] = grid_evidence_arr - bp_state.evidence["additional_vars"] = additional_vars_evidence_dict + bp_state.evidence[grid_vars] = grid_evidence_arr + bp_state.evidence[additional_vars] = additional_vars_evidence_dict bp = graph.BP(bp_state) bp_arrays = bp.run_bp(bp.init(), num_iters=100) # Test that the output messages are close to the true messages @@ -388,15 +385,15 @@ def test_e2e_heretic(): # Define some global constants im_size = (30, 30) # Instantiate all the Variables in the factor graph via VariableGroups - pixel_vars = vgroup.NDVariableArray(3, im_size) + pixel_vars = vgroup.NDVariableArray(shape=im_size, num_states=3) hidden_vars = vgroup.NDVariableArray( - 17, (im_size[0] - 2, im_size[1] - 2) + shape=(im_size[0] - 2, im_size[1] - 2), num_states=17 ) # Each hidden var is connected to a 3x3 patch of pixel vars bXn = np.zeros((30, 30, 3)) # Create the factor graph - fg = graph.FactorGraph((pixel_vars, hidden_vars)) + fg = graph.FactorGraph([pixel_vars, hidden_vars]) def binary_connected_variables( num_hidden_rows, num_hidden_cols, kernel_row, kernel_col @@ -406,8 +403,8 @@ def binary_connected_variables( for h_col in range(num_hidden_cols): ret_list.append( [ - (1, h_row, h_col), - (0, h_row + kernel_row, h_col + kernel_col), + hidden_vars[h_row, h_col], + pixel_vars[h_row + kernel_row, h_col + kernel_col], ] ) return ret_list @@ -417,22 +414,20 @@ def binary_connected_variables( for k_col in range(3): fg.add_factor_group( factory=enumeration.PairwiseFactorGroup, - variable_names_for_factors=binary_connected_variables( - 28, 28, k_row, k_col - ), + variables_for_factors=binary_connected_variables(28, 28, k_row, k_col), log_potential_matrix=W_pot[:, :, k_row, k_col], name=(k_row, k_col), ) # Assign evidence to pixel vars bp_state = fg.bp_state - bp_state.evidence[0] = np.array(bXn) - bp_state.evidence[0, 0, 0] = np.array([0.0, 0.0, 0.0]) - bp_state.evidence[0, 0, 0] - bp_state.evidence[1, 0, 0] + bp_state.evidence[pixel_vars] = np.array(bXn) + bp_state.evidence[pixel_vars[0, 0]] = np.array([0.0, 0.0, 0.0]) + bp_state.evidence[pixel_vars[0, 0]] + bp_state.evidence[hidden_vars[0, 0]] assert isinstance(bp_state.evidence.value, np.ndarray) assert len(sum(fg.factors.values(), ())) == 7056 bp = graph.BP(bp_state, temperature=1.0) bp_arrays = bp.run_bp(bp.init(), num_iters=1) marginals = graph.get_marginals(bp.get_beliefs(bp_arrays)) - assert jnp.allclose(jnp.sum(marginals[0], axis=-1), 1.0) + assert jnp.allclose(jnp.sum(marginals[pixel_vars], axis=-1), 1.0) From 9a2dcd215f322029a6ed5173bb5c893b03d99bfd Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Thu, 21 Apr 2022 21:01:09 +0000 Subject: [PATCH 08/35] Variables --- pgmax/groups/variables.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index 28572ae1..8dac5c7c 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -148,7 +148,8 @@ def variables(self) -> List[Tuple]: ) def __getitem__(self, val): - # Numpy indexation will throw IndexError for us if out-of-bounds + if val not in self.variable_names: + raise ValueError(f"Variable {val} is not in VariableDict") return (val, self.num_states) def flatten( @@ -170,7 +171,7 @@ def flatten( (2) data is not of the correct shape """ for name in data: - if name not in self.variable_names: + if name not in self.variables: raise ValueError( f"data is referring to a non-existent variable {name}." ) @@ -181,9 +182,7 @@ def flatten( f"{(self.num_states,)} or (1,). Got {data[name].shape}." ) - flat_data = jnp.concatenate( - [data[name].flatten() for name in self.variable_names] - ) + flat_data = jnp.concatenate([data[name].flatten() for name in self.variables]) return flat_data def unflatten( From c6ae8d80948476525c00546f866f6cb49389160e Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Thu, 21 Apr 2022 23:16:37 +0000 Subject: [PATCH 09/35] Tests + mypy --- .pre-commit-config.yaml | 10 ++--- examples/gmrf.py | 31 +++++++++------ examples/pmp_binary_deconvolution.py | 8 +--- examples/rbm.py | 56 ++-------------------------- examples/rcn.py | 8 ++-- pgmax/fg/graph.py | 38 ++++++++++++------- pgmax/fg/groups.py | 1 + pgmax/groups/variables.py | 9 ++++- tests/fg/test_graph.py | 12 +++--- tests/fg/test_groups.py | 6 +-- tests/test_pgmax.py | 35 ++++++++--------- 11 files changed, 94 insertions(+), 120 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7e553c33..cfec6112 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,10 +27,10 @@ repos: args: ['--config', '.flake8.config','--exit-zero'] verbose: true -# - repo: https://github.com/pre-commit/mirrors-mypy -# rev: 'v0.942' # Use the sha / tag you want to point at -# hooks: -# - id: mypy -# additional_dependencies: [tokenize-rt==3.2.0] +- repo: https://github.com/pre-commit/mirrors-mypy + rev: 'v0.942' # Use the sha / tag you want to point at + hooks: + - id: mypy + additional_dependencies: [tokenize-rt==3.2.0] ci: autoupdate_schedule: 'quarterly' diff --git a/examples/gmrf.py b/examples/gmrf.py index 5e8dc86b..95ae4a7a 100644 --- a/examples/gmrf.py +++ b/examples/gmrf.py @@ -61,31 +61,39 @@ # Add top-down factors fg.add_factor_group( factory=enumeration.PairwiseFactorGroup, - variable_names_for_factors=[ - [(ii, jj), (ii + 1, jj)] for ii in range(M - 1) for jj in range(N) + variables_for_factors=[ + [variables[ii, jj], variables[ii + 1, jj]] + for ii in range(M - 1) + for jj in range(N) ], name="top_down", ) # Add left-right factors fg.add_factor_group( factory=enumeration.PairwiseFactorGroup, - variable_names_for_factors=[ - [(ii, jj), (ii, jj + 1)] for ii in range(M) for jj in range(N - 1) + variables_for_factors=[ + [variables[ii, jj], variables[ii, jj + 1]] + for ii in range(M) + for jj in range(N - 1) ], name="left_right", ) # Add diagonal factors fg.add_factor_group( factory=enumeration.PairwiseFactorGroup, - variable_names_for_factors=[ - [(ii, jj), (ii + 1, jj + 1)] for ii in range(M - 1) for jj in range(N - 1) + variables_for_factors=[ + [variables[ii, jj], variables[ii + 1, jj + 1]] + for ii in range(M - 1) + for jj in range(N - 1) ], name="diagonal0", ) fg.add_factor_group( factory=enumeration.PairwiseFactorGroup, - variable_names_for_factors=[ - [(ii, jj), (ii - 1, jj + 1)] for ii in range(1, M) for jj in range(N - 1) + variables_for_factors=[ + [variables[ii, jj], variables[ii - 1, jj + 1]] + for ii in range(1, M) + for jj in range(N - 1) ], name="diagonal1", ) @@ -106,7 +114,7 @@ bp.get_beliefs( bp.run_bp( bp.init( - evidence_updates={None: evidence}, + evidence_updates={variables: evidence}, log_potentials_updates=log_potentials, ), num_iters=15, @@ -114,6 +122,7 @@ ) ) ) + marginals = marginals[variables] pred_image = np.argmax( np.stack( [ @@ -153,7 +162,7 @@ def loss(noisy_image, target_image, log_potentials): bp.get_beliefs( bp.run_bp( bp.init( - evidence_updates={None: evidence}, + evidence_updates={variables: evidence}, log_potentials_updates=log_potentials, ), num_iters=15, @@ -161,7 +170,7 @@ def loss(noisy_image, target_image, log_potentials): ) ) ) - logp = jnp.mean(jnp.log(jnp.sum(target * marginals, axis=-1))) + logp = jnp.mean(jnp.log(jnp.sum(target * marginals[variables], axis=-1))) return -logp diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py index bd418a69..6c31a537 100644 --- a/examples/pmp_binary_deconvolution.py +++ b/examples/pmp_binary_deconvolution.py @@ -107,13 +107,7 @@ def plot_images(images, display=True, nr=None): # # See Section 5.6 of the [PMP paper](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) for more details. -import imp - # %% -from pgmax.fg import graph - -imp.reload(graph) - # The dimensions of W used for the generation of X were (4, 5, 5) but we set them to (5, 6, 6) # to simulate a more realistic scenario in which we do not know their ground truth values n_feat, feat_height, feat_width = 5, 6, 6 @@ -183,7 +177,7 @@ def plot_images(images, display=True, nr=None): X_var = X[idx_img, idx_chan, idx_img_height, idx_img_width] variables_for_ORFactors_dict[X_var].append(SW_var) -print(time.time() - start) +print("After loop", time.time() - start) # Add ANDFactorGroup, which is computationally efficient fg.add_factor_group( diff --git a/examples/rbm.py b/examples/rbm.py index abd7435d..30f5ac2d 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -44,14 +44,8 @@ # %% [markdown] # We can then initialize the factor graph for the RBM with -# -# import imp # %% -from pgmax.fg import graph - -imp.reload(graph) - import time start = time.time() @@ -61,49 +55,6 @@ fg = graph.FactorGraph(variables=[hidden_variables, visible_variables]) print("Time", time.time() - start) -# %% -import itertools - -start = time.time() -variable_names_for_factors = factors = list( - map( - lambda ij: ( - hidden_variables.variable_names[ij[0]], - visible_variables.variable_names[ij[1]], - ), - list(itertools.product(range(bh.shape[0]), range(bv.shape[0]))), - ) -) -print("Time", time.time() - start, len(variable_names_for_factors)) - -import numba as nb - - -@nb.jit(parallel=False, cache=True, fastmath=True, nopython=True) -def run_numba(h, v, f): - for h_idx in nb.prange(h.shape[0]): - for v_idx in nb.prange(v.shape[0]): - f[h_idx * v.shape[0] + v_idx, 0] = h[h_idx] - f[h_idx * v.shape[0] + v_idx, 1] = v[v_idx] - - -start = time.time() -variable_names_for_factors = np.empty(shape=(bv.shape[0] * bh.shape[0], 2), dtype=int) -run_numba( - hidden_variables.variable_names, - visible_variables.variable_names, - variable_names_for_factors, -) -print("Time", time.time() - start, len(variable_names_for_factors)) - -start = time.time() -variable_names_for_factors = [ - [hidden_variables[ii], visible_variables[jj]] - for ii in range(bh.shape[0]) - for jj in range(bv.shape[0]) -] -print("Time", time.time() - start, len(variable_names_for_factors)) - # %% [markdown] # [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup) (e.g. an [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)), or a list/dictionary of [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed. # @@ -111,6 +62,7 @@ def run_numba(h, v, f): # %% start = time.time() + # Add unary factors fg.add_factor_group( factory=enumeration.EnumerationFactorGroup, @@ -136,11 +88,10 @@ def run_numba(h, v, f): for ii in range(bh.shape[0]) for jj in range(bv.shape[0]) ], - # variable_names_for_factors=variable_names_for_factors, log_potential_matrix=log_potential_matrix, ) -# # %snakeviz fg.add_factor_group(factory=enumeration.PairwiseFactorGroup, variable_names_for_factors=v, log_potential_matrix=log_potential_matrix,) +# # %snakeviz fg.add_factor_group(factory=enumeration.PairwiseFactorGroup, variables_for_factors=v, log_potential_matrix=log_potential_matrix,) print("Time", time.time() - start) @@ -227,7 +178,8 @@ def run_numba(h, v, f): # %% fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow( - graph.map_states(beliefs)[visible_variables].copy().reshape((28, 28)), cmap="gray" + graph.decode_map_states(beliefs)[visible_variables].copy().reshape((28, 28)), + cmap="gray", ) ax.axis("off") diff --git a/examples/rcn.py b/examples/rcn.py index 44fc1b85..97de9fff 100644 --- a/examples/rcn.py +++ b/examples/rcn.py @@ -211,11 +211,11 @@ def fetch_mnist_dataset(test_size: int, seed: int = 5) -> tuple[np.ndarray, np.n 2 * vps + 1 ) # The number of pool choices for the different variables of the PGM. -variables_all_models = {} +variables_all_models = [] for idx in range(frcs.shape[0]): frc = frcs[idx] - variables_all_models[idx] = vgroup.NDVariableArray( - num_states=M, shape=(frc.shape[0],) + variables_all_models.append( + vgroup.NDVariableArray(num_states=M, shape=(frc.shape[0],)) ) end = time.time() @@ -280,7 +280,7 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray: for e in edge: i1, i2, r = e fg.add_factor( - [(idx, i1), (idx, i2)], + [variables_all_models[idx][i1], variables_all_models[idx][i2]], valid_configs_list[r], ) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 264a6ab0..77fc489a 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -83,7 +83,6 @@ def __post_init__(self): ] = collections.OrderedDict( [(factor_type, set()) for factor_type in FAC_TO_VAR_UPDATES] ) - print("2", time.time() - start) # Used to add FactorGroups vars_num_states = [variable[1] for variable in self._variables] @@ -92,13 +91,14 @@ def __post_init__(self): 0, 0, ) + print("1", time.time() - start) # See FactorGraphState docstrings for documentation on the following fields self._num_var_states = vars_num_states_cumsum[-1] - self._vars_to_starts: OrderedDict[ - Tuple[int, int], int - ] = collections.OrderedDict(zip(self._variables, vars_num_states_cumsum[:-1])) + self._vars_to_starts: Dict[Tuple[int, int], int] = collections.OrderedDict( + zip(self._variables, vars_num_states_cumsum[:-1]) + ) self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {} - print("3", time.time() - start) + print("2", time.time() - start) def __hash__(self) -> int: all_factor_groups = tuple( @@ -1062,6 +1062,7 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]: Raises: ValueError: if flat_data is not of the right shape """ + if flat_beliefs.ndim != 1: raise ValueError( f"Can only unflatten 1D array. Got a {flat_beliefs.ndim}D array." @@ -1070,9 +1071,18 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]: num_variables = 0 num_variable_states = 0 for variable_group in variable_groups: - if isinstance(variable_group, vgroup.NDVariableArray): - num_variables += variable_group.num_states.size - num_variable_states += variable_group.num_states.sum() + variables = variable_group.variables + num_variables += len(variables) + num_variable_states += sum([variable[1] for variable in variables]) + + # if isinstance(variable_group, vgroup.NDVariableArray): + # num_variables += variable_group.num_states.size + # num_variable_states += variable_group.num_states.sum() + # elif isinstance(variable_group, vgroup.VariableDict): + # num_variables += len(variable_group.variables) + # num_variable_states += ( + # len(variable_group.variables) * variable_group.variables[0].num_states + # ) if flat_beliefs.shape[0] == num_variables: use_num_states = False @@ -1088,18 +1098,20 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]: beliefs = {} start = 0 for variable_group in variable_groups: - if use_num_states: - length = variable_group.num_states.sum() + variables = variable_group.variables + if not use_num_states: + length = len(variables) + # length = variable_group.num_states.sum() else: - length = variable_group.num_states.size - + length = sum([variable[1] for variable in variables]) + # length = variable_group.num_states.size beliefs[variable_group] = variable_group.unflatten( flat_beliefs[start : start + length] ) start += length return beliefs - @jax.jit + # @jax.jit def get_beliefs(bp_arrays: BPArrays) -> Dict[Hashable, Any]: """Function to calculate beliefs from a BPArrays diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 06bc80cb..96cc4d19 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -143,6 +143,7 @@ def _variables_to_factors(self) -> Mapping[FrozenSet, nodes.Factor]: @cached_property def total_num_states(self) -> int: """TODO""" + # TODO: this could be returned by the wiring return sum( [ variable[1] diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index 8dac5c7c..3907b9f8 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -28,12 +28,17 @@ class NDVariableArray(groups.VariableGroup): num_states: np.ndarray def __post_init__(self): - if isinstance(self.num_states, int): + if np.isscalar(self.num_states) and np.issubdtype(type(np.int64(10)), int): num_states = np.full(self.shape, fill_value=self.num_states) object.__setattr__(self, "num_states", num_states) - elif isinstance(self.num_states, np.ndarray): + elif isinstance(self.num_states, np.ndarray) and np.issubdtype( + self.num_states.dtype, int + ): if self.num_states.shape != self.shape: raise ValueError("Should be same shape") + else: + raise ValueError("num_states entries should be of type np.int") + random_hash = random.randint(0, 2**63) object.__setattr__(self, "random_hash", random_hash) diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py index ea2a7500..52090b1a 100644 --- a/tests/fg/test_graph.py +++ b/tests/fg/test_graph.py @@ -13,7 +13,7 @@ def test_factor_graph(): - vg = vgroup.VariableDict(15, (0,)) + vg = vgroup.VariableDict(variable_names=(0,), num_states=15) fg = graph.FactorGraph(vg) fg.add_factor_by_type( factor_type=enumeration_factor.EnumerationFactor, @@ -76,7 +76,7 @@ def test_factor_adding(): def test_bp_state(): - vg = vgroup.VariableDict(15, (0,)) + vg = vgroup.VariableDict(variable_names=(0,), num_states=15) fg0 = graph.FactorGraph(vg) fg0.add_factor( variables=[vg[0]], @@ -101,7 +101,7 @@ def test_bp_state(): def test_log_potentials(): - vg = vgroup.VariableDict(15, (0,)) + vg = vgroup.VariableDict(variable_names=(0,), num_states=15) fg = graph.FactorGraph(vg) fg.add_factor( variables=[vg[0]], @@ -130,7 +130,7 @@ def test_log_potentials(): def test_ftov_msgs(): - vg = vgroup.VariableDict(15, (0,)) + vg = vgroup.VariableDict(variable_names=(0,), num_states=15) fg = graph.FactorGraph(vg) fg.add_factor( variables=[vg[0]], @@ -164,7 +164,7 @@ def test_ftov_msgs(): def test_evidence(): - vg = vgroup.VariableDict(15, (0,)) + vg = vgroup.VariableDict(variable_names=(0,), num_states=15) fg = graph.FactorGraph(vg) fg.add_factor( variables=[vg[0]], @@ -181,7 +181,7 @@ def test_evidence(): def test_bp(): - vg = vgroup.VariableDict(15, (0,)) + vg = vgroup.VariableDict(variable_names=(0,), num_states=15) fg = graph.FactorGraph(vg) fg.add_factor( variables=[vg[0]], diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index 9a0fd312..7d4c7a72 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -10,7 +10,7 @@ def test_variable_dict(): - variable_dict = vgroup.VariableDict(15, tuple([0, 1, 2])) + variable_dict = vgroup.VariableDict(variable_names=tuple([0, 1, 2]), num_states=15) with pytest.raises( ValueError, match="data is referring to a non-existent variable 3" ): @@ -19,10 +19,10 @@ def test_variable_dict(): with pytest.raises( ValueError, match=re.escape( - "Variable 2 expects a data array of shape (15,) or (1,). Got (10,)" + "Variable (2, 15) expects a data array of shape (15,) or (1,). Got (10,)" ), ): - variable_dict.flatten({2: np.zeros(10)}) + variable_dict.flatten({(2, 15): np.zeros(10)}) with pytest.raises( ValueError, match="Can only unflatten 1D array. Got a 2D array." diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py index a5e720a8..f1404907 100644 --- a/tests/test_pgmax.py +++ b/tests/test_pgmax.py @@ -6,7 +6,7 @@ from numpy.random import default_rng from scipy.ndimage import gaussian_filter -from pgmax.fg import graph, groups, nodes +from pgmax.fg import graph, nodes from pgmax.groups import enumeration from pgmax.groups import variables as vgroup @@ -121,20 +121,6 @@ def create_valid_suppression_config_arr(suppression_diameter): ] ) ) - true_map_state_output = { - ("grid_vars", 0, 0, 0): 2, - ("grid_vars", 0, 0, 1): 0, - ("grid_vars", 0, 1, 0): 0, - ("grid_vars", 0, 1, 1): 2, - ("grid_vars", 1, 0, 0): 1, - ("grid_vars", 1, 0, 1): 0, - ("grid_vars", 1, 1, 0): 1, - ("grid_vars", 1, 1, 1): 0, - ("additional_vars", 0, 0, 2): 0, - ("additional_vars", 0, 1, 2): 2, - ("additional_vars", 1, 2, 0): 1, - ("additional_vars", 1, 2, 1): 0, - } # Create a synthetic depth image for testing purposes im_size = 3 @@ -222,6 +208,21 @@ def create_valid_suppression_config_arr(suppression_diameter): additional_names = tuple(extra_row_names + extra_col_names) additional_vars = vgroup.VariableDict(additional_names, num_states=3) + true_map_state_output = { + (grid_vars, (0, 0, 0)): 2, + (grid_vars, (0, 0, 1)): 0, + (grid_vars, (0, 1, 0)): 0, + (grid_vars, (0, 1, 1)): 2, + (grid_vars, (1, 0, 0)): 1, + (grid_vars, (1, 0, 1)): 0, + (grid_vars, (1, 1, 0)): 1, + (grid_vars, (1, 1, 1)): 0, + (additional_vars, (0, 0, 2)): 0, + (additional_vars, (0, 1, 2)): 2, + (additional_vars, (1, 2, 0)): 1, + (additional_vars, (1, 2, 1)): 0, + } + gt_has_cuts = gt_has_cuts.astype(np.int32) # Now, we use this array along with the gt_has_cuts array computed earlier using the image in order to derive the evidence values @@ -252,7 +253,7 @@ def create_valid_suppression_config_arr(suppression_diameter): additional_vars_evidence_dict[ additional_vars[i, row, col] ] = evidence_vals_arr - except IndexError: + except ValueError: pass # Create the factor graph @@ -378,7 +379,7 @@ def create_valid_suppression_config_arr(suppression_diameter): assert jnp.allclose(bp_arrays.ftov_msgs, true_final_msgs_output, atol=1e-06) decoded_map_states = graph.decode_map_states(bp.get_beliefs(bp_arrays)) for name in true_map_state_output: - assert true_map_state_output[name] == decoded_map_states[name[0]][name[1:]] + assert true_map_state_output[name] == decoded_map_states[name[0]][name[1]] def test_e2e_heretic(): From 5c8b381b6ac74060b1988efa4ef057467ec5d05b Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Fri, 22 Apr 2022 02:12:06 +0000 Subject: [PATCH 10/35] Some docstrings --- examples/pmp_binary_deconvolution.py | 6 ++-- pgmax/factors/enumeration.py | 5 ++- pgmax/factors/logical.py | 5 ++- pgmax/fg/graph.py | 46 +++++++++++++--------------- pgmax/fg/groups.py | 35 +++++++-------------- pgmax/fg/nodes.py | 14 ++------- pgmax/groups/variables.py | 46 ++++++++++++++++++---------- tests/factors/test_and.py | 5 ++- tests/factors/test_or.py | 1 - tests/fg/test_groups.py | 2 +- tests/test_pgmax.py | 8 ++--- 11 files changed, 80 insertions(+), 93 deletions(-) diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py index 6c31a537..e6a37a0a 100644 --- a/examples/pmp_binary_deconvolution.py +++ b/examples/pmp_binary_deconvolution.py @@ -224,7 +224,7 @@ def plot_images(images, display=True, nr=None): # %% pW = 0.25 -pS = 1e-72 +pS = 1e-80 pX = 1e-100 # Sparsity inducing priors for W and S @@ -242,7 +242,7 @@ def plot_images(images, display=True, nr=None): # We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap` # %% -np.random.seed(seed=40) +np.random.seed(seed=42) n_samples = 4 start = time.time() @@ -271,3 +271,5 @@ def plot_images(images, display=True, nr=None): # %% _ = plot_images(map_states[W].reshape(-1, feat_height, feat_width), nr=n_samples) + +# %% diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py index 49bd8fd3..d461d3ab 100644 --- a/pgmax/factors/enumeration.py +++ b/pgmax/factors/enumeration.py @@ -163,9 +163,8 @@ def compile_wiring( Internally calls _compile_var_states_numba and _compile_enumeration_wiring_numba for speed. Args: - variables_for_factors: A list of list of variables, where each innermost element is a - variable. Each list within the outer list is taken to contain the names of the - variables connected to a Factor. + variables_for_factors: A list of list of variables. Each list within the outer list contains the + variables connected to a Factor. The same variable can be connected to multiple Factors. factor_configs: Array of shape (num_val_configs, num_variables) containing an explicit enumeration of all valid configurations. vars_to_starts: A dictionary that maps variables to their global starting indices diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py index a7554334..7f2d1ca6 100644 --- a/pgmax/factors/logical.py +++ b/pgmax/factors/logical.py @@ -148,9 +148,8 @@ def compile_wiring( Internally calls _compile_var_states_numba and _compile_logical_wiring_numba for speed. Args: - variables_for_factors: A list of list of variables, where each innermost element is a - variable. Each list within the outer list is taken to contain the names of the - variables connected to a Factor. + variables_for_factors: A list of list of variables. Each list within the outer list contains the + variables connected to a Factor. The same variable can be connected to multiple Factors. vars_to_starts: A dictionary that maps variables to their global starting indices For an n-state variable, a global start index of m means the global indices of its n variable states are m, m + 1, ..., m + n - 1 diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 77fc489a..9b8f7cc7 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -65,13 +65,6 @@ def __post_init__(self): else: self._variable_groups = self.variables - self._variables = [ - variable - for variable_group in self._variable_groups - for variable in variable_group.variables - ] - print("1", time.time() - start) - # Useful objects to build the FactorGraph self._factor_types_to_groups: OrderedDict[ Type, List[groups.FactorGroup] @@ -84,19 +77,23 @@ def __post_init__(self): [(factor_type, set()) for factor_type in FAC_TO_VAR_UPDATES] ) - # Used to add FactorGroups - vars_num_states = [variable[1] for variable in self._variables] - vars_num_states_cumsum = np.insert( - np.array(vars_num_states).cumsum(), - 0, - 0, - ) - print("1", time.time() - start) # See FactorGraphState docstrings for documentation on the following fields - self._num_var_states = vars_num_states_cumsum[-1] - self._vars_to_starts: Dict[Tuple[int, int], int] = collections.OrderedDict( - zip(self._variables, vars_num_states_cumsum[:-1]) - ) + self._vars_to_starts: OrderedDict[ + Tuple[int, int], int + ] = collections.OrderedDict() + vars_num_states_cumsum = 0 + for variable_group in self._variable_groups: + vg_num_states = variable_group.num_states.flatten() + vg_num_states_cumsum = np.insert(np.cumsum(vg_num_states), 0, 0) + self._vars_to_starts.update( + zip( + variable_group.variables, + vars_num_states_cumsum + vg_num_states_cumsum[:-1], + ) + ) + vars_num_states_cumsum += vg_num_states_cumsum[-1] + + self._num_var_states = vars_num_states_cumsum self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {} print("2", time.time() - start) @@ -112,7 +109,7 @@ def __hash__(self) -> int: def add_factor( self, - variables: List, + variables: List[Tuple], factor_configs: np.ndarray, log_potentials: Optional[np.ndarray] = None, name: Optional[str] = None, @@ -120,8 +117,8 @@ def add_factor( """Function to add a single factor to the FactorGraph. Args: - variables: A list containing the connected variable names. - Variable names are tuples of the type (variable_group_name, variable_name_within_variable_group) + variables: A list containing the connected variables. + Each variable is represented by a tuple of the form (variable hash/name, number of states) factor_configs: Array of shape (num_val_configs, num_variables) An array containing explicit enumeration of all valid configurations. If the connected variables have n1, n2, ... states, 1 <= num_val_configs <= n1 * n2 * ... @@ -146,8 +143,8 @@ def add_factor_by_type( """Function to add a single factor to the FactorGraph. Args: - variable_names: A list containing the connected variable names. - Variable names are tuples of the type (variable_group_name, variable_name_within_variable_group) + variables: A list containing the connected variables. + Each variable is represented by a tuple of the form (variable hash/name, number of states) factor_type: Type of factor to be added args: Args to be passed to the factor_type. kwargs: kwargs to be passed to the factor_type, and an optional "name" argument @@ -1105,6 +1102,7 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]: else: length = sum([variable[1] for variable in variables]) # length = variable_group.num_states.size + beliefs[variable_group] = variable_group.unflatten( flat_beliefs[start : start + length] ) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 96cc4d19..b0340cfa 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -45,12 +45,11 @@ def __getitem__(self, val): @cached_property def variables(self) -> Tuple[Any, int]: - """Function that returns the list of variables. Each variable is represented - by a tuple of the variable name (which can be a hash or a string) and its - number of states. + """Function that returns the list of all variables in the VariableGroup. + Each variable is represented by a tuple of the form (variable name, number of states) Returns: - List of variables in the VariableGroup + List of variables in the VariableGroup """ raise NotImplementedError( "Please subclass the VariableGroup class and override this method" @@ -88,9 +87,8 @@ class FactorGroup: """Class to represent a group of Factors. Args: - variables_for_factors: A list of list of variables, where each innermost element is a - variable. Each list within the outer list is taken to contain the names of the - variables connected to a Factor. + variables_for_factors: A list of list of variables. Each list within the outer list contains the + variables connected to a Factor. The same variable can be connected to multiple Factors. factor_configs: Optional array containing an explicit enumeration of all valid configurations log_potentials: Array of log potentials. @@ -142,8 +140,12 @@ def _variables_to_factors(self) -> Mapping[FrozenSet, nodes.Factor]: @cached_property def total_num_states(self) -> int: - """TODO""" - # TODO: this could be returned by the wiring + """Function to return the total number of states for all the variables involved in all the Factors + + Returns: + Total number of variable states in the FactorGroup + """ + # TODO: this could be returned by the wiring to loop over variables_for_factors only once return sum( [ variable[1] @@ -293,18 +295,3 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any: raise NotImplementedError( "SingleFactorGroup does not support vectorized factor operations." ) - - -# def get_ndvariable_names_for_factor_groups( -# arrays: List[Any] - -# ): -# "Util function" - -# import numba as nb -# @nb.jit(parallel=False, cache=True, fastmath=True, nopython=True) -# def run_numba(h, v, f): -# for h_idx in nb.prange(h.shape[0]): -# for v_idx in nb.prange(v.shape[0]): -# f[h_idx * v.shape[0] + v_idx, 0] = h[h_idx] -# f[h_idx * v.shape[0] + v_idx, 1] = v[v_idx] diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py index 67a998ba..06d7c5d6 100644 --- a/pgmax/fg/nodes.py +++ b/pgmax/fg/nodes.py @@ -41,8 +41,8 @@ class Factor: """A factor Args: - variables: List of variables in the factors. Each variable is represented - by a tuple containing the variable hash and number of states. + variables: List of variables connected by the Factor. + Each variable is represented by a tuple of the form (variable hash/name, number of states) Raises: NotImplementedError: If compile_wiring is not implemented @@ -57,16 +57,6 @@ def __post_init__(self): "Please implement compile_wiring in for your factor" ) - # @utils.cached_property - # def edges_num_states(self) -> np.ndarray: - # """Number of states for the variables connected to each edge - - # Returns: - # Array of shape (num_edges,) - # Number of states for the variables connected to each edge - # """ - # return self.variables.values() - @staticmethod def concatenate_wirings(wirings: Sequence) -> Wiring: """Concatenate a list of Wirings diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index 3907b9f8..36f7a4b9 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -52,7 +52,7 @@ def __lt__(self, other): return hash(self) < hash(other) def __getitem__(self, val): - # Numpy indexation will throw IndexError for us if out-of-bounds + # Relies on numpy indexation to throw IndexError if val is out-of-bounds result = (self.variable_names[val], self.num_states[val]) if isinstance(val, slice): return tuple(zip(result)) @@ -60,16 +60,22 @@ def __getitem__(self, val): @cached_property def variables(self) -> List[Tuple]: + """Function that returns the list of all variables in the VariableGroup. + Each variable is represented by a tuple of the form (variable hash, number of states) + + Returns: + List of variables in the VariableGroup + """ vars_names = self.variable_names.flatten() vars_num_states = self.num_states.flatten() return list(zip(vars_names, vars_num_states)) @cached_property def variable_names(self) -> np.ndarray: - """Function that generates a dictionary mapping names to variables. + """Function that generates all the variables names, in the form of hashes Returns: - a dictionary mapping all possible names to different variables. + Array of variables names. """ # Overwite default hash as it does not give enough spacing across consecutive objects indices = np.reshape(np.arange(np.product(self.shape)), self.shape) @@ -123,7 +129,7 @@ def unflatten( if flat_data.size == np.product(self.shape): data = flat_data.reshape(self.shape) elif flat_data.size == self.num_states.sum(): - # TODO: what should we dot for different number of states + # TODO: what should we do for different number of states data = flat_data.reshape(self.shape + (self.num_states.max(),)) else: raise ValueError( @@ -144,21 +150,29 @@ class VariableDict(groups.VariableGroup): """ variable_names: Tuple[Any, ...] - num_states: int + num_states: np.ndarray # TODO: this should be an int converted to an array in __post_init__ + + def __post_init__(self): + num_states = np.full((len(self.variable_names),), fill_value=self.num_states) + object.__setattr__(self, "num_states", num_states) @cached_property def variables(self) -> List[Tuple]: - return list( - zip(self.variable_names, [self.num_states] * len(self.variable_names)) - ) + """Function that returns the list of all variables in the VariableGroup. + Each variable is represented by a tuple of the form (variable name, number of states) + + Returns: + List of variables in the VariableGroup + """ + return list(zip(self.variable_names, self.num_states)) def __getitem__(self, val): if val not in self.variable_names: raise ValueError(f"Variable {val} is not in VariableDict") - return (val, self.num_states) + return (val, self.num_states[0]) def flatten( - self, data: Mapping[Hashable, Union[np.ndarray, jnp.ndarray]] + self, data: Mapping[Tuple[int, int], Union[np.ndarray, jnp.ndarray]] ) -> jnp.ndarray: """Function that turns meaningful structured data into a flat data array for internal use. @@ -181,10 +195,10 @@ def flatten( f"data is referring to a non-existent variable {name}." ) - if data[name].shape != (self.num_states,) and data[name].shape != (1,): + if data[name].shape != (name[1],) and data[name].shape != (1,): raise ValueError( f"Variable {name} expects a data array of shape " - f"{(self.num_states,)} or (1,). Got {data[name].shape}." + f"{(name[1],)} or (1,). Got {data[name].shape}." ) flat_data = jnp.concatenate([data[name].flatten() for name in self.variables]) @@ -214,7 +228,7 @@ def unflatten( ) num_variables = len(self.variable_names) - num_variable_states = len(self.variable_names) * self.num_states + num_variable_states = self.num_states.sum() if flat_data.shape[0] == num_variables: use_num_states = False elif flat_data.shape[0] == num_variable_states: @@ -228,10 +242,10 @@ def unflatten( start = 0 data = {} - for name in self.variable_names: + for name in self.variables: if use_num_states: - data[name] = flat_data[start : start + self.num_states] - start += self.num_states + data[name] = flat_data[start : start + name[1]] + start += name[1] else: data[name] = flat_data[np.array([start])] start += 1 diff --git a/tests/factors/test_and.py b/tests/factors/test_and.py index 19bc08d0..84e37175 100644 --- a/tests/factors/test_and.py +++ b/tests/factors/test_and.py @@ -15,7 +15,7 @@ def test_run_bp_with_ANDFactors(): (2) the support of several factor types in a FactorGraph and during inference To do so, observe that an ANDFactor can be defined as an equivalent EnumerationFactor - (which list all the valid OR configurations) and define two equivalent FactorGraphs + (which list all the valid AND configurations) and define two equivalent FactorGraphs FG1: first half of factors are defined as EnumerationFactors, second half are defined as ANDFactors FG2: first half of factors are defined as ANDFactors, second half are defined as EnumerationFactors @@ -25,7 +25,6 @@ def test_run_bp_with_ANDFactors(): Note: for the first seed, add all the EnumerationFactors to FG1 and all the ANDFactors to FG2 """ for idx in range(10): - print("it", idx) np.random.seed(idx) # Parameters @@ -54,7 +53,7 @@ def test_run_bp_with_ANDFactors(): children_variables2 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,)) fg2 = graph.FactorGraph(variables=[parents_variables2, children_variables2]) - # Variable names for factors + # Option 1: Define EnumerationFactors equivalent to the ANDFactors variables_for_factors1 = [] variables_for_factors2 = [] for factor_idx in range(num_factors): diff --git a/tests/factors/test_or.py b/tests/factors/test_or.py index 162a7338..c4bdc758 100644 --- a/tests/factors/test_or.py +++ b/tests/factors/test_or.py @@ -25,7 +25,6 @@ def test_run_bp_with_ORFactors(): Note: for the first seed, add all the EnumerationFactors to FG1 and all the ORFactors to FG2 """ for idx in range(10): - print("it", idx) np.random.seed(idx) # Parameters diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index 7d4c7a72..d6e18e73 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -35,7 +35,7 @@ def test_variable_dict(): jax.tree_util.tree_multimap( lambda x, y: jnp.all(x == y), variable_dict.unflatten(jnp.zeros(3)), - {name: np.zeros(1) for name in range(3)}, + {(name, 15): np.zeros(1) for name in range(3)}, ) ) ) diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py index f1404907..5c30e047 100644 --- a/tests/test_pgmax.py +++ b/tests/test_pgmax.py @@ -217,10 +217,10 @@ def create_valid_suppression_config_arr(suppression_diameter): (grid_vars, (1, 0, 1)): 0, (grid_vars, (1, 1, 0)): 1, (grid_vars, (1, 1, 1)): 0, - (additional_vars, (0, 0, 2)): 0, - (additional_vars, (0, 1, 2)): 2, - (additional_vars, (1, 2, 0)): 1, - (additional_vars, (1, 2, 1)): 0, + (additional_vars, ((0, 0, 2), 3)): 0, + (additional_vars, ((0, 1, 2), 3)): 2, + (additional_vars, ((1, 2, 0), 3)): 1, + (additional_vars, ((1, 2, 1), 3)): 0, } gt_has_cuts = gt_has_cuts.astype(np.int32) From 40ec5190b4b915bbb33b48b5a0cbbef7818505d8 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Fri, 22 Apr 2022 19:24:00 +0000 Subject: [PATCH 11/35] Stannis first comments --- examples/gmrf.py | 4 +-- examples/rbm.py | 6 ++--- pgmax/fg/graph.py | 47 +++++++++++---------------------- pgmax/fg/groups.py | 27 +++++++++++++++---- pgmax/groups/variables.py | 55 +++++++++++++++++++++------------------ 5 files changed, 71 insertions(+), 68 deletions(-) diff --git a/examples/gmrf.py b/examples/gmrf.py index 95ae4a7a..3520c04d 100644 --- a/examples/gmrf.py +++ b/examples/gmrf.py @@ -121,8 +121,8 @@ damping=0.0, ) ) - ) - marginals = marginals[variables] + )[variables] + pred_image = np.argmax( np.stack( [ diff --git a/examples/rbm.py b/examples/rbm.py index 30f5ac2d..6d7a1d53 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -110,14 +110,14 @@ # # Add unary factors # for ii in range(bh.shape[0]): # fg.add_factor( -# variable_names=[("hidden", ii)], +# variables=[hidden_variables[ii]], # factor_configs=np.arange(2)[:, None], # log_potentials=np.array([0, bh[ii]]), # ) # # for jj in range(bv.shape[0]): # fg.add_factor( -# variable_names=[("visible", jj)], +# variables=[visible_variables[jj]], # factor_configs=np.arange(2)[:, None], # log_potentials=np.array([0, bv[jj]]), # ) @@ -127,7 +127,7 @@ # for ii in tqdm(range(bh.shape[0])): # for jj in range(bv.shape[0]): # fg.add_factor( -# variable_names=[("hidden", ii), ("visible", jj)], +# variables=[hidden_variables[ii], visible_variables[jj]], # factor_configs=factor_configs, # log_potentials=np.array([0, 0, 0, W[ii, jj]]), # ) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 9b8f7cc7..28460c79 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -35,7 +35,6 @@ from pgmax.bp import infer from pgmax.factors import FAC_TO_VAR_UPDATES from pgmax.fg import groups, nodes -from pgmax.groups import variables as vgroup from pgmax.groups.enumeration import EnumerationFactorGroup from pgmax.utils import cached_property @@ -60,10 +59,8 @@ def __post_init__(self): import time start = time.time() - if isinstance(self.variables, (vgroup.NDVariableArray, vgroup.VariableDict)): - self._variable_groups = [self.variables] - else: - self._variable_groups = self.variables + if isinstance(self.variables, groups.VariableGroup): + self.variables = [self.variables] # Useful objects to build the FactorGraph self._factor_types_to_groups: OrderedDict[ @@ -82,7 +79,7 @@ def __post_init__(self): Tuple[int, int], int ] = collections.OrderedDict() vars_num_states_cumsum = 0 - for variable_group in self._variable_groups: + for variable_group in self.variables: vg_num_states = variable_group.num_states.flatten() vg_num_states_cumsum = np.insert(np.cumsum(vg_num_states), 0, 0) self._vars_to_starts.update( @@ -355,9 +352,10 @@ def fg_state(self) -> FactorGraphState: log_potentials = np.concatenate( [self.log_potentials[factor_type] for factor_type in self.log_potentials] ) + assert isinstance(self.variables, list) return FactorGraphState( - variable_groups=self._variable_groups, + variable_groups=self.variables, vars_to_starts=self._vars_to_starts, num_var_states=self._num_var_states, total_factor_num_states=self._total_factor_num_states, @@ -391,7 +389,7 @@ class FactorGraphState: """FactorGraphState. Args: - variable_group: A variable group containing all the variables in the FactorGraph. + variable_groups: All the variable groups in the FactorGraph. vars_to_starts: Maps variables to their starting indices in the flat evidence array. flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states] contains evidence to the variable. @@ -1109,7 +1107,7 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]: start += length return beliefs - # @jax.jit + @jax.jit def get_beliefs(bp_arrays: BPArrays) -> Dict[Hashable, Any]: """Function to calculate beliefs from a BPArrays @@ -1143,7 +1141,8 @@ def compute_flat_beliefs(bp_arrays, var_states_for_edges): return bp -def decode_map_states(beliefs: Dict[Hashable, Any]) -> Dict[Hashable, Any]: +@jax.jit +def decode_map_states(beliefs: Dict[Hashable, Any]) -> Any: """Function to decode MAP states given the calculated beliefs. Args: @@ -1152,18 +1151,11 @@ def decode_map_states(beliefs: Dict[Hashable, Any]) -> Dict[Hashable, Any]: Returns: An array or a PyTree container containing the MAP states for different variables. """ - - @jax.jit - def _decode_map_states(beliefs) -> Any: - return jax.tree_util.tree_map(lambda x: jnp.argmax(x, axis=-1), beliefs) - - map_states = {} - for variable_group, vgroup_beliefs in beliefs.items(): - map_states[variable_group] = _decode_map_states(vgroup_beliefs) - return map_states + return jax.tree_util.tree_map(lambda x: jnp.argmax(x, axis=-1), beliefs) -def get_marginals(beliefs: Dict[Hashable, Any]) -> Dict[Hashable, Any]: +@jax.jit +def get_marginals(beliefs: Dict[Hashable, Any]) -> Any: """Function to get marginal probabilities given the calculated beliefs. Args: @@ -1172,15 +1164,6 @@ def get_marginals(beliefs: Dict[Hashable, Any]) -> Dict[Hashable, Any]: Returns: An array or a PyTree container containing the marginal probabilities different variables. """ - - @jax.jit - def _get_marginals(beliefs) -> Any: - return jax.tree_util.tree_map( - lambda x: jnp.exp(x - logsumexp(x, axis=-1, keepdims=True)), - beliefs, - ) - - marginals = {} - for variable_group, vgroup_beliefs in beliefs.items(): - marginals[variable_group] = _get_marginals(vgroup_beliefs) - return marginals + return jax.tree_util.tree_map( + lambda x: jnp.exp(x - logsumexp(x, axis=-1, keepdims=True)), beliefs + ) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index b0340cfa..04b22ec4 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -1,7 +1,9 @@ """A module containing the base classes for variable and factor groups in a Factor Graph.""" import inspect +import random from dataclasses import dataclass, field +from functools import total_ordering from typing import ( Any, FrozenSet, @@ -21,6 +23,7 @@ from pgmax.utils import cached_property +@total_ordering @dataclass(frozen=True, eq=False) class VariableGroup: """Class to represent a group of variables. @@ -30,14 +33,28 @@ class VariableGroup: a sequence of variable names) of the VariableGroup. """ + def __post_init__(self): + random_hash = random.randint(0, 2**63) + object.__setattr__(self, "random_hash", random_hash) + + def __hash__(self): + return self.random_hash + + def __eq__(self, other): + return hash(self) == hash(other) + + def __lt__(self, other): + return hash(self) < hash(other) + def __getitem__(self, val): - """Given a variable name, retrieve the associated Variable. + """Given a variable name, index, or a group of variable indices, retrieve the associated variable(s). + Each variable is returned via a tuple of the form (variable hash/name, number of states) Args: - val: a single name corresponding to a single variable, or a list of such names + val: a variable index, slice, or name Returns: - A single variable if the name is not a list. A list of variables if name is a list + A single variable or a list of variables """ raise NotImplementedError( "Please subclass the VariableGroup class and override this method" @@ -46,7 +63,7 @@ def __getitem__(self, val): @cached_property def variables(self) -> Tuple[Any, int]: """Function that returns the list of all variables in the VariableGroup. - Each variable is represented by a tuple of the form (variable name, number of states) + Each variable is represented by a tuple of the form (variable hash/name, number of states) Returns: List of variables in the VariableGroup @@ -108,7 +125,7 @@ def __post_init__(self): if len(self.variables_for_factors) == 0: raise ValueError("Do not add a factor group with no factors.") - def __getitem__(self, variables: Sequence[int]) -> Any: + def __getitem__(self, variables: Sequence[Tuple[int, int]]) -> Any: """Function to query individual factors in the factor group Args: diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index 36f7a4b9..b7e0d66b 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -1,8 +1,6 @@ """A module containing the variables group classes inheriting from the base VariableGroup.""" -import random from dataclasses import dataclass -from functools import total_ordering from typing import Any, Dict, Hashable, List, Mapping, Tuple, Union import jax @@ -13,7 +11,6 @@ from pgmax.utils import cached_property -@total_ordering @dataclass(frozen=True, eq=False) class NDVariableArray(groups.VariableGroup): """Subclass of VariableGroup for n-dimensional grids of variables. @@ -28,6 +25,8 @@ class NDVariableArray(groups.VariableGroup): num_states: np.ndarray def __post_init__(self): + super().__post_init__() + if np.isscalar(self.num_states) and np.issubdtype(type(np.int64(10)), int): num_states = np.full(self.shape, fill_value=self.num_states) object.__setattr__(self, "num_states", num_states) @@ -39,23 +38,23 @@ def __post_init__(self): else: raise ValueError("num_states entries should be of type np.int") - random_hash = random.randint(0, 2**63) - object.__setattr__(self, "random_hash", random_hash) - - def __hash__(self): - return self.random_hash + def __getitem__( + self, val: Union[int, tuple, slice] + ) -> Union[Tuple[int, int], List[Tuple[int]]]: + """Given an index or a slice, retrieve the associated variable(s). + Each variable is returned via a tuple of the form (variable hash, number of states) - def __eq__(self, other): - return hash(self) == hash(other) + Note: Relies on numpy indexation to throw IndexError if val is out-of-bounds - def __lt__(self, other): - return hash(self) < hash(other) + Args: + val: a variable index or slice - def __getitem__(self, val): - # Relies on numpy indexation to throw IndexError if val is out-of-bounds + Returns: + A single variable or a list of variables + """ result = (self.variable_names[val], self.num_states[val]) if isinstance(val, slice): - return tuple(zip(result)) + return list(zip(result)) return result @cached_property @@ -153,6 +152,8 @@ class VariableDict(groups.VariableGroup): num_states: np.ndarray # TODO: this should be an int converted to an array in __post_init__ def __post_init__(self): + super().__post_init__() + num_states = np.full((len(self.variable_names),), fill_value=self.num_states) object.__setattr__(self, "num_states", num_states) @@ -189,19 +190,21 @@ def flatten( (1) data is referring to a non-existing variable (2) data is not of the correct shape """ - for name in data: - if name not in self.variables: + for variable in data: + if variable not in self.variables: raise ValueError( - f"data is referring to a non-existent variable {name}." + f"data is referring to a non-existent variable {variable}." ) - if data[name].shape != (name[1],) and data[name].shape != (1,): + if data[variable].shape != (variable[1],) and data[variable].shape != (1,): raise ValueError( - f"Variable {name} expects a data array of shape " - f"{(name[1],)} or (1,). Got {data[name].shape}." + f"Variable {variable} expects a data array of shape " + f"{(variable[1],)} or (1,). Got {data[variable].shape}." ) - flat_data = jnp.concatenate([data[name].flatten() for name in self.variables]) + flat_data = jnp.concatenate( + [data[variable].flatten() for variable in self.variables] + ) return flat_data def unflatten( @@ -242,12 +245,12 @@ def unflatten( start = 0 data = {} - for name in self.variables: + for variable in self.variables: if use_num_states: - data[name] = flat_data[start : start + name[1]] - start += name[1] + data[variable] = flat_data[start : start + variable[1]] + start += variable[1] else: - data[name] = flat_data[np.array([start])] + data[variable] = flat_data[np.array([start])] start += 1 return data From 96c7fe380cd1feb141e4de6a4fcde9f6291086f1 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Fri, 22 Apr 2022 23:42:02 +0000 Subject: [PATCH 12/35] Remove add_factor --- examples/gmrf.py | 33 ++-- examples/ising_model.py | 5 +- examples/pmp_binary_deconvolution.py | 16 +- examples/rbm.py | 27 +-- examples/rcn.py | 9 +- pgmax/factors/enumeration.py | 3 +- pgmax/fg/graph.py | 90 ++-------- pgmax/fg/groups.py | 6 +- pgmax/groups/logical.py | 4 + tests/factors/test_and.py | 23 +-- tests/factors/test_or.py | 23 +-- tests/fg/test_graph.py | 237 +++++++++++++-------------- tests/fg/test_wiring.py | 53 +++--- tests/test_pgmax.py | 23 +-- 14 files changed, 246 insertions(+), 306 deletions(-) diff --git a/examples/gmrf.py b/examples/gmrf.py index 3520c04d..363aca85 100644 --- a/examples/gmrf.py +++ b/examples/gmrf.py @@ -59,44 +59,44 @@ # %% # Add top-down factors -fg.add_factor_group( - factory=enumeration.PairwiseFactorGroup, +factor_group = enumeration.PairwiseFactorGroup( variables_for_factors=[ [variables[ii, jj], variables[ii + 1, jj]] for ii in range(M - 1) for jj in range(N) ], - name="top_down", ) +fg.add_factor_group(factor_group, name="top_down") + # Add left-right factors -fg.add_factor_group( - factory=enumeration.PairwiseFactorGroup, +factor_group = enumeration.PairwiseFactorGroup( variables_for_factors=[ [variables[ii, jj], variables[ii, jj + 1]] for ii in range(M) for jj in range(N - 1) ], - name="left_right", ) +fg.add_factor_group(factor_group, name="left_right") + # Add diagonal factors -fg.add_factor_group( - factory=enumeration.PairwiseFactorGroup, +factor_group = enumeration.PairwiseFactorGroup( variables_for_factors=[ [variables[ii, jj], variables[ii + 1, jj + 1]] for ii in range(M - 1) for jj in range(N - 1) ], - name="diagonal0", ) -fg.add_factor_group( - factory=enumeration.PairwiseFactorGroup, +fg.add_factor_group(factor_group, name="diagonal0") + + +factor_group = enumeration.PairwiseFactorGroup( variables_for_factors=[ [variables[ii, jj], variables[ii - 1, jj + 1]] for ii in range(1, M) for jj in range(N - 1) ], - name="diagonal1", ) +fg.add_factor_group(factor_group, name="diagonal1") # %% bp = graph.BP(fg.bp_state, temperature=1.0) @@ -227,3 +227,12 @@ def update(step, batch_noisy_images, batch_target_images, opt_state): ) pbar.update() pbar.set_postfix(loss=value) + + +batch_indices = indices[idx * batch_size : (idx + 1) * batch_size] +batch_noisy_images, batch_target_images = ( + noisy_images_train[:10], + target_images_train[:10], +) +step = 0 +value, opt_state = update(step, batch_noisy_images, batch_target_images, opt_state) diff --git a/examples/ising_model.py b/examples/ising_model.py index 88535c23..b2f5aef7 100644 --- a/examples/ising_model.py +++ b/examples/ising_model.py @@ -39,12 +39,11 @@ variables_for_factors.append([variables[ii, jj], variables[kk, jj]]) variables_for_factors.append([variables[ii, jj], variables[kk, ll]]) -fg.add_factor_group( - factory=enumeration.PairwiseFactorGroup, +factor_group = enumeration.PairwiseFactorGroup( variables_for_factors=variables_for_factors, log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]), - name="factors", ) +fg.add_factor_group(factor_group, name="factors") # %% [markdown] # ### Run inference and visualize results diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py index e6a37a0a..6cdb84f1 100644 --- a/examples/pmp_binary_deconvolution.py +++ b/examples/pmp_binary_deconvolution.py @@ -180,10 +180,8 @@ def plot_images(images, display=True, nr=None): print("After loop", time.time() - start) # Add ANDFactorGroup, which is computationally efficient -fg.add_factor_group( - factory=logical.ANDFactorGroup, - variables_for_factors=variables_for_ANDFactors, -) +AND_factor_group = logical.ANDFactorGroup(variables_for_ANDFactors) +fg.add_factor_group(AND_factor_group) print(time.time() - start) # Define the ORFactors @@ -193,10 +191,8 @@ def plot_images(images, display=True, nr=None): ] # Add ORFactorGroup, which is computationally efficient -fg.add_factor_group( - factory=logical.ORFactorGroup, - variables_for_factors=variables_for_ORFactors, -) +OR_factor_group = logical.ORFactorGroup(variables_for_ORFactors) +fg.add_factor_group(OR_factor_group) print("Time", time.time() - start) for factor_type, factor_groups in fg.factor_groups.items(): @@ -224,7 +220,7 @@ def plot_images(images, display=True, nr=None): # %% pW = 0.25 -pS = 1e-80 +pS = 1e-70 pX = 1e-100 # Sparsity inducing priors for W and S @@ -271,5 +267,3 @@ def plot_images(images, display=True, nr=None): # %% _ = plot_images(map_states[W].reshape(-1, feat_height, feat_width), nr=n_samples) - -# %% diff --git a/examples/rbm.py b/examples/rbm.py index 6d7a1d53..876f08ee 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -26,6 +26,7 @@ import matplotlib.pyplot as plt import numpy as np +from pgmax.factors import enumeration as enumeration_factor from pgmax.fg import graph from pgmax.groups import enumeration from pgmax.groups import variables as vgroup @@ -64,25 +65,25 @@ start = time.time() # Add unary factors -fg.add_factor_group( - factory=enumeration.EnumerationFactorGroup, +hidden_unaries = enumeration.EnumerationFactorGroup( variables_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])], factor_configs=np.arange(2)[:, None], log_potentials=np.stack([np.zeros_like(bh), bh], axis=1), ) +fg.add_factor_group(hidden_unaries) -fg.add_factor_group( - factory=enumeration.EnumerationFactorGroup, +visible_unaries = enumeration.EnumerationFactorGroup( variables_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])], factor_configs=np.arange(2)[:, None], log_potentials=np.stack([np.zeros_like(bv), bv], axis=1), ) +fg.add_factor_group(visible_unaries) + # Add pairwise factors log_potential_matrix = np.zeros(W.shape + (2, 2)).reshape((-1, 2, 2)) log_potential_matrix[:, 1, 1] = W.ravel() -fg.add_factor_group( - factory=enumeration.PairwiseFactorGroup, +pairwise_factors = enumeration.PairwiseFactorGroup( variables_for_factors=[ [hidden_variables[ii], visible_variables[jj]] for ii in range(bh.shape[0]) @@ -90,6 +91,7 @@ ], log_potential_matrix=log_potential_matrix, ) +fg.add_factor_group(pairwise_factors) # # %snakeviz fg.add_factor_group(factory=enumeration.PairwiseFactorGroup, variables_for_factors=v, log_potential_matrix=log_potential_matrix,) print("Time", time.time() - start) @@ -109,28 +111,31 @@ # # # Add unary factors # for ii in range(bh.shape[0]): -# fg.add_factor( +# factor = enumeration_factor.EnumerationFactor( # variables=[hidden_variables[ii]], # factor_configs=np.arange(2)[:, None], # log_potentials=np.array([0, bh[ii]]), # ) +# fg.add_factor(factor) # # for jj in range(bv.shape[0]): -# fg.add_factor( +# factor = enumeration_factor.EnumerationFactor( # variables=[visible_variables[jj]], # factor_configs=np.arange(2)[:, None], # log_potentials=np.array([0, bv[jj]]), # ) +# fg.add_factor(factor) # # # Add pairwise factors # factor_configs = np.array(list(itertools.product(np.arange(2), repeat=2))) # for ii in tqdm(range(bh.shape[0])): # for jj in range(bv.shape[0]): -# fg.add_factor( +# factor = enumeration_factor.EnumerationFactor( # variables=[hidden_variables[ii], visible_variables[jj]], # factor_configs=factor_configs, # log_potentials=np.array([0, 0, 0, W[ii, jj]]), # ) +# fg.add_factor(factor) # ~~~ # # Once we have added the factors, we can run max-product LBP and get MAP decoding by @@ -194,8 +199,8 @@ # ~~~python # bp_arrays = run_bp( # evidence_updates={ -# "hidden": np.random.gumbel(size=(bh.shape[0], 2)), -# "visible": np.random.gumbel(size=(bv.shape[0], 2)), +# hidden_variables: np.random.gumbel(size=(bh.shape[0], 2)), +# visible_variables: np.random.gumbel(size=(bv.shape[0], 2)), # }, # damping=0.5, # ) diff --git a/examples/rcn.py b/examples/rcn.py index 97de9fff..d7add7d7 100644 --- a/examples/rcn.py +++ b/examples/rcn.py @@ -38,6 +38,7 @@ from scipy.signal import fftconvolve from sklearn.datasets import fetch_openml +from pgmax.factors.enumeration import EnumerationFactor from pgmax.fg import graph from pgmax.groups import variables as vgroup @@ -279,10 +280,12 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray: for e in edge: i1, i2, r = e - fg.add_factor( - [variables_all_models[idx][i1], variables_all_models[idx][i2]], - valid_configs_list[r], + factor = EnumerationFactor( + variables=[variables_all_models[idx][i1], variables_all_models[idx][i2]], + factor_configs=valid_configs_list[r], + log_potentials=np.zeros(valid_configs_list[r].shape[0]), ) + fg.add_factor(factor) end = time.time() print(f"Creating factors took {end-start:.3f} seconds.") diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py index d461d3ab..b0610c5a 100644 --- a/pgmax/factors/enumeration.py +++ b/pgmax/factors/enumeration.py @@ -1,6 +1,5 @@ """Defines an enumeration factor""" -import functools from dataclasses import dataclass from typing import List, Mapping, Sequence, Tuple, Union @@ -263,7 +262,7 @@ def _compile_enumeration_wiring_numba( ) -@functools.partial(jax.jit, static_argnames=("num_val_configs", "temperature")) +# @functools.partial(jax.jit, static_argnames=("num_val_configs", "temperature")) def pass_enum_fac_to_var_messages( vtof_msgs: jnp.ndarray, factor_configs_edge_states: jnp.ndarray, diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 28460c79..aec8b82f 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -104,95 +104,31 @@ def __hash__(self) -> int: ) return hash(all_factor_groups) - def add_factor( - self, - variables: List[Tuple], - factor_configs: np.ndarray, - log_potentials: Optional[np.ndarray] = None, - name: Optional[str] = None, - ) -> None: - """Function to add a single factor to the FactorGraph. + def add_factor(self, factor: nodes.Factor, name: Optional[str] = None) -> None: + """Function to add a single Factor to the FactorGraph. Args: - variables: A list containing the connected variables. - Each variable is represented by a tuple of the form (variable hash/name, number of states) - factor_configs: Array of shape (num_val_configs, num_variables) - An array containing explicit enumeration of all valid configurations. - If the connected variables have n1, n2, ... states, 1 <= num_val_configs <= n1 * n2 * ... - factor_configs[config_idx, variable_idx] represents the state of variable_names[variable_idx] - in the configuration factor_configs[config_idx]. - log_potentials: Optional array of shape (num_val_configs,). - If specified, log_potentials[config_idx] contains the log of the potential value for - the valid configuration factor_configs[config_idx]. - If None, it is assumed the log potential is uniform 0 and such an array is automatically - initialized. + factor: The factor to be added to the factor graph. + name: Optional name of the FactorGroup. """ - factor_group = EnumerationFactorGroup( - variables_for_factors=[variables], - factor_configs=factor_configs, - log_potentials=log_potentials, - ) - self._register_factor_group(factor_group, name) - - def add_factor_by_type( - self, variables: List[int], factor_type: type, *args, **kwargs - ) -> None: - """Function to add a single factor to the FactorGraph. - - Args: - variables: A list containing the connected variables. - Each variable is represented by a tuple of the form (variable hash/name, number of states) - factor_type: Type of factor to be added - args: Args to be passed to the factor_type. - kwargs: kwargs to be passed to the factor_type, and an optional "name" argument - for specifying the name of a named factor group. - - Example: - To add an ORFactor to a FactorGraph fg, run:: - - fg.add_factor_by_type( - variables=variables_for_OR_factor, - factor_type=logical.ORFactor - ) - """ - if factor_type not in FAC_TO_VAR_UPDATES: - raise ValueError( - f"Type {factor_type} is not one of the supported factor types {FAC_TO_VAR_UPDATES.keys()}" - ) - - name = kwargs.pop("name", None) - factor = factor_type(variables, *args, **kwargs) factor_group = groups.SingleFactorGroup( - variables_for_factors=[variables], + variables_for_factors=[factor.variables], factor=factor, ) - self._register_factor_group(factor_group, name) - - def add_factor_group(self, factory: Callable, *args, **kwargs) -> None: - """Add a factor group to the factor graph - - Args: - factory: Factory function that takes args and kwargs as input and outputs a factor group. - args: Args to be passed to the factory function. - kwargs: kwargs to be passed to the factory function, and an optional "name" argument - for specifying the name of a named factor group. - """ - name = kwargs.pop("name", None) - factor_group = factory(*args, **kwargs) - self._register_factor_group(factor_group, name) + self.add_factor_group(factor_group, name) - def _register_factor_group( + def add_factor_group( self, factor_group: groups.FactorGroup, name: Optional[str] = None ) -> None: - """Register a factor group to the factor graph, by updating the factor graph state. + """Add a FactorGroup to the FactorGraph, by updating the FactorGraphState. Args: - factor_group: The factor group to be registered to the factor graph. - name: Optional name of the factor group. + factor_group: The FactorGroup to be added to the FactorGraph. + name: Optional name of the FactorGroup. Raises: - ValueError: If the factor group with the same name or a factor involving the same variables - already exists in the factor graph. + ValueError: If the factor group with the same name or a Factor involving the same variables + already exists in the FactorGraph. """ if name in self._named_factor_groups: raise ValueError( @@ -988,7 +924,7 @@ def run_bp( ftov_msgs, edges_num_states, max_msg_size ) - @jax.checkpoint + # @jax.checkpoint def update(msgs: jnp.ndarray, _) -> Tuple[jnp.ndarray, None]: # Compute new variable to factor messages by message passing vtof_msgs = infer.pass_var_to_fac_messages( diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 04b22ec4..b777f5a4 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -123,7 +123,7 @@ class FactorGroup: def __post_init__(self): if len(self.variables_for_factors) == 0: - raise ValueError("Do not add a factor group with no factors.") + raise ValueError("Cannot create a FactorGroup with no Factor.") def __getitem__(self, variables: Sequence[Tuple[int, int]]) -> Any: """Function to query individual factors in the factor group @@ -277,6 +277,10 @@ def __post_init__(self): if not hasattr(self, key): object.__setattr__(self, key, getattr(self.factor, key)) + object.__setattr__( + self, "log_potentials", getattr(self.factor, "log_potentials") + ) + def _get_variables_to_factors( self, ) -> OrderedDict[FrozenSet, nodes.Factor]: diff --git a/pgmax/groups/logical.py b/pgmax/groups/logical.py index 0f4714d5..83115945 100644 --- a/pgmax/groups/logical.py +++ b/pgmax/groups/logical.py @@ -22,6 +22,10 @@ class LogicalFactorGroup(groups.FactorGroup): edge_states_offset: int = field(init=False) + def __post_init__(self): + super().__post_init__() + object.__setattr__(self, "factor_configs", None) + def _get_variables_to_factors( self, ) -> OrderedDict[FrozenSet, logical.LogicalFactor]: diff --git a/tests/factors/test_and.py b/tests/factors/test_and.py index 84e37175..147d57a0 100644 --- a/tests/factors/test_and.py +++ b/tests/factors/test_and.py @@ -3,6 +3,7 @@ import jax import numpy as np +from pgmax.factors.enumeration import EnumerationFactor from pgmax.fg import graph from pgmax.groups import logical from pgmax.groups import variables as vgroup @@ -94,26 +95,29 @@ def test_run_bp_with_ANDFactors(): if factor_idx < num_factors // 2: # Add the first half of factors to FactorGraph1 - fg1.add_factor( + enum_factor = EnumerationFactor( variables=variables_for_factors1[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) + fg1.add_factor(enum_factor) else: if idx != 0: # Add the second half of factors to FactorGraph2 - fg2.add_factor( + enum_factor = EnumerationFactor( variables=variables_for_factors2[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) + fg2.add_factor(enum_factor) else: # Add all the EnumerationFactors to FactorGraph1 for the first iter - fg1.add_factor( + enum_factor = EnumerationFactor( variables=variables_for_factors1[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) + fg1.add_factor(enum_factor) # Option 2: Define the ANDFactors num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0) @@ -136,14 +140,11 @@ def test_run_bp_with_ANDFactors(): variables_for_factors2[factor_idx] ) if idx != 0: - fg1.add_factor_group( - factory=logical.ANDFactorGroup, - variables_for_factors=variables_for_ANDFactors_fg1, - ) - fg2.add_factor_group( - factory=logical.ANDFactorGroup, - variables_for_factors=variables_for_ANDFactors_fg2, - ) + factor_group = logical.ANDFactorGroup(variables_for_ANDFactors_fg1) + fg1.add_factor_group(factor_group) + + factor_group = logical.ANDFactorGroup(variables_for_ANDFactors_fg2) + fg2.add_factor_group(factor_group) # Run inference bp1 = graph.BP(fg1.bp_state, temperature=temperature) diff --git a/tests/factors/test_or.py b/tests/factors/test_or.py index c4bdc758..76cc5107 100644 --- a/tests/factors/test_or.py +++ b/tests/factors/test_or.py @@ -3,6 +3,7 @@ import jax import numpy as np +from pgmax.factors.enumeration import EnumerationFactor from pgmax.fg import graph from pgmax.groups import logical from pgmax.groups import variables as vgroup @@ -92,26 +93,29 @@ def test_run_bp_with_ORFactors(): if factor_idx < num_factors // 2: # Add the first half of factors to FactorGraph1 - fg1.add_factor( + enum_factor = EnumerationFactor( variables=variables_for_factors1[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) + fg1.add_factor(enum_factor) else: if idx != 0: # Add the second half of factors to FactorGraph2 - fg2.add_factor( + enum_factor = EnumerationFactor( variables=variables_for_factors2[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) + fg2.add_factor(enum_factor) else: # Add all the EnumerationFactors to FactorGraph1 for the first iter - fg1.add_factor( + enum_factor = EnumerationFactor( variables=variables_for_factors1[factor_idx], factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) + fg1.add_factor(enum_factor) # Option 2: Define the ORFactors num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0) @@ -134,14 +138,11 @@ def test_run_bp_with_ORFactors(): variables_for_factors2[factor_idx] ) if idx != 0: - fg1.add_factor_group( - factory=logical.ORFactorGroup, - variables_for_factors=variables_for_ORFactors_fg1, - ) - fg2.add_factor_group( - factory=logical.ORFactorGroup, - variables_for_factors=variables_for_ORFactors_fg2, - ) + factor_group = logical.ORFactorGroup(variables_for_ORFactors_fg1) + fg1.add_factor_group(factor_group) + + factor_group = logical.ORFactorGroup(variables_for_ORFactors_fg2) + fg2.add_factor_group(factor_group) # Run inference bp1 = graph.BP(fg1.bp_state, temperature=temperature) diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py index 52090b1a..5fdf2815 100644 --- a/tests/fg/test_graph.py +++ b/tests/fg/test_graph.py @@ -13,24 +13,22 @@ def test_factor_graph(): + # TODO: remove factor graph name + vg = vgroup.VariableDict(variable_names=(0,), num_states=15) fg = graph.FactorGraph(vg) - fg.add_factor_by_type( - factor_type=enumeration_factor.EnumerationFactor, + factor = enumeration_factor.EnumerationFactor( variables=[vg[0]], factor_configs=np.arange(15)[:, None], log_potentials=np.zeros(15), - name="test", ) + fg.add_factor(factor, name="test") + with pytest.raises( ValueError, match="A factor group with the name test already exists. Please choose a different name", ): - fg.add_factor( - variables=[vg[0]], - factor_configs=np.arange(15)[:, None], - name="test", - ) + fg.add_factor(factor, name="test") with pytest.raises( ValueError, @@ -38,18 +36,7 @@ def test_factor_graph(): f"A Factor of type {enumeration_factor.EnumerationFactor} involving variables {frozenset([(0, 15)])} already exists." ), ): - fg.add_factor( - variables=[vg[0]], - factor_configs=np.arange(10)[:, None], - ) - - with pytest.raises( - ValueError, - match=re.escape( - f"Type {groups.FactorGroup} is not one of the supported factor types {FAC_TO_VAR_UPDATES.keys()}" - ), - ): - fg.add_factor_by_type(variables=[vg[0]], factor_type=groups.FactorGroup) + fg.add_factor(factor) def test_factor_adding(): @@ -57,11 +44,8 @@ def test_factor_adding(): B = vgroup.NDVariableArray(num_states=2, shape=(10,)) fg = graph.FactorGraph(variables=[A, B]) - with pytest.raises(ValueError, match="Do not add a factor group with no factors."): - fg.add_factor_group( - factory=logical.ORFactorGroup, - variables_for_factors=[], - ) + with pytest.raises(ValueError, match="Cannot create a FactorGroup with no Factor."): + factor_group = logical.ORFactorGroup(variables_for_factors=[]) variables0 = (A[0], B[0]) variables1 = (A[1], B[1]) @@ -78,17 +62,16 @@ def test_factor_adding(): def test_bp_state(): vg = vgroup.VariableDict(variable_names=(0,), num_states=15) fg0 = graph.FactorGraph(vg) - fg0.add_factor( - variables=[vg[0]], - factor_configs=np.arange(10)[:, None], - name="test", - ) - fg1 = graph.FactorGraph(vg) - fg1.add_factor( + factor = enumeration_factor.EnumerationFactor( variables=[vg[0]], factor_configs=np.arange(15)[:, None], - name="test", + log_potentials=np.zeros(15), ) + fg0.add_factor(factor, name="test") + + fg1 = graph.FactorGraph(vg) + fg1.add_factor(factor, name="test") + with pytest.raises( ValueError, match="log_potentials, ftov_msgs and evidence should be derived from the same fg_state", @@ -103,98 +86,100 @@ def test_bp_state(): def test_log_potentials(): vg = vgroup.VariableDict(variable_names=(0,), num_states=15) fg = graph.FactorGraph(vg) - fg.add_factor( - variables=[vg[0]], - factor_configs=np.arange(10)[:, None], - name="test", - ) - with pytest.raises( - ValueError, - match=re.escape("Expected log potentials shape (10,) for factor group test."), - ): - fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15)) - - with pytest.raises( - ValueError, - match=re.escape("Invalid name (0, 15) for log potentials updates."), - ): - fg.bp_state.log_potentials[vg[0]] = np.zeros(10) - - with pytest.raises( - ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)") - ): - graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15)) - - log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10)) - assert jnp.all(log_potentials["test"] == jnp.zeros(10)) - - -def test_ftov_msgs(): - vg = vgroup.VariableDict(variable_names=(0,), num_states=15) - fg = graph.FactorGraph(vg) - fg.add_factor( - variables=[vg[0]], - factor_configs=np.arange(10)[:, None], - name="test", - ) - with pytest.raises( - ValueError, - match=re.escape("Invalid names for setting messages"), - ): - fg.bp_state.ftov_msgs[0] = np.ones(10) - - with pytest.raises( - ValueError, - match=re.escape( - "Given belief shape (10,) does not match expected shape (15,) for variable (0, 15)." - ), - ): - fg.bp_state.ftov_msgs[vg[0]] = np.ones(10) - - with pytest.raises( - ValueError, match=re.escape("Expected messages shape (15,). Got (10,)") - ): - graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10)) - - ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15)) - with pytest.raises( - TypeError, match=re.escape("'FToVMessages' object is not subscriptable") - ): - ftov_msgs[(10,)] - - -def test_evidence(): - vg = vgroup.VariableDict(variable_names=(0,), num_states=15) - fg = graph.FactorGraph(vg) - fg.add_factor( + factor = enumeration_factor.EnumerationFactor( variables=[vg[0]], - factor_configs=np.arange(10)[:, None], - name="test", - ) - with pytest.raises( - ValueError, match=re.escape("Expected evidence shape (15,). Got (10,).") - ): - graph.Evidence(fg_state=fg.fg_state, value=np.zeros(10)) - - evidence = graph.Evidence(fg_state=fg.fg_state, value=np.zeros(15)) - assert jnp.all(evidence.value == jnp.zeros(15)) - - -def test_bp(): - vg = vgroup.VariableDict(variable_names=(0,), num_states=15) - fg = graph.FactorGraph(vg) - fg.add_factor( - variables=[vg[0]], - factor_configs=np.arange(10)[:, None], - name="test", - ) - bp = graph.BP(fg.bp_state, temperature=0) - bp_arrays = bp.update() - bp_arrays = bp.update( - bp_arrays=bp_arrays, - ftov_msgs_updates={vg[0]: np.zeros(15)}, + factor_configs=np.arange(15)[:, None], + log_potentials=np.zeros(15), ) - bp_arrays = bp.run_bp(bp_arrays, num_iters=1) - bp_arrays = replace(bp_arrays, log_potentials=jnp.zeros((10))) - bp_state = bp.to_bp_state(bp_arrays) - assert bp_state.fg_state == fg.fg_state + fg.add_factor(factor, name="test") + + # with pytest.raises( + # ValueError, + # match=re.escape("Expected log potentials shape (10,) for factor group test."), + # ): + # fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15)) + + # with pytest.raises( + # ValueError, + # match=re.escape("Invalid name (0, 15) for log potentials updates."), + # ): + # fg.bp_state.log_potentials[vg[0]] = np.zeros(10) + + # with pytest.raises( + # ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)") + # ): + # graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15)) + + # log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10)) + # assert jnp.all(log_potentials["test"] == jnp.zeros(10)) + + +# def test_ftov_msgs(): +# vg = vgroup.VariableDict(variable_names=(0,), num_states=15) +# fg = graph.FactorGraph(vg) +# fg.add_factor( +# variables=[vg[0]], +# factor_configs=np.arange(10)[:, None], +# name="test", +# ) +# with pytest.raises( +# ValueError, +# match=re.escape("Invalid names for setting messages"), +# ): +# fg.bp_state.ftov_msgs[0] = np.ones(10) + +# with pytest.raises( +# ValueError, +# match=re.escape( +# "Given belief shape (10,) does not match expected shape (15,) for variable (0, 15)." +# ), +# ): +# fg.bp_state.ftov_msgs[vg[0]] = np.ones(10) + +# with pytest.raises( +# ValueError, match=re.escape("Expected messages shape (15,). Got (10,)") +# ): +# graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10)) + +# ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15)) +# with pytest.raises( +# TypeError, match=re.escape("'FToVMessages' object is not subscriptable") +# ): +# ftov_msgs[(10,)] + + +# def test_evidence(): +# vg = vgroup.VariableDict(variable_names=(0,), num_states=15) +# fg = graph.FactorGraph(vg) +# fg.add_factor( +# variables=[vg[0]], +# factor_configs=np.arange(10)[:, None], +# name="test", +# ) +# with pytest.raises( +# ValueError, match=re.escape("Expected evidence shape (15,). Got (10,).") +# ): +# graph.Evidence(fg_state=fg.fg_state, value=np.zeros(10)) + +# evidence = graph.Evidence(fg_state=fg.fg_state, value=np.zeros(15)) +# assert jnp.all(evidence.value == jnp.zeros(15)) + + +# def test_bp(): +# vg = vgroup.VariableDict(variable_names=(0,), num_states=15) +# fg = graph.FactorGraph(vg) +# fg.add_factor( +# variables=[vg[0]], +# factor_configs=np.arange(10)[:, None], +# name="test", +# ) +# bp = graph.BP(fg.bp_state, temperature=0) +# bp_arrays = bp.update() +# bp_arrays = bp.update( +# bp_arrays=bp_arrays, +# ftov_msgs_updates={vg[0]: np.zeros(15)}, +# ) +# bp_arrays = bp.run_bp(bp_arrays, num_iters=1) +# bp_arrays = replace(bp_arrays, log_potentials=jnp.zeros((10))) +# bp_state = bp.to_bp_state(bp_arrays) +# assert bp_state.fg_state == fg.fg_state diff --git a/tests/fg/test_wiring.py b/tests/fg/test_wiring.py index 4966a66e..5b68a974 100644 --- a/tests/fg/test_wiring.py +++ b/tests/fg/test_wiring.py @@ -20,10 +20,11 @@ def test_wiring_with_PairwiseFactorGroup(): # First test that compile_wiring enforces the correct factor_edges_num_states shape fg = graph.FactorGraph(variables=[A, B]) - fg.add_factor_group( - factory=enumeration.PairwiseFactorGroup, - variables_for_factors=[[A[idx], B[idx]] for idx in range(10)], + factor_group = enumeration.PairwiseFactorGroup( + variables_for_factors=[[A[idx], B[idx]] for idx in range(10)] ) + fg.add_factor_group(factor_group) + factor_group = fg.factor_groups[enumeration_factor.EnumerationFactor][0] object.__setattr__( factor_group, "factor_configs", factor_group.factor_configs[:, :1] @@ -36,32 +37,30 @@ def test_wiring_with_PairwiseFactorGroup(): # FactorGraph with a single PairwiseFactorGroup fg1 = graph.FactorGraph(variables=[A, B]) - fg1.add_factor_group( - factory=enumeration.PairwiseFactorGroup, - variables_for_factors=[[A[idx], B[idx]] for idx in range(10)], + factor_group = enumeration.PairwiseFactorGroup( + variables_for_factors=[[A[idx], B[idx]] for idx in range(10)] ) + fg1.add_factor_group(factor_group) assert len(fg1.factor_groups[enumeration_factor.EnumerationFactor]) == 1 # FactorGraph with multiple PairwiseFactorGroup fg2 = graph.FactorGraph(variables=[A, B]) for idx in range(10): - fg2.add_factor_group( - factory=enumeration.PairwiseFactorGroup, - variables_for_factors=[[A[idx], B[idx]]], + factor_group = enumeration.PairwiseFactorGroup( + variables_for_factors=[[A[idx], B[idx]]] ) + fg2.add_factor_group(factor_group) assert len(fg2.factor_groups[enumeration_factor.EnumerationFactor]) == 10 # FactorGraph with multiple SingleFactorGroup fg3 = graph.FactorGraph(variables=[A, B]) for idx in range(10): - fg3.add_factor_by_type( + factor = enumeration_factor.EnumerationFactor( variables=[A[idx], B[idx]], - factor_type=enumeration_factor.EnumerationFactor, - **{ - "factor_configs": np.array([[0, 0], [0, 1], [1, 0], [1, 1]]), - "log_potentials": np.zeros((4,)), - } + factor_configs=np.array([[0, 0], [0, 1], [1, 0], [1, 1]]), + log_potentials=np.zeros((4,)), ) + fg3.add_factor(factor) assert len(fg3.factor_groups[enumeration_factor.EnumerationFactor]) == 10 assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors) @@ -97,28 +96,28 @@ def test_wiring_with_ORFactorGroup(): # FactorGraph with a single ORFactorGroup fg1 = graph.FactorGraph(variables=[A, B, C]) - fg1.add_factor_group( - factory=logical.ORFactorGroup, + factor_group = logical.ORFactorGroup( variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)], ) + fg1.add_factor_group(factor_group) assert len(fg1.factor_groups[logical_factor.ORFactor]) == 1 # FactorGraph with multiple ORFactorGroup fg2 = graph.FactorGraph(variables=[A, B, C]) for idx in range(10): - fg2.add_factor_group( - factory=logical.ORFactorGroup, + factor_group = logical.ORFactorGroup( variables_for_factors=[[A[idx], B[idx], C[idx]]], ) + fg2.add_factor_group(factor_group) assert len(fg2.factor_groups[logical_factor.ORFactor]) == 10 # FactorGraph with multiple SingleFactorGroup fg3 = graph.FactorGraph(variables=[A, B, C]) for idx in range(10): - fg3.add_factor_by_type( + factor = logical_factor.ORFactor( variables=[A[idx], B[idx], C[idx]], - factor_type=logical_factor.ORFactor, ) + fg3.add_factor(factor) assert len(fg3.factor_groups[logical_factor.ORFactor]) == 10 assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors) @@ -152,28 +151,28 @@ def test_wiring_with_ANDFactorGroup(): # FactorGraph with a single ANDFactorGroup fg1 = graph.FactorGraph(variables=[A, B, C]) - fg1.add_factor_group( - factory=logical.ANDFactorGroup, + factor_group = logical.ANDFactorGroup( variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)], ) + fg1.add_factor_group(factor_group) assert len(fg1.factor_groups[logical_factor.ANDFactor]) == 1 # FactorGraph with multiple ANDFactorGroup fg2 = graph.FactorGraph(variables=[A, B, C]) for idx in range(10): - fg2.add_factor_group( - factory=logical.ANDFactorGroup, + factor_group = logical.ANDFactorGroup( variables_for_factors=[[A[idx], B[idx], C[idx]]], ) + fg2.add_factor_group(factor_group) assert len(fg2.factor_groups[logical_factor.ANDFactor]) == 10 # FactorGraph with multiple SingleFactorGroup fg3 = graph.FactorGraph(variables=[A, B, C]) for idx in range(10): - fg3.add_factor_by_type( + factor = logical_factor.ANDFactor( variables=[A[idx], B[idx], C[idx]], - factor_type=logical_factor.ANDFactor, ) + fg3.add_factor(factor) assert len(fg3.factor_groups[logical_factor.ANDFactor]) == 10 assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors) diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py index 5c30e047..e9067abe 100644 --- a/tests/test_pgmax.py +++ b/tests/test_pgmax.py @@ -6,6 +6,7 @@ from numpy.random import default_rng from scipy.ndimage import gaussian_filter +from pgmax.factors.enumeration import EnumerationFactor from pgmax.fg import graph, nodes from pgmax.groups import enumeration from pgmax.groups import variables as vgroup @@ -294,23 +295,23 @@ def create_valid_suppression_config_arr(suppression_diameter): additional_vars[1, row + 1, col], ] if row % 2 == 0: - fg.add_factor( + factor = EnumerationFactor( variables=curr_names, factor_configs=valid_configs_non_supp, log_potentials=np.zeros( valid_configs_non_supp.shape[0], dtype=float ), - name=(row, col), ) + fg.add_factor(factor, name=(row, col)) else: - fg.add_factor( + factor = EnumerationFactor( variables=curr_names, factor_configs=valid_configs_non_supp, log_potentials=np.zeros( valid_configs_non_supp.shape[0], dtype=float ), - name=(row, col), ) + fg.add_factor(factor, name=(row, col)) # Create an EnumerationFactorGroup for vertical suppression factors vert_suppression_names: List[List[Tuple[Any, ...]]] = [] @@ -350,17 +351,18 @@ def create_valid_suppression_config_arr(suppression_diameter): ) # Add the suppression factors to the graph via kwargs - fg.add_factor_group( - factory=enumeration.EnumerationFactorGroup, + factor_group = enumeration.EnumerationFactorGroup( variables_for_factors=vert_suppression_names, factor_configs=valid_configs_supp, ) - fg.add_factor_group( - factory=enumeration.EnumerationFactorGroup, + fg.add_factor_group(factor_group) + + factor_group = enumeration.EnumerationFactorGroup( variables_for_factors=horz_suppression_names, factor_configs=valid_configs_supp, log_potentials=np.zeros(valid_configs_supp.shape[0], dtype=float), ) + fg.add_factor_group(factor_group) # Run BP # Set the evidence @@ -413,12 +415,11 @@ def binary_connected_variables( W_pot = np.zeros((17, 3, 3, 3), dtype=float) for k_row in range(3): for k_col in range(3): - fg.add_factor_group( - factory=enumeration.PairwiseFactorGroup, + factor_group = enumeration.PairwiseFactorGroup( variables_for_factors=binary_connected_variables(28, 28, k_row, k_col), log_potential_matrix=W_pot[:, :, k_row, k_col], - name=(k_row, k_col), ) + fg.add_factor_group(factor_group, name=(k_row, k_col)) # Assign evidence to pixel vars bp_state = fg.bp_state From 88f8e230f601bef213a6106f3181ce7e6f8984df Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Sat, 23 Apr 2022 00:51:36 +0000 Subject: [PATCH 13/35] Test --- examples/rbm.py | 1 - pgmax/fg/graph.py | 1 - tests/fg/test_graph.py | 204 ++++++++++++++++++++--------------------- 3 files changed, 102 insertions(+), 104 deletions(-) diff --git a/examples/rbm.py b/examples/rbm.py index 876f08ee..8ea595ea 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -26,7 +26,6 @@ import matplotlib.pyplot as plt import numpy as np -from pgmax.factors import enumeration as enumeration_factor from pgmax.fg import graph from pgmax.groups import enumeration from pgmax.groups import variables as vgroup diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index aec8b82f..6f804be2 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -35,7 +35,6 @@ from pgmax.bp import infer from pgmax.factors import FAC_TO_VAR_UPDATES from pgmax.fg import groups, nodes -from pgmax.groups.enumeration import EnumerationFactorGroup from pgmax.utils import cached_property diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py index 5fdf2815..79ee0895 100644 --- a/tests/fg/test_graph.py +++ b/tests/fg/test_graph.py @@ -5,10 +5,9 @@ import numpy as np import pytest -from pgmax.factors import FAC_TO_VAR_UPDATES from pgmax.factors import enumeration as enumeration_factor from pgmax.fg import graph, groups -from pgmax.groups import logical +from pgmax.groups import enumeration, logical from pgmax.groups import variables as vgroup @@ -39,13 +38,12 @@ def test_factor_graph(): fg.add_factor(factor) -def test_factor_adding(): +def test_single_factor(): + with pytest.raises(ValueError, match="Cannot create a FactorGroup with no Factor."): + logical.ORFactorGroup(variables_for_factors=[]) + A = vgroup.NDVariableArray(num_states=2, shape=(10,)) B = vgroup.NDVariableArray(num_states=2, shape=(10,)) - fg = graph.FactorGraph(variables=[A, B]) - - with pytest.raises(ValueError, match="Cannot create a FactorGroup with no Factor."): - factor_group = logical.ORFactorGroup(variables_for_factors=[]) variables0 = (A[0], B[0]) variables1 = (A[1], B[1]) @@ -86,100 +84,102 @@ def test_bp_state(): def test_log_potentials(): vg = vgroup.VariableDict(variable_names=(0,), num_states=15) fg = graph.FactorGraph(vg) - factor = enumeration_factor.EnumerationFactor( - variables=[vg[0]], - factor_configs=np.arange(15)[:, None], - log_potentials=np.zeros(15), + factor_group = enumeration.EnumerationFactorGroup( + variables_for_factors=[[vg[0]]], + factor_configs=np.arange(10)[:, None], ) - fg.add_factor(factor, name="test") + fg.add_factor_group(factor_group, name="test") - # with pytest.raises( - # ValueError, - # match=re.escape("Expected log potentials shape (10,) for factor group test."), - # ): - # fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15)) - - # with pytest.raises( - # ValueError, - # match=re.escape("Invalid name (0, 15) for log potentials updates."), - # ): - # fg.bp_state.log_potentials[vg[0]] = np.zeros(10) - - # with pytest.raises( - # ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)") - # ): - # graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15)) - - # log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10)) - # assert jnp.all(log_potentials["test"] == jnp.zeros(10)) - - -# def test_ftov_msgs(): -# vg = vgroup.VariableDict(variable_names=(0,), num_states=15) -# fg = graph.FactorGraph(vg) -# fg.add_factor( -# variables=[vg[0]], -# factor_configs=np.arange(10)[:, None], -# name="test", -# ) -# with pytest.raises( -# ValueError, -# match=re.escape("Invalid names for setting messages"), -# ): -# fg.bp_state.ftov_msgs[0] = np.ones(10) - -# with pytest.raises( -# ValueError, -# match=re.escape( -# "Given belief shape (10,) does not match expected shape (15,) for variable (0, 15)." -# ), -# ): -# fg.bp_state.ftov_msgs[vg[0]] = np.ones(10) - -# with pytest.raises( -# ValueError, match=re.escape("Expected messages shape (15,). Got (10,)") -# ): -# graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10)) - -# ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15)) -# with pytest.raises( -# TypeError, match=re.escape("'FToVMessages' object is not subscriptable") -# ): -# ftov_msgs[(10,)] - - -# def test_evidence(): -# vg = vgroup.VariableDict(variable_names=(0,), num_states=15) -# fg = graph.FactorGraph(vg) -# fg.add_factor( -# variables=[vg[0]], -# factor_configs=np.arange(10)[:, None], -# name="test", -# ) -# with pytest.raises( -# ValueError, match=re.escape("Expected evidence shape (15,). Got (10,).") -# ): -# graph.Evidence(fg_state=fg.fg_state, value=np.zeros(10)) - -# evidence = graph.Evidence(fg_state=fg.fg_state, value=np.zeros(15)) -# assert jnp.all(evidence.value == jnp.zeros(15)) - - -# def test_bp(): -# vg = vgroup.VariableDict(variable_names=(0,), num_states=15) -# fg = graph.FactorGraph(vg) -# fg.add_factor( -# variables=[vg[0]], -# factor_configs=np.arange(10)[:, None], -# name="test", -# ) -# bp = graph.BP(fg.bp_state, temperature=0) -# bp_arrays = bp.update() -# bp_arrays = bp.update( -# bp_arrays=bp_arrays, -# ftov_msgs_updates={vg[0]: np.zeros(15)}, -# ) -# bp_arrays = bp.run_bp(bp_arrays, num_iters=1) -# bp_arrays = replace(bp_arrays, log_potentials=jnp.zeros((10))) -# bp_state = bp.to_bp_state(bp_arrays) -# assert bp_state.fg_state == fg.fg_state + with pytest.raises( + ValueError, + match=re.escape("Expected log potentials shape (10,) for factor group test."), + ): + fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15)) + + with pytest.raises( + ValueError, + match=re.escape("Invalid name (0, 15) for log potentials updates."), + ): + fg.bp_state.log_potentials[vg[0]] = np.zeros(10) + + with pytest.raises( + ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)") + ): + graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15)) + + log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10)) + assert jnp.all(log_potentials["test"] == jnp.zeros(10)) + + +def test_ftov_msgs(): + vg = vgroup.VariableDict(variable_names=(0,), num_states=15) + fg = graph.FactorGraph(vg) + factor_group = enumeration.EnumerationFactorGroup( + variables_for_factors=[[vg[0]]], + factor_configs=np.arange(10)[:, None], + ) + fg.add_factor_group(factor_group, name="test") + + with pytest.raises( + ValueError, + match=re.escape("Invalid names for setting messages"), + ): + fg.bp_state.ftov_msgs[0] = np.ones(10) + + with pytest.raises( + ValueError, + match=re.escape( + "Given belief shape (10,) does not match expected shape (15,) for variable (0, 15)." + ), + ): + fg.bp_state.ftov_msgs[vg[0]] = np.ones(10) + + with pytest.raises( + ValueError, match=re.escape("Expected messages shape (15,). Got (10,)") + ): + graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(10)) + + ftov_msgs = graph.FToVMessages(fg_state=fg.fg_state, value=np.zeros(15)) + with pytest.raises( + TypeError, match=re.escape("'FToVMessages' object is not subscriptable") + ): + ftov_msgs[(10,)] + + +def test_evidence(): + vg = vgroup.VariableDict(variable_names=(0,), num_states=15) + fg = graph.FactorGraph(vg) + factor_group = enumeration.EnumerationFactorGroup( + variables_for_factors=[[vg[0]]], + factor_configs=np.arange(10)[:, None], + ) + fg.add_factor_group(factor_group, name="test") + + with pytest.raises( + ValueError, match=re.escape("Expected evidence shape (15,). Got (10,).") + ): + graph.Evidence(fg_state=fg.fg_state, value=np.zeros(10)) + + evidence = graph.Evidence(fg_state=fg.fg_state, value=np.zeros(15)) + assert jnp.all(evidence.value == jnp.zeros(15)) + + +def test_bp(): + vg = vgroup.VariableDict(variable_names=(0,), num_states=15) + fg = graph.FactorGraph(vg) + factor_group = enumeration.EnumerationFactorGroup( + variables_for_factors=[[vg[0]]], + factor_configs=np.arange(10)[:, None], + ) + fg.add_factor_group(factor_group, name="test") + + bp = graph.BP(fg.bp_state, temperature=0) + bp_arrays = bp.update() + bp_arrays = bp.update( + bp_arrays=bp_arrays, + ftov_msgs_updates={vg[0]: np.zeros(15)}, + ) + bp_arrays = bp.run_bp(bp_arrays, num_iters=1) + bp_arrays = replace(bp_arrays, log_potentials=jnp.zeros((10))) + bp_state = bp.to_bp_state(bp_arrays) + assert bp_state.fg_state == fg.fg_state From 033d17661c32f8708910843e547a9f32ca11af49 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Sat, 23 Apr 2022 01:14:38 +0000 Subject: [PATCH 14/35] Docstring --- examples/gmrf.py | 1 - pgmax/factors/enumeration.py | 3 ++- pgmax/fg/graph.py | 35 +++++++++++++++++------------------ pgmax/fg/groups.py | 6 +++--- pgmax/groups/variables.py | 9 +++++++++ 5 files changed, 31 insertions(+), 23 deletions(-) diff --git a/examples/gmrf.py b/examples/gmrf.py index 363aca85..51977d38 100644 --- a/examples/gmrf.py +++ b/examples/gmrf.py @@ -88,7 +88,6 @@ ) fg.add_factor_group(factor_group, name="diagonal0") - factor_group = enumeration.PairwiseFactorGroup( variables_for_factors=[ [variables[ii, jj], variables[ii - 1, jj + 1]] diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py index b0610c5a..d461d3ab 100644 --- a/pgmax/factors/enumeration.py +++ b/pgmax/factors/enumeration.py @@ -1,5 +1,6 @@ """Defines an enumeration factor""" +import functools from dataclasses import dataclass from typing import List, Mapping, Sequence, Tuple, Union @@ -262,7 +263,7 @@ def _compile_enumeration_wiring_numba( ) -# @functools.partial(jax.jit, static_argnames=("num_val_configs", "temperature")) +@functools.partial(jax.jit, static_argnames=("num_val_configs", "temperature")) def pass_enum_fac_to_var_messages( vtof_msgs: jnp.ndarray, factor_configs_edge_states: jnp.ndarray, diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 6f804be2..5295d3ae 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -108,7 +108,7 @@ def add_factor(self, factor: nodes.Factor, name: Optional[str] = None) -> None: Args: factor: The factor to be added to the factor graph. - name: Optional name of the FactorGroup. + name: Optional name of the SingleFactorGroup created. """ factor_group = groups.SingleFactorGroup( variables_for_factors=[factor.variables], @@ -350,14 +350,12 @@ class FactorGraphState: wiring: OrderedDict[type, nodes.Wiring] def __post_init__(self): - for this_field in self.__dataclass_fields__: - if isinstance(getattr(self, this_field), np.ndarray): - getattr(self, this_field).flags.writeable = False + for field in self.__dataclass_fields__: + if isinstance(getattr(self, field), np.ndarray): + getattr(self, field).flags.writeable = False - if isinstance(getattr(self, this_field), Mapping): - object.__setattr__( - self, this_field, MappingProxyType(getattr(self, this_field)) - ) + if isinstance(getattr(self, field), Mapping): + object.__setattr__(self, field, MappingProxyType(getattr(self, field))) @dataclass(frozen=True, eq=False) @@ -756,9 +754,9 @@ class BPArrays: evidence: Union[np.ndarray, jnp.ndarray] def __post_init__(self): - for this_field in self.__dataclass_fields__: - if isinstance(getattr(self, this_field), np.ndarray): - getattr(self, this_field).flags.writeable = False + for field in self.__dataclass_fields__: + if isinstance(getattr(self, field), np.ndarray): + getattr(self, field).flags.writeable = False def tree_flatten(self): return jax.tree_util.tree_flatten(asdict(self)) @@ -981,16 +979,16 @@ def to_bp_state(bp_arrays: BPArrays) -> BPState: ) def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]: - """Function that recovers meaningful structured data from internal flat data array + """Function that returns unflattened beliefs from the flat beliefs Args: - variable_groups: TODO - - Returns: - Meaningful structured data, with structure matching that of self.variable_group_container. + flat_beliefs: Flattened array of beliefs + variable_groups: All the variable groups in the FactorGraph. Raises: - ValueError: if flat_data is not of the right shape + ValueError: If + (1) flat_beliefs is not one dimensional + (2) flat_beliefs is not of the right shape """ if flat_beliefs.ndim != 1: @@ -998,6 +996,7 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]: f"Can only unflatten 1D array. Got a {flat_beliefs.ndim}D array." ) + # TODO: make sure this is not too slow num_variables = 0 num_variable_states = 0 for variable_group in variable_groups: @@ -1020,7 +1019,7 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]: use_num_states = True else: raise ValueError( - f"flat_data should be either of shape (num_variables(={num_variables}),), " + f"flat_beliefs should be either of shape (num_variables(={num_variables}),), " f"or (num_variable_states(={num_variable_states}),). " f"Got {flat_beliefs.shape}" ) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index b777f5a4..62c23af1 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -27,10 +27,10 @@ @dataclass(frozen=True, eq=False) class VariableGroup: """Class to represent a group of variables. + Each variable is represented via a tuple of the form (variable hash/name, number of states) - All variables in the group are assumed to have the same size. Additionally, the - variables are indexed by a variable name, and can be retrieved by direct indexing (even indexing - a sequence of variable names) of the VariableGroup. + Attributes: + random_hash: Hash of the VariableGroup """ def __post_init__(self): diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index b7e0d66b..3813a364 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -168,6 +168,15 @@ def variables(self) -> List[Tuple]: return list(zip(self.variable_names, self.num_states)) def __getitem__(self, val): + """Given a variable name retrieve the associated variable, returned via a tuple of the form + (variable name, number of states) + + Args: + val: a variable index or slice + + Returns: + The queried variable + """ if val not in self.variable_names: raise ValueError(f"Variable {val} is not in VariableDict") return (val, self.num_states[0]) From 89767c81a39c0d6b4c241ab6ea4ac07f60ced1bd Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Mon, 25 Apr 2022 21:18:51 +0000 Subject: [PATCH 15/35] Coverage --- pgmax/fg/graph.py | 46 ++++++++++++++++++--------------------- pgmax/fg/groups.py | 2 ++ pgmax/groups/variables.py | 34 +++++++++++++++++++---------- tests/fg/test_graph.py | 28 ++++++++++++++++++++++-- tests/fg/test_groups.py | 15 +++++++++++++ 5 files changed, 86 insertions(+), 39 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 5295d3ae..3dd0ebb1 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -35,6 +35,7 @@ from pgmax.bp import infer from pgmax.factors import FAC_TO_VAR_UPDATES from pgmax.fg import groups, nodes +from pgmax.groups import variables from pgmax.utils import cached_property @@ -61,6 +62,25 @@ def __post_init__(self): if isinstance(self.variables, groups.VariableGroup): self.variables = [self.variables] + # Check ids are unique + vg_names = [] + vg_array_names = [] + for variable_group in self.variables: + vg_name = variable_group.__hash__() + if vg_name in vg_names: + raise ValueError("Two objects have the same name") + vg_names.append(vg_name) + if isinstance(variable_group, variables.NDVariableArray): + start_name, end_name = ( + vg_name, + variable_group.variable_names.flatten()[-1], + ) + for var_array_name in vg_array_names: + start_name2, end_name2 = var_array_name + if max(start_name, start_name2) <= min(end_name, end_name2): + raise ValueError("Two NDVariableArrays have overlapping names") + vg_array_names.append((start_name, end_name)) + # Useful objects to build the FactorGraph self._factor_types_to_groups: OrderedDict[ Type, List[groups.FactorGroup] @@ -91,7 +111,7 @@ def __post_init__(self): self._num_var_states = vars_num_states_cumsum self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {} - print("2", time.time() - start) + print("Init", time.time() - start) def __hash__(self) -> int: all_factor_groups = tuple( @@ -469,9 +489,6 @@ def __getitem__(self, name: Any) -> np.ndarray: The queried log potentials. """ value = cast(np.ndarray, self.value) - if not isinstance(name, Hashable): - name = frozenset(name) - if name in self.fg_state.named_factor_groups: factor_group = self.fg_state.named_factor_groups[name] start = self.fg_state.factor_group_to_potentials_starts[factor_group] @@ -496,9 +513,6 @@ def __setitem__( data: Array containing the log potentials for the named factor group or the factor. """ - if not isinstance(name, Hashable): - name = frozenset(name) - object.__setattr__( self, "value", @@ -996,7 +1010,6 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]: f"Can only unflatten 1D array. Got a {flat_beliefs.ndim}D array." ) - # TODO: make sure this is not too slow num_variables = 0 num_variable_states = 0 for variable_group in variable_groups: @@ -1004,25 +1017,10 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]: num_variables += len(variables) num_variable_states += sum([variable[1] for variable in variables]) - # if isinstance(variable_group, vgroup.NDVariableArray): - # num_variables += variable_group.num_states.size - # num_variable_states += variable_group.num_states.sum() - # elif isinstance(variable_group, vgroup.VariableDict): - # num_variables += len(variable_group.variables) - # num_variable_states += ( - # len(variable_group.variables) * variable_group.variables[0].num_states - # ) - if flat_beliefs.shape[0] == num_variables: use_num_states = False elif flat_beliefs.shape[0] == num_variable_states: use_num_states = True - else: - raise ValueError( - f"flat_beliefs should be either of shape (num_variables(={num_variables}),), " - f"or (num_variable_states(={num_variable_states}),). " - f"Got {flat_beliefs.shape}" - ) beliefs = {} start = 0 @@ -1030,10 +1028,8 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]: variables = variable_group.variables if not use_num_states: length = len(variables) - # length = variable_group.num_states.sum() else: length = sum([variable[1] for variable in variables]) - # length = variable_group.num_states.size beliefs[variable_group] = variable_group.unflatten( flat_beliefs[start : start + length] diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 62c23af1..192eee84 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -34,6 +34,8 @@ class VariableGroup: """ def __post_init__(self): + # Overwite default hash to have larger differences + random.seed(id(self)) random_hash = random.randint(0, 2**63) object.__setattr__(self, "random_hash", random_hash) diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index 3813a364..feff8467 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -16,31 +16,34 @@ class NDVariableArray(groups.VariableGroup): """Subclass of VariableGroup for n-dimensional grids of variables. Args: - num_states: The size of the variables in this variable group - shape: a tuple specifying the size of each dimension of the grid (similar to + shape: Tuple specifying the size of each dimension of the grid (similar to the notion of a NumPy ndarray shape) + num_states: An integer or an array specifying the number of states of the + variables in this VariableGroup """ shape: Tuple[int, ...] - num_states: np.ndarray + num_states: Union[int, np.ndarray] def __post_init__(self): super().__post_init__() - if np.isscalar(self.num_states) and np.issubdtype(type(np.int64(10)), int): + if np.isscalar(self.num_states): num_states = np.full(self.shape, fill_value=self.num_states) object.__setattr__(self, "num_states", num_states) elif isinstance(self.num_states, np.ndarray) and np.issubdtype( self.num_states.dtype, int ): if self.num_states.shape != self.shape: - raise ValueError("Should be same shape") + raise ValueError( + f"Expected num_states shape {self.shape}. Got {self.num_states.shape}." + ) else: raise ValueError("num_states entries should be of type np.int") def __getitem__( - self, val: Union[int, tuple, slice] - ) -> Union[Tuple[int, int], List[Tuple[int]]]: + self, val: Union[int, slice, Tuple] + ) -> Union[Tuple[int, int], List[Tuple]]: """Given an index or a slice, retrieve the associated variable(s). Each variable is returned via a tuple of the form (variable hash, number of states) @@ -52,10 +55,13 @@ def __getitem__( Returns: A single variable or a list of variables """ - result = (self.variable_names[val], self.num_states[val]) - if isinstance(val, slice): - return list(zip(result)) - return result + assert isinstance(self.num_states, np.ndarray) + if np.isscalar(self.variable_names[val]): + return (self.variable_names[val], self.num_states[val]) + else: + vars_names = self.variable_names[val].flatten() + vars_num_states = self.num_states[val].flatten() + return list(zip(vars_names, vars_num_states)) @cached_property def variables(self) -> List[Tuple]: @@ -65,6 +71,7 @@ def variables(self) -> List[Tuple]: Returns: List of variables in the VariableGroup """ + assert isinstance(self.num_states, np.ndarray) vars_names = self.variable_names.flatten() vars_num_states = self.num_states.flatten() return list(zip(vars_names, vars_num_states)) @@ -76,7 +83,6 @@ def variable_names(self) -> np.ndarray: Returns: Array of variables names. """ - # Overwite default hash as it does not give enough spacing across consecutive objects indices = np.reshape(np.arange(np.product(self.shape)), self.shape) return self.__hash__() + indices @@ -93,6 +99,8 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: Raises: ValueError: If the data is not of the correct shape. """ + assert isinstance(self.num_states, np.ndarray) + # TODO: what should we do for different number of states -> look at mask_array if data.shape != self.shape and data.shape != self.shape + ( self.num_states.max(), @@ -120,6 +128,8 @@ def unflatten( (1) flat_data is not a 1D array (2) flat_data is not of the right shape """ + assert isinstance(self.num_states, np.ndarray) + if flat_data.ndim != 1: raise ValueError( f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py index 79ee0895..09c766a0 100644 --- a/tests/fg/test_graph.py +++ b/tests/fg/test_graph.py @@ -1,6 +1,7 @@ import re from dataclasses import replace +import jax import jax.numpy as jnp import numpy as np import pytest @@ -12,6 +13,16 @@ def test_factor_graph(): + vg = vgroup.NDVariableArray(num_states=2, shape=(10, 10)) + with pytest.raises(ValueError, match="Two objects have the same name"): + fg = graph.FactorGraph(variables=[vg, vg]) + + vg = vgroup.NDVariableArray(num_states=2, shape=(10, 10)) + vg2 = vgroup.NDVariableArray(num_states=2, shape=(10, 10)) + object.__setattr__(vg2, "random_hash", vg.__hash__() + 10) + with pytest.raises(ValueError, match="Two NDVariableArrays have overlapping names"): + fg = graph.FactorGraph(variables=[vg, vg2]) + # TODO: remove factor graph name vg = vgroup.VariableDict(variable_names=(0,), num_states=15) @@ -147,10 +158,10 @@ def test_ftov_msgs(): def test_evidence(): - vg = vgroup.VariableDict(variable_names=(0,), num_states=15) + vg = vgroup.VariableDict(variable_names=("a",), num_states=15) fg = graph.FactorGraph(vg) factor_group = enumeration.EnumerationFactorGroup( - variables_for_factors=[[vg[0]]], + variables_for_factors=[[vg["a"]]], factor_configs=np.arange(10)[:, None], ) fg.add_factor_group(factor_group, name="test") @@ -163,6 +174,19 @@ def test_evidence(): evidence = graph.Evidence(fg_state=fg.fg_state, value=np.zeros(15)) assert jnp.all(evidence.value == jnp.zeros(15)) + vg2 = vgroup.VariableDict(variable_names=("b",), num_states=15) + with pytest.raises( + ValueError, + match=re.escape( + "Got evidence for a variable or a variable group not in the FactorGraph!" + ), + ): + graph.update_evidence( + jax.device_put(evidence.value), + {vg2["b"]: jax.device_put(np.zeros(15))}, + fg.fg_state, + ) + def test_bp(): vg = vgroup.VariableDict(variable_names=(0,), num_states=15) diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index d6e18e73..b12b49d3 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -50,6 +50,21 @@ def test_variable_dict(): def test_nd_variable_array(): + num_states = np.full((2, 3), fill_value=2) + with pytest.raises( + ValueError, match=re.escape("Expected num_states shape (2, 2). Got (2, 3).") + ): + vgroup.NDVariableArray(shape=(2, 2), num_states=num_states) + + num_states = np.full((2, 3), fill_value=2, dtype=np.float32) + with pytest.raises( + ValueError, match=re.escape("num_states entries should be of type np.int") + ): + vgroup.NDVariableArray(shape=(2, 2), num_states=num_states) + + variable_group = vgroup.NDVariableArray(shape=(5, 5), num_states=2) + assert len(variable_group[:3, :3]) == 9 + variable_group = vgroup.NDVariableArray(shape=(2, 2), num_states=3) with pytest.raises( ValueError, From 1ccfcf5de56644688a9cf90900020a575017ef3d Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Mon, 25 Apr 2022 21:20:15 +0000 Subject: [PATCH 16/35] Coverage --- pgmax/fg/graph.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 3dd0ebb1..dd38740a 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -935,7 +935,7 @@ def run_bp( ftov_msgs, edges_num_states, max_msg_size ) - # @jax.checkpoint + @jax.checkpoint def update(msgs: jnp.ndarray, _) -> Tuple[jnp.ndarray, None]: # Compute new variable to factor messages by message passing vtof_msgs = infer.pass_var_to_fac_messages( @@ -961,8 +961,7 @@ def update(msgs: jnp.ndarray, _) -> Tuple[jnp.ndarray, None]: # update the factor to variable messages delta_msgs = ftov_msgs - msgs msgs = msgs + (1 - damping) * delta_msgs - # Normalize and clip these damped, updated messages before - # returning them. + # Normalize and clip these damped, updated messages before returning them. msgs = infer.normalize_and_clip_msgs(msgs, edges_num_states, max_msg_size) return msgs, None From ecaab6c8196a4b113f70e258d00d8bc5f553f5b1 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Mon, 25 Apr 2022 21:35:06 +0000 Subject: [PATCH 17/35] Coverage 100% --- pgmax/fg/graph.py | 28 +--------------------------- tests/fg/test_graph.py | 6 ++++++ 2 files changed, 7 insertions(+), 27 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index dd38740a..6cfbac10 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -997,38 +997,12 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]: Args: flat_beliefs: Flattened array of beliefs variable_groups: All the variable groups in the FactorGraph. - - Raises: - ValueError: If - (1) flat_beliefs is not one dimensional - (2) flat_beliefs is not of the right shape """ - - if flat_beliefs.ndim != 1: - raise ValueError( - f"Can only unflatten 1D array. Got a {flat_beliefs.ndim}D array." - ) - - num_variables = 0 - num_variable_states = 0 - for variable_group in variable_groups: - variables = variable_group.variables - num_variables += len(variables) - num_variable_states += sum([variable[1] for variable in variables]) - - if flat_beliefs.shape[0] == num_variables: - use_num_states = False - elif flat_beliefs.shape[0] == num_variable_states: - use_num_states = True - beliefs = {} start = 0 for variable_group in variable_groups: variables = variable_group.variables - if not use_num_states: - length = len(variables) - else: - length = sum([variable[1] for variable in variables]) + length = sum([variable[1] for variable in variables]) beliefs[variable_group] = variable_group.unflatten( flat_beliefs[start : start + length] diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py index 09c766a0..c3b3a853 100644 --- a/tests/fg/test_graph.py +++ b/tests/fg/test_graph.py @@ -107,6 +107,12 @@ def test_log_potentials(): ): fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15)) + with pytest.raises( + ValueError, + match=re.escape("Invalid name new_test for log potentials updates."), + ): + fg.bp_state.log_potentials["new_test"] = jnp.zeros((1, 15)) + with pytest.raises( ValueError, match=re.escape("Invalid name (0, 15) for log potentials updates."), From cbd136bf8241430db734668bfa6bb87d96f86a25 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Mon, 25 Apr 2022 23:00:42 +0000 Subject: [PATCH 18/35] Remove factor group names --- examples/gmrf.py | 36 +++++++----- examples/ising_model.py | 4 +- examples/pmp_binary_deconvolution.py | 2 +- pgmax/fg/graph.py | 83 +++++++++++----------------- pgmax/fg/groups.py | 10 ++++ tests/fg/test_graph.py | 47 +++++++--------- tests/fg/test_wiring.py | 20 +++---- tests/test_pgmax.py | 34 ++++++------ 8 files changed, 114 insertions(+), 122 deletions(-) diff --git a/examples/gmrf.py b/examples/gmrf.py index 51977d38..6601e5c7 100644 --- a/examples/gmrf.py +++ b/examples/gmrf.py @@ -38,8 +38,8 @@ # %% # Load saved log potentials -log_potentials = dict(**np.load("example_data/gmrf_log_potentials.npz")) -n_clones = log_potentials.pop("n_clones") +grmf_log_potentials = dict(**np.load("example_data/gmrf_log_potentials.npz")) +n_clones = grmf_log_potentials.pop("n_clones") p_contour = jax.device_put(np.repeat(data["p_contour"], n_clones)) prototype_targets = jax.device_put( np.array( @@ -59,48 +59,55 @@ # %% # Add top-down factors -factor_group = enumeration.PairwiseFactorGroup( +top_down = enumeration.PairwiseFactorGroup( variables_for_factors=[ [variables[ii, jj], variables[ii + 1, jj]] for ii in range(M - 1) for jj in range(N) ], ) -fg.add_factor_group(factor_group, name="top_down") +fg.add_factor_group(top_down) # Add left-right factors -factor_group = enumeration.PairwiseFactorGroup( +left_right = enumeration.PairwiseFactorGroup( variables_for_factors=[ [variables[ii, jj], variables[ii, jj + 1]] for ii in range(M) for jj in range(N - 1) ], ) -fg.add_factor_group(factor_group, name="left_right") +fg.add_factor_group(left_right) # Add diagonal factors -factor_group = enumeration.PairwiseFactorGroup( +diagonal0 = enumeration.PairwiseFactorGroup( variables_for_factors=[ [variables[ii, jj], variables[ii + 1, jj + 1]] for ii in range(M - 1) for jj in range(N - 1) ], ) -fg.add_factor_group(factor_group, name="diagonal0") +fg.add_factor_group(diagonal0) -factor_group = enumeration.PairwiseFactorGroup( +diagonal1 = enumeration.PairwiseFactorGroup( variables_for_factors=[ [variables[ii, jj], variables[ii - 1, jj + 1]] for ii in range(1, M) for jj in range(N - 1) ], ) -fg.add_factor_group(factor_group, name="diagonal1") +fg.add_factor_group(diagonal1) # %% bp = graph.BP(fg.bp_state, temperature=1.0) # %% +log_potentials = { + top_down: grmf_log_potentials["top_down"], + left_right: grmf_log_potentials["left_right"], + diagonal0: grmf_log_potentials["diagonal0"], + diagonal1: grmf_log_potentials["diagonal1"], +} + n_plots = 5 indices = np.random.permutation(noisy_images.shape[0])[:n_plots] fig, ax = plt.subplots(n_plots, 3, figsize=(30, 10 * n_plots)) @@ -199,10 +206,10 @@ def update(step, batch_noisy_images, batch_target_images, opt_state): # %% opt_state = init_fun( { - "top_down": np.random.randn(num_states, num_states), - "left_right": np.random.randn(num_states, num_states), - "diagonal0": np.random.randn(num_states, num_states), - "diagonal1": np.random.randn(num_states, num_states), + top_down: np.random.randn(num_states, num_states), + left_right: np.random.randn(num_states, num_states), + diagonal0: np.random.randn(num_states, num_states), + diagonal1: np.random.randn(num_states, num_states), } ) @@ -227,7 +234,6 @@ def update(step, batch_noisy_images, batch_target_images, opt_state): pbar.update() pbar.set_postfix(loss=value) - batch_indices = indices[idx * batch_size : (idx + 1) * batch_size] batch_noisy_images, batch_target_images = ( noisy_images_train[:10], diff --git a/examples/ising_model.py b/examples/ising_model.py index b2f5aef7..52e3d02c 100644 --- a/examples/ising_model.py +++ b/examples/ising_model.py @@ -43,7 +43,7 @@ variables_for_factors=variables_for_factors, log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]), ) -fg.add_factor_group(factor_group, name="factors") +fg.add_factor_group(factor_group) # %% [markdown] # ### Run inference and visualize results @@ -88,7 +88,7 @@ def loss(log_potentials_updates, evidence_updates): # %% grads = log_potentials_grads( - {"factors": jnp.eye(2)}, + {factor_group: jnp.eye(2)}, {variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))}, ) diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py index 6cdb84f1..d37a0c9f 100644 --- a/examples/pmp_binary_deconvolution.py +++ b/examples/pmp_binary_deconvolution.py @@ -195,7 +195,7 @@ def plot_images(images, display=True, nr=None): fg.add_factor_group(OR_factor_group) print("Time", time.time() - start) -for factor_type, factor_groups in fg.factor_groups.items(): +for factor_type, factor_groups in fg._factor_types_to_groups.items(): if len(factor_groups) > 0: assert len(factor_groups) == 1 print(f"The factor graph contains {factor_groups[0].num_factors} {factor_type}") diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 6cfbac10..815d8ba2 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -62,7 +62,7 @@ def __post_init__(self): if isinstance(self.variables, groups.VariableGroup): self.variables = [self.variables] - # Check ids are unique + # Check variable groups are unique vg_names = [] vg_array_names = [] for variable_group in self.variables: @@ -108,52 +108,34 @@ def __post_init__(self): ) ) vars_num_states_cumsum += vg_num_states_cumsum[-1] - self._num_var_states = vars_num_states_cumsum - self._named_factor_groups: Dict[Hashable, groups.FactorGroup] = {} print("Init", time.time() - start) def __hash__(self) -> int: - all_factor_groups = tuple( - [ - factor_group - for factor_groups_per_type in self._factor_types_to_groups.values() - for factor_group in factor_groups_per_type - ] - ) - return hash(all_factor_groups) + return hash(self.factor_groups) - def add_factor(self, factor: nodes.Factor, name: Optional[str] = None) -> None: + def add_factor(self, factor: nodes.Factor) -> None: """Function to add a single Factor to the FactorGraph. Args: factor: The factor to be added to the factor graph. - name: Optional name of the SingleFactorGroup created. """ factor_group = groups.SingleFactorGroup( variables_for_factors=[factor.variables], factor=factor, ) - self.add_factor_group(factor_group, name) + self.add_factor_group(factor_group) - def add_factor_group( - self, factor_group: groups.FactorGroup, name: Optional[str] = None - ) -> None: + def add_factor_group(self, factor_group: groups.FactorGroup) -> None: """Add a FactorGroup to the FactorGraph, by updating the FactorGraphState. Args: factor_group: The FactorGroup to be added to the FactorGraph. - name: Optional name of the FactorGroup. Raises: ValueError: If the factor group with the same name or a Factor involving the same variables already exists in the FactorGraph. """ - if name in self._named_factor_groups: - raise ValueError( - f"A factor group with the name {name} already exists. Please choose a different name!" - ) - factor_type = factor_group.factor_type for var_names_for_factor in factor_group.variables_for_factors: var_names = frozenset(var_names_for_factor) @@ -165,9 +147,6 @@ def add_factor_group( self._factor_types_to_groups[factor_type].append(factor_group) - if name is not None: - self._named_factor_groups[name] = factor_group - @functools.lru_cache(None) def compute_offsets(self) -> None: """Compute factor messages offsets for the factor types and factor groups @@ -284,20 +263,28 @@ def factors(self) -> OrderedDict[Type, Tuple[nodes.Factor, ...]]: tuple( [ factor - for factor_group in self.factor_groups[factor_type] + for factor_group in self._factor_types_to_groups[ + factor_type + ] for factor in factor_group.factors ] ), ) - for factor_type in self.factor_groups + for factor_type in self._factor_types_to_groups ] ) return factors @property - def factor_groups(self) -> OrderedDict[Type, List[groups.FactorGroup]]: + def factor_groups(self) -> Tuple[groups.FactorGroup, ...]: """Tuple of factor groups in the factor graph""" - return self._factor_types_to_groups + return tuple( + [ + factor_group + for factor_groups_per_type in self._factor_types_to_groups.values() + for factor_group in factor_groups_per_type + ] + ) @cached_property def fg_state(self) -> FactorGraphState: @@ -314,7 +301,7 @@ def fg_state(self) -> FactorGraphState: vars_to_starts=self._vars_to_starts, num_var_states=self._num_var_states, total_factor_num_states=self._total_factor_num_states, - named_factor_groups=copy.copy(self._named_factor_groups), + factor_groups=self.factor_groups, factor_type_to_msgs_range=copy.copy(self._factor_type_to_msgs_range), factor_type_to_potentials_range=copy.copy( self._factor_type_to_potentials_range @@ -350,7 +337,7 @@ class FactorGraphState: contains evidence to the variable. num_var_states: Total number of variable states. total_factor_num_states: Size of the flat ftov messages array. - named_factor_groups: Maps the names of named factor groups to the corresponding factor groups. + factor_groups: Factor groups in the FactorGraph factor_type_to_msgs_range: Maps factors types to their start and end indices in the flat ftov messages. factor_type_to_potentials_range: Maps factor types to their start and end indices in the flat log potentials. factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials. @@ -362,7 +349,7 @@ class FactorGraphState: vars_to_starts: Mapping[Tuple[int, int], int] num_var_states: int total_factor_num_states: int - named_factor_groups: Mapping[Hashable, groups.FactorGroup] + factor_groups: Tuple[groups.FactorGroup, ...] factor_type_to_msgs_range: OrderedDict[type, Tuple[int, int]] factor_type_to_potentials_range: OrderedDict[type, Tuple[int, int]] factor_group_to_potentials_starts: OrderedDict[groups.FactorGroup, int] @@ -430,15 +417,13 @@ def update_log_potentials( (1) Provided log_potentials shape does not match the expected log_potentials shape. (2) Provided name is not valid for log_potentials updates. """ - for name, data in updates.items(): - if name in fg_state.named_factor_groups: - factor_group = fg_state.named_factor_groups[name] - + for factor_group, data in updates.items(): + if factor_group in fg_state.factor_groups: flat_data = factor_group.flatten(data) if flat_data.shape != factor_group.factor_group_log_potentials.shape: raise ValueError( f"Expected log potentials shape {factor_group.factor_group_log_potentials.shape} " - f"for factor group {name}. Got incompatible data shape {data.shape}." + f"for factor group. Got incompatible data shape {data.shape}." ) start = fg_state.factor_group_to_potentials_starts[factor_group] @@ -446,7 +431,7 @@ def update_log_potentials( flat_data ) else: - raise ValueError(f"Invalid name {name} for log potentials updates.") + raise ValueError("Invalid FactorGroup for log potentials updates.") return log_potentials @@ -478,38 +463,34 @@ def __post_init__(self): object.__setattr__(self, "value", self.value) - def __getitem__(self, name: Any) -> np.ndarray: - """Function to query log potentials for a named factor group or a factor. + def __getitem__(self, factor_group: groups.FactorGroup) -> np.ndarray: + """Function to query log potentials for a FactorGroup. Args: - name: Name of a named factor group, or a frozenset containing the set - of connected variables for the queried factor. + factor_group: Queried FactorGroup Returns: The queried log potentials. """ value = cast(np.ndarray, self.value) - if name in self.fg_state.named_factor_groups: - factor_group = self.fg_state.named_factor_groups[name] + if factor_group in self.fg_state.factor_groups: start = self.fg_state.factor_group_to_potentials_starts[factor_group] log_potentials = value[ start : start + factor_group.factor_group_log_potentials.shape[0] ] else: - raise ValueError(f"Invalid name {name} for log potentials updates.") - + raise ValueError("Invalid FactorGroup for log potentials updates.") return log_potentials def __setitem__( self, - name: Any, + factor_group: Any, data: Union[np.ndarray, jnp.ndarray], ): """Set the log potentials for a named factor group or a factor. Args: - name: Name of a named factor group, or a frozenset containing the set - of connected variables for the queried factor. + factor_group: FactorGroup data: Array containing the log potentials for the named factor group or the factor. """ @@ -519,7 +500,7 @@ def __setitem__( np.asarray( update_log_potentials( jax.device_put(self.value), - {name: jax.device_put(data)}, + {factor_group: jax.device_put(data)}, self.fg_state, ) ), diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 192eee84..7c56deef 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -101,6 +101,7 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any: ) +@total_ordering @dataclass(frozen=True, eq=False) class FactorGroup: """Class to represent a group of Factors. @@ -127,6 +128,15 @@ def __post_init__(self): if len(self.variables_for_factors) == 0: raise ValueError("Cannot create a FactorGroup with no Factor.") + def __hash__(self): + return id(self) + + def __eq__(self, other): + return hash(self) == hash(other) + + def __lt__(self, other): + return hash(self) < hash(other) + def __getitem__(self, variables: Sequence[Tuple[int, int]]) -> Any: """Function to query individual factors in the factor group diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py index c3b3a853..e80b32a3 100644 --- a/tests/fg/test_graph.py +++ b/tests/fg/test_graph.py @@ -13,17 +13,14 @@ def test_factor_graph(): - vg = vgroup.NDVariableArray(num_states=2, shape=(10, 10)) + vg1 = vgroup.NDVariableArray(num_states=2, shape=(10, 10)) with pytest.raises(ValueError, match="Two objects have the same name"): - fg = graph.FactorGraph(variables=[vg, vg]) + fg = graph.FactorGraph(variables=[vg1, vg1]) - vg = vgroup.NDVariableArray(num_states=2, shape=(10, 10)) vg2 = vgroup.NDVariableArray(num_states=2, shape=(10, 10)) - object.__setattr__(vg2, "random_hash", vg.__hash__() + 10) + object.__setattr__(vg2, "random_hash", vg1.__hash__() + 10) with pytest.raises(ValueError, match="Two NDVariableArrays have overlapping names"): - fg = graph.FactorGraph(variables=[vg, vg2]) - - # TODO: remove factor graph name + fg = graph.FactorGraph(variables=[vg1, vg2]) vg = vgroup.VariableDict(variable_names=(0,), num_states=15) fg = graph.FactorGraph(vg) @@ -32,13 +29,7 @@ def test_factor_graph(): factor_configs=np.arange(15)[:, None], log_potentials=np.zeros(15), ) - fg.add_factor(factor, name="test") - - with pytest.raises( - ValueError, - match="A factor group with the name test already exists. Please choose a different name", - ): - fg.add_factor(factor, name="test") + fg.add_factor(factor) with pytest.raises( ValueError, @@ -76,10 +67,10 @@ def test_bp_state(): factor_configs=np.arange(15)[:, None], log_potentials=np.zeros(15), ) - fg0.add_factor(factor, name="test") + fg0.add_factor(factor) fg1 = graph.FactorGraph(vg) - fg1.add_factor(factor, name="test") + fg1.add_factor(factor) with pytest.raises( ValueError, @@ -99,23 +90,27 @@ def test_log_potentials(): variables_for_factors=[[vg[0]]], factor_configs=np.arange(10)[:, None], ) - fg.add_factor_group(factor_group, name="test") + fg.add_factor_group(factor_group) with pytest.raises( ValueError, - match=re.escape("Expected log potentials shape (10,) for factor group test."), + match=re.escape("Expected log potentials shape (10,) for factor group."), ): - fg.bp_state.log_potentials["test"] = jnp.zeros((1, 15)) + fg.bp_state.log_potentials[factor_group] = jnp.zeros((1, 15)) with pytest.raises( ValueError, - match=re.escape("Invalid name new_test for log potentials updates."), + match=re.escape("Invalid FactorGroup for log potentials updates."), ): - fg.bp_state.log_potentials["new_test"] = jnp.zeros((1, 15)) + factor_group2 = enumeration.EnumerationFactorGroup( + variables_for_factors=[[vg[0]]], + factor_configs=np.arange(10)[:, None], + ) + fg.bp_state.log_potentials[factor_group2] = jnp.zeros((1, 15)) with pytest.raises( ValueError, - match=re.escape("Invalid name (0, 15) for log potentials updates."), + match=re.escape("Invalid FactorGroup for log potentials updates."), ): fg.bp_state.log_potentials[vg[0]] = np.zeros(10) @@ -125,7 +120,7 @@ def test_log_potentials(): graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(15)) log_potentials = graph.LogPotentials(fg_state=fg.fg_state, value=np.zeros(10)) - assert jnp.all(log_potentials["test"] == jnp.zeros(10)) + assert jnp.all(log_potentials[factor_group] == jnp.zeros(10)) def test_ftov_msgs(): @@ -135,7 +130,7 @@ def test_ftov_msgs(): variables_for_factors=[[vg[0]]], factor_configs=np.arange(10)[:, None], ) - fg.add_factor_group(factor_group, name="test") + fg.add_factor_group(factor_group) with pytest.raises( ValueError, @@ -170,7 +165,7 @@ def test_evidence(): variables_for_factors=[[vg["a"]]], factor_configs=np.arange(10)[:, None], ) - fg.add_factor_group(factor_group, name="test") + fg.add_factor_group(factor_group) with pytest.raises( ValueError, match=re.escape("Expected evidence shape (15,). Got (10,).") @@ -201,7 +196,7 @@ def test_bp(): variables_for_factors=[[vg[0]]], factor_configs=np.arange(10)[:, None], ) - fg.add_factor_group(factor_group, name="test") + fg.add_factor_group(factor_group) bp = graph.BP(fg.bp_state, temperature=0) bp_arrays = bp.update() diff --git a/tests/fg/test_wiring.py b/tests/fg/test_wiring.py index 5b68a974..3df31cb5 100644 --- a/tests/fg/test_wiring.py +++ b/tests/fg/test_wiring.py @@ -25,7 +25,7 @@ def test_wiring_with_PairwiseFactorGroup(): ) fg.add_factor_group(factor_group) - factor_group = fg.factor_groups[enumeration_factor.EnumerationFactor][0] + factor_group = fg.factor_groups[0] object.__setattr__( factor_group, "factor_configs", factor_group.factor_configs[:, :1] ) @@ -41,7 +41,7 @@ def test_wiring_with_PairwiseFactorGroup(): variables_for_factors=[[A[idx], B[idx]] for idx in range(10)] ) fg1.add_factor_group(factor_group) - assert len(fg1.factor_groups[enumeration_factor.EnumerationFactor]) == 1 + assert len(fg1.factor_groups) == 1 # FactorGraph with multiple PairwiseFactorGroup fg2 = graph.FactorGraph(variables=[A, B]) @@ -50,7 +50,7 @@ def test_wiring_with_PairwiseFactorGroup(): variables_for_factors=[[A[idx], B[idx]]] ) fg2.add_factor_group(factor_group) - assert len(fg2.factor_groups[enumeration_factor.EnumerationFactor]) == 10 + assert len(fg2.factor_groups) == 10 # FactorGraph with multiple SingleFactorGroup fg3 = graph.FactorGraph(variables=[A, B]) @@ -61,7 +61,7 @@ def test_wiring_with_PairwiseFactorGroup(): log_potentials=np.zeros((4,)), ) fg3.add_factor(factor) - assert len(fg3.factor_groups[enumeration_factor.EnumerationFactor]) == 10 + assert len(fg3.factor_groups) == 10 assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors) @@ -100,7 +100,7 @@ def test_wiring_with_ORFactorGroup(): variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)], ) fg1.add_factor_group(factor_group) - assert len(fg1.factor_groups[logical_factor.ORFactor]) == 1 + assert len(fg1.factor_groups) == 1 # FactorGraph with multiple ORFactorGroup fg2 = graph.FactorGraph(variables=[A, B, C]) @@ -109,7 +109,7 @@ def test_wiring_with_ORFactorGroup(): variables_for_factors=[[A[idx], B[idx], C[idx]]], ) fg2.add_factor_group(factor_group) - assert len(fg2.factor_groups[logical_factor.ORFactor]) == 10 + assert len(fg2.factor_groups) == 10 # FactorGraph with multiple SingleFactorGroup fg3 = graph.FactorGraph(variables=[A, B, C]) @@ -118,7 +118,7 @@ def test_wiring_with_ORFactorGroup(): variables=[A[idx], B[idx], C[idx]], ) fg3.add_factor(factor) - assert len(fg3.factor_groups[logical_factor.ORFactor]) == 10 + assert len(fg3.factor_groups) == 10 assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors) @@ -155,7 +155,7 @@ def test_wiring_with_ANDFactorGroup(): variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)], ) fg1.add_factor_group(factor_group) - assert len(fg1.factor_groups[logical_factor.ANDFactor]) == 1 + assert len(fg1.factor_groups) == 1 # FactorGraph with multiple ANDFactorGroup fg2 = graph.FactorGraph(variables=[A, B, C]) @@ -164,7 +164,7 @@ def test_wiring_with_ANDFactorGroup(): variables_for_factors=[[A[idx], B[idx], C[idx]]], ) fg2.add_factor_group(factor_group) - assert len(fg2.factor_groups[logical_factor.ANDFactor]) == 10 + assert len(fg2.factor_groups) == 10 # FactorGraph with multiple SingleFactorGroup fg3 = graph.FactorGraph(variables=[A, B, C]) @@ -173,7 +173,7 @@ def test_wiring_with_ANDFactorGroup(): variables=[A[idx], B[idx], C[idx]], ) fg3.add_factor(factor) - assert len(fg3.factor_groups[logical_factor.ANDFactor]) == 10 + assert len(fg3.factor_groups) == 10 assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors) diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py index e9067abe..c6a030d6 100644 --- a/tests/test_pgmax.py +++ b/tests/test_pgmax.py @@ -265,14 +265,14 @@ def create_valid_suppression_config_arr(suppression_diameter): for row in range(M - 1): for col in range(N - 1): if row != M - 2 and col != N - 2: - curr_names = [ + curr_vars = [ grid_vars[0, row, col], grid_vars[1, row, col], grid_vars[0, row, col + 1], grid_vars[1, row + 1, col], ] elif row != M - 2: - curr_names = [ + curr_vars = [ grid_vars[0, row, col], grid_vars[1, row, col], additional_vars[0, row, col + 1], @@ -280,7 +280,7 @@ def create_valid_suppression_config_arr(suppression_diameter): ] elif col != N - 2: - curr_names = [ + curr_vars = [ grid_vars[0, row, col], grid_vars[1, row, col], grid_vars[0, row, col + 1], @@ -288,7 +288,7 @@ def create_valid_suppression_config_arr(suppression_diameter): ] else: - curr_names = [ + curr_vars = [ grid_vars[0, row, col], grid_vars[1, row, col], additional_vars[0, row, col + 1], @@ -296,54 +296,54 @@ def create_valid_suppression_config_arr(suppression_diameter): ] if row % 2 == 0: factor = EnumerationFactor( - variables=curr_names, + variables=curr_vars, factor_configs=valid_configs_non_supp, log_potentials=np.zeros( valid_configs_non_supp.shape[0], dtype=float ), ) - fg.add_factor(factor, name=(row, col)) + fg.add_factor(factor) else: factor = EnumerationFactor( - variables=curr_names, + variables=curr_vars, factor_configs=valid_configs_non_supp, log_potentials=np.zeros( valid_configs_non_supp.shape[0], dtype=float ), ) - fg.add_factor(factor, name=(row, col)) + fg.add_factor(factor) # Create an EnumerationFactorGroup for vertical suppression factors - vert_suppression_names: List[List[Tuple[Any, ...]]] = [] + vert_suppression_vars: List[List[Tuple[Any, ...]]] = [] for col in range(N): for start_row in range(M - SUPPRESSION_DIAMETER): if col != N - 1: - vert_suppression_names.append( + vert_suppression_vars.append( [ grid_vars[0, r, col] for r in range(start_row, start_row + SUPPRESSION_DIAMETER) ] ) else: - vert_suppression_names.append( + vert_suppression_vars.append( [ additional_vars[0, r, col] for r in range(start_row, start_row + SUPPRESSION_DIAMETER) ] ) - horz_suppression_names: List[List[Tuple[Any, ...]]] = [] + horz_suppression_vars: List[List[Tuple[Any, ...]]] = [] for row in range(M): for start_col in range(N - SUPPRESSION_DIAMETER): if row != M - 1: - horz_suppression_names.append( + horz_suppression_vars.append( [ grid_vars[1, row, c] for c in range(start_col, start_col + SUPPRESSION_DIAMETER) ] ) else: - horz_suppression_names.append( + horz_suppression_vars.append( [ additional_vars[1, row, c] for c in range(start_col, start_col + SUPPRESSION_DIAMETER) @@ -352,13 +352,13 @@ def create_valid_suppression_config_arr(suppression_diameter): # Add the suppression factors to the graph via kwargs factor_group = enumeration.EnumerationFactorGroup( - variables_for_factors=vert_suppression_names, + variables_for_factors=vert_suppression_vars, factor_configs=valid_configs_supp, ) fg.add_factor_group(factor_group) factor_group = enumeration.EnumerationFactorGroup( - variables_for_factors=horz_suppression_names, + variables_for_factors=horz_suppression_vars, factor_configs=valid_configs_supp, log_potentials=np.zeros(valid_configs_supp.shape[0], dtype=float), ) @@ -419,7 +419,7 @@ def binary_connected_variables( variables_for_factors=binary_connected_variables(28, 28, k_row, k_col), log_potential_matrix=W_pot[:, :, k_row, k_col], ) - fg.add_factor_group(factor_group, name=(k_row, k_col)) + fg.add_factor_group(factor_group) # Assign evidence to pixel vars bp_state = fg.bp_state From 22b604e7a8e78c622e6bbb8ca54ce142e872a40d Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Mon, 25 Apr 2022 23:05:11 +0000 Subject: [PATCH 19/35] Remove factor group names --- tests/fg/test_groups.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index b12b49d3..24d24400 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -94,6 +94,9 @@ def test_nd_variable_array(): def test_enumeration_factor_group(): vg = vgroup.NDVariableArray(shape=(2, 2), num_states=3) + vg_bis = vgroup.NDVariableArray(shape=(2, 2), num_states=3) + vg < vg_bis + with pytest.raises( ValueError, match=re.escape("Expected log potentials shape: (1,) or (2, 1). Got (3, 2)"), From 87fcfd9eb000f308ac00e28e75922385962c9c5b Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Tue, 26 Apr 2022 02:27:18 +0000 Subject: [PATCH 20/35] Modify hash + add_factors --- examples/gmrf.py | 8 +-- examples/ising_model.py | 4 +- examples/pmp_binary_deconvolution.py | 8 +-- examples/rbm.py | 16 ++--- examples/rcn.py | 4 +- pgmax/factors/logical.py | 4 +- pgmax/fg/graph.py | 93 +++++++++++----------------- pgmax/fg/groups.py | 16 ++--- pgmax/groups/variables.py | 39 ++++++++---- tests/factors/test_and.py | 18 +++--- tests/factors/test_or.py | 14 +++-- tests/fg/test_graph.py | 80 ++++++++++++------------ tests/fg/test_groups.py | 31 ++++++++-- tests/fg/test_wiring.py | 40 ++++++------ tests/test_pgmax.py | 12 ++-- 15 files changed, 202 insertions(+), 185 deletions(-) diff --git a/examples/gmrf.py b/examples/gmrf.py index 6601e5c7..eb86a5e9 100644 --- a/examples/gmrf.py +++ b/examples/gmrf.py @@ -66,7 +66,7 @@ for jj in range(N) ], ) -fg.add_factor_group(top_down) +fg.add_factors(top_down) # Add left-right factors left_right = enumeration.PairwiseFactorGroup( @@ -76,7 +76,7 @@ for jj in range(N - 1) ], ) -fg.add_factor_group(left_right) +fg.add_factors(left_right) # Add diagonal factors diagonal0 = enumeration.PairwiseFactorGroup( @@ -86,7 +86,7 @@ for jj in range(N - 1) ], ) -fg.add_factor_group(diagonal0) +fg.add_factors(diagonal0) diagonal1 = enumeration.PairwiseFactorGroup( variables_for_factors=[ @@ -95,7 +95,7 @@ for jj in range(N - 1) ], ) -fg.add_factor_group(diagonal1) +fg.add_factors(diagonal1) # %% bp = graph.BP(fg.bp_state, temperature=1.0) diff --git a/examples/ising_model.py b/examples/ising_model.py index 52e3d02c..3f530a33 100644 --- a/examples/ising_model.py +++ b/examples/ising_model.py @@ -29,7 +29,7 @@ # %% variables = vgroup.NDVariableArray(num_states=2, shape=(50, 50)) -fg = graph.FactorGraph(variables=variables) +fg = graph.FactorGraph(variable_groups=variables) variables_for_factors = [] for ii in range(50): @@ -43,7 +43,7 @@ variables_for_factors=variables_for_factors, log_potential_matrix=0.8 * np.array([[1.0, -1.0], [-1.0, 1.0]]), ) -fg.add_factor_group(factor_group) +fg.add_factors(factor_group) # %% [markdown] # ### Run inference and visualize results diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py index d37a0c9f..a0f24406 100644 --- a/examples/pmp_binary_deconvolution.py +++ b/examples/pmp_binary_deconvolution.py @@ -138,12 +138,12 @@ def plot_images(images, display=True, nr=None): print("Time", time.time() - start) # %% [markdown] -# For computation efficiency, we add large FactorGroups via `fg.add_factor_group` instead of adding individual Factors +# For computation efficiency, we add large FactorGroups via `fg.add_factors` instead of adding individual Factors # %% start = time.time() # Factor graph -fg = graph.FactorGraph(variables=[S, W, SW, X]) +fg = graph.FactorGraph(variable_groups=[S, W, SW, X]) print(time.time() - start) # Define the ANDFactors @@ -181,7 +181,7 @@ def plot_images(images, display=True, nr=None): # Add ANDFactorGroup, which is computationally efficient AND_factor_group = logical.ANDFactorGroup(variables_for_ANDFactors) -fg.add_factor_group(AND_factor_group) +fg.add_factors(AND_factor_group) print(time.time() - start) # Define the ORFactors @@ -192,7 +192,7 @@ def plot_images(images, display=True, nr=None): # Add ORFactorGroup, which is computationally efficient OR_factor_group = logical.ORFactorGroup(variables_for_ORFactors) -fg.add_factor_group(OR_factor_group) +fg.add_factors(OR_factor_group) print("Time", time.time() - start) for factor_type, factor_groups in fg._factor_types_to_groups.items(): diff --git a/examples/rbm.py b/examples/rbm.py index 8ea595ea..cdc51390 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -52,7 +52,7 @@ # Initialize factor graph hidden_variables = vgroup.NDVariableArray(num_states=2, shape=bh.shape) visible_variables = vgroup.NDVariableArray(num_states=2, shape=bv.shape) -fg = graph.FactorGraph(variables=[hidden_variables, visible_variables]) +fg = graph.FactorGraph(variable_groups=[hidden_variables, visible_variables]) print("Time", time.time() - start) # %% [markdown] @@ -69,14 +69,14 @@ factor_configs=np.arange(2)[:, None], log_potentials=np.stack([np.zeros_like(bh), bh], axis=1), ) -fg.add_factor_group(hidden_unaries) +fg.add_factors(hidden_unaries) visible_unaries = enumeration.EnumerationFactorGroup( variables_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])], factor_configs=np.arange(2)[:, None], log_potentials=np.stack([np.zeros_like(bv), bv], axis=1), ) -fg.add_factor_group(visible_unaries) +fg.add_factors(visible_unaries) # Add pairwise factors log_potential_matrix = np.zeros(W.shape + (2, 2)).reshape((-1, 2, 2)) @@ -90,9 +90,9 @@ ], log_potential_matrix=log_potential_matrix, ) -fg.add_factor_group(pairwise_factors) +fg.add_factors(pairwise_factors) -# # %snakeviz fg.add_factor_group(factory=enumeration.PairwiseFactorGroup, variables_for_factors=v, log_potential_matrix=log_potential_matrix,) +# # %snakeviz fg.add_factors(factory=enumeration.PairwiseFactorGroup, variables_for_factors=v, log_potential_matrix=log_potential_matrix,) print("Time", time.time() - start) @@ -115,7 +115,7 @@ # factor_configs=np.arange(2)[:, None], # log_potentials=np.array([0, bh[ii]]), # ) -# fg.add_factor(factor) +# fg.add_factors(factor=factor) # # for jj in range(bv.shape[0]): # factor = enumeration_factor.EnumerationFactor( @@ -123,7 +123,7 @@ # factor_configs=np.arange(2)[:, None], # log_potentials=np.array([0, bv[jj]]), # ) -# fg.add_factor(factor) +# fg.add_factors(factor=factor) # # # Add pairwise factors # factor_configs = np.array(list(itertools.product(np.arange(2), repeat=2))) @@ -134,7 +134,7 @@ # factor_configs=factor_configs, # log_potentials=np.array([0, 0, 0, W[ii, jj]]), # ) -# fg.add_factor(factor) +# fg.add_factors(factor=factor) # ~~~ # # Once we have added the factors, we can run max-product LBP and get MAP decoding by diff --git a/examples/rcn.py b/examples/rcn.py index d7add7d7..2414c0e9 100644 --- a/examples/rcn.py +++ b/examples/rcn.py @@ -272,7 +272,7 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray: # %% start = time.time() -fg = graph.FactorGraph(variables=variables_all_models) +fg = graph.FactorGraph(variables_all_models) # Adding rcn model edges to the pgmax factor graph. for idx in range(edges.shape[0]): @@ -285,7 +285,7 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray: factor_configs=valid_configs_list[r], log_potentials=np.zeros(valid_configs_list[r].shape[0]), ) - fg.add_factor(factor) + fg.add_factors(factor=factor) end = time.time() print(f"Creating factors took {end-start:.3f} seconds.") diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py index 7f2d1ca6..bf8652cb 100644 --- a/pgmax/factors/logical.py +++ b/pgmax/factors/logical.py @@ -2,7 +2,7 @@ import functools from dataclasses import dataclass, field -from typing import List, Mapping, Optional, Sequence, Union +from typing import List, Mapping, Optional, Sequence, Tuple, Union import jax import jax.numpy as jnp @@ -141,7 +141,7 @@ def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring: @staticmethod def compile_wiring( variables_for_factors: Sequence[List], - vars_to_starts: Mapping[int, int], + vars_to_starts: Mapping[Tuple[int, int], int], edge_states_offset: int, ) -> LogicalWiring: """Compile a LogicalWiring for a LogicalFactor or a FactorGroup with LogicalFactors. diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 815d8ba2..9fbe794c 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -35,7 +35,6 @@ from pgmax.bp import infer from pgmax.factors import FAC_TO_VAR_UPDATES from pgmax.fg import groups, nodes -from pgmax.groups import variables from pgmax.utils import cached_property @@ -53,33 +52,14 @@ class FactorGraph: this input, and the individual VariableGroups will need to be accessed by indexing. """ - variables: Union[groups.VariableGroup, Sequence[groups.VariableGroup]] + variable_groups: Union[groups.VariableGroup, Sequence[groups.VariableGroup]] def __post_init__(self): import time start = time.time() - if isinstance(self.variables, groups.VariableGroup): - self.variables = [self.variables] - - # Check variable groups are unique - vg_names = [] - vg_array_names = [] - for variable_group in self.variables: - vg_name = variable_group.__hash__() - if vg_name in vg_names: - raise ValueError("Two objects have the same name") - vg_names.append(vg_name) - if isinstance(variable_group, variables.NDVariableArray): - start_name, end_name = ( - vg_name, - variable_group.variable_names.flatten()[-1], - ) - for var_array_name in vg_array_names: - start_name2, end_name2 = var_array_name - if max(start_name, start_name2) <= min(end_name, end_name2): - raise ValueError("Two NDVariableArrays have overlapping names") - vg_array_names.append((start_name, end_name)) + if isinstance(self.variable_groups, groups.VariableGroup): + self.variable_groups = [self.variable_groups] # Useful objects to build the FactorGraph self._factor_types_to_groups: OrderedDict[ @@ -98,7 +78,7 @@ def __post_init__(self): Tuple[int, int], int ] = collections.OrderedDict() vars_num_states_cumsum = 0 - for variable_group in self.variables: + for variable_group in self.variable_groups: vg_num_states = variable_group.num_states.flatten() vg_num_states_cumsum = np.insert(np.cumsum(vg_num_states), 0, 0) self._vars_to_starts.update( @@ -114,28 +94,35 @@ def __post_init__(self): def __hash__(self) -> int: return hash(self.factor_groups) - def add_factor(self, factor: nodes.Factor) -> None: - """Function to add a single Factor to the FactorGraph. - - Args: - factor: The factor to be added to the factor graph. - """ - factor_group = groups.SingleFactorGroup( - variables_for_factors=[factor.variables], - factor=factor, - ) - self.add_factor_group(factor_group) - - def add_factor_group(self, factor_group: groups.FactorGroup) -> None: - """Add a FactorGroup to the FactorGraph, by updating the FactorGraphState. + def add_factors( + self, + factor_group: Optional[groups.FactorGroup] = None, + factor: Optional[nodes.Factor] = None, + ) -> None: + """Add a FactorGroup or a single Factor to the FactorGraph, by updating the FactorGraphState. Args: factor_group: The FactorGroup to be added to the FactorGraph. + factor: The Factor to be added to the factor graph. Raises: - ValueError: If the factor group with the same name or a Factor involving the same variables - already exists in the FactorGraph. + ValueError: If + (1) Both a Factor and a FactorGroup are added + (2) The FactorGroup involving the same variables already exists in the FactorGraph. """ + if factor is None and factor_group is None: + raise ValueError("A Factor or a FactorGroup is required") + + if factor is not None and factor_group is not None: + raise ValueError("Cannot simultaneously add a Factor and a FactorGroup") + + if factor is not None: + factor_group = groups.SingleFactorGroup( + variables_for_factors=[factor.variables], + factor=factor, + ) + assert factor_group is not None + factor_type = factor_group.factor_type for var_names_for_factor in factor_group.variables_for_factors: var_names = frozenset(var_names_for_factor) @@ -294,10 +281,10 @@ def fg_state(self) -> FactorGraphState: log_potentials = np.concatenate( [self.log_potentials[factor_type] for factor_type in self.log_potentials] ) - assert isinstance(self.variables, list) + assert isinstance(self.variable_groups, list) return FactorGraphState( - variable_groups=self.variables, + variable_groups=self.variable_groups, vars_to_starts=self._vars_to_starts, num_var_states=self._num_var_states, total_factor_num_states=self._total_factor_num_states, @@ -606,7 +593,7 @@ def __setitem__( @typing.overload def __setitem__( self, - names: Tuple[int, int], + variable: Tuple[int, int], data: Union[np.ndarray, jnp.ndarray], ) -> None: """Spreading beliefs at a variable to all connected factors @@ -618,14 +605,14 @@ def __setitem__( variable. """ - def __setitem__(self, names, data) -> None: + def __setitem__(self, variable, data) -> None: object.__setattr__( self, "value", np.asarray( update_ftov_msgs( jax.device_put(self.value), - {names: jax.device_put(data)}, + {variable: jax.device_put(data)}, self.fg_state, ) ), @@ -1002,18 +989,12 @@ def get_beliefs(bp_arrays: BPArrays) -> Dict[Hashable, Any]: beliefs: Beliefs returned by belief propagation. """ - def compute_flat_beliefs(bp_arrays, var_states_for_edges): - flat_beliefs = ( - jax.device_put(bp_arrays.evidence) - .at[jax.device_put(var_states_for_edges)] - .add(bp_arrays.ftov_msgs) - ) - return flat_beliefs - - return unflatten_beliefs( - compute_flat_beliefs(bp_arrays, var_states_for_edges), - bp_state.fg_state.variable_groups, + flat_beliefs = ( + jax.device_put(bp_arrays.evidence) + .at[jax.device_put(var_states_for_edges)] + .add(bp_arrays.ftov_msgs) ) + return unflatten_beliefs(flat_beliefs, bp_state.fg_state.variable_groups) bp = BeliefPropagation( init=functools.partial(update, None), diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 7c56deef..62692ffd 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -1,7 +1,6 @@ """A module containing the base classes for variable and factor groups in a Factor Graph.""" import inspect -import random from dataclasses import dataclass, field from functools import total_ordering from typing import ( @@ -22,25 +21,18 @@ import pgmax.fg.nodes as nodes from pgmax.utils import cached_property +MAX_SIZE = 1e16 + @total_ordering @dataclass(frozen=True, eq=False) class VariableGroup: """Class to represent a group of variables. Each variable is represented via a tuple of the form (variable hash/name, number of states) - - Attributes: - random_hash: Hash of the VariableGroup """ - def __post_init__(self): - # Overwite default hash to have larger differences - random.seed(id(self)) - random_hash = random.randint(0, 2**63) - object.__setattr__(self, "random_hash", random_hash) - def __hash__(self): - return self.random_hash + return id(self) * int(MAX_SIZE) def __eq__(self, other): return hash(self) == hash(other) @@ -63,7 +55,7 @@ def __getitem__(self, val): ) @cached_property - def variables(self) -> Tuple[Any, int]: + def variables(self) -> List[Tuple]: """Function that returns the list of all variables in the VariableGroup. Each variable is represented by a tuple of the form (variable hash/name, number of states) diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index feff8467..7b50167e 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -26,10 +26,13 @@ class NDVariableArray(groups.VariableGroup): num_states: Union[int, np.ndarray] def __post_init__(self): - super().__post_init__() + if np.prod(self.shape) > groups.MAX_SIZE: + raise ValueError( + f"Currently only support NDVariableArray of size smaller than {groups.MAX_SIZE}. Got {np.prod(self.shape)}" + ) if np.isscalar(self.num_states): - num_states = np.full(self.shape, fill_value=self.num_states) + num_states = np.full(self.shape, fill_value=self.num_states, dtype=np.int32) object.__setattr__(self, "num_states", num_states) elif isinstance(self.num_states, np.ndarray) and np.issubdtype( self.num_states.dtype, int @@ -43,7 +46,7 @@ def __post_init__(self): def __getitem__( self, val: Union[int, slice, Tuple] - ) -> Union[Tuple[int, int], List[Tuple]]: + ) -> Union[Tuple[int, int], List[Tuple[int, int]]]: """Given an index or a slice, retrieve the associated variable(s). Each variable is returned via a tuple of the form (variable hash, number of states) @@ -159,14 +162,19 @@ class VariableDict(groups.VariableGroup): """ variable_names: Tuple[Any, ...] - num_states: np.ndarray # TODO: this should be an int converted to an array in __post_init__ + num_states: int def __post_init__(self): - super().__post_init__() - - num_states = np.full((len(self.variable_names),), fill_value=self.num_states) + num_states = np.full( + (len(self.variable_names),), fill_value=self.num_states, dtype=np.int32 + ) object.__setattr__(self, "num_states", num_states) + hash_and_names = tuple( + (self.__hash__(), var_name) for var_name in self.variable_names + ) + object.__setattr__(self, "variable_names", hash_and_names) + @cached_property def variables(self) -> List[Tuple]: """Function that returns the list of all variables in the VariableGroup. @@ -175,9 +183,12 @@ def variables(self) -> List[Tuple]: Returns: List of variables in the VariableGroup """ - return list(zip(self.variable_names, self.num_states)) + assert isinstance(self.num_states, np.ndarray) + vars_names = list(self.variable_names) + vars_num_states = self.num_states.flatten() + return list(zip(vars_names, vars_num_states)) - def __getitem__(self, val): + def __getitem__(self, val: Any) -> Tuple[Any, int]: """Given a variable name retrieve the associated variable, returned via a tuple of the form (variable name, number of states) @@ -187,9 +198,11 @@ def __getitem__(self, val): Returns: The queried variable """ - if val not in self.variable_names: + assert isinstance(self.num_states, np.ndarray) + + if (self.__hash__(), val) not in self.variable_names: raise ValueError(f"Variable {val} is not in VariableDict") - return (val, self.num_states[0]) + return ((self.__hash__(), val), self.num_states[0]) def flatten( self, data: Mapping[Tuple[int, int], Union[np.ndarray, jnp.ndarray]] @@ -209,6 +222,8 @@ def flatten( (1) data is referring to a non-existing variable (2) data is not of the correct shape """ + assert isinstance(self.num_states, np.ndarray) + for variable in data: if variable not in self.variables: raise ValueError( @@ -244,6 +259,8 @@ def unflatten( (1) flat_data is not a 1D array (2) flat_data is not of the right shape """ + assert isinstance(self.num_states, np.ndarray) + if flat_data.ndim != 1: raise ValueError( f"Can only unflatten 1D array. Got a {flat_data.ndim}D array." diff --git a/tests/factors/test_and.py b/tests/factors/test_and.py index 147d57a0..fcbe6772 100644 --- a/tests/factors/test_and.py +++ b/tests/factors/test_and.py @@ -45,14 +45,18 @@ def test_run_bp_with_ANDFactors(): num_states=2, shape=(num_parents.sum(),) ) children_variables1 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,)) - fg1 = graph.FactorGraph(variables=[parents_variables1, children_variables1]) + fg1 = graph.FactorGraph( + variable_groups=[parents_variables1, children_variables1] + ) # Graph 2 parents_variables2 = vgroup.NDVariableArray( num_states=2, shape=(num_parents.sum(),) ) children_variables2 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,)) - fg2 = graph.FactorGraph(variables=[parents_variables2, children_variables2]) + fg2 = graph.FactorGraph( + variable_groups=[parents_variables2, children_variables2] + ) # Option 1: Define EnumerationFactors equivalent to the ANDFactors variables_for_factors1 = [] @@ -100,7 +104,7 @@ def test_run_bp_with_ANDFactors(): factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) - fg1.add_factor(enum_factor) + fg1.add_factors(factor=enum_factor) else: if idx != 0: # Add the second half of factors to FactorGraph2 @@ -109,7 +113,7 @@ def test_run_bp_with_ANDFactors(): factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) - fg2.add_factor(enum_factor) + fg2.add_factors(factor=enum_factor) else: # Add all the EnumerationFactors to FactorGraph1 for the first iter enum_factor = EnumerationFactor( @@ -117,7 +121,7 @@ def test_run_bp_with_ANDFactors(): factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) - fg1.add_factor(enum_factor) + fg1.add_factors(factor=enum_factor) # Option 2: Define the ANDFactors num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0) @@ -141,10 +145,10 @@ def test_run_bp_with_ANDFactors(): ) if idx != 0: factor_group = logical.ANDFactorGroup(variables_for_ANDFactors_fg1) - fg1.add_factor_group(factor_group) + fg1.add_factors(factor_group) factor_group = logical.ANDFactorGroup(variables_for_ANDFactors_fg2) - fg2.add_factor_group(factor_group) + fg2.add_factors(factor_group) # Run inference bp1 = graph.BP(fg1.bp_state, temperature=temperature) diff --git a/tests/factors/test_or.py b/tests/factors/test_or.py index 76cc5107..16236b35 100644 --- a/tests/factors/test_or.py +++ b/tests/factors/test_or.py @@ -45,14 +45,18 @@ def test_run_bp_with_ORFactors(): num_states=2, shape=(num_parents.sum(),) ) children_variables1 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,)) - fg1 = graph.FactorGraph(variables=[parents_variables1, children_variables1]) + fg1 = graph.FactorGraph( + variable_groups=[parents_variables1, children_variables1] + ) # Graph 2 parents_variables2 = vgroup.NDVariableArray( num_states=2, shape=(num_parents.sum(),) ) children_variables2 = vgroup.NDVariableArray(num_states=2, shape=(num_factors,)) - fg2 = graph.FactorGraph(variables=[parents_variables2, children_variables2]) + fg2 = graph.FactorGraph( + variable_groups=[parents_variables2, children_variables2] + ) # Variable names for factors variables_for_factors1 = [] @@ -98,7 +102,7 @@ def test_run_bp_with_ORFactors(): factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) - fg1.add_factor(enum_factor) + fg1.add_factor(factor=enum_factor) else: if idx != 0: # Add the second half of factors to FactorGraph2 @@ -107,7 +111,7 @@ def test_run_bp_with_ORFactors(): factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) - fg2.add_factor(enum_factor) + fg2.add_factor(factor=enum_factor) else: # Add all the EnumerationFactors to FactorGraph1 for the first iter enum_factor = EnumerationFactor( @@ -115,7 +119,7 @@ def test_run_bp_with_ORFactors(): factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) - fg1.add_factor(enum_factor) + fg1.add_factor(factor=enum_factor) # Option 2: Define the ORFactors num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0) diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py index e80b32a3..bb26ce99 100644 --- a/tests/fg/test_graph.py +++ b/tests/fg/test_graph.py @@ -7,56 +7,52 @@ import pytest from pgmax.factors import enumeration as enumeration_factor -from pgmax.fg import graph, groups -from pgmax.groups import enumeration, logical +from pgmax.fg import graph +from pgmax.groups import enumeration from pgmax.groups import variables as vgroup def test_factor_graph(): - vg1 = vgroup.NDVariableArray(num_states=2, shape=(10, 10)) - with pytest.raises(ValueError, match="Two objects have the same name"): - fg = graph.FactorGraph(variables=[vg1, vg1]) - - vg2 = vgroup.NDVariableArray(num_states=2, shape=(10, 10)) - object.__setattr__(vg2, "random_hash", vg1.__hash__() + 10) - with pytest.raises(ValueError, match="Two NDVariableArrays have overlapping names"): - fg = graph.FactorGraph(variables=[vg1, vg2]) - vg = vgroup.VariableDict(variable_names=(0,), num_states=15) fg = graph.FactorGraph(vg) + + with pytest.raises( + ValueError, + match="A Factor or a FactorGroup is required", + ): + fg.add_factors(factor_group=None, factor=None) + factor = enumeration_factor.EnumerationFactor( variables=[vg[0]], factor_configs=np.arange(15)[:, None], log_potentials=np.zeros(15), ) - fg.add_factor(factor) + factor_group = enumeration.EnumerationFactorGroup( + variables_for_factors=[[vg[0]]], + factor_configs=np.arange(15)[:, None], + log_potentials=np.zeros(15), + ) with pytest.raises( ValueError, - match=re.escape( - f"A Factor of type {enumeration_factor.EnumerationFactor} involving variables {frozenset([(0, 15)])} already exists." - ), + match="Cannot simultaneously add a Factor and a FactorGroup", ): - fg.add_factor(factor) + fg.add_factors(factor_group=factor_group, factor=factor) + fg.add_factors(factor=factor) -def test_single_factor(): - with pytest.raises(ValueError, match="Cannot create a FactorGroup with no Factor."): - logical.ORFactorGroup(variables_for_factors=[]) - - A = vgroup.NDVariableArray(num_states=2, shape=(10,)) - B = vgroup.NDVariableArray(num_states=2, shape=(10,)) - - variables0 = (A[0], B[0]) - variables1 = (A[1], B[1]) - ORFactor = logical.ORFactorGroup(variables_for_factors=[variables0]) + factor_group = enumeration.EnumerationFactorGroup( + variables_for_factors=[[vg[0]]], + factor_configs=np.arange(15)[:, None], + log_potentials=np.zeros(15), + ) with pytest.raises( - ValueError, match="SingleFactorGroup should only contain one factor. Got 2" + ValueError, + match=re.escape( + f"A Factor of type {enumeration_factor.EnumerationFactor} involving variables {frozenset([((vg.__hash__(), 0), 15)])} already exists." + ), ): - groups.SingleFactorGroup( - variables_for_factors=[variables0, variables1], - factor=ORFactor, - ) + fg.add_factors(factor_group) def test_bp_state(): @@ -67,10 +63,10 @@ def test_bp_state(): factor_configs=np.arange(15)[:, None], log_potentials=np.zeros(15), ) - fg0.add_factor(factor) + fg0.add_factors(factor=factor) fg1 = graph.FactorGraph(vg) - fg1.add_factor(factor) + fg1.add_factors(factor=factor) with pytest.raises( ValueError, @@ -90,7 +86,7 @@ def test_log_potentials(): variables_for_factors=[[vg[0]]], factor_configs=np.arange(10)[:, None], ) - fg.add_factor_group(factor_group) + fg.add_factors(factor_group) with pytest.raises( ValueError, @@ -130,7 +126,7 @@ def test_ftov_msgs(): variables_for_factors=[[vg[0]]], factor_configs=np.arange(10)[:, None], ) - fg.add_factor_group(factor_group) + fg.add_factors(factor_group) with pytest.raises( ValueError, @@ -141,7 +137,7 @@ def test_ftov_msgs(): with pytest.raises( ValueError, match=re.escape( - "Given belief shape (10,) does not match expected shape (15,) for variable (0, 15)." + f"Given belief shape (10,) does not match expected shape (15,) for variable (({vg.__hash__()}, 0), 15)." ), ): fg.bp_state.ftov_msgs[vg[0]] = np.ones(10) @@ -159,13 +155,13 @@ def test_ftov_msgs(): def test_evidence(): - vg = vgroup.VariableDict(variable_names=("a",), num_states=15) + vg = vgroup.VariableDict(variable_names=(0,), num_states=15) fg = graph.FactorGraph(vg) factor_group = enumeration.EnumerationFactorGroup( - variables_for_factors=[[vg["a"]]], + variables_for_factors=[[vg[0]]], factor_configs=np.arange(10)[:, None], ) - fg.add_factor_group(factor_group) + fg.add_factors(factor_group) with pytest.raises( ValueError, match=re.escape("Expected evidence shape (15,). Got (10,).") @@ -175,7 +171,7 @@ def test_evidence(): evidence = graph.Evidence(fg_state=fg.fg_state, value=np.zeros(15)) assert jnp.all(evidence.value == jnp.zeros(15)) - vg2 = vgroup.VariableDict(variable_names=("b",), num_states=15) + vg2 = vgroup.VariableDict(variable_names=(0,), num_states=15) with pytest.raises( ValueError, match=re.escape( @@ -184,7 +180,7 @@ def test_evidence(): ): graph.update_evidence( jax.device_put(evidence.value), - {vg2["b"]: jax.device_put(np.zeros(15))}, + {vg2[0]: jax.device_put(np.zeros(15))}, fg.fg_state, ) @@ -196,7 +192,7 @@ def test_bp(): variables_for_factors=[[vg[0]]], factor_configs=np.arange(10)[:, None], ) - fg.add_factor_group(factor_group) + fg.add_factors(factor_group) bp = graph.BP(fg.bp_state, temperature=0) bp_arrays = bp.update() diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index 24d24400..82a2ea06 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -5,7 +5,8 @@ import numpy as np import pytest -from pgmax.groups import enumeration +from pgmax.fg import groups +from pgmax.groups import enumeration, logical from pgmax.groups import variables as vgroup @@ -19,10 +20,10 @@ def test_variable_dict(): with pytest.raises( ValueError, match=re.escape( - "Variable (2, 15) expects a data array of shape (15,) or (1,). Got (10,)" + f"Variable (({variable_dict.__hash__()}, 2), 15) expects a data array of shape (15,) or (1,). Got (10,)" ), ): - variable_dict.flatten({(2, 15): np.zeros(10)}) + variable_dict.flatten({((variable_dict.__hash__(), 2), 15): np.zeros(10)}) with pytest.raises( ValueError, match="Can only unflatten 1D array. Got a 2D array." @@ -35,7 +36,10 @@ def test_variable_dict(): jax.tree_util.tree_multimap( lambda x, y: jnp.all(x == y), variable_dict.unflatten(jnp.zeros(3)), - {(name, 15): np.zeros(1) for name in range(3)}, + { + ((variable_dict.__hash__(), name), 15): np.zeros(1) + for name in range(3) + }, ) ) ) @@ -92,6 +96,25 @@ def test_nd_variable_array(): assert jnp.all(variable_group.unflatten(np.zeros(12)) == jnp.zeros((2, 2, 3))) +def test_single_factor(): + with pytest.raises(ValueError, match="Cannot create a FactorGroup with no Factor."): + logical.ORFactorGroup(variables_for_factors=[]) + + A = vgroup.NDVariableArray(num_states=2, shape=(10,)) + B = vgroup.NDVariableArray(num_states=2, shape=(10,)) + + variables0 = (A[0], B[0]) + variables1 = (A[1], B[1]) + ORFactor = logical.ORFactorGroup(variables_for_factors=[variables0]) + with pytest.raises( + ValueError, match="SingleFactorGroup should only contain one factor. Got 2" + ): + groups.SingleFactorGroup( + variables_for_factors=[variables0, variables1], + factor=ORFactor, + ) + + def test_enumeration_factor_group(): vg = vgroup.NDVariableArray(shape=(2, 2), num_states=3) vg_bis = vgroup.NDVariableArray(shape=(2, 2), num_states=3) diff --git a/tests/fg/test_wiring.py b/tests/fg/test_wiring.py index 3df31cb5..53efeee3 100644 --- a/tests/fg/test_wiring.py +++ b/tests/fg/test_wiring.py @@ -19,11 +19,11 @@ def test_wiring_with_PairwiseFactorGroup(): B = vgroup.NDVariableArray(num_states=2, shape=(10,)) # First test that compile_wiring enforces the correct factor_edges_num_states shape - fg = graph.FactorGraph(variables=[A, B]) + fg = graph.FactorGraph(variable_groups=[A, B]) factor_group = enumeration.PairwiseFactorGroup( variables_for_factors=[[A[idx], B[idx]] for idx in range(10)] ) - fg.add_factor_group(factor_group) + fg.add_factors(factor_group) factor_group = fg.factor_groups[0] object.__setattr__( @@ -36,31 +36,31 @@ def test_wiring_with_PairwiseFactorGroup(): factor_group.compile_wiring(fg._vars_to_starts) # FactorGraph with a single PairwiseFactorGroup - fg1 = graph.FactorGraph(variables=[A, B]) + fg1 = graph.FactorGraph(variable_groups=[A, B]) factor_group = enumeration.PairwiseFactorGroup( variables_for_factors=[[A[idx], B[idx]] for idx in range(10)] ) - fg1.add_factor_group(factor_group) + fg1.add_factors(factor_group) assert len(fg1.factor_groups) == 1 # FactorGraph with multiple PairwiseFactorGroup - fg2 = graph.FactorGraph(variables=[A, B]) + fg2 = graph.FactorGraph(variable_groups=[A, B]) for idx in range(10): factor_group = enumeration.PairwiseFactorGroup( variables_for_factors=[[A[idx], B[idx]]] ) - fg2.add_factor_group(factor_group) + fg2.add_factors(factor_group) assert len(fg2.factor_groups) == 10 # FactorGraph with multiple SingleFactorGroup - fg3 = graph.FactorGraph(variables=[A, B]) + fg3 = graph.FactorGraph(variable_groups=[A, B]) for idx in range(10): factor = enumeration_factor.EnumerationFactor( variables=[A[idx], B[idx]], factor_configs=np.array([[0, 0], [0, 1], [1, 0], [1, 1]]), log_potentials=np.zeros((4,)), ) - fg3.add_factor(factor) + fg3.add_factors(factor=factor) assert len(fg3.factor_groups) == 10 assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors) @@ -95,29 +95,29 @@ def test_wiring_with_ORFactorGroup(): C = vgroup.NDVariableArray(num_states=2, shape=(10,)) # FactorGraph with a single ORFactorGroup - fg1 = graph.FactorGraph(variables=[A, B, C]) + fg1 = graph.FactorGraph(variable_groups=[A, B, C]) factor_group = logical.ORFactorGroup( variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)], ) - fg1.add_factor_group(factor_group) + fg1.add_factors(factor_group) assert len(fg1.factor_groups) == 1 # FactorGraph with multiple ORFactorGroup - fg2 = graph.FactorGraph(variables=[A, B, C]) + fg2 = graph.FactorGraph(variable_groups=[A, B, C]) for idx in range(10): factor_group = logical.ORFactorGroup( variables_for_factors=[[A[idx], B[idx], C[idx]]], ) - fg2.add_factor_group(factor_group) + fg2.add_factors(factor_group) assert len(fg2.factor_groups) == 10 # FactorGraph with multiple SingleFactorGroup - fg3 = graph.FactorGraph(variables=[A, B, C]) + fg3 = graph.FactorGraph(variable_groups=[A, B, C]) for idx in range(10): factor = logical_factor.ORFactor( variables=[A[idx], B[idx], C[idx]], ) - fg3.add_factor(factor) + fg3.add_factors(factor=factor) assert len(fg3.factor_groups) == 10 assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors) @@ -150,29 +150,29 @@ def test_wiring_with_ANDFactorGroup(): C = vgroup.NDVariableArray(num_states=2, shape=(10,)) # FactorGraph with a single ANDFactorGroup - fg1 = graph.FactorGraph(variables=[A, B, C]) + fg1 = graph.FactorGraph(variable_groups=[A, B, C]) factor_group = logical.ANDFactorGroup( variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)], ) - fg1.add_factor_group(factor_group) + fg1.add_factors(factor_group) assert len(fg1.factor_groups) == 1 # FactorGraph with multiple ANDFactorGroup - fg2 = graph.FactorGraph(variables=[A, B, C]) + fg2 = graph.FactorGraph(variable_groups=[A, B, C]) for idx in range(10): factor_group = logical.ANDFactorGroup( variables_for_factors=[[A[idx], B[idx], C[idx]]], ) - fg2.add_factor_group(factor_group) + fg2.add_factors(factor_group) assert len(fg2.factor_groups) == 10 # FactorGraph with multiple SingleFactorGroup - fg3 = graph.FactorGraph(variables=[A, B, C]) + fg3 = graph.FactorGraph(variable_groups=[A, B, C]) for idx in range(10): factor = logical_factor.ANDFactor( variables=[A[idx], B[idx], C[idx]], ) - fg3.add_factor(factor) + fg3.add_factors(factor=factor) assert len(fg3.factor_groups) == 10 assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors) diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py index c6a030d6..643448a3 100644 --- a/tests/test_pgmax.py +++ b/tests/test_pgmax.py @@ -258,7 +258,7 @@ def create_valid_suppression_config_arr(suppression_diameter): pass # Create the factor graph - fg = graph.FactorGraph(variables=[grid_vars, additional_vars]) + fg = graph.FactorGraph(variable_groups=[grid_vars, additional_vars]) # Imperatively add EnumerationFactorGroups (each consisting of just one EnumerationFactor) to # the graph! @@ -302,7 +302,7 @@ def create_valid_suppression_config_arr(suppression_diameter): valid_configs_non_supp.shape[0], dtype=float ), ) - fg.add_factor(factor) + fg.add_factors(factor=factor) else: factor = EnumerationFactor( variables=curr_vars, @@ -311,7 +311,7 @@ def create_valid_suppression_config_arr(suppression_diameter): valid_configs_non_supp.shape[0], dtype=float ), ) - fg.add_factor(factor) + fg.add_factors(factor=factor) # Create an EnumerationFactorGroup for vertical suppression factors vert_suppression_vars: List[List[Tuple[Any, ...]]] = [] @@ -355,14 +355,14 @@ def create_valid_suppression_config_arr(suppression_diameter): variables_for_factors=vert_suppression_vars, factor_configs=valid_configs_supp, ) - fg.add_factor_group(factor_group) + fg.add_factors(factor_group) factor_group = enumeration.EnumerationFactorGroup( variables_for_factors=horz_suppression_vars, factor_configs=valid_configs_supp, log_potentials=np.zeros(valid_configs_supp.shape[0], dtype=float), ) - fg.add_factor_group(factor_group) + fg.add_factors(factor_group) # Run BP # Set the evidence @@ -419,7 +419,7 @@ def binary_connected_variables( variables_for_factors=binary_connected_variables(28, 28, k_row, k_col), log_potential_matrix=W_pot[:, :, k_row, k_col], ) - fg.add_factor_group(factor_group) + fg.add_factors(factor_group) # Assign evidence to pixel vars bp_state = fg.bp_state From 0f639fe6fcce3422c970e9aeea64a31a100b558c Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Tue, 26 Apr 2022 18:58:24 +0000 Subject: [PATCH 21/35] Stannis' comments --- examples/gmrf.py | 8 ----- pgmax/factors/enumeration.py | 4 +-- pgmax/factors/logical.py | 4 +-- pgmax/fg/graph.py | 66 ++++++++++-------------------------- pgmax/fg/groups.py | 9 ++--- pgmax/fg/nodes.py | 4 +-- pgmax/groups/logical.py | 7 ++-- pgmax/groups/variables.py | 22 ++++++------ tests/factors/test_or.py | 10 +++--- tests/fg/test_graph.py | 4 +-- tests/fg/test_groups.py | 20 +++++++++-- tests/test_pgmax.py | 8 ++--- 12 files changed, 71 insertions(+), 95 deletions(-) diff --git a/examples/gmrf.py b/examples/gmrf.py index eb86a5e9..1bf4fad9 100644 --- a/examples/gmrf.py +++ b/examples/gmrf.py @@ -233,11 +233,3 @@ def update(step, batch_noisy_images, batch_target_images, opt_state): ) pbar.update() pbar.set_postfix(loss=value) - -batch_indices = indices[idx * batch_size : (idx + 1) * batch_size] -batch_noisy_images, batch_target_images = ( - noisy_images_train[:10], - target_images_train[:10], -) -step = 0 -value, opt_state = update(step, batch_noisy_images, batch_target_images, opt_state) diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py index d461d3ab..5f776ba0 100644 --- a/pgmax/factors/enumeration.py +++ b/pgmax/factors/enumeration.py @@ -2,7 +2,7 @@ import functools from dataclasses import dataclass -from typing import List, Mapping, Sequence, Tuple, Union +from typing import Any, List, Mapping, Sequence, Tuple, Union import jax import jax.numpy as jnp @@ -156,7 +156,7 @@ def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiri def compile_wiring( variables_for_factors: Sequence[List], factor_configs: np.ndarray, - vars_to_starts: Mapping[Tuple[int, int], int], + vars_to_starts: Mapping[Tuple[Any, int], int], num_factors: int, ) -> EnumerationWiring: """Compile an EnumerationWiring for an EnumerationFactor or a FactorGroup with EnumerationFactors. diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py index bf8652cb..b78ea206 100644 --- a/pgmax/factors/logical.py +++ b/pgmax/factors/logical.py @@ -2,7 +2,7 @@ import functools from dataclasses import dataclass, field -from typing import List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union import jax import jax.numpy as jnp @@ -141,7 +141,7 @@ def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring: @staticmethod def compile_wiring( variables_for_factors: Sequence[List], - vars_to_starts: Mapping[Tuple[int, int], int], + vars_to_starts: Mapping[Tuple[Any, int], int], edge_states_offset: int, ) -> LogicalWiring: """Compile a LogicalWiring for a LogicalFactor or a FactorGroup with LogicalFactors. diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 9fbe794c..2f16f2eb 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -6,7 +6,6 @@ import copy import functools import inspect -import typing from dataclasses import asdict, dataclass from types import MappingProxyType from typing import ( @@ -44,12 +43,7 @@ class FactorGraph: Factors in a graph are clustered in factor groups, which are grouped according to their factor types. Args: - variables: A single VariableGroup or a container containing variable groups. - If not a single VariableGroup, supported containers include mapping and sequence. - For a mapping, the keys of the mapping are used to index the variable groups. - For a sequence, the indices of the sequence are used to index the variable groups. - Note that if not a single VariableGroup, a CompositeVariableGroup will be created from - this input, and the individual VariableGroups will need to be accessed by indexing. + variable_groups: A single VariableGroup or a list of VariableGroups. """ variable_groups: Union[groups.VariableGroup, Sequence[groups.VariableGroup]] @@ -103,7 +97,7 @@ def add_factors( Args: factor_group: The FactorGroup to be added to the FactorGraph. - factor: The Factor to be added to the factor graph. + factor: The Factor to be added to the FactorGraph. Raises: ValueError: If @@ -318,13 +312,13 @@ class FactorGraphState: """FactorGraphState. Args: - variable_groups: All the variable groups in the FactorGraph. + variable_groups: VariableGroups in the FactorGraph. vars_to_starts: Maps variables to their starting indices in the flat evidence array. flat_evidence[vars_to_starts[variable]: vars_to_starts[variable] + variable.num_var_states] contains evidence to the variable. num_var_states: Total number of variable states. total_factor_num_states: Size of the flat ftov messages array. - factor_groups: Factor groups in the FactorGraph + factor_groups: FactorGroups in the FactorGraph factor_type_to_msgs_range: Maps factors types to their start and end indices in the flat ftov messages. factor_type_to_potentials_range: Maps factor types to their start and end indices in the flat log potentials. factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials. @@ -333,7 +327,7 @@ class FactorGraphState: """ variable_groups: Sequence[groups.VariableGroup] - vars_to_starts: Mapping[Tuple[int, int], int] + vars_to_starts: Mapping[Tuple[Any, int], int] num_var_states: int total_factor_num_states: int factor_groups: Tuple[groups.FactorGroup, ...] @@ -474,12 +468,11 @@ def __setitem__( factor_group: Any, data: Union[np.ndarray, jnp.ndarray], ): - """Set the log potentials for a named factor group or a factor. + """Set the log potentials for a FactorGroup Args: factor_group: FactorGroup - data: Array containing the log potentials for the named factor group - or the factor. + data: Array containing the log potentials for the FactorGroup """ object.__setattr__( self, @@ -510,7 +503,7 @@ def update_ftov_msgs( Raises: ValueError if: (1) provided ftov_msgs shape does not match the expected ftov_msgs shape. - (2) provided name is not valid for ftov_msgs updates. + (2) provided variable is not in the FactorGraph. """ for variable, data in updates.items(): if variable in fg_state.vars_to_starts: @@ -535,14 +528,7 @@ def update_ftov_msgs( data / starts.shape[0] ) else: - raise ValueError( - "Invalid names for setting messages. " - "Supported names include a tuple of length 2 with factor " - "and variable names for directly setting factor to variable " - "messages, or a valid variable name for spreading expected " - "beliefs at a variable" - ) - + raise ValueError("Provided variable is not in the FactorGraph") return ftov_msgs @@ -574,38 +560,19 @@ def __post_init__(self): object.__setattr__(self, "value", self.value) - @typing.overload - def __setitem__( - self, - names: Tuple[Any, Any], - data: Union[np.ndarray, jnp.ndarray], - ) -> None: - """Setting messages from a factor to a variable - - Args: - names: A tuple of length 2 - names[0] is the name of the factor - names[1] is the name of the variable - data: An array containing messages from factor names[0] - to variable names[1] - """ - - @typing.overload def __setitem__( self, - variable: Tuple[int, int], + variable: Tuple[Any, int], data: Union[np.ndarray, jnp.ndarray], ) -> None: - """Spreading beliefs at a variable to all connected factors + """Spreading beliefs at a variable to all connected Factors Args: variable: A tuple representing a variable data: An array containing the beliefs to be spread uniformly - across all factor to variable messages involving this - variable. + across all factors to variable messages involving this variable. """ - def __setitem__(self, variable, data) -> None: object.__setattr__( self, "value", @@ -647,7 +614,7 @@ def update_evidence( evidence = evidence.at[start_index : start_index + name[1]].set(data) else: raise ValueError( - "Got evidence for a variable or a variable group not in the FactorGraph!" + "Got evidence for a variable or a VariableGroup not in the FactorGraph!" ) return evidence @@ -678,7 +645,7 @@ def __post_init__(self): object.__setattr__(self, "value", self.value) - def __getitem__(self, variable: Tuple[int, int]) -> np.ndarray: + def __getitem__(self, variable: Tuple[Any, int]) -> np.ndarray: """Function to query evidence for a variable Args: @@ -969,8 +936,9 @@ def unflatten_beliefs(flat_beliefs, variable_groups) -> Dict[Hashable, Any]: beliefs = {} start = 0 for variable_group in variable_groups: - variables = variable_group.variables - length = sum([variable[1] for variable in variables]) + num_states = variable_group.num_states + assert isinstance(num_states, np.ndarray) + length = num_states.sum() beliefs[variable_group] = variable_group.unflatten( flat_beliefs[start : start + length] diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 62692ffd..15f50603 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -5,6 +5,7 @@ from functools import total_ordering from typing import ( Any, + Collection, FrozenSet, List, Mapping, @@ -21,7 +22,7 @@ import pgmax.fg.nodes as nodes from pgmax.utils import cached_property -MAX_SIZE = 1e16 +MAX_SIZE = 1e10 @total_ordering @@ -40,7 +41,7 @@ def __eq__(self, other): def __lt__(self, other): return hash(self) < hash(other) - def __getitem__(self, val): + def __getitem__(self, val: Any) -> Union[Tuple[Any, int], List[Tuple[Any, int]]]: """Given a variable name, index, or a group of variable indices, retrieve the associated variable(s). Each variable is returned via a tuple of the form (variable hash/name, number of states) @@ -129,7 +130,7 @@ def __eq__(self, other): def __lt__(self, other): return hash(self) < hash(other) - def __getitem__(self, variables: Sequence[Tuple[int, int]]) -> Any: + def __getitem__(self, variables: Union[Sequence, Collection]) -> Any: """Function to query individual factors in the factor group Args: @@ -228,7 +229,7 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any: "Please subclass the FactorGroup class and override this method" ) - def compile_wiring(self, vars_to_starts: Mapping[Tuple[int, int], int]) -> Any: + def compile_wiring(self, vars_to_starts: Mapping[Tuple[Any, int], int]) -> Any: """Compile an efficient wiring for the FactorGroup. Args: diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py index 06d7c5d6..f987b8df 100644 --- a/pgmax/fg/nodes.py +++ b/pgmax/fg/nodes.py @@ -1,7 +1,7 @@ """A module containing classes that specify the basic components of a Factor Graph.""" from dataclasses import asdict, dataclass -from typing import List, Sequence, Tuple, Union +from typing import Any, List, Sequence, Tuple, Union import jax import jax.numpy as jnp @@ -48,7 +48,7 @@ class Factor: NotImplementedError: If compile_wiring is not implemented """ - variables: List[Tuple[int, int]] + variables: List[Tuple[Any, int]] log_potentials: np.ndarray def __post_init__(self): diff --git a/pgmax/groups/logical.py b/pgmax/groups/logical.py index 83115945..618a0526 100644 --- a/pgmax/groups/logical.py +++ b/pgmax/groups/logical.py @@ -4,6 +4,8 @@ from dataclasses import dataclass, field from typing import FrozenSet, OrderedDict, Type +import numpy as np + from pgmax.factors import logical from pgmax.fg import groups @@ -20,12 +22,9 @@ class LogicalFactorGroup(groups.FactorGroup): For ORFactors the edge_states_offset is 1, for ANDFactors the edge_states_offset is -1. """ + factor_configs: np.ndarray = field(init=False, default=None) edge_states_offset: int = field(init=False) - def __post_init__(self): - super().__post_init__() - object.__setattr__(self, "factor_configs", None) - def _get_variables_to_factors( self, ) -> OrderedDict[FrozenSet, logical.LogicalFactor]: diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index 7b50167e..57570a1a 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -26,9 +26,9 @@ class NDVariableArray(groups.VariableGroup): num_states: Union[int, np.ndarray] def __post_init__(self): - if np.prod(self.shape) > groups.MAX_SIZE: + if np.prod(self.shape) > int(groups.MAX_SIZE): raise ValueError( - f"Currently only support NDVariableArray of size smaller than {groups.MAX_SIZE}. Got {np.prod(self.shape)}" + f"Currently only support NDVariableArray of size smaller than {int(groups.MAX_SIZE)}. Got {np.prod(self.shape)}" ) if np.isscalar(self.num_states): @@ -42,7 +42,9 @@ def __post_init__(self): f"Expected num_states shape {self.shape}. Got {self.num_states.shape}." ) else: - raise ValueError("num_states entries should be of type np.int") + raise ValueError( + "num_states should be an integer or a NumPy array of dtype int" + ) def __getitem__( self, val: Union[int, slice, Tuple] @@ -94,7 +96,7 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: Args: data: Meaningful structured data. Should be an array of shape self.shape (for e.g. MAP decodings) - or self.shape + (self.num_states,) (for e.g. evidence, beliefs). + or self.shape + (self.num_states.max(),) (for e.g. evidence, beliefs). Returns: A flat jnp.array for internal use @@ -158,7 +160,8 @@ class VariableDict(groups.VariableGroup): Args: num_states: The size of the variables in this variable group - variable_names: A tuple of all names of the variables in this variable group + variable_names: A tuple of all names of the variables in this variable group. + Note that we overwrite variable_names to add the hash of the VariableDict """ variable_names: Tuple[Any, ...] @@ -176,7 +179,7 @@ def __post_init__(self): object.__setattr__(self, "variable_names", hash_and_names) @cached_property - def variables(self) -> List[Tuple]: + def variables(self) -> List[Tuple[Tuple[Any, int], int]]: """Function that returns the list of all variables in the VariableGroup. Each variable is represented by a tuple of the form (variable name, number of states) @@ -188,24 +191,23 @@ def variables(self) -> List[Tuple]: vars_num_states = self.num_states.flatten() return list(zip(vars_names, vars_num_states)) - def __getitem__(self, val: Any) -> Tuple[Any, int]: + def __getitem__(self, val: Any) -> Tuple[Tuple[Any, int], int]: """Given a variable name retrieve the associated variable, returned via a tuple of the form (variable name, number of states) Args: - val: a variable index or slice + val: a variable name Returns: The queried variable """ assert isinstance(self.num_states, np.ndarray) - if (self.__hash__(), val) not in self.variable_names: raise ValueError(f"Variable {val} is not in VariableDict") return ((self.__hash__(), val), self.num_states[0]) def flatten( - self, data: Mapping[Tuple[int, int], Union[np.ndarray, jnp.ndarray]] + self, data: Mapping[Tuple[Tuple[int, int], int], Union[np.ndarray, jnp.ndarray]] ) -> jnp.ndarray: """Function that turns meaningful structured data into a flat data array for internal use. diff --git a/tests/factors/test_or.py b/tests/factors/test_or.py index 16236b35..f29dccd0 100644 --- a/tests/factors/test_or.py +++ b/tests/factors/test_or.py @@ -102,7 +102,7 @@ def test_run_bp_with_ORFactors(): factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) - fg1.add_factor(factor=enum_factor) + fg1.add_factors(factor=enum_factor) else: if idx != 0: # Add the second half of factors to FactorGraph2 @@ -111,7 +111,7 @@ def test_run_bp_with_ORFactors(): factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) - fg2.add_factor(factor=enum_factor) + fg2.add_factors(factor=enum_factor) else: # Add all the EnumerationFactors to FactorGraph1 for the first iter enum_factor = EnumerationFactor( @@ -119,7 +119,7 @@ def test_run_bp_with_ORFactors(): factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) - fg1.add_factor(factor=enum_factor) + fg1.add_factors(factor=enum_factor) # Option 2: Define the ORFactors num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0) @@ -143,10 +143,10 @@ def test_run_bp_with_ORFactors(): ) if idx != 0: factor_group = logical.ORFactorGroup(variables_for_ORFactors_fg1) - fg1.add_factor_group(factor_group) + fg1.add_factors(factor_group) factor_group = logical.ORFactorGroup(variables_for_ORFactors_fg2) - fg2.add_factor_group(factor_group) + fg2.add_factors(factor_group) # Run inference bp1 = graph.BP(fg1.bp_state, temperature=temperature) diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py index bb26ce99..df734ac2 100644 --- a/tests/fg/test_graph.py +++ b/tests/fg/test_graph.py @@ -130,7 +130,7 @@ def test_ftov_msgs(): with pytest.raises( ValueError, - match=re.escape("Invalid names for setting messages"), + match=re.escape("Provided variable is not in the FactorGraph"), ): fg.bp_state.ftov_msgs[0] = np.ones(10) @@ -175,7 +175,7 @@ def test_evidence(): with pytest.raises( ValueError, match=re.escape( - "Got evidence for a variable or a variable group not in the FactorGraph!" + "Got evidence for a variable or a VariableGroup not in the FactorGraph!" ), ): graph.update_evidence( diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index 82a2ea06..322147c4 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -54,6 +54,15 @@ def test_variable_dict(): def test_nd_variable_array(): + max_size = int(groups.MAX_SIZE) + with pytest.raises( + ValueError, + match=re.escape( + f"Currently only support NDVariableArray of size smaller than {max_size}. Got {max_size + 1}" + ), + ): + vgroup.NDVariableArray(shape=(max_size + 1,), num_states=2) + num_states = np.full((2, 3), fill_value=2) with pytest.raises( ValueError, match=re.escape("Expected num_states shape (2, 2). Got (2, 3).") @@ -62,14 +71,19 @@ def test_nd_variable_array(): num_states = np.full((2, 3), fill_value=2, dtype=np.float32) with pytest.raises( - ValueError, match=re.escape("num_states entries should be of type np.int") + ValueError, + match=re.escape( + "num_states should be an integer or a NumPy array of dtype int" + ), ): vgroup.NDVariableArray(shape=(2, 2), num_states=num_states) - variable_group = vgroup.NDVariableArray(shape=(5, 5), num_states=2) - assert len(variable_group[:3, :3]) == 9 + variable_group0 = vgroup.NDVariableArray(shape=(5, 5), num_states=2) + assert len(variable_group0[:3, :3]) == 9 variable_group = vgroup.NDVariableArray(shape=(2, 2), num_states=3) + variable_group0 < variable_group + with pytest.raises( ValueError, match=re.escape("data should be of shape (2, 2) or (2, 2, 3). Got (3, 3)."), diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py index 643448a3..9e6deb65 100644 --- a/tests/test_pgmax.py +++ b/tests/test_pgmax.py @@ -218,10 +218,10 @@ def create_valid_suppression_config_arr(suppression_diameter): (grid_vars, (1, 0, 1)): 0, (grid_vars, (1, 1, 0)): 1, (grid_vars, (1, 1, 1)): 0, - (additional_vars, ((0, 0, 2), 3)): 0, - (additional_vars, ((0, 1, 2), 3)): 2, - (additional_vars, ((1, 2, 0), 3)): 1, - (additional_vars, ((1, 2, 1), 3)): 0, + (additional_vars, ((additional_vars.__hash__(), (0, 0, 2)), 3)): 0, + (additional_vars, ((additional_vars.__hash__(), (0, 1, 2)), 3)): 2, + (additional_vars, ((additional_vars.__hash__(), (1, 2, 0)), 3)): 1, + (additional_vars, ((additional_vars.__hash__(), (1, 2, 1)), 3)): 0, } gt_has_cuts = gt_has_cuts.astype(np.int32) From 546b7906e694a29677b065b49781f6251ac049ff Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Tue, 26 Apr 2022 19:39:36 +0000 Subject: [PATCH 22/35] Flattent / unflattent --- pgmax/fg/graph.py | 2 +- pgmax/groups/variables.py | 22 ++++++++++++---------- tests/fg/test_graph.py | 4 ++-- tests/fg/test_groups.py | 23 +++++++++++++---------- 4 files changed, 28 insertions(+), 23 deletions(-) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 2f16f2eb..67ca4477 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -460,7 +460,7 @@ def __getitem__(self, factor_group: groups.FactorGroup) -> np.ndarray: start : start + factor_group.factor_group_log_potentials.shape[0] ] else: - raise ValueError("Invalid FactorGroup for log potentials updates.") + raise ValueError("Invalid FactorGroup queried to access log potentials.") return log_potentials def __setitem__( diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index 57570a1a..ad03e966 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -106,19 +106,19 @@ def flatten(self, data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: """ assert isinstance(self.num_states, np.ndarray) - # TODO: what should we do for different number of states -> look at mask_array - if data.shape != self.shape and data.shape != self.shape + ( - self.num_states.max(), - ): + if data.shape == self.shape: + return jax.device_put(data).flatten() + elif data.shape == self.shape + (self.num_states.max(),): + return jax.device_put( + data[np.arange(data.shape[-1]) < self.num_states[..., None]] + ) + else: raise ValueError( f"data should be of shape {self.shape} or {self.shape + (self.num_states.max(),)}. " f"Got {data.shape}." ) - return jax.device_put(data).flatten() - def unflatten( - self, flat_data: Union[np.ndarray, jnp.ndarray] - ) -> Union[np.ndarray, jnp.ndarray]: + def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: """Function that recovers meaningful structured data from internal flat data array Args: @@ -143,8 +143,10 @@ def unflatten( if flat_data.size == np.product(self.shape): data = flat_data.reshape(self.shape) elif flat_data.size == self.num_states.sum(): - # TODO: what should we do for different number of states - data = flat_data.reshape(self.shape + (self.num_states.max(),)) + data = jnp.zeros(self.shape + (self.num_states.max(),)) + data = data.at[np.arange(data.shape[-1]) < self.num_states[..., None]].set( + flat_data + ) else: raise ValueError( f"flat_data should be compatible with shape {self.shape} or {self.shape + (self.num_states.max(),)}. " diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py index df734ac2..0fdadbb2 100644 --- a/tests/fg/test_graph.py +++ b/tests/fg/test_graph.py @@ -106,9 +106,9 @@ def test_log_potentials(): with pytest.raises( ValueError, - match=re.escape("Invalid FactorGroup for log potentials updates."), + match=re.escape("Invalid FactorGroup queried to access log potentials."), ): - fg.bp_state.log_potentials[vg[0]] = np.zeros(10) + fg.bp_state.log_potentials[vg[0]] with pytest.raises( ValueError, match=re.escape("Expected log potentials shape (10,). Got (15,)") diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index 322147c4..8b5d7a44 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -81,18 +81,22 @@ def test_nd_variable_array(): variable_group0 = vgroup.NDVariableArray(shape=(5, 5), num_states=2) assert len(variable_group0[:3, :3]) == 9 - variable_group = vgroup.NDVariableArray(shape=(2, 2), num_states=3) + variable_group = vgroup.NDVariableArray( + shape=(2, 2), num_states=np.array([[1, 2], [3, 4]]) + ) variable_group0 < variable_group with pytest.raises( ValueError, - match=re.escape("data should be of shape (2, 2) or (2, 2, 3). Got (3, 3)."), + match=re.escape("data should be of shape (2, 2) or (2, 2, 4). Got (3, 3)."), ): variable_group.flatten(np.zeros((3, 3))) assert jnp.all( variable_group.flatten(np.array([[1, 2], [3, 4]])) == jnp.array([1, 2, 3, 4]) ) + assert jnp.all(variable_group.flatten(np.zeros((2, 2, 4))) == jnp.zeros((10,))) + with pytest.raises( ValueError, match="Can only unflatten 1D array. Got a 2D array." ): @@ -101,13 +105,13 @@ def test_nd_variable_array(): with pytest.raises( ValueError, match=re.escape( - "flat_data should be compatible with shape (2, 2) or (2, 2, 3). Got (10,)." + "flat_data should be compatible with shape (2, 2) or (2, 2, 4). Got (12,)." ), ): - variable_group.unflatten(np.zeros((10,))) + variable_group.unflatten(np.zeros((12,))) assert jnp.all(variable_group.unflatten(np.zeros(4)) == jnp.zeros((2, 2))) - assert jnp.all(variable_group.unflatten(np.zeros(12)) == jnp.zeros((2, 2, 3))) + assert jnp.all(variable_group.unflatten(np.zeros(10)) == jnp.zeros((2, 2, 4))) def test_single_factor(): @@ -119,21 +123,20 @@ def test_single_factor(): variables0 = (A[0], B[0]) variables1 = (A[1], B[1]) - ORFactor = logical.ORFactorGroup(variables_for_factors=[variables0]) + ORFactor0 = logical.ORFactorGroup(variables_for_factors=[variables0]) with pytest.raises( ValueError, match="SingleFactorGroup should only contain one factor. Got 2" ): groups.SingleFactorGroup( variables_for_factors=[variables0, variables1], - factor=ORFactor, + factor=ORFactor0, ) + ORFactor1 = logical.ORFactorGroup(variables_for_factors=[variables1]) + ORFactor0 < ORFactor1 def test_enumeration_factor_group(): vg = vgroup.NDVariableArray(shape=(2, 2), num_states=3) - vg_bis = vgroup.NDVariableArray(shape=(2, 2), num_states=3) - vg < vg_bis - with pytest.raises( ValueError, match=re.escape("Expected log potentials shape: (1,) or (2, 1). Got (3, 2)"), From 35ce6c053e8a6940aefc8ad545accfec138cdcd3 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Tue, 26 Apr 2022 19:48:05 +0000 Subject: [PATCH 23/35] Unflatten with nan --- pgmax/groups/variables.py | 4 +++- tests/fg/test_groups.py | 8 +++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index ad03e966..887a32a7 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -143,7 +143,9 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: if flat_data.size == np.product(self.shape): data = flat_data.reshape(self.shape) elif flat_data.size == self.num_states.sum(): - data = jnp.zeros(self.shape + (self.num_states.max(),)) + data = jnp.full( + shape=self.shape + (self.num_states.max(),), fill_value=jnp.nan + ) data = data.at[np.arange(data.shape[-1]) < self.num_states[..., None]].set( flat_data ) diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index 8b5d7a44..b65a8ce5 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -111,7 +111,13 @@ def test_nd_variable_array(): variable_group.unflatten(np.zeros((12,))) assert jnp.all(variable_group.unflatten(np.zeros(4)) == jnp.zeros((2, 2))) - assert jnp.all(variable_group.unflatten(np.zeros(10)) == jnp.zeros((2, 2, 4))) + unflattened = jnp.full((2, 2, 4), fill_value=jnp.nan) + unflattened = unflattened.at[0, 0, 0].set(0) + unflattened = unflattened.at[0, 1, :1].set(0) + unflattened = unflattened.at[1, 0, :2].set(0) + unflattened = unflattened.at[1, 1].set(0) + mask = ~jnp.isnan(unflattened) + assert jnp.all(variable_group.unflatten(np.zeros(10))[mask] == unflattened[mask]) def test_single_factor(): From 704ee3a6c978f0cbbb8a8195a9db8d95f228db2a Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Wed, 27 Apr 2022 01:30:37 +0000 Subject: [PATCH 24/35] Speeding up --- examples/pmp_binary_deconvolution.py | 1 + examples/rbm.py | 5 +-- pgmax/factors/enumeration.py | 9 +++-- pgmax/factors/logical.py | 26 +++++++++----- pgmax/fg/graph.py | 7 ++-- pgmax/fg/groups.py | 52 +++++++++++++++++++--------- pgmax/groups/enumeration.py | 22 ++++++++---- pgmax/groups/variables.py | 27 +++++++++++---- 8 files changed, 102 insertions(+), 47 deletions(-) diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py index a0f24406..fd396f3a 100644 --- a/examples/pmp_binary_deconvolution.py +++ b/examples/pmp_binary_deconvolution.py @@ -211,6 +211,7 @@ def plot_images(images, display=True, nr=None): # in the same manner does not change X, so this naturally results in multiple equivalent modes. # %% +# %load_ext snakeviz start = time.time() bp = graph.BP(fg.bp_state, temperature=0.0) print("Time", time.time() - start) diff --git a/examples/rbm.py b/examples/rbm.py index cdc51390..cde32960 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -90,9 +90,10 @@ ], log_potential_matrix=log_potential_matrix, ) -fg.add_factors(pairwise_factors) +print("Time", time.time() - start) +# #%snakeviz pairwise_factors = enumeration.PairwiseFactorGroup(variables_for_factors=[ [hidden_variables[ii], visible_variables[jj]] for ii in range(bh.shape[0]) for jj in range(bv.shape[0])],log_potential_matrix=log_potential_matrix,) -# # %snakeviz fg.add_factors(factory=enumeration.PairwiseFactorGroup, variables_for_factors=v, log_potential_matrix=log_potential_matrix,) +fg.add_factors(pairwise_factors) print("Time", time.time() - start) diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py index 5f776ba0..ef2a99a7 100644 --- a/pgmax/factors/enumeration.py +++ b/pgmax/factors/enumeration.py @@ -154,6 +154,7 @@ def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiri @staticmethod def compile_wiring( + factor_edges_num_states: np.ndarray, variables_for_factors: Sequence[List], factor_configs: np.ndarray, vars_to_starts: Mapping[Tuple[Any, int], int], @@ -163,10 +164,14 @@ def compile_wiring( Internally calls _compile_var_states_numba and _compile_enumeration_wiring_numba for speed. Args: + factor_edges_num_states: An array concatenating the number of states for the variables connected to each + Factor of the FactorGroup. Each variable will appear once for each Factor it connects to. variables_for_factors: A list of list of variables. Each list within the outer list contains the variables connected to a Factor. The same variable can be connected to multiple Factors. factor_configs: Array of shape (num_val_configs, num_variables) containing an explicit enumeration of all valid configurations. + factor_edges_num_states: Array concatenating the number of states for the variables connected to each Factor of + the FactorGroup. Each variable will appear once for each Factor it connects to. vars_to_starts: A dictionary that maps variables to their global starting indices For an n-state variable, a global start index of m means the global indices of its n variable states are m, m + 1, ..., m + n - 1 @@ -178,14 +183,12 @@ def compile_wiring( Returns: The EnumerationWiring """ + # TODO: Don't use vars_to_starts var_states = [] - factor_edges_num_states = [] for variables_for_factor in variables_for_factors: for variable in variables_for_factor: var_states.append(vars_to_starts[variable]) - factor_edges_num_states.append(variable[1]) var_states = np.array(var_states) - factor_edges_num_states = np.array(factor_edges_num_states) num_states_cumsum = np.insert(np.cumsum(factor_edges_num_states), 0, 0) var_states_for_edges = np.empty(shape=(num_states_cumsum[-1],), dtype=int) diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py index b78ea206..c8f93f6b 100644 --- a/pgmax/factors/logical.py +++ b/pgmax/factors/logical.py @@ -140,7 +140,9 @@ def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring: @staticmethod def compile_wiring( + factor_edges_num_states: np.ndarray, variables_for_factors: Sequence[List], + factor_sizes: np.ndarray, vars_to_starts: Mapping[Tuple[Any, int], int], edge_states_offset: int, ) -> LogicalWiring: @@ -148,8 +150,11 @@ def compile_wiring( Internally calls _compile_var_states_numba and _compile_logical_wiring_numba for speed. Args: - variables_for_factors: A list of list of variables. Each list within the outer list contains the - variables connected to a Factor. The same variable can be connected to multiple Factors. + factor_edges_num_states: An array concatenating the number of states for the variables connected to each + Factor of the FactorGroup. Each variable will appear once for each Factor it connects to. + variables_for_factors: A tuple of tuples containing variables connected to each Factor of the FactorGroup. + Each variable will appear once for each Factor it connects to. + factor_sizes: An array containing the different factor sizes. vars_to_starts: A dictionary that maps variables to their global starting indices For an n-state variable, a global start index of m means the global indices of its n variable states are m, m + 1, ..., m + n - 1 @@ -159,17 +164,12 @@ def compile_wiring( Returns: The LogicalWiring """ - factor_sizes = [] + # TODO: Don't use vars_to_starts var_states = [] - factor_edges_num_states = [] for variables_for_factor in variables_for_factors: - factor_sizes.append(len(variables_for_factor)) for variable in variables_for_factor: var_states.append(vars_to_starts[variable]) - factor_edges_num_states.append(variable[1]) - factor_sizes = np.array(factor_sizes) var_states = np.array(var_states) - factor_edges_num_states = np.array(factor_edges_num_states) # Relevant state differs for ANDFactors and ORFactors relevant_state = (-edge_states_offset + 1) // 2 @@ -241,6 +241,16 @@ class ANDFactor(LogicalFactor): edge_states_offset: int = field(init=False, default=-1) +@nb.jit(parallel=False, cache=True, fastmath=True, nopython=True) +def _compile_utils_numba(var_states, factor_edges_num_states, variables_for_factors): + idx = 0 + for variables_for_factor in variables_for_factors: + for variable in variables_for_factor: + var_states[idx] = variable + factor_edges_num_states[idx] = variable[1] + idx += 1 + + @nb.jit(parallel=False, cache=True, fastmath=True, nopython=True) def _compile_logical_wiring_numba( parents_edge_states: np.ndarray, diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 67ca4477..a13fc426 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -68,9 +68,8 @@ def __post_init__(self): ) # See FactorGraphState docstrings for documentation on the following fields - self._vars_to_starts: OrderedDict[ - Tuple[int, int], int - ] = collections.OrderedDict() + self._vars_to_starts: Dict[Tuple[int, int], int] = {} + vars_num_states_cumsum = 0 for variable_group in self.variable_groups: vg_num_states = variable_group.num_states.flatten() @@ -159,7 +158,7 @@ def compute_offsets(self) -> None: factor_group ] = factor_num_configs_cumsum - factor_num_states_cumsum += factor_group.total_num_states + factor_num_states_cumsum += factor_group.factor_edges_num_states.sum() factor_num_configs_cumsum += ( factor_group.factor_group_log_potentials.shape[0] ) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 15f50603..1ffc771d 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -17,6 +17,7 @@ ) import jax.numpy as jnp +import numba as nb import numpy as np import pgmax.fg.nodes as nodes @@ -30,10 +31,13 @@ class VariableGroup: """Class to represent a group of variables. Each variable is represented via a tuple of the form (variable hash/name, number of states) + + Attributes: + this_hash: Hash of the VariableGroup """ def __hash__(self): - return id(self) * int(MAX_SIZE) + return self.this_hash def __eq__(self, other): return hash(self) == hash(other) @@ -107,6 +111,9 @@ class FactorGroup: Attributes: factor_type: Factor type shared by all the Factors in the FactorGroup. + factor_sizes: Array of the different factor sizes. + factor_edges_num_states: Array concatenating the number of states for the variables connected to each Factor of + the FactorGroup. Each variable will appear once for each Factor it connects to. Raises: ValueError: if the FactorGroup does not contain a Factor @@ -116,11 +123,27 @@ class FactorGroup: factor_configs: np.ndarray = field(init=False) log_potentials: np.ndarray = field(init=False, default=np.empty((0,))) factor_type: Type = field(init=False) + factor_sizes: np.ndarray = field(init=False) + factor_edges_num_states: np.ndarray = field(init=False) def __post_init__(self): if len(self.variables_for_factors) == 0: raise ValueError("Cannot create a FactorGroup with no Factor.") + factor_sizes = np.array( + [ + len(variables_for_factor) + for variables_for_factor in self.variables_for_factors + ] + ) + object.__setattr__(self, "factor_sizes", factor_sizes) + + factor_edges_num_states = np.empty(shape=(self.factor_sizes.sum(),), dtype=int) + _compile_edges_num_states_numba( + factor_edges_num_states, self.variables_for_factors + ) + object.__setattr__(self, "factor_edges_num_states", factor_edges_num_states) + def __hash__(self): return id(self) @@ -160,22 +183,6 @@ def _variables_to_factors(self) -> Mapping[FrozenSet, nodes.Factor]: """ return self._get_variables_to_factors() - @cached_property - def total_num_states(self) -> int: - """Function to return the total number of states for all the variables involved in all the Factors - - Returns: - Total number of variable states in the FactorGroup - """ - # TODO: this could be returned by the wiring to loop over variables_for_factors only once - return sum( - [ - variable[1] - for variables_for_factor in self.variables_for_factors - for variable in variables_for_factor - ] - ) - @cached_property def factor_group_log_potentials(self) -> np.ndarray: """Flattened array of log potentials""" @@ -321,3 +328,14 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any: raise NotImplementedError( "SingleFactorGroup does not support vectorized factor operations." ) + + +nb.jit(parallel=False, cache=True, fastmath=True, nopython=False) + + +def _compile_edges_num_states_numba(factor_edges_num_states, variables_for_factors): + idx = 0 + for variables_for_factor in variables_for_factors: + for variable in variables_for_factor: + factor_edges_num_states[idx] = variable[1] + idx += 1 diff --git a/pgmax/groups/enumeration.py b/pgmax/groups/enumeration.py index 03f17cb3..a09d9f24 100644 --- a/pgmax/groups/enumeration.py +++ b/pgmax/groups/enumeration.py @@ -65,7 +65,6 @@ def __post_init__(self): raise ValueError( f"Potentials should be floats. Got {log_potentials.dtype}." ) - object.__setattr__(self, "log_potentials", log_potentials) def _get_variables_to_factors( @@ -232,6 +231,10 @@ def __post_init__(self): f"Got log_potential_matrix for {log_potential_matrix.shape[0]} factors." ) + import time + + start = time.time() + log_potential_shape = log_potential_matrix.shape[-2:] for variables_for_factor in self.variables_for_factors: if len(variables_for_factor) != 2: raise ValueError( @@ -239,15 +242,19 @@ def __post_init__(self): f" {len(variables_for_factor)} variables ({variables_for_factor})." ) - num_states0 = variables_for_factor[0][1] - num_states1 = variables_for_factor[1][1] - if not log_potential_matrix.shape[-2:] == (num_states0, num_states1): + factor_num_configs = ( + variables_for_factor[0][1], + variables_for_factor[1][1], + ) + if log_potential_shape != factor_num_configs: raise ValueError( - f"The specified pairwise factor {variables_for_factor} (with {(num_states0, num_states1)}" + f"The specified pairwise factor {variables_for_factor} (with {factor_num_configs}" f"configurations) does not match the specified log_potential_matrix " - f"(with {log_potential_matrix.shape[-2:]} configurations)." + f"(with {log_potential_shape} configurations)." ) - object.__setattr__(self, "log_potential_matrix", log_potential_matrix) + object.__setattr__(self, "log_potential_matrix", log_potential_matrix) + print(time.time() - start) + factor_configs = ( np.mgrid[ : log_potential_matrix.shape[-2], @@ -257,6 +264,7 @@ def __post_init__(self): .reshape((-1, 2)) ) object.__setattr__(self, "factor_configs", factor_configs) + log_potential_matrix = np.broadcast_to( log_potential_matrix, (len(self.variables_for_factors),) + log_potential_matrix.shape[-2:], diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index 887a32a7..92551ffa 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -10,6 +10,8 @@ from pgmax.fg import groups from pgmax.utils import cached_property +MAX_SIZE = 1e9 + @dataclass(frozen=True, eq=False) class NDVariableArray(groups.VariableGroup): @@ -26,13 +28,13 @@ class NDVariableArray(groups.VariableGroup): num_states: Union[int, np.ndarray] def __post_init__(self): - if np.prod(self.shape) > int(groups.MAX_SIZE): + if np.prod(self.shape) > MAX_SIZE: raise ValueError( - f"Currently only support NDVariableArray of size smaller than {int(groups.MAX_SIZE)}. Got {np.prod(self.shape)}" + f"Currently only support NDVariableArray of size smaller than {int(MAX_SIZE)}. Got {np.prod(self.shape)}" ) if np.isscalar(self.num_states): - num_states = np.full(self.shape, fill_value=self.num_states, dtype=np.int32) + num_states = np.full(self.shape, fill_value=self.num_states, dtype=np.int64) object.__setattr__(self, "num_states", num_states) elif isinstance(self.num_states, np.ndarray) and np.issubdtype( self.num_states.dtype, int @@ -46,6 +48,12 @@ def __post_init__(self): "num_states should be an integer or a NumPy array of dtype int" ) + # Only compute the hash once, which is guaranteed to be an int64 + this_id = id(self) % 2**32 + this_hash = this_id * int(MAX_SIZE) + assert this_hash < 2**63 + object.__setattr__(self, "this_hash", this_hash) + def __getitem__( self, val: Union[int, slice, Tuple] ) -> Union[Tuple[int, int], List[Tuple[int, int]]]: @@ -61,13 +69,17 @@ def __getitem__( A single variable or a list of variables """ assert isinstance(self.num_states, np.ndarray) - if np.isscalar(self.variable_names[val]): - return (self.variable_names[val], self.num_states[val]) - else: + + if isinstance(val, slice) or ( + isinstance(val, tuple) and isinstance(val[0], slice) + ): + assert isinstance(self.num_states, np.ndarray) vars_names = self.variable_names[val].flatten() vars_num_states = self.num_states[val].flatten() return list(zip(vars_names, vars_num_states)) + return (self.variable_names[val], self.num_states[val]) + @cached_property def variables(self) -> List[Tuple]: """Function that returns the list of all variables in the VariableGroup. @@ -177,6 +189,9 @@ def __post_init__(self): ) object.__setattr__(self, "num_states", num_states) + # Only compute the hash once + object.__setattr__(self, "this_hash", id(self)) + hash_and_names = tuple( (self.__hash__(), var_name) for var_name in self.variable_names ) From aaa67ef5f828c89c3450daccd5ed2093c382c691 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Wed, 27 Apr 2022 01:33:44 +0000 Subject: [PATCH 25/35] max size --- pgmax/fg/groups.py | 2 -- tests/fg/test_groups.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 1ffc771d..bf1c43d1 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -23,8 +23,6 @@ import pgmax.fg.nodes as nodes from pgmax.utils import cached_property -MAX_SIZE = 1e10 - @total_ordering @dataclass(frozen=True, eq=False) diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index b65a8ce5..b47e3e09 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -54,7 +54,7 @@ def test_variable_dict(): def test_nd_variable_array(): - max_size = int(groups.MAX_SIZE) + max_size = int(vgroup.MAX_SIZE) with pytest.raises( ValueError, match=re.escape( From 04c4d89b79b7b8c353cdb26faed3eb8e35c70342 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Wed, 27 Apr 2022 01:54:43 +0000 Subject: [PATCH 26/35] Understand timings --- examples/pmp_binary_deconvolution.py | 1 - examples/rbm.py | 13 +++++++------ pgmax/factors/logical.py | 10 ---------- 3 files changed, 7 insertions(+), 17 deletions(-) diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py index fd396f3a..a0f24406 100644 --- a/examples/pmp_binary_deconvolution.py +++ b/examples/pmp_binary_deconvolution.py @@ -211,7 +211,6 @@ def plot_images(images, display=True, nr=None): # in the same manner does not change X, so this naturally results in multiple equivalent modes. # %% -# %load_ext snakeviz start = time.time() bp = graph.BP(fg.bp_state, temperature=0.0) print("Time", time.time() - start) diff --git a/examples/rbm.py b/examples/rbm.py index cde32960..8565f086 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -82,16 +82,17 @@ log_potential_matrix = np.zeros(W.shape + (2, 2)).reshape((-1, 2, 2)) log_potential_matrix[:, 1, 1] = W.ravel() +variables_for_factors = [ + [hidden_variables[ii], visible_variables[jj]] + for ii in range(bh.shape[0]) + for jj in range(bv.shape[0]) +] +print("Time", time.time() - start) pairwise_factors = enumeration.PairwiseFactorGroup( - variables_for_factors=[ - [hidden_variables[ii], visible_variables[jj]] - for ii in range(bh.shape[0]) - for jj in range(bv.shape[0]) - ], + variables_for_factors=variables_for_factors, log_potential_matrix=log_potential_matrix, ) print("Time", time.time() - start) -# #%snakeviz pairwise_factors = enumeration.PairwiseFactorGroup(variables_for_factors=[ [hidden_variables[ii], visible_variables[jj]] for ii in range(bh.shape[0]) for jj in range(bv.shape[0])],log_potential_matrix=log_potential_matrix,) fg.add_factors(pairwise_factors) print("Time", time.time() - start) diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py index c8f93f6b..8b931e47 100644 --- a/pgmax/factors/logical.py +++ b/pgmax/factors/logical.py @@ -241,16 +241,6 @@ class ANDFactor(LogicalFactor): edge_states_offset: int = field(init=False, default=-1) -@nb.jit(parallel=False, cache=True, fastmath=True, nopython=True) -def _compile_utils_numba(var_states, factor_edges_num_states, variables_for_factors): - idx = 0 - for variables_for_factor in variables_for_factors: - for variable in variables_for_factor: - var_states[idx] = variable - factor_edges_num_states[idx] = variable[1] - idx += 1 - - @nb.jit(parallel=False, cache=True, fastmath=True, nopython=True) def _compile_logical_wiring_numba( parents_edge_states: np.ndarray, From a276ce5a45676872eb8471953ec3567f497ca479 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Wed, 27 Apr 2022 18:04:45 +0000 Subject: [PATCH 27/35] Some comments --- examples/ising_model.py | 4 +-- examples/pmp_binary_deconvolution.py | 2 +- pgmax/fg/graph.py | 5 +-- pgmax/fg/groups.py | 13 ++++++-- pgmax/groups/enumeration.py | 6 +--- pgmax/groups/variables.py | 48 +++++++++++++++------------- 6 files changed, 40 insertions(+), 38 deletions(-) diff --git a/examples/ising_model.py b/examples/ising_model.py index 3f530a33..bf30c7bd 100644 --- a/examples/ising_model.py +++ b/examples/ising_model.py @@ -37,7 +37,7 @@ kk = (ii + 1) % 50 ll = (jj + 1) % 50 variables_for_factors.append([variables[ii, jj], variables[kk, jj]]) - variables_for_factors.append([variables[ii, jj], variables[kk, ll]]) + variables_for_factors.append([variables[ii, jj], variables[ii, ll]]) factor_group = enumeration.PairwiseFactorGroup( variables_for_factors=variables_for_factors, @@ -52,8 +52,6 @@ bp = graph.BP(fg.bp_state, temperature=0) # %% -# TODO: check\ time for before BP vs time for BP -# TODO: time PGMAX vs PMP bp_arrays = bp.init( evidence_updates={variables: jax.device_put(np.random.gumbel(size=(50, 50, 2)))} ) diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py index a0f24406..131c79e3 100644 --- a/examples/pmp_binary_deconvolution.py +++ b/examples/pmp_binary_deconvolution.py @@ -138,7 +138,7 @@ def plot_images(images, display=True, nr=None): print("Time", time.time() - start) # %% [markdown] -# For computation efficiency, we add large FactorGroups via `fg.add_factors` instead of adding individual Factors +# For computation efficiency, we construct large FactorGroups instead of individual Factors # %% start = time.time() diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index a13fc426..29ecc26a 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -49,9 +49,6 @@ class FactorGraph: variable_groups: Union[groups.VariableGroup, Sequence[groups.VariableGroup]] def __post_init__(self): - import time - - start = time.time() if isinstance(self.variable_groups, groups.VariableGroup): self.variable_groups = [self.variable_groups] @@ -68,6 +65,7 @@ def __post_init__(self): ) # See FactorGraphState docstrings for documentation on the following fields + # TODO: vars_to_starts does not have to be a dict self._vars_to_starts: Dict[Tuple[int, int], int] = {} vars_num_states_cumsum = 0 @@ -82,7 +80,6 @@ def __post_init__(self): ) vars_num_states_cumsum += vg_num_states_cumsum[-1] self._num_var_states = vars_num_states_cumsum - print("Init", time.time() - start) def __hash__(self) -> int: return hash(self.factor_groups) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index bf1c43d1..5f47929b 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -23,6 +23,8 @@ import pgmax.fg.nodes as nodes from pgmax.utils import cached_property +MAX_SIZE = 1e9 + @total_ordering @dataclass(frozen=True, eq=False) @@ -34,6 +36,13 @@ class VariableGroup: this_hash: Hash of the VariableGroup """ + def __post_init__(self): + # Only compute the hash once, which is guaranteed to be an int64 + this_id = id(self) % 2**32 + this_hash = this_id * int(MAX_SIZE) + assert this_hash < 2**63 + object.__setattr__(self, "this_hash", this_hash) + def __hash__(self): return self.this_hash @@ -328,9 +337,7 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any: ) -nb.jit(parallel=False, cache=True, fastmath=True, nopython=False) - - +# @nb.jit(parallel=False, cache=True, fastmath=True) def _compile_edges_num_states_numba(factor_edges_num_states, variables_for_factors): idx = 0 for variables_for_factor in variables_for_factors: diff --git a/pgmax/groups/enumeration.py b/pgmax/groups/enumeration.py index a09d9f24..dab6e14b 100644 --- a/pgmax/groups/enumeration.py +++ b/pgmax/groups/enumeration.py @@ -231,9 +231,6 @@ def __post_init__(self): f"Got log_potential_matrix for {log_potential_matrix.shape[0]} factors." ) - import time - - start = time.time() log_potential_shape = log_potential_matrix.shape[-2:] for variables_for_factor in self.variables_for_factors: if len(variables_for_factor) != 2: @@ -252,8 +249,7 @@ def __post_init__(self): f"configurations) does not match the specified log_potential_matrix " f"(with {log_potential_shape} configurations)." ) - object.__setattr__(self, "log_potential_matrix", log_potential_matrix) - print(time.time() - start) + object.__setattr__(self, "log_potential_matrix", log_potential_matrix) factor_configs = ( np.mgrid[ diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index 92551ffa..7d61fb7a 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -10,8 +10,6 @@ from pgmax.fg import groups from pgmax.utils import cached_property -MAX_SIZE = 1e9 - @dataclass(frozen=True, eq=False) class NDVariableArray(groups.VariableGroup): @@ -28,9 +26,12 @@ class NDVariableArray(groups.VariableGroup): num_states: Union[int, np.ndarray] def __post_init__(self): - if np.prod(self.shape) > MAX_SIZE: + super().__post_init__() + + max_size = int(groups.MAX_SIZE) + if np.prod(self.shape) > max_size: raise ValueError( - f"Currently only support NDVariableArray of size smaller than {int(MAX_SIZE)}. Got {np.prod(self.shape)}" + f"Currently only support NDVariableArray of size smaller than {max_size}. Got {np.prod(self.shape)}" ) if np.isscalar(self.num_states): @@ -48,12 +49,6 @@ def __post_init__(self): "num_states should be an integer or a NumPy array of dtype int" ) - # Only compute the hash once, which is guaranteed to be an int64 - this_id = id(self) % 2**32 - this_hash = this_id * int(MAX_SIZE) - assert this_hash < 2**63 - object.__setattr__(self, "this_hash", this_hash) - def __getitem__( self, val: Union[int, slice, Tuple] ) -> Union[Tuple[int, int], List[Tuple[int, int]]]: @@ -175,28 +170,35 @@ class VariableDict(groups.VariableGroup): """A variable dictionary that contains a set of variables of the same size Args: - num_states: The size of the variables in this variable group - variable_names: A tuple of all names of the variables in this variable group. + num_states: The size of the variables in this VariableGroup + variable_names: A tuple of all names of the variables in this VariableGroup. Note that we overwrite variable_names to add the hash of the VariableDict """ variable_names: Tuple[Any, ...] - num_states: int + num_states: Union[int, np.ndarray] def __post_init__(self): - num_states = np.full( - (len(self.variable_names),), fill_value=self.num_states, dtype=np.int32 - ) - object.__setattr__(self, "num_states", num_states) - - # Only compute the hash once - object.__setattr__(self, "this_hash", id(self)) + super().__post_init__() hash_and_names = tuple( (self.__hash__(), var_name) for var_name in self.variable_names ) object.__setattr__(self, "variable_names", hash_and_names) + if np.isscalar(self.num_states): + num_states = np.full( + len(self.variable_names), fill_value=self.num_states, dtype=np.int64 + ) + object.__setattr__(self, "num_states", num_states) + elif isinstance(self.num_states, np.ndarray) and np.issubdtype( + self.num_states.dtype, int + ): + if self.num_states.shape != len(self.variable_names): + raise ValueError( + f"Expected num_states shape {len(self.variable_names)}. Got {self.num_states.shape}." + ) + @cached_property def variables(self) -> List[Tuple[Tuple[Any, int], int]]: """Function that returns the list of all variables in the VariableGroup. @@ -215,7 +217,7 @@ def __getitem__(self, val: Any) -> Tuple[Tuple[Any, int], int]: (variable name, number of states) Args: - val: a variable name + val: a variable name (without the object hash) Returns: The queried variable @@ -223,7 +225,9 @@ def __getitem__(self, val: Any) -> Tuple[Tuple[Any, int], int]: assert isinstance(self.num_states, np.ndarray) if (self.__hash__(), val) not in self.variable_names: raise ValueError(f"Variable {val} is not in VariableDict") - return ((self.__hash__(), val), self.num_states[0]) + + idx = self.variable_names.index((self.__hash__(), val)) + return ((self.__hash__(), val), self.num_states[idx]) def flatten( self, data: Mapping[Tuple[Tuple[int, int], int], Union[np.ndarray, jnp.ndarray]] From a839be6ee691e755ab81a920e9db1632a587b576 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Wed, 27 Apr 2022 22:56:11 +0000 Subject: [PATCH 28/35] Comments --- examples/pmp_binary_deconvolution.py | 2 +- examples/rbm.py | 11 ++- examples/rcn.py | 24 ++++--- pgmax/fg/graph.py | 67 +++++++++-------- pgmax/fg/groups.py | 93 ++++++++++++++---------- pgmax/groups/variables.py | 103 ++++++++++++--------------- tests/factors/test_and.py | 6 +- tests/factors/test_or.py | 6 +- tests/fg/test_graph.py | 28 ++------ tests/fg/test_groups.py | 26 +++++-- tests/fg/test_wiring.py | 28 ++++---- tests/test_pgmax.py | 22 +++--- 12 files changed, 212 insertions(+), 204 deletions(-) diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py index 131c79e3..a8b992e5 100644 --- a/examples/pmp_binary_deconvolution.py +++ b/examples/pmp_binary_deconvolution.py @@ -195,7 +195,7 @@ def plot_images(images, display=True, nr=None): fg.add_factors(OR_factor_group) print("Time", time.time() - start) -for factor_type, factor_groups in fg._factor_types_to_groups.items(): +for factor_type, factor_groups in fg.factor_groups.items(): if len(factor_groups) > 0: assert len(factor_groups) == 1 print(f"The factor graph contains {factor_groups[0].num_factors} {factor_type}") diff --git a/examples/rbm.py b/examples/rbm.py index 8565f086..4b391a62 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -69,14 +69,12 @@ factor_configs=np.arange(2)[:, None], log_potentials=np.stack([np.zeros_like(bh), bh], axis=1), ) -fg.add_factors(hidden_unaries) visible_unaries = enumeration.EnumerationFactorGroup( variables_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])], factor_configs=np.arange(2)[:, None], log_potentials=np.stack([np.zeros_like(bv), bv], axis=1), ) -fg.add_factors(visible_unaries) # Add pairwise factors log_potential_matrix = np.zeros(W.shape + (2, 2)).reshape((-1, 2, 2)) @@ -94,7 +92,7 @@ ) print("Time", time.time() - start) -fg.add_factors(pairwise_factors) +fg.add_factors([hidden_unaries, visible_unaries, pairwise_factors]) print("Time", time.time() - start) @@ -107,6 +105,7 @@ # # An alternative way of creating the above factors is to add them iteratively by calling [`fg.add_factor`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph.add_factor) as below. This approach is not recommended as it is not computationally efficient. # ~~~python +# from pgmax.factors import enumeration as enumeration_factor # import itertools # from tqdm import tqdm # @@ -117,7 +116,7 @@ # factor_configs=np.arange(2)[:, None], # log_potentials=np.array([0, bh[ii]]), # ) -# fg.add_factors(factor=factor) +# fg.add_factors(factor) # # for jj in range(bv.shape[0]): # factor = enumeration_factor.EnumerationFactor( @@ -125,7 +124,7 @@ # factor_configs=np.arange(2)[:, None], # log_potentials=np.array([0, bv[jj]]), # ) -# fg.add_factors(factor=factor) +# fg.add_factors(factor) # # # Add pairwise factors # factor_configs = np.array(list(itertools.product(np.arange(2), repeat=2))) @@ -136,7 +135,7 @@ # factor_configs=factor_configs, # log_potentials=np.array([0, 0, 0, W[ii, jj]]), # ) -# fg.add_factors(factor=factor) +# fg.add_factors(factor) # ~~~ # # Once we have added the factors, we can run max-product LBP and get MAP decoding by diff --git a/examples/rcn.py b/examples/rcn.py index 2414c0e9..140c1712 100644 --- a/examples/rcn.py +++ b/examples/rcn.py @@ -38,9 +38,9 @@ from scipy.signal import fftconvolve from sklearn.datasets import fetch_openml -from pgmax.factors.enumeration import EnumerationFactor from pgmax.fg import graph from pgmax.groups import variables as vgroup +from pgmax.groups.enumeration import EnumerationFactorGroup memory = Memory("./example_data/tmp") fetch_openml_cached = memory.cache(fetch_openml) @@ -280,12 +280,13 @@ def valid_configs(r: int, hps: int, vps: int) -> np.ndarray: for e in edge: i1, i2, r = e - factor = EnumerationFactor( - variables=[variables_all_models[idx][i1], variables_all_models[idx][i2]], + factor_group = EnumerationFactorGroup( + variables_for_factors=[ + [variables_all_models[idx][i1], variables_all_models[idx][i2]] + ], factor_configs=valid_configs_list[r], - log_potentials=np.zeros(valid_configs_list[r].shape[0]), ) - fg.add_factors(factor=factor) + fg.add_factors(factor_group) end = time.time() print(f"Creating factors took {end-start:.3f} seconds.") @@ -384,7 +385,10 @@ def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray: # %% -frcs_dict = {model_idx: frcs[model_idx] for model_idx in range(frcs.shape[0])} +frcs_dict = { + variables_all_models[model_idx]: frcs[model_idx] + for model_idx in range(frcs.shape[0]) +} bp = graph.BP(fg.bp_state, temperature=0.0) scores = np.zeros((len(test_set), frcs.shape[0])) map_states_dict = {} @@ -413,8 +417,8 @@ def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray: evidence_updates, map_states, ) - for ii in score: - scores[test_idx, ii] = score[ii] + for idx, score in enumerate(score.values()): + scores[test_idx, idx] = score end = time.time() print(f"Computing scores took {end-start:.3f} seconds for image {test_idx}.") @@ -437,7 +441,9 @@ def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray: fig, ax = plt.subplots(5, 4, figsize=(16, 20)) for test_idx in range(20): idx = np.unravel_index(test_idx, (5, 4)) - map_state = map_states_dict[test_idx][best_model_idx[test_idx]] + map_state = map_states_dict[test_idx][ + variables_all_models[best_model_idx[test_idx]] + ] offsets = np.array( np.unravel_index(map_state, (2 * hps + 1, 2 * vps + 1)) ).T - np.array([hps, vps]) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 29ecc26a..21ea1f11 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -82,36 +82,44 @@ def __post_init__(self): self._num_var_states = vars_num_states_cumsum def __hash__(self) -> int: - return hash(self.factor_groups) + all_factor_groups = tuple( + [ + factor_group + for factor_groups_per_type in self._factor_types_to_groups.values() + for factor_group in factor_groups_per_type + ] + ) + return hash(all_factor_groups) def add_factors( self, - factor_group: Optional[groups.FactorGroup] = None, - factor: Optional[nodes.Factor] = None, + factors: Union[ + nodes.Factor, + groups.FactorGroup, + Sequence[Union[nodes.Factor, groups.FactorGroup]], + ], ) -> None: - """Add a FactorGroup or a single Factor to the FactorGraph, by updating the FactorGraphState. + """Add a single Factor, a FactorGroup or a list of single Factors and FactorGroups to the FactorGraph, + by updating the FactorGraphState. Args: - factor_group: The FactorGroup to be added to the FactorGraph. - factor: The Factor to be added to the FactorGraph. + factors: The Factor, FactorGroup or list of Factors and FactorGroups to be added to the FactorGraph. Raises: - ValueError: If - (1) Both a Factor and a FactorGroup are added - (2) The FactorGroup involving the same variables already exists in the FactorGraph. + ValueError: A FactorGroup involving the same variables already exists in the FactorGraph. """ - if factor is None and factor_group is None: - raise ValueError("A Factor or a FactorGroup is required") - - if factor is not None and factor_group is not None: - raise ValueError("Cannot simultaneously add a Factor and a FactorGroup") - - if factor is not None: + if isinstance(factors, list): + for factor in factors: + self.add_factors(factor) + return None + + if isinstance(factors, groups.FactorGroup): + factor_group = factors + elif isinstance(factors, nodes.Factor): factor_group = groups.SingleFactorGroup( - variables_for_factors=[factor.variables], - factor=factor, + variables_for_factors=[factors.variables], + factor=factors, ) - assert factor_group is not None factor_type = factor_group.factor_type for var_names_for_factor in factor_group.variables_for_factors: @@ -253,15 +261,9 @@ def factors(self) -> OrderedDict[Type, Tuple[nodes.Factor, ...]]: return factors @property - def factor_groups(self) -> Tuple[groups.FactorGroup, ...]: + def factor_groups(self) -> OrderedDict[Type, List[groups.FactorGroup]]: """Tuple of factor groups in the factor graph""" - return tuple( - [ - factor_group - for factor_groups_per_type in self._factor_types_to_groups.values() - for factor_group in factor_groups_per_type - ] - ) + return self._factor_types_to_groups @cached_property def fg_state(self) -> FactorGraphState: @@ -278,7 +280,6 @@ def fg_state(self) -> FactorGraphState: vars_to_starts=self._vars_to_starts, num_var_states=self._num_var_states, total_factor_num_states=self._total_factor_num_states, - factor_groups=self.factor_groups, factor_type_to_msgs_range=copy.copy(self._factor_type_to_msgs_range), factor_type_to_potentials_range=copy.copy( self._factor_type_to_potentials_range @@ -314,7 +315,6 @@ class FactorGraphState: contains evidence to the variable. num_var_states: Total number of variable states. total_factor_num_states: Size of the flat ftov messages array. - factor_groups: FactorGroups in the FactorGraph factor_type_to_msgs_range: Maps factors types to their start and end indices in the flat ftov messages. factor_type_to_potentials_range: Maps factor types to their start and end indices in the flat log potentials. factor_group_to_potentials_starts: Maps factor groups to their starting indices in the flat log potentials. @@ -326,7 +326,6 @@ class FactorGraphState: vars_to_starts: Mapping[Tuple[Any, int], int] num_var_states: int total_factor_num_states: int - factor_groups: Tuple[groups.FactorGroup, ...] factor_type_to_msgs_range: OrderedDict[type, Tuple[int, int]] factor_type_to_potentials_range: OrderedDict[type, Tuple[int, int]] factor_group_to_potentials_starts: OrderedDict[groups.FactorGroup, int] @@ -395,7 +394,7 @@ def update_log_potentials( (2) Provided name is not valid for log_potentials updates. """ for factor_group, data in updates.items(): - if factor_group in fg_state.factor_groups: + if factor_group in fg_state.factor_group_to_potentials_starts: flat_data = factor_group.flatten(data) if flat_data.shape != factor_group.factor_group_log_potentials.shape: raise ValueError( @@ -450,7 +449,7 @@ def __getitem__(self, factor_group: groups.FactorGroup) -> np.ndarray: The queried log potentials. """ value = cast(np.ndarray, self.value) - if factor_group in self.fg_state.factor_groups: + if factor_group in self.fg_state.factor_group_to_potentials_starts: start = self.fg_state.factor_group_to_potentials_starts[factor_group] log_potentials = value[ start : start + factor_group.factor_group_log_potentials.shape[0] @@ -558,7 +557,7 @@ def __post_init__(self): def __setitem__( self, - variable: Tuple[Any, int], + variable: Tuple[int, int], data: Union[np.ndarray, jnp.ndarray], ) -> None: """Spreading beliefs at a variable to all connected Factors @@ -641,7 +640,7 @@ def __post_init__(self): object.__setattr__(self, "value", self.value) - def __getitem__(self, variable: Tuple[Any, int]) -> np.ndarray: + def __getitem__(self, variable: Tuple[int, int]) -> np.ndarray: """Function to query evidence for a variable Args: diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 5f47929b..e8079a36 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -17,7 +17,6 @@ ) import jax.numpy as jnp -import numba as nb import numpy as np import pgmax.fg.nodes as nodes @@ -32,10 +31,13 @@ class VariableGroup: """Class to represent a group of variables. Each variable is represented via a tuple of the form (variable hash/name, number of states) - Attributes: - this_hash: Hash of the VariableGroup + Arguments: + num_states: An integer or an array specifying the number of states of the variables + in this VariableGroup """ + num_states: Union[int, np.ndarray] = field(init=False) + def __post_init__(self): # Only compute the hash once, which is guaranteed to be an int64 this_id = id(self) % 2**32 @@ -52,9 +54,9 @@ def __eq__(self, other): def __lt__(self, other): return hash(self) < hash(other) - def __getitem__(self, val: Any) -> Union[Tuple[Any, int], List[Tuple[Any, int]]]: + def __getitem__(self, val: Any) -> Union[Tuple[int, int], List[Tuple[int, int]]]: """Given a variable name, index, or a group of variable indices, retrieve the associated variable(s). - Each variable is returned via a tuple of the form (variable hash/name, number of states) + Each variable is returned via a tuple of the form (variable hash, number of states) Args: val: a variable index, slice, or name @@ -67,17 +69,30 @@ def __getitem__(self, val: Any) -> Union[Tuple[Any, int], List[Tuple[Any, int]]] ) @cached_property - def variables(self) -> List[Tuple]: - """Function that returns the list of all variables in the VariableGroup. - Each variable is represented by a tuple of the form (variable hash/name, number of states) + def variable_hashes(self) -> np.ndarray: + """Function that generates a variable hash for each variable Returns: - List of variables in the VariableGroup + Array of variables hashes. """ raise NotImplementedError( "Please subclass the VariableGroup class and override this method" ) + @cached_property + def variables(self) -> List[Tuple[int, int]]: + """Function that returns the list of all variables in the VariableGroup. + Each variable is represented by a tuple of the form (variable hash, number of states) + + Returns: + List of variables in the VariableGroup + """ + assert isinstance(self.variable_hashes, np.ndarray) + assert isinstance(self.num_states, np.ndarray) + vars_hashes = self.variable_hashes.flatten() + vars_num_states = self.num_states.flatten() + return list(zip(vars_hashes, vars_num_states)) + def flatten(self, data: Any) -> jnp.ndarray: """Function that turns meaningful structured data into a flat data array for internal use. @@ -118,9 +133,6 @@ class FactorGroup: Attributes: factor_type: Factor type shared by all the Factors in the FactorGroup. - factor_sizes: Array of the different factor sizes. - factor_edges_num_states: Array concatenating the number of states for the variables connected to each Factor of - the FactorGroup. Each variable will appear once for each Factor it connects to. Raises: ValueError: if the FactorGroup does not contain a Factor @@ -130,27 +142,11 @@ class FactorGroup: factor_configs: np.ndarray = field(init=False) log_potentials: np.ndarray = field(init=False, default=np.empty((0,))) factor_type: Type = field(init=False) - factor_sizes: np.ndarray = field(init=False) - factor_edges_num_states: np.ndarray = field(init=False) def __post_init__(self): if len(self.variables_for_factors) == 0: raise ValueError("Cannot create a FactorGroup with no Factor.") - factor_sizes = np.array( - [ - len(variables_for_factor) - for variables_for_factor in self.variables_for_factors - ] - ) - object.__setattr__(self, "factor_sizes", factor_sizes) - - factor_edges_num_states = np.empty(shape=(self.factor_sizes.sum(),), dtype=int) - _compile_edges_num_states_numba( - factor_edges_num_states, self.variables_for_factors - ) - object.__setattr__(self, "factor_edges_num_states", factor_edges_num_states) - def __hash__(self): return id(self) @@ -180,6 +176,38 @@ def __getitem__(self, variables: Union[Sequence, Collection]) -> Any: ) return self._variables_to_factors[variables] + @cached_property + def factor_sizes(self) -> np.ndarray: + """Computes the factor sizes + + Returns: + Array of the different factor sizes. + """ + return np.array( + [ + len(variables_for_factor) + for variables_for_factor in self.variables_for_factors + ] + ) + + @cached_property + def factor_edges_num_states(self) -> np.ndarray: + """Computes an array concatenating the number of states for the variables connected to each Factor of + the FactorGroup. Each variable will appear once for each Factor it connects to. + + Returns: + Array with the the number of states for the variables connected to each Factor. + """ + factor_edges_num_states = np.empty(shape=(self.factor_sizes.sum(),), dtype=int) + + # TODO: create variables_for_factors as an array and move this to numba + idx = 0 + for variables_for_factor in self.variables_for_factors: + for variable in variables_for_factor: + factor_edges_num_states[idx] = variable[1] + idx += 1 + return factor_edges_num_states + @cached_property def _variables_to_factors(self) -> Mapping[FrozenSet, nodes.Factor]: """Function to compile potential array for the factor group. @@ -335,12 +363,3 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any: raise NotImplementedError( "SingleFactorGroup does not support vectorized factor operations." ) - - -# @nb.jit(parallel=False, cache=True, fastmath=True) -def _compile_edges_num_states_numba(factor_edges_num_states, variables_for_factors): - idx = 0 - for variables_for_factor in variables_for_factors: - for variable in variables_for_factor: - factor_edges_num_states[idx] = variable[1] - idx += 1 diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index 7d61fb7a..22a7ef0e 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -16,14 +16,14 @@ class NDVariableArray(groups.VariableGroup): """Subclass of VariableGroup for n-dimensional grids of variables. Args: - shape: Tuple specifying the size of each dimension of the grid (similar to - the notion of a NumPy ndarray shape) num_states: An integer or an array specifying the number of states of the variables in this VariableGroup + shape: Tuple specifying the size of each dimension of the grid (similar to + the notion of a NumPy ndarray shape) """ - shape: Tuple[int, ...] num_states: Union[int, np.ndarray] + shape: Tuple[int, ...] def __post_init__(self): super().__post_init__() @@ -69,31 +69,18 @@ def __getitem__( isinstance(val, tuple) and isinstance(val[0], slice) ): assert isinstance(self.num_states, np.ndarray) - vars_names = self.variable_names[val].flatten() + vars_names = self.variable_hashes[val].flatten() vars_num_states = self.num_states[val].flatten() return list(zip(vars_names, vars_num_states)) - return (self.variable_names[val], self.num_states[val]) - - @cached_property - def variables(self) -> List[Tuple]: - """Function that returns the list of all variables in the VariableGroup. - Each variable is represented by a tuple of the form (variable hash, number of states) - - Returns: - List of variables in the VariableGroup - """ - assert isinstance(self.num_states, np.ndarray) - vars_names = self.variable_names.flatten() - vars_num_states = self.num_states.flatten() - return list(zip(vars_names, vars_num_states)) + return (self.variable_hashes[val], self.num_states[val]) @cached_property - def variable_names(self) -> np.ndarray: - """Function that generates all the variables names, in the form of hashes + def variable_hashes(self) -> np.ndarray: + """Function that generates a variable hash for each variable Returns: - Array of variables names. + Array of variables hashes. """ indices = np.reshape(np.arange(np.product(self.shape)), self.shape) return self.__hash__() + indices @@ -167,28 +154,22 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> jnp.ndarray: @dataclass(frozen=True, eq=False) class VariableDict(groups.VariableGroup): - """A variable dictionary that contains a set of variables of the same size + """A variable dictionary that contains a set of variables Args: num_states: The size of the variables in this VariableGroup - variable_names: A tuple of all names of the variables in this VariableGroup. - Note that we overwrite variable_names to add the hash of the VariableDict + variable_names: A tuple of all the names of the variables in this VariableGroup. """ - variable_names: Tuple[Any, ...] num_states: Union[int, np.ndarray] + variable_names: Tuple[Any, ...] def __post_init__(self): super().__post_init__() - hash_and_names = tuple( - (self.__hash__(), var_name) for var_name in self.variable_names - ) - object.__setattr__(self, "variable_names", hash_and_names) - if np.isscalar(self.num_states): num_states = np.full( - len(self.variable_names), fill_value=self.num_states, dtype=np.int64 + (len(self.variable_names),), fill_value=self.num_states, dtype=np.int64 ) object.__setattr__(self, "num_states", num_states) elif isinstance(self.num_states, np.ndarray) and np.issubdtype( @@ -196,38 +177,39 @@ def __post_init__(self): ): if self.num_states.shape != len(self.variable_names): raise ValueError( - f"Expected num_states shape {len(self.variable_names)}. Got {self.num_states.shape}." + f"Expected num_states shape ({len(self.variable_names)},). Got {self.num_states.shape}." ) + else: + raise ValueError( + "num_states should be an integer or a NumPy array of dtype int" + ) @cached_property - def variables(self) -> List[Tuple[Tuple[Any, int], int]]: - """Function that returns the list of all variables in the VariableGroup. - Each variable is represented by a tuple of the form (variable name, number of states) + def variable_hashes(self) -> np.ndarray: + """Function that generates a variable hash for each variable Returns: - List of variables in the VariableGroup + Array of variables hashes. """ - assert isinstance(self.num_states, np.ndarray) - vars_names = list(self.variable_names) - vars_num_states = self.num_states.flatten() - return list(zip(vars_names, vars_num_states)) + indices = np.arange(len(self.variable_names)) + return self.__hash__() + indices - def __getitem__(self, val: Any) -> Tuple[Tuple[Any, int], int]: + def __getitem__(self, var_name: Any) -> Tuple[int, int]: """Given a variable name retrieve the associated variable, returned via a tuple of the form - (variable name, number of states) + (variable hash, number of states) Args: - val: a variable name (without the object hash) + val: a variable name Returns: The queried variable """ assert isinstance(self.num_states, np.ndarray) - if (self.__hash__(), val) not in self.variable_names: - raise ValueError(f"Variable {val} is not in VariableDict") + if var_name not in self.variable_names: + raise ValueError(f"Variable {var_name} is not in VariableDict") - idx = self.variable_names.index((self.__hash__(), val)) - return ((self.__hash__(), val), self.num_states[idx]) + var_idx = self.variable_names.index(var_name) + return (self.variable_hashes[var_idx], self.num_states[var_idx]) def flatten( self, data: Mapping[Tuple[Tuple[int, int], int], Union[np.ndarray, jnp.ndarray]] @@ -249,20 +231,23 @@ def flatten( """ assert isinstance(self.num_states, np.ndarray) - for variable in data: - if variable not in self.variables: + for var_name in data: + if var_name not in self.variable_names: raise ValueError( - f"data is referring to a non-existent variable {variable}." + f"data is referring to a non-existent variable {var_name}." ) - if data[variable].shape != (variable[1],) and data[variable].shape != (1,): + var_idx = self.variable_names.index(var_name) + if data[var_name].shape != (self.num_states[var_idx],) and data[ + var_name + ].shape != (1,): raise ValueError( - f"Variable {variable} expects a data array of shape " - f"{(variable[1],)} or (1,). Got {data[variable].shape}." + f"Variable {var_name} expects a data array of shape " + f"{(self.num_states[var_idx],)} or (1,). Got {data[var_name].shape}." ) flat_data = jnp.concatenate( - [data[variable].flatten() for variable in self.variables] + [data[var_name].flatten() for var_name in self.variable_names] ) return flat_data @@ -306,12 +291,14 @@ def unflatten( start = 0 data = {} - for variable in self.variables: + for var_name in self.variable_names: if use_num_states: - data[variable] = flat_data[start : start + variable[1]] - start += variable[1] + var_idx = self.variable_names.index(var_name) + var_num_states = self.num_states[var_idx] + data[var_name] = flat_data[start : start + var_num_states] + start += var_num_states else: - data[variable] = flat_data[np.array([start])] + data[var_name] = flat_data[np.array([start])] start += 1 return data diff --git a/tests/factors/test_and.py b/tests/factors/test_and.py index fcbe6772..4f2f7612 100644 --- a/tests/factors/test_and.py +++ b/tests/factors/test_and.py @@ -104,7 +104,7 @@ def test_run_bp_with_ANDFactors(): factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) - fg1.add_factors(factor=enum_factor) + fg1.add_factors(enum_factor) else: if idx != 0: # Add the second half of factors to FactorGraph2 @@ -113,7 +113,7 @@ def test_run_bp_with_ANDFactors(): factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) - fg2.add_factors(factor=enum_factor) + fg2.add_factors(enum_factor) else: # Add all the EnumerationFactors to FactorGraph1 for the first iter enum_factor = EnumerationFactor( @@ -121,7 +121,7 @@ def test_run_bp_with_ANDFactors(): factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) - fg1.add_factors(factor=enum_factor) + fg1.add_factors(enum_factor) # Option 2: Define the ANDFactors num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0) diff --git a/tests/factors/test_or.py b/tests/factors/test_or.py index f29dccd0..c615a800 100644 --- a/tests/factors/test_or.py +++ b/tests/factors/test_or.py @@ -102,7 +102,7 @@ def test_run_bp_with_ORFactors(): factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) - fg1.add_factors(factor=enum_factor) + fg1.add_factors(enum_factor) else: if idx != 0: # Add the second half of factors to FactorGraph2 @@ -111,7 +111,7 @@ def test_run_bp_with_ORFactors(): factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) - fg2.add_factors(factor=enum_factor) + fg2.add_factors(enum_factor) else: # Add all the EnumerationFactors to FactorGraph1 for the first iter enum_factor = EnumerationFactor( @@ -119,7 +119,7 @@ def test_run_bp_with_ORFactors(): factor_configs=valid_configs, log_potentials=np.zeros(valid_configs.shape[0]), ) - fg1.add_factors(factor=enum_factor) + fg1.add_factors(enum_factor) # Option 2: Define the ORFactors num_parents_cumsum = np.insert(np.cumsum(num_parents), 0, 0) diff --git a/tests/fg/test_graph.py b/tests/fg/test_graph.py index 0fdadbb2..8b0d72b0 100644 --- a/tests/fg/test_graph.py +++ b/tests/fg/test_graph.py @@ -16,30 +16,12 @@ def test_factor_graph(): vg = vgroup.VariableDict(variable_names=(0,), num_states=15) fg = graph.FactorGraph(vg) - with pytest.raises( - ValueError, - match="A Factor or a FactorGroup is required", - ): - fg.add_factors(factor_group=None, factor=None) - factor = enumeration_factor.EnumerationFactor( variables=[vg[0]], factor_configs=np.arange(15)[:, None], log_potentials=np.zeros(15), ) - - factor_group = enumeration.EnumerationFactorGroup( - variables_for_factors=[[vg[0]]], - factor_configs=np.arange(15)[:, None], - log_potentials=np.zeros(15), - ) - with pytest.raises( - ValueError, - match="Cannot simultaneously add a Factor and a FactorGroup", - ): - fg.add_factors(factor_group=factor_group, factor=factor) - - fg.add_factors(factor=factor) + fg.add_factors(factor) factor_group = enumeration.EnumerationFactorGroup( variables_for_factors=[[vg[0]]], @@ -49,7 +31,7 @@ def test_factor_graph(): with pytest.raises( ValueError, match=re.escape( - f"A Factor of type {enumeration_factor.EnumerationFactor} involving variables {frozenset([((vg.__hash__(), 0), 15)])} already exists." + f"A Factor of type {enumeration_factor.EnumerationFactor} involving variables {frozenset([(vg.__hash__(), 15)])} already exists." ), ): fg.add_factors(factor_group) @@ -63,10 +45,10 @@ def test_bp_state(): factor_configs=np.arange(15)[:, None], log_potentials=np.zeros(15), ) - fg0.add_factors(factor=factor) + fg0.add_factors(factor) fg1 = graph.FactorGraph(vg) - fg1.add_factors(factor=factor) + fg1.add_factors(factor) with pytest.raises( ValueError, @@ -137,7 +119,7 @@ def test_ftov_msgs(): with pytest.raises( ValueError, match=re.escape( - f"Given belief shape (10,) does not match expected shape (15,) for variable (({vg.__hash__()}, 0), 15)." + f"Given belief shape (10,) does not match expected shape (15,) for variable ({vg.__hash__()}, 15)." ), ): fg.bp_state.ftov_msgs[vg[0]] = np.ones(10) diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index b47e3e09..7799ffb0 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -11,6 +11,21 @@ def test_variable_dict(): + num_states = np.full((4,), fill_value=2) + with pytest.raises( + ValueError, match=re.escape("Expected num_states shape (3,). Got (4,).") + ): + vgroup.VariableDict(variable_names=tuple([0, 1, 2]), num_states=num_states) + + num_states = np.full((3,), fill_value=2, dtype=np.float32) + with pytest.raises( + ValueError, + match=re.escape( + "num_states should be an integer or a NumPy array of dtype int" + ), + ): + vgroup.NDVariableArray(shape=(2, 2), num_states=num_states) + variable_dict = vgroup.VariableDict(variable_names=tuple([0, 1, 2]), num_states=15) with pytest.raises( ValueError, match="data is referring to a non-existent variable 3" @@ -20,10 +35,10 @@ def test_variable_dict(): with pytest.raises( ValueError, match=re.escape( - f"Variable (({variable_dict.__hash__()}, 2), 15) expects a data array of shape (15,) or (1,). Got (10,)" + "Variable 2 expects a data array of shape (15,) or (1,). Got (10,)." ), ): - variable_dict.flatten({((variable_dict.__hash__(), 2), 15): np.zeros(10)}) + variable_dict.flatten({2: np.zeros(10)}) with pytest.raises( ValueError, match="Can only unflatten 1D array. Got a 2D array." @@ -36,10 +51,7 @@ def test_variable_dict(): jax.tree_util.tree_multimap( lambda x, y: jnp.all(x == y), variable_dict.unflatten(jnp.zeros(3)), - { - ((variable_dict.__hash__(), name), 15): np.zeros(1) - for name in range(3) - }, + {name: np.zeros(1) for name in range(3)}, ) ) ) @@ -54,7 +66,7 @@ def test_variable_dict(): def test_nd_variable_array(): - max_size = int(vgroup.MAX_SIZE) + max_size = int(groups.MAX_SIZE) with pytest.raises( ValueError, match=re.escape( diff --git a/tests/fg/test_wiring.py b/tests/fg/test_wiring.py index 53efeee3..af2fb9fc 100644 --- a/tests/fg/test_wiring.py +++ b/tests/fg/test_wiring.py @@ -25,7 +25,7 @@ def test_wiring_with_PairwiseFactorGroup(): ) fg.add_factors(factor_group) - factor_group = fg.factor_groups[0] + factor_group = fg.factor_groups[enumeration_factor.EnumerationFactor][0] object.__setattr__( factor_group, "factor_configs", factor_group.factor_configs[:, :1] ) @@ -41,7 +41,7 @@ def test_wiring_with_PairwiseFactorGroup(): variables_for_factors=[[A[idx], B[idx]] for idx in range(10)] ) fg1.add_factors(factor_group) - assert len(fg1.factor_groups) == 1 + assert len(fg1.factor_groups[enumeration_factor.EnumerationFactor]) == 1 # FactorGraph with multiple PairwiseFactorGroup fg2 = graph.FactorGraph(variable_groups=[A, B]) @@ -50,18 +50,20 @@ def test_wiring_with_PairwiseFactorGroup(): variables_for_factors=[[A[idx], B[idx]]] ) fg2.add_factors(factor_group) - assert len(fg2.factor_groups) == 10 + assert len(fg2.factor_groups[enumeration_factor.EnumerationFactor]) == 10 # FactorGraph with multiple SingleFactorGroup fg3 = graph.FactorGraph(variable_groups=[A, B]) + factors = [] for idx in range(10): factor = enumeration_factor.EnumerationFactor( variables=[A[idx], B[idx]], factor_configs=np.array([[0, 0], [0, 1], [1, 0], [1, 1]]), log_potentials=np.zeros((4,)), ) - fg3.add_factors(factor=factor) - assert len(fg3.factor_groups) == 10 + factors.append(factor) + fg3.add_factors(factors) + assert len(fg3.factor_groups[enumeration_factor.EnumerationFactor]) == 10 assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors) @@ -100,7 +102,7 @@ def test_wiring_with_ORFactorGroup(): variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)], ) fg1.add_factors(factor_group) - assert len(fg1.factor_groups) == 1 + assert len(fg1.factor_groups[logical_factor.ORFactor]) == 1 # FactorGraph with multiple ORFactorGroup fg2 = graph.FactorGraph(variable_groups=[A, B, C]) @@ -109,7 +111,7 @@ def test_wiring_with_ORFactorGroup(): variables_for_factors=[[A[idx], B[idx], C[idx]]], ) fg2.add_factors(factor_group) - assert len(fg2.factor_groups) == 10 + assert len(fg2.factor_groups[logical_factor.ORFactor]) == 10 # FactorGraph with multiple SingleFactorGroup fg3 = graph.FactorGraph(variable_groups=[A, B, C]) @@ -117,8 +119,8 @@ def test_wiring_with_ORFactorGroup(): factor = logical_factor.ORFactor( variables=[A[idx], B[idx], C[idx]], ) - fg3.add_factors(factor=factor) - assert len(fg3.factor_groups) == 10 + fg3.add_factors(factor) + assert len(fg3.factor_groups[logical_factor.ORFactor]) == 10 assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors) @@ -155,7 +157,7 @@ def test_wiring_with_ANDFactorGroup(): variables_for_factors=[[A[idx], B[idx], C[idx]] for idx in range(10)], ) fg1.add_factors(factor_group) - assert len(fg1.factor_groups) == 1 + assert len(fg1.factor_groups[logical_factor.ANDFactor]) == 1 # FactorGraph with multiple ANDFactorGroup fg2 = graph.FactorGraph(variable_groups=[A, B, C]) @@ -164,7 +166,7 @@ def test_wiring_with_ANDFactorGroup(): variables_for_factors=[[A[idx], B[idx], C[idx]]], ) fg2.add_factors(factor_group) - assert len(fg2.factor_groups) == 10 + assert len(fg2.factor_groups[logical_factor.ANDFactor]) == 10 # FactorGraph with multiple SingleFactorGroup fg3 = graph.FactorGraph(variable_groups=[A, B, C]) @@ -172,8 +174,8 @@ def test_wiring_with_ANDFactorGroup(): factor = logical_factor.ANDFactor( variables=[A[idx], B[idx], C[idx]], ) - fg3.add_factors(factor=factor) - assert len(fg3.factor_groups) == 10 + fg3.add_factors(factor) + assert len(fg3.factor_groups[logical_factor.ANDFactor]) == 10 assert len(fg1.factors) == len(fg2.factors) == len(fg3.factors) diff --git a/tests/test_pgmax.py b/tests/test_pgmax.py index 9e6deb65..f3ee7bfa 100644 --- a/tests/test_pgmax.py +++ b/tests/test_pgmax.py @@ -207,7 +207,7 @@ def create_valid_suppression_config_arr(suppression_diameter): extra_row_names: List[Tuple[Any, ...]] = [(0, row, N - 1) for row in range(M - 1)] extra_col_names: List[Tuple[Any, ...]] = [(1, M - 1, col) for col in range(N - 1)] additional_names = tuple(extra_row_names + extra_col_names) - additional_vars = vgroup.VariableDict(additional_names, num_states=3) + additional_vars = vgroup.VariableDict(variable_names=additional_names, num_states=3) true_map_state_output = { (grid_vars, (0, 0, 0)): 2, @@ -218,10 +218,14 @@ def create_valid_suppression_config_arr(suppression_diameter): (grid_vars, (1, 0, 1)): 0, (grid_vars, (1, 1, 0)): 1, (grid_vars, (1, 1, 1)): 0, - (additional_vars, ((additional_vars.__hash__(), (0, 0, 2)), 3)): 0, - (additional_vars, ((additional_vars.__hash__(), (0, 1, 2)), 3)): 2, - (additional_vars, ((additional_vars.__hash__(), (1, 2, 0)), 3)): 1, - (additional_vars, ((additional_vars.__hash__(), (1, 2, 1)), 3)): 0, + # (additional_vars, ((additional_vars.__hash__(), (0, 0, 2)), 3)): 0, + # (additional_vars, ((additional_vars.__hash__(), (0, 1, 2)), 3)): 2, + # (additional_vars, ((additional_vars.__hash__(), (1, 2, 0)), 3)): 1, + # (additional_vars, ((additional_vars.__hash__(), (1, 2, 1)), 3)): 0, + (additional_vars, (0, 0, 2)): 0, + (additional_vars, (0, 1, 2)): 2, + (additional_vars, (1, 2, 0)): 1, + (additional_vars, (1, 2, 1)): 0, } gt_has_cuts = gt_has_cuts.astype(np.int32) @@ -251,9 +255,7 @@ def create_valid_suppression_config_arr(suppression_diameter): except IndexError: try: _ = additional_vars[i, row, col] - additional_vars_evidence_dict[ - additional_vars[i, row, col] - ] = evidence_vals_arr + additional_vars_evidence_dict[i, row, col] = evidence_vals_arr except ValueError: pass @@ -302,7 +304,7 @@ def create_valid_suppression_config_arr(suppression_diameter): valid_configs_non_supp.shape[0], dtype=float ), ) - fg.add_factors(factor=factor) + fg.add_factors(factor) else: factor = EnumerationFactor( variables=curr_vars, @@ -311,7 +313,7 @@ def create_valid_suppression_config_arr(suppression_diameter): valid_configs_non_supp.shape[0], dtype=float ), ) - fg.add_factors(factor=factor) + fg.add_factors(factor) # Create an EnumerationFactorGroup for vertical suppression factors vert_suppression_vars: List[List[Tuple[Any, ...]]] = [] From 05de350e4d20010690f8a3a1fb760cdcccaee13d Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Wed, 27 Apr 2022 23:20:34 +0000 Subject: [PATCH 29/35] Minor --- pgmax/factors/enumeration.py | 5 ++--- pgmax/factors/logical.py | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pgmax/factors/enumeration.py b/pgmax/factors/enumeration.py index ef2a99a7..2340d0c1 100644 --- a/pgmax/factors/enumeration.py +++ b/pgmax/factors/enumeration.py @@ -2,7 +2,7 @@ import functools from dataclasses import dataclass -from typing import Any, List, Mapping, Sequence, Tuple, Union +from typing import List, Mapping, Sequence, Tuple, Union import jax import jax.numpy as jnp @@ -157,7 +157,7 @@ def compile_wiring( factor_edges_num_states: np.ndarray, variables_for_factors: Sequence[List], factor_configs: np.ndarray, - vars_to_starts: Mapping[Tuple[Any, int], int], + vars_to_starts: Mapping[Tuple[int, int], int], num_factors: int, ) -> EnumerationWiring: """Compile an EnumerationWiring for an EnumerationFactor or a FactorGroup with EnumerationFactors. @@ -183,7 +183,6 @@ def compile_wiring( Returns: The EnumerationWiring """ - # TODO: Don't use vars_to_starts var_states = [] for variables_for_factor in variables_for_factors: for variable in variables_for_factor: diff --git a/pgmax/factors/logical.py b/pgmax/factors/logical.py index 8b931e47..482865b2 100644 --- a/pgmax/factors/logical.py +++ b/pgmax/factors/logical.py @@ -2,7 +2,7 @@ import functools from dataclasses import dataclass, field -from typing import Any, List, Mapping, Optional, Sequence, Tuple, Union +from typing import List, Mapping, Optional, Sequence, Tuple, Union import jax import jax.numpy as jnp @@ -143,7 +143,7 @@ def compile_wiring( factor_edges_num_states: np.ndarray, variables_for_factors: Sequence[List], factor_sizes: np.ndarray, - vars_to_starts: Mapping[Tuple[Any, int], int], + vars_to_starts: Mapping[Tuple[int, int], int], edge_states_offset: int, ) -> LogicalWiring: """Compile a LogicalWiring for a LogicalFactor or a FactorGroup with LogicalFactors. @@ -164,7 +164,6 @@ def compile_wiring( Returns: The LogicalWiring """ - # TODO: Don't use vars_to_starts var_states = [] for variables_for_factor in variables_for_factors: for variable in variables_for_factor: From 1e0c93d6675d6eae41fba7c03ff25b67ab0b8021 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Wed, 27 Apr 2022 23:50:35 +0000 Subject: [PATCH 30/35] Docstring --- pgmax/fg/groups.py | 2 +- pgmax/fg/nodes.py | 4 ++-- pgmax/groups/variables.py | 2 +- tests/fg/test_groups.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index e8079a36..e705f477 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -271,7 +271,7 @@ def unflatten(self, flat_data: Union[np.ndarray, jnp.ndarray]) -> Any: "Please subclass the FactorGroup class and override this method" ) - def compile_wiring(self, vars_to_starts: Mapping[Tuple[Any, int], int]) -> Any: + def compile_wiring(self, vars_to_starts: Mapping[Tuple[int, int], int]) -> Any: """Compile an efficient wiring for the FactorGroup. Args: diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py index f987b8df..06d7c5d6 100644 --- a/pgmax/fg/nodes.py +++ b/pgmax/fg/nodes.py @@ -1,7 +1,7 @@ """A module containing classes that specify the basic components of a Factor Graph.""" from dataclasses import asdict, dataclass -from typing import Any, List, Sequence, Tuple, Union +from typing import List, Sequence, Tuple, Union import jax import jax.numpy as jnp @@ -48,7 +48,7 @@ class Factor: NotImplementedError: If compile_wiring is not implemented """ - variables: List[Tuple[Any, int]] + variables: List[Tuple[int, int]] log_potentials: np.ndarray def __post_init__(self): diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index 22a7ef0e..87371c95 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -212,7 +212,7 @@ def __getitem__(self, var_name: Any) -> Tuple[int, int]: return (self.variable_hashes[var_idx], self.num_states[var_idx]) def flatten( - self, data: Mapping[Tuple[Tuple[int, int], int], Union[np.ndarray, jnp.ndarray]] + self, data: Mapping[Any, Union[np.ndarray, jnp.ndarray]] ) -> jnp.ndarray: """Function that turns meaningful structured data into a flat data array for internal use. diff --git a/tests/fg/test_groups.py b/tests/fg/test_groups.py index 7799ffb0..e39fbda3 100644 --- a/tests/fg/test_groups.py +++ b/tests/fg/test_groups.py @@ -24,7 +24,7 @@ def test_variable_dict(): "num_states should be an integer or a NumPy array of dtype int" ), ): - vgroup.NDVariableArray(shape=(2, 2), num_states=num_states) + vgroup.VariableDict(variable_names=tuple([0, 1, 2]), num_states=num_states) variable_dict = vgroup.VariableDict(variable_names=tuple([0, 1, 2]), num_states=15) with pytest.raises( From 63c6738ef19d1dfcb5ca54b852bd9a6e8f99e87f Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Thu, 28 Apr 2022 18:16:35 +0000 Subject: [PATCH 31/35] Minor changes --- examples/pmp_binary_deconvolution.py | 2 +- examples/rcn.py | 4 ++-- pgmax/fg/graph.py | 6 +++--- pgmax/fg/groups.py | 8 ++++---- pgmax/fg/nodes.py | 2 +- pgmax/groups/variables.py | 2 -- 6 files changed, 11 insertions(+), 13 deletions(-) diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py index a8b992e5..190abe36 100644 --- a/examples/pmp_binary_deconvolution.py +++ b/examples/pmp_binary_deconvolution.py @@ -220,7 +220,7 @@ def plot_images(images, display=True, nr=None): # %% pW = 0.25 -pS = 1e-70 +pS = 1e-100 pX = 1e-100 # Sparsity inducing priors for W and S diff --git a/examples/rcn.py b/examples/rcn.py index 140c1712..e83cf10c 100644 --- a/examples/rcn.py +++ b/examples/rcn.py @@ -417,8 +417,8 @@ def get_evidence(bu_msg: np.ndarray, frc: np.ndarray) -> np.ndarray: evidence_updates, map_states, ) - for idx, score in enumerate(score.values()): - scores[test_idx, idx] = score + for model_idx in range(frcs.shape[0]): + scores[test_idx, model_idx] = score[variables_all_models[model_idx]] end = time.time() print(f"Computing scores took {end-start:.3f} seconds for image {test_idx}.") diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 21ea1f11..77a51f71 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -323,7 +323,7 @@ class FactorGraphState: """ variable_groups: Sequence[groups.VariableGroup] - vars_to_starts: Mapping[Tuple[Any, int], int] + vars_to_starts: Mapping[Tuple[int, int], int] num_var_states: int total_factor_num_states: int factor_type_to_msgs_range: OrderedDict[type, Tuple[int, int]] @@ -460,7 +460,7 @@ def __getitem__(self, factor_group: groups.FactorGroup) -> np.ndarray: def __setitem__( self, - factor_group: Any, + factor_group: groups.FactorGroup, data: Union[np.ndarray, jnp.ndarray], ): """Set the log potentials for a FactorGroup @@ -563,7 +563,7 @@ def __setitem__( """Spreading beliefs at a variable to all connected Factors Args: - variable: A tuple representing a variable + variable: Variable queried data: An array containing the beliefs to be spread uniformly across all factors to variable messages involving this variable. """ diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index e705f477..311d53b1 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -29,14 +29,14 @@ @dataclass(frozen=True, eq=False) class VariableGroup: """Class to represent a group of variables. - Each variable is represented via a tuple of the form (variable hash/name, number of states) + Each variable is represented via a tuple of the form (variable hash, variable num_states) Arguments: num_states: An integer or an array specifying the number of states of the variables in this VariableGroup """ - num_states: Union[int, np.ndarray] = field(init=False) + num_states: Union[int, np.ndarray] def __post_init__(self): # Only compute the hash once, which is guaranteed to be an int64 @@ -56,7 +56,7 @@ def __lt__(self, other): def __getitem__(self, val: Any) -> Union[Tuple[int, int], List[Tuple[int, int]]]: """Given a variable name, index, or a group of variable indices, retrieve the associated variable(s). - Each variable is returned via a tuple of the form (variable hash, number of states) + Each variable is returned via a tuple of the form (variable hash, variable num_states) Args: val: a variable index, slice, or name @@ -82,7 +82,7 @@ def variable_hashes(self) -> np.ndarray: @cached_property def variables(self) -> List[Tuple[int, int]]: """Function that returns the list of all variables in the VariableGroup. - Each variable is represented by a tuple of the form (variable hash, number of states) + Each variable is represented by a tuple of the form (variable hash, variable num_states) Returns: List of variables in the VariableGroup diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py index 06d7c5d6..37cfce13 100644 --- a/pgmax/fg/nodes.py +++ b/pgmax/fg/nodes.py @@ -42,7 +42,7 @@ class Factor: Args: variables: List of variables connected by the Factor. - Each variable is represented by a tuple of the form (variable hash/name, number of states) + Each variable is represented by a tuple (variable hash, variable num_ states) Raises: NotImplementedError: If compile_wiring is not implemented diff --git a/pgmax/groups/variables.py b/pgmax/groups/variables.py index 87371c95..357d681a 100644 --- a/pgmax/groups/variables.py +++ b/pgmax/groups/variables.py @@ -22,7 +22,6 @@ class NDVariableArray(groups.VariableGroup): the notion of a NumPy ndarray shape) """ - num_states: Union[int, np.ndarray] shape: Tuple[int, ...] def __post_init__(self): @@ -161,7 +160,6 @@ class VariableDict(groups.VariableGroup): variable_names: A tuple of all the names of the variables in this VariableGroup. """ - num_states: Union[int, np.ndarray] variable_names: Tuple[Any, ...] def __post_init__(self): From 0397f9b788c95666bab3dad6680d380731ec59f7 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Thu, 28 Apr 2022 18:24:52 +0000 Subject: [PATCH 32/35] Doc --- pgmax/fg/nodes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pgmax/fg/nodes.py b/pgmax/fg/nodes.py index 37cfce13..8388a61e 100644 --- a/pgmax/fg/nodes.py +++ b/pgmax/fg/nodes.py @@ -42,7 +42,7 @@ class Factor: Args: variables: List of variables connected by the Factor. - Each variable is represented by a tuple (variable hash, variable num_ states) + Each variable is represented by a tuple (variable hash, variable num_states) Raises: NotImplementedError: If compile_wiring is not implemented From 8b3d60e131e9d31d86db9b1647b94f9b81244620 Mon Sep 17 00:00:00 2001 From: stannis <stannis@vicarious.com> Date: Fri, 29 Apr 2022 22:51:02 -0700 Subject: [PATCH 33/35] Rename this_hash --- pgmax/fg/groups.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 311d53b1..2a3d691a 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -41,12 +41,12 @@ class VariableGroup: def __post_init__(self): # Only compute the hash once, which is guaranteed to be an int64 this_id = id(self) % 2**32 - this_hash = this_id * int(MAX_SIZE) - assert this_hash < 2**63 - object.__setattr__(self, "this_hash", this_hash) + _hash = this_id * int(MAX_SIZE) + assert _hash < 2**63 + object.__setattr__(self, "_hash", _hash) def __hash__(self): - return self.this_hash + return self._hash def __eq__(self, other): return hash(self) == hash(other) From 8751e90c243564f4e43ec62e2f68cf4d1523b5b7 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Mon, 2 May 2022 18:37:53 +0000 Subject: [PATCH 34/35] Final comments --- examples/gmrf.py | 14 +++---- examples/pmp_binary_deconvolution.py | 24 +++-------- examples/rbm.py | 59 +++++++++------------------- pgmax/fg/graph.py | 1 - pgmax/fg/groups.py | 1 - 5 files changed, 30 insertions(+), 69 deletions(-) diff --git a/examples/gmrf.py b/examples/gmrf.py index 1bf4fad9..50e95dc7 100644 --- a/examples/gmrf.py +++ b/examples/gmrf.py @@ -58,7 +58,7 @@ fg = graph.FactorGraph(variables) # %% -# Add top-down factors +# Create top-down factors top_down = enumeration.PairwiseFactorGroup( variables_for_factors=[ [variables[ii, jj], variables[ii + 1, jj]] @@ -66,9 +66,8 @@ for jj in range(N) ], ) -fg.add_factors(top_down) -# Add left-right factors +# Create left-right factors left_right = enumeration.PairwiseFactorGroup( variables_for_factors=[ [variables[ii, jj], variables[ii, jj + 1]] @@ -76,9 +75,8 @@ for jj in range(N - 1) ], ) -fg.add_factors(left_right) -# Add diagonal factors +# Create diagonal factors diagonal0 = enumeration.PairwiseFactorGroup( variables_for_factors=[ [variables[ii, jj], variables[ii + 1, jj + 1]] @@ -86,8 +84,6 @@ for jj in range(N - 1) ], ) -fg.add_factors(diagonal0) - diagonal1 = enumeration.PairwiseFactorGroup( variables_for_factors=[ [variables[ii, jj], variables[ii - 1, jj + 1]] @@ -95,7 +91,9 @@ for jj in range(N - 1) ], ) -fg.add_factors(diagonal1) + +# Add factors +fg.add_factors([top_down, left_right, diagonal0, diagonal1]) # %% bp = graph.BP(fg.bp_state, temperature=1.0) diff --git a/examples/pmp_binary_deconvolution.py b/examples/pmp_binary_deconvolution.py index 190abe36..43e628af 100644 --- a/examples/pmp_binary_deconvolution.py +++ b/examples/pmp_binary_deconvolution.py @@ -116,9 +116,6 @@ def plot_images(images, display=True, nr=None): s_height = im_height - feat_height + 1 s_width = im_width - feat_width + 1 -import time - -start = time.time() # Binary features W = vgroup.NDVariableArray( num_states=2, shape=(n_chan, n_feat, feat_height, feat_width) @@ -135,16 +132,13 @@ def plot_images(images, display=True, nr=None): # Binary images obtained by convolution X = vgroup.NDVariableArray(num_states=2, shape=X_gt.shape) -print("Time", time.time() - start) # %% [markdown] -# For computation efficiency, we construct large FactorGroups instead of individual Factors +# For computation efficiency, we construct large FactorGroups instead of individual factors # %% -start = time.time() # Factor graph fg = graph.FactorGraph(variable_groups=[S, W, SW, X]) -print(time.time() - start) # Define the ANDFactors variables_for_ANDFactors = [] @@ -177,12 +171,10 @@ def plot_images(images, display=True, nr=None): X_var = X[idx_img, idx_chan, idx_img_height, idx_img_width] variables_for_ORFactors_dict[X_var].append(SW_var) -print("After loop", time.time() - start) # Add ANDFactorGroup, which is computationally efficient AND_factor_group = logical.ANDFactorGroup(variables_for_ANDFactors) fg.add_factors(AND_factor_group) -print(time.time() - start) # Define the ORFactors variables_for_ORFactors = [ @@ -193,7 +185,6 @@ def plot_images(images, display=True, nr=None): # Add ORFactorGroup, which is computationally efficient OR_factor_group = logical.ORFactorGroup(variables_for_ORFactors) fg.add_factors(OR_factor_group) -print("Time", time.time() - start) for factor_type, factor_groups in fg.factor_groups.items(): if len(factor_groups) > 0: @@ -211,16 +202,14 @@ def plot_images(images, display=True, nr=None): # in the same manner does not change X, so this naturally results in multiple equivalent modes. # %% -start = time.time() bp = graph.BP(fg.bp_state, temperature=0.0) -print("Time", time.time() - start) # %% [markdown] # We first compute the evidence without perturbation, similar to the PMP paper. # %% pW = 0.25 -pS = 1e-100 +pS = 1e-75 pX = 1e-100 # Sparsity inducing priors for W and S @@ -235,13 +224,12 @@ def plot_images(images, display=True, nr=None): uX[..., 0] = (2 * X_gt - 1) * logit(pX) # %% [markdown] -# We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap` +# We draw a batch of samples from the posterior in parallel by transforming `bp.init`/`bp.run_bp`/`bp.get_beliefs` with `jax.vmap` # %% -np.random.seed(seed=42) +np.random.seed(seed=0) n_samples = 4 -start = time.time() bp_arrays = jax.vmap(bp.init, in_axes=0, out_axes=0)( evidence_updates={ S: uS[None] + np.random.gumbel(size=(n_samples,) + uS.shape), @@ -250,13 +238,13 @@ def plot_images(images, display=True, nr=None): X: uX[None] + np.zeros(shape=(n_samples,) + uX.shape), }, ) -print("Time", time.time() - start) + bp_arrays = jax.vmap( functools.partial(bp.run_bp, num_iters=100, damping=0.5), in_axes=0, out_axes=0, )(bp_arrays) -print("Time", time.time() - start) + beliefs = jax.vmap(bp.get_beliefs, in_axes=0, out_axes=0)(bp_arrays) map_states = graph.decode_map_states(beliefs) diff --git a/examples/rbm.py b/examples/rbm.py index 4b391a62..c29efbd7 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -46,37 +46,30 @@ # We can then initialize the factor graph for the RBM with # %% -import time - -start = time.time() # Initialize factor graph hidden_variables = vgroup.NDVariableArray(num_states=2, shape=bh.shape) visible_variables = vgroup.NDVariableArray(num_states=2, shape=bv.shape) fg = graph.FactorGraph(variable_groups=[hidden_variables, visible_variables]) -print("Time", time.time() - start) # %% [markdown] # [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup) (e.g. an [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)), or a list/dictionary of [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed. # -# After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors by grouping them using +# After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors by grouping them using [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) # %% -start = time.time() - -# Add unary factors +# Create unary factors hidden_unaries = enumeration.EnumerationFactorGroup( variables_for_factors=[[hidden_variables[ii]] for ii in range(bh.shape[0])], factor_configs=np.arange(2)[:, None], log_potentials=np.stack([np.zeros_like(bh), bh], axis=1), ) - visible_unaries = enumeration.EnumerationFactorGroup( variables_for_factors=[[visible_variables[jj]] for jj in range(bv.shape[0])], factor_configs=np.arange(2)[:, None], log_potentials=np.stack([np.zeros_like(bv), bv], axis=1), ) -# Add pairwise factors +# Create pairwise factors log_potential_matrix = np.zeros(W.shape + (2, 2)).reshape((-1, 2, 2)) log_potential_matrix[:, 1, 1] = W.ravel() @@ -85,25 +78,23 @@ for ii in range(bh.shape[0]) for jj in range(bv.shape[0]) ] -print("Time", time.time() - start) pairwise_factors = enumeration.PairwiseFactorGroup( variables_for_factors=variables_for_factors, log_potential_matrix=log_potential_matrix, ) -print("Time", time.time() - start) +# Add factors to the FactorGraph fg.add_factors([hidden_unaries, visible_unaries, pairwise_factors]) -print("Time", time.time() - start) # %% [markdown] -# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing Groups of similar factors. The code above makes use of [`EnumerationFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumerationFactorGroup.html#pgmax.fg.groups.EnumerationFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup), two [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)s implemented in the [`pgmax.fg.groups`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.html#module-pgmax.fg.graph) module. +# PGMax implements convenient and computationally efficient [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) for representing groups of similar factors. The code above makes use of [`EnumerationFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.EnumerationFactorGroup.html#pgmax.fg.groups.EnumerationFactorGroup) and [`PairwiseFactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.PairwiseFactorGroup.html#pgmax.fg.groups.PairwiseFactorGroup). # -# A [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) is created by calling [`fg.add_factor_group`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph.add_factor_group), which takes 2 arguments: `factory` which specifies the [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) subclass, `variable_names_for_factors` which is a list of lists containing the name of the involved variables in the different factors, and additional arguments for the [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) (e.g. `factor_configs` or `log_potential_matrix` here). +# A [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) takes as argument `variables_for_factors` which is a list of lists of the variables involved in the different factors, and additional arguments specific to each [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup) (e.g. `factor_configs` or `log_potential_matrix` here). # -# In this example, since we construct `fg` with variables `dict(hidden=hidden_variables, visible=visible_variables)`, where `hidden_variables` and `visible_variables` are [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)s, we can refer to the `ii`th hidden variable as `("hidden", ii)` and the `jj`th visible variable as `("visible", jj)`. In general, PGMax implements an intuitive scheme for automatically assigning names to the variables in a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). +# In this example, since we construct `fg` with variables `hidden_variables` and `visible_variables`, which are both [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)s, we can refer to the `ii`th hidden variable as `hidden_variables[ii]` and the `jj`th visible variable as `visible_variables[jj]`. # -# An alternative way of creating the above factors is to add them iteratively by calling [`fg.add_factor`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph.add_factor) as below. This approach is not recommended as it is not computationally efficient. +# An alternative way of creating the above factors is to add them iteratively without building the [`FactorGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.FactorGroup.html#pgmax.fg.groups.FactorGroup)s as below. This approach is not recommended as it is not computationally efficient. # ~~~python # from pgmax.factors import enumeration as enumeration_factor # import itertools @@ -155,30 +146,28 @@ # More generally, PGMax implements LBP with temperature, with `temperature=0.0` and `temperature=1.0` corresponding to the commonly used max/sum-product LBP respectively. # # Now we are ready to demonstrate PMP sampling from RBM. PMP perturbs the model with [Gumbel](https://numpy.org/doc/stable/reference/random/generated/numpy.random.gumbel.html) unary potentials, and draws a sample from the RBM as the MAP decoding from running max-product LBP on the perturbed model +# +# import itertools +# +# from tqdm import tqdm # %% -start = time.time() bp = graph.BP(fg.bp_state, temperature=0.0) -print("Time", time.time() - start) # %% -start = time.time() bp_arrays = bp.init( evidence_updates={ hidden_variables: np.random.gumbel(size=(bh.shape[0], 2)), visible_variables: np.random.gumbel(size=(bv.shape[0], 2)), } ) -print("Time", time.time() - start) bp_arrays = bp.run_bp(bp_arrays, num_iters=100, damping=0.5) -print("Time", time.time() - start) beliefs = bp.get_beliefs(bp_arrays) -print("Time", time.time() - start) # %% [markdown] -# Here we use the `evidence_updates` argument of `run_bp` to perturb the model with Gumbel unary potentials. In general, `evidence_updates` can be used to incorporate evidence in the form of externally applied unary potentials in PGM inference. +# Here we use the `evidence_updates` argument of `bp.init` to perturb the model with Gumbel unary potentials. In general, `evidence_updates` can be used to incorporate evidence in the form of externally applied unary potentials in PGM inference. # -# Visualizing the MAP decoding (Figure [fig:rbm_single_digit]), we see that we have sampled an MNIST digit! +# Visualizing the MAP decoding, we see that we have sampled an MNIST digit! # %% fig, ax = plt.subplots(1, 1, figsize=(10, 10)) @@ -191,23 +180,11 @@ # %% [markdown] # PGMax adopts a functional interface for implementing LBP: running LBP in PGMax starts with # ~~~python -# run_bp, get_bp_state, get_beliefs = graph.BP(fg.bp_state, num_iters=NUM_ITERS, temperature=T) +# bp = graph.BP(fg.bp_state, temperature=T) # ~~~ -# where `run_bp` and `get_beliefs` are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`/`vmap`/`grad`, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see [here](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) and [here](https://github.com/n2cholas/awesome-jax)). In what follows we demonstrate an example on applying `jax.vmap`, a convenient transformation for automatically vectorizing functions. +# where the arguments of the `bp` are several useful functions to run LBP. In particular, `bp.init`, `bp.run_bp`, `bp.get_beliefs` are pure functions with no side-effects. This design choice means that we can easily apply JAX transformations like `jit`/`vmap`/`grad`, etc., to these functions, and additionally allows PGMax to seamlessly interact with other packages in the rapidly growing JAX ecosystem (see [here](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research) and [here](https://github.com/n2cholas/awesome-jax)). # -# Since we implement `run_bp`/`get_beliefs` as a pure function, we can apply `jax.vmap` to `run_bp`/`get_beliefs` to process a batch of samples/models in parallel. As an example, consider the PGMax implementation of PMP sampling from the RBM trained on MNIST images in Section [Tutorial: implementing LBP inference for RBMs with PGMax]. Instead of drawing one sample at a time -# ~~~python -# bp_arrays = run_bp( -# evidence_updates={ -# hidden_variables: np.random.gumbel(size=(bh.shape[0], 2)), -# visible_variables: np.random.gumbel(size=(bv.shape[0], 2)), -# }, -# damping=0.5, -# ) -# beliefs = get_beliefs(bp_arrays) -# map_states = graph.decode_map_states(beliefs) -# ~~~ -# we can draw a batch of samples in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap` +# As an example of applying `jax.vmap` to `bp.init`/`bp.run_bp`/`bp.get_beliefs` to process a batch of samples/models in parallel, instead of drawing one sample at a time as above, we can draw a batch of samples in parallel as follows: # %% n_samples = 10 @@ -227,7 +204,7 @@ map_states = graph.decode_map_states(beliefs) # %% [markdown] -# Visualizing the MAP decodings (Figure [fig:rbm_multiple_digits]), we see that we have sampled 10 MNIST digits in parallel! +# Visualizing the MAP decodings, we see that we have sampled 10 MNIST digits in parallel! # %% fig, ax = plt.subplots(2, 5, figsize=(20, 8)) diff --git a/pgmax/fg/graph.py b/pgmax/fg/graph.py index 77a51f71..16088e9f 100644 --- a/pgmax/fg/graph.py +++ b/pgmax/fg/graph.py @@ -65,7 +65,6 @@ def __post_init__(self): ) # See FactorGraphState docstrings for documentation on the following fields - # TODO: vars_to_starts does not have to be a dict self._vars_to_starts: Dict[Tuple[int, int], int] = {} vars_num_states_cumsum = 0 diff --git a/pgmax/fg/groups.py b/pgmax/fg/groups.py index 2a3d691a..e772095a 100644 --- a/pgmax/fg/groups.py +++ b/pgmax/fg/groups.py @@ -200,7 +200,6 @@ def factor_edges_num_states(self) -> np.ndarray: """ factor_edges_num_states = np.empty(shape=(self.factor_sizes.sum(),), dtype=int) - # TODO: create variables_for_factors as an array and move this to numba idx = 0 for variables_for_factor in self.variables_for_factors: for variable in variables_for_factor: From b265f14c42b657c88103f35e97047f92426575e7 Mon Sep 17 00:00:00 2001 From: Antoine Dedieu <antoine@vicarious.com> Date: Mon, 2 May 2022 18:39:20 +0000 Subject: [PATCH 35/35] Minor --- examples/rbm.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/examples/rbm.py b/examples/rbm.py index c29efbd7..5e751f44 100644 --- a/examples/rbm.py +++ b/examples/rbm.py @@ -146,10 +146,6 @@ # More generally, PGMax implements LBP with temperature, with `temperature=0.0` and `temperature=1.0` corresponding to the commonly used max/sum-product LBP respectively. # # Now we are ready to demonstrate PMP sampling from RBM. PMP perturbs the model with [Gumbel](https://numpy.org/doc/stable/reference/random/generated/numpy.random.gumbel.html) unary potentials, and draws a sample from the RBM as the MAP decoding from running max-product LBP on the perturbed model -# -# import itertools -# -# from tqdm import tqdm # %% bp = graph.BP(fg.bp_state, temperature=0.0)