Skip to content

Commit e7ba030

Browse files
committed
add new dynamic dataset
1 parent 96d9ad4 commit e7ba030

File tree

9 files changed

+1323
-186
lines changed

9 files changed

+1323
-186
lines changed
+3-4
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
defaults:
2-
- datasets: QH9-stable
2+
- datasets: QH9-dynamic
33
ex_name: 'ex_name'
4-
device: 0
4+
device: 2
55
ckpt_dir: 'checkpoints'
66
split_seed: 42
77
optimizer: adam
88
ema_start_epoch: -1
99

10+
trained_model: '.'
1011
# For evaluating trained model
11-
trained_model: './' # path to your trained model
12-
1312
seed: 0

OpenDFT/QHBench/QH9/config/datasets/QH9-dynamic.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
dataset_name: QH9Dynamic
2-
split: geometry #[geometry, mol]
2+
split: mol #[geometry, mol]
33

44
train_batch_size: 32
55
valid_batch_size: 32
@@ -13,7 +13,7 @@ pin_memory: True
1313
num_workers: 8
1414

1515
warmup_steps: 1000
16-
total_steps: 250000
16+
total_steps: 260000
1717
lr_end: 1e-7
1818

1919
train_batch_interval: 100

OpenDFT/QHBench/QH9/datasets.py

+151-113
Large diffs are not rendered by default.

OpenDFT/QHBench/QH9/main.py

+21-22
Original file line numberDiff line numberDiff line change
@@ -4,46 +4,45 @@
44
import logging
55

66
from models import QHNet
7-
from torchvision.transforms import Compose
87
from torch_geometric.loader import DataLoader
8+
from torch_scatter import scatter_sum
99

1010
from datasets import QH9Stable, QH9Dynamic
1111
from torch_ema import ExponentialMovingAverage
1212
from transformers import get_polynomial_decay_schedule_with_warmup
1313
logger = logging.getLogger()
1414

1515

16-
def recorder_pos(data):
17-
data.pos = torch.stack(
18-
[data.pos[:, 2], data.pos[:, 0], data.pos[:, 1]],
19-
dim=1
20-
)
21-
return data
22-
23-
2416
def criterion(outputs, target, loss_weights):
2517
error_dict = {}
2618
keys = loss_weights.keys()
27-
# the diagonal and non-diagonal should be considered with the mask
2819
try:
2920
for key in keys:
21+
row = target.edge_index[0]
22+
edge_batch = target.batch[row]
3023
diff_diagonal = outputs[f'{key}_diagonal_blocks']-target[f'diagonal_{key}']
31-
mse_diagonal = torch.sum(diff_diagonal**2 * target[f"diagonal_{key}_mask"])
32-
mae_diagonal = torch.sum(torch.abs(diff_diagonal) * target[f"diagonal_{key}_mask"])
33-
count_sum_diagonal = torch.sum(target[f"diagonal_{key}_mask"])
24+
mse_diagonal = torch.sum(diff_diagonal**2 * target[f"diagonal_{key}_mask"], dim=[1, 2])
25+
mae_diagonal = torch.sum(torch.abs(diff_diagonal) * target[f"diagonal_{key}_mask"], dim=[1, 2])
26+
count_sum_diagonal = torch.sum(target[f"diagonal_{key}_mask"], dim=[1, 2])
27+
mse_diagonal = scatter_sum(mse_diagonal, target.batch)
28+
mae_diagonal = scatter_sum(mae_diagonal, target.batch)
29+
count_sum_diagonal = scatter_sum(count_sum_diagonal, target.batch)
3430

3531
diff_non_diagonal = outputs[f'{key}_non_diagonal_blocks']-target[f'non_diagonal_{key}']
36-
mse_non_diagonal = torch.sum(diff_non_diagonal**2 * target[f"non_diagonal_{key}_mask"])
37-
mae_non_diagonal = torch.sum(torch.abs(diff_non_diagonal) * target[f"non_diagonal_{key}_mask"])
38-
count_sum_non_diagonal = torch.sum(target[f"non_diagonal_{key}_mask"])
32+
mse_non_diagonal = torch.sum(diff_non_diagonal**2 * target[f"non_diagonal_{key}_mask"], dim=[1, 2])
33+
mae_non_diagonal = torch.sum(torch.abs(diff_non_diagonal) * target[f"non_diagonal_{key}_mask"], dim=[1, 2])
34+
count_sum_non_diagonal = torch.sum(target[f"non_diagonal_{key}_mask"], dim=[1, 2])
35+
mse_non_diagonal = scatter_sum(mse_non_diagonal, edge_batch)
36+
mae_non_diagonal = scatter_sum(mae_non_diagonal, edge_batch)
37+
count_sum_non_diagonal = scatter_sum(count_sum_non_diagonal, edge_batch)
3938

