Skip to content

Commit

Permalink
Implement selected_atoms
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed May 22, 2024
1 parent 86ca791 commit 5e33e80
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 45 deletions.
14 changes: 14 additions & 0 deletions regtest/metatensor/rt-soap/plumed.dat
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,23 @@ soap: METATENSOR ...
SPECIES_TO_TYPES=6,1,8
...

soap_selected: METATENSOR ...
MODEL=soap_cv.pt
EXTENSIONS_DIRECTORY=extensions

SPECIES1=1-26
SPECIES2=27-62
SPECIES3=63-76
SPECIES_TO_TYPES=6,1,8

# select out of order to make sure this is respected in the output
SELECTED_ATOMS=2,3,1
...


scalar: SUM ARG=soap PERIODIC=NO
BIASVALUE ARG=scalar


PRINT ARG=soap FILE=soap_data STRIDE=1 FMT=%8.4f
PRINT ARG=soap_selected FILE=soap_selected_data STRIDE=1 FMT=%8.4f
2 changes: 2 additions & 0 deletions regtest/metatensor/rt-soap/soap_selected_data.reference
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#! FIELDS time soap_selected.1.1 soap_selected.1.2 soap_selected.1.3 soap_selected.2.1 soap_selected.2.2 soap_selected.2.3 soap_selected.3.1 soap_selected.3.2 soap_selected.3.3
0.000000 6.0785 6.3903 6.9409 5.2246 4.6212 5.9061 5.3739 5.3189 6.4924
183 changes: 138 additions & 45 deletions src/metatensor/metatensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,18 +35,11 @@ directory defines a custom machine learning CV that can be used with PLUMED.
\par Examples
The following input shows how you can call metatensor and evaluate the model
that is described in the file soap_cv.pt from PLUMED. To evaluate this model
plumed is required to use code that is included in the directory extensions
which has been specified using the `EXTENSIONS_DIRECTORY` flag. Numbered
`SPECIES` labels are used to indicate the list of indices that belong to each
atomic species in the model. The `SPECIES_TO_TYPE` keyword then provides
information on the atom type for each species. The first number here is the
atomic number of the atoms that have been specified using the `SPECIES1` flag,
the second number is the atomic number of the atoms that have been specified
using the `SPECIES2` flag and so on.
that is described in the file `custom_cv.pt` from PLUMED.
\plumedfile soap: METATENSOR ... MODEL=soap_cv.pt
EXTENSIONS_DIRECTORY=extensions
\plumedfile
metatensor_cv: METATENSOR ...
MODEL=custom_cv.pt
SPECIES1=1-26
SPECIES2=27-62
Expand All @@ -55,6 +48,47 @@ EXTENSIONS_DIRECTORY=extensions
...
\endplumedfile
The numbered `SPECIES` labels are used to indicate the list of atoms that belong
to each atomic species in the system. The `SPECIES_TO_TYPE` keyword then
provides information on the atom type for each species. The first number here is
the atomic type of the atoms that have been specified using the `SPECIES1` flag,
the second number is the atomic number of the atoms that have been specified
using the `SPECIES2` flag and so on.
`METATENSOR` action also accepts the following options:
- `EXTENSIONS_DIRECTORY` should be the path to a directory containing
TorchScript extensions (as shared libraries) that are required to load and
execute the model. This matches the `collect_extensions` argument to
`MetatensorAtomisticModel.export` in Python.
- `NO_CONSISTENCY_CHECK` can be used to disable internal consistency checks;
- `SELECTED_ATOMS` can be used to signal the metatensor models that it should
only run its calculation for the selected subset of atoms. The model still
need to know about all the atoms in the system (through the `SPECIES`
keyword); but this can be used to reduce the calculation cost. Note that the
indices of the selected atoms should start at 1 in the PLUMED input file, but
they will be translated to start at 0 when given to the model (i.e. in
Python/TorchScript, the `forward` method will receive a `selected_atoms` which
starts at 0)
Here is another example with all the possible keywords:
\plumedfile
soap: METATENSOR ...
MODEL=soap.pt
EXTENSION_DIRECTORY=extensions
NO_CONSISTENCY_CHECK
SPECIES1=1-10
SPECIES2=11-20
SPECIES_TO_TYPES=8,13
# only run the calculation for the Aluminium (type 13) atoms, but
# include the Oxygen (type 8) as potential neighbors.
SELECTED_ATOMS=11-20
...
\endplumedfile
\par Collective variables and metatensor models
Collective variables are not yet part of the [known outputs][mts_outputs] for
Expand Down Expand Up @@ -309,8 +343,6 @@ MetatensorPlumedAction::MetatensorPlumedAction(const ActionOptions& options):
}

