Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(jax): passing mapping from LAMMPS to DPA-2 #4316

Closed
wants to merge 29 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ repos:
exclude: |
(?x)^(
source/tests/infer/dipolecharge_e.pbtxt|
source/tests/infer/deeppolar_new.pbtxt
source/tests/infer/deeppolar_new.pbtxt|
source/tests/infer/deeppot_dpa.savedmodel/saved_model.pb
Comment on lines +18 to +19
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Additional large model files need exclusion

The scan revealed several model files exceeding the 1024KB limit that are not currently excluded:

  • source/tests/infer/deepdipole_new.pbtxt (916KB)
  • source/tests/infer/deepdipole_fake.pbtxt (1002KB)
  • source/tests/infer/deeppot_sea.savedmodel/saved_model.pb (630KB)

These files should also be added to the exclusion list in .pre-commit-config.yaml to maintain consistency with the current approach of excluding large model files.

🔗 Analysis chain

LGTM! Consider checking for additional model files.

The exclusion patterns for the new model files are consistent with the existing pattern. However, since this PR is part of the JAX backend integration, we should verify if there are other similar model files that might need exclusion.


Let me try a different approach to find potential large model files by using fd and checking file sizes directly.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for other potential large model files that might need exclusion
# Look for similar files in the test directories that might exceed the size limit

echo "Searching for .pb and .pbtxt files in test directories..."
find source/tests -type f \( -name "*.pb" -o -name "*.pbtxt" \) -exec ls -lh {} \; | \
  awk '{if($5+0 > 1024) print "Warning: Large file found:", $9, "Size:", $5}'

Length of output: 292


Script:

#!/bin/bash
# Search for model files and their sizes in the source/tests directory
echo "Searching for model files in tests directory..."
fd -t f "\.(pb|pbtxt|savedmodel)$" source/tests -x ls -lh {}

# Also check the pre-commit config for current exclusions
echo -e "\nCurrent exclusions in pre-commit config:"
grep -A 10 "check-added-large-files" .pre-commit-config.yaml

Length of output: 2875

)$
- id: check-merge-conflict
- id: check-symlinks
Expand Down
4 changes: 3 additions & 1 deletion doc/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ While `.pth` and `.pt` are the same in the PyTorch package, they have different
[JAX](https://jax.readthedocs.io/) 0.4.33 (which requires Python 3.10 or above) or above is required.
Both `.xlo` and `.jax` are customized format extensions defined in DeePMD-kit, since JAX has no convention for file extensions.
`.savedmodel` is the TensorFlow [SavedModel format](https://www.tensorflow.org/guide/saved_model) generated by [JAX2TF](https://www.tensorflow.org/guide/jax2tf), which needs the installation of TensorFlow.
Currently, this backend is developed actively, and has no support for training and the C++ interface.
Only the `.savedmodel` format supports C++ inference, which needs the TensorFlow C++ interface.
The model is device-specific, so that the model generated on the GPU device cannot be run on the CPUs.
Currently, this backend is developed actively, and has no support for training.

### DP {{ dpmodel_icon }}

Expand Down
10 changes: 6 additions & 4 deletions doc/install/install-from-source.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,9 @@ If one does not need to use DeePMD-kit with LAMMPS or i-PI, then the python inte

::::{tab-set}

:::{tab-item} TensorFlow {{ tensorflow_icon }}
:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }}

The C++ interfaces of both TensorFlow and JAX backends are based on the TensorFlow C++ library.

Since TensorFlow 2.12, TensorFlow C++ library (`libtensorflow_cc`) is packaged inside the Python library. Thus, you can skip building TensorFlow C++ library manually. If that does not work for you, you can still build it manually.

Expand Down Expand Up @@ -338,7 +340,7 @@ We recommend using [conda packages](https://docs.deepmodeling.org/faq/conda.html

::::{tab-set}

:::{tab-item} TensorFlow {{ tensorflow_icon }}
:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }}

I assume you have activated the TensorFlow Python environment and want to install DeePMD-kit into path `$deepmd_root`, then execute CMake

Expand Down Expand Up @@ -375,7 +377,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value

**Type**: `BOOL` (`ON`/`OFF`), Default: `OFF`

