Skip to content

Commit

Permalink
Added a PET driver
Browse files Browse the repository at this point in the history
  • Loading branch information
ceriottm committed Nov 14, 2023
1 parent fa0a756 commit 3861466
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 13 deletions.
2 changes: 1 addition & 1 deletion drivers/py/pes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
globals()[driver_class] = getattr(module, driver_class) # add class to globals
else:
raise ImportError(
f"PES module {module_name} does not define __DRIVER_CLASS__ and __DRIVER_NAME__"
f"PES module `{module_name}` does not define __DRIVER_CLASS__ and __DRIVER_NAME__"
)

__all__.append("__drivers__")
5 changes: 2 additions & 3 deletions drivers/py/pes/dummy.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
__DRIVER_NAME__ = (
"dummy" # this is how the driver will be referred to in the input files
)
# this is how the driver will be referred to in the input files
__DRIVER_NAME__ = "dummy"
__DRIVER_CLASS__ = "Dummy_driver"


Expand Down
4 changes: 1 addition & 3 deletions drivers/py/pes/harmonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
import sys
from .dummy import Dummy_driver

__DRIVER_NAME__ = (
"harmonic" # this is how the driver will be referred to in the input files
)
__DRIVER_NAME__ = "harmonic"
__DRIVER_CLASS__ = "Harmonic_driver"


Expand Down
4 changes: 1 addition & 3 deletions drivers/py/pes/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@

from mace.calculators import MACECalculator

__DRIVER_NAME__ = (
"mace" # this is how the driver will be referred to in the input files
)
__DRIVER_NAME__ = "mace"
__DRIVER_CLASS__ = "MACE_driver"


Expand Down
76 changes: 76 additions & 0 deletions drivers/py/pes/pet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Interface with librascal to run machine learning potentials"""

import sys, os
import numpy as np
from .dummy import Dummy_driver

from ipi.utils.mathtools import det_ut3x3
from ipi.utils.units import unit_to_internal, unit_to_user

try:
from ase.io import read
except:
raise ImportError("The PET driver has an ASE dependency")

try:
sys.path.append(os.getcwd()+'/pet/src')
from single_struct_calculator import SingleStructCalculator as PETCalc
except:
PETCalc = None

__DRIVER_NAME__ = "pet"
__DRIVER_CLASS__ = "PET_driver"

class PET_driver(Dummy_driver):
def __init__(self, args=None):
self.error_msg = """
The PET driver requires specification of a .json model file fitted with librascal,
and a template file that describes the chemical makeup of the structure.
Example: python driver.py -m pet -u -o model.json,template.xyz
"""

super().__init__(args)

if PETCalc is None:
raise ImportError("Couldn't load PET bindings")

def check_arguments(self):
"""Check the arguments required to run the driver
This loads the potential and atoms template in librascal
"""
try:
arglist = self.args.split(",")
except ValueError:
sys.exit(self.error_msg)

if len(arglist) == 2:
self.model = arglist[0]
self.template = arglist[1]
else:
sys.exit(self.error_msg)

self.template_ase = read(self.template)
self.template_ase.arrays['forces']=np.zeros_like(self.template_ase.positions)
self.pet_calc = PETCalc(self.model)

def __call__(self, cell, pos):
"""Get energies, forces, and stresses from the PET model"""
pos_pet = unit_to_user("length", "angstrom", pos)
# librascal expects ASE-format, cell-vectors-as-rows
cell_pet = unit_to_user("length", "angstrom", cell.T)
# applies the cell and positions to the template
pet_structure = self.template_ase.copy()
pet_structure.positions = pos_pet
pet_structure.cell = cell_pet

# Do the actual calculation
pot, force = self.pet_calc.forward(pet_structure)
pot_ipi = np.asarray(unit_to_internal("energy", "electronvolt", pot), np.float64)
force_ipi = np.asarray(unit_to_internal("force", "ev/ang", force), np.float64)
# PET does not yet compute stress
vir_pet = 0*np.eye(3)
vir_ipi = unit_to_internal("energy", "electronvolt", vir_pet.T)
extras = ""
return pot_ipi, force_ipi, vir_ipi, extras
4 changes: 1 addition & 3 deletions drivers/py/pes/rascal.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
except:
RascalCalc = None

__DRIVER_NAME__ = (
"rascal" # this is how the driver will be referred to in the input files
)
__DRIVER_NAME__ = "rascal"
__DRIVER_CLASS__ = "Rascal_driver"


Expand Down

0 comments on commit 3861466

Please sign in to comment.