-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
183 lines (163 loc) · 7.23 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import argparse
import pickle
import yaml
import torch
from glob import glob
from tqdm.auto import tqdm
from easydict import EasyDict
from models.epsnet import *
from utils.datasets import *
from utils.transforms import *
from utils.misc import *
def num_confs(num:str):
if num.endswith('x'):
return lambda x:x*int(num[:-1])
elif int(num) > 0:
return lambda x:int(num)
else:
raise ValueError()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--ckpt', type=str, help='path for loading the checkpoint')
parser.add_argument('--config', type=str, default=None)
parser.add_argument('--save_traj', action='store_true', default=False,
help='whether store the whole trajectory for sampling')
parser.add_argument('--not_use_mask', action='store_true', default=False,
help='whether use mask in sampling')
parser.add_argument('--resume', type=str, default=None)
parser.add_argument('--tag', type=str, default=None)
parser.add_argument('--num_confs', type=num_confs, default=num_confs('2x'))
parser.add_argument('--test_set', type=str, default=None)
parser.add_argument('--start_idx', type=int, default=0)
parser.add_argument('--end_idx', type=int, default=200)
parser.add_argument('--out_dir', type=str, default=None)
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--clip', type=float, default=1000.0)
parser.add_argument('--clip_local', type=float, default=20)
parser.add_argument('--n_steps', type=int, default=5000,
help='sampling num steps; for DSM framework, this means num steps for each noise scale')
parser.add_argument('--global_start_sigma', type=float, default=0.5,
help='enable global gradients only when noise is low')
parser.add_argument('--w_global', type=float, default=0.3,
help='weight for global gradients')
# Parameters for DDPM
parser.add_argument('--sampling_type', type=str, default='ddpm_noisy',
help='generalized, ddpm_noisy, ld: sampling method for DDIM, DDPM or Langevin Dynamics')
parser.add_argument('--eta', type=float, default=1.0,
help='weight for DDIM and DDPM: 0->DDIM, 1->DDPM')
args = parser.parse_args()
# Load checkpoint
ckpt = torch.load(args.ckpt, map_location=torch.device('cpu'))
if args.config is None:
config_path = glob(os.path.join(os.path.dirname(os.path.dirname(args.ckpt)), '*.yml'))[0]
else:
config_path =args.config
with open(config_path, 'r') as f:
config = EasyDict(yaml.safe_load(f))
seed_all(config.train.seed)
log_dir = os.path.dirname(os.path.dirname(args.ckpt))
# Logging
if args.tag is None:
args.tag=os.path.basename(os.path.dirname(args.test_set))
output_dir = get_new_log_dir(log_dir, 'sample', tag=args.tag)
logger = get_logger('test', output_dir)
# Datasets and loaders
logger.info(args)
logger.info('Loading datasets...')
transforms = Compose([
CountNodesPerGraph(),
AddHigherOrderEdges(order=config.model.edge_order), # Offline edge augmentation
])
if args.test_set is None:
test_set = PackedConformationDataset(config.dataset.test, transform=transforms)
logger.info(f'Loading {config.dataset.test}')
else:
test_set = PackedConformationDataset(args.test_set, transform=transforms)
logger.info(f'Loading {args.test_set}')
# Model
logger.info('Loading model...')
model = get_model(ckpt['config'].model).to(args.device)
model.load_state_dict(ckpt['model'])
test_set_selected = []
for i, data in enumerate(test_set):
if not (args.start_idx <= i < args.end_idx): continue
if args.tag=="qm92drugs":
atom_set = set(atom.GetSymbol() for atom in data.rdmol.GetAtoms())
qm9_set={'H','C', 'N', 'O', 'F'}
if atom_set <= qm9_set:
test_set_selected.append(data)
else:
test_set_selected.append(data)
done_smiles = set()
results = []
if args.resume is not None:
with open(args.resume, 'rb') as f:
results = pickle.load(f)
for data in results:
done_smiles.add(data.smiles)
for i, data in enumerate(tqdm(test_set_selected)):
if data.smiles in done_smiles:
logger.info('Molecule#%d is already done.' % i)
continue
num_refs = data.pos_ref.size(0) // data.num_nodes
num_samples = args.num_confs(num_refs)
data_input = data.clone()
data_input['pos_ref'] = None
batch = repeat_data(data_input, num_samples).to(args.device)
clip_local = None
for try_n in range(3): # Maximum number of retry
try:
pos_init = torch.randn(batch.num_nodes, 3).to(args.device)
pos_gen, pos_gen_traj = model.langevin_dynamics_sample_diffusion_subgraph(
atom_type=batch.atom_type,
pos_init=pos_init,
bond_index=batch.edge_index,
bond_type=batch.edge_type,
batch=batch.batch,
num_graphs=batch.num_graphs,
extend_order=False, # Done in transforms.
n_steps=args.n_steps,
step_lr=1e-6,
w_global=args.w_global,
global_start_sigma=args.global_start_sigma,
clip=args.clip,
clip_local=clip_local,
sampling_type=args.sampling_type,
eta=args.eta,
use_mask=not args.not_use_mask
)
pos_gen = pos_gen.cpu()
if args.save_traj:
data.pos_gen = torch.stack(pos_gen_traj)
else:
data.pos_gen = pos_gen
results.append(data)
done_smiles.add(data.smiles)
save_path = os.path.join(output_dir, 'samples_%d.pkl' % i)
logger.info('Saving samples to: %s' % save_path)
with open(save_path, 'wb') as f:
pickle.dump(results, f)
break # No errors occured, break the retry loop
except FloatingPointError:
# clip_local = 100
if try_n==1:
clip_local = args.clip_local
logger.warning(f'Retrying with local clipping. clip_local={clip_local}')
if try_n==2:
clip_local = args.clip_local//2
logger.warning(f'Retrying with local clipping. clip_local={clip_local}')
seed_all(config.train.seed + try_n)
pass
save_path = os.path.join(output_dir, 'samples_all.pkl')
logger.info('Saving samples to: %s' % save_path)
logger.info(output_dir)
def get_mol_key(data):
for i, d in enumerate(test_set_selected):
if d.smiles == data.smiles:
return i
return -1
results.sort(key=get_mol_key)
with open(save_path, 'wb') as f:
pickle.dump(results, f)