forked from plumed/plumed2
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added basic regression test for the CV based on metatensor that Ali a…
…nd co are using
- Loading branch information
Gareth Aneurin Tribello
authored and
Gareth Aneurin Tribello
committed
Apr 24, 2024
1 parent
e3369d4
commit 993d301
Showing
6 changed files
with
9,184 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
include ../../scripts/test.make |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
plumed_modules=metatensor | ||
plumed_needs=metatensor | ||
type=driver | ||
arg="--plumed plumed.dat --ixyz traj.xyz --length-units A" # --debug-forces forces.num | ||
|
||
function plumed_regtest_before(){ | ||
python soap_cv.py | ||
echo Generated model using soap_cv.py | ||
} |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
spex_A: METATENSOR ... | ||
MODEL=soap_cv.pt | ||
EXTENSIONS_DIRECTORY=extensions | ||
SPECIES=1-900:3 | ||
SPECIES_TO_TYPES=8 | ||
... | ||
|
||
# spex_A: SPHERICAL_INVARIANTS ... | ||
# SPECIES1=1-900:3 | ||
# HYPERPARAMS={ | ||
# "max_radial": 8, | ||
# "max_angular": 6, | ||
# "compute_gradients": false, | ||
# "normalize": false, | ||
# "soap_type": "PowerSpectrum", | ||
# "cutoff_function": {"type": "ShiftedCosine", "cutoff": {"value": 3.7, "unit": "AA"}, "smooth_width": {"value": 0.1, "unit": "AA"}}, | ||
# "gaussian_density": {"type": "Constant", "gaussian_sigma": {"value": 1.0, "unit": "AA"}}, | ||
# "radial_contribution": {"type": "GTO"} | ||
# } | ||
# ... | ||
|
||
spex_AT: TRANSPOSE ARG=spex_A | ||
|
||
ones: ONES SIZE=300 | ||
sum_A: MATRIX_VECTOR_PRODUCT ARG=spex_AT,ones | ||
num_O: SUM ARG=ones PERIODIC=NO #count number of oxygens | ||
av_soap: CUSTOM ARG=sum_A,num_O FUNC=x/y PERIODIC=NO #divide SUM_A by 300 to obtain the snapshot average soap descriptor | ||
av_soapT: TRANSPOSE ARG=av_soap | ||
|
||
PRINT ARG=spex_A STRIDE=1 FILE=final FMT=%8.4f | ||
PRINT ARG=av_soapT STRIDE=1 FILE=global_vec FMT=%8.4f |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
from typing import Dict, List, Optional | ||
|
||
import torch | ||
from metatensor.torch import Labels, TensorBlock, TensorMap | ||
from metatensor.torch.atomistic import ( | ||
MetatensorAtomisticModel, | ||
ModelCapabilities, | ||
ModelMetadata, | ||
ModelOutput, | ||
System, | ||
) | ||
from rascaline.torch import SoapPowerSpectrum | ||
|
||
|
||
class SOAP_CV(torch.nn.Module): | ||
def __init__(self, species): | ||
super().__init__() | ||
|
||
self.neighbor_type_pairs = Labels( | ||
names=["neighbor_1_type", "neighbor_2_type"], | ||
values=torch.tensor( | ||
[[t1, t2] for t1 in species for t2 in species if t1 <= t2] | ||
), | ||
) | ||
self.calculator = SoapPowerSpectrum( | ||
cutoff=0.37, | ||
max_angular=6, | ||
max_radial=8, | ||
radial_basis={"Gto": {}}, | ||
cutoff_function={"ShiftedCosine": {"width": 0.01}}, | ||
center_atom_weight=1.0, | ||
atomic_gaussian_width=0.1, | ||
) | ||
|
||
def forward( | ||
self, | ||
systems: List[System], | ||
outputs: Dict[str, ModelOutput], | ||
selected_atoms: Optional[Labels], | ||
) -> Dict[str, TensorMap]: | ||
|
||
if "plumed::cv" not in outputs: | ||
return {} | ||
|
||
output = outputs["plumed::cv"] | ||
|
||
device = torch.device("cpu") | ||
if len(systems) > 0: | ||
device = systems[0].positions.device | ||
|
||
soap = self.calculator(systems, selected_samples=selected_atoms) | ||
soap = soap.keys_to_samples("center_type") | ||
soap = soap.keys_to_properties(self.neighbor_type_pairs) | ||
|
||
if not output.per_atom: | ||
raise ValueError("per_atom=False is not supported") | ||
|
||
soap_block = soap.block() | ||
#projected = soap_block.values @ self.pca_projection | ||
|
||
samples = soap_block.samples.remove("center_type") | ||
|
||
block = TensorBlock( | ||
values=soap_block.values, | ||
samples=samples, | ||
components=[], | ||
properties=soap_block.properties, | ||
) | ||
cv = TensorMap( | ||
keys=Labels("_", torch.tensor([[0]], device=device)), | ||
blocks=[block], | ||
) | ||
|
||
return {"plumed::cv": cv} | ||
|
||
|
||
cv = SOAP_CV(species=[1]) | ||
cv.eval() | ||
|
||
|
||
capabilities = ModelCapabilities( | ||
outputs={ | ||
"plumed::cv": ModelOutput( | ||
quantity="", | ||
unit="", | ||
per_atom=True, | ||
explicit_gradients=["postions"], | ||
) | ||
}, | ||
interaction_range=0.37, | ||
supported_devices=["cpu", "mps", "cuda"], | ||
length_unit="nm", | ||
atomic_types=[8], | ||
dtype="float64", | ||
) | ||
|
||
metadata = ModelMetadata( | ||
name="Soap component retrieval", | ||
description=""" | ||
Retrieving all soap components for testing purposes | ||
""", | ||
authors=["Gareth Tribello"], | ||
references={ | ||
"implementation": ["ref to SOAP code"], | ||
"architecture": ["ref to SOAP"], | ||
"model": ["ref to paper"], | ||
}, | ||
) | ||
|
||
|
||
model = MetatensorAtomisticModel(cv, metadata, capabilities) | ||
model.export("soap_cv.pt", collect_extensions="extensions") |
Oops, something went wrong.