{{ tensorflow_icon }} Whether building the TensorFlow backend.
{{ tensorflow_icon }} {{ jax_icon }} Whether building the TensorFlow backend and the JAX backend.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider separating backend enable flags.

The current ENABLE_TENSORFLOW variable name doesn't accurately reflect that it enables both TensorFlow and JAX backends. This could be confusing for users who only want to use JAX.

Consider either:

  1. Renaming to something more inclusive like ENABLE_TF_JAX_BACKEND
  2. Creating a separate ENABLE_JAX flag


:::

Expand All @@ -391,7 +393,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value

**Type**: `PATH`

{{ tensorflow_icon }} The Path to TensorFlow's C++ interface.
{{ tensorflow_icon }} {{ jax_icon }} The Path to TensorFlow's C++ interface.

:::

Expand Down
8 changes: 8 additions & 0 deletions doc/model/dpa2.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ If one runs LAMMPS with MPI, the customized OP library for the C++ interface sho
If one runs LAMMPS with MPI and CUDA devices, it is recommended to compile the customized OP library for the C++ interface with a [CUDA-Aware MPI](https://developer.nvidia.com/mpi-solutions-gpus) library and CUDA,
otherwise the communication between GPU cards falls back to the slower CPU implementation.

## Limiations of the JAX backend with LAMMPS {{ jax_icon }}

When using the JAX backend, 2 or more MPI ranks are not supported. One must set `map` to `yes` using the [`atom_modify`](https://docs.lammps.org/atom_modify.html) command.

```lammps
atom_modify map yes
```

## Data format

DPA-2 supports both the [standard data format](../data/system.md) and the [mixed type data format](../data/system.md#mixed-type).
16 changes: 13 additions & 3 deletions source/api_c/include/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ extern "C" {
/** C API version. Bumped whenever the API is changed.
* @since API version 22
*/
#define DP_C_API_VERSION 23
#define DP_C_API_VERSION 24

/**
* @brief Neighbor list.
Expand All @@ -31,7 +31,7 @@ extern DP_Nlist* DP_NewNlist(int inum_,
int* ilist_,
int* numneigh_,
int** firstneigh_);
/*
/**
* @brief Create a new neighbor list with communication capabilities.
* @details This function extends DP_NewNlist by adding support for parallel
* communication, allowing the neighbor list to be used in distributed
Expand Down Expand Up @@ -68,7 +68,7 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
int* recvproc,
void* world);

/*
/**
* @brief Set mask for a neighbor list.
*
* @param nl Neighbor list.
Expand All @@ -78,6 +78,16 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_,
**/
extern void DP_NlistSetMask(DP_Nlist* nl, int mask);

/**
* @brief Set mapping for a neighbor list.
*
* @param nl Neighbor list.
* @param mapping mapping from all atoms to real atoms, in size nall.
* @since API version 24
*
**/
extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping);
Comment on lines +81 to +89
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider adding size parameter for safety.

While the function signature and documentation are well-structured, consider adding a size parameter for the mapping array to prevent potential buffer overflows. This would allow runtime validation of the mapping array size against the expected size (nall).

-extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping);
+extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping, int mapping_size);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
/**
* @brief Set mapping for a neighbor list.
*
* @param nl Neighbor list.
* @param mapping mapping from all atoms to real atoms, in size nall.
* @since API version 24
*
**/
extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping);
/**
* @brief Set mapping for a neighbor list.
*
* @param nl Neighbor list.
* @param mapping mapping from all atoms to real atoms, in size nall.
* @since API version 24
*
**/
extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping, int mapping_size);

💡 Codebase verification

⚠️ Potential issue

Based on the code inspection, I can now provide the final response:

Buffer overflow risk confirmed - no size validation for mapping array

The implementation in both C API (source/api_c/src/c_api.cc) and the underlying C++ class (source/lib/include/neighbor_list.h) shows that the set_mapping function directly assigns the pointer without validating its size against nall. This could lead to buffer overflows if the provided array is smaller than expected.

Key findings:

  • C API implementation simply forwards the pointer: nl->nl.set_mapping(mapping);
  • C++ implementation in InputNlist class also directly assigns without validation: void set_mapping(int* mapping_) { mapping = mapping_; };
  • The mapping array is documented to require size nall but this is not enforced
