Skip to content

Commit

Permalink
Input parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
ceriottm committed Nov 11, 2024
1 parent d514555 commit 59130b9
Show file tree
Hide file tree
Showing 19 changed files with 216 additions and 203 deletions.
8 changes: 6 additions & 2 deletions drivers/py/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,14 @@ def run_driver(

args = parser.parse_args()

driver_args, driver_kwargs = parse_args_kwargs(args.param)

if args.mode in __drivers__:
d_f = __drivers__[args.mode](args.param, args.verbose)
d_f = __drivers__[args.mode](
*driver_args, verbose=args.verbose, **driver_kwargs
)
elif args.mode == "dummy":
d_f = Dummy_driver(args.param, args.verbose)
d_f = Dummy_driver(verbose=args.verbose)
else:
raise ValueError("Unsupported driver mode ", args.mode)

Expand Down
37 changes: 36 additions & 1 deletion ipi/pes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import importlib
import traceback

__all__ = []
__all__ = ["parse_args_kwargs"]

# Dictionary to store driver name to class mapping
__drivers__ = {}
Expand Down Expand Up @@ -33,4 +33,39 @@
f"PES module `{module_name}` does not define __DRIVER_CLASS__ and __DRIVER_NAME__"
)


def _parse_value(s):
"""Attempt to parse a string to int or float; fallback to string."""
s = s.strip()
for cast in (int, float):
try:
return cast(s)
except ValueError:
continue
return s


def parse_args_kwargs(input_str):
"""
Parses a string into positional arguments and keyword arguments.
Args:
input_str (str): The input string containing comma-separated values and key-value pairs.
Returns:
tuple: A tuple containing a list of positional arguments and a dictionary of keyword arguments.
"""
args = []
kwargs = {}
tokens = input_str.split(",")
for token in tokens:
token = token.strip()
if "=" in token:
key, value = token.split("=", 1)
kwargs[key.strip()] = _parse_value(value)
elif len(token) > 0:
args.append(_parse_value(token))
return args, kwargs


__all__.append("__drivers__")
42 changes: 26 additions & 16 deletions ipi/pes/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,8 @@
from ipi.utils.units import unit_to_internal, unit_to_user
from ipi.utils.messages import warning

try:
from ase.io import read
except ImportError:
warning("Could not find or import the ASE module")

read = None

__DRIVER_NAME__ = "ase"
__DRIVER_CLASS__ = "ASEDriver"
Expand All @@ -24,24 +22,36 @@


class ASEDriver(Dummy_driver):
"""Abstract base class using an arbitrary ASE calculator as i-pi driver"""

def __init__(self, args=None, verbose=False, error_msg=ERROR_MSG):
super().__init__(args, verbose, error_msg=error_msg)

def check_arguments(self):
"""Abstract base class using an arbitrary ASE calculator as i-pi driver.
Parameters:
:param verbose: bool, whether to print verbose output
:param template: string, where to get the structure from
"""

_error_msg = """
ASEDriver has two arguments:
verbose (a bool flag) and template (a string holding the name of an ASE-readable
structure to initialize the calculator)
"""

def __init__(self, template, *args, **kwargs):
global read
try:
from ase.io import read
except ImportError:
warning("Could not find or import the ASE module")

self.template = template
super().__init__(*args, **kwargs)

def check_parameters(self):
"""Check the arguments required to run the driver
This loads the potential and atoms template in metatensor
"""

if len(self.args) >= 1:
self.template = self.args[0]
else:
sys.exit(self.error_msg)

self.template_ase = read(self.template)

self.ase_calculator = None

def __call__(self, cell, pos):
Expand Down
5 changes: 4 additions & 1 deletion ipi/pes/bath.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
class Harmonic_Bath_explicit(object):
"""Explicit description of an Harmonic bath"""

def __init__(self, nbath, parameters):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# (self, nbath, parameters):
self.nbath = nbath

Check warning on line 16 in ipi/pes/bath.py

View workflow job for this annotation

GitHub Actions / lint

F821 undefined name 'nbath'
self.m = parameters["m"]

Check warning on line 17 in ipi/pes/bath.py

View workflow job for this annotation

GitHub Actions / lint

F821 undefined name 'parameters'
# self.delta = parameters["delta"]
Expand Down
15 changes: 8 additions & 7 deletions ipi/pes/doubledoublewell.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,17 @@ class DDW_with_explicit_bath_driver(Dummy_driver):
! DDW(q1,q2) = DW(q1) + DW(q2) + C(q1q2)^2
"""

def __init__(self, args=None, verbose=None):
self.error_msg = r"""\nDW+explicit_bath driver expects 11 arguments.\n
_error_msg = r"""\nDW+explicit_bath driver expects 11 arguments.\n
Example: python driver.py -m DoubleWell_with_explicit_bath -o wb1 (cm^-1) V1 (cm^-1) wb2 (cm^-1) V2 (cm^-1) coupling(au) mass delta(\AA) eta0 eps1 eps2 deltaQ omega_c(cm^-1) \n
python driver.py -m DoubleWell -o 500,2085,500,2085,0.1,1837,0.00,1,0,0,1,500\n"""
super(DDW_with_explicit_bath_driver, self).__init__(
args, error_msg=self.error_msg
)

