Skip to content

Commit

Permalink
output ForceFieldStructureTaskDocument or ForceFieldMoleculeTaskDocum…
Browse files Browse the repository at this point in the history
…ent based on the input type of mol_or_struct.

change name ForceFieldTaskDocument => ForceFieldStructureTaskDocument

output ForceFieldStructureTaskDocument or ForceFieldMoleculeTaskDocument based on type of mol_or_struct

update ForceFieldTaskDocument => ForceFieldStructureTaskDocument in the tests

import Union from typing

include Union in forcefield/md.py

take the suggestions from the formatter

take ruff's suggestions

try again with ruff format

ruff format again

try again ruff

ruff again

fix the mypy error

Take inputs of both Molecule and Structure

update docstring

add molecule test for forcefield
  • Loading branch information
yaoyi92 committed Feb 18, 2025
1 parent 4f4b607 commit 6e8c8e0
Show file tree
Hide file tree
Showing 7 changed files with 200 additions and 79 deletions.
21 changes: 0 additions & 21 deletions src/atomate2/ase/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,27 +233,6 @@ class AseStructureTaskDoc(StructureMetadata):

tags: Optional[list[str]] = Field(None, description="List of tags for the task.")

@classmethod
def from_ase_task_doc(
cls, ase_task_doc: AseTaskDoc, **task_document_kwargs
) -> AseStructureTaskDoc:
"""Create an AseStructureTaskDoc for a task that has ASE-compatible outputs.
Parameters
----------
ase_task_doc : AseTaskDoc
Task doc for the calculation
task_document_kwargs : dict
Additional keyword args passed to :obj:`.AseStructureTaskDoc()`.
"""
task_document_kwargs.update(
{k: getattr(ase_task_doc, k) for k in _task_doc_translation_keys},
structure=ase_task_doc.mol_or_struct,
)
return cls.from_structure(
meta_structure=ase_task_doc.mol_or_struct, **task_document_kwargs
)


class AseMoleculeTaskDoc(MoleculeMetadata):
"""Document containing information on molecule manipulation using ASE."""
Expand Down
57 changes: 30 additions & 27 deletions src/atomate2/forcefields/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@

from atomate2.ase.jobs import AseRelaxMaker
from atomate2.forcefields import MLFF, _get_formatted_ff_name
from atomate2.forcefields.schemas import ForceFieldTaskDocument
from atomate2.forcefields.schemas import (
ForceFieldMoleculeTaskDocument,
ForceFieldStructureTaskDocument,
ForceFieldTaskDocument,
)
from atomate2.forcefields.utils import ase_calculator, revert_default_dtype

if TYPE_CHECKING:
from collections.abc import Callable
from pathlib import Path

from ase.calculators.calculator import Calculator
from pymatgen.core.structure import Structure
from pymatgen.core.structure import Molecule, Structure

logger = logging.getLogger(__name__)

Expand All @@ -48,7 +52,8 @@ def forcefield_job(method: Callable) -> job:
This is a thin wrapper around :obj:`~jobflow.core.job.Job` that configures common
settings for all forcefield jobs. For example, it ensures that large data objects
(currently only trajectories) are all stored in the atomate2 data store.
It also configures the output schema to be a ForceFieldTaskDocument :obj:`.TaskDoc`.
It also configures the output schema to be a
ForceFieldStructureTaskDocument :obj:`.TaskDoc`.
Any makers that return forcefield jobs (not flows) should decorate the
``make`` method with @forcefield_job. For example:
Expand All @@ -72,9 +77,7 @@ def make(structure):
callable
A decorated version of the make function that will generate forcefield jobs.
"""
return job(
method, data=_FORCEFIELD_DATA_OBJECTS, output_schema=ForceFieldTaskDocument
)
return job(method, data=_FORCEFIELD_DATA_OBJECTS)


@dataclass
Expand Down Expand Up @@ -118,7 +121,7 @@ class ForceFieldRelaxMaker(AseRelaxMaker):
tags : list[str] or None
A list of tags for the task.
task_document_kwargs : dict (deprecated)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
"""

name: str = "Force field relax"
Expand Down Expand Up @@ -146,15 +149,15 @@ def __post_init__(self) -> None:

@forcefield_job
def make(
self, structure: Structure, prev_dir: str | Path | None = None
) -> ForceFieldTaskDocument:
self, structure: Molecule | Structure, prev_dir: str | Path | None = None
) -> ForceFieldStructureTaskDocument | ForceFieldMoleculeTaskDocument:
"""
Perform a relaxation of a structure using a force field.
Parameters
----------
structure: .Structure
pymatgen structure.
structure: .Structure or Molecule
pymatgen structure or molecule.
prev_dir : str or Path or None
A previous calculation directory to copy output files from. Unused, just
added to match the method signature of other makers.
Expand All @@ -170,7 +173,7 @@ def make(
stacklevel=1,
)

return ForceFieldTaskDocument.from_ase_compatible_result(
return ForceFieldTaskDocument.from_ase_compatible_result_forcefield(
str(self.force_field_name), # make mypy happy
ase_result,
self.steps,
Expand Down Expand Up @@ -212,7 +215,7 @@ class ForceFieldStaticMaker(ForceFieldRelaxMaker):
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict (deprecated)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
"""

name: str = "Force field static"
Expand Down Expand Up @@ -255,7 +258,7 @@ class CHGNetRelaxMaker(ForceFieldRelaxMaker):
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict (deprecated)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
"""

name: str = f"{MLFF.CHGNet} relax"
Expand Down Expand Up @@ -291,7 +294,7 @@ class CHGNetStaticMaker(ForceFieldStaticMaker):
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict (deprecated)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
"""

