Skip to content

Commit

Permalink
Add dipoles test frames and enhance tests for magnetostatic Ewald cal…
Browse files Browse the repository at this point in the history
…culations
  • Loading branch information
E-Rum committed Jan 31, 2025
1 parent ec6d55e commit f90a036
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 14 deletions.
30 changes: 30 additions & 0 deletions examples/dipoles_test_frames.xyz
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
8
Lattice="8.460000038146973 0.0 0.0 0.0 8.460000038146973 0.0 0.0 0.0 8.460000038146973" Properties=species:S:1:pos:R:3:dipoles:R:3:forces:R:3 energy=1.5323538047012266 pbc="T T T"
Na 6.40999985 1.92000008 7.67000008 0.87830178 0.44712111 0.83810067 0.57283619 1.14204645 1.33681563
Na 4.18000031 4.96000004 8.02999973 0.73324216 0.97949808 -0.26166893 -0.03507047 0.24607122 1.13457851
Na 4.30999994 7.90000010 3.68999982 -0.31266377 -0.51245950 -0.18034898 -0.01967271 0.0084999 -0.36205072
Na 8.23999977 2.98000002 3.89999986 -0.12426795 0.51958488 0.66561609 -0.02508925 -0.11864547 -0.19522136
Cl 4.57999992 0.88000000 7.42999983 -0.41149415 0.34275195 0.52485166 -0.51002499 -1.5905667 -1.21172538
Cl 7.16000032 5.13000011 1.26999998 0.93048734 -0.30992820 0.32975380 -0.00883947 -0.11365123 -0.21224626
Cl 8.27999973 8.02999973 5.61999989 0.17311888 -0.83212290 -0.09616741 0.34479173 0.24309117 0.15871536
Cl 4.19000006 3.82999992 5.23000002 0.31179980 0.71290202 -0.87487990 -0.31893102 0.18315466 -0.64886577
8
Lattice="8.0 0.0 0.0 0.0 8.0 0.0 0.0 0.0 8.0" Properties=species:S:1:pos:R:3:dipoles:R:3:forces:R:3 energy=2.005822828064277 pbc="T T T"
Na 7.50000000 1.50999999 0.05000000 -0.85723323 0.56525094 -0.66685193 -1.46823136 1.56818289 1.1018063
Na 6.37000036 6.44000006 5.28999996 0.83409694 -0.89774275 0.74302336 1.85867637 1.43423158 -0.04219309
Na 1.15999997 2.56000018 3.88999987 -0.44331833 0.29179266 -0.94754402 0.06170197 -0.30920673 -0.13200583
Na 5.46000004 7.88000011 6.57999992 0.70186080 0.98230338 -0.06109690 -0.98041385 -1.39992847 1.33130465
Cl 3.61999989 5.91000032 3.55999994 -0.93308686 -0.88801660 -0.49452546 -0.49279598 -0.71824789 -0.67401955
Cl 6.17000008 4.02999973 2.26999998 0.10955457 0.48917852 0.99768683 -0.00349212 0.44927656 -0.29142061
Cl 4.38000011 0.11000000 0.82000005 -0.30364667 0.26808690 -0.68324974 -0.30267466 -0.5973787 -0.8010298
Cl 0.71000004 2.49000001 0.93999994 -0.56214647 0.24127451 0.28667338 1.32722962 -0.42692926 -0.49244207
8
Lattice="10.0 0.0 0.0 0.0 10.0 0.0 0.0 0.0 10.0" Properties=species:S:1:pos:R:3:dipoles:R:3:forces:R:3 energy=2.261072327985546 pbc="T T T"
Na 9.36999989 1.88999999 0.06000000 0.14120932 0.48341832 0.45537058 -2.12882348 -0.97937658 -0.75803323
Na 7.96000004 8.05000019 6.61999989 -0.45102672 -0.03732401 -0.74260234 0.54009361 -0.42119106 -0.60862247
Na 1.44999993 3.19999981 4.86999989 0.25118867 -0.29834069 -0.69578594 0.00429438 0.08823258 -0.29035364
Na 6.83000040 9.85000038 8.23000050 -0.44058746 0.57551755 -0.88325601 -0.0410207 0.05467702 0.84370787
Cl 4.53000021 7.37999964 4.46000004 0.96973183 -0.29250836 0.83184721 -0.02518924 -0.13902849 -0.24158766
Cl 7.71000004 5.03999996 2.84000015 0.24131227 -0.80615721 -0.57806687 -0.26878176 0.42816378 0.41324843
Cl 5.46999979 0.14000000 1.02999997 0.02369047 -0.50338626 0.70727541 -0.37825837 0.20451106 -0.10303815
Cl 0.88000000 3.10999990 1.18000007 -0.74464657 -0.57116991 -0.90829042 2.29768557 0.7640117 0.74467885
27 changes: 15 additions & 12 deletions src/torchpme/calculators/calculator_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch
from torch import profiler

