Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(fix) Make bias statistics complete for all elements #4496

Open
wants to merge 118 commits into
base: devel
Choose a base branch
from

Conversation

SumGuo-88
Copy link
Collaborator

@SumGuo-88 SumGuo-88 commented Dec 23, 2024

Summary by CodeRabbit

  • New Features

    • Introduced a method to identify and count unique element types in datasets.
    • Added new parameters for enhanced control over statistics collection in training configurations.
    • Expanded dataset of chemical elements for improved mixed-type data processing.
    • Added command-line options to skip element checks and specify minimum frames during statistical data retrieval.
  • Bug Fixes

    • Improved error handling and reporting for type mapping issues in dataset loading.
  • Tests

    • Added unit tests for the make_stat_input function to ensure accurate processing of atomic types.
    • Created a new test file for comprehensive testing of statistical input functionality.
    • Enhanced testing framework for better coverage of element completion scenarios.

Copy link
Contributor

coderabbitai bot commented Dec 23, 2024

📝 Walkthrough

Walkthrough

The pull request introduces modifications in the DeepMD-kit's PyTorch utility modules. A new public method get_frame_index_for_elements is added to the DeepmdDataSetForLoader class in dataset.py, which retrieves frame indices and counts for each unique element in the dataset. The make_stat_input function in stat.py is updated to include new parameters for enhanced handling of atomic types, and a new test file is created to validate the functionality of make_stat_input. Additionally, attributes related to statistical calculations are added to the Trainer class, and error handling is improved in the data.py methods.

Changes

File Change Summary
deepmd/pt/utils/dataset.py Added public method get_frame_index_for_elements() to retrieve frame indices and counts for unique elements. Minor corrections made in the constructor's docstring.
deepmd/pt/utils/stat.py Updated make_stat_input() function to include min_frames_per_element_forstat and enable_element_completion parameters, with enhanced logic for atomic types and statistics handling.
source/tests/pt/test_make_stat_input.py Introduced unit tests for make_stat_input, including class TestMakeStatInput with relevant test methods.
deepmd/pt/train/training.py Added attributes min_frames_per_element_forstat and enable_element_completion to the Trainer class, initialized with default values.
deepmd/utils/argcheck.py Added optional argument min_frames_per_element_forstat and required argument enable_element_completion to the training configuration.
deepmd/utils/data.py Modified error handling in _load_set and _load_type_mix methods to improve robustness and error reporting. Added new method build_reidx_to_name_map().
source/tests/pt/mixed_type_data/sys.000000/type_map.raw Added new entries for chemical elements, expanding the dataset for mixed-type data processing.
deepmd/main.py Introduced new command-line argument --skip-elementcheck to the change-bias parser for enhanced flexibility.
deepmd/pt/entrypoints/main.py Updated change_bias function to include new parameters skip_elem_check and min_frames for controlling element completion during bias adjustments.

Possibly related PRs

  • feat pt : Support property fitting #3867: The changes in the main PR, specifically the addition of parameters for statistical calculations in the make_stat_input function, are directly related to the modifications in the retrieved PR, which also involve enhancements to statistical handling in the deepmd/pt/utils/stat.py file.
  • Add pt compress commad line #4300: The changes in the main PR, specifically the addition of parameters to the make_stat_input function and their integration into the change_bias function, are related to modifications in the retrieved PR, which also involves updates to the change_bias function to accommodate new command-line arguments for model compression.

Suggested labels

Docs, breaking change

Suggested reviewers

  • njzjz
  • iProzd
✨ Finishing Touches
  • 📝 Generate Docstrings (Beta)

Thank you for using CodeRabbit. We offer it for free to the OSS community and would appreciate your support in helping us grow. If you find it useful, would you consider giving us a shout-out on your favorite social media?

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Generate unit testing code for this file.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai generate unit testing code for this file.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and generate unit testing code.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR. (Beta)
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
deepmd/pt/utils/stat.py (1)

110-119: Avoid potential index-out-of-range errors when concatenating Tensors.
The code checks “isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor)” without verifying that the list is non-empty. This might raise an IndexError if the list is empty. Consider adding a length check before accessing sys_stat[key][0].

Possible safe check:

-if isinstance(sys_stat[key], list) and isinstance(sys_stat[key][0], torch.Tensor):
+if isinstance(sys_stat[key], list) and len(sys_stat[key]) > 0 \
+   and isinstance(sys_stat[key][0], torch.Tensor):
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cfe17a3 and 32da243.

📒 Files selected for processing (2)
  • deepmd/pt/utils/dataset.py (1 hunks)
  • deepmd/pt/utils/stat.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/utils/dataset.py

44-44: SyntaxError: Expected an indented block after function definition

🔇 Additional comments (3)
deepmd/pt/utils/stat.py (3)

91-96: Validate 'atype' presence and handle empty arrays gracefully.
While updating the collected elements, consider explicitly checking that each 'atype' array is non-empty. If a dataset yields an empty array, it might bypass your missing-elements logic.

✅ Verification successful

Let me gather more information about how atype is handled in the codebase.


Let me try again with the correct file type flag for ripgrep.


Based on the code analysis, I can now generate the final response:

Empty arrays are handled safely by numpy.unique and set.update
The code uses numpy.unique() which safely handles empty arrays by returning an empty array, and set.update() which safely handles empty iterables. The logic for missing elements works correctly even with empty arrays since an empty array's unique values would simply not add any elements to the set.

🏁 Scripts executed

The following scripts were executed for the analysis:

Error: No shell scripts found


Script:

#!/bin/bash
# Search for patterns related to 'atype' handling and array emptiness checks
rg -A 3 "atype" --type python

# Also look for the specific class/function containing this code
ast-grep --pattern $'def $_(self, $_):
    $$$
    collected_elements = set()
    for sys_stat in lst:
        if "atype" in sys_stat:
            collected_elements.update($$$)
    $$$'

Length of output: 248


Script:

#!/bin/bash
# Search for patterns related to 'atype' handling and array emptiness checks
rg -A 3 "atype" -t py

# Look for similar array handling patterns in the codebase
rg -t py "numpy\.unique|\.update.*numpy" -A 2

# Find the specific file we're analyzing
fd "stat.py" --type f

Length of output: 89317


85-90: Consider using logging and verifying that all datasets are included.
You are only collecting the set of elements from the first dataset and simply printing them. If your intent is to gather possible elements across all datasets, you may want to iterate over all datasets. Also, using "log.info" would be more consistent than a raw "print" in a library setting.

Possible fix for the print statement:

-        print('we want', all_elements)
+        log.info(f"We want elements: {all_elements}")

97-109: Ensure partial coverage logic for missing elements is intended.
Here, frames are appended until the number of records reaches 'nbatches'. Verify that the short-circuiting with “break” actually meets your coverage needs for missing elements. If your goal is to gather as many frames as possible across all missing elements, you might need a more explicit limit or a different approach.