name: str = f"{MLFF.CHGNet} static"
Expand Down Expand Up @@ -334,7 +337,7 @@ class M3GNetRelaxMaker(ForceFieldRelaxMaker):
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict (deprecated)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
"""

name: str = f"{MLFF.M3GNet} relax"
Expand Down Expand Up @@ -372,7 +375,7 @@ class M3GNetStaticMaker(ForceFieldStaticMaker):
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict (deprecated)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
"""

name: str = f"{MLFF.M3GNet} static"
Expand Down Expand Up @@ -415,7 +418,7 @@ class NEPRelaxMaker(ForceFieldRelaxMaker):
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict (deprecated)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
"""

name: str = f"{MLFF.NEP} relax"
Expand Down Expand Up @@ -451,7 +454,7 @@ class NEPStaticMaker(ForceFieldStaticMaker):
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict (deprecated)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
"""

name: str = f"{MLFF.NEP} static"
Expand Down Expand Up @@ -494,7 +497,7 @@ class NequipRelaxMaker(ForceFieldRelaxMaker):
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict (deprecated)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
"""

name: str = f"{MLFF.Nequip} relax"
Expand Down Expand Up @@ -529,7 +532,7 @@ class NequipStaticMaker(ForceFieldStaticMaker):
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict (deprecated)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
"""

name: str = f"{MLFF.Nequip} static"
Expand Down Expand Up @@ -576,7 +579,7 @@ class MACERelaxMaker(ForceFieldRelaxMaker):
trained for Matbench Discovery on the MPtrj dataset available at
https://figshare.com/articles/dataset/22715158.
task_document_kwargs : dict (deprecated)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
"""

name: str = f"{MLFF.MACE_MP_0} relax"
Expand Down Expand Up @@ -616,7 +619,7 @@ class MACEStaticMaker(ForceFieldStaticMaker):
trained for Matbench Discovery on the MPtrj dataset available at
https://figshare.com/articles/dataset/22715158.
task_document_kwargs : dict (deprecated)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
"""

name: str = f"{MLFF.MACE_MP_0} static"
Expand Down Expand Up @@ -665,7 +668,7 @@ class SevenNetRelaxMaker(ForceFieldRelaxMaker):
trained for Matbench Discovery on the MPtrj dataset available at
https://figshare.com/articles/dataset/22715158.
task_document_kwargs : dict (deprecated)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
"""

name: str = f"{MLFF.SevenNet} relax"
Expand Down Expand Up @@ -707,7 +710,7 @@ class SevenNetStaticMaker(ForceFieldStaticMaker):
trained for Matbench Discovery on the MPtrj dataset available at
https://figshare.com/articles/dataset/22715158.
task_document_kwargs : dict (deprecated)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
"""

name: str = f"{MLFF.SevenNet} static"
Expand Down Expand Up @@ -747,7 +750,7 @@ class GAPRelaxMaker(ForceFieldRelaxMaker):
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict (deprecated)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
"""

name: str = f"{MLFF.GAP} relax"
Expand Down Expand Up @@ -783,7 +786,7 @@ class GAPStaticMaker(ForceFieldStaticMaker):
calculator_kwargs : dict
Keyword arguments that will get passed to the ASE calculator.
task_document_kwargs : dict (deprecated)
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
Additional keyword args passed to :obj:`.ForceFieldStructureTaskDocument()`.
"""

name: str = f"{MLFF.GAP} static"
Expand Down
17 changes: 10 additions & 7 deletions src/atomate2/forcefields/md.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,18 @@
_DEFAULT_CALCULATOR_KWARGS,
_FORCEFIELD_DATA_OBJECTS,
)
from atomate2.forcefields.schemas import ForceFieldTaskDocument
from atomate2.forcefields.schemas import (
ForceFieldMoleculeTaskDocument,
ForceFieldStructureTaskDocument,
ForceFieldTaskDocument,
)
from atomate2.forcefields.utils import ase_calculator, revert_default_dtype

if TYPE_CHECKING:
from pathlib import Path

from ase.calculators.calculator import Calculator
from pymatgen.core.structure import Structure
from pymatgen.core.structure import Molecule, Structure


@dataclass
Expand Down Expand Up @@ -126,19 +130,18 @@ def __post_init__(self) -> None:

@job(
data=[*_FORCEFIELD_DATA_OBJECTS, "ionic_steps"],
output_schema=ForceFieldTaskDocument,
)
def make(
self,
structure: Structure,
structure: Molecule | Structure,
prev_dir: str | Path | None = None,
) -> ForceFieldTaskDocument:
) -> ForceFieldStructureTaskDocument | ForceFieldMoleculeTaskDocument:
"""
Perform MD on a structure using forcefields and jobflow.
Parameters
----------
structure: .Structure
structure: .Structure or Molecule
pymatgen structure.
prev_dir : str or Path or None
A previous calculation directory to copy output files from. Unused, just
Expand All @@ -156,7 +159,7 @@ def make(
stacklevel=1,
)

return ForceFieldTaskDocument.from_ase_compatible_result(
return ForceFieldTaskDocument.from_ase_compatible_result_forcefield(
str(self.force_field_name), # make mypy happy
md_result,
relax_cell=(self.ensemble == MDEnsemble.npt),
Expand Down
Loading

0 comments on commit 6e8c8e0

Please sign in to comment.