def __init__(self, *args, **kwargs):

self.param = list(map(str, args))
self.init = False
super().__init__(*args, **kwargs)

def check_arguments(self):
def check_parameters(self):
"""Function that checks the arguments required to run the driver"""

try:
Expand All @@ -83,7 +84,7 @@ def check_arguments(self):

except:
print("Received arguments:")
sys.exit(self.error_msg)
sys.exit(self._error_msg)

self.A1 = -0.5 * self.m * (wb1) ** 2
self.B1 = ((self.m**2) * (wb1) ** 4) / (16 * v1)
Expand Down
11 changes: 6 additions & 5 deletions ipi/pes/doublewell.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@


class DoubleWell_driver(Dummy_driver):
def __init__(self, args=None, verbose=None):
self.error_msg = r"""\nDW driver accepts 0 or 4 arguments.\nExample: python driver.py -m DoubleWell -o omega_b (cm^-1) V0 (cm^-1) mass(a.u) delta(angs) \n
_error_msg = r"""\nDW driver accepts 0 or 4 arguments.\nExample: python driver.py -m DoubleWell -o omega_b (cm^-1) V0 (cm^-1) mass(a.u) delta(angs) \n
python driver.py -m DoubleWell -o 500,2085,1837,0.00 \n"""
super(DoubleWell_driver, self).__init__(args, error_msg=self.error_msg)

def check_arguments(self):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def check_parameters(self):
"""Function that checks the arguments required to run the driver"""
self.k = 1837.36223469 * (3800.0 / 219323.0) ** 2
if self.args == "":
Expand All @@ -51,7 +52,7 @@ def check_arguments(self):
m = param[2]
self.delta = param[3] * A2au
except:
sys.exit(self.error_msg)
sys.exit(self._error_msg)

self.A = -0.5 * m * (w_b) ** 2
self.B = ((m**2) * (w_b) ** 4) / (16 * v0)
Expand Down
13 changes: 6 additions & 7 deletions ipi/pes/doublewell_with_bath.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,17 +39,16 @@ class DoubleWell_with_explicit_bath_driver(Dummy_driver):
! If eps1=eps2=0 then sd(q) =1 and s(q) = q --->Spatially independent bath
"""

def __init__(self, args=None, verbose=None):
self.error_msg = r"""\nDW+explicit_bath driver expects 9 arguments.\n
_error_msg = r"""\nDW+explicit_bath driver expects 9 arguments.\n
Example: python driver.py -m DoubleWell_with_explicit_bath -o omega_b (cm^-1) V0 (cm^-1) mass delta(\AA) eta0 eps1 eps2 deltaQ omega_c(cm^-1) \n
python driver.py -m DoubleWell -o 500,2085,1837,0.00,1,0,0,1,500\n"""
super(DoubleWell_with_explicit_bath_driver, self).__init__(
args, error_msg=self.error_msg
)

def __init__(self, *args, **kwargs):

self.init = False
super().__init__(*args, **kwargs)

def check_arguments(self):
def check_parameters(self):
"""Function that checks the arguments required to run the driver"""

try:
Expand All @@ -70,7 +69,7 @@ def check_arguments(self):
self.bath_parameters["w_c"] = param[8] * invcm2au

except:
sys.exit(self.error_msg)
sys.exit(self._error_msg)

self.A = -0.5 * self.m * (w_b) ** 2
self.B = ((self.m**2) * (w_b) ** 4) / (16 * v0)
Expand Down
14 changes: 7 additions & 7 deletions ipi/pes/doublewell_with_friction.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,23 @@


class DoubleWell_with_friction_driver(DoubleWell_driver):
"""Adds to the double well potential the calculation of the friction tensor.
r"""Adds to the double well potential the calculation of the friction tensor.
friction(q) = eta0 [\partial sd(q) \partial q ]^2
with
q = position, and
sd(q) = [1+eps1 exp( (q-0)^2 / (2deltaQ^2) ) ] + eps2 tanh(q/deltaQ)
"""

