Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Fix offsets computation for messages/potentials; Support AND factor #126

Merged
merged 8 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pgmax/factors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
] = collections.OrderedDict(
[
(enumeration.EnumerationFactor, enumeration.pass_enum_fac_to_var_messages),
(logical.ORFactor, logical.pass_OR_fac_to_var_messages),
(logical.ORFactor, logical.pass_logical_fac_to_var_messages),
(logical.ANDFactor, logical.pass_logical_fac_to_var_messages),
]
)
168 changes: 98 additions & 70 deletions pgmax/factors/logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import numpy as np
from jax.nn import log_sigmoid, sigmoid

from pgmax import utils
from pgmax.bp import bp_utils
from pgmax.fg import nodes

Expand All @@ -31,15 +30,19 @@ class LogicalWiring(nodes.Wiring):
children_edge_states[ii] contains the message index of the child variable's state 0,
which takes into account all the LogicalFactors of the same subtype (OR/AND) of the FactorGraph.
The child variable's state 1 is children_edge_states[ii, 1] + 1.
edge_states_offset: Offset to go from a variable's relevant state to its other state
For ORFactors the edge_states_offset is 1, for ANDFactors the edge_states_offset is -1.

Raises:
ValueError: If:
(1) The are no num_logical_factors different factor indices
(2) There is a factor index higher than num_logical_factors - 1
(3) The edge_states_offset is not 1 or -1
"""

parents_edge_states: Union[np.ndarray, jnp.ndarray]
children_edge_states: Union[np.ndarray, jnp.ndarray]
edge_states_offset: int

def __post_init__(self):
if self.children_edge_states.shape[0] > 0:
Expand All @@ -56,6 +59,11 @@ def __post_init__(self):
f"The highest LogicalFactor index must be {num_logical_factors - 1}"
)

if self.edge_states_offset != 1 and self.edge_states_offset != -1:
raise ValueError(
f"The LogicalWiring's edge_states_offset must be 1 (for OR) and -1 (for AND), but is {self.edge_states_offset}"
)

@property
def inference_arguments(self) -> Mapping[str, np.ndarray]:
"""
Expand All @@ -65,6 +73,7 @@ def inference_arguments(self) -> Mapping[str, np.ndarray]:
return {
"parents_edge_states": self.parents_edge_states,
"children_edge_states": self.children_edge_states,
"edge_states_offset": self.edge_states_offset,
}


Expand All @@ -73,13 +82,18 @@ class LogicalFactor(nodes.Factor):
"""A logical OR/AND factor of the form (p1,...,pn, c)
where p1,...,pn are the parents variables and c is the child variable.

Args:
edge_states_offset: Offset to go from a variable's relevant state to its other state
For ORFactors the edge_states_offset is 1, for ANDFactors the edge_states_offset is -1.

Raises:
ValueError: If:
(1) There are less than 2 variables
(2) The variables are not all binary
"""

log_potentials: np.ndarray = field(init=False, default=np.empty((0,)))
edge_states_offset: int = field(init=False)

def __post_init__(self):
if len(self.variables) < 2:
Expand All @@ -90,59 +104,6 @@ def __post_init__(self):
if not np.all([variable.num_states == 2 for variable in self.variables]):
raise ValueError("All variables should all be binary")

@utils.cached_property
def parents_edge_states(self) -> np.ndarray:
"""
Returns:
Array of shape (num_parents, 2)
parents_edge_states[ii, 0] contains the local ORFactor index,
parents_edge_states[ii, 1] contains the message index of the parent variable's state 0.
"""
num_parents = len(self.variables) - 1

parents_edge_states = np.vstack(
[
np.zeros(num_parents, dtype=int),
np.arange(0, 2 * num_parents, 2, dtype=int),
],
).T
return parents_edge_states

@utils.cached_property
def child_edge_state(self) -> np.ndarray:
"""
Returns:
Array of shape (num_factors,)
children_edge_states[ii] contains the message index of the child variable's state 0.
"""
return np.array([2 * (len(self.variables) - 1)], dtype=int)

def compile_wiring(
self, vars_to_starts: Mapping[nodes.Variable, int]
) -> LogicalWiring:
"""Compile LogicalWiring for the LogicalFactor

Args:
vars_to_starts: A dictionary that maps variables to their global starting indices
For an n-state variable, a global start index of m means the global indices
of its n variable states are m, m + 1, ..., m + n - 1

Returns:
LogicalWiring for the LogicalFactor
"""
var_states_for_edges = np.concatenate(
[
np.arange(variable.num_states) + vars_to_starts[variable]
for variable in self.variables
]
)
return LogicalWiring(
edges_num_states=self.edges_num_states,
var_states_for_edges=var_states_for_edges,
parents_edge_states=self.parents_edge_states,
children_edge_states=self.child_edge_state,
)