🔗 Analysis chain

Add validation for mapping array size.

The function should validate that the mapping array size matches the expected size (nall) to prevent buffer overflows.

Let's verify if size validation is implemented in the source:


Let me check the C++ implementation of the set_mapping method to see if size validation is performed there.


Let me check the NeighborList class implementation to understand how the mapping is used and if there are any size validations.


Let me try a different approach to find the NeighborList implementation.

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check if the implementation validates the mapping array size
# Test: Search for size validation in the implementation
rg -A 5 "DP_NlistSetMapping.*\{" source/api_c/src/

Length of output: 373


Script:

#!/bin/bash
# Search for set_mapping implementation in the C++ code
ast-grep --pattern 'set_mapping($$$) { $$$ }'

# Also search for any references to mapping size validation
rg -i "set_mapping|mapping.*size" source/

Length of output: 2662


Script:

#!/bin/bash
# Search for NeighborList class implementation and usage of mapping
rg -A 10 "class.*NeighborList" source/lib/include/
rg -B 2 -A 5 "mapping\s*=" source/lib/src/neighbor_list.cc

Length of output: 113


Script:

#!/bin/bash
# Search for the NeighborList class definition and implementation
rg -A 10 "class.*NeighborList" source/
cat source/lib/include/neighbor_list.h

Length of output: 13092


/**
* @brief Delete a neighbor list.
*
Expand Down
5 changes: 5 additions & 0 deletions source/api_c/include/deepmd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,11 @@ struct InputNlist {
* @brief Set mask for this neighbor list.
*/
void set_mask(int mask) { DP_NlistSetMask(nl, mask); };
/**
* @brief Set mapping for this neighbor list.
* @param mapping mapping from all atoms to real atoms, in size nall.
*/
void set_mapping(int *mapping) { DP_NlistSetMapping(nl, mapping); };
};

/**
Expand Down
3 changes: 3 additions & 0 deletions source/api_c/src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ DP_Nlist* DP_NewNlist_comm(int inum_,
return new_nl;
}
void DP_NlistSetMask(DP_Nlist* nl, int mask) { nl->nl.set_mask(mask); }
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
nl->nl.set_mapping(mapping);
}
Comment on lines +46 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add error handling and documentation.

The function implementation should include:

  1. Error handling for null pointers
  2. Documentation explaining the purpose, parameters, and expected mapping array size
  3. Validation of the mapping array to prevent buffer overflows

Consider applying this diff:

+/**
+ * @brief Set mapping from all atoms to real atoms for neighbor list
+ * @param[in] nl Pointer to DP_Nlist object
+ * @param[in] mapping Integer array containing the mapping indices
+ * @note The mapping array size must match the total number of atoms
+ */
 void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
+  if (nl == nullptr || mapping == nullptr) {
+    nl->exception = std::string("null pointer in DP_NlistSetMapping");
+    return;
+  }
   nl->nl.set_mapping(mapping);
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
nl->nl.set_mapping(mapping);
}
/**
* @brief Set mapping from all atoms to real atoms for neighbor list
* @param[in] nl Pointer to DP_Nlist object
* @param[in] mapping Integer array containing the mapping indices
* @note The mapping array size must match the total number of atoms
*/
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
if (nl == nullptr || mapping == nullptr) {
nl->exception = std::string("null pointer in DP_NlistSetMapping");
return;
}
nl->nl.set_mapping(mapping);
}

Comment on lines +46 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add safety checks and documentation for the new mapping function.

The implementation needs several improvements for robustness and clarity:

  1. Add null pointer validation for both parameters
  2. Document the requirements for the mapping array (size, lifetime, ownership)
  3. Consider adding error handling similar to other functions in the codebase

Consider applying this diff:

+/**
+ * @brief Set the mapping array for the neighbor list
+ * @param[in] nl The neighbor list object
+ * @param[in] mapping The mapping array that will be used by the neighbor list.
+ *                    The array must remain valid for the lifetime of the neighbor list
+ *                    or until a new mapping is set.
+ * @note The function does not take ownership of the mapping array.
+ */
 void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