✅ Verification successful

Let me gather more context about this code to better understand the coverage logic for missing elements.


Let me search for more context about the initialization and usage of nbatches to better understand the coverage requirements.


Let me search for the initialization of nbatches to understand its purpose in the coverage logic.


Let me search for more context about how nbatches is used and its relationship with the dataset statistics.


Based on the gathered context, I can now provide a final response about the coverage logic for missing elements.

The current partial coverage logic is intentional and correct.

The code's behavior is appropriate because:

  1. nbatches is a parameter that defines the desired number of batches for collecting statistics, used consistently throughout the make_stat_input function
  2. The break condition if len(lst[i]['atype']) >= nbatches ensures each system collects up to the specified number of batches, preventing over-sampling
  3. This is consistent with how regular batches are collected earlier in the function where numb_batches = min(nbatches, len(dataloaders[i])) is used
  4. The goal is to gather a statistically significant sample (controlled by nbatches) rather than exhaustively processing all frames containing missing elements

The current implementation correctly balances between gathering sufficient data for statistics while maintaining consistent batch sizes across all elements.

🏁 Scripts executed

The following scripts were executed for the analysis:

Error: No shell scripts found


Script:

#!/bin/bash
# Let's first look at the broader context of this code
rg -B 10 -A 5 "for missing_element in missing_elements" deepmd/pt/utils/stat.py

# Also search for any related tests that might help understand the intended behavior
fd "test.*stat.*py" -x rg -l "missing_element"

# Look for the initialization of nbatches and lst variables
rg -B 5 "nbatches.*=|lst.*=" deepmd/pt/utils/stat.py

Length of output: 1724


Script:

#!/bin/bash
# Search for the function definition and its parameters
ast-grep --pattern 'def $FUNC($PARAMS) {
  $$$
  nbatches
  $$$
}'  deepmd/pt/utils/stat.py

# Look for any documentation or comments about nbatches
rg -B 3 -A 3 "nbatches" deepmd/pt/utils/stat.py

# Check if there are any tests that verify this behavior
fd "test.*stat.*py" -x cat {} 

Length of output: 66294


Script:

#!/bin/bash
# Look for the function definition and documentation
rg -B 10 "def make_stat_input" deepmd/pt/utils/stat.py

# Search for any tests that verify the behavior with missing elements
rg -B 5 -A 5 "missing_element" deepmd/pt/utils/stat.py

Length of output: 1079


Script:

#!/bin/bash
# Look for the full function implementation and its docstring
rg -B 20 -A 50 "def make_stat_input" deepmd/pt/utils/stat.py

# Search for any tests that verify the statistics collection
fd "test.*stat.*py" -x rg -A 10 "make_stat_input"

# Look for the initialization of lst and how frames are collected
rg -B 5 "lst.*=.*\[\]" deepmd/pt/utils/stat.py

Length of output: 4539

deepmd/pt/utils/dataset.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (2)
deepmd/pt/utils/stat.py (2)

89-89: Remove debug print statement.

The print statement print("we want", all_elements) appears to be debug code that should be removed or replaced with proper logging.

-        print("we want", all_elements)
+        log.debug(f"Required elements for statistics: {all_elements}")

97-111: Optimize nested loops and add error handling.

The nested loops for handling missing elements could be optimized, and error handling should be added for invalid frame indices.

         for missing_element in missing_elements:
             for i, dataset in enumerate(datasets):
                 if hasattr(dataset, "element_to_frames"):
-                    frame_indices = dataset.element_to_frames.get(
-                        missing_element, []
-                    )
+                    try:
+                        frame_indices = dataset.element_to_frames.get(missing_element, [])
+                        if not frame_indices:
+                            continue
+                            
+                        # Pre-check if we need more frames
+                        if len(lst[i]["atype"]) >= nbatches:
+                            break
+                            
+                        # Process frames in batch
+                        for frame_idx in frame_indices:
+                            frame_data = dataset[frame_idx]
+                            if any(key not in lst[i] for key in frame_data):
+                                lst[i].update({key: [] for key in frame_data if key not in lst[i]})
+                            for key in frame_data:
+                                lst[i][key].append(frame_data[key])
+                            if len(lst[i]["atype"]) >= nbatches:
+                                break
+                    except Exception as e:
+                        log.warning(f"Error processing frames for element {missing_element}: {e}")
+                        continue
-                    for frame_idx in frame_indices:
-                        if len(lst[i]["atype"]) >= nbatches:
-                            break
-                        frame_data = dataset[frame_idx]
-                        for key in frame_data:
-                            if key not in lst[i]:
-                                lst[i][key] = []
-                            lst[i][key].append(frame_data[key])
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 32da243 and adf2315.

📒 Files selected for processing (2)
  • deepmd/pt/utils/dataset.py (1 hunks)
  • deepmd/pt/utils/stat.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/utils/dataset.py

44-44: SyntaxError: Expected an indented block after function definition

🔇 Additional comments (2)
deepmd/pt/utils/dataset.py (2)

43-44: ⚠️ Potential issue

Fix the indentation error in method definition.

The method definition has incorrect indentation which will cause a SyntaxError. It should be aligned with other class methods.

-        def _build_element_to_frames(self):
-        """Mapping element types to frame indexes"""
+    def _build_element_to_frames(self):
+        """Mapping element types to frame indexes"""

Likely invalid or redundant comment.

🧰 Tools
🪛 Ruff (0.8.2)

44-44: SyntaxError: Expected an indented block after function definition


43-53: 🛠️ Refactor suggestion

Make frame limit configurable and enhance documentation.

  1. The hard-coded limit of 10 frames per element should be configurable.
  2. The docstring should be more descriptive about the method's purpose and return value.
-    def _build_element_to_frames(self):
-        """Mapping element types to frame indexes"""
+    def _build_element_to_frames(self, max_frames_per_element: int = 10) -> dict[int, list[int]]:
+        """Build a mapping of element types to their corresponding frame indices.
+        
+        Args:
+            max_frames_per_element: Maximum number of frames to store per element type.
+            
+        Returns:
+            A dictionary mapping element types (int) to lists of frame indices (list[int])
+            where each element type appears.
+        """
         element_to_frames = {element: [] for element in range(self._ntypes)}
         for frame_idx in range(len(self)):
             frame_data = self._data_system.get_item_torch(frame_idx)
 
             elements = frame_data["atype"]
             for element in set(elements):
-                if len(element_to_frames[element]) < 10:
+                if len(element_to_frames[element]) < max_frames_per_element:
                     element_to_frames[element].append(frame_idx)
         return element_to_frames

Likely invalid or redundant comment.

🧰 Tools
🪛 Ruff (0.8.2)