@staticmethod
def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring:
"""Concatenate a list of LogicalWirings
Expand All @@ -159,6 +120,7 @@ def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring:
var_states_for_edges=np.empty((0,), dtype=int),
parents_edge_states=np.empty((0, 2), dtype=int),
children_edge_states=np.empty((0,), dtype=int),
edge_states_offset=1,
)

# Note: this correspomds to all the factor_to_msgs_starts for the LogicalFactors
Expand All @@ -167,7 +129,6 @@ def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring:
0,
0,
)[:-1]

parents_edge_states = []
children_edge_states = []
for ww, or_wiring in enumerate(wirings):
Expand All @@ -184,6 +145,43 @@ def concatenate_wirings(wirings: Sequence[LogicalWiring]) -> LogicalWiring:
),
parents_edge_states=np.concatenate(parents_edge_states, axis=0),
children_edge_states=np.concatenate(children_edge_states, axis=0),
edge_states_offset=wirings[0].edge_states_offset,
)

def compile_wiring(
self, vars_to_starts: Mapping[nodes.Variable, int]
) -> LogicalWiring:
"""Compile LogicalWiring for the LogicalFactor

Args:
vars_to_starts: A dictionary that maps variables to their global starting indices
For an n-state variable, a global start index of m means the global indices
of its n variable states are m, m + 1, ..., m + n - 1

Returns:
LogicalWiring for the LogicalFactor
"""
var_states_for_edges = np.concatenate(
[
np.arange(variable.num_states) + vars_to_starts[variable]
for variable in self.variables
]
)
num_parents = len(self.variables) - 1
relevant_state = (-self.edge_states_offset + 1) / 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
relevant_state = (-self.edge_states_offset + 1) / 2
relevant_state = (-self.edge_states_offset + 1) // 2

parents_edge_states = np.vstack(
[
np.zeros(num_parents, dtype=int),
np.arange(relevant_state, 2 * num_parents, 2, dtype=int),
],
).T
child_edge_state = np.array([2 * num_parents + relevant_state], dtype=int)
return LogicalWiring(
edges_num_states=self.edges_num_states,
var_states_for_edges=var_states_for_edges,
parents_edge_states=parents_edge_states,
children_edge_states=child_edge_state,
edge_states_offset=self.edge_states_offset,
)


Expand All @@ -195,48 +193,74 @@ class ORFactor(LogicalFactor):
An OR factor is defined as:
F(p1, p2, ..., pn, c) = 0 <=> c = OR(p1, p2, ..., pn)
F(p1, p2, ..., pn, c) = -inf o.w.

Args:
edge_states_offset: Offset to go from a variable's relevant state to its other state
For ORFactors the edge_states_offset is 1.
"""

edge_states_offset: int = field(init=False, default=1)


@dataclass(frozen=True, eq=False)
class ANDFactor(LogicalFactor):
"""An AND factor of the form (p1,...,pn, c)
where p1,...,pn are the parents variables and c is the child variable.

An AND factor is defined as:
F(p1, p2, ..., pn, c) = 0 <=> c = AND(p1, p2, ..., pn)
F(p1, p2, ..., pn, c) = -inf o.w.

Args:
edge_states_offset: Offset to go from a variable's relevant state to its other state
For ANDFactors the edge_states_offset is -1.
"""

pass
edge_states_offset: int = field(init=False, default=-1)