this->atomic_types_ = torch::tensor(std::move(atomic_types));

// Request the atoms and check we have read in everything
this->requestAtoms(all_atoms);

bool no_consistency_check = false;
Expand Down Expand Up @@ -346,10 +378,6 @@ MetatensorPlumedAction::MetatensorPlumedAction(const ActionOptions& options):
output->explicit_gradients = {};
evaluations_options_->outputs.insert("plumed::cv", output);

// TODO: selected_atoms
// evaluations_options_->set_selected_atoms()


// Determine which device we should use based on user input, what the model
// supports and what's available
auto available_devices = std::vector<torch::Device>();
Expand Down Expand Up @@ -435,9 +463,8 @@ MetatensorPlumedAction::MetatensorPlumedAction(const ActionOptions& options):
auto tensor_options = torch::TensorOptions().dtype(this->dtype_).device(this->device_);
this->strain_ = torch::eye(3, tensor_options.requires_grad(true));

// setup storage for the computed CV: we need to run the model once to know
// the shape of the output, so we use a dummy system with one since atom for
// this
// determine how many properties there will be in the output by running the
// model once on a dummy system
auto dummy_system = torch::make_intrusive<metatensor_torch::SystemHolder>(
/*types = */ torch::zeros({0}, tensor_options.dtype(torch::kInt32)),
/*positions = */ torch::zeros({0, 3}, tensor_options),
Expand All @@ -461,16 +488,52 @@ MetatensorPlumedAction::MetatensorPlumedAction(const ActionOptions& options):
dummy_system->add_neighbor_list(request, neighbors);
}

this->n_properties_ = static_cast<unsigned>(
this->executeModel(dummy_system)->properties()->count()
);

// parse and handle atom sub-selection. This is done AFTER determining the
// output size, since the selection might not be valid for the dummy system
std::vector<int32_t> selected_atoms;
this->parseVector("SELECTED_ATOMS", selected_atoms);
if (!selected_atoms.empty()) {
auto selection_value = torch::zeros(
{static_cast<int64_t>(selected_atoms.size()), 2},
torch::TensorOptions().dtype(torch::kInt32).device(this->device_)
);

for (unsigned i=0; i<selected_atoms.size(); i++) {
auto n_atoms = static_cast<int32_t>(this->atomic_types_.size(0));
if (selected_atoms[i] <= 0 || selected_atoms[i] > n_atoms) {
this->error(
"Values in metatensor's SELECTED_ATOMS should be between 1 "
"and the number of atoms (" + std::to_string(n_atoms) + "), "
"got " + std::to_string(selected_atoms[i]));
}
// PLUMED input uses 1-based indexes, but metatensor wants 0-based
selection_value[i][1] = selected_atoms[i] - 1;
}

evaluations_options_->set_selected_atoms(
torch::make_intrusive<metatensor_torch::LabelsHolder>(
std::vector<std::string>{"system", "atom"}, selection_value
)
);
}

// Now that we now both n_samples and n_properties, we can setup the
// PLUMED-side storage for the computed CV
if (output->per_atom) {
this->n_samples_ = static_cast<unsigned>(this->atomic_types_.size(0));
auto selected_atoms = this->evaluations_options_->get_selected_atoms();
if (selected_atoms.has_value()) {
this->n_samples_ = static_cast<unsigned>(selected_atoms.value()->count());
} else {
this->n_samples_ = static_cast<unsigned>(this->atomic_types_.size(0));
}
} else {
this->n_samples_ = 1;
}

this->n_properties_ = static_cast<unsigned>(
this->executeModel(dummy_system)->properties()->count()
);

