Skip to content

Commit

Permalink
add use_raft to knn_gpu (torch) (facebookresearch#3509)
Browse files Browse the repository at this point in the history
Summary:
Add support for `use_raft` in the torch version of `knn_gpu`. The numpy version already has this support, see https://github.com/facebookresearch/faiss/blob/main/faiss/python/gpu_wrappers.py#L59

Pull Request resolved: facebookresearch#3509

Reviewed By: mlomeli1, junjieqi

Differential Revision: D58489851

Pulled By: algoriddle

fbshipit-source-id: cfad722fefd4809b135b765d0d43587cfd782d0e
  • Loading branch information
algoriddle authored and facebook-github-bot committed Jun 13, 2024
1 parent f71d5b9 commit 3d32330
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
4 changes: 3 additions & 1 deletion contrib/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,8 +492,9 @@ def torch_replacement_sa_decode(self, codes, x=None):
if issubclass(the_class, faiss.Index):
handle_torch_Index(the_class)


# allows torch tensor usage with bfKnn
def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1):
def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRIC_L2, device=-1, use_raft=False):
if type(xb) is np.ndarray:
# Forward to faiss __init__.py base method
return faiss.knn_gpu_numpy(res, xq, xb, k, D, I, metric, device)
Expand Down Expand Up @@ -574,6 +575,7 @@ def torch_replacement_knn_gpu(res, xq, xb, k, D=None, I=None, metric=faiss.METRI
args.outIndices = I_ptr
args.outIndicesType = I_type
args.device = device
args.use_raft = use_raft

with using_stream(res):
faiss.bfKnn(res, args)
Expand Down
20 changes: 13 additions & 7 deletions faiss/gpu/test/torch_test_contrib_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def test_sa_encode_decode(self):
return

class TestTorchUtilsKnnGpu(unittest.TestCase):
def test_knn_gpu(self):
def test_knn_gpu(self, use_raft=False):
torch.manual_seed(10)
d = 32
nb = 1024
Expand Down Expand Up @@ -286,7 +286,7 @@ def test_knn_gpu(self):
else:
xb_c = xb_np

D, I = faiss.knn_gpu(res, xq_c, xb_c, k)
D, I = faiss.knn_gpu(res, xq_c, xb_c, k, use_raft=use_raft)

self.assertTrue(torch.equal(torch.from_numpy(I), gt_I))
self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1e-4)
Expand All @@ -312,15 +312,15 @@ def test_knn_gpu(self):
xb_c = to_column_major_torch(xb)
assert not xb_c.is_contiguous()

D, I = faiss.knn_gpu(res, xq_c, xb_c, k)
D, I = faiss.knn_gpu(res, xq_c, xb_c, k, use_raft=use_raft)

self.assertTrue(torch.equal(I.cpu(), gt_I))
self.assertLess((D.cpu() - gt_D).abs().max(), 1e-4)

# test on subset
try:
# This internally uses the current pytorch stream
D, I = faiss.knn_gpu(res, xq_c[6:8], xb_c, k)
D, I = faiss.knn_gpu(res, xq_c[6:8], xb_c, k, use_raft=use_raft)
except TypeError:
if not xq_row_major:
# then it is expected
Expand All @@ -331,7 +331,13 @@ def test_knn_gpu(self):
self.assertTrue(torch.equal(I.cpu(), gt_I[6:8]))
self.assertLess((D.cpu() - gt_D[6:8]).abs().max(), 1e-4)

def test_knn_gpu_datatypes(self):
@unittest.skipUnless(
"RAFT" in faiss.get_compile_options(),
"only if RAFT is compiled in")
def test_knn_gpu_raft(self):
self.test_knn_gpu(use_raft=True)

def test_knn_gpu_datatypes(self, use_raft=False):
torch.manual_seed(10)
d = 10
nb = 1024
Expand All @@ -354,7 +360,7 @@ def test_knn_gpu_datatypes(self):
D = torch.zeros(nq, k, device=xb_c.device, dtype=torch.float32)
I = torch.zeros(nq, k, device=xb_c.device, dtype=torch.int32)

faiss.knn_gpu(res, xq_c, xb_c, k, D, I)
faiss.knn_gpu(res, xq_c, xb_c, k, D, I, use_raft=use_raft)

self.assertTrue(torch.equal(I.long().cpu(), gt_I))
self.assertLess((D.float().cpu() - gt_D).abs().max(), 1.5e-3)
Expand All @@ -366,7 +372,7 @@ def test_knn_gpu_datatypes(self):
xb_c = xb.half().numpy()
xq_c = xq.half().numpy()

faiss.knn_gpu(res, xq_c, xb_c, k, D, I)
faiss.knn_gpu(res, xq_c, xb_c, k, D, I, use_raft=use_raft)

self.assertTrue(torch.equal(torch.from_numpy(I).long(), gt_I))
self.assertLess((torch.from_numpy(D) - gt_D).abs().max(), 1.5e-3)
Expand Down

0 comments on commit 3d32330

Please sign in to comment.