44-44: SyntaxError: Expected an indented block after function definition

deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
deepmd/pt/utils/dataset.py Fixed Show fixed Hide fixed
deepmd/pt/utils/stat.py Fixed Show fixed Hide fixed
@iProzd iProzd marked this pull request as draft December 24, 2024 14:37
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (4)
source/tests/pt/test_make_stat_input.py (4)

14-23: Consider using collections.defaultdict for element_to_frames.
You can simplify the nested checks for element presence in the dictionary by using a defaultdict(list), which would eliminate the need for the explicit if atype not in self.element_to_frames: condition.

-from collections import defaultdict

class TestDataset:
    def __init__(self, samples):
        self.samples = samples
-        self.element_to_frames = {}
+        from collections import defaultdict
+        self.element_to_frames = defaultdict(list)
        for idx, sample in enumerate(samples):
            atypes = sample["atype"]
            for atype in atypes:
-                if atype not in self.element_to_frames:
-                    self.element_to_frames[atype] = []
                self.element_to_frames[atype].append(idx)

25-28: Rename the property to better reflect usage.
Using @property but naming it get_all_atype can be confusing. Consider a more descriptive name like all_atypes, since Python properties typically avoid "get_" prefixes.


53-59: Remove or use the assigned lst variable.
The variable lst is assigned but never used, according to static analysis hints. Consider removing it or using it for additional assertions.

 def test_make_stat_input(self):
     nbatches = 1
-    lst = make_stat_input(self.datasets, self.dataloaders, nbatches=nbatches)
+    _ = make_stat_input(self.datasets, self.dataloaders, nbatches=nbatches)
     all_elements = self.system.get_all_atype
     unique_elements = {1, 2}
     self.assertEqual(unique_elements, all_elements, "make_stat_input miss elements")
🧰 Tools
🪛 Ruff (0.8.2)

55-55: Local variable lst is assigned to but never used

Remove assignment to unused variable lst

(F841)


61-62: Optional test runner inclusion.
Having the if __name__ == "__main__": unittest.main() block is fine. You could remove it if tests are run by a dedicated test runner.

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between adf2315 and dc64307.

📒 Files selected for processing (3)
  • deepmd/pt/utils/dataset.py (4 hunks)
  • deepmd/pt/utils/stat.py (1 hunks)
  • source/tests/pt/test_make_stat_input.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/pt/test_make_stat_input.py

55-55: Local variable lst is assigned to but never used

Remove assignment to unused variable lst

(F841)

🔇 Additional comments (4)
source/tests/pt/test_make_stat_input.py (1)

40-52: Test setup looks good.
The dataset creation for testing is straightforward and clear. No issues found.

deepmd/pt/utils/dataset.py (2)

21-24: Docstring clarity is sufficient.
The docstring effectively describes constructor parameters. No corrections needed.


34-34: Initialization of element frames is a good approach.
Storing the result of _build_element_to_frames() in self.element_to_frames and self.get_all_atype reduces redundancy.

deepmd/pt/utils/stat.py (1)

86-94: No immediate issues with collection of atomic types.
Collecting and updating sets is correct.

deepmd/pt/utils/dataset.py Outdated Show resolved Hide resolved
deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (4)
deepmd/utils/data.py (1)

667-674: Consider providing more informative error details when the type map lookup fails.
While raising IndexError is appropriate, developers might benefit from including the failing value(s). You could, for instance, collect and display the out-of-range types to provide immediate troubleshooting clues. A custom exception or a more descriptive error message can significantly improve clarity and debuggability.

Here's an example of how you might refine the exception:

            except IndexError as e:
-                raise IndexError(
-                    f"some types in 'real_atom_types.npy' of set {set_name} are not contained in {self.get_ntypes()} types!"
-                ) from e
+                # Gather all invalid elements
+                invalid_types = np.unique(real_type[(real_type < 0) | (real_type >= len(self.type_idx_map))])
+                raise IndexError(
+                    f"One or more invalid types found in 'real_atom_types.npy' of set {set_name}: {invalid_types}. "
+                    f"Ensure all types are within [0, {self.get_ntypes()-1}]."
+                ) from e
deepmd/utils/argcheck.py (1)

2829-2834: Ensure user awareness of the new argument.

The new argument min_frames_per_element_forstat is useful for controlling statistic completeness. It might be helpful to specify the expected range (e.g., must be ≥ 1) and how large values impact memory or performance overhead.

source/tests/pt/test_make_stat_input.py (1)

68-68: Remove or utilize the unused variable.

The variable lst is assigned with the result of make_stat_input(...) but never used. If no further checks are applied, remove it to keep the code clean.