40-
mae = (mae_diagonal + mae_non_diagonal) / (count_sum_diagonal + count_sum_non_diagonal)
41-
mse = (mse_diagonal + mse_non_diagonal) / (count_sum_diagonal + count_sum_non_diagonal)
39+
mae = ((mae_diagonal + mae_non_diagonal) / (count_sum_diagonal + count_sum_non_diagonal)).mean()
40+
mse = ((mse_diagonal + mse_non_diagonal) / (count_sum_diagonal + count_sum_non_diagonal)).mean()
4241

4342
error_dict[key+'_mae'] = mae
4443
error_dict[key+'_rmse'] = torch.sqrt(mse)
45-
error_dict[key + '_diagonal_mae'] = mae_diagonal / count_sum_diagonal
46-
error_dict[key + '_non_diagonal_mae'] = mae_non_diagonal / count_sum_non_diagonal
44+
error_dict[key + '_diagonal_mae'] = (mae_diagonal / count_sum_diagonal).mean()
45+
error_dict[key + '_non_diagonal_mae'] = (mae_non_diagonal / count_sum_non_diagonal).mean()
4746
loss = mse + mae
4847
error_dict[key] = loss
4948
if 'loss' in error_dict.keys():
@@ -100,6 +99,7 @@ def main(conf):
10099
if torch.cuda.is_available():
101100
torch.cuda.manual_seed_all(0)
102101

102+
# root_path = '/data/meng/QC_features'
103103
root_path = os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-3]))
104104
# determine whether GPU is used for training
105105
if torch.cuda.is_available():
@@ -130,7 +130,6 @@ def main(conf):
130130
test_dataset, batch_size=conf.datasets.test_batch_size, shuffle=False,
131131
num_workers=conf.datasets.num_workers, pin_memory=conf.datasets.pin_memory)
132132
train_iterator = iter(train_data_loader)
133-
134133
# define model
135134
model = QHNet(
136135
in_node_features=1,
@@ -237,7 +236,7 @@ def main(conf):
237236

238237
def post_processing(batch, default_type):
239238
for key in batch.keys:
240-
if torch.is_floating_point(batch[key]):
239+
if torch.is_tensor(batch[key]) and torch.is_floating_point(batch[key]):
241240
batch[key] = batch[key].type(default_type)
242241
return batch
243242

OpenDFT/QHBench/QH9/models/QHNet.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,9 @@ def forward(self, data, keep_blocks=True):
611611
data.node_attr, data.edge_index, data.edge_attr, data.edge_sh = \
612612
node_attr, edge_index, rbf_new, edge_sh
613613

614-
_, full_edge_index, full_edge_attr, full_edge_sh, transpose_edge_index = self.build_graph(data, 10000)
614+
_, full_edge_index, full_edge_attr, full_edge_sh, transpose_edge_index = \
615+
self.build_graph(data, max_radius=10000)
616+
615617
data.full_edge_index, data.full_edge_attr, data.full_edge_sh = \
616618
full_edge_index, full_edge_attr, full_edge_sh
617619

@@ -652,11 +654,12 @@ def forward(self, data, keep_blocks=True):
652654

653655
return results
654656

655-
def build_graph(self, data, max_radius):
657+
def build_graph(self, data, max_radius, edge_index=None):
656658
node_attr = data.atoms.squeeze()
657-
658-
659-
radius_edges = radius_graph(data.pos, max_radius, data.batch, max_num_neighbors=data.num_nodes)
659+
if edge_index is None:
660+
radius_edges = radius_graph(data.pos, max_radius, data.batch, max_num_neighbors=data.num_nodes)
661+
else:
662+
radius_edges = data.full_edge_index
660663

661664
dst, src = radius_edges
662665
edge_vec = data.pos[dst.long()] - data.pos[src.long()]
+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1-
from .QHNet import QHNet
1+
# from .QHNet import QHNet
2+
from .ori_QHNet_with_bias import QHNet
23

0 commit comments

Comments
 (0)