+  if (nl == nullptr || mapping == nullptr) {
+    return;
+  }
   nl->nl.set_mapping(mapping);
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
nl->nl.set_mapping(mapping);
}
/**
* @brief Set the mapping array for the neighbor list
* @param[in] nl The neighbor list object
* @param[in] mapping The mapping array that will be used by the neighbor list.
* The array must remain valid for the lifetime of the neighbor list
* or until a new mapping is set.
* @note The function does not take ownership of the mapping array.
*/
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
if (nl == nullptr || mapping == nullptr) {
return;
}
nl->nl.set_mapping(mapping);
}

Comment on lines +46 to +48
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add parameter validation and documentation.

While the implementation is correct, consider adding:

  1. Null pointer validation for parameters
  2. Documentation explaining the expected format and size of the mapping array
  3. Error handling for invalid inputs
+// Set the mapping for the neighbor list
+// @param nl: Pointer to the neighbor list
+// @param mapping: Array of integers defining the mapping. Must not be null.
+// @return void
 void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
+  if (nl == nullptr || mapping == nullptr) {
+    nl->exception = "Invalid null pointer in DP_NlistSetMapping";
+    return;
+  }
   nl->nl.set_mapping(mapping);
 }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
nl->nl.set_mapping(mapping);
}
// Set the mapping for the neighbor list
// @param nl: Pointer to the neighbor list
// @param mapping: Array of integers defining the mapping. Must not be null.
// @return void
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
if (nl == nullptr || mapping == nullptr) {
nl->exception = "Invalid null pointer in DP_NlistSetMapping";
return;
}
nl->nl.set_mapping(mapping);
}

void DP_DeleteNlist(DP_Nlist* nl) { delete nl; }

DP_DeepPot::DP_DeepPot() {}
Expand Down
249 changes: 249 additions & 0 deletions source/api_cc/include/DeepPotJAX.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
// SPDX-License-Identifier: LGPL-3.0-or-later
#pragma once

#include <tensorflow/c/c_api.h>
#include <tensorflow/c/eager/c_api.h>

#include "DeepPot.h"
#include "common.h"
#include "neighbor_list.h"

