Skip to content

Commit

Permalink
Handle out of order samples
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Apr 16, 2024
1 parent 514164b commit 893fe99
Showing 1 changed file with 30 additions and 13 deletions.
43 changes: 30 additions & 13 deletions src/metatensor/metatensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class MetatensorPlumedAction: public ActionAtomistic, public ActionWithValue {
);

// execute the model for the given system
torch::Tensor executeModel(metatensor_torch::System system);
metatensor_torch::TorchTensorBlock executeModel(metatensor_torch::System system);

torch::jit::Module model_;

Expand Down Expand Up @@ -349,7 +349,9 @@ MetatensorPlumedAction::MetatensorPlumedAction(const ActionOptions& options):
this->n_samples_ = 1;
}

this->n_properties_ = static_cast<unsigned>(this->executeModel(dummy_system).size(1));
this->n_properties_ = static_cast<unsigned>(
this->executeModel(dummy_system)->properties()->count()
);

if (n_samples_ == 1 && n_properties_ == 1) {
log.printf(" the output of this model is a scalar\n");
Expand Down Expand Up @@ -531,7 +533,7 @@ metatensor_torch::TorchTensorBlock MetatensorPlumedAction::computeNeighbors(
return neighbors;
}

torch::Tensor MetatensorPlumedAction::executeModel(metatensor_torch::System system) {
metatensor_torch::TorchTensorBlock MetatensorPlumedAction::executeModel(metatensor_torch::System system) {
try {
auto ivalue_output = this->model_.forward({
std::vector<metatensor_torch::System>{system},
Expand All @@ -550,14 +552,15 @@ torch::Tensor MetatensorPlumedAction::executeModel(metatensor_torch::System syst
auto block = metatensor_torch::TensorMapHolder::block_by_id(this->output_, 0);
plumed_massert(block->components().empty(), "components are not yet supported in the output");

return block->values().to(torch::kCPU).to(torch::kFloat64);
return block;
}


void MetatensorPlumedAction::calculate() {
this->createSystem();

auto torch_values = executeModel(this->system_);
auto block = this->executeModel(this->system_);
auto torch_values = block->values().to(torch::kCPU).to(torch::kFloat64);

if (static_cast<unsigned>(torch_values.size(0)) != this->n_samples_) {
this->error(
Expand Down Expand Up @@ -585,18 +588,25 @@ void MetatensorPlumedAction::calculate() {
}
}
} else {
auto samples = block->samples()->as_metatensor();
plumed_assert(samples.names().size() == 2);
plumed_assert(samples.names()[0] == std::string("system"));
plumed_assert(samples.names()[1] == std::string("atom"));

auto& samples_values = samples.values();

if (n_properties_ == 1) {
// we have a single CV describing multiple things (i.e. atoms)
for (unsigned i=0; i<n_samples_; i++) {
// TODO: check sample order
value->set(i, torch_values[i][0].item<double>());
auto atom_i = static_cast<size_t>(samples_values(i, 1));
value->set(atom_i, torch_values[i][0].item<double>());
}
} else {
// the CV is a matrix
for (unsigned i=0; i<n_samples_; i++) {
// TODO: check sample order
auto atom_i = static_cast<size_t>(samples_values(i, 1));
for (unsigned j=0; j<n_properties_; j++) {
value->set(i * n_properties_ + j, torch_values[i][j].item<double>());
value->set(atom_i * n_properties_ + j, torch_values[i][j].item<double>());
}
}
}
Expand All @@ -623,16 +633,23 @@ void MetatensorPlumedAction::apply() {
}
}
} else {
auto samples = block->samples()->as_metatensor();
plumed_assert(samples.names().size() == 2);
plumed_assert(samples.names()[0] == std::string("system"));
plumed_assert(samples.names()[1] == std::string("atom"));

auto& samples_values = samples.values();

if (n_properties_ == 1) {
// TODO: check sample order?
for (unsigned i=0; i<n_samples_; i++) {
output_grad[i][0] = value->getForce(i);
auto atom_i = static_cast<size_t>(samples_values(i, 1));
output_grad[i][0] = value->getForce(atom_i);
}
} else {
// TODO: check sample order?
for (unsigned i=0; i<n_samples_; i++) {
auto atom_i = static_cast<size_t>(samples_values(i, 1));
for (unsigned j=0; j<n_properties_; j++) {
output_grad[i][j] = value->getForce(i * n_properties_ + j);
output_grad[i][j] = value->getForce(atom_i * n_properties_ + j);
}
}
}
Expand Down

0 comments on commit 893fe99

Please sign in to comment.