Skip to content
This repository was archived by the owner on Dec 5, 2024. It is now read-only.

Commit 282ac42

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 787b573 commit 282ac42

File tree

3 files changed

+24
-29
lines changed

3 files changed

+24
-29
lines changed

examples/rcn/helpers.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ def get_number_of_states(hps, vps):
1818

1919
def initialize_evidences(inf_img, frcs, hps, vps, neg_inf=-1000):
2020
M = get_number_of_states(hps, vps)
21-
21+
2222
preproc_layer = Preproc(cross_channel_pooling=True)
2323
bu_msg = preproc_layer.fwd_infer(inf_img)
24-
24+
2525
evidence_updates = {}
2626

2727
for idx in range(frcs.shape[0]):

examples/rcn/inference_pgmax_small.py

+21-26
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717
# %matplotlib inline
1818
import os
1919
import time
20-
from jax import jit, tree_util
21-
from jax import numpy as jnp
20+
2221
import matplotlib.pyplot as plt
2322
import numpy as np
24-
23+
from jax import jit
24+
from jax import numpy as jnp
25+
from jax import tree_util
2526
from load_data import get_mnist_data_iters
2627
from preproc import Preproc
2728

@@ -53,17 +54,16 @@
5354
# -
5455

5556

56-
5757
# %% [markdown]
5858
# ## 2. Load the model
5959
#
6060
#
6161

6262
# %%
6363
directory = f"/storage/users/skushagra/pgmax_rcn_artifacts/model_science_{train_size}_{hps}_{vps}"
64-
frcs = np.load(f"{directory}/frcs.npy", allow_pickle=True, encoding='latin1')
65-
edges = np.load(f"{directory}/edges.npy", allow_pickle=True, encoding='latin1')
66-
M = (2 * hps + 1) * (2 * vps + 1) + 1
64+
frcs = np.load(f"{directory}/frcs.npy", allow_pickle=True, encoding="latin1")
65+
edges = np.load(f"{directory}/edges.npy", allow_pickle=True, encoding="latin1")
66+
M = (2 * hps + 1) * (2 * vps + 1) + 1
6767

6868
# %% [markdown]
6969
# ## 3. Visualize loaded model.
@@ -88,7 +88,6 @@
8888
plt.imshow(img, cmap="gray")
8989

9090

91-
9291
# %% [markdown]
9392
# ## 4. Make pgmax graph
9493
#
@@ -123,7 +122,7 @@
123122
# %%
124123

125124
# %% [markdown]
126-
# ## 3.2.1 Pre-compute the valid configs for different perturb radii.
125+
# ## 3.2.1 Pre-compute the valid configs for different perturb radii.
127126
#
128127
#
129128

@@ -150,6 +149,7 @@ def valid_configs(r):
150149

151150
return np.stack([rows, cols], axis=1)
152151

152+
153153
max_perturb_radii = 25
154154
phis = []
155155
for r in range(max_perturb_radii):
@@ -181,7 +181,6 @@ def valid_configs(r):
181181
print(f"Creating factors took {end-start:.3f} seconds.")
182182

183183

184-
185184
# %% [markdown]
186185
# ## 4. Run inference
187186
#
@@ -196,7 +195,7 @@ def valid_configs(r):
196195
def initialize_evidences(test_img, frcs, hps, vps):
197196
preproc_layer = Preproc(cross_channel_pooling=True)
198197
bu_msg = preproc_layer.fwd_infer(test_img)
199-
198+
200199
evidence_updates = {}
201200
for idx in range(frcs.shape[0]):
202201
frc = frcs[idx]
@@ -237,25 +236,25 @@ def initialize_evidences(test_img, frcs, hps, vps):
237236
print(f"Initializing evidences took {end-start:.3f} seconds for image {test_idx}.")
238237

239238
start = end
240-
map_states = graph.decode_map_states(get_beliefs_fn(run_bp_fn(evidence_updates=evidence_updates)))
239+
map_states = graph.decode_map_states(
240+
get_beliefs_fn(run_bp_fn(evidence_updates=evidence_updates))
241+
)
241242
end = time.time()
242243
print(f"Max product inference took {end-start:.3f} seconds for image {test_idx}.")
243244

244245
map_states_dict[test_idx] = map_states
245246
start = end
246247
score = tree_util.tree_multimap(
247-
lambda evidence, map: jnp.sum(
248-
evidence[jnp.arange(map.shape[0]), map]
249-
),
250-
evidence_updates,
251-
map_states
252-
)
253-
for ii in score: scores[test_idx, ii] = score[ii]
248+
lambda evidence, map: jnp.sum(evidence[jnp.arange(map.shape[0]), map]),
249+
evidence_updates,
250+
map_states,
251+
)
252+
for ii in score:
253+
scores[test_idx, ii] = score[ii]
254254
end = time.time()
255255
print(f"Computing scores took {end-start:.3f} seconds for image {test_idx}.")
256256

257-
#scores[test_idx, :] = score
258-
257+
# scores[test_idx, :] = score
259258

260259

261260
# %% [markdown]
@@ -270,7 +269,6 @@ def initialize_evidences(test_img, frcs, hps, vps):
270269
print(f"accuracy = {accuracy}")
271270

272271

273-
274272
# %% [markdown]
275273
# ## 6. Visualize predictions (backtrace)
276274
#
@@ -281,12 +279,10 @@ def initialize_evidences(test_img, frcs, hps, vps):
281279
plt.imshow(test_set[test_idx][0], cmap="gray")
282280

283281

284-
285282
# %%
286283
# ## 6.1 Backtrace of some models on this test image
287284

288285

289-
290286
# %%
291287
# +
292288
map_states = map_states_dict[test_idx]
@@ -306,11 +302,10 @@ def initialize_evidences(test_img, frcs, hps, vps):
306302
plt.figure(figsize=(15, 15))
307303

308304
for k, index in enumerate(range(0, len(train_set), 5)):
309-
plt.subplot(1, 4, 1+k)
305+
plt.subplot(1, 4, 1 + k)
310306
plt.title(f" Model {int(train_labels[index])}")
311307
plt.imshow(imgs[index, :, :], cmap="gray")
312308
# -
313309

314310

315-
316311
# %%

examples/rcn/model_create/save_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import time
1111

1212
import numpy as np
13-
from learning import Model
1413
from helpers import get_number_of_states, index_to_rc, rc_to_index
14+
from learning import Model
1515
from load_data import get_mnist_data_iters
1616

1717

0 commit comments

Comments
 (0)