From 324037d8a5a9468642110988d1c9f2157abf4cfd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Feb 2025 08:30:36 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/utils/stat.py | 29 +++++++++---- source/tests/pt/test_make_stat_input.py | 55 +++++++++++++++++-------- 2 files changed, 59 insertions(+), 25 deletions(-) diff --git a/deepmd/pt/utils/stat.py b/deepmd/pt/utils/stat.py index 9da91e3e8e..7d5bff3b60 100644 --- a/deepmd/pt/utils/stat.py +++ b/deepmd/pt/utils/stat.py @@ -108,7 +108,7 @@ def finalize_stats(sys_stat): elif isinstance(sys_stat[key][0], torch.Tensor): sys_stat[key] = torch.cat(sys_stat[key], dim=0) dict_to_device(sys_stat) - + def process_element_counts(sys_index, dataset, min_frames_per_element_forstat): """Process and update global element counts.""" element_counts, type_name = dataset.get_frame_index_for_elements() @@ -124,12 +124,21 @@ def process_element_counts(sys_index, dataset, min_frames_per_element_forstat): if count > min_frames_per_element_forstat: global_element_counts[elem]["count"] += min_frames_per_element_forstat indices = indices[:min_frames_per_element_forstat] - global_element_counts[elem]["indices"].append({"sys_index": sys_index, "frames": indices}) + global_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": indices} + ) else: global_element_counts[elem]["count"] += count - global_element_counts[elem]["indices"].append({"sys_index": sys_index, "frames": indices}) + global_element_counts[elem]["indices"].append( + {"sys_index": sys_index, "frames": indices} + ) - def process_missing_elements(min_frames_per_element_forstat, global_element_counts, total_element_types, collect_ele): + def process_missing_elements( + min_frames_per_element_forstat, + global_element_counts, + total_element_types, + collect_ele, + ): """Handle missing elements and check element completeness.""" collect_elements = collect_ele.keys() missing_elements = total_element_types - collect_elements @@ -140,7 +149,9 @@ def process_missing_elements(min_frames_per_element_forstat, global_element_coun missing_elements.add(ele) for miss in missing_elements: sys_indices = global_element_counts[miss].get("indices", []) - newele_counter = collect_ele.get(miss, 0) if miss in collect_miss_element else 0 + newele_counter = ( + collect_ele.get(miss, 0) if miss in collect_miss_element else 0 + ) process_with_new_frame(sys_indices, newele_counter, miss) def process_with_new_frame(sys_indices, newele_counter, miss): @@ -196,12 +207,16 @@ def process_with_new_frame(sys_indices, newele_counter, miss): process_element_counts(sys_index, dataset, min_frames_per_element_forstat) if datasets[0].mixed_type and enable_element_completion: - process_missing_elements(min_frames_per_element_forstat, global_element_counts, total_element_types, collect_ele) + process_missing_elements( + min_frames_per_element_forstat, + global_element_counts, + total_element_types, + collect_ele, + ) return lst - def _restore_from_file( stat_file_path: DPPath, keys: list[str] = ["energy"], diff --git a/source/tests/pt/test_make_stat_input.py b/source/tests/pt/test_make_stat_input.py index dbcb3e84ea..42f1650de0 100644 --- a/source/tests/pt/test_make_stat_input.py +++ b/source/tests/pt/test_make_stat_input.py @@ -1,11 +1,26 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later import unittest -from pathlib import Path +from pathlib import ( + Path, +) + import numpy as np import torch -from torch.utils.data import DataLoader -from deepmd.pt.utils.dataset import DeepmdDataSetForLoader -from deepmd.pt.utils.stat import compute_output_stats, make_stat_input -from deepmd.utils.data import DataRequirementItem +from torch.utils.data import ( + DataLoader, +) + +from deepmd.pt.utils.dataset import ( + DeepmdDataSetForLoader, +) +from deepmd.pt.utils.stat import ( + compute_output_stats, + make_stat_input, +) +from deepmd.utils.data import ( + DataRequirementItem, +) + def collate_fn(batch): if isinstance(batch, dict): @@ -27,6 +42,7 @@ def collate_fn(batch): return out + class TestMakeStatInput(unittest.TestCase): @classmethod def setUpClass(cls): @@ -71,13 +87,13 @@ def test_make_stat_input_with_element_counts(self): energy = bias.get("energy") non_zero_count = self.count_non_zero_elements(energy) self.assertEqual( - non_zero_count, + non_zero_count, self.real_ntypes, - f"Expected exactly {self.real_ntypes} non-zero elements in energy, but got {non_zero_count}." + f"Expected exactly {self.real_ntypes} non-zero elements in energy, but got {non_zero_count}.", ) def test_process_missing_elements(self): - #3 frames would be count + # 3 frames would be count lst = make_stat_input( datasets=self.datasets, dataloaders=self.dataloaders, @@ -91,9 +107,9 @@ def test_process_missing_elements(self): energy = sys_stat["energy"] non_zero_count = self.count_non_zero_elements(energy) self.assertLess( - non_zero_count, - self.real_ntypes, - f"Expected fewer than {self.real_ntypes} non-zero elements due to missing elements." + non_zero_count, + self.real_ntypes, + f"Expected fewer than {self.real_ntypes} non-zero elements due to missing elements.", ) def test_with_missing_elements_and_new_frames(self): @@ -116,9 +132,11 @@ def test_with_missing_elements_and_new_frames(self): energy = sys_stat["energy"] missing_elements.append(self.count_non_zero_elements(energy)) - # - self.assertGreater(len(missing_elements), 0, "Expected missing elements to be processed.") - + # + self.assertGreater( + len(missing_elements), 0, "Expected missing elements to be processed." + ) + lst_new = make_stat_input( datasets=self.datasets, dataloaders=self.dataloaders, @@ -127,13 +145,13 @@ def test_with_missing_elements_and_new_frames(self): enable_element_completion=True, ) - # + # for original, new in zip(lst, lst_new): energy_ori = np.array(original["energy"].cpu()).flatten() energy_new = np.array(new["energy"].cpu()).flatten() self.assertTrue( np.allclose(energy_ori, energy_new), - msg=f"Energy values don't match. Original: {energy_ori}, New: {energy_new}" + msg=f"Energy values don't match. Original: {energy_ori}, New: {energy_new}", ) def test_bias(self): @@ -173,8 +191,8 @@ def test_bias(self): ) def test_with_nomissing(self): - #missing element:13,31,37 - #only one frame would be count + # missing element:13,31,37 + # only one frame would be count lst_ori = make_stat_input( datasets=self.datasets, dataloaders=self.dataloaders, @@ -211,5 +229,6 @@ def test_with_nomissing(self): f"energy_ori = {energy_ori}\nenergy_new = {energy_new}", ) + if __name__ == "__main__": unittest.main()