diff --git a/pyxtal_ff/descriptors/SO3.py b/pyxtal_ff/descriptors/SO3.py index be3e3a4..33d14a6 100644 --- a/pyxtal_ff/descriptors/SO3.py +++ b/pyxtal_ff/descriptors/SO3.py @@ -1,6 +1,6 @@ from __future__ import division import numpy as np -from ase.neighborlist import NeighborList +from ase.neighborlist import NeighborList, NewPrimitiveNeighborList from optparse import OptionParser from scipy.special import sph_harm, spherical_in from ase import Atoms @@ -18,9 +18,10 @@ class SO3: alpha: float, gaussian width parameter derivative: bool, whether to calculate the gradient of not weight_on: bool, if True, the neighbors with different type will be counted as negative + primitive: bool, use the asePrimitiveNeighborList ''' - def __init__(self, nmax=3, lmax=3, rcut=3.5, alpha=2.0, derivative=True, stress=False, cutoff_function='cosine', weight_on=False): + def __init__(self, nmax=3, lmax=3, rcut=3.5, alpha=2.0, derivative=True, stress=False, cutoff_function='cosine', weight_on=False, primitive=False): # populate attributes self.nmax = nmax self.lmax = lmax @@ -31,6 +32,7 @@ def __init__(self, nmax=3, lmax=3, rcut=3.5, alpha=2.0, derivative=True, stress= self._type = "SO3" self.cutoff_function = cutoff_function self.weight_on = weight_on + self.PrimitiveNeighborList = primitive return def __str__(self): @@ -322,7 +324,10 @@ def build_neighbor_list(self, atom_ids=None): atom_ids = range(len(atoms)) cutoffs = [self.rcut/2]*len(atoms) - nl = NeighborList(cutoffs, self_interaction=False, bothways=True, skin=0.0) + if self.PrimitiveNeighborList: + nl = NewPrimitiveNeighborList(cutoffs, self_interaction=False, bothways=True, skin=0.0) + else: + nl = NeighborList(cutoffs, self_interaction=False, bothways=True, skin=0.0) nl.update(atoms) center_atoms = []