Skip to content

Commit

Permalink
Replace vesin.torch, as it doesn’t support Windows OS.
Browse files Browse the repository at this point in the history
  • Loading branch information
E-Rum committed Jan 31, 2025
1 parent f90a036 commit 99929d0
Showing 1 changed file with 41 additions and 10 deletions.
51 changes: 41 additions & 10 deletions tests/test_magnetostatics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,36 @@
import pytest
import torch
from ase.io import read
from helpers import DIPOLES_TEST_FRAMES
from vesin.torch import NeighborList
from helpers import DIPOLES_TEST_FRAMES, neighbor_list

from torchpme.calculators import CalculatorDipole
from torchpme.potentials import PotentialDipole
from torchpme.prefactors import eV_A


def compute_distance_vectors(
positions, neighbor_indices, cell=None, neighbor_shifts=None
):
"""Compute pairwise distance vectors."""
atom_is = neighbor_indices[:, 0]
atom_js = neighbor_indices[:, 1]

pos_is = positions[atom_is]
pos_js = positions[atom_js]

distance_vectors = pos_js - pos_is

if cell is not None and neighbor_shifts is not None:
shifts = neighbor_shifts.type(cell.dtype)
distance_vectors += shifts @ cell
elif cell is not None and neighbor_shifts is None:
raise ValueError("Provided `cell` but no `neighbor_shifts`.")
elif cell is None and neighbor_shifts is not None:
raise ValueError("Provided `neighbor_shifts` but no `cell`.")

return distance_vectors


class System:
def __init__(self):
self.cell = torch.tensor(
Expand Down Expand Up @@ -122,21 +144,30 @@ def test_magnetostatic_ewald_crystal(frame, cutoff, alpha, energy, force):
dtype=torch.float64,
prefactor=eV_A,
)
positions = torch.tensor(
frame.get_positions(), requires_grad=True, dtype=torch.float64
)
positions = torch.tensor(frame.get_positions(), dtype=torch.float64)
dipoles = torch.tensor(frame.get_array("dipoles"), dtype=torch.float64)
cell = torch.tensor(frame.get_cell().array, dtype=torch.float64)
calculator = NeighborList(cutoff=cutoff, full_list=False)
p, d = calculator.compute(
points=positions, box=cell, periodic=True, quantities="PD"
neighbor_indices, neighbor_shifts = neighbor_list(
positions=positions,
periodic=True,
box=cell,
cutoff=cutoff,
full_neighbor_list=False,
neighbor_shifts=True,
)
positions.requires_grad = True
neighbor_distance_vectors = compute_distance_vectors(
positions=positions,
neighbor_indices=neighbor_indices,
cell=cell,
neighbor_shifts=neighbor_shifts,
)
pot = calc(
dipoles=dipoles,
cell=cell,
positions=positions,
neighbor_indices=p,
neighbor_vectors=d,
neighbor_indices=neighbor_indices,
neighbor_vectors=neighbor_distance_vectors,
)

result = torch.einsum("ij,ij->", pot, dipoles)
Expand Down

0 comments on commit 99929d0

Please sign in to comment.