Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Numba speedup for wiring + log potentials #133

Merged
merged 19 commits into from
Apr 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/pmp_binary_deconvolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def plot_images(images, display=True, nr=None):

# %%
pW = 0.25
pS = 1e-72
pS = 1e-100
pX = 1e-100

# Sparsity inducing priors for W and S
Expand All @@ -240,7 +240,7 @@ def plot_images(images, display=True, nr=None):
# We draw a batch of samples from the posterior in parallel by transforming `run_bp`/`get_beliefs` with `jax.vmap`

# %%
np.random.seed(seed=42)
np.random.seed(seed=40)
n_samples = 4

bp_arrays = jax.vmap(functools.partial(run_bp, damping=0.5), in_axes=0, out_axes=0)(
Expand Down
4 changes: 2 additions & 2 deletions examples/rbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# ---

# %% [markdown]
# [Restricted Boltzmann Machine (RBM)](https://en.wikipedia.org/wiki/Restricted_Boltzmann_machine) is a well-known and widely used PGM for learning probabilistic distributions over binary data. we demonstrate how we can easily implement [perturb-and-max-product (PMP)](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) sampling from an RBM trained on MNIST digits using PGMax. PMP is a recently proposed method for approximately sampling from a PGM by computing the maximum-a-posteriori (MAP) configuration (using max-product LBP) of a perturbed version of the model.
# [Restricted Boltzmann Machine (RBM)](https://en.wikipedia.org/wiki/Restricted_Boltzmann_machine) is a well-known and widely used PGM for learning probabilistic distributions over binary data. We demonstrate how we can easily implement [perturb-and-max-product (PMP)](https://proceedings.neurips.cc/paper/2021/hash/07b1c04a30f798b5506c1ec5acfb9031-Abstract.html) sampling from an RBM trained on MNIST digits using PGMax. PMP is a recently proposed method for approximately sampling from a PGM by computing the maximum-a-posteriori (MAP) configuration (using max-product LBP) of a perturbed version of the model.
#
# We start by making some necessary imports.

Expand Down Expand Up @@ -56,7 +56,7 @@
# %% [markdown]
# [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray) is a convenient class for specifying a group of variables living on a multidimensional grid with the same number of states, and shares some similarities with [`numpy.ndarray`](https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html). The [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph) `fg` is initialized with a set of variables, which can be either a single [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup) (e.g. an [`NDVariableArray`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.NDVariableArray.html#pgmax.fg.groups.NDVariableArray)), or a list/dictionary of [`VariableGroup`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.groups.VariableGroup.html#pgmax.fg.groups.VariableGroup)s. Once initialized, the set of variables in `fg` is fixed and cannot be changed.
#
# After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors one at a time to `fg` by
# After initialization, `fg` does not have any factors. PGMax supports imperatively adding factors to a [`FactorGraph`](https://pgmax.readthedocs.io/en/latest/_autosummary/pgmax.fg.graph.FactorGraph.html#pgmax.fg.graph.FactorGraph). We can add the unary and pairwise factors by grouping them using

# %%
# Add unary factors
Expand Down
225 changes: 128 additions & 97 deletions pgmax/factors/enumeration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import jax
import jax.numpy as jnp
import numba as nb
import numpy as np

from pgmax.bp import bp_utils
Expand All @@ -23,106 +24,81 @@ class EnumerationWiring(nodes.Wiring):
factor_configs_edge_states[ii, 0] contains the global EnumerationFactor config index,
factor_configs_edge_states[ii, 1] contains the corresponding global edge_state index.
Both indices only take into account the EnumerationFactors of the FactorGraph

Attributes:
num_val_configs: Number of valid configurations for this wiring
"""

factor_configs_edge_states: Union[np.ndarray, jnp.ndarray]

@property
def inference_arguments(self) -> Mapping[str, Union[np.ndarray, int]]:
"""
Returns:
A dictionnary of elements used to run belief propagation.
"""
def __post_init__(self):
super().__post_init__()

if self.factor_configs_edge_states.shape[0] == 0:
num_val_configs = 0
else:
num_val_configs = int(self.factor_configs_edge_states[-1, 0]) + 1

return {
"factor_configs_edge_states": self.factor_configs_edge_states,
"num_val_configs": num_val_configs,
}
object.__setattr__(self, "num_val_configs", num_val_configs)


@dataclass(frozen=True, eq=False)
class EnumerationFactor(nodes.Factor):
"""An enumeration factor

Args:
configs: Array of shape (num_val_configs, num_variables)
factor_configs: Array of shape (num_val_configs, num_variables)
An array containing an explicit enumeration of all valid configurations
log_potentials: Array of shape (num_val_configs,)
An array containing the log of the potential value for each valid configuration

Raises:
ValueError: If:
(1) The dtype of the configs array is not int
(1) The dtype of the factor_configs array is not int
(2) The dtype of the potential array is not float
(3) Configs does not have the correct shape
(3) factor_configs does not have the correct shape
(4) The potential array does not have the correct shape
(5) The configs array contains invalid values
(5) The factor_configs array contains invalid values
"""

configs: np.ndarray
factor_configs: np.ndarray
log_potentials: np.ndarray

def __post_init__(self):
self.configs.flags.writeable = False
if not np.issubdtype(self.configs.dtype, np.integer):
self.factor_configs.flags.writeable = False
if not np.issubdtype(self.factor_configs.dtype, np.integer):
raise ValueError(
f"Configurations should be integers. Got {self.configs.dtype}."
f"Configurations should be integers. Got {self.factor_configs.dtype}."
)

if not np.issubdtype(self.log_potentials.dtype, np.floating):
raise ValueError(
f"Potential should be floats. Got {self.log_potentials.dtype}."
)

if self.configs.ndim != 2:
if self.factor_configs.ndim != 2:
raise ValueError(
"configs should be a 2D array containing a list of valid configurations for "
f"EnumerationFactor. Got a configs array of shape {self.configs.shape}."
"factor_configs should be a 2D array containing a list of valid configurations for "
f"EnumerationFactor. Got a factor_configs array of shape {self.factor_configs.shape}."
)

if len(self.variables) != self.configs.shape[1]:
if len(self.variables) != self.factor_configs.shape[1]:
raise ValueError(
f"Number of variables {len(self.variables)} doesn't match given configurations {self.configs.shape}"
f"Number of variables {len(self.variables)} doesn't match given configurations {self.factor_configs.shape}"
)

if self.log_potentials.shape != (self.configs.shape[0],):
if self.log_potentials.shape != (self.factor_configs.shape[0],):
raise ValueError(
f"Expected log potentials of shape {(self.configs.shape[0],)} for "
f"({self.configs.shape[0]}) valid configurations. Got log potentials of "
f"Expected log potentials of shape {(self.factor_configs.shape[0],)} for "
f"({self.factor_configs.shape[0]}) valid configurations. Got log potentials of "
f"shape {self.log_potentials.shape}."
)

vars_num_states = np.array([variable.num_states for variable in self.variables])
if not np.logical_and(
self.configs >= 0, self.configs < vars_num_states[None]
self.factor_configs >= 0, self.factor_configs < vars_num_states[None]
).all():
raise ValueError("Invalid configurations for given variables")

def compile_wiring(
self, vars_to_starts: Mapping[nodes.Variable, int]
) -> EnumerationWiring:
"""Compile EnumerationWiring for the EnumerationFactor

Args:
vars_to_starts: A dictionary that maps variables to their global starting indices
For an n-state variable, a global start index of m means the global indices
of its n variable states are m, m + 1, ..., m + n - 1

Returns:
EnumerationWiring for the EnumerationFactor
"""
return compile_enumeration_wiring(
factor_edges_num_states=self.edges_num_states,
variables_for_factors=tuple([self.variables]),
factor_configs=self.configs,
vars_to_starts=vars_to_starts,
num_factors=1,
)

@staticmethod
def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiring:
"""Concatenate a list of EnumerationWirings
Expand Down Expand Up @@ -176,60 +152,115 @@ def concatenate_wirings(wirings: Sequence[EnumerationWiring]) -> EnumerationWiri
),
)

