-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathpool.py
474 lines (408 loc) · 16.9 KB
/
pool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
# Copyright 2023 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines a pool factor."""
import dataclasses
import functools
from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union
import warnings
import jax
import jax.numpy as jnp
import numpy as np
from pgmax.factor import factor
from pgmax.factor import logical
from pgmax.factor import update_utils
from pgmax.utils import NEG_INF
# pylint: disable=unexpected-keyword-arg
@jax.tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True, eq=False)
class PoolWiring(factor.Wiring):
"""Wiring for PoolFactors.
Attributes:
pool_choices_edge_states: Array of shape (num_pool_choices, 2)
pool_choices_edge_states[ii, 0] contains the global PoolFactor index
pool_choices_edge_states[ii, 1] contains the message index of the pool
choice variable's state 0.
The message index of the pool choice variable's state 1 is
pool_choices_edge_states[ii, 1] + 1
Both indices only take into account the PoolFactors of the FactorGraph
pool_indicators_edge_states: Array of shape (num_pool_factors,)
pool_indicators_edge_states[ii] contains the message index of the pool
indicator variable's state 0
The message index of the pool indicator variable's state 1 is
pool_indicators_edge_states[ii] + 1
Only takes into account the PoolFactors of the FactorGraph
Raises:
ValueError: If:
(1) The are no num_pool_factors different factor indices
(2) There is a factor index higher than num_pool_factors - 1
"""
pool_choices_edge_states: Union[np.ndarray, jnp.ndarray]
pool_indicators_edge_states: Union[np.ndarray, jnp.ndarray]
def get_inference_arguments(self) -> Dict[str, Any]:
"""Return the list of arguments to run BP with LogicalWirings."""
if self.pool_choices_edge_states.shape[0] > 0:
pool_factor_indices = self.pool_choices_edge_states[:, 0]
num_pool_factors = self.pool_indicators_edge_states.shape[0]
if np.unique(pool_factor_indices).shape[0] != num_pool_factors:
raise ValueError(
f"The PoolWiring must have {num_pool_factors} different"
" PoolFactor indices"
)
if pool_factor_indices.max() >= num_pool_factors:
raise ValueError(
f"The highest PoolFactor index must be {num_pool_factors - 1}"
)
return {
"pool_choices_factor_indices": self.pool_choices_edge_states[..., 0],
"pool_choices_msg_indices": self.pool_choices_edge_states[..., 1],
"pool_indicators_edge_states": self.pool_indicators_edge_states,
}
@dataclasses.dataclass(frozen=True, eq=False)
class PoolFactor(factor.Factor):
"""A Pool factor of the form (pc1, ...,pcn, pi) where (pc1,...,pcn) are the pool choices and pi is the pool indicator.
A Pool factor is defined as:
F(pc1, ...,pcn, pi) = 0 <=> (pc1=...=pcn=pi=0) OR (pi=1 AND pc1 +...+ pcn=1)
F(pc1, ...,pcn, pi) = -inf o.w.
i.e. either (a) all the variables are set to 0, or (b) the pool indicator
variable is set to 1 and exactly one of the pool choices variables is set to 1
Note: placing the pool indicator at the end allows us to reuse our
existing infrastucture for wiring logical factors
"""
log_potentials: np.ndarray = dataclasses.field(
init=False,
default_factory=lambda: np.empty((0,)),
)
def __post_init__(self):
if len(self.variables) < 2:
raise ValueError(
"A PoolFactor requires at least one pool choice and one pool "
"indicator."
)
if not np.all([variable[1] == 2 for variable in self.variables]):
raise ValueError("All the variables in a PoolFactor should all be binary")
@staticmethod
def concatenate_wirings(wirings: Sequence[PoolWiring]) -> PoolWiring:
"""Concatenate a list of PoolWirings.
Args:
wirings: A list of PoolWirings
Returns:
Concatenated PoolWiring
"""
if not wirings:
return PoolWiring(
var_states_for_edges=np.empty((0, 3), dtype=int),
pool_choices_edge_states=np.empty((0, 2), dtype=int),
pool_indicators_edge_states=np.empty((0,), dtype=int),
)
logical_wirings = []
for wiring in wirings:
logical_wiring = logical.LogicalWiring(
var_states_for_edges=wiring.var_states_for_edges,
parents_edge_states=wiring.pool_choices_edge_states,
children_edge_states=wiring.pool_indicators_edge_states,
edge_states_offset=1,
)
logical_wirings.append(logical_wiring)
logical_wiring = logical.LogicalFactor.concatenate_wirings(logical_wirings)
return PoolWiring(
var_states_for_edges=logical_wiring.var_states_for_edges,
pool_choices_edge_states=logical_wiring.parents_edge_states,
pool_indicators_edge_states=logical_wiring.children_edge_states,
)
# pylint: disable=g-doc-args
@staticmethod
def compile_wiring(
factor_edges_num_states: np.ndarray,
variables_for_factors: Sequence[List[Tuple[int, int]]],
factor_sizes: np.ndarray,
vars_to_starts: Mapping[Tuple[int, int], int],
) -> PoolWiring:
"""Compile a PoolWiring for a PoolFactor or for a FactorGroup with PoolFactors.
Internally uses the logical factor compile_wiring.
Args: See LogicalFactor.compile_wiring docstring.
Returns:
The PoolWiring
"""
logical_wiring = logical.LogicalFactor.compile_wiring(
factor_edges_num_states=factor_edges_num_states,
variables_for_factors=variables_for_factors,
factor_sizes=factor_sizes,
vars_to_starts=vars_to_starts,
edge_states_offset=1,
)
return PoolWiring(
var_states_for_edges=logical_wiring.var_states_for_edges,
pool_choices_edge_states=logical_wiring.parents_edge_states,
pool_indicators_edge_states=logical_wiring.children_edge_states,
)
@staticmethod
@jax.jit
def compute_energy(
edge_states_one_hot_decoding: jnp.ndarray,
pool_choices_factor_indices: jnp.ndarray,
pool_choices_msg_indices: jnp.ndarray,
pool_indicators_edge_states: jnp.ndarray,
log_potentials: Optional[jnp.ndarray] = None,
) -> float:
"""Returns the contribution to the energy of several PoolFactors.
Args:
edge_states_one_hot_decoding: Array of shape (num_edge_states,)
Flattened array of one-hot decoding of the edge states connected to the
PoolFactors
pool_choices_factor_indices: Array of shape (num_pool_choices,)
pool_choices_factor_indices[ii] contains the global PoolFactor index of
the pool choice variable's state 0
Only takes into account the PoolFactors of the FactorGraph
pool_choices_msg_indices: Array of shape (num_pool_choices,)
pool_choices_msg_indices[ii] contains the message index of the pool
choice variable's state 0
The message index of the pool choice variable's state 1 is
pool_choices_msg_indices[ii] + 1
Only takes into account the PoolFactors of the FactorGraph
pool_indicators_edge_states: Array of shape (num_pool_factors,)
pool_indicators_edge_states[ii] contains the message index of the pool
indicator variable's state 0
The message index of the pool indicator variable's state 1 is
pool_indicators_edge_states[ii] + 1
Only takes into account the PoolFactors of the FactorGraph
log_potentials: Optional array of log potentials
"""
num_factors = pool_indicators_edge_states.shape[0]
# pool_choices_edge_states[..., 1] + 1 contains the state 1
# Either all the pool_choices and the pool_indicator are set to 0
# or exactly one of the pool_choices and the pool_indicator are set to 1
pool_choices_decoded = (
jnp.zeros(shape=(num_factors,))
.at[pool_choices_factor_indices]
.add(
edge_states_one_hot_decoding[pool_choices_msg_indices + 1]
)
)
pool_indicators_decoded = edge_states_one_hot_decoding[
pool_indicators_edge_states + 1
]
energy = jnp.where(
jnp.any(pool_choices_decoded != pool_indicators_decoded),
jnp.inf, # invalid decoding
0.0
)
return energy
@staticmethod
def compute_factor_energy(
variables: List[Hashable],
vars_to_map_states: Dict[Hashable, Any],
**kwargs,
) -> float:
"""Returns the contribution to the energy of a single PoolFactor.
Args:
variables: List of variables connected by the PoolFactor
vars_to_map_states: A dictionary mapping each individual variable to
its MAP state.
**kwargs: Other parameters, not used
"""
vars_decoded_states = np.array(
[vars_to_map_states[var] for var in variables]
)
pool_choices_decoded_states = vars_decoded_states[:-1]
pool_indicators_decoded_states = vars_decoded_states[-1]
if (
np.sum(pool_choices_decoded_states)
!= pool_indicators_decoded_states
):
warnings.warn(
f"Invalid decoding for Pool factor {variables} "
f"with pool choices set to {pool_choices_decoded_states} "
f"and pool indicators set to {pool_indicators_decoded_states}!"
)
factor_energy = np.inf
else:
factor_energy = 0.0
return factor_energy
# pylint: disable=unused-argument
@functools.partial(jax.jit, static_argnames=("temperature", "normalize"))
def pass_pool_fac_to_var_messages(
vtof_msgs: jnp.ndarray,
pool_choices_factor_indices: jnp.ndarray,
pool_choices_msg_indices: jnp.ndarray,
pool_indicators_edge_states: jnp.ndarray,
temperature: float,
normalize: bool,
log_potentials: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""Passes messages from PoolFactors to Variables.
Args:
vtof_msgs: Array of shape (num_edge_states,). This holds all the flattened
variable to all the PoolFactors messages.
pool_choices_factor_indices: Array of shape (num_pool_choices,)
pool_choices_factor_indices[ii] contains the global PoolFactor index of
the pool choice variable's state 0
Only takes into account the PoolFactors of the FactorGraph
pool_choices_msg_indices: Array of shape (num_pool_choices,)
pool_choices_msg_indices[ii] contains the message index of the pool choice
variable's state 0
The message index of the pool choice variable's state 1 is
pool_choices_msg_indices[ii] + 1
Only takes into account the PoolFactors of the FactorGraph
pool_indicators_edge_states: Array of shape (num_pool_factors,)
pool_indicators_edge_states[ii] contains the message index of the pool
indicator variable's state 0
The message index of the pool indicator variable's state 1 is
pool_indicators_edge_states[ii] + 1
Only takes into account the PoolFactors of the FactorGraph
temperature: Temperature for loopy belief propagation. 1.0 corresponds to
sum-product, 0.0 corresponds to max-product.
normalize: Whether we normalize the outgoing messages. Set to True for BP
and to False for Smooth Dual LP.
log_potentials: Optional array of log potentials
Returns:
Array of shape (num_edge_states,). This holds all the flattened PoolFactors
to variable messages.
Note: The updates below are derived to be mathematically stable at low
temperatures in (0, 0.1].
logaddexp_with_temp, logminusexp_with_temp and logsumexps_with_temp are also
derived to be numerically stable at these low temperatures.
"""
num_factors = pool_indicators_edge_states.shape[0]
pool_choices_tof_msgs_diffs = (
vtof_msgs[pool_choices_msg_indices + 1]
- vtof_msgs[pool_choices_msg_indices]
)
pool_choices_tof_msgs_zeros = vtof_msgs[pool_choices_msg_indices]
pool_indicators_tof_msgs_diffs = (
vtof_msgs[pool_indicators_edge_states + 1]
- vtof_msgs[pool_indicators_edge_states]
)
pool_indicators_tof_msgs_ones = vtof_msgs[pool_indicators_edge_states + 1]
# First, get the easier outgoing messages to pool choices and pool indicators
sums_pool_choices_tof_msgs_zeros = (
jnp.zeros((num_factors,))
.at[pool_choices_factor_indices]
.add(pool_choices_tof_msgs_zeros)
)
pool_choices_msgs_ones = (
sums_pool_choices_tof_msgs_zeros[pool_choices_factor_indices]
+ pool_indicators_tof_msgs_ones[pool_choices_factor_indices]
- pool_choices_tof_msgs_zeros
)
pool_indicators_msgs_zeros = sums_pool_choices_tof_msgs_zeros
# Second derive the other outgoing messages
# Get the maxes and argmaxes of pool_choices_tof_msgs_diffs per factor
(
pool_choices_diffs_maxes,
pool_choices_diffs_argmaxes,
) = update_utils.get_maxes_and_argmaxes(
pool_choices_tof_msgs_diffs, pool_choices_factor_indices, num_factors
)
# Consider the max-product case separately.
if temperature == 0.0:
# Get the second maxes and argmaxes per factor
pool_choices_diffs_wo_maxes = pool_choices_tof_msgs_diffs.at[
pool_choices_diffs_argmaxes
].set(NEG_INF)
pool_choices_diffs_second_maxes = (
jnp.full(shape=(num_factors,), fill_value=NEG_INF)
.at[pool_choices_factor_indices]
.max(pool_choices_diffs_wo_maxes)
)
# Get the difference between the outgoing messages
pool_choices_msgs_diffs = jnp.minimum(
pool_indicators_tof_msgs_diffs, -pool_choices_diffs_maxes
)[pool_choices_factor_indices]
pool_choices_msgs_diffs = pool_choices_msgs_diffs.at[
pool_choices_diffs_argmaxes
].set(
jnp.minimum(
pool_indicators_tof_msgs_diffs, -pool_choices_diffs_second_maxes,
)
)
pool_indicators_msgs_diffs = pool_choices_diffs_maxes
else:
# Stable difference between the pool indicators outgoing messages
pool_indicators_msgs_diffs = update_utils.logsumexps_with_temp(
data=pool_choices_tof_msgs_diffs,
labels=pool_choices_factor_indices,
num_labels=num_factors,
temperature=temperature,
maxes=pool_choices_diffs_maxes
)
factor_logsumexp_msgs_diffs = update_utils.logaddexp_with_temp(
pool_indicators_msgs_diffs,
-pool_indicators_tof_msgs_diffs,
temperature
)
# Stable difference between the pool choices outgoing messages
# Except for the pool_choices_tof_msgs_diffs argmaxes
pool_choices_msgs_diffs = - update_utils.logminusexp_with_temp(
factor_logsumexp_msgs_diffs[pool_choices_factor_indices],
pool_choices_tof_msgs_diffs,
temperature,
)
# pool_choices_msgs_diffs above is not numerically stable for the
# pool_choices_diffs_argmaxes. The stable update is derived below
pool_choices_indicators_diffs = pool_choices_tof_msgs_diffs.at[
pool_choices_diffs_argmaxes
].set(-pool_indicators_tof_msgs_diffs)
pool_choices_msgs_diffs_argmaxes = - update_utils.logsumexps_with_temp(
data=pool_choices_indicators_diffs,
labels=pool_choices_factor_indices,
num_labels=num_factors,
temperature=temperature,
)
pool_choices_msgs_diffs = pool_choices_msgs_diffs.at[
pool_choices_diffs_argmaxes
].set(pool_choices_msgs_diffs_argmaxes)
# Special case: factors with a single pool choice
num_pool_choices = jnp.bincount(
pool_choices_factor_indices, length=num_factors
)
first_pool_choices = jnp.concatenate(
[jnp.zeros(1, dtype=int), jnp.cumsum(num_pool_choices)]
)[:-1]
pool_choices_msgs_diffs = pool_choices_msgs_diffs.at[first_pool_choices].set(
jnp.where(
num_pool_choices == 1,
pool_indicators_tof_msgs_diffs,
pool_choices_msgs_diffs[first_pool_choices],
),
)
pool_choices_msgs_ones = pool_choices_msgs_ones.at[first_pool_choices].set(
jnp.where(
num_pool_choices == 1,
pool_indicators_tof_msgs_ones,
pool_choices_msgs_ones[first_pool_choices],
),
)
# Outgoing messages
ftov_msgs = jnp.zeros_like(vtof_msgs)
if normalize:
ftov_msgs = ftov_msgs.at[pool_choices_msg_indices + 1].set(
pool_choices_msgs_diffs
)
ftov_msgs = ftov_msgs.at[pool_indicators_edge_states + 1].set(
pool_indicators_msgs_diffs
)
else:
ftov_msgs = ftov_msgs.at[pool_choices_msg_indices + 1].set(
pool_choices_msgs_ones
)
ftov_msgs = ftov_msgs.at[pool_choices_msg_indices].set(
pool_choices_msgs_ones - pool_choices_msgs_diffs
)
ftov_msgs = ftov_msgs.at[pool_indicators_edge_states + 1].set(
pool_indicators_msgs_zeros + pool_indicators_msgs_diffs
)
ftov_msgs = ftov_msgs.at[pool_indicators_edge_states].set(
pool_indicators_msgs_zeros
)
return ftov_msgs