-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
i-pi driver now automatically imports all modules within pes/
- Loading branch information
Showing
6 changed files
with
118 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,30 @@ | ||
""" Small functions/classes providing access to driver PES to be called from driver.py """ | ||
|
||
from .dummy import Dummy_driver | ||
from .harmonic import Harmonic_driver | ||
from .rascal import Rascal_driver | ||
|
||
__all__ = ["__drivers__", "Dummy_driver", "Harmonic_driver", "Rascal_driver"] | ||
|
||
# dictionary linking strings | ||
__drivers__ = { | ||
"dummy": Dummy_driver, | ||
"harmonic": Harmonic_driver, | ||
"rascal": Rascal_driver, | ||
} | ||
import pkgutil | ||
import importlib | ||
|
||
__all__ = [] | ||
|
||
# Dictionary to store driver name to class mapping | ||
__drivers__ = {} | ||
|
||
# Iterate through all modules in the current package folder | ||
for loader, module_name, is_pkg in pkgutil.iter_modules(__path__): | ||
# Import the module | ||
module = importlib.import_module("." + module_name, __package__) | ||
|
||
# Get the driver class and name from the module | ||
driver_class = getattr(module, "__DRIVER_CLASS__", None) | ||
driver_name = getattr(module, "__DRIVER_NAME__", None) | ||
|
||
# If both class and name are defined, update __all__ and __drivers__ | ||
if driver_class and driver_name: | ||
__all__.append(driver_class) | ||
__drivers__[driver_name] = getattr(module, driver_class) | ||
globals()[driver_class] = getattr(module, driver_class) # add class to globals | ||
else: | ||
raise ImportError( | ||
f"PES module {module_name} does not define __DRIVER_CLASS__ and __DRIVER_NAME__" | ||
) | ||
|
||
__all__.append("__drivers__") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import numpy as np | ||
from ipi.utils.units import unit_to_internal, unit_to_user | ||
from .dummy import Dummy_driver | ||
|
||
from ase import Atoms | ||
from ase.io import read | ||
|
||
from mace.calculators import MACECalculator | ||
|
||
__DRIVER_NAME__ = ( | ||
"mace" # this is how the driver will be referred to in the input files | ||
) | ||
__DRIVER_CLASS__ = "MACE_driver" | ||
|
||
|
||
class MACE_driver(Dummy_driver): | ||
def __init__(self, args=None): | ||
self.error_msg = """Rascal driver requires specification of a .json model file fitted with librascal, | ||
and a template file that describes the chemical makeup of the structure. | ||
Example: python driver.py -m rascal -u -o model.json,template.xyz""" | ||
|
||
super().__init__(args) | ||
|
||
def check_arguments(self): | ||
"""Check the arguments requMACECalculatorred to run the driver | ||
This loads the potential and atoms template in librascal | ||
""" | ||
try: | ||
arglist = self.args.split(",") | ||
except ValueError: | ||
sys.exit(self.error_msg) | ||
|
||
self.model_atoms = read(arglist[0]) | ||
self.driver_example_atoms = arglist[0] | ||
self.driver_model_path = arglist[1] | ||
|
||
def __call__(self, cell, pos): | ||
"""Get energies, forces, and stresses from the librascal model""" | ||
pos_calc = unit_to_user("length", "angstrom", pos) | ||
cell_calc = unit_to_user("length", "angstrom", cell.T) | ||
|
||
atoms = read(self.driver_example_atoms) | ||
atoms.set_pbc([True, True, True]) | ||
atoms.set_cell(cell_calc, scale_atoms=True) | ||
atoms.set_positions(pos_calc) | ||
|
||
mace_calculator = MACECalculator( | ||
model_path=self.driver_model_path, device="cpu" | ||
) | ||
atoms.set_calculator(mace_calculator) | ||
|
||
pot = atoms.get_potential_energy() | ||
force = atoms.get_forces() | ||
stress = atoms.get_stress(voigt=False) | ||
|
||
pot_ipi = float(unit_to_internal("energy", "electronvolt", pot)) | ||
force_ipi = np.array( | ||
unit_to_internal("force", "ev/ang", force.reshape(-1, 3)), dtype=np.float64 | ||
) | ||
|
||
vir_calc = -stress * self.model_atoms.get_volume() | ||
vir_ipi = np.array( | ||
unit_to_internal("energy", "electronvolt", vir_calc.T), dtype=np.float64 | ||
) | ||
extras = "" | ||
|
||
return pot_ipi, force_ipi, vir_ipi, extras |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters