17
17
# %matplotlib inline
18
18
import os
19
19
import time
20
- from jax import jit , tree_util
21
- from jax import numpy as jnp
20
+
22
21
import matplotlib .pyplot as plt
23
22
import numpy as np
24
-
23
+ from jax import jit
24
+ from jax import numpy as jnp
25
+ from jax import tree_util
25
26
from load_data import get_mnist_data_iters
26
27
from preproc import Preproc
27
28
53
54
# -
54
55
55
56
56
-
57
57
# %% [markdown]
58
58
# ## 2. Load the model
59
59
#
60
60
#
61
61
62
62
# %%
63
63
directory = f"/storage/users/skushagra/pgmax_rcn_artifacts/model_science_{ train_size } _{ hps } _{ vps } "
64
- frcs = np .load (f"{ directory } /frcs.npy" , allow_pickle = True , encoding = ' latin1' )
65
- edges = np .load (f"{ directory } /edges.npy" , allow_pickle = True , encoding = ' latin1' )
66
- M = (2 * hps + 1 ) * (2 * vps + 1 ) + 1
64
+ frcs = np .load (f"{ directory } /frcs.npy" , allow_pickle = True , encoding = " latin1" )
65
+ edges = np .load (f"{ directory } /edges.npy" , allow_pickle = True , encoding = " latin1" )
66
+ M = (2 * hps + 1 ) * (2 * vps + 1 ) + 1
67
67
68
68
# %% [markdown]
69
69
# ## 3. Visualize loaded model.
88
88
plt .imshow (img , cmap = "gray" )
89
89
90
90
91
-
92
91
# %% [markdown]
93
92
# ## 4. Make pgmax graph
94
93
#
123
122
# %%
124
123
125
124
# %% [markdown]
126
- # ## 3.2.1 Pre-compute the valid configs for different perturb radii.
125
+ # ## 3.2.1 Pre-compute the valid configs for different perturb radii.
127
126
#
128
127
#
129
128
@@ -150,6 +149,7 @@ def valid_configs(r):
150
149
151
150
return np .stack ([rows , cols ], axis = 1 )
152
151
152
+
153
153
max_perturb_radii = 25
154
154
phis = []
155
155
for r in range (max_perturb_radii ):
@@ -181,7 +181,6 @@ def valid_configs(r):
181
181
print (f"Creating factors took { end - start :.3f} seconds." )
182
182
183
183
184
-
185
184
# %% [markdown]
186
185
# ## 4. Run inference
187
186
#
@@ -196,7 +195,7 @@ def valid_configs(r):
196
195
def initialize_evidences (test_img , frcs , hps , vps ):
197
196
preproc_layer = Preproc (cross_channel_pooling = True )
198
197
bu_msg = preproc_layer .fwd_infer (test_img )
199
-
198
+
200
199
evidence_updates = {}
201
200
for idx in range (frcs .shape [0 ]):
202
201
frc = frcs [idx ]
@@ -237,25 +236,25 @@ def initialize_evidences(test_img, frcs, hps, vps):
237
236
print (f"Initializing evidences took { end - start :.3f} seconds for image { test_idx } ." )
238
237
239
238
start = end
240
- map_states = graph .decode_map_states (get_beliefs_fn (run_bp_fn (evidence_updates = evidence_updates )))
239
+ map_states = graph .decode_map_states (
240
+ get_beliefs_fn (run_bp_fn (evidence_updates = evidence_updates ))
241
+ )
241
242
end = time .time ()
242
243
print (f"Max product inference took { end - start :.3f} seconds for image { test_idx } ." )
243
244
244
245
map_states_dict [test_idx ] = map_states
245
246
start = end
246
247
score = tree_util .tree_multimap (
247
- lambda evidence , map : jnp .sum (
248
- evidence [jnp .arange (map .shape [0 ]), map ]
249
- ),
250
- evidence_updates ,
251
- map_states
252
- )
253
- for ii in score : scores [test_idx , ii ] = score [ii ]
248
+ lambda evidence , map : jnp .sum (evidence [jnp .arange (map .shape [0 ]), map ]),
249
+ evidence_updates ,
250
+ map_states ,
251
+ )
252
+ for ii in score :
253
+ scores [test_idx , ii ] = score [ii ]
254
254
end = time .time ()
255
255
print (f"Computing scores took { end - start :.3f} seconds for image { test_idx } ." )
256
256
257
- #scores[test_idx, :] = score
258
-
257
+ # scores[test_idx, :] = score
259
258
260
259
261
260
# %% [markdown]
@@ -270,7 +269,6 @@ def initialize_evidences(test_img, frcs, hps, vps):
270
269
print (f"accuracy = { accuracy } " )
271
270
272
271
273
-
274
272
# %% [markdown]
275
273
# ## 6. Visualize predictions (backtrace)
276
274
#
@@ -281,12 +279,10 @@ def initialize_evidences(test_img, frcs, hps, vps):
281
279
plt .imshow (test_set [test_idx ][0 ], cmap = "gray" )
282
280
283
281
284
-
285
282
# %%
286
283
# ## 6.1 Backtrace of some models on this test image
287
284
288
285
289
-
290
286
# %%
291
287
# +
292
288
map_states = map_states_dict [test_idx ]
@@ -306,11 +302,10 @@ def initialize_evidences(test_img, frcs, hps, vps):
306
302
plt .figure (figsize = (15 , 15 ))
307
303
308
304
for k , index in enumerate (range (0 , len (train_set ), 5 )):
309
- plt .subplot (1 , 4 , 1 + k )
305
+ plt .subplot (1 , 4 , 1 + k )
310
306
plt .title (f" Model { int (train_labels [index ])} " )
311
307
plt .imshow (imgs [index , :, :], cmap = "gray" )
312
308
# -
313
309
314
310
315
-
316
311
# %%
0 commit comments