-
Notifications
You must be signed in to change notification settings - Fork 529
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(jax): passing mapping from LAMMPS to DPA-2 #4316
Changes from all commits
147400e
8c6d522
140f3e1
a0b8074
e6bf59f
6c10e8e
2aa6deb
f16dd92
297ae26
e64e06a
8fefce8
261c7bd
4d5ccc5
d365bbc
21fc045
660171e
d552821
0461248
f26f3fe
713d065
ccb182d
8ccead6
904042d
0f9d5c5
bad564b
2b165d7
e717ba3
f075075
58dcf2b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -297,7 +297,9 @@ If one does not need to use DeePMD-kit with LAMMPS or i-PI, then the python inte | |
|
||
::::{tab-set} | ||
|
||
:::{tab-item} TensorFlow {{ tensorflow_icon }} | ||
:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }} | ||
|
||
The C++ interfaces of both TensorFlow and JAX backends are based on the TensorFlow C++ library. | ||
|
||
Since TensorFlow 2.12, TensorFlow C++ library (`libtensorflow_cc`) is packaged inside the Python library. Thus, you can skip building TensorFlow C++ library manually. If that does not work for you, you can still build it manually. | ||
|
||
|
@@ -338,7 +340,7 @@ We recommend using [conda packages](https://docs.deepmodeling.org/faq/conda.html | |
|
||
::::{tab-set} | ||
|
||
:::{tab-item} TensorFlow {{ tensorflow_icon }} | ||
:::{tab-item} TensorFlow {{ tensorflow_icon }} / JAX {{ jax_icon }} | ||
|
||
I assume you have activated the TensorFlow Python environment and want to install DeePMD-kit into path `$deepmd_root`, then execute CMake | ||
|
||
|
@@ -375,7 +377,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value | |
|
||
**Type**: `BOOL` (`ON`/`OFF`), Default: `OFF` | ||
|
||
{{ tensorflow_icon }} Whether building the TensorFlow backend. | ||
{{ tensorflow_icon }} {{ jax_icon }} Whether building the TensorFlow backend and the JAX backend. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Consider separating backend enable flags. The current Consider either:
|
||
|
||
::: | ||
|
||
|
@@ -391,7 +393,7 @@ One may add the following CMake variables to `cmake` using the [`-D <var>=<value | |
|
||
**Type**: `PATH` | ||
|
||
{{ tensorflow_icon }} The Path to TensorFlow's C++ interface. | ||
{{ tensorflow_icon }} {{ jax_icon }} The Path to TensorFlow's C++ interface. | ||
|
||
::: | ||
|
||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -12,7 +12,7 @@ extern "C" { | |||||||||||||||||||||||||||||||||||||
/** C API version. Bumped whenever the API is changed. | ||||||||||||||||||||||||||||||||||||||
* @since API version 22 | ||||||||||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||||||||||
#define DP_C_API_VERSION 23 | ||||||||||||||||||||||||||||||||||||||
#define DP_C_API_VERSION 24 | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||||||
* @brief Neighbor list. | ||||||||||||||||||||||||||||||||||||||
|
@@ -31,7 +31,7 @@ extern DP_Nlist* DP_NewNlist(int inum_, | |||||||||||||||||||||||||||||||||||||
int* ilist_, | ||||||||||||||||||||||||||||||||||||||
int* numneigh_, | ||||||||||||||||||||||||||||||||||||||
int** firstneigh_); | ||||||||||||||||||||||||||||||||||||||
/* | ||||||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||||||
* @brief Create a new neighbor list with communication capabilities. | ||||||||||||||||||||||||||||||||||||||
* @details This function extends DP_NewNlist by adding support for parallel | ||||||||||||||||||||||||||||||||||||||
* communication, allowing the neighbor list to be used in distributed | ||||||||||||||||||||||||||||||||||||||
|
@@ -68,7 +68,7 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_, | |||||||||||||||||||||||||||||||||||||
int* recvproc, | ||||||||||||||||||||||||||||||||||||||
void* world); | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
/* | ||||||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||||||
* @brief Set mask for a neighbor list. | ||||||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||||||
* @param nl Neighbor list. | ||||||||||||||||||||||||||||||||||||||
|
@@ -78,6 +78,16 @@ extern DP_Nlist* DP_NewNlist_comm(int inum_, | |||||||||||||||||||||||||||||||||||||
**/ | ||||||||||||||||||||||||||||||||||||||
extern void DP_NlistSetMask(DP_Nlist* nl, int mask); | ||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||||||
* @brief Set mapping for a neighbor list. | ||||||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||||||
* @param nl Neighbor list. | ||||||||||||||||||||||||||||||||||||||
* @param mapping mapping from all atoms to real atoms, in size nall. | ||||||||||||||||||||||||||||||||||||||
* @since API version 24 | ||||||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||||||
**/ | ||||||||||||||||||||||||||||||||||||||
extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping); | ||||||||||||||||||||||||||||||||||||||
Comment on lines
+81
to
+89
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Consider adding size parameter for safety. While the function signature and documentation are well-structured, consider adding a size parameter for the mapping array to prevent potential buffer overflows. This would allow runtime validation of the mapping array size against the expected size ( -extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping);
+extern void DP_NlistSetMapping(DP_Nlist* nl, int* mapping, int mapping_size); 📝 Committable suggestion
Suggested change
💡 Codebase verification Based on the code inspection, I can now provide the final response: Buffer overflow risk confirmed - no size validation for mapping array The implementation in both C API ( Key findings:
🔗 Analysis chainAdd validation for mapping array size. The function should validate that the mapping array size matches the expected size ( Let's verify if size validation is implemented in the source: Let me check the C++ implementation of the Let me check the Let me try a different approach to find the NeighborList implementation. 🏁 Scripts executedThe following scripts were executed for the analysis: Script: #!/bin/bash
# Description: Check if the implementation validates the mapping array size
# Test: Search for size validation in the implementation
rg -A 5 "DP_NlistSetMapping.*\{" source/api_c/src/
Length of output: 373 Script: #!/bin/bash
# Search for set_mapping implementation in the C++ code
ast-grep --pattern 'set_mapping($$$) { $$$ }'
# Also search for any references to mapping size validation
rg -i "set_mapping|mapping.*size" source/
Length of output: 2662 Script: #!/bin/bash
# Search for NeighborList class implementation and usage of mapping
rg -A 10 "class.*NeighborList" source/lib/include/
rg -B 2 -A 5 "mapping\s*=" source/lib/src/neighbor_list.cc
Length of output: 113 Script: #!/bin/bash
# Search for the NeighborList class definition and implementation
rg -A 10 "class.*NeighborList" source/
cat source/lib/include/neighbor_list.h
Length of output: 13092 |
||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||
/** | ||||||||||||||||||||||||||||||||||||||
* @brief Delete a neighbor list. | ||||||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -43,6 +43,9 @@ DP_Nlist* DP_NewNlist_comm(int inum_, | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
return new_nl; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
void DP_NlistSetMask(DP_Nlist* nl, int mask) { nl->nl.set_mask(mask); } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
nl->nl.set_mapping(mapping); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+46
to
+48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add error handling and documentation. The function implementation should include:
Consider applying this diff: +/**
+ * @brief Set mapping from all atoms to real atoms for neighbor list
+ * @param[in] nl Pointer to DP_Nlist object
+ * @param[in] mapping Integer array containing the mapping indices
+ * @note The mapping array size must match the total number of atoms
+ */
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
+ if (nl == nullptr || mapping == nullptr) {
+ nl->exception = std::string("null pointer in DP_NlistSetMapping");
+ return;
+ }
nl->nl.set_mapping(mapping);
} 📝 Committable suggestion
Suggested change
Comment on lines
+46
to
+48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add safety checks and documentation for the new mapping function. The implementation needs several improvements for robustness and clarity:
Consider applying this diff: +/**
+ * @brief Set the mapping array for the neighbor list
+ * @param[in] nl The neighbor list object
+ * @param[in] mapping The mapping array that will be used by the neighbor list.
+ * The array must remain valid for the lifetime of the neighbor list
+ * or until a new mapping is set.
+ * @note The function does not take ownership of the mapping array.
+ */
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
+ if (nl == nullptr || mapping == nullptr) {
+ return;
+ }
nl->nl.set_mapping(mapping);
} 📝 Committable suggestion
Suggested change
Comment on lines
+46
to
+48
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Add parameter validation and documentation. While the implementation is correct, consider adding:
+// Set the mapping for the neighbor list
+// @param nl: Pointer to the neighbor list
+// @param mapping: Array of integers defining the mapping. Must not be null.
+// @return void
void DP_NlistSetMapping(DP_Nlist* nl, int* mapping) {
+ if (nl == nullptr || mapping == nullptr) {
+ nl->exception = "Invalid null pointer in DP_NlistSetMapping";
+ return;
+ }
nl->nl.set_mapping(mapping);
} 📝 Committable suggestion
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
void DP_DeleteNlist(DP_Nlist* nl) { delete nl; } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
DP_DeepPot::DP_DeepPot() {} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
// SPDX-License-Identifier: LGPL-3.0-or-later | ||
#pragma once | ||
|
||
#include <tensorflow/c/c_api.h> | ||
#include <tensorflow/c/eager/c_api.h> | ||
|
||
#include "DeepPot.h" | ||
#include "common.h" | ||
#include "neighbor_list.h" | ||
|
||
namespace deepmd { | ||
/** | ||
* @brief TensorFlow implementation for Deep Potential. | ||
**/ | ||
class DeepPotJAX : public DeepPotBase { | ||
public: | ||
/** | ||
* @brief DP constructor without initialization. | ||
**/ | ||
DeepPotJAX(); | ||
virtual ~DeepPotJAX(); | ||
/** | ||
* @brief DP constructor with initialization. | ||
* @param[in] model The name of the frozen model file. | ||
* @param[in] gpu_rank The GPU rank. Default is 0. | ||
* @param[in] file_content The content of the model file. If it is not empty, | ||
*DP will read from the string instead of the file. | ||
**/ | ||
DeepPotJAX(const std::string& model, | ||
const int& gpu_rank = 0, | ||
const std::string& file_content = ""); | ||
/** | ||
* @brief Initialize the DP. | ||
* @param[in] model The name of the frozen model file. | ||
* @param[in] gpu_rank The GPU rank. Default is 0. | ||
* @param[in] file_content The content of the model file. If it is not empty, | ||
*DP will read from the string instead of the file. | ||
**/ | ||
void init(const std::string& model, | ||
const int& gpu_rank = 0, | ||
const std::string& file_content = ""); | ||
/** | ||
* @brief Get the cutoff radius. | ||
* @return The cutoff radius. | ||
**/ | ||
double cutoff() const { | ||
assert(inited); | ||
return rcut; | ||
}; | ||
/** | ||
* @brief Get the number of types. | ||
* @return The number of types. | ||
**/ | ||
int numb_types() const { | ||
assert(inited); | ||
return ntypes; | ||
}; | ||
/** | ||
* @brief Get the number of types with spin. | ||
* @return The number of types with spin. | ||
**/ | ||
int numb_types_spin() const { | ||
assert(inited); | ||
return 0; | ||
}; | ||
/** | ||
* @brief Get the dimension of the frame parameter. | ||
* @return The dimension of the frame parameter. | ||
**/ | ||
int dim_fparam() const { | ||
assert(inited); | ||
return dfparam; | ||
}; | ||
/** | ||
* @brief Get the dimension of the atomic parameter. | ||
* @return The dimension of the atomic parameter. | ||
**/ | ||
int dim_aparam() const { | ||
assert(inited); | ||
return daparam; | ||
}; | ||
/** | ||
* @brief Get the type map (element name of the atom types) of this model. | ||
* @param[out] type_map The type map of this model. | ||
**/ | ||
void get_type_map(std::string& type_map); | ||
|
||
/** | ||
* @brief Get whether the atom dimension of aparam is nall instead of fparam. | ||
* @param[out] aparam_nall whether the atom dimension of aparam is nall | ||
*instead of fparam. | ||
**/ | ||
bool is_aparam_nall() const { | ||
assert(inited); | ||
return false; | ||
}; | ||
|
||
// forward to template class | ||
void computew(std::vector<double>& ener, | ||
std::vector<double>& force, | ||
std::vector<double>& virial, | ||
std::vector<double>& atom_energy, | ||
std::vector<double>& atom_virial, | ||
const std::vector<double>& coord, | ||
const std::vector<int>& atype, | ||
const std::vector<double>& box, | ||
const std::vector<double>& fparam, | ||
const std::vector<double>& aparam, | ||
const bool atomic); | ||
void computew(std::vector<double>& ener, | ||
std::vector<float>& force, | ||
std::vector<float>& virial, | ||
std::vector<float>& atom_energy, | ||
std::vector<float>& atom_virial, | ||
const std::vector<float>& coord, | ||
const std::vector<int>& atype, | ||
const std::vector<float>& box, | ||
const std::vector<float>& fparam, | ||
const std::vector<float>& aparam, | ||
const bool atomic); | ||
void computew(std::vector<double>& ener, | ||
std::vector<double>& force, | ||
std::vector<double>& virial, | ||
std::vector<double>& atom_energy, | ||
std::vector<double>& atom_virial, | ||
const std::vector<double>& coord, | ||
const std::vector<int>& atype, | ||
const std::vector<double>& box, | ||
const int nghost, | ||
const InputNlist& inlist, | ||
const int& ago, | ||
const std::vector<double>& fparam, | ||
const std::vector<double>& aparam, | ||
const bool atomic); | ||
void computew(std::vector<double>& ener, | ||
std::vector<float>& force, | ||
std::vector<float>& virial, | ||
std::vector<float>& atom_energy, | ||
std::vector<float>& atom_virial, | ||
const std::vector<float>& coord, | ||
const std::vector<int>& atype, | ||
const std::vector<float>& box, | ||
const int nghost, | ||
const InputNlist& inlist, | ||
const int& ago, | ||
const std::vector<float>& fparam, | ||
const std::vector<float>& aparam, | ||
const bool atomic); | ||
void computew_mixed_type(std::vector<double>& ener, | ||
std::vector<double>& force, | ||
std::vector<double>& virial, | ||
std::vector<double>& atom_energy, | ||
std::vector<double>& atom_virial, | ||
const int& nframes, | ||
const std::vector<double>& coord, | ||
const std::vector<int>& atype, | ||
const std::vector<double>& box, | ||
const std::vector<double>& fparam, | ||
const std::vector<double>& aparam, | ||
const bool atomic); | ||
void computew_mixed_type(std::vector<double>& ener, | ||
std::vector<float>& force, | ||
std::vector<float>& virial, | ||
std::vector<float>& atom_energy, | ||
std::vector<float>& atom_virial, | ||
const int& nframes, | ||
const std::vector<float>& coord, | ||
const std::vector<int>& atype, | ||
const std::vector<float>& box, | ||
const std::vector<float>& fparam, | ||
const std::vector<float>& aparam, | ||
const bool atomic); | ||
|
||
private: | ||
bool inited; | ||
// device | ||
std::string device; | ||
// the cutoff radius | ||
double rcut; | ||
// the number of types | ||
int ntypes; | ||
// the dimension of the frame parameter | ||
int dfparam; | ||
// the dimension of the atomic parameter | ||
int daparam; | ||
// type map | ||
std::string type_map; | ||
// sel | ||
std::vector<int64_t> sel; | ||
// number of neighbors | ||
int nnei; | ||
/** TF C API objects. | ||
* @{ | ||
*/ | ||
TF_Graph* graph; | ||
TF_Status* status; | ||
TF_Session* session; | ||
TF_SessionOptions* sessionopts; | ||
TFE_ContextOptions* ctx_opts; | ||
TFE_Context* ctx; | ||
std::vector<TF_Function*> func_vector; | ||
Comment on lines
+195
to
+201
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ensure proper cleanup of TensorFlow C API objects in destructor The class utilizes several TensorFlow C API objects ( Consider adding cleanup code in the destructor +DeepPotJAX::~DeepPotJAX() {
+ if (session) TF_DeleteSession(session, status);
+ if (graph) TF_DeleteGraph(graph);
+ if (status) TF_DeleteStatus(status);
+ if (sessionopts) TF_DeleteSessionOptions(sessionopts);
+ if (ctx_opts) TFE_DeleteContextOptions(ctx_opts);
+ if (ctx) TFE_DeleteContext(ctx);
+ for (auto func : func_vector) {
+ TF_DeleteFunction(func);
+ }
+} Ensure that all TensorFlow objects are properly deleted and that error checking is implemented where necessary.
|
||
/** | ||
* @} | ||
*/ | ||
// neighbor list data | ||
NeighborListData nlist_data; | ||
/** | ||
* @brief Evaluate the energy, force, virial, atomic energy, and atomic virial | ||
*by using this DP. | ||
* @param[out] ener The system energy. | ||
* @param[out] force The force on each atom. | ||
* @param[out] virial The virial. | ||
* @param[out] atom_energy The atomic energy. | ||
* @param[out] atom_virial The atomic virial. | ||
* @param[in] coord The coordinates of atoms. The array should be of size | ||
*nframes x natoms x 3. | ||
* @param[in] atype The atom types. The list should contain natoms ints. | ||
* @param[in] box The cell of the region. The array should be of size nframes | ||
*x 9. | ||
* @param[in] nghost The number of ghost atoms. | ||
* @param[in] lmp_list The input neighbour list. | ||
* @param[in] ago Update the internal neighbour list if ago is 0. | ||
* @param[in] fparam The frame parameter. The array can be of size : | ||
* nframes x dim_fparam. | ||
* dim_fparam. Then all frames are assumed to be provided with the same | ||
*fparam. | ||
* @param[in] aparam The atomic parameter The array can be of size : | ||
* nframes x natoms x dim_aparam. | ||
* natoms x dim_aparam. Then all frames are assumed to be provided with the | ||
*same aparam. | ||
* @param[in] atomic Whether to compute atomic energy and virial. | ||
**/ | ||
template <typename VALUETYPE> | ||
void compute(std::vector<ENERGYTYPE>& ener, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Undefined type The template function Please ensure that |
||
std::vector<VALUETYPE>& force, | ||
std::vector<VALUETYPE>& virial, | ||
std::vector<VALUETYPE>& atom_energy, | ||
std::vector<VALUETYPE>& atom_virial, | ||
const std::vector<VALUETYPE>& coord, | ||
const std::vector<int>& atype, | ||
const std::vector<VALUETYPE>& box, | ||
const int nghost, | ||
const InputNlist& lmp_list, | ||
const int& ago, | ||
const std::vector<VALUETYPE>& fparam, | ||
const std::vector<VALUETYPE>& aparam, | ||
const bool atomic); | ||
}; | ||
} // namespace deepmd |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,7 @@ | |
namespace deepmd { | ||
|
||
typedef double ENERGYTYPE; | ||
enum DPBackend { TensorFlow, PyTorch, Paddle, Unknown }; | ||
enum DPBackend { TensorFlow, PyTorch, Paddle, JAX, Unknown }; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Codebase verification Missing JAX case in some backend handlers The verification reveals that JAX backend case is missing in two files:
While 🔗 Analysis chainLGTM! Verify switch statements for the new backend. The addition of JAX to the DPBackend enum is clean and properly placed before the Unknown value. Let's verify that all switch statements handling DPBackend are updated to include the JAX case: 🏁 Scripts executedThe following scripts were executed for the analysis: Script: #!/bin/bash
# Search for switch statements or if-else chains handling DPBackend
rg -A 10 "switch.*DPBackend|if.*DPBackend.*==" source/
Length of output: 4502 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Codebase verification Update if-else chains in multiple files to handle JAX backend The codebase has several if-else chains that need to be updated to handle the new JAX backend:
Each of these files needs to add a new condition 🔗 Analysis chainLGTM! Verify enum usage across codebase. The addition of Let's verify the enum usage across the codebase: 🏁 Scripts executedThe following scripts were executed for the analysis: Script: #!/bin/bash
# Description: Check for DPBackend enum usage and potential switch statements that need updating
# Search for switch statements on DPBackend that might need updating
rg -A 10 "switch.*DPBackend"
# Search for direct enum value comparisons
rg "DPBackend::(TensorFlow|PyTorch|Paddle|Unknown)"
Length of output: 1157 |
||
|
||
struct NeighborListData { | ||
/// Array stores the core region atom's index | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codebase verification
Additional large model files need exclusion
The scan revealed several model files exceeding the 1024KB limit that are not currently excluded:
source/tests/infer/deepdipole_new.pbtxt
(916KB)source/tests/infer/deepdipole_fake.pbtxt
(1002KB)source/tests/infer/deeppot_sea.savedmodel/saved_model.pb
(630KB)These files should also be added to the exclusion list in
.pre-commit-config.yaml
to maintain consistency with the current approach of excluding large model files.🔗 Analysis chain
LGTM! Consider checking for additional model files.
The exclusion patterns for the new model files are consistent with the existing pattern. However, since this PR is part of the JAX backend integration, we should verify if there are other similar model files that might need exclusion.
Let me try a different approach to find potential large model files by using
fd
and checking file sizes directly.🏁 Scripts executed
The following scripts were executed for the analysis:
Script:
Length of output: 292
Script:
Length of output: 2875