namespace deepmd {
/**
* @brief TensorFlow implementation for Deep Potential.
**/
class DeepPotJAX : public DeepPotBase {
public:
/**
* @brief DP constructor without initialization.
**/
DeepPotJAX();
virtual ~DeepPotJAX();
/**
* @brief DP constructor with initialization.
* @param[in] model The name of the frozen model file.
* @param[in] gpu_rank The GPU rank. Default is 0.
* @param[in] file_content The content of the model file. If it is not empty,
*DP will read from the string instead of the file.
**/
DeepPotJAX(const std::string& model,
const int& gpu_rank = 0,
const std::string& file_content = "");
/**
* @brief Initialize the DP.
* @param[in] model The name of the frozen model file.
* @param[in] gpu_rank The GPU rank. Default is 0.
* @param[in] file_content The content of the model file. If it is not empty,
*DP will read from the string instead of the file.
**/
void init(const std::string& model,
const int& gpu_rank = 0,
const std::string& file_content = "");
/**
* @brief Get the cutoff radius.
* @return The cutoff radius.
**/
double cutoff() const {
assert(inited);
return rcut;
};
/**
* @brief Get the number of types.
* @return The number of types.
**/
int numb_types() const {
assert(inited);
return ntypes;
};
/**
* @brief Get the number of types with spin.
* @return The number of types with spin.
**/
int numb_types_spin() const {
assert(inited);
return 0;
};
/**
* @brief Get the dimension of the frame parameter.
* @return The dimension of the frame parameter.
**/
int dim_fparam() const {
assert(inited);
return dfparam;
};
/**
* @brief Get the dimension of the atomic parameter.
* @return The dimension of the atomic parameter.
**/
int dim_aparam() const {
assert(inited);
return daparam;
};
/**
* @brief Get the type map (element name of the atom types) of this model.
* @param[out] type_map The type map of this model.
**/
void get_type_map(std::string& type_map);

/**
* @brief Get whether the atom dimension of aparam is nall instead of fparam.
* @param[out] aparam_nall whether the atom dimension of aparam is nall
*instead of fparam.
**/
bool is_aparam_nall() const {
assert(inited);
return false;
};

// forward to template class
void computew(std::vector<double>& ener,
std::vector<double>& force,
std::vector<double>& virial,
std::vector<double>& atom_energy,
std::vector<double>& atom_virial,
const std::vector<double>& coord,
const std::vector<int>& atype,
const std::vector<double>& box,
const std::vector<double>& fparam,
const std::vector<double>& aparam,
const bool atomic);
void computew(std::vector<double>& ener,
std::vector<float>& force,
std::vector<float>& virial,
std::vector<float>& atom_energy,
std::vector<float>& atom_virial,
const std::vector<float>& coord,
const std::vector<int>& atype,
const std::vector<float>& box,
const std::vector<float>& fparam,
const std::vector<float>& aparam,
const bool atomic);
void computew(std::vector<double>& ener,
std::vector<double>& force,
std::vector<double>& virial,
std::vector<double>& atom_energy,
std::vector<double>& atom_virial,
const std::vector<double>& coord,
const std::vector<int>& atype,
const std::vector<double>& box,
const int nghost,
const InputNlist& inlist,
const int& ago,
const std::vector<double>& fparam,
const std::vector<double>& aparam,
const bool atomic);
void computew(std::vector<double>& ener,
std::vector<float>& force,
std::vector<float>& virial,
std::vector<float>& atom_energy,
std::vector<float>& atom_virial,
const std::vector<float>& coord,
const std::vector<int>& atype,
const std::vector<float>& box,
const int nghost,
const InputNlist& inlist,
const int& ago,
const std::vector<float>& fparam,
const std::vector<float>& aparam,
const bool atomic);
void computew_mixed_type(std::vector<double>& ener,
std::vector<double>& force,
std::vector<double>& virial,
std::vector<double>& atom_energy,
std::vector<double>& atom_virial,
const int& nframes,
const std::vector<double>& coord,
const std::vector<int>& atype,
const std::vector<double>& box,
const std::vector<double>& fparam,
const std::vector<double>& aparam,
const bool atomic);
void computew_mixed_type(std::vector<double>& ener,
std::vector<float>& force,
std::vector<float>& virial,
std::vector<float>& atom_energy,
std::vector<float>& atom_virial,
const int& nframes,
const std::vector<float>& coord,
const std::vector<int>& atype,
const std::vector<float>& box,
const std::vector<float>& fparam,
const std::vector<float>& aparam,
const bool atomic);

private:
bool inited;
// device
std::string device;
// the cutoff radius
double rcut;
// the number of types
int ntypes;
// the dimension of the frame parameter
int dfparam;
// the dimension of the atomic parameter
int daparam;
// type map
std::string type_map;
// sel
std::vector<int64_t> sel;
// number of neighbors
int nnei;
/** TF C API objects.
* @{
*/
TF_Graph* graph;
TF_Status* status;
TF_Session* session;
TF_SessionOptions* sessionopts;
TFE_ContextOptions* ctx_opts;
TFE_Context* ctx;
std::vector<TF_Function*> func_vector;
Comment on lines +195 to +201
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure proper cleanup of TensorFlow C API objects in destructor

The class utilizes several TensorFlow C API objects (graph, status, session, sessionopts, ctx_opts, ctx, func_vector). To prevent memory leaks, it's crucial to release these resources appropriately in the destructor.

Consider adding cleanup code in the destructor ~DeepPotJAX():

+DeepPotJAX::~DeepPotJAX() {
+    if (session) TF_DeleteSession(session, status);
+    if (graph) TF_DeleteGraph(graph);
+    if (status) TF_DeleteStatus(status);
+    if (sessionopts) TF_DeleteSessionOptions(sessionopts);
+    if (ctx_opts) TFE_DeleteContextOptions(ctx_opts);
+    if (ctx) TFE_DeleteContext(ctx);
+    for (auto func : func_vector) {
+        TF_DeleteFunction(func);
+    }
+}

Ensure that all TensorFlow objects are properly deleted and that error checking is implemented where necessary.

Committable suggestion skipped: line range outside the PR's diff.

/**
* @}
*/
// neighbor list data
NeighborListData nlist_data;
/**
* @brief Evaluate the energy, force, virial, atomic energy, and atomic virial
*by using this DP.
* @param[out] ener The system energy.
* @param[out] force The force on each atom.
* @param[out] virial The virial.
* @param[out] atom_energy The atomic energy.
* @param[out] atom_virial The atomic virial.
* @param[in] coord The coordinates of atoms. The array should be of size
*nframes x natoms x 3.
* @param[in] atype The atom types. The list should contain natoms ints.
* @param[in] box The cell of the region. The array should be of size nframes
*x 9.
* @param[in] nghost The number of ghost atoms.
* @param[in] lmp_list The input neighbour list.
* @param[in] ago Update the internal neighbour list if ago is 0.
* @param[in] fparam The frame parameter. The array can be of size :
* nframes x dim_fparam.
* dim_fparam. Then all frames are assumed to be provided with the same
*fparam.
* @param[in] aparam The atomic parameter The array can be of size :
* nframes x natoms x dim_aparam.
* natoms x dim_aparam. Then all frames are assumed to be provided with the
*same aparam.
* @param[in] atomic Whether to compute atomic energy and virial.
**/
template <typename VALUETYPE>
void compute(std::vector<ENERGYTYPE>& ener,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Undefined type ENERGYTYPE in template function

The template function compute uses ENERGYTYPE for the ener parameter, but ENERGYTYPE is not defined within this header file. This may lead to compilation errors.

Please ensure that ENERGYTYPE is defined or include the appropriate header file where ENERGYTYPE is declared.

std::vector<VALUETYPE>& force,
std::vector<VALUETYPE>& virial,
std::vector<VALUETYPE>& atom_energy,
std::vector<VALUETYPE>& atom_virial,
const std::vector<VALUETYPE>& coord,
const std::vector<int>& atype,
const std::vector<VALUETYPE>& box,
const int nghost,
const InputNlist& lmp_list,
const int& ago,
const std::vector<VALUETYPE>& fparam,
const std::vector<VALUETYPE>& aparam,
const bool atomic);
};
} // namespace deepmd
2 changes: 1 addition & 1 deletion source/api_cc/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
namespace deepmd {

typedef double ENERGYTYPE;
enum DPBackend { TensorFlow, PyTorch, Paddle, Unknown };
enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown };
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Missing JAX case in some backend handlers

