Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Add support for sum-product with temperature #104

Merged
merged 5 commits into from
Dec 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions pgmax/bp/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,20 @@ def pass_var_to_fac_messages(
return vtof_msgs


@functools.partial(jax.jit, static_argnames=("num_val_configs"))
@functools.partial(jax.jit, static_argnames=("num_val_configs", "temperature"))
def pass_fac_to_var_messages(
vtof_msgs: jnp.ndarray,
factor_configs_edge_states: jnp.ndarray,
log_potentials: jnp.ndarray,
num_val_configs: int,
temperature: float,
) -> jnp.ndarray:

"""Passes messages from Factors to Variables.

The update is performed in two steps. First, a "summary" array is generated that has an entry for every valid
configuration for every factor. The elements of this array are simply the sums of messages across each valid
config. Then, the info from edge_vals_to_config_summary_indices is used to apply the scattering operation and
config. Then, the info from factor_configs_edge_states is used to apply the scattering operation and
generate a flat set of output messages.

Args:
Expand All @@ -60,6 +61,8 @@ def pass_fac_to_var_messages(
log_potentials: Array of shape (num_val_configs, ). An entry at index i is the log potential
function value for the configuration with global factor config index i.
num_val_configs: the total number of valid configurations for factors in the factor graph.
temperature: Temperature for loopy belief propagation.
1.0 corresponds to sum-product, 0.0 corresponds to max-product.

Returns:
Array of shape (num_edge_state,). This holds all the flattened factor to variable messages.
Expand All @@ -69,11 +72,32 @@ def pass_fac_to_var_messages(
.at[factor_configs_edge_states[..., 0]]
.add(vtof_msgs[factor_configs_edge_states[..., 1]])
) + log_potentials
ftov_msgs = (
max_factor_config_summary_for_edge_states = (
jnp.full(shape=(vtof_msgs.shape[0],), fill_value=NEG_INF)
.at[factor_configs_edge_states[..., 1]]
.max(fac_config_summary_sum[factor_configs_edge_states[..., 0]])
) - vtof_msgs
)
ftov_msgs = max_factor_config_summary_for_edge_states - vtof_msgs
if temperature != 0.0:
ftov_msgs = ftov_msgs + (
temperature
* jnp.log(
jnp.full(shape=(vtof_msgs.shape[0],), fill_value=jnp.exp(NEG_INF))
.at[factor_configs_edge_states[..., 1]]
.add(
jnp.exp(
(
fac_config_summary_sum[factor_configs_edge_states[..., 0]]
- max_factor_config_summary_for_edge_states[
factor_configs_edge_states[..., 1]
]
)
/ temperature
)
)
)
)

return ftov_msgs


Expand Down
25 changes: 24 additions & 1 deletion pgmax/fg/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import jax
import jax.numpy as jnp
import numpy as np
from jax.scipy.special import logsumexp

from pgmax.bp import infer
from pgmax.fg import fg_utils, groups, nodes
Expand Down Expand Up @@ -781,12 +782,16 @@ def tree_unflatten(cls, aux_data, children):
return cls(**aux_data.unflatten(children))


def BP(bp_state: BPState, num_iters: int) -> Tuple[Callable, Callable, Callable]:
def BP(
bp_state: BPState, num_iters: int, temperature: float = 0.0
) -> Tuple[Callable, Callable, Callable]:
"""Function for generating belief propagation functions.

Args:
bp_state: Belief propagation state.
num_iters: Number of belief propagation iterations.
temperature: Temperature for loopy belief propagation.
1.0 corresponds to sum-product, 0.0 corresponds to max-product.

Returns:
Tuple containing\n
Expand Down Expand Up @@ -856,6 +861,7 @@ def update(msgs: jnp.ndarray, _) -> Tuple[jnp.ndarray, None]:
wiring.factor_configs_edge_states,
log_potentials,
num_val_configs,
temperature,
)
# Use the results of message passing to perform damping and
# update the factor to variable messages
Expand Down Expand Up @@ -930,3 +936,20 @@ def decode_map_states(beliefs: Any) -> Any:
beliefs,
)
return map_states


@jax.jit
def get_marginals(beliefs: Any) -> Any:
"""Function to get marginal probabilities given the calculated beliefs.

Args:
beliefs: An array or a PyTree container containing beliefs for different variables.

Returns:
An array or a PyTree container containing the marginal probabilities different variables.
"""
marginals = jax.tree_util.tree_map(
lambda x: jnp.exp(x - logsumexp(x, axis=-1, keepdims=True)),
beliefs,
)
return marginals
3 changes: 3 additions & 0 deletions tests/test_pgmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,6 @@ def binary_connected_variables(
bp_state.evidence[1, 0, 0]
assert isinstance(bp_state.evidence.value, np.ndarray)
assert len(fg.factors) == 7056
run_bp, _, get_beliefs = graph.BP(bp_state, 1, 1.0)
marginals = graph.get_marginals(get_beliefs(run_bp()))
assert jnp.allclose(jnp.sum(marginals[0], axis=-1), 1.0)