-        lst = make_stat_input(
+        make_stat_input(
🧰 Tools
🪛 Ruff (0.8.2)

68-68: Local variable lst is assigned to but never used

Remove assignment to unused variable lst

(F841)

deepmd/pt/utils/stat.py (1)

188-197: Double-check sets for collected vs. missing elements.

This code block re-checks missing elements with:

missing_element = all_element - collect_elements

Confirm that the logic aligns with the earlier missing_elements sets in lines 110–111 to avoid confusion or duplication.

🧰 Tools
🪛 Ruff (0.8.2)

188-188: SyntaxError: unindent does not match any outer indentation level


189-189: SyntaxError: Unexpected indentation

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between dc64307 and 725f1dd.

📒 Files selected for processing (6)
  • deepmd/pt/train/training.py (2 hunks)
  • deepmd/pt/utils/dataset.py (3 hunks)
  • deepmd/pt/utils/stat.py (3 hunks)
  • deepmd/utils/argcheck.py (1 hunks)
  • deepmd/utils/data.py (1 hunks)
  • source/tests/pt/test_make_stat_input.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/utils/dataset.py
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/utils/stat.py

134-134: SyntaxError: Expected a statement


134-134: SyntaxError: Expected a statement


134-134: SyntaxError: Expected a statement


134-134: SyntaxError: Expected a statement


135-135: SyntaxError: Unexpected indentation


144-144: SyntaxError: unindent does not match any outer indentation level


144-144: SyntaxError: Expected a statement


144-144: SyntaxError: Expected a statement


144-145: SyntaxError: Expected a statement


145-145: SyntaxError: Unexpected indentation


176-176: SyntaxError: Expected a statement


176-176: SyntaxError: Expected a statement


176-176: SyntaxError: Expected a statement


176-176: SyntaxError: Expected a statement


188-188: SyntaxError: unindent does not match any outer indentation level


189-189: SyntaxError: Unexpected indentation


231-231: SyntaxError: Expected a statement


231-231: SyntaxError: Expected a statement


231-231: SyntaxError: Expected a statement


231-231: SyntaxError: Expected a statement

source/tests/pt/test_make_stat_input.py

43-43: Loop control variable idx not used within loop body

Rename unused idx to _idx

(B007)


68-68: Local variable lst is assigned to but never used

Remove assignment to unused variable lst

(F841)

🔇 Additional comments (4)
deepmd/pt/train/training.py (2)

145-147: Add type check or validation for the statistic threshold.

While setting self.min_frames_per_element_forstat, consider ensuring it's a strictly positive integer. If a negative or zero value is passed, it may cause runtime issues or meaningless statistics.


232-232: Parameter usage looks correct.

Passing self.min_frames_per_element_forstat to make_stat_input aligns with the newly introduced functionality. Just be sure to verify that all call sites expect this parameter and properly handle out-of-range values.

source/tests/pt/test_make_stat_input.py (1)

41-50: Optimize loop variable usage.

The static analysis hint suggests renaming _idx if it were truly unused, but here idx is used to assign frames. The accusations of an unused loop variable are false. You can safely ignore that static analysis hint.

🧰 Tools
🪛 Ruff (0.8.2)

43-43: Loop control variable idx not used within loop body

Rename unused idx to _idx

(B007)

deepmd/pt/utils/stat.py (1)

52-110: Confirm correctness of cumulative element count logic.

You aggregate total_element_counts[elem]["count"], but also track indices. Ensure you don’t exceed list boundaries when collecting indices for up to min_frames_per_element_forstat. If more frames exist, consider whether you need them to fulfill certain statistics.

deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
deepmd/pt/utils/dataset.py Outdated Show resolved Hide resolved
deepmd/pt/utils/dataset.py Outdated Show resolved Hide resolved
deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
deepmd/pt/utils/dataset.py Outdated Show resolved Hide resolved
deepmd/pt/utils/dataset.py Outdated Show resolved Hide resolved
deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
source/tests/pt/test_make_stat_input.py Show resolved Hide resolved
source/tests/pt/test_make_stat_input.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (2)
deepmd/pt/utils/stat.py (2)

38-58: Improve docstring and add input validation.

The docstring should provide more details about the new parameters:

  • What are valid values for min_frames_per_element_forstat?
  • What are the implications of enabling/disabling element completion?

Apply this diff to improve the docstring and add input validation:

 def make_stat_input(
     datasets,
     dataloaders,
     nbatches,
     min_frames_per_element_forstat=10,
     enable_element_completion=True,
 ):
     """Pack data for statistics.
        Element checking is only enabled with mixed_type.

     Args:
     - datasets: A list of datasets to analyze.
     - dataloaders: Corresponding dataloaders for the datasets.
     - nbatches: Batch count for collecting stats.
-    - min_frames_per_element_forstat: Minimum frames required for statistics.
-    - enable_element_completion: Whether to perform missing element completion (default: True).
+    - min_frames_per_element_forstat: Minimum number of frames required per element for statistics.
+        Must be a positive integer. Default is 10.
+    - enable_element_completion: Whether to perform missing element completion.
+        If True, ensures each element has at least min_frames_per_element_forstat frames.
+        If False, skips missing element handling. Default is True.

     Returns
     -------
     - A list of dicts, each of which contains data from a system.
     """
+    if not datasets:
+        raise ValueError("No datasets provided")
+    if len(datasets) != len(dataloaders):
+        raise ValueError("Number of datasets does not match number of dataloaders")
+    if min_frames_per_element_forstat < 1:
+        raise ValueError("min_frames_per_element_forstat must be positive")

61-74: Remove unused variable and improve logging messages.

The variable global_element_counts is initialized but never used. Also, the logging messages could be more informative.

Apply this diff:

     total_element_types = set()
-    global_element_counts = {}
     global_type_name = {}
     collect_ele = defaultdict(int)
     if datasets[0].mixed_type:
         if enable_element_completion:
             log.info(
-                f"Element check enabled. "
-                f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}."
+                f"Element completion enabled. "
+                f"Ensuring each element has at least {min_frames_per_element_forstat} frames."
             )
         else:
             log.info(
-                "Element completion is disabled. Skipping missing element handling."
+                "Element completion disabled. Elements with insufficient frames will be skipped."
             )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 8763165 and 9f389ad.

📒 Files selected for processing (3)
  • deepmd/pt/utils/dataset.py (3 hunks)
  • deepmd/pt/utils/stat.py (1 hunks)
  • deepmd/utils/data.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • deepmd/pt/utils/dataset.py
  • deepmd/utils/data.py
🧰 Additional context used
🪛 Ruff (0.8.2)
deepmd/pt/utils/stat.py

165-165: Undefined name type_name

(F821)

⏰ Context from checks skipped due to timeout of 90000ms (19)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Analyze (c-cpp)

deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9f389ad and 58647f3.

📒 Files selected for processing (2)
  • deepmd/pt/utils/dataset.py (3 hunks)
  • deepmd/utils/data.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/pt/utils/dataset.py
⏰ Context from checks skipped due to timeout of 90000ms (19)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Analyze (c-cpp)
🔇 Additional comments (1)
deepmd/utils/data.py (1)

699-706: LGTM! Improved error handling for type mapping.

The error handling is well-implemented with:

  • Clear error message including context (set name and available types)
  • Proper exception chaining using from e
  • Appropriate exception type (IndexError)

deepmd/utils/data.py Show resolved Hide resolved
@SumGuo-88 SumGuo-88 requested a review from iProzd January 10, 2025 07:43
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (2)
deepmd/pt/utils/stat.py (2)

38-58: Enhance docstring with more details.

The docstring could be improved with:

  • Return type annotation
  • More detailed parameter descriptions including valid ranges and constraints
  • Documentation of exceptions that might be raised
    """Pack data for statistics.
       Element checking is only enabled with mixed_type.

    Args:
-    - datasets: A list of datasets to analyze.
-    - dataloaders: Corresponding dataloaders for the datasets.
-    - nbatches: Batch count for collecting stats.
-    - min_frames_per_element_forstat: Minimum frames required for statistics.
-    - enable_element_completion: Whether to perform missing element completion (default: True).
+    datasets: list
+        A list of datasets to analyze. Must not be empty.
+    dataloaders: list
+        Corresponding dataloaders for the datasets. Must match length of datasets.
+    nbatches: int
+        Batch count for collecting stats. Must be positive.
+    min_frames_per_element_forstat: int, optional
+        Minimum frames required for statistics per element. Must be positive.
+        Defaults to 10.
+    enable_element_completion: bool, optional
+        Whether to perform missing element completion. Only applies when mixed_type=True.
+        Defaults to True.

    Returns
    -------
-    - A list of dicts, each of which contains data from a system.
+    list[dict]
+        A list of dictionaries, each containing statistical data from a system.
+        Each dict contains tensor data for various properties.

+    Raises
+    ------
+    ValueError
+        If datasets is empty or if datasets and dataloaders lengths don't match.
+        If min_frames_per_element_forstat is less than 1.
+    AssertionError
+        If element check fails during frame processing.
    """

199-221: Improve error handling for missing elements.

Add validation for empty element sets and improve warning messages.

    if datasets[0].mixed_type and enable_element_completion:
+        if not total_element_types:
+            log.warning("No elements found in any dataset")
+            return lst
        for elem, data in global_element_counts.items():
            indices_count = data["count"]
            if indices_count < min_frames_per_element_forstat:
                log.warning(
-                    f"The number of frames in your datasets with element {element_name} is {indices_count}, "
-                    f"which is less than the set {min_frames_per_element_forstat}"
+                    f"Insufficient frames for element {element_name}: found {indices_count}, "
+                    f"required {min_frames_per_element_forstat}. This may affect model accuracy."
                )
+                if indices_count == 0:
+                    log.error(f"No frames found for element {element_name}")
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 58647f3 and 4ce9cfb.

📒 Files selected for processing (1)
  • deepmd/pt/utils/stat.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (10)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Analyze (python)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)

deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
deepmd/pt/utils/stat.py Show resolved Hide resolved
deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
@iProzd iProzd requested review from njzjz and wanghan-iapcm January 15, 2025 06:01
Copy link
Collaborator

@wanghan-iapcm wanghan-iapcm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to refactorized make_stat_input:

  • split it into subfunctions
  • write unittest for each subfunction.
  • the make_stat_input is constructed by the subfunctions.

deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
deepmd/main.py Show resolved Hide resolved
deepmd/main.py Outdated Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (3)
source/tests/pt/test_make_stat_input.py (1)

62-72: 🛠️ Refactor suggestion

Test with multiple batch sizes.

Using a fixed batch size of 1 may not catch batch-related issues.

Apply this diff to test with different batch sizes:

-            cls.dataloaders = []
-            for dataset in cls.datasets:
-                dataloader = DataLoader(
-                    dataset,
-                    batch_size=1,
-                    num_workers=0,
-                    drop_last=False,
-                    collate_fn=collate_fn,
-                    pin_memory=False,
-                )
-                cls.dataloaders.append(dataloader)
+            # Test with different batch sizes
+            cls.batch_sizes = [1, 4, 16]
+            cls.dataloaders = {}
+            for batch_size in cls.batch_sizes:
+                loaders = []
+                for dataset in cls.datasets:
+                    loader = DataLoader(
+                        dataset,
+                        batch_size=batch_size,
+                        num_workers=0,
+                        drop_last=False,
+                        collate_fn=collate_fn,
+                        pin_memory=False,
+                    )
+                    loaders.append(loader)
+                cls.dataloaders[batch_size] = loaders
deepmd/pt/utils/stat.py (2)

38-44: 🛠️ Refactor suggestion

Add input validation for new parameters.

The function should validate the new parameters to ensure they meet requirements.

Add this validation at the start of the function:

def make_stat_input(
    datasets,
    dataloaders,
    nbatches,
    min_frames_per_element_forstat=10,
    enable_element_completion=True,
):
+    if not datasets:
+        raise ValueError("No datasets provided")
+    if len(datasets) != len(dataloaders):
+        raise ValueError("Number of datasets does not match number of dataloaders")
+    if min_frames_per_element_forstat < 1:
+        raise ValueError("min_frames_per_element_forstat must be positive")

166-172: 🛠️ Refactor suggestion

Replace assertion with proper exception handling.

Based on feedback from @wanghan-iapcm, assertions should be replaced with proper exceptions.

-                    assert miss in frame_data["atype"], (
-                        "Element check failed. "
-                        "If you are running in 'change-bias' mode, use '--skip-elementcheck' to disable this check. "
-                        "If you encountered this error during model training, set 'enable_element_completion' to False "
-                        "in the 'training' section of your input file."
-                    )
+                    if miss not in frame_data["atype"]:
+                        raise ValueError(
+                            f"Element {miss} not found in frame data. This could happen if:\n"
+                            "1. The dataset is incomplete or corrupted\n"
+                            "2. You're running in 'change-bias' mode (use '--skip-elementcheck' to proceed)\n"
+                            "3. You're training a model (set 'enable_element_completion=False' in training config)"
+                        )
🧹 Nitpick comments (2)
source/tests/pt/test_make_stat_input.py (1)

194-196: Missing documentation for test purpose.

The test comment "missing element:13,31,37" is insufficient. Need proper documentation explaining why these elements are missing and what the test verifies.

Add proper documentation:

     def test_with_nomissing(self):
-        # missing element:13,31,37
-        # only one frame would be count
+        """Test make_stat_input with no missing elements.
+        
+        This test verifies that:
+        1. Elements 13, 31, and 37 are missing in the test data
+        2. Only one frame is counted due to [explain why]
+        3. The bias values should be identical with and without element completion
+        """
deepmd/pt/utils/stat.py (1)

62-64: Remove unused variable.

The variable global_element_counts is initialized but never used.

     total_element_types = set()
-    global_element_counts = {}
     global_type_name = {}
     collect_ele = defaultdict(int)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4ce9cfb and 324037d.

📒 Files selected for processing (2)
  • deepmd/pt/utils/stat.py (1 hunks)
  • source/tests/pt/test_make_stat_input.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/pt/test_make_stat_input.py

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

⏰ Context from checks skipped due to timeout of 90000ms (20)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Test Python (6, 3.12)
  • GitHub Check: Test Python (6, 3.9)
  • GitHub Check: Test Python (5, 3.12)
  • GitHub Check: Test Python (5, 3.9)
  • GitHub Check: Test Python (4, 3.12)
  • GitHub Check: Test Python (4, 3.9)
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (javascript-typescript)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Analyze (c-cpp)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)

source/tests/pt/test_make_stat_input.py Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
deepmd/pt/entrypoints/main.py (1)

389-390: 🛠️ Refactor suggestion

Consider parameter naming consistency and default behavior.

  1. Parameter names should be consistent between function signature and CLI:
    • skip_elem_check vs skip_elementcheck
    • min_frames vs minframes_eachelement
  2. Setting skip_elem_check=True by default means element checks are skipped, which might not align with the PR's goal of making bias statistics complete for all elements.

Apply this diff to improve consistency and default behavior:

-    skip_elem_check: bool = True,
-    min_frames: int = 10,
+    skip_elementcheck: bool = False,
+    minframes_eachelement: int = 10,
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 324037d and 5ea3594.

📒 Files selected for processing (2)
  • deepmd/main.py (1 hunks)
  • deepmd/pt/entrypoints/main.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • deepmd/main.py
⏰ Context from checks skipped due to timeout of 90000ms (18)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Test Python (3, 3.12)
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Test Python (3, 3.9)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Test Python (2, 3.9)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Analyze (python)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Test Python (1, 3.9)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Analyze (c-cpp)

