Skip to content

Commit

Permalink
Merge branch 'devel' of https://github.com/SumGuo-88/deepmd-kit into …
Browse files Browse the repository at this point in the history
…devel
  • Loading branch information
SumGuo-88 committed Feb 12, 2025
2 parents a235f71 + 324037d commit 5ea3594
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 25 deletions.
29 changes: 22 additions & 7 deletions deepmd/pt/utils/stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def finalize_stats(sys_stat):
elif isinstance(sys_stat[key][0], torch.Tensor):
sys_stat[key] = torch.cat(sys_stat[key], dim=0)
dict_to_device(sys_stat)

def process_element_counts(sys_index, dataset, min_frames_per_element_forstat):
"""Process and update global element counts."""
element_counts, type_name = dataset.get_frame_index_for_elements()
Expand All @@ -124,12 +124,21 @@ def process_element_counts(sys_index, dataset, min_frames_per_element_forstat):
if count > min_frames_per_element_forstat:
global_element_counts[elem]["count"] += min_frames_per_element_forstat
indices = indices[:min_frames_per_element_forstat]
global_element_counts[elem]["indices"].append({"sys_index": sys_index, "frames": indices})
global_element_counts[elem]["indices"].append(
{"sys_index": sys_index, "frames": indices}
)
else:
global_element_counts[elem]["count"] += count
global_element_counts[elem]["indices"].append({"sys_index": sys_index, "frames": indices})
global_element_counts[elem]["indices"].append(
{"sys_index": sys_index, "frames": indices}
)

def process_missing_elements(min_frames_per_element_forstat, global_element_counts, total_element_types, collect_ele):
def process_missing_elements(
min_frames_per_element_forstat,
global_element_counts,
total_element_types,
collect_ele,
):
"""Handle missing elements and check element completeness."""
collect_elements = collect_ele.keys()
missing_elements = total_element_types - collect_elements
Expand All @@ -140,7 +149,9 @@ def process_missing_elements(min_frames_per_element_forstat, global_element_coun
missing_elements.add(ele)
for miss in missing_elements:
sys_indices = global_element_counts[miss].get("indices", [])
newele_counter = collect_ele.get(miss, 0) if miss in collect_miss_element else 0
newele_counter = (
collect_ele.get(miss, 0) if miss in collect_miss_element else 0
)
process_with_new_frame(sys_indices, newele_counter, miss)

def process_with_new_frame(sys_indices, newele_counter, miss):
Expand Down Expand Up @@ -196,12 +207,16 @@ def process_with_new_frame(sys_indices, newele_counter, miss):
process_element_counts(sys_index, dataset, min_frames_per_element_forstat)

if datasets[0].mixed_type and enable_element_completion:
process_missing_elements(min_frames_per_element_forstat, global_element_counts, total_element_types, collect_ele)
process_missing_elements(
min_frames_per_element_forstat,
global_element_counts,
total_element_types,
collect_ele,
)

return lst



def _restore_from_file(
stat_file_path: DPPath,
keys: list[str] = ["energy"],
Expand Down
55 changes: 37 additions & 18 deletions source/tests/pt/test_make_stat_input.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,26 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import unittest
from pathlib import Path
from pathlib import (
Path,
)

import numpy as np
import torch
from torch.utils.data import DataLoader
from deepmd.pt.utils.dataset import DeepmdDataSetForLoader
from deepmd.pt.utils.stat import compute_output_stats, make_stat_input
from deepmd.utils.data import DataRequirementItem
from torch.utils.data import (
DataLoader,
)

from deepmd.pt.utils.dataset import (
DeepmdDataSetForLoader,
)
from deepmd.pt.utils.stat import (
compute_output_stats,
make_stat_input,
)
from deepmd.utils.data import (
DataRequirementItem,
)


def collate_fn(batch):
if isinstance(batch, dict):
Expand All @@ -27,6 +42,7 @@ def collate_fn(batch):

return out


class TestMakeStatInput(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -71,13 +87,13 @@ def test_make_stat_input_with_element_counts(self):
energy = bias.get("energy")
non_zero_count = self.count_non_zero_elements(energy)
self.assertEqual(
non_zero_count,
non_zero_count,
self.real_ntypes,
f"Expected exactly {self.real_ntypes} non-zero elements in energy, but got {non_zero_count}."
f"Expected exactly {self.real_ntypes} non-zero elements in energy, but got {non_zero_count}.",
)

def test_process_missing_elements(self):
#3 frames would be count
# 3 frames would be count
lst = make_stat_input(
datasets=self.datasets,
dataloaders=self.dataloaders,
Expand All @@ -91,9 +107,9 @@ def test_process_missing_elements(self):
energy = sys_stat["energy"]
non_zero_count = self.count_non_zero_elements(energy)
self.assertLess(
non_zero_count,
self.real_ntypes,
f"Expected fewer than {self.real_ntypes} non-zero elements due to missing elements."
non_zero_count,
self.real_ntypes,
f"Expected fewer than {self.real_ntypes} non-zero elements due to missing elements.",
)

def test_with_missing_elements_and_new_frames(self):
Expand All @@ -116,9 +132,11 @@ def test_with_missing_elements_and_new_frames(self):
energy = sys_stat["energy"]
missing_elements.append(self.count_non_zero_elements(energy))

#
self.assertGreater(len(missing_elements), 0, "Expected missing elements to be processed.")

#
self.assertGreater(
len(missing_elements), 0, "Expected missing elements to be processed."
)

lst_new = make_stat_input(
datasets=self.datasets,
dataloaders=self.dataloaders,
Expand All @@ -127,13 +145,13 @@ def test_with_missing_elements_and_new_frames(self):
enable_element_completion=True,
)

#
#
for original, new in zip(lst, lst_new):
energy_ori = np.array(original["energy"].cpu()).flatten()
energy_new = np.array(new["energy"].cpu()).flatten()
self.assertTrue(
np.allclose(energy_ori, energy_new),
msg=f"Energy values don't match. Original: {energy_ori}, New: {energy_new}"
msg=f"Energy values don't match. Original: {energy_ori}, New: {energy_new}",
)

def test_bias(self):
Expand Down Expand Up @@ -173,8 +191,8 @@ def test_bias(self):
)

def test_with_nomissing(self):
#missing element:13,31,37
#only one frame would be count
# missing element:13,31,37
# only one frame would be count
lst_ori = make_stat_input(
datasets=self.datasets,
dataloaders=self.dataloaders,
Expand Down Expand Up @@ -211,5 +229,6 @@ def test_with_nomissing(self):
f"energy_ori = {energy_ori}\nenergy_new = {energy_new}",
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5ea3594

Please sign in to comment.