From fa0a756ee68821132bcedba97a61f30fe2621fff Mon Sep 17 00:00:00 2001 From: Michele Ceriotti Date: Tue, 14 Nov 2023 12:43:30 -0800 Subject: [PATCH] i-pi driver now automatically imports all modules within pes/ --- drivers/py/driver.py | 4 +++ drivers/py/pes/__init__.py | 40 +++++++++++++++------- drivers/py/pes/dummy.py | 8 +++++ drivers/py/pes/harmonic.py | 5 +++ drivers/py/pes/mace.py | 68 ++++++++++++++++++++++++++++++++++++++ drivers/py/pes/rascal.py | 5 +++ 6 files changed, 118 insertions(+), 12 deletions(-) create mode 100644 drivers/py/pes/mace.py diff --git a/drivers/py/driver.py b/drivers/py/driver.py index 14b5ef014..bd6d565aa 100755 --- a/drivers/py/driver.py +++ b/drivers/py/driver.py @@ -3,6 +3,10 @@ import argparse import numpy as np +from pes.dummy import Dummy_driver +from pes import Dummy_driver +from pes import * + try: from pes import * except ImportError: diff --git a/drivers/py/pes/__init__.py b/drivers/py/pes/__init__.py index bd391184a..79306da43 100644 --- a/drivers/py/pes/__init__.py +++ b/drivers/py/pes/__init__.py @@ -1,14 +1,30 @@ """ Small functions/classes providing access to driver PES to be called from driver.py """ -from .dummy import Dummy_driver -from .harmonic import Harmonic_driver -from .rascal import Rascal_driver - -__all__ = ["__drivers__", "Dummy_driver", "Harmonic_driver", "Rascal_driver"] - -# dictionary linking strings -__drivers__ = { - "dummy": Dummy_driver, - "harmonic": Harmonic_driver, - "rascal": Rascal_driver, -} +import pkgutil +import importlib + +__all__ = [] + +# Dictionary to store driver name to class mapping +__drivers__ = {} + +# Iterate through all modules in the current package folder +for loader, module_name, is_pkg in pkgutil.iter_modules(__path__): + # Import the module + module = importlib.import_module("." + module_name, __package__) + + # Get the driver class and name from the module + driver_class = getattr(module, "__DRIVER_CLASS__", None) + driver_name = getattr(module, "__DRIVER_NAME__", None) + + # If both class and name are defined, update __all__ and __drivers__ + if driver_class and driver_name: + __all__.append(driver_class) + __drivers__[driver_name] = getattr(module, driver_class) + globals()[driver_class] = getattr(module, driver_class) # add class to globals + else: + raise ImportError( + f"PES module {module_name} does not define __DRIVER_CLASS__ and __DRIVER_NAME__" + ) + +__all__.append("__drivers__") diff --git a/drivers/py/pes/dummy.py b/drivers/py/pes/dummy.py index 2b49f678e..2a1213c83 100644 --- a/drivers/py/pes/dummy.py +++ b/drivers/py/pes/dummy.py @@ -1,4 +1,12 @@ +__DRIVER_NAME__ = ( + "dummy" # this is how the driver will be referred to in the input files +) +__DRIVER_CLASS__ = "Dummy_driver" + + class Dummy_driver(object): + """A dummy class providing the structure of an PES for the python driver.""" + def __init__(self, args=None): """Initialized dummy drivers""" self.args = args diff --git a/drivers/py/pes/harmonic.py b/drivers/py/pes/harmonic.py index 3c2271b9d..0ad16648d 100644 --- a/drivers/py/pes/harmonic.py +++ b/drivers/py/pes/harmonic.py @@ -3,6 +3,11 @@ import sys from .dummy import Dummy_driver +__DRIVER_NAME__ = ( + "harmonic" # this is how the driver will be referred to in the input files +) +__DRIVER_CLASS__ = "Harmonic_driver" + class Harmonic_driver(Dummy_driver): def __init__(self, args=None): diff --git a/drivers/py/pes/mace.py b/drivers/py/pes/mace.py new file mode 100644 index 000000000..b84a1286a --- /dev/null +++ b/drivers/py/pes/mace.py @@ -0,0 +1,68 @@ +import numpy as np +from ipi.utils.units import unit_to_internal, unit_to_user +from .dummy import Dummy_driver + +from ase import Atoms +from ase.io import read + +from mace.calculators import MACECalculator + +__DRIVER_NAME__ = ( + "mace" # this is how the driver will be referred to in the input files +) +__DRIVER_CLASS__ = "MACE_driver" + + +class MACE_driver(Dummy_driver): + def __init__(self, args=None): + self.error_msg = """Rascal driver requires specification of a .json model file fitted with librascal, + and a template file that describes the chemical makeup of the structure. + Example: python driver.py -m rascal -u -o model.json,template.xyz""" + + super().__init__(args) + + def check_arguments(self): + """Check the arguments requMACECalculatorred to run the driver + + This loads the potential and atoms template in librascal + """ + try: + arglist = self.args.split(",") + except ValueError: + sys.exit(self.error_msg) + + self.model_atoms = read(arglist[0]) + self.driver_example_atoms = arglist[0] + self.driver_model_path = arglist[1] + + def __call__(self, cell, pos): + """Get energies, forces, and stresses from the librascal model""" + pos_calc = unit_to_user("length", "angstrom", pos) + cell_calc = unit_to_user("length", "angstrom", cell.T) + + atoms = read(self.driver_example_atoms) + atoms.set_pbc([True, True, True]) + atoms.set_cell(cell_calc, scale_atoms=True) + atoms.set_positions(pos_calc) + + mace_calculator = MACECalculator( + model_path=self.driver_model_path, device="cpu" + ) + atoms.set_calculator(mace_calculator) + + pot = atoms.get_potential_energy() + force = atoms.get_forces() + stress = atoms.get_stress(voigt=False) + + pot_ipi = float(unit_to_internal("energy", "electronvolt", pot)) + force_ipi = np.array( + unit_to_internal("force", "ev/ang", force.reshape(-1, 3)), dtype=np.float64 + ) + + vir_calc = -stress * self.model_atoms.get_volume() + vir_ipi = np.array( + unit_to_internal("energy", "electronvolt", vir_calc.T), dtype=np.float64 + ) + extras = "" + + return pot_ipi, force_ipi, vir_ipi, extras diff --git a/drivers/py/pes/rascal.py b/drivers/py/pes/rascal.py index bd4b6c208..d21b726b6 100644 --- a/drivers/py/pes/rascal.py +++ b/drivers/py/pes/rascal.py @@ -11,6 +11,11 @@ except: RascalCalc = None +__DRIVER_NAME__ = ( + "rascal" # this is how the driver will be referred to in the input files +) +__DRIVER_CLASS__ = "Rascal_driver" + class Rascal_driver(Dummy_driver): def __init__(self, args=None):