deepmd/pt/entrypoints/main.py Show resolved Hide resolved
deepmd/pt/entrypoints/main.py Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
deepmd/pt/utils/stat.py (1)

76-96: 🛠️ Refactor suggestion

Add error handling for unknown data types.

The function silently passes on unknown data types, which could hide potential issues.

Apply this diff:

    def process_batches(dataloader, sys_stat):
        """Process batches from a dataloader to collect statistics."""
        iterator = iter(dataloader)
        numb_batches = min(nbatches, len(dataloader))
        for _ in range(numb_batches):
            try:
                stat_data = next(iterator)
            except StopIteration:
                iterator = iter(dataloader)
                stat_data = next(iterator)
-            for dd in stat_data:
-                if stat_data[dd] is None:
+            for dd, value in stat_data.items():
+                try:
+                    if value is None:
                         sys_stat[dd] = None
-                elif isinstance(stat_data[dd], torch.Tensor):
+                    elif isinstance(value, torch.Tensor):
                         if dd not in sys_stat:
                             sys_stat[dd] = []
-                        sys_stat[dd].append(stat_data[dd])
-                elif isinstance(stat_data[dd], np.float32):
-                        sys_stat[dd] = stat_data[dd]
+                        sys_stat[dd].append(value)
+                    elif isinstance(value, np.float32):
+                        sys_stat[dd] = value
                     else:
-                        pass
+                        log.warning(f"Unexpected data type {type(value)} for key {dd}")
+                except Exception as e:
+                    log.error(f"Error processing key {dd}: {str(e)}")
+                    raise
🧹 Nitpick comments (4)
deepmd/pt/entrypoints/main.py (3)

391-392: Align parameter names with function signature.

The parameter names in change_bias differ from the CLI flags and make_stat_input function:

  • skip_elem_checkskip_elementcheck
  • min_framesminframes_eachelement

Apply this diff to align the parameter names:

-    skip_elem_check: bool = True,
-    min_frames: int = 10,
+    skip_elementcheck: bool = True,
+    minframes_eachelement: int = 10,

479-481: Simplify parameter naming and logic.

The parameter names and logic could be simplified:

  1. Parameter names differ across layers:
    • Function: min_framesmin_frames_per_element_forstat
    • Function: skip_elem_checknot enable_element_completion
  2. The negation not skip_elem_check adds unnecessary complexity.

Apply this diff to improve clarity:

     sampled_data = make_stat_input(
         data_single.systems,
         data_single.dataloaders,
         nbatches,
-        min_frames_per_element_forstat=min_frames,
-        enable_element_completion=not skip_elem_check,
+        min_frames_per_element_forstat=minframes_eachelement,
+        enable_element_completion=not skip_elementcheck,
     )

564-565: Align parameter names with function signature.

The parameter names in main match the CLI flags but differ from the change_bias function signature:

  • FLAGS.skip_elementcheckskip_elem_check
  • FLAGS.minframes_eachelementmin_frames

After updating the function signature as suggested earlier, apply this diff:

             model_branch=FLAGS.model_branch,
             output=FLAGS.output,
-            skip_elem_check=FLAGS.skip_elementcheck,
-            min_frames=FLAGS.minframes_eachelement,
+            skip_elementcheck=FLAGS.skip_elementcheck,
+            minframes_eachelement=FLAGS.minframes_eachelement,
         )
deepmd/pt/utils/stat.py (1)

136-155: Improve error messages for missing elements.

The missing element handling could be more informative.

Apply this diff:

    def process_missing_elements(
        min_frames_per_element_forstat,
        global_element_counts,
        total_element_types,
        collect_ele,
    ):
        """Handle missing elements and check element completeness."""
+        if not total_element_types:
+            log.warning("No elements found in any dataset")
+            return
        collect_elements = collect_ele.keys()
        missing_elements = total_element_types - collect_elements
        collect_miss_element = set()
        for ele, count in collect_ele.items():
            if count < min_frames_per_element_forstat:
+                log.warning(
+                    f"Insufficient frames for element {ele}: found {count}, "
+                    f"required {min_frames_per_element_forstat}. This may affect model accuracy."
+                )
                collect_miss_element.add(ele)
                missing_elements.add(ele)
+        if missing_elements:
+            log.warning(f"Missing elements detected: {missing_elements}")
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5ea3594 and 719360b.

📒 Files selected for processing (6)
  • deepmd/main.py (1 hunks)
  • deepmd/pt/entrypoints/main.py (3 hunks)
  • deepmd/pt/train/training.py (2 hunks)
  • deepmd/pt/utils/stat.py (1 hunks)
  • deepmd/utils/argcheck.py (2 hunks)
  • deepmd/utils/data.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
  • deepmd/pt/train/training.py
  • deepmd/utils/data.py
  • deepmd/main.py
  • deepmd/utils/argcheck.py
🔇 Additional comments (1)
deepmd/pt/utils/stat.py (1)

38-58: Implementation aligns with PR objectives.

The changes successfully enhance bias statistics completeness by:

  1. Adding control over element completion
  2. Introducing minimum frame requirements
  3. Improving error handling and validation

The implementation looks good with the suggested improvements.

deepmd/pt/utils/stat.py Show resolved Hide resolved
deepmd/pt/utils/stat.py Outdated Show resolved Hide resolved
deepmd/pt/utils/stat.py Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (4)
deepmd/pt/utils/stat.py (4)

61-64: 🛠️ Refactor suggestion

Remove unused variable.

The variable global_element_counts is initialized but never used in the code.

    total_element_types = set()
-    global_element_counts = {}
    global_type_name = {}
    collect_ele = defaultdict(int)

38-58: 🛠️ Refactor suggestion

Add validation for input parameters.

The function should validate its input parameters to ensure they meet requirements:

  • datasets should not be empty
  • dataloaders length should match datasets length
  • min_frames_per_element_forstat should be positive
def make_stat_input(
    datasets,
    dataloaders,
    nbatches,
    min_frames_per_element_forstat=10,
    enable_element_completion=True,
):
+    if not datasets:
+        raise ValueError("No datasets provided")
+    if len(datasets) != len(dataloaders):
+        raise ValueError("Number of datasets does not match number of dataloaders")
+    if min_frames_per_element_forstat < 1:
+        raise ValueError("min_frames_per_element_forstat must be positive")

76-96: 🛠️ Refactor suggestion

Improve error handling in process_batches.

The function silently passes on unknown data types and may raise KeyError.

    def process_batches(dataloader, sys_stat):
        """Process batches from a dataloader to collect statistics."""
        iterator = iter(dataloader)
        numb_batches = min(nbatches, len(dataloader))
        for _ in range(numb_batches):
            try:
                stat_data = next(iterator)
            except StopIteration:
                iterator = iter(dataloader)
                stat_data = next(iterator)
