Skip to content

Commit

Permalink
Add regtest without any dependency, using pre-generated model
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Apr 25, 2024
1 parent 9dccf6c commit 33546c6
Show file tree
Hide file tree
Showing 15 changed files with 266 additions and 1 deletion.
1 change: 1 addition & 0 deletions regtest/metatensor/rt-basic/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
include ../../scripts/test.make
6 changes: 6 additions & 0 deletions regtest/metatensor/rt-basic/config
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
plumed_modules=metatensor
plumed_needs=metatensor
type=driver

# NOTE: to enable --debug-forces, also change the dtype of the models to float64
arg="--plumed plumed.dat --ixyz structure.xyz --length-units A --dump-forces forces --dump-forces-fmt %8.2f" # --debug-forces forces.num"
164 changes: 164 additions & 0 deletions regtest/metatensor/rt-basic/cv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from typing import Dict, List, Optional

import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import (
MetatensorAtomisticModel,
ModelCapabilities,
ModelMetadata,
ModelOutput,
NeighborsListOptions,
System,
)


class TestCollectiveVariable(torch.nn.Module):
r"""
This class computes a simple CV which is then used to test metatensor integration
with PLUMED.
The per-atom CV is defined as a sum over all pairs for an atom:
CV^1_i = \sum_j 1/r_ij
CV^2_i = \sum_j 1/r_ij^2
The global CV is a sum over all atoms of the per-atom CV:
CV^1 = \sum_i CV^1_i
CV^2 = \sum_i CV^2_i
If ``multiple_properties=True``, a only CV^1 is returned, otherwise both CV^1 and
CV^2 are returned.
"""

def __init__(self, cutoff, multiple_properties):
super().__init__()

self._nl_request = NeighborsListOptions(cutoff=cutoff, full_list=True)
self._multiple_properties = multiple_properties

def forward(
self,
systems: List[System],
outputs: Dict[str, ModelOutput],
selected_atoms: Optional[Labels],
) -> Dict[str, TensorMap]:

if "plumed::cv" not in outputs:
return {}

device = torch.device("cpu")
if len(systems) > 0:
device = systems[0].positions.device

if selected_atoms is not None:
raise ValueError("selected atoms is not supported")

output = outputs["plumed::cv"]

if output.per_atom:
samples_list: List[List[int]] = []
for s, system in enumerate(systems):
for i in range(len(system)):
samples_list.append([s, i])

samples = Labels(
["system", "atom"],
torch.tensor(samples_list, device=device),
)
else:
samples = Labels(
"system", torch.arange(len(systems), device=device).reshape(-1, 1)
)

if self._multiple_properties:
properties = Labels("cv", torch.tensor([[0], [1]], device=device))
else:
properties = Labels("cv", torch.tensor([[0]], device=device))

values = torch.zeros(
(len(samples), len(properties)), dtype=torch.float32, device=device
)
system_start = 0
for system_i, system in enumerate(systems):
system_stop = system_start + len(system)

neighbors = system.get_neighbors_list(self._nl_request)

atom_index = neighbors.samples.column("first_atom")
distances = torch.linalg.vector_norm(neighbors.values.reshape(-1, 3), dim=1)
inv_dist = 1.0 / distances

if distances.shape[0] != 0:
if output.per_atom:
sliced = values[system_start:system_stop, 0]
sliced += sliced.index_add(0, atom_index, inv_dist)
else:
values[system_i, 0] += inv_dist.sum()

if self._multiple_properties:
if output.per_atom:
sliced = values[system_start:system_stop, 1]
sliced += sliced.index_add(0, atom_index, inv_dist**2)
else:
values[system_i, 1] += inv_dist.sum() ** 2

system_start = system_stop

block = TensorBlock(
values=values,
samples=samples,
components=[],
properties=properties,
)
cv = TensorMap(
keys=Labels("_", torch.tensor([[0]], device=device)),
blocks=[block],
)

return {"plumed::cv": cv}

def requested_neighbors_lists(self) -> List[NeighborsListOptions]:
return [self._nl_request]


CUTOFF = 3.5

capabilities = ModelCapabilities(
outputs={"plumed::cv": ModelOutput(per_atom=True)},
interaction_range=CUTOFF,
supported_devices=["cpu", "mps", "cuda"],
length_unit="A",
atomic_types=[6],
dtype="float32",
)

# export all variations of the model
cv = TestCollectiveVariable(cutoff=CUTOFF, multiple_properties=False)
cv.eval()
model = MetatensorAtomisticModel(cv, ModelMetadata(), capabilities)
model.export("scalar-per-atom.pt")

cv = TestCollectiveVariable(cutoff=CUTOFF, multiple_properties=True)
cv.eval()
model = MetatensorAtomisticModel(cv, ModelMetadata(), capabilities)
model.export("vector-per-atom.pt")

capabilities = ModelCapabilities(
outputs={"plumed::cv": ModelOutput(per_atom=False)},
interaction_range=CUTOFF,
supported_devices=["cpu", "mps", "cuda"],
length_unit="A",
atomic_types=[6],
dtype="float32",
)

cv = TestCollectiveVariable(cutoff=CUTOFF, multiple_properties=False)
cv.eval()
model = MetatensorAtomisticModel(cv, ModelMetadata(), capabilities)
model.export("scalar-global.pt")

