From 63574600f3ffbcd37ce01c26e8679b5a4b1c98c2 Mon Sep 17 00:00:00 2001 From: Benjamin Trent Date: Wed, 29 Jan 2025 10:12:36 -0500 Subject: [PATCH] Add knn result consistency test (#14167) --- .../search/BaseKnnVectorQueryTestCase.java | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index fe1f614dfcc..efe9500635a 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -24,6 +24,7 @@ import java.io.IOException; import java.util.HashSet; +import java.util.Random; import java.util.Set; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.document.Document; @@ -41,7 +42,9 @@ import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.index.SerialMergeScheduler; import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorEncoding; @@ -61,6 +64,7 @@ /** Test cases for AbstractKnnVectorQuery objects. */ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { + static final float EPSILON = 0.001f; abstract AbstractKnnVectorQuery getKnnVectorQuery( String field, float[] query, int k, Query queryFilter); @@ -482,6 +486,62 @@ public void testSkewedIndex() throws IOException { } } + /** Tests with random vectors, number of documents, etc. Uses RandomIndexWriter. */ + public void testRandomConsistencySingleThreaded() throws IOException { + assertRandomConsistency(false); + } + + @AwaitsFix(bugUrl = "https://github.com/apache/lucene/issues/14180") + public void testRandomConsistencyMultiThreaded() throws IOException { + assertRandomConsistency(true); + } + + private void assertRandomConsistency(boolean multiThreaded) throws IOException { + int numDocs = 100; + int dimension = 4; + int numIters = 10; + boolean everyDocHasAVector = random().nextBoolean(); + Random r = random(); + try (Directory d = newDirectoryForTest()) { + // To ensure consistency between seeded runs, remove some randomness + IndexWriterConfig iwc = new IndexWriterConfig(new MockAnalyzer(random())); + iwc.setMergeScheduler(new SerialMergeScheduler()); + iwc.setMergePolicy(NoMergePolicy.INSTANCE); + iwc.setMaxBufferedDocs(numDocs); + iwc.setRAMBufferSizeMB(IndexWriterConfig.DISABLE_AUTO_FLUSH); + try (IndexWriter w = new IndexWriter(d, iwc)) { + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + if (everyDocHasAVector || random().nextInt(10) != 2) { + doc.add(getKnnVectorField("field", randomVector(dimension))); + } + w.addDocument(doc); + if (r.nextBoolean() && i % 50 == 0) { + w.flush(); + } + } + } + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = newSearcher(reader, true, true, multiThreaded); + // first get the initial set of docs, and we expect all future queries to be exactly the + // same + int k = random().nextInt(80) + 1; + AbstractKnnVectorQuery query = getKnnVectorQuery("field", randomVector(dimension), k); + int n = random().nextInt(100) + 1; + TopDocs expectedResults = searcher.search(query, n); + for (int i = 0; i < numIters; i++) { + TopDocs results = searcher.search(query, n); + assertEquals(expectedResults.totalHits.value(), results.totalHits.value()); + assertEquals(expectedResults.scoreDocs.length, results.scoreDocs.length); + for (int j = 0; j < results.scoreDocs.length; j++) { + assertEquals(expectedResults.scoreDocs[j].doc, results.scoreDocs[j].doc); + assertEquals(expectedResults.scoreDocs[j].score, results.scoreDocs[j].score, EPSILON); + } + } + } + } + } + /** Tests with random vectors, number of documents, etc. Uses RandomIndexWriter. */ public void testRandom() throws IOException { int numDocs = atLeast(100);