diff --git a/src/autoplex/auto/rss/jobs.py b/src/autoplex/auto/rss/jobs.py index 60ed679f5..6bc224eb0 100644 --- a/src/autoplex/auto/rss/jobs.py +++ b/src/autoplex/auto/rss/jobs.py @@ -1,6 +1,7 @@ """RSS Jobs include the generation of the initial potential model as well as iterative RSS exploration.""" import logging +from typing import Literal from jobflow import Flow, Response, job @@ -54,7 +55,7 @@ def initial_rss( force_max: float | None = None, force_label: str = "REF_forces", pre_database_dir: str | None = None, - mlip_type: str = "GAP", + mlip_type: Literal["GAP", "J-ACE", "NEP", "NEQUIP", "M3GNET", "MACE"] = "GAP", ref_energy_name: str = "REF_energy", ref_force_name: str = "REF_forces", ref_virial_name: str = "REF_virial", @@ -137,9 +138,8 @@ def initial_rss( The label of force values to use for distillation. Default is 'REF_forces'. pre_database_dir: str | None Directory where the previous database was saved. Default is None. - mlip_type: str - Choose one specific MLIP type to be fitted: 'GAP' | 'J-ACE' | 'NEQUIP' | 'M3GNET' | 'MACE'. - Default is 'GAP'. + mlip_type: Literal["GAP", "J-ACE", "NEP", "NEQUIP", "M3GNET", "MACE"] + Choose one specific MLIP type to be fitted. Default is 'GAP'. ref_energy_name: str Reference energy name. Default is 'REF_energy'. ref_force_name: str @@ -286,7 +286,7 @@ def do_rss_iterations( distillation: bool = True, force_max: float = 200, force_label: str = "REF_forces", - mlip_type: str = "GAP", + mlip_type: Literal["GAP", "J-ACE", "NEP", "NEQUIP", "M3GNET", "MACE"] = "GAP", ref_energy_name: str = "REF_energy", ref_force_name: str = "REF_forces", ref_virial_name: str = "REF_virial", @@ -409,8 +409,8 @@ def do_rss_iterations( Maximum force value to exclude structures. Default is 200. force_label: str The label of force values to use for distillation. Default is 'REF_forces'. - mlip_type: str - Choose one specific MLIP type: 'GAP' | 'J-ACE' | 'NequIP' | 'M3GNet' | 'MACE'. Default is 'GAP'. + mlip_type: Literal["GAP", "J-ACE", "NEP", "NEQUIP", "M3GNET", "MACE"] + Choose one specific MLIP type to be fitted. Default is 'GAP'. ref_energy_name: str Reference energy name. Default is 'REF_energy'. ref_force_name: str diff --git a/src/autoplex/data/rss/jobs.py b/src/autoplex/data/rss/jobs.py index b4a1e853c..42ec3468a 100644 --- a/src/autoplex/data/rss/jobs.py +++ b/src/autoplex/data/rss/jobs.py @@ -7,6 +7,7 @@ from pathlib import Path from shutil import which from subprocess import run +from typing import Literal import ase.io import numpy as np @@ -413,7 +414,7 @@ def _parallel_process( @job def do_rss_single_node( - mlip_type: str, + mlip_type: Literal["GAP", "J-ACE", "NEP", "NEQUIP", "M3GNET", "MACE"], mlip_path: str, iteration_index: str, structures: list[Structure], @@ -441,9 +442,8 @@ def do_rss_single_node( Parameters ---------- - mlip_type: str - Choose one specific MLIP type: - 'GAP' | 'J-ACE' | 'NequIP' | 'M3GNet' | 'MACE'. + mlip_type: Literal["GAP", "J-ACE", "NEP", "NEQUIP", "M3GNET", "MACE"] + Choose one specific MLIP type to be fitted. mlip_path: str Path to the MLIP model. iteration_index: str @@ -521,7 +521,7 @@ def do_rss_single_node( @job def do_rss_multi_node( - mlip_type: str, + mlip_type: Literal["GAP", "J-ACE", "NEP", "NEQUIP", "M3GNET", "MACE"], mlip_path: str, iteration_index: str, structure: list[Structure] | list[list[Structure]] | None = None, @@ -550,9 +550,8 @@ def do_rss_multi_node( Parameters ---------- - mlip_type: str - Choose one specific MLIP type: - 'GAP' | 'J-ACE' | 'NequIP' | 'M3GNet' | 'MACE'. + mlip_type: Literal["GAP", "J-ACE", "NEP", "NEQUIP", "M3GNET", "MACE"] + Choose one specific MLIP type to be fitted. mlip_path: str Path to the MLIP model. iteration_index: str diff --git a/src/autoplex/data/rss/utils.py b/src/autoplex/data/rss/utils.py index d0101f44b..d2f99a019 100644 --- a/src/autoplex/data/rss/utils.py +++ b/src/autoplex/data/rss/utils.py @@ -553,7 +553,7 @@ def build_traj(): def minimize_structures( - mlip_type: str, + mlip_type: Literal["GAP", "J-ACE", "NEP", "NEQUIP", "M3GNET", "MACE"], mlip_path: str, iteration_index: str, structures: list[Structure], @@ -580,9 +580,8 @@ def minimize_structures( Parameters ---------- - mlip_type: str - Choose one specific MLIP type: - 'GAP' | 'J-ACE' | 'NequIP' | 'M3GNet' | 'MACE'. + mlip_type: Literal["GAP", "J-ACE", "NEP", "NEQUIP", "M3GNET", "MACE"] + Choose one specific MLIP type to be fitted. mlip_path: str Path to the MLIP model. iteration_index: str diff --git a/src/autoplex/fitting/common/flows.py b/src/autoplex/fitting/common/flows.py index a17fc444b..9cbca388d 100644 --- a/src/autoplex/fitting/common/flows.py +++ b/src/autoplex/fitting/common/flows.py @@ -40,9 +40,8 @@ class MLIPFitMaker(Maker): ---------- name : str Name of the flows produced by this maker. - mlip_type: str - Choose one specific MLIP type to be fitted: - 'GAP' | 'J-ACE' | 'NEQUIP' | 'M3GNET' | 'MACE' + mlip_type: Literal["GAP", "J-ACE", "NEP", "NEQUIP", "M3GNET", "MACE"] + Choose one specific MLIP type to be fitted. hyperpara_opt: bool Perform hyperparameter optimization using XPOT (XPOT: https://pubs.aip.org/aip/jcp/article/159/2/024803/2901815)