@staticmethod
def compile_wiring(
factor_edges_num_states: np.ndarray,
variables_for_factors: Tuple[nodes.Variable, ...],
factor_configs: np.ndarray,
vars_to_starts: Mapping[nodes.Variable, int],
num_factors: int,
) -> EnumerationWiring:
"""Compile an EnumerationWiring for an EnumerationFactor or a FactorGroup with EnumerationFactors.
Internally calls _compile_var_states_numba and _compile_enumeration_wiring_numba for speed.

def compile_enumeration_wiring(
factor_edges_num_states: np.ndarray,
variables_for_factors: Tuple[Tuple[nodes.Variable, ...], ...],
factor_configs: np.ndarray,
vars_to_starts: Mapping[nodes.Variable, int],
num_factors: int,
) -> EnumerationWiring:
"""Compile an EnumerationWiring for an EnumerationFactor or a FactorGroup with EnumerationFactors.
Args:
factor_edges_num_states: An array concatenating the number of states for the variables connected to each
Factor of the FactorGroup. Each variable will appear once for each Factor it connects to.
variables_for_factors: A tuple of tuples containing variables connected to each Factor of the FactorGroup.
Each variable will appear once for each Factor it connects to.
factor_configs: Array of shape (num_val_configs, num_variables) containing an explicit enumeration
of all valid configurations.
vars_to_starts: A dictionary that maps variables to their global starting indices
For an n-state variable, a global start index of m means the global indices
of its n variable states are m, m + 1, ..., m + n - 1
num_factors: Number of Factors in the FactorGroup.