-            for dd in stat_data:
-                if stat_data[dd] is None:
+            for dd, value in stat_data.items():
+                try:
+                    if value is None:
                         sys_stat[dd] = None
-                elif isinstance(stat_data[dd], torch.Tensor):
+                    elif isinstance(value, torch.Tensor):
                         if dd not in sys_stat:
                             sys_stat[dd] = []
-                        sys_stat[dd].append(stat_data[dd])
-                elif isinstance(stat_data[dd], np.float32):
-                        sys_stat[dd] = stat_data[dd]
+                        sys_stat[dd].append(value)
+                    elif isinstance(value, np.float32):
+                        sys_stat[dd] = value
                     else:
-                        pass
+                        log.warning(f"Unexpected data type {type(value)} for key {dd}")
+                except Exception as e:
+                    log.error(f"Error processing key {dd}: {str(e)}")
+                    raise

157-173: 🛠️ Refactor suggestion

Replace assertion with proper exception handling.

As suggested in the past review comments, replace the assertion with a proper exception that provides clear guidance:

    def process_with_new_frame(sys_indices, newele_counter, miss):
        """Process frames with missing elements."""
+        if not sys_indices:
+            log.warning(f"No system indices provided for element {miss}")
+            return
        for sys_info in sys_indices:
            sys_index = sys_info["sys_index"]
            frames = sys_info["frames"]
+            if not frames:
+                log.warning(f"No frames found for system {sys_index}")
+                continue
            sys = datasets[sys_index]
            for frame in frames:
                newele_counter += 1
                if newele_counter <= min_frames_per_element_forstat:
-                    frame_data = sys.__getitem__(frame)
-                    if miss not in frame_data["atype"]:
-                        raise ValueError(
-                            "Element check failed. "
-                            "If you are running in 'change-bias' mode, use '--skip-elementcheck' to disable this check. "
-                            "If you encountered this error during model training, set 'enable_element_completion' to False "
-                            "in the 'training' section of your input file."
-                        )
+                    try:
+                        frame_data = sys.__getitem__(frame)
+                        if "atype" not in frame_data:
+                            raise ValueError(f"Frame {frame} does not contain type information")
+                        if miss not in frame_data["atype"]:
+                            raise ValueError(
+                                f"Element {miss} not found in frame {frame}.\n"
+                                "To proceed without element completion:\n"
+                                "1. For change-bias mode: Use '--skip-elementcheck'\n"
+                                "2. For model training: Set 'enable_element_completion' to False"
+                            )
+                    except Exception as e:
+                        log.error(f"Failed to process frame {frame} from system {sys_index}: {e}")
+                        raise
🧹 Nitpick comments (2)
deepmd/pt/utils/stat.py (2)

175-190: Extract common tensor processing logic.

The tensor processing logic is duplicated from the main processing loop. Extract it into a helper function:

+                def process_tensor_data(data):
+                    """Process tensor data with proper error handling."""
+                    try:
+                        if data is None:
+                            return None
+                        if isinstance(data, np.ndarray):
+                            tensor_data = torch.from_numpy(data)
+                            return tensor_data.unsqueeze(0)
+                        if isinstance(data, np.float32):
+                            return data
+                        return None
+                    except Exception as e:
+                        log.error(f"Failed to process tensor data: {e}")
+                        raise

                     sys_stat_new = {}
                     for dd in frame_data:
                         if dd == "type":
                             continue
-                        if frame_data[dd] is None:
-                            sys_stat_new[dd] = None
-                        elif isinstance(frame_data[dd], np.ndarray):
+                        result = process_tensor_data(frame_data[dd])
+                        if result is not None:
                             if dd not in sys_stat_new:
                                 sys_stat_new[dd] = []
-                            tensor_data = torch.from_numpy(frame_data[dd])
-                            tensor_data = tensor_data.unsqueeze(0)
-                            sys_stat_new[dd].append(tensor_data)
-                        elif isinstance(frame_data[dd], np.float32):
-                            sys_stat_new[dd] = frame_data[dd]
-                        else:
-                            pass
+                            if isinstance(result, torch.Tensor):
+                                sys_stat_new[dd].append(result)
+                            else:
+                                sys_stat_new[dd] = result

136-156: Improve error handling for missing elements.

The missing element handling could be more robust with better error messages:

    def process_missing_elements(
        min_frames_per_element_forstat,
        global_element_counts,
        total_element_types,
        collect_ele,
    ):
        """Handle missing elements and check element completeness."""
+        if not total_element_types:
+            log.warning("No elements found in any dataset")
+            return
        collect_elements = collect_ele.keys()
        missing_elements = total_element_types - collect_elements
+        if not missing_elements:
+            log.info("All elements are present in the datasets")
+            return
        collect_miss_element = set()
        for ele, count in collect_ele.items():
            if count < min_frames_per_element_forstat:
                collect_miss_element.add(ele)
                missing_elements.add(ele)
+                log.warning(
+                    f"Insufficient frames for element {ele}: found {count}, "
+                    f"required {min_frames_per_element_forstat}. This may affect model accuracy."
+                )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 719360b and 9fcfffe.

📒 Files selected for processing (1)
  • deepmd/pt/utils/stat.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (6)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (cpu, cpu)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (1)
deepmd/pt/utils/stat.py (1)

157-184: ⚠️ Potential issue

Replace assertion with proper exception handling.

The function uses an assertion for validation, which should be replaced with proper exception handling as suggested by the reviewer.

     def process_with_new_frame(sys_indices, newele_counter):
         """Process frames with missing elements."""
+        if not sys_indices:
+            log.warning("No system indices provided")
+            return
         for sys_info in sys_indices:
             sys_index = sys_info["sys_index"]
             frames = sys_info["frames"]
+            if not frames:
+                log.warning(f"No frames found for system {sys_index}")
+                continue
             sys = datasets[sys_index]
             for frame in frames:
                 newele_counter += 1
                 if newele_counter <= min_frames_per_element_forstat:
-                    frame_data = sys.__getitem__(frame)
+                    try:
+                        frame_data = sys.__getitem__(frame)
+                        if "type" not in frame_data:
+                            raise ValueError(f"Frame {frame} does not contain type information")
+                        if miss not in frame_data["type"]:
+                            raise ValueError(
+                                f"Element {miss} not found in frame {frame}. "
+                                "To proceed without element completion:\n"
+                                "1. For change-bias mode: Use '--skip-elementcheck'\n"
+                                "2. For model training: Set 'enable_element_completion' to False"
+                            )
+                    except Exception as e:
+                        log.error(f"Failed to process frame {frame} from system {sys_index}: {e}")
+                        raise
🧹 Nitpick comments (2)
deepmd/pt/utils/stat.py (2)

38-58: Add input validation and improve docstring.

The docstring should provide more details about the new parameters:

  • What constitutes a valid value for min_frames_per_element_forstat?
  • What are the implications of enabling/disabling element completion?
 def make_stat_input(
     datasets,
     dataloaders,
     nbatches,
     min_frames_per_element_forstat=10,
     enable_element_completion=True,
 ):
+    if not datasets:
+        raise ValueError("No datasets provided")
+    if len(datasets) != len(dataloaders):
+        raise ValueError("Number of datasets does not match number of dataloaders")
+    if min_frames_per_element_forstat < 1:
+        raise ValueError("min_frames_per_element_forstat must be positive")
     """Pack data for statistics.
        Element checking is only enabled with mixed_type.

     Args:
     - datasets: A list of datasets to analyze.
     - dataloaders: Corresponding dataloaders for the datasets.
     - nbatches: Batch count for collecting stats.
-    - min_frames_per_element_forstat: Minimum frames required for statistics.
-    - enable_element_completion: Whether to perform missing element completion (default: True).
+    - min_frames_per_element_forstat: Minimum number of frames required per element for accurate statistics.
+                                     Must be a positive integer. Default: 10.
+    - enable_element_completion: If True, attempts to complete missing elements by searching for frames
+                                containing those elements. If False, skips this process. Default: True.

     Returns
     -------
     - A list of dicts, each of which contains data from a system.
     """

61-74: Remove unused variable and improve logging messages.

The variable global_element_counts is initialized but never used. Also, the logging messages could be more specific.

     total_element_types = set()
-    global_element_counts = {}
     global_type_name = {}
     collect_ele = defaultdict(int)
     if datasets[0].mixed_type:
         if enable_element_completion:
             log.info(
-                f"Element check enabled. "
-                f"Verifying if frames with elements meet the set of {min_frames_per_element_forstat}."
+                f"Element completion enabled. "
+                f"Each element requires at least {min_frames_per_element_forstat} frames for statistics."
             )
         else:
             log.info(
-                "Element completion is disabled. Skipping missing element handling."
+                "Element completion disabled. Statistics may be incomplete for elements with insufficient frames."
             )
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9fcfffe and e943950.

📒 Files selected for processing (1)
  • deepmd/pt/utils/stat.py (1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms (7)
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Test C++ (true)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (javascript-typescript)
  • GitHub Check: Analyze (c-cpp)

deepmd/pt/utils/stat.py Show resolved Hide resolved
deepmd/pt/utils/stat.py Show resolved Hide resolved
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

♻️ Duplicate comments (3)
source/tests/pt/test_make_stat_input.py (3)

62-72: ⚠️ Potential issue

Test with multiple batch sizes.

Using a fixed batch size of 1 may not catch batch-related issues.

-            cls.dataloaders = []
-            for dataset in cls.datasets:
-                dataloader = DataLoader(
-                    dataset,
-                    batch_size=1,
-                    num_workers=0,
-                    drop_last=False,
-                    collate_fn=collate_fn,
-                    pin_memory=False,
-                )
-                cls.dataloaders.append(dataloader)
+            # Test with different batch sizes
+            cls.batch_sizes = [1, 4, 16]
+            cls.dataloaders = {}
+            for batch_size in cls.batch_sizes:
+                loaders = []
+                for dataset in cls.datasets:
+                    loader = DataLoader(
+                        dataset,
+                        batch_size=batch_size,
+                        num_workers=0,
+                        drop_last=False,
+                        collate_fn=collate_fn,
+                        pin_memory=False,
+                    )
+                    loaders.append(loader)
+                cls.dataloaders[batch_size] = loaders

171-172: ⚠️ Potential issue

Use self.real_ntypes instead of hard-coded ntypes value.

The hard-coded ntypes=57 should use the derived value from the dataset.

-        bias_ori, _ = compute_output_stats(lst_ori, ntypes=57)
-        bias_all, _ = compute_output_stats(lst_all, ntypes=57)
+        bias_ori, _ = compute_output_stats(lst_ori, ntypes=self.real_ntypes)
+        bias_all, _ = compute_output_stats(lst_all, ntypes=self.real_ntypes)

192-193: 🛠️ Refactor suggestion

Document missing elements and improve test method.

The test method needs better documentation of missing elements and their impact.

-        # missing element:13,31,37
-        # only one frame would be count
+        """Test make_stat_input with element completion enabled/disabled.
+        
+        Missing elements in test data:
+        - Element 13: Missing due to [reason]
+        - Element 31: Missing due to [reason]
+        - Element 37: Missing due to [reason]
+        
+        When element completion is disabled:
+        - Only one frame is counted because [explain why]
+        - Missing elements should have zero bias
+        """
🧹 Nitpick comments (1)
source/tests/pt/test_make_stat_input.py (1)

25-44: Optimize dictionary key lookup in collate function.

The dictionary key lookup in the loop can be optimized.

-    for key in batch[0].keys():
+    for key in batch[0]:
🧰 Tools
🪛 Ruff (0.8.2)

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between e943950 and 61671af.

📒 Files selected for processing (1)
  • source/tests/pt/test_make_stat_input.py (1 hunks)
🧰 Additional context used
🪛 Ruff (0.8.2)
source/tests/pt/test_make_stat_input.py

30-30: Use key in dict instead of key in dict.keys()

Remove .keys()

(SIM118)

⏰ Context from checks skipped due to timeout of 90000ms (19)
  • GitHub Check: Test Python (2, 3.12)
  • GitHub Check: Test Python (1, 3.12)
  • GitHub Check: Build wheels for cp310-manylinux_aarch64
  • GitHub Check: Build wheels for cp311-win_amd64
  • GitHub Check: Build wheels for cp311-macosx_arm64
  • GitHub Check: Build wheels for cp311-macosx_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build wheels for cp311-manylinux_x86_64
  • GitHub Check: Build C++ (clang, clang)
  • GitHub Check: Build C++ (rocm, rocm)
  • GitHub Check: Build C++ (cuda120, cuda)
  • GitHub Check: Build C++ (cuda, cuda)
  • GitHub Check: Build C++ (cpu, cpu)
  • GitHub Check: Build C library (2.14, >=2.5.0rc0,<2.15, libdeepmd_c_cu11.tar.gz)
  • GitHub Check: Test C++ (false)
  • GitHub Check: Build C library (2.18, libdeepmd_c.tar.gz)
  • GitHub Check: Test C++ (true)
  • GitHub Check: Analyze (python)
  • GitHub Check: Analyze (c-cpp)

source/tests/pt/test_make_stat_input.py Show resolved Hide resolved
source/tests/pt/test_make_stat_input.py Show resolved Hide resolved
source/tests/pt/test_make_stat_input.py Show resolved Hide resolved
source/tests/pt/test_make_stat_input.py Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] Incomplete and risky bias statistics
4 participants