def __init__(self, args=None, verbose=None):
self.error_msg = r"""\nDW+fric driver expects 8 arguments.\n
_error_msg = r"""\nDW+fric driver expects 8 arguments.\n
Example: python driver.py -m DoubleWell_with_fric -o omega_b (cm^-1) V0 (cm^-1) mass delta(\AA) eta0 eps1 eps2 deltaQ \n
python driver.py -m DoubleWell -o 500,2085,1837,0.00,1,0,0,1\n"""

def __init__(self, *args, **kwargs):
self.args = args.split(",")
self.verbose = verbose
self.check_arguments()
super().__init__(*args, **kwargs)

def check_arguments(self):
def check_parameters(self):
"""Function that checks the arguments required to run the driver"""

self.k = 1837.36223469 * (3800.0 / 219323.0) ** 2
Expand All @@ -61,7 +61,7 @@ def check_arguments(self):
self.eps2 = param[6]
self.deltaQ = param[7]
except:
sys.exit(self.error_msg)
sys.exit(self._error_msg)

self.A = -0.5 * m * (w_b) ** 2
self.B = ((m**2) * (w_b) ** 4) / (16 * v0)
Expand Down
13 changes: 7 additions & 6 deletions ipi/pes/driverdipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,19 +131,20 @@ class driverdipole_driver(Dummy_driver):
"restart": False, # whether remove the files (if already existing) where the dipole and BEC will be saved.
}

def __init__(self, args=None):
self.error_msg = """The parameters of 'driverdipole_driver' are not correctly formatted. \
_error_msg = """The parameters of 'driverdipole_driver' are not correctly formatted. \
They should be two or three strings, separated by a comma."""

def __init__(self, *args, **kwargs):
self.opts = dict()
self.count = 0
super().__init__(args)
super().__init__(*args, **kwargs)

def check_arguments(self):
def check_parameters(self):
"""Check the arguments required to run the driver."""
try:
arglist = self.args.split(",")
except ValueError:
sys.exit(self.error_msg)
sys.exit(self._error_msg)

if len(arglist) >= 2:
info_file = arglist[0] # json file to properly allocate a 'model' object
Expand All @@ -154,7 +155,7 @@ def check_arguments(self):
print("\tNo options file provided: using the default values")
opts_file = None
else:
sys.exit(self.error_msg) # to be modified
sys.exit(self._error_msg) # to be modified

print("\n\tThe driver is 'driverdipole_driver'")
print("\tLoading model ...")
Expand Down
21 changes: 13 additions & 8 deletions ipi/pes/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,23 @@


class Dummy_driver(object):
"""Base class providing the structure of a PES for the python driver."""
"""Base class providing the structure of a PES for the python driver.
def __init__(
self, args="", verbose=False, error_msg="Invalid arguments for the PES"
):
Init arguments:
param verbose: bool to determine whether the PES should output verbose info.
"""

_error_msg = "Invalid arguments for the PES"

def __init__(self, verbose=False, *args, **kwargs):
"""Initialized dummy drivers"""
self.error_msg = error_msg
self.args = args.split(",")
self.verbose = verbose
self.check_arguments()
self.args = args
self.kwargs = kwargs

self.check_parameters()

def check_arguments(self):
def check_parameters(self):
"""Dummy function that checks the arguments required to run the driver"""
pass

Expand Down
2 changes: 1 addition & 1 deletion ipi/pes/elphmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class ModelIIIDriver(Dummy_driver):
"""Wrapper around elphmod MD driver."""

def check_arguments(self):
def check_parameters(self):
"""Check arguments and load driver instance."""

import elphmod
Expand Down
25 changes: 11 additions & 14 deletions ipi/pes/harmonic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,24 @@
__DRIVER_NAME__ = "harmonic"
__DRIVER_CLASS__ = "Harmonic_driver"

ERROR_MSG = """

class Harmonic_driver(Dummy_driver):

_error_msg = """
Harmonic driver requires specification of force constant.
Example: python driver.py -m harmonic -u -o 1.3
"""

def __init__(self, k1, k2=None, k3=None, *args, **kwargs):

class Harmonic_driver(Dummy_driver):
def __init__(self, args=None, verbose=False):
super(Harmonic_driver, self).__init__(args, verbose, error_msg=ERROR_MSG)

def check_arguments(self):
"""Function that checks the arguments required to run the driver"""

if len(self.args) == 1:
self.k = float(self.args[0])
if k2 == None or k3 == None:
self.k = k1
self.type = "isotropic"
elif len(self.args) == 3:
self.k = np.asarray(list(map(float, self.args)))
self.type = "non-isotropic"
else:
sys.exit(self.error_msg)
self.k = np.asarray([k1, k2, k3])
self.type = "non-isotropic"

super().__init__(*args, **kwargs)

def __call__(self, cell, pos):
"""Silly harmonic potential"""
Expand Down
Loading

0 comments on commit 59130b9

Please sign in to comment.