@functools.partial(jax.jit, static_argnames=("temperature"))
def pass_OR_fac_to_var_messages(
def pass_logical_fac_to_var_messages(
vtof_msgs: jnp.ndarray,
parents_edge_states: jnp.ndarray,
children_edge_states: jnp.ndarray,
edge_states_offset: int,
temperature: float,
log_potentials: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:

"""Passes messages from ORFactors to Variables.
"""Passes messages from LogicalFactors to Variables.

Args:
vtof_msgs: Array of shape (num_edge_state,). This holds all the flattened variable to all the ORFactors messages.
vtof_msgs: Array of shape (num_edge_state,). This holds all the flattened variable to all the LogicalFactors messages.
parents_edge_states: Array of shape (num_parents, 2)
parents_edge_states[ii, 0] contains the global ORFactor index,
parents_edge_states[ii, 1] contains the message index of the parent variable's state 0.
Both indices only take into account the ORFactors of the FactorGraph
The parent variable's state 1 is parents_edge_states[ii, 2] + 1
parents_edge_states[ii, 0] contains the global LogicalFactor index,
parents_edge_states[ii, 1] contains the message index of the parent variable's relevant state.
For ORFactors the relevant state is 0, for ANDFactors the relevant state is 1.
Both indices only take into account the LogicalFactors of the FactorGraph
The parent variable's other state is parents_edge_states[ii, 2] + edge_states_offset
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The parent variable's other state is parents_edge_states[ii, 2] + edge_states_offset
The parent variable's other state is parents_edge_states[ii, 1] + edge_states_offset

children_edge_states: Array of shape (num_factors,)
children_edge_states[ii] contains the message index of the child variable's state 0
The child variable's state 1 is children_edge_states[ii, 1] + 1
children_edge_states[ii] contains the message index of the child variable's relevant state.
For ORFactors the relevant state is 0, for ANDFactors the relevant state is 1.
The child variable's other state is children_edge_states[ii, 1] + edge_states_offset
edge_states_offset: Offset to go from a variable's relevant state to its other state
For ORFactors the edge_states_offset is 1, for ANDFactors the edge_states_offset is -1.
temperature: Temperature for loopy belief propagation.
1.0 corresponds to sum-product, 0.0 corresponds to max-product.

Returns:
Array of shape (num_edge_state,). This holds all the flattened ORFactors to variable messages.
"""
num_factors = children_edge_states.shape[0]

factor_indices = parents_edge_states[..., 0]

parents_tof_msgs = (
vtof_msgs[parents_edge_states[..., 1] + 1]
vtof_msgs[parents_edge_states[..., 1] + edge_states_offset]
- vtof_msgs[parents_edge_states[..., 1]]
)
children_tof_msgs = (
vtof_msgs[children_edge_states + 1] - vtof_msgs[children_edge_states]
vtof_msgs[children_edge_states + edge_states_offset]
- vtof_msgs[children_edge_states]
)

# Consider the max-product case separately.
Expand Down Expand Up @@ -321,6 +345,10 @@ def g(x):
)

ftov_msgs = jnp.zeros_like(vtof_msgs)
ftov_msgs = ftov_msgs.at[parents_edge_states[..., 1] + 1].set(parents_msgs)
ftov_msgs = ftov_msgs.at[children_edge_states + 1].set(children_msgs)
ftov_msgs = ftov_msgs.at[parents_edge_states[..., 1] + edge_states_offset].set(
parents_msgs
)
ftov_msgs = ftov_msgs.at[children_edge_states + edge_states_offset].set(
children_msgs
)
return ftov_msgs
40 changes: 13 additions & 27 deletions pgmax/fg/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,53 +228,39 @@ def compute_offsets(self) -> None:
self._factor_type_to_potentials_range = collections.OrderedDict()
self._factor_group_to_potentials_starts = collections.OrderedDict()
self._factor_to_potentials_starts = collections.OrderedDict()
factor_group_num_configs_cumsum = 0
factor_num_configs_cumsum = 0

for factor_type, factors_groups_by_type in self.factor_groups.items():
factor_num_states_start = factor_num_states_cumsum
factor_group_num_configs_start = factor_group_num_configs_cumsum

# As inference will be run by chunking the flattened arrays of messages from variables
# to factors according to their factor types, this resets the offsets to 0 within a type
factor_num_states_cumsum_by_type = 0
factor_group_num_configs_cumsum_by_type = 0

factor_type_num_states_start = factor_num_states_cumsum
factor_type_num_configs_start = factor_num_configs_cumsum
for factor_group in factors_groups_by_type:
self._factor_group_to_msgs_starts[
factor_group
] = factor_num_states_cumsum_by_type
] = factor_num_states_cumsum
self._factor_group_to_potentials_starts[
factor_group
] = factor_group_num_configs_cumsum_by_type
] = factor_num_configs_cumsum

for factor in factor_group.factors:
self._factor_to_msgs_starts[
factor
] = factor_num_states_cumsum_by_type
self._factor_to_msgs_starts[factor] = factor_num_states_cumsum
self._factor_to_potentials_starts[
factor
] = factor_group_num_configs_cumsum_by_type
] = factor_num_configs_cumsum

factor_num_states_cumsum_by_type += np.sum(factor.edges_num_states)
if factor.log_potentials is not None:
factor_group_num_configs_cumsum_by_type += (
factor.log_potentials.shape[0]
)
factor_num_states_cumsum += np.sum(factor.edges_num_states)
factor_num_configs_cumsum += factor.log_potentials.shape[0]

# Add global offsets
factor_num_states_cumsum += factor_num_states_cumsum_by_type
factor_group_num_configs_cumsum += factor_group_num_configs_cumsum_by_type
self._factor_type_to_msgs_range[factor_type] = (
factor_num_states_start,
factor_type_num_states_start,
factor_num_states_cumsum,
)
self._factor_type_to_potentials_range[factor_type] = (
factor_group_num_configs_start,
factor_group_num_configs_cumsum,
factor_type_num_configs_start,
factor_num_configs_cumsum,
)

self._total_factor_num_states = factor_num_states_cumsum
self._total_factor_num_configs = factor_group_num_configs_cumsum
self._total_factor_num_configs = factor_num_configs_cumsum

@cached_property
def wiring(self) -> OrderedDict[Type, nodes.Wiring]:
Expand Down
Loading