if (n_samples_ == 1 && n_properties_ == 1) {
log.printf(" the output of this model is a scalar\n");

Expand Down Expand Up @@ -511,6 +574,8 @@ void MetatensorPlumedAction::createSystem() {
plumed_merror(oss.str());
}

// this->getTotAtoms()

const auto& cell = this->getPbc().getBox();

auto cpu_f64_tensor = torch::TensorOptions().dtype(torch::kFloat64).device(torch::kCPU);
Expand Down Expand Up @@ -714,25 +779,39 @@ void MetatensorPlumedAction::calculate() {
}
}
} else {
auto samples = block->samples()->as_metatensor();
plumed_assert(samples.names().size() == 2);
plumed_assert(samples.names()[0] == std::string("system"));
plumed_assert(samples.names()[1] == std::string("atom"));

auto& samples_values = samples.values();
auto samples = block->samples();
plumed_assert((samples->names() == std::vector<std::string>{"system", "atom"}));

auto samples_values = samples->values().to(torch::kCPU);
auto selected_atoms = this->evaluations_options_->get_selected_atoms();

// handle the possibility that samples are returned in
// a non-sorted order.
auto get_output_location = [&](unsigned i) {
if (selected_atoms.has_value()) {
// If the users picked some selected atoms, then we store the
// output in the same order as the selection was given
auto sample = samples_values.index({static_cast<int64_t>(i), torch::indexing::Slice()});
auto position = selected_atoms.value()->position(sample);
plumed_assert(position.has_value());
return static_cast<unsigned>(position.value());
} else {
return static_cast<unsigned>(samples_values[i][1].item<int32_t>());
}
};

if (n_properties_ == 1) {
// we have a single CV describing multiple things (i.e. atoms)
for (unsigned i=0; i<n_samples_; i++) {
auto atom_i = static_cast<size_t>(samples_values(i, 1));
value->set(atom_i, torch_values[i][0].item<double>());
auto output_i = get_output_location(i);
value->set(output_i, torch_values[i][0].item<double>());
}
} else {
// the CV is a matrix
for (unsigned i=0; i<n_samples_; i++) {
auto atom_i = static_cast<size_t>(samples_values(i, 1));
auto output_i = get_output_location(i);
for (unsigned j=0; j<n_properties_; j++) {
value->set(atom_i * n_properties_ + j, torch_values[i][j].item<double>());
value->set(output_i * n_properties_ + j, torch_values[i][j].item<double>());
}
}
}
Expand All @@ -759,23 +838,34 @@ void MetatensorPlumedAction::apply() {
}
}
} else {
auto samples = block->samples()->as_metatensor();
plumed_assert(samples.names().size() == 2);
plumed_assert(samples.names()[0] == std::string("system"));
plumed_assert(samples.names()[1] == std::string("atom"));

auto& samples_values = samples.values();
auto samples = block->samples();
plumed_assert((samples->names() == std::vector<std::string>{"system", "atom"}));

auto samples_values = samples->values().to(torch::kCPU);
auto selected_atoms = this->evaluations_options_->get_selected_atoms();

// see above for an explanation of why we use this function
auto get_output_location = [&](unsigned i) {
if (selected_atoms.has_value()) {
auto sample = samples_values.index({static_cast<int64_t>(i), torch::indexing::Slice()});
auto position = selected_atoms.value()->position(sample);
plumed_assert(position.has_value());
return static_cast<unsigned>(position.value());
} else {
return static_cast<unsigned>(samples_values[i][1].item<int32_t>());
}
};

if (n_properties_ == 1) {
for (unsigned i=0; i<n_samples_; i++) {
auto atom_i = static_cast<size_t>(samples_values(i, 1));
output_grad[i][0] = value->getForce(atom_i);
auto output_i = get_output_location(i);
output_grad[i][0] = value->getForce(output_i);
}
} else {
for (unsigned i=0; i<n_samples_; i++) {
auto atom_i = static_cast<size_t>(samples_values(i, 1));
auto output_i = get_output_location(i);
for (unsigned j=0; j<n_properties_; j++) {
output_grad[i][j] = value->getForce(atom_i * n_properties_ + j);
output_grad[i][j] = value->getForce(output_i * n_properties_ + j);
}
}
}
Expand Down Expand Up @@ -842,6 +932,9 @@ namespace PLMD { namespace metatensor {
keys.add("numbered", "SPECIES", "the atoms in each PLUMED species");
keys.reset_style("SPECIES", "atoms");

keys.add("optional", "SELECTED_ATOMS", "subset of atoms that should be used for the calculation");
keys.reset_style("SELECTED_ATOMS", "atoms");

keys.add("optional", "SPECIES_TO_TYPES", "mapping from PLUMED SPECIES to metatensor's atomic types");
}

Expand Down

0 comments on commit 5e33e80

Please sign in to comment.