from .._utils import _validate_parameters
from .._utils import _get_device, _get_dtype, _validate_parameters
from ..lib import generate_kvectors_for_ewald
from ..potentials import PotentialDipole

Expand All @@ -24,7 +24,7 @@ def __init__(

if not isinstance(potential, PotentialDipole):
raise TypeError(
f"Potential must be an instance of Potential, got {type(potential)}"
f"Potential must be an instance of PotentialDipole, got {type(potential)}"
)

self.potential = potential
Expand All @@ -36,8 +36,8 @@ def __init__(
or (self.lr_wavelength is None and self.potential.smearing is None)
), "Either both `lr_wavelength` and `smearing` must be set or both must be None"

self.device = torch.get_default_device() if device is None else device
self.dtype = torch.get_default_dtype() if dtype is None else dtype
self.device = _get_device(device)
self.dtype = _get_dtype(dtype)

if self.dtype != potential.dtype:
raise TypeError(
Expand Down Expand Up @@ -76,7 +76,9 @@ def _compute_rspace(
atom_is = neighbor_indices[:, 0]
atom_js = neighbor_indices[:, 1]
with profiler.record_function("compute real potential"):
contributions_is = torch.bmm(potentials_bare,dipoles[atom_js].unsqueeze(-1)).squeeze(-1)
contributions_is = torch.bmm(
potentials_bare, dipoles[atom_js].unsqueeze(-1)
).squeeze(-1)

# For each atom i, add up all contributions of the form q_j*V(r_ij) for j
# ranging over all of its neighbors.
Expand All @@ -86,7 +88,9 @@ def _compute_rspace(
# If we are using a half neighbor list, we need to add the contributions
# from the "inverse" pairs (j, i) to the atoms i
if not self.full_neighbor_list:
contributions_js = torch.bmm(potentials_bare,dipoles[atom_is].unsqueeze(-1)).squeeze(-1)
contributions_js = torch.bmm(
potentials_bare, dipoles[atom_is].unsqueeze(-1)
).squeeze(-1)
potential.index_add_(0, atom_js, contributions_js)

# Compensate for double counting of pairs (i,j) and (j,i)
Expand Down Expand Up @@ -124,15 +128,14 @@ def _compute_kspace(
c = torch.cos(trig_args) # [k, i]
s = torch.sin(trig_args) # [k, i]
sc = torch.stack([c, s], dim=0) # [2 "f", k, i]
mu_k = dipoles @ kvectors.T # [i, k]
print(mu_k)
mu_k = dipoles @ kvectors.T # [i, k]
sc_summed_G = torch.einsum("fki, ik, k->fk", sc, mu_k, G)
energy = torch.einsum(
"fk, fki, kc->ic", sc_summed_G, sc, kvectors
)
energy = torch.einsum("fk, fki, kc->ic", sc_summed_G, sc, kvectors)
energy /= torch.abs(cell.det())
energy -= dipoles * self.potential.self_contribution()
energy += self.potential.background_correction(torch.abs(cell.det())) * dipoles.sum(dim = 0)
energy += self.potential.background_correction(
torch.abs(cell.det())
) * dipoles.sum(dim=0)
return energy / 2

def forward(
Expand Down
6 changes: 4 additions & 2 deletions src/torchpme/potentials/potential_dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch

from .._utils import _get_device, _get_dtype
from .potential import Potential


Expand All @@ -17,8 +18,9 @@ def __init__(
device: Union[None, str, torch.device] = None,
):
super().__init__()
self.dtype = torch.get_default_dtype() if dtype is None else dtype
self.device = torch.get_default_device() if device is None else device

self.dtype = _get_dtype(dtype)
self.device = _get_device(device)
if smearing is not None:
self.register_buffer(
"smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype)
Expand Down
1 change: 1 addition & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
DIR_PATH = Path(__file__).parent
EXAMPLES = DIR_PATH / ".." / "examples"
COULOMB_TEST_FRAMES = EXAMPLES / "coulomb_test_frames.xyz"
DIPOLES_TEST_FRAMES = EXAMPLES / "dipoles_test_frames.xyz"


def define_crystal(crystal_name="CsCl", dtype=None, device=None):
Expand Down
54 changes: 54 additions & 0 deletions tests/test_magnetostatics.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import pytest
import torch
from ase.io import read
from helpers import DIPOLES_TEST_FRAMES
from vesin.torch import NeighborList

from torchpme.calculators import CalculatorDipole
from torchpme.potentials import PotentialDipole
from torchpme.prefactors import eV_A


class System:
Expand Down Expand Up @@ -96,3 +100,53 @@ def test_magnetostatic_ewald():
assert torch.isclose(result, expected_result, atol=1e-4), (
f"Expected {expected_result}, but got {result}"
)


frames = read(DIPOLES_TEST_FRAMES, ":3")
cutoffs = [3.9986718930, 4.0000000000, 4.7363281250]
alphas = [0.8819831493, 0.8956299559, 0.7215211182]
energies = [frame.get_potential_energy() for frame in frames]
forces = [frame.get_forces() for frame in frames]


@pytest.mark.parametrize(
("frame", "cutoff", "alpha", "energy", "force"),
zip(frames, cutoffs, alphas, energies, forces),
)
def test_magnetostatic_ewald_crystal(frame, cutoff, alpha, energy, force):
smearing = (1 / (2 * alpha**2)) ** 0.5
calc = CalculatorDipole(
potential=PotentialDipole(smearing=smearing, dtype=torch.float64),
full_neighbor_list=False,
lr_wavelength=0.1,
dtype=torch.float64,
prefactor=eV_A,
)
positions = torch.tensor(
frame.get_positions(), requires_grad=True, dtype=torch.float64
)
dipoles = torch.tensor(frame.get_array("dipoles"), dtype=torch.float64)
cell = torch.tensor(frame.get_cell().array, dtype=torch.float64)
calculator = NeighborList(cutoff=cutoff, full_list=False)
p, d = calculator.compute(
points=positions, box=cell, periodic=True, quantities="PD"
)
pot = calc(
dipoles=dipoles,
cell=cell,
positions=positions,
neighbor_indices=p,
neighbor_vectors=d,
)

result = torch.einsum("ij,ij->", pot, dipoles)
expected_result = torch.tensor(energy, dtype=torch.float64)
assert torch.isclose(result, expected_result, atol=1e-4), (
f"Expected {expected_result}, but got {result}"
)

forces = -torch.autograd.grad(result, positions)[0]
expected_forces = torch.tensor(force, dtype=torch.float64)
assert torch.allclose(forces, expected_forces, atol=1e-4), (
f"Expected {expected_forces}, but got {forces}"
)

0 comments on commit f90a036

Please sign in to comment.