-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path01b-check_training.py
53 lines (41 loc) · 1.32 KB
/
01b-check_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import logging
import pathlib
import pickle
import jax
from absl import app, flags
from jax import numpy as jnp
FLAGS = flags.FLAGS
flags.DEFINE_string("folder", None, "folder with posterior and parame files")
flags.mark_flags_as_required(["folder"])
def open_pickle(fl):
with open(fl, "rb") as handle:
d = pickle.load(handle)
return d
def check_posterior(file):
dic = open_pickle(file)
nans_there = jnp.isnan(dic["samples"])
num_nans = jnp.sum(nans_there)
ratio_ans = num_nans / jnp.prod(jnp.asarray(nans_there.shape))
if num_nans != 0:
logging.warning(f"file {file} contains {ratio_ans} nans")
def check_params(file):
dic = open_pickle(file)
if "params" not in dic:
return
nans_there = jax.tree_map(jnp.isnan, dic["params"])
num_nans = jax.tree_map(jnp.sum, nans_there)
num_nans = jax.tree_util.tree_reduce(jnp.add, num_nans)
if num_nans != 0:
logging.warning(f"file {file} contains nans in training parameters")
def main(argv):
del argv
dir = pathlib.Path(FLAGS.folder)
for p in dir.rglob("*"):
if p.is_file():
if "posteriors" in p.stem:
check_posterior(p)
elif "params" in p.stem:
check_params(p)
if __name__ == "__main__":
jax.config.config_with_absl()
app.run(main)