Skip to content

Commit

Permalink
Fix bug in cell list with rlist larger then data.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastian Kehl committed Nov 28, 2022
1 parent 831012c commit 981c1f1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/nlists/cell_list.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ struct CellList : NeighbourListT<exclusion>
{
real dim = box_max[k] - box_min[k] + box_eps;
int nk = int(dim/rlist);
r_list[k] = (nk > 0) ? dim/nk : box_eps;
r_list[k] = (nk > 0) ? dim/nk : dim;
shape[k] = (nk > 0) ? nk : 1;
}
strides = { 1, shape[0], shape[0]*shape[1] };
Expand Down
16 changes: 11 additions & 5 deletions tests/test_neighbours.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ class Data:
d.nn = request.param[1]
return d

@pytest.fixture(params=[0.2, 0.5, 1.2])
def rlist(request):
"""Parametrize test on rlist."""
return request.param

def get_nlist(mesh, ltype, rlist, excl):
"""Get a neighbour list."""

Expand All @@ -53,24 +58,25 @@ def get_nlist(mesh, ltype, rlist, excl):
# -----------------------------------------------------------------------------
# test --
# -----------------------------------------------------------------------------
def test_distance_matrix(data):
def test_distance_matrix(data, rlist):
"""Verify distance matrix againt kdtee implementation."""

mesh = data.mesh
x = mesh.points()

nl = get_nlist(mesh, data.ltype, 0.2, 0)
nl = get_nlist(mesh, data.ltype, rlist, 0)

# compute distance matrix
d,i,j = nl.distance_matrix(mesh, 0.123)
A = coo_matrix((d,(i,j)), shape=(len(x),len(x)))
M = A + A.T # kdtree gives full matrix
M = coo_matrix((d,(i,j)), shape=(len(x),len(x)))
if data.ltype == "cell-list":
M = M + M.T # kdtree and verlet-lists gives full matrix

# compare against kd-tree distance computation
tree = KDTree(x)
C = tree.sparse_distance_matrix(tree, 0.123)

assert (C-M).max() == 0.0
assert np.allclose(C.todense(), M.todense())

def test_exclusion(data, excl):
"""Test neighbour lists exclusion level."""
Expand Down

0 comments on commit 981c1f1

Please sign in to comment.