Args:
factor_edges_num_states: An array concatenating the number of states for the variables connected to each
Factor of the FactorGroup. Each variable will appear once for each Factor it connects to.
variables_for_factors: A tuple of tuples containing variables connected to each Factor of the FactorGroup.
Each variable will appear once for each Factor it connects to.
factor_configs: Array of shape (num_val_configs, num_variables) containing an explicit enumeration
of all valid configurations.
vars_to_starts: A dictionary that maps variables to their global starting indices
For an n-state variable, a global start index of m means the global indices
of its n variable states are m, m + 1, ..., m + n - 1
num_factors: Number of Factors in the FactorGroup.
Raises:
ValueError: if factor_edges_num_states is not of shape (num_factors * num_variables, )

Returns:
The EnumerationWiring
"""
var_states_for_edges = []
for variables_for_factor in variables_for_factors:
for variable in variables_for_factor:
num_states = variable.num_states
this_var_states_for_edges = np.arange(
vars_to_starts[variable], vars_to_starts[variable] + num_states
Returns:
The EnumerationWiring
"""
var_states = np.array(
[vars_to_starts[variable] for variable in variables_for_factors]
)
num_states = np.array(
[variable.num_states for variable in variables_for_factors]
)
num_states_cumsum = np.insert(np.cumsum(num_states), 0, 0)
var_states_for_edges = np.empty(shape=(num_states_cumsum[-1],), dtype=int)
_compile_var_states_numba(var_states_for_edges, num_states_cumsum, var_states)

num_configs, num_variables = factor_configs.shape
if not factor_edges_num_states.shape == (num_factors * num_variables,):
raise ValueError(
f"Expected factor_edges_num_states shape is {(num_factors * num_variables,)}. Got {factor_edges_num_states.shape}."
)
var_states_for_edges.append(this_var_states_for_edges)
factor_configs_edge_states = np.empty(
(num_factors * num_configs * num_variables, 2), dtype=int
)
factor_edges_starts = np.insert(np.cumsum(factor_edges_num_states), 0, 0)
_compile_enumeration_wiring_numba(
factor_configs_edge_states, factor_configs, factor_edges_starts, num_factors
)

# Note: edges_starts corresponds to the factor_to_msgs_start for the LogicalFactors
edges_starts = np.insert(factor_edges_num_states.cumsum(), 0, 0)[:-1].reshape(
-1, factor_configs.shape[1]
)
return EnumerationWiring(
edges_num_states=factor_edges_num_states,
var_states_for_edges=var_states_for_edges,
factor_configs_edge_states=factor_configs_edge_states,
)

factor_configs_edge_states = np.stack(
[
np.repeat(
np.arange(factor_configs.shape[0] * num_factors),
factor_configs.shape[1],
),
(factor_configs[None] + edges_starts[:, None, :]).flatten(),
],
axis=1,
)
return EnumerationWiring(
edges_num_states=factor_edges_num_states,
var_states_for_edges=np.concatenate(var_states_for_edges),
factor_configs_edge_states=factor_configs_edge_states,
)

@nb.jit(parallel=False, cache=True, fastmath=True, nopython=True)
def _compile_var_states_numba(
Copy link
Contributor

@wlehrach wlehrach Apr 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason you make the caller allocate these arrays? In general t's cleaner and less likely to result in error to allocate return arrays inside numba rather that mutating a passed in array. You can get a very small optimization by re-using arrays between calls (so highly performance sensitive code it can be useful), but you're not doing that here. You can refer to dtype of incoming arrays as well and copy that.

var_states_for_edges: np.ndarray,
num_states_cumsum: np.ndarray,
var_states: np.ndarray,
) -> np.ndarray:
"""Fast numba computation of the var_states_for_edges of a Wiring.
var_states_for_edges is updated in-place.
"""

for variable_idx in nb.prange(num_states_cumsum.shape[0] - 1):
start_variable, end_variable = (
num_states_cumsum[variable_idx],
num_states_cumsum[variable_idx + 1],
)
var_states_for_edges[start_variable:end_variable] = var_states[
variable_idx
] + np.arange(end_variable - start_variable)


@nb.jit(parallel=False, cache=True, fastmath=True, nopython=True)
def _compile_enumeration_wiring_numba(
factor_configs_edge_states: np.ndarray,
factor_configs: np.ndarray,
factor_edges_starts: np.ndarray,
num_factors: int,
) -> np.ndarray:
"""Fast numba computation of the factor_configs_edge_states of an EnumerationWiring.
factor_edges_starts is updated in-place.
"""

num_configs, num_variables = factor_configs.shape

for factor_idx in nb.prange(num_factors):
for config_idx in range(num_configs):
factor_config_idx = num_configs * factor_idx + config_idx
factor_configs_edge_states[
num_variables
* factor_config_idx : num_variables
* (factor_config_idx + 1),
0,
] = factor_config_idx

for var_idx in range(num_variables):
factor_configs_edge_states[
num_variables * factor_config_idx + var_idx, 1
] = (
factor_edges_starts[num_variables * factor_idx + var_idx]
+ factor_configs[config_idx, var_idx]
)


@functools.partial(jax.jit, static_argnames=("num_val_configs", "temperature"))
Expand Down
Loading