diff --git a/examples/dipoles_test_frames.xyz b/examples/dipoles_test_frames.xyz new file mode 100644 index 00000000..9cdf0576 --- /dev/null +++ b/examples/dipoles_test_frames.xyz @@ -0,0 +1,30 @@ +8 +Lattice="8.460000038146973 0.0 0.0 0.0 8.460000038146973 0.0 0.0 0.0 8.460000038146973" Properties=species:S:1:pos:R:3:dipoles:R:3:forces:R:3 energy=1.5323538047012266 pbc="T T T" +Na 6.40999985 1.92000008 7.67000008 0.87830178 0.44712111 0.83810067 0.57283619 1.14204645 1.33681563 +Na 4.18000031 4.96000004 8.02999973 0.73324216 0.97949808 -0.26166893 -0.03507047 0.24607122 1.13457851 +Na 4.30999994 7.90000010 3.68999982 -0.31266377 -0.51245950 -0.18034898 -0.01967271 0.0084999 -0.36205072 +Na 8.23999977 2.98000002 3.89999986 -0.12426795 0.51958488 0.66561609 -0.02508925 -0.11864547 -0.19522136 +Cl 4.57999992 0.88000000 7.42999983 -0.41149415 0.34275195 0.52485166 -0.51002499 -1.5905667 -1.21172538 +Cl 7.16000032 5.13000011 1.26999998 0.93048734 -0.30992820 0.32975380 -0.00883947 -0.11365123 -0.21224626 +Cl 8.27999973 8.02999973 5.61999989 0.17311888 -0.83212290 -0.09616741 0.34479173 0.24309117 0.15871536 +Cl 4.19000006 3.82999992 5.23000002 0.31179980 0.71290202 -0.87487990 -0.31893102 0.18315466 -0.64886577 +8 +Lattice="8.0 0.0 0.0 0.0 8.0 0.0 0.0 0.0 8.0" Properties=species:S:1:pos:R:3:dipoles:R:3:forces:R:3 energy=2.005822828064277 pbc="T T T" +Na 7.50000000 1.50999999 0.05000000 -0.85723323 0.56525094 -0.66685193 -1.46823136 1.56818289 1.1018063 +Na 6.37000036 6.44000006 5.28999996 0.83409694 -0.89774275 0.74302336 1.85867637 1.43423158 -0.04219309 +Na 1.15999997 2.56000018 3.88999987 -0.44331833 0.29179266 -0.94754402 0.06170197 -0.30920673 -0.13200583 +Na 5.46000004 7.88000011 6.57999992 0.70186080 0.98230338 -0.06109690 -0.98041385 -1.39992847 1.33130465 +Cl 3.61999989 5.91000032 3.55999994 -0.93308686 -0.88801660 -0.49452546 -0.49279598 -0.71824789 -0.67401955 +Cl 6.17000008 4.02999973 2.26999998 0.10955457 0.48917852 0.99768683 -0.00349212 0.44927656 -0.29142061 +Cl 4.38000011 0.11000000 0.82000005 -0.30364667 0.26808690 -0.68324974 -0.30267466 -0.5973787 -0.8010298 +Cl 0.71000004 2.49000001 0.93999994 -0.56214647 0.24127451 0.28667338 1.32722962 -0.42692926 -0.49244207 +8 +Lattice="10.0 0.0 0.0 0.0 10.0 0.0 0.0 0.0 10.0" Properties=species:S:1:pos:R:3:dipoles:R:3:forces:R:3 energy=2.261072327985546 pbc="T T T" +Na 9.36999989 1.88999999 0.06000000 0.14120932 0.48341832 0.45537058 -2.12882348 -0.97937658 -0.75803323 +Na 7.96000004 8.05000019 6.61999989 -0.45102672 -0.03732401 -0.74260234 0.54009361 -0.42119106 -0.60862247 +Na 1.44999993 3.19999981 4.86999989 0.25118867 -0.29834069 -0.69578594 0.00429438 0.08823258 -0.29035364 +Na 6.83000040 9.85000038 8.23000050 -0.44058746 0.57551755 -0.88325601 -0.0410207 0.05467702 0.84370787 +Cl 4.53000021 7.37999964 4.46000004 0.96973183 -0.29250836 0.83184721 -0.02518924 -0.13902849 -0.24158766 +Cl 7.71000004 5.03999996 2.84000015 0.24131227 -0.80615721 -0.57806687 -0.26878176 0.42816378 0.41324843 +Cl 5.46999979 0.14000000 1.02999997 0.02369047 -0.50338626 0.70727541 -0.37825837 0.20451106 -0.10303815 +Cl 0.88000000 3.10999990 1.18000007 -0.74464657 -0.57116991 -0.90829042 2.29768557 0.7640117 0.74467885 \ No newline at end of file diff --git a/src/torchpme/calculators/calculator_dipole.py b/src/torchpme/calculators/calculator_dipole.py index f127ece3..377bdedd 100644 --- a/src/torchpme/calculators/calculator_dipole.py +++ b/src/torchpme/calculators/calculator_dipole.py @@ -3,7 +3,7 @@ import torch from torch import profiler -from .._utils import _validate_parameters +from .._utils import _get_device, _get_dtype, _validate_parameters from ..lib import generate_kvectors_for_ewald from ..potentials import PotentialDipole @@ -24,7 +24,7 @@ def __init__( if not isinstance(potential, PotentialDipole): raise TypeError( - f"Potential must be an instance of Potential, got {type(potential)}" + f"Potential must be an instance of PotentialDipole, got {type(potential)}" ) self.potential = potential @@ -36,8 +36,8 @@ def __init__( or (self.lr_wavelength is None and self.potential.smearing is None) ), "Either both `lr_wavelength` and `smearing` must be set or both must be None" - self.device = torch.get_default_device() if device is None else device - self.dtype = torch.get_default_dtype() if dtype is None else dtype + self.device = _get_device(device) + self.dtype = _get_dtype(dtype) if self.dtype != potential.dtype: raise TypeError( @@ -76,7 +76,9 @@ def _compute_rspace( atom_is = neighbor_indices[:, 0] atom_js = neighbor_indices[:, 1] with profiler.record_function("compute real potential"): - contributions_is = torch.bmm(potentials_bare,dipoles[atom_js].unsqueeze(-1)).squeeze(-1) + contributions_is = torch.bmm( + potentials_bare, dipoles[atom_js].unsqueeze(-1) + ).squeeze(-1) # For each atom i, add up all contributions of the form q_j*V(r_ij) for j # ranging over all of its neighbors. @@ -86,7 +88,9 @@ def _compute_rspace( # If we are using a half neighbor list, we need to add the contributions # from the "inverse" pairs (j, i) to the atoms i if not self.full_neighbor_list: - contributions_js = torch.bmm(potentials_bare,dipoles[atom_is].unsqueeze(-1)).squeeze(-1) + contributions_js = torch.bmm( + potentials_bare, dipoles[atom_is].unsqueeze(-1) + ).squeeze(-1) potential.index_add_(0, atom_js, contributions_js) # Compensate for double counting of pairs (i,j) and (j,i) @@ -124,15 +128,14 @@ def _compute_kspace( c = torch.cos(trig_args) # [k, i] s = torch.sin(trig_args) # [k, i] sc = torch.stack([c, s], dim=0) # [2 "f", k, i] - mu_k = dipoles @ kvectors.T # [i, k] - print(mu_k) + mu_k = dipoles @ kvectors.T # [i, k] sc_summed_G = torch.einsum("fki, ik, k->fk", sc, mu_k, G) - energy = torch.einsum( - "fk, fki, kc->ic", sc_summed_G, sc, kvectors - ) + energy = torch.einsum("fk, fki, kc->ic", sc_summed_G, sc, kvectors) energy /= torch.abs(cell.det()) energy -= dipoles * self.potential.self_contribution() - energy += self.potential.background_correction(torch.abs(cell.det())) * dipoles.sum(dim = 0) + energy += self.potential.background_correction( + torch.abs(cell.det()) + ) * dipoles.sum(dim=0) return energy / 2 def forward( diff --git a/src/torchpme/potentials/potential_dipole.py b/src/torchpme/potentials/potential_dipole.py index 8c2d2736..1b062c73 100644 --- a/src/torchpme/potentials/potential_dipole.py +++ b/src/torchpme/potentials/potential_dipole.py @@ -2,6 +2,7 @@ import torch +from .._utils import _get_device, _get_dtype from .potential import Potential @@ -17,8 +18,9 @@ def __init__( device: Union[None, str, torch.device] = None, ): super().__init__() - self.dtype = torch.get_default_dtype() if dtype is None else dtype - self.device = torch.get_default_device() if device is None else device + + self.dtype = _get_dtype(dtype) + self.device = _get_device(device) if smearing is not None: self.register_buffer( "smearing", torch.tensor(smearing, device=self.device, dtype=self.dtype) diff --git a/tests/helpers.py b/tests/helpers.py index 4682b906..2204f9bf 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -14,6 +14,7 @@ DIR_PATH = Path(__file__).parent EXAMPLES = DIR_PATH / ".." / "examples" COULOMB_TEST_FRAMES = EXAMPLES / "coulomb_test_frames.xyz" +DIPOLES_TEST_FRAMES = EXAMPLES / "dipoles_test_frames.xyz" def define_crystal(crystal_name="CsCl", dtype=None, device=None): diff --git a/tests/test_magnetostatics.py b/tests/test_magnetostatics.py index d2a3f19a..fae2ff5f 100644 --- a/tests/test_magnetostatics.py +++ b/tests/test_magnetostatics.py @@ -1,8 +1,12 @@ import pytest import torch +from ase.io import read +from helpers import DIPOLES_TEST_FRAMES +from vesin.torch import NeighborList from torchpme.calculators import CalculatorDipole from torchpme.potentials import PotentialDipole +from torchpme.prefactors import eV_A class System: @@ -96,3 +100,53 @@ def test_magnetostatic_ewald(): assert torch.isclose(result, expected_result, atol=1e-4), ( f"Expected {expected_result}, but got {result}" ) + + +frames = read(DIPOLES_TEST_FRAMES, ":3") +cutoffs = [3.9986718930, 4.0000000000, 4.7363281250] +alphas = [0.8819831493, 0.8956299559, 0.7215211182] +energies = [frame.get_potential_energy() for frame in frames] +forces = [frame.get_forces() for frame in frames] + + +@pytest.mark.parametrize( + ("frame", "cutoff", "alpha", "energy", "force"), + zip(frames, cutoffs, alphas, energies, forces), +) +def test_magnetostatic_ewald_crystal(frame, cutoff, alpha, energy, force): + smearing = (1 / (2 * alpha**2)) ** 0.5 + calc = CalculatorDipole( + potential=PotentialDipole(smearing=smearing, dtype=torch.float64), + full_neighbor_list=False, + lr_wavelength=0.1, + dtype=torch.float64, + prefactor=eV_A, + ) + positions = torch.tensor( + frame.get_positions(), requires_grad=True, dtype=torch.float64 + ) + dipoles = torch.tensor(frame.get_array("dipoles"), dtype=torch.float64) + cell = torch.tensor(frame.get_cell().array, dtype=torch.float64) + calculator = NeighborList(cutoff=cutoff, full_list=False) + p, d = calculator.compute( + points=positions, box=cell, periodic=True, quantities="PD" + ) + pot = calc( + dipoles=dipoles, + cell=cell, + positions=positions, + neighbor_indices=p, + neighbor_vectors=d, + ) + + result = torch.einsum("ij,ij->", pot, dipoles) + expected_result = torch.tensor(energy, dtype=torch.float64) + assert torch.isclose(result, expected_result, atol=1e-4), ( + f"Expected {expected_result}, but got {result}" + ) + + forces = -torch.autograd.grad(result, positions)[0] + expected_forces = torch.tensor(force, dtype=torch.float64) + assert torch.allclose(forces, expected_forces, atol=1e-4), ( + f"Expected {expected_forces}, but got {forces}" + )