Skip to content

Commit

Permalink
i-pi driver now automatically imports all modules within pes/
Browse files Browse the repository at this point in the history
  • Loading branch information
ceriottm committed Nov 14, 2023
1 parent c6d1ec6 commit fa0a756
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 12 deletions.
4 changes: 4 additions & 0 deletions drivers/py/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import argparse
import numpy as np

from pes.dummy import Dummy_driver
from pes import Dummy_driver
from pes import *

try:
from pes import *
except ImportError:
Expand Down
40 changes: 28 additions & 12 deletions drivers/py/pes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
""" Small functions/classes providing access to driver PES to be called from driver.py """

from .dummy import Dummy_driver
from .harmonic import Harmonic_driver
from .rascal import Rascal_driver

__all__ = ["__drivers__", "Dummy_driver", "Harmonic_driver", "Rascal_driver"]

# dictionary linking strings
__drivers__ = {
"dummy": Dummy_driver,
"harmonic": Harmonic_driver,
"rascal": Rascal_driver,
}
import pkgutil
import importlib

__all__ = []

# Dictionary to store driver name to class mapping
__drivers__ = {}

# Iterate through all modules in the current package folder
for loader, module_name, is_pkg in pkgutil.iter_modules(__path__):
# Import the module
module = importlib.import_module("." + module_name, __package__)

# Get the driver class and name from the module
driver_class = getattr(module, "__DRIVER_CLASS__", None)
driver_name = getattr(module, "__DRIVER_NAME__", None)

# If both class and name are defined, update __all__ and __drivers__
if driver_class and driver_name:
__all__.append(driver_class)
__drivers__[driver_name] = getattr(module, driver_class)
globals()[driver_class] = getattr(module, driver_class) # add class to globals
else:
raise ImportError(
f"PES module {module_name} does not define __DRIVER_CLASS__ and __DRIVER_NAME__"
)

__all__.append("__drivers__")
8 changes: 8 additions & 0 deletions drivers/py/pes/dummy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
__DRIVER_NAME__ = (
"dummy" # this is how the driver will be referred to in the input files
)
__DRIVER_CLASS__ = "Dummy_driver"


class Dummy_driver(object):
"""A dummy class providing the structure of an PES for the python driver."""

def __init__(self, args=None):
"""Initialized dummy drivers"""
self.args = args
Expand Down
5 changes: 5 additions & 0 deletions drivers/py/pes/harmonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
import sys
from .dummy import Dummy_driver

__DRIVER_NAME__ = (
"harmonic" # this is how the driver will be referred to in the input files
)
__DRIVER_CLASS__ = "Harmonic_driver"


class Harmonic_driver(Dummy_driver):
def __init__(self, args=None):
Expand Down
68 changes: 68 additions & 0 deletions drivers/py/pes/mace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import numpy as np
from ipi.utils.units import unit_to_internal, unit_to_user
from .dummy import Dummy_driver

from ase import Atoms
from ase.io import read

from mace.calculators import MACECalculator

__DRIVER_NAME__ = (
"mace" # this is how the driver will be referred to in the input files
)
__DRIVER_CLASS__ = "MACE_driver"


class MACE_driver(Dummy_driver):
def __init__(self, args=None):
self.error_msg = """Rascal driver requires specification of a .json model file fitted with librascal,
and a template file that describes the chemical makeup of the structure.
Example: python driver.py -m rascal -u -o model.json,template.xyz"""

super().__init__(args)

def check_arguments(self):
"""Check the arguments requMACECalculatorred to run the driver
This loads the potential and atoms template in librascal
"""
try:
arglist = self.args.split(",")
except ValueError:
sys.exit(self.error_msg)

self.model_atoms = read(arglist[0])
self.driver_example_atoms = arglist[0]
self.driver_model_path = arglist[1]

def __call__(self, cell, pos):
"""Get energies, forces, and stresses from the librascal model"""
pos_calc = unit_to_user("length", "angstrom", pos)
cell_calc = unit_to_user("length", "angstrom", cell.T)

atoms = read(self.driver_example_atoms)
atoms.set_pbc([True, True, True])
atoms.set_cell(cell_calc, scale_atoms=True)
atoms.set_positions(pos_calc)

mace_calculator = MACECalculator(
model_path=self.driver_model_path, device="cpu"
)
atoms.set_calculator(mace_calculator)

pot = atoms.get_potential_energy()
force = atoms.get_forces()
stress = atoms.get_stress(voigt=False)

pot_ipi = float(unit_to_internal("energy", "electronvolt", pot))
force_ipi = np.array(
unit_to_internal("force", "ev/ang", force.reshape(-1, 3)), dtype=np.float64
)

vir_calc = -stress * self.model_atoms.get_volume()
vir_ipi = np.array(
unit_to_internal("energy", "electronvolt", vir_calc.T), dtype=np.float64
)
extras = ""

return pot_ipi, force_ipi, vir_ipi, extras
5 changes: 5 additions & 0 deletions drivers/py/pes/rascal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
except:
RascalCalc = None

__DRIVER_NAME__ = (
"rascal" # this is how the driver will be referred to in the input files
)
__DRIVER_CLASS__ = "Rascal_driver"


class Rascal_driver(Dummy_driver):
def __init__(self, args=None):
Expand Down

0 comments on commit fa0a756

Please sign in to comment.