The verification reveals that JAX backend case is missing in two files:

  • source/api_cc/src/DeepTensor.cc: JAX case not handled in if-else chain
  • source/api_cc/src/DataModifier.cc: JAX case not handled in if-else chain

While DeepPot.cc properly handles the JAX backend, the other backend handlers need to be updated for consistency.

🔗 Analysis chain

LGTM! Verify switch statements for the new backend.

The addition of JAX to the DPBackend enum is clean and properly placed before the Unknown value.

Let's verify that all switch statements handling DPBackend are updated to include the JAX case:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Search for switch statements or if-else chains handling DPBackend
rg -A 10 "switch.*DPBackend|if.*DPBackend.*==" source/

Length of output: 4502

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codebase verification

Update if-else chains in multiple files to handle JAX backend

The codebase has several if-else chains that need to be updated to handle the new JAX backend:

  • source/api_cc/src/DeepPot.cc: Missing JAX case in backend comparison
  • source/api_cc/src/DataModifier.cc: Missing JAX case in backend comparison
  • source/api_cc/src/DeepTensor.cc: Missing JAX case in backend comparison

Each of these files needs to add a new condition else if (deepmd::DPBackend::JAX == backend) before the final else clause to properly handle the JAX backend.

🔗 Analysis chain

LGTM! Verify enum usage across codebase.

The addition of JAX to the DPBackend enum before Unknown is correct and maintains backward compatibility.

Let's verify the enum usage across the codebase:

🏁 Scripts executed

The following scripts were executed for the analysis:

Script:

#!/bin/bash
# Description: Check for DPBackend enum usage and potential switch statements that need updating

# Search for switch statements on DPBackend that might need updating
rg -A 10 "switch.*DPBackend" 

# Search for direct enum value comparisons
rg "DPBackend::(TensorFlow|PyTorch|Paddle|Unknown)"

Length of output: 1157


struct NeighborListData {
/// Array stores the core region atom's index
Expand Down
Loading