-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit fef7226
Showing
4 changed files
with
166 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
# -*- coding: utf-8 -*- | ||
"""Tools for interfacing with `ASE`_. | ||
.. _ASE: | ||
https://wiki.fysik.dtu.dk/ase | ||
""" | ||
|
||
import torch | ||
import ase.calculators.calculator | ||
|
||
|
||
def map2central(cell, coordinates, pbc): | ||
"""Map atoms outside the unit cell into the cell using PBC. | ||
Arguments: | ||
cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three | ||
vectors defining unit cell: | ||
.. code-block:: python | ||
tensor([[x1, y1, z1], | ||
[x2, y2, z2], | ||
[x3, y3, z3]]) | ||
coordinates (:class:`torch.Tensor`): Tensor of shape | ||
``(molecules, atoms, 3)``. | ||
pbc (:class:`torch.Tensor`): boolean vector of size 3 storing | ||
if pbc is enabled for that direction. | ||
Returns: | ||
:class:`torch.Tensor`: coordinates of atoms mapped back to unit cell. | ||
""" | ||
# Step 1: convert coordinates from standard cartesian coordinate to unit | ||
# cell coordinates | ||
inv_cell = torch.inverse(cell) | ||
coordinates_cell = torch.matmul(coordinates, inv_cell) | ||
# Step 2: wrap cell coordinates into [0, 1) | ||
coordinates_cell -= coordinates_cell.floor() * pbc.to(coordinates_cell.dtype) | ||
# Step 3: convert from cell coordinates back to standard cartesian | ||
# coordinate | ||
return torch.matmul(coordinates_cell, cell) | ||
|
||
|
||
class Calculator(ase.calculators.calculator.Calculator): | ||
"""ASE Calculator that wraps a neural network potential | ||
Arguments: | ||
model (:class:`torch.nn.Module`): neural network potential model | ||
that convert coordinates into energies. | ||
overwrite (bool): After wrapping atoms into central box, whether | ||
to replace the original positions stored in :class:`ase.Atoms` | ||
object with the wrapped positions. | ||
""" | ||
|
||
implemented_properties = ['energy', 'forces', 'stress', 'free_energy'] | ||
|
||
def __init__(self, model, overwrite=False): | ||
super(Calculator, self).__init__() | ||
self.model = model | ||
self.overwrite = overwrite | ||
|
||
def calculate(self, atoms=None, properties=['energy'], | ||
system_changes=ase.calculators.calculator.all_changes): | ||
super(Calculator, self).calculate(atoms, properties, system_changes) | ||
cell = torch.from_numpy(self.atoms.get_cell(complete=True)) | ||
pbc = torch.tensor(self.atoms.get_pbc(), dtype=torch.bool) | ||
coordinates = torch.from_numpy(self.atoms.get_positions()).requires_grad_('forces' in properties) | ||
pbc_enabled = pbc.any().item() | ||
|
||
if pbc_enabled: | ||
coordinates = map2central(cell, coordinates, pbc) | ||
if self.overwrite and atoms is not None: | ||
atoms.set_positions(coordinates.detach().cpu().numpy()) | ||
|
||
if 'stress' in properties: | ||
scaling = torch.eye(3, requires_grad=True, dtype=self.dtype, device=self.device) | ||
coordinates = coordinates @ scaling | ||
cell = cell @ scaling | ||
energy = self.model(atoms.get_chemical_symbols(), coordinates, cell, pbc) | ||
|
||
self.results['energy'] = energy.item() | ||
self.results['free_energy'] = energy.item() | ||
|
||
if 'forces' in properties: | ||
forces = -torch.autograd.grad(energy, coordinates)[0] | ||
self.results['forces'] = forces.cpu().numpy() | ||
|
||
if 'stress' in properties: | ||
volume = self.atoms.get_volume() | ||
stress = torch.autograd.grad(energy.squeeze(), scaling)[0] / volume | ||
self.results['stress'] = stress.cpu().numpy() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import torch | ||
|
||
|
||
def hessian(coordinates, energies=None, forces=None): | ||
"""Compute analytical hessian from the energy graph or force graph. | ||
Arguments: | ||
coordinates (:class:`torch.Tensor`): Tensor of shape `(molecules, atoms, 3)` | ||
energies (:class:`torch.Tensor`): Tensor of shape `(molecules,)`, if specified, | ||
then `forces` must be `None`. This energies must be computed from | ||
`coordinates` in a graph. | ||
forces (:class:`torch.Tensor`): Tensor of shape `(molecules, atoms, 3)`, if specified, | ||
then `energies` must be `None`. This forces must be computed from | ||
`coordinates` in a graph. | ||
Returns: | ||
:class:`torch.Tensor`: Tensor of shape `(molecules, 3A, 3A)` where A is the number of | ||
atoms in each molecule | ||
""" | ||
if energies is None and forces is None: | ||
raise ValueError('Energies or forces must be specified') | ||
if energies is not None and forces is not None: | ||
raise ValueError('Energies or forces can not be specified at the same time') | ||
if forces is None: | ||
forces = -torch.autograd.grad(energies.sum(), coordinates, create_graph=True)[0] | ||
flattened_force = forces.flatten(start_dim=1) | ||
force_components = flattened_force.unbind(dim=1) | ||
return -torch.stack([ | ||
torch.autograd.grad(f.sum(), coordinates, retain_graph=True)[0].flatten(start_dim=1) | ||
for f in force_components | ||
], dim=1) | ||
|
||
|
||
class FreqsModes(NamedTuple): | ||
freqs: Tensor | ||
modes: Tensor | ||
|
||
|
||
def vibrational_analysis(masses, hessian, unit='cm^-1'): | ||
"""Computing the vibrational wavenumbers from hessian.""" | ||
if unit != 'cm^-1': | ||
raise ValueError('Only cm^-1 are supported right now') | ||
assert hessian.shape[0] == 1, 'Currently only supporting computing one molecule a time' | ||
# Solving the eigenvalue problem: Hq = w^2 * T q | ||
# where H is the Hessian matrix, q is the normal coordinates, | ||
# T = diag(m1, m1, m1, m2, m2, m2, ....) is the mass | ||
# We solve this eigenvalue problem through Lowdin diagnolization: | ||
# Hq = w^2 * Tq ==> Hq = w^2 * T^(1/2) T^(1/2) q | ||
# Letting q' = T^(1/2) q, we then have | ||
# T^(-1/2) H T^(-1/2) q' = w^2 * q' | ||
inv_sqrt_mass = (1 / masses.sqrt()).repeat_interleave(3, dim=1) # shape (molecule, 3 * atoms) | ||
mass_scaled_hessian = hessian * inv_sqrt_mass.unsqueeze(1) * inv_sqrt_mass.unsqueeze(2) | ||
if mass_scaled_hessian.shape[0] != 1: | ||
raise ValueError('The input should contain only one molecule') | ||
mass_scaled_hessian = mass_scaled_hessian.squeeze(0) | ||
eigenvalues, eigenvectors = torch.symeig(mass_scaled_hessian, eigenvectors=True) | ||
angular_frequencies = eigenvalues.sqrt() | ||
frequencies = angular_frequencies / (2 * math.pi) | ||
# converting from sqrt(hartree / (amu * angstrom^2)) to cm^-1 | ||
wavenumbers = frequencies * 17092 | ||
modes = (eigenvectors.t() * inv_sqrt_mass).reshape(frequencies.numel(), -1, 3) | ||
return FreqsModes(wavenumbers, modes) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from setuptools import setup, find_packages | ||
|
||
with open("README.md", "r") as fh: | ||
long_description = fh.read() | ||
|
||
setup( | ||
name='nnp', | ||
description='Common tools for PyTorch based of neural network potentials', | ||
long_description=long_description, | ||
long_description_content_type="text/markdown", | ||
url='https://github.com/aiqm/nnp', | ||
author='Xiang Gao', | ||
author_email='[email protected]', | ||
license='MIT', | ||
packages=find_packages(), | ||
include_package_data=True, | ||
use_scm_version=True, | ||
setup_requires=['setuptools_scm'], | ||
install_requires=[ 'torch' ], | ||
) |