Skip to content

Commit 8ba868b

Browse files
authored
Update test_dft_acceleration.py
1 parent 2089c8e commit 8ba868b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

OpenDFT/QHBench/QH9/test_dft_acceleration.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def get_stable_dataset_split(root_path):
235235
os.path.join(root_path, 'datasets', 'QH9Stable', 'processed', 'processed_QH9Stable_random.pt')
236236
processed_ood = \
237237
os.path.join(root_path, 'datasets', 'QH9Stable', 'processed', 'processed_QH9Stable_size_ood.pt')
238-
split_idx_iid_test_mask = torch.load(processed_random)[4]
239-
split_idx_ood_test_mask = torch.load(processed_ood)[4]
238+
split_idx_iid_test_mask = torch.load(processed_random)[2]
239+
split_idx_ood_test_mask = torch.load(processed_ood)[2]
240240
# test_data_mask = np.logical_and(split_idx_iid_test_mask, split_idx_ood_test_mask)
241241
# test_data_indices = np.where(test_data_mask)[0]
242242
test_data_indices = np.intersect1d(split_idx_iid_test_mask, split_idx_ood_test_mask)

0 commit comments

Comments
 (0)