From 981c1f100b1b089563665f9077837386b1ad0bb3 Mon Sep 17 00:00:00 2001 From: Sebastian Kehl Date: Mon, 28 Nov 2022 14:54:30 +0000 Subject: [PATCH] Fix bug in cell list with rlist larger then data. --- src/nlists/cell_list.h | 2 +- tests/test_neighbours.py | 16 +++++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/nlists/cell_list.h b/src/nlists/cell_list.h index 6aac2d4..e7b8ac5 100644 --- a/src/nlists/cell_list.h +++ b/src/nlists/cell_list.h @@ -60,7 +60,7 @@ struct CellList : NeighbourListT { real dim = box_max[k] - box_min[k] + box_eps; int nk = int(dim/rlist); - r_list[k] = (nk > 0) ? dim/nk : box_eps; + r_list[k] = (nk > 0) ? dim/nk : dim; shape[k] = (nk > 0) ? nk : 1; } strides = { 1, shape[0], shape[0]*shape[1] }; diff --git a/tests/test_neighbours.py b/tests/test_neighbours.py index ea2572a..d36f6ba 100644 --- a/tests/test_neighbours.py +++ b/tests/test_neighbours.py @@ -40,6 +40,11 @@ class Data: d.nn = request.param[1] return d +@pytest.fixture(params=[0.2, 0.5, 1.2]) +def rlist(request): + """Parametrize test on rlist.""" + return request.param + def get_nlist(mesh, ltype, rlist, excl): """Get a neighbour list.""" @@ -53,24 +58,25 @@ def get_nlist(mesh, ltype, rlist, excl): # ----------------------------------------------------------------------------- # test -- # ----------------------------------------------------------------------------- -def test_distance_matrix(data): +def test_distance_matrix(data, rlist): """Verify distance matrix againt kdtee implementation.""" mesh = data.mesh x = mesh.points() - nl = get_nlist(mesh, data.ltype, 0.2, 0) + nl = get_nlist(mesh, data.ltype, rlist, 0) # compute distance matrix d,i,j = nl.distance_matrix(mesh, 0.123) - A = coo_matrix((d,(i,j)), shape=(len(x),len(x))) - M = A + A.T # kdtree gives full matrix + M = coo_matrix((d,(i,j)), shape=(len(x),len(x))) + if data.ltype == "cell-list": + M = M + M.T # kdtree and verlet-lists gives full matrix # compare against kd-tree distance computation tree = KDTree(x) C = tree.sparse_distance_matrix(tree, 0.123) - assert (C-M).max() == 0.0 + assert np.allclose(C.todense(), M.todense()) def test_exclusion(data, excl): """Test neighbour lists exclusion level."""