|
4 | 4 | import logging
|
5 | 5 |
|
6 | 6 | from models import QHNet
|
7 |
| -from torchvision.transforms import Compose |
8 | 7 | from torch_geometric.loader import DataLoader
|
| 8 | +from torch_scatter import scatter_sum |
9 | 9 |
|
10 | 10 | from datasets import QH9Stable, QH9Dynamic
|
11 | 11 | from torch_ema import ExponentialMovingAverage
|
12 | 12 | from transformers import get_polynomial_decay_schedule_with_warmup
|
13 | 13 | logger = logging.getLogger()
|
14 | 14 |
|
15 | 15 |
|
16 |
| -def recorder_pos(data): |
17 |
| - data.pos = torch.stack( |
18 |
| - [data.pos[:, 2], data.pos[:, 0], data.pos[:, 1]], |
19 |
| - dim=1 |
20 |
| - ) |
21 |
| - return data |
22 |
| - |
23 |
| - |
24 | 16 | def criterion(outputs, target, loss_weights):
|
25 | 17 | error_dict = {}
|
26 | 18 | keys = loss_weights.keys()
|
27 |
| - # the diagonal and non-diagonal should be considered with the mask |
28 | 19 | try:
|
29 | 20 | for key in keys:
|
| 21 | + row = target.edge_index[0] |
| 22 | + edge_batch = target.batch[row] |
30 | 23 | diff_diagonal = outputs[f'{key}_diagonal_blocks']-target[f'diagonal_{key}']
|
31 |
| - mse_diagonal = torch.sum(diff_diagonal**2 * target[f"diagonal_{key}_mask"]) |
32 |
| - mae_diagonal = torch.sum(torch.abs(diff_diagonal) * target[f"diagonal_{key}_mask"]) |
33 |
| - count_sum_diagonal = torch.sum(target[f"diagonal_{key}_mask"]) |
| 24 | + mse_diagonal = torch.sum(diff_diagonal**2 * target[f"diagonal_{key}_mask"], dim=[1, 2]) |
| 25 | + mae_diagonal = torch.sum(torch.abs(diff_diagonal) * target[f"diagonal_{key}_mask"], dim=[1, 2]) |
| 26 | + count_sum_diagonal = torch.sum(target[f"diagonal_{key}_mask"], dim=[1, 2]) |
| 27 | + mse_diagonal = scatter_sum(mse_diagonal, target.batch) |
| 28 | + mae_diagonal = scatter_sum(mae_diagonal, target.batch) |
| 29 | + count_sum_diagonal = scatter_sum(count_sum_diagonal, target.batch) |
34 | 30 |
|
35 | 31 | diff_non_diagonal = outputs[f'{key}_non_diagonal_blocks']-target[f'non_diagonal_{key}']
|
36 |
| - mse_non_diagonal = torch.sum(diff_non_diagonal**2 * target[f"non_diagonal_{key}_mask"]) |
37 |
| - mae_non_diagonal = torch.sum(torch.abs(diff_non_diagonal) * target[f"non_diagonal_{key}_mask"]) |
38 |
| - count_sum_non_diagonal = torch.sum(target[f"non_diagonal_{key}_mask"]) |
| 32 | + mse_non_diagonal = torch.sum(diff_non_diagonal**2 * target[f"non_diagonal_{key}_mask"], dim=[1, 2]) |
| 33 | + mae_non_diagonal = torch.sum(torch.abs(diff_non_diagonal) * target[f"non_diagonal_{key}_mask"], dim=[1, 2]) |
| 34 | + count_sum_non_diagonal = torch.sum(target[f"non_diagonal_{key}_mask"], dim=[1, 2]) |
| 35 | + mse_non_diagonal = scatter_sum(mse_non_diagonal, edge_batch) |
| 36 | + mae_non_diagonal = scatter_sum(mae_non_diagonal, edge_batch) |
| 37 | + count_sum_non_diagonal = scatter_sum(count_sum_non_diagonal, edge_batch) |
39 | 38 |
|
40 |
| - mae = (mae_diagonal + mae_non_diagonal) / (count_sum_diagonal + count_sum_non_diagonal) |
41 |
| - mse = (mse_diagonal + mse_non_diagonal) / (count_sum_diagonal + count_sum_non_diagonal) |
| 39 | + mae = ((mae_diagonal + mae_non_diagonal) / (count_sum_diagonal + count_sum_non_diagonal)).mean() |
| 40 | + mse = ((mse_diagonal + mse_non_diagonal) / (count_sum_diagonal + count_sum_non_diagonal)).mean() |
42 | 41 |
|
43 | 42 | error_dict[key+'_mae'] = mae
|
44 | 43 | error_dict[key+'_rmse'] = torch.sqrt(mse)
|
45 |
| - error_dict[key + '_diagonal_mae'] = mae_diagonal / count_sum_diagonal |
46 |
| - error_dict[key + '_non_diagonal_mae'] = mae_non_diagonal / count_sum_non_diagonal |
| 44 | + error_dict[key + '_diagonal_mae'] = (mae_diagonal / count_sum_diagonal).mean() |
| 45 | + error_dict[key + '_non_diagonal_mae'] = (mae_non_diagonal / count_sum_non_diagonal).mean() |
47 | 46 | loss = mse + mae
|
48 | 47 | error_dict[key] = loss
|
49 | 48 | if 'loss' in error_dict.keys():
|
@@ -100,6 +99,7 @@ def main(conf):
|
100 | 99 | if torch.cuda.is_available():
|
101 | 100 | torch.cuda.manual_seed_all(0)
|
102 | 101 |
|
| 102 | + # root_path = '/data/meng/QC_features' |
103 | 103 | root_path = os.path.join(os.sep.join(os.getcwd().split(os.sep)[:-3]))
|
104 | 104 | # determine whether GPU is used for training
|
105 | 105 | if torch.cuda.is_available():
|
@@ -130,7 +130,6 @@ def main(conf):
|
130 | 130 | test_dataset, batch_size=conf.datasets.test_batch_size, shuffle=False,
|
131 | 131 | num_workers=conf.datasets.num_workers, pin_memory=conf.datasets.pin_memory)
|
132 | 132 | train_iterator = iter(train_data_loader)
|
133 |
| - |
134 | 133 | # define model
|
135 | 134 | model = QHNet(
|
136 | 135 | in_node_features=1,
|
@@ -237,7 +236,7 @@ def main(conf):
|
237 | 236 |
|
238 | 237 | def post_processing(batch, default_type):
|
239 | 238 | for key in batch.keys:
|
240 |
| - if torch.is_floating_point(batch[key]): |
| 239 | + if torch.is_tensor(batch[key]) and torch.is_floating_point(batch[key]): |
241 | 240 | batch[key] = batch[key].type(default_type)
|
242 | 241 | return batch
|
243 | 242 |
|
|
0 commit comments