cv = TestCollectiveVariable(cutoff=CUTOFF, multiple_properties=True)
cv.eval()
model = MetatensorAtomisticModel(cv, ModelMetadata(), capabilities)
model.export("vector-global.pt")
11 changes: 11 additions & 0 deletions regtest/metatensor/rt-basic/forces.reference
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
9
-4809.71 -5697.44 -5071.73
X 147.68 -79.47 -178.64
X -174.52 -55.59 124.40
X -216.07 78.68 152.02
X -425.43 203.97 277.68
X 71.52 -44.53 35.34
X 231.25 -151.21 -440.05
X 141.26 -55.91 -67.49
X 208.93 -18.51 211.77
X 15.37 122.58 -115.02
64 changes: 64 additions & 0 deletions regtest/metatensor/rt-basic/plumed.dat
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
scalar_global: METATENSOR ...
MODEL=scalar-global.pt
DEVICE=cpu
# DEVICE=cuda
# DEVICE=mps

SPECIES1=1-9
SPECIES_TO_TYPES=6
...

PRINT ARG=scalar_global FILE=scalar_global FMT=%8.2f


scalar_per_atom: METATENSOR ...
MODEL=scalar-per-atom.pt
DEVICE=cpu
# DEVICE=cuda
# DEVICE=mps

SPECIES1=1-9
SPECIES_TO_TYPES=6
...

PRINT ARG=scalar_per_atom FILE=scalar_per_atom FMT=%8.2f


vector_global: METATENSOR ...
MODEL=vector-global.pt
DEVICE=cpu
# DEVICE=cuda
# DEVICE=mps

SPECIES1=1-9
SPECIES_TO_TYPES=6
...

PRINT ARG=vector_global FILE=vector_global FMT=%8.2f


vector_per_atom: METATENSOR ...
MODEL=vector-per-atom.pt
DEVICE=cpu
# DEVICE=cuda
# DEVICE=mps

SPECIES1=1-9
SPECIES_TO_TYPES=6
...

PRINT ARG=vector_per_atom FILE=vector_per_atom FMT=%8.2f


scalar_per_atom_sum: SUM ARG=scalar_per_atom PERIODIC=NO
vector_global_sum: SUM ARG=vector_global PERIODIC=NO
vector_per_atom_sum: SUM ARG=vector_per_atom PERIODIC=NO

summed: CUSTOM ...
ARG=scalar_global,scalar_per_atom_sum,vector_global_sum,vector_per_atom_sum
VAR=x,y,z,t
FUNC=x+y+z+t
PERIODIC=NO
...

BIASVALUE ARG=summed
Binary file added regtest/metatensor/rt-basic/scalar-global.pt
Binary file not shown.
Binary file added regtest/metatensor/rt-basic/scalar-per-atom.pt
Binary file not shown.
2 changes: 2 additions & 0 deletions regtest/metatensor/rt-basic/scalar_global.reference
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#! FIELDS time scalar_global
0.000000 87.04
2 changes: 2 additions & 0 deletions regtest/metatensor/rt-basic/scalar_per_atom.reference
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#! FIELDS time scalar_per_atom.1 scalar_per_atom.2 scalar_per_atom.3 scalar_per_atom.4 scalar_per_atom.5 scalar_per_atom.6 scalar_per_atom.7 scalar_per_atom.8 scalar_per_atom.9
0.000000 9.73 9.34 9.93 9.06 8.87 10.28 9.32 9.93 10.58
11 changes: 11 additions & 0 deletions regtest/metatensor/rt-basic/structure.xyz
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
9
4.0 4.0 4.0
C 2.3535 5.10858 4.48679
C 2.83322 3.38203 3.004
C 3.66083 1.24405 2.65283
C 3.81093 3.40544 1.76377
C 4.22474 0.188689 3.59337
C 4.29399 7.00701 1.25775
C 4.52552 4.72654 1.62947
C 4.63227 2.20046 2.00559
C 5.4412 1.75131 0.827934
Binary file added regtest/metatensor/rt-basic/vector-global.pt
Binary file not shown.
Binary file added regtest/metatensor/rt-basic/vector-per-atom.pt
Binary file not shown.
2 changes: 2 additions & 0 deletions regtest/metatensor/rt-basic/vector_global.reference
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#! FIELDS time vector_global.1 vector_global.2
0.000000 87.04 7575.59
2 changes: 2 additions & 0 deletions regtest/metatensor/rt-basic/vector_per_atom.reference
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#! FIELDS time vector_per_atom.1.1 vector_per_atom.1.2 vector_per_atom.2.1 vector_per_atom.2.2 vector_per_atom.3.1 vector_per_atom.3.2 vector_per_atom.4.1 vector_per_atom.4.2 vector_per_atom.5.1 vector_per_atom.5.2 vector_per_atom.6.1 vector_per_atom.6.2 vector_per_atom.7.1 vector_per_atom.7.2 vector_per_atom.8.1 vector_per_atom.8.2 vector_per_atom.9.1 vector_per_atom.9.2
0.000000 9.73 3.95 9.34 3.68 9.93 4.29 9.06 5.04 8.87 3.78 10.28 5.47 9.32 4.34 9.93 4.72 10.58 4.50
2 changes: 1 addition & 1 deletion src/metatensor/metatensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ MetatensorPlumedAction::MetatensorPlumedAction(const ActionOptions& options):
if (n_samples_ == 1 && n_properties_ == 1) {
log.printf(" the output of this model is a scalar\n");

this->addValue({this->n_samples_, this->n_properties_});
this->addValue();
} else if (n_samples_ == 1) {
log.printf(" the output of this model is 1x%d vector\n", n_properties_);

Expand Down

0 comments on commit 33546c6

Please sign in to comment.