From cd9c486f5137d5eca0feba9337b3a588e2cf50f5 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Sun, 1 Sep 2024 16:31:31 -0400 Subject: [PATCH 01/25] compiles! --- .../synonym/word2vec/Word2VecModel.java | 4 +- .../lucene90/Lucene90HnswGraphBuilder.java | 11 +- .../lucene90/Lucene90HnswVectorsReader.java | 45 +--- .../lucene90/Lucene90OnHeapHnswGraph.java | 4 +- .../lucene91/Lucene91HnswVectorsReader.java | 48 +--- .../lucene92/OffHeapFloatVectorValues.java | 75 +----- .../lucene94/OffHeapByteVectorValues.java | 74 +----- .../lucene94/OffHeapFloatVectorValues.java | 75 +----- .../lucene90/Lucene90HnswVectorsWriter.java | 12 +- .../lucene91/Lucene91HnswGraphBuilder.java | 18 +- .../lucene91/Lucene91HnswVectorsWriter.java | 10 +- .../lucene92/Lucene92HnswVectorsWriter.java | 9 +- .../lucene94/Lucene94HnswVectorsWriter.java | 20 +- .../lucene95/Lucene95HnswVectorsWriter.java | 30 ++- .../TestBasicBackwardsCompatibility.java | 9 +- .../benchmark/jmh/VectorScorerBenchmark.java | 16 +- .../bitvectors/FlatBitVectorsScorer.java | 40 ++-- .../SimpleTextKnnVectorsReader.java | 86 ++----- .../SimpleTextKnnVectorsWriter.java | 26 +-- .../codecs/BufferingKnnVectorsWriter.java | 218 +++++------------ .../lucene/codecs/KnnVectorsWriter.java | 106 +++++---- .../codecs/hnsw/DefaultFlatVectorScorer.java | 60 ++--- .../lucene/codecs/hnsw/FlatVectorsScorer.java | 13 +- .../hnsw/ScalarQuantizedVectorScorer.java | 26 +-- .../lucene95/OffHeapByteVectorValues.java | 75 +----- .../lucene95/OffHeapFloatVectorValues.java | 81 +------ .../lucene99/Lucene99FlatVectorsWriter.java | 17 +- .../lucene99/Lucene99HnswVectorsWriter.java | 17 +- .../Lucene99ScalarQuantizedVectorScorer.java | 48 ++-- .../Lucene99ScalarQuantizedVectorsReader.java | 25 +- .../Lucene99ScalarQuantizedVectorsWriter.java | 220 +++++------------- .../OffHeapQuantizedByteVectorValues.java | 109 +++------ .../apache/lucene/index/ByteVectorValues.java | 79 +++++-- .../org/apache/lucene/index/CheckIndex.java | 38 +-- .../lucene/index/ExitableDirectoryReader.java | 86 +++---- .../lucene/index/FloatVectorValues.java | 81 +++++-- .../apache/lucene/index/KnnVectorValues.java | 133 +++++++++++ .../SlowCompositeCodecReaderWrapper.java | 74 +++--- .../lucene/index/SortingCodecReader.java | 92 +++----- .../lucene/search/FieldExistsQuery.java | 4 +- .../util/hnsw/ConcurrentHnswMerger.java | 6 +- .../lucene/util/hnsw/HnswGraphMerger.java | 6 +- .../util/hnsw/IncrementalHnswGraphMerger.java | 39 ++-- .../util/hnsw/RandomAccessVectorValues.java | 146 +----------- .../lucene/util/hnsw/RandomVectorScorer.java | 5 +- .../QuantizedByteVectorValues.java | 35 +-- ...RandomAccessQuantizedByteVectorValues.java | 36 --- .../util/quantization/ScalarQuantizer.java | 58 ++--- ...Lucene99MemorySegmentByteVectorScorer.java | 19 +- ...MemorySegmentByteVectorScorerSupplier.java | 16 +- ...ucene99MemorySegmentFlatVectorsScorer.java | 30 ++- .../codecs/hnsw/TestFlatVectorScorer.java | 9 +- ...estLucene99HnswQuantizedVectorsFormat.java | 14 +- ...stLucene99ScalarQuantizedVectorScorer.java | 8 +- ...tLucene99ScalarQuantizedVectorsFormat.java | 8 +- ...tLucene99ScalarQuantizedVectorsWriter.java | 2 +- .../org/apache/lucene/document/TestField.java | 19 +- .../index/TestExitableDirectoryReader.java | 34 +-- .../org/apache/lucene/index/TestKnnGraph.java | 8 +- .../lucene/index/TestSortingCodecReader.java | 4 +- .../vectorization/TestVectorScorer.java | 6 +- .../search/BaseKnnVectorQueryTestCase.java | 5 +- .../util/hnsw/AbstractMockVectorValues.java | 3 +- .../lucene/util/hnsw/HnswGraphTestCase.java | 34 +-- .../util/hnsw/MockByteVectorValues.java | 8 +- .../lucene/util/hnsw/MockVectorValues.java | 8 +- .../util/hnsw/TestHnswByteVectorGraph.java | 7 +- .../util/hnsw/TestHnswFloatVectorGraph.java | 9 +- .../TestScalarQuantizedVectorSimilarity.java | 12 +- .../quantization/TestScalarQuantizer.java | 30 +-- .../lucene/index/memory/MemoryIndex.java | 70 +----- .../lucene/index/memory/TestMemoryIndex.java | 16 +- .../valuesource/ByteKnnVectorFieldSource.java | 4 +- .../FloatKnnVectorFieldSource.java | 4 +- .../sandbox/codecs/quantization/KMeans.java | 20 +- .../codecs/quantization/SampleReader.java | 14 +- .../codecs/quantization/TestKMeans.java | 15 +- .../asserting/AssertingKnnVectorsFormat.java | 4 +- .../index/BaseKnnVectorsFormatTestCase.java | 146 ++++++------ 79 files changed, 1127 insertions(+), 2004 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java delete mode 100644 lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java diff --git a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java index f43d33ca2747..53b0dec10114 100644 --- a/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java +++ b/lucene/analysis/common/src/java/org/apache/lucene/analysis/synonym/word2vec/Word2VecModel.java @@ -18,10 +18,10 @@ package org.apache.lucene.analysis.synonym.word2vec; import java.io.IOException; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRefHash; import org.apache.lucene.util.TermAndVector; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** * Word2VecModel is a class representing the parsed Word2Vec model containing the vectors for each @@ -29,7 +29,7 @@ * * @lucene.experimental */ -public class Word2VecModel implements RandomAccessVectorValues.Floats { +public class Word2VecModel extends FloatVectorValues { private final int dictionarySize; private final int vectorDimension; diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java index 52972e9dcda4..30b49f64b542 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java @@ -22,6 +22,7 @@ import java.util.Objects; import java.util.SplittableRandom; import java.util.concurrent.TimeUnit; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.hnsw.NeighborQueue; @@ -49,7 +50,7 @@ public final class Lucene90HnswGraphBuilder { private final Lucene90NeighborArray scratch; private final VectorSimilarityFunction similarityFunction; - private final RandomAccessVectorValues.Floats vectorValues; + private final FloatVectorValues vectorValues; private final SplittableRandom random; private final Lucene90BoundsChecker bound; final Lucene90OnHeapHnswGraph hnsw; @@ -58,7 +59,7 @@ public final class Lucene90HnswGraphBuilder { // we need two sources of vectors in order to perform diversity check comparisons without // colliding - private final RandomAccessVectorValues.Floats buildVectors; + private final FloatVectorValues buildVectors; /** * Reads all the vectors from vector values, builds a graph connecting them by their dense @@ -73,7 +74,7 @@ public final class Lucene90HnswGraphBuilder { * to ensure repeatable construction. */ public Lucene90HnswGraphBuilder( - RandomAccessVectorValues.Floats vectors, + FloatVectorValues vectors, VectorSimilarityFunction similarityFunction, int maxConn, int beamWidth, @@ -104,7 +105,7 @@ public Lucene90HnswGraphBuilder( * @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet * accessor for the vectors */ - public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues.Floats vectors) throws IOException { + public Lucene90OnHeapHnswGraph build(FloatVectorValues vectors) throws IOException { if (vectors == vectorValues) { throw new IllegalArgumentException( "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); @@ -230,7 +231,7 @@ private boolean diversityCheck( float[] candidate, float score, Lucene90NeighborArray neighbors, - RandomAccessVectorValues.Floats vectorValues) + FloatVectorValues vectorValues) throws IOException { bound.set(score); for (int i = 0; i < neighbors.size(); i++) { diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index 665d31403214..e9a74b29726c 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -20,7 +20,6 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; -import java.util.Arrays; import java.util.HashMap; import java.util.Map; import java.util.SplittableRandom; @@ -32,9 +31,9 @@ import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.ChecksumIndexInput; @@ -394,42 +393,6 @@ public int size() { return ordToDoc.length; } - @Override - public float[] vectorValue() throws IOException { - return vectorValue(ord); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() { - if (++ord >= size()) { - doc = NO_MORE_DOCS; - } else { - doc = ordToDoc[ord]; - } - return doc; - } - - @Override - public int advance(int target) { - assert docID() < target; - ord = Arrays.binarySearch(ordToDoc, ord + 1, ordToDoc.length, target); - if (ord < 0) { - ord = -(ord + 1); - } - assert ord <= ordToDoc.length; - if (ord == ordToDoc.length) { - doc = NO_MORE_DOCS; - } else { - doc = ordToDoc[ord]; - } - return doc; - } - @Override public OffHeapFloatVectorValues copy() { return new OffHeapFloatVectorValues(dimension, ordToDoc, similarityFunction, dataIn.clone()); @@ -455,12 +418,12 @@ public VectorScorer scorer(float[] target) { return new VectorScorer() { @Override public float score() throws IOException { - return values.similarityFunction.compare(values.vectorValue(), target); + return values.similarityFunction.compare(values.vectorValue(iterator().index()), target); } @Override - public DocIdSetIterator iterator() { - return values; + public KnnVectorValues.KnnValuesDocIterator iterator() { + return values.iterator(); } }; } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java index 52f2146e836b..845987c2957c 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90OnHeapHnswGraph.java @@ -23,12 +23,12 @@ import java.util.ArrayList; import java.util.List; import java.util.SplittableRandom; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.Bits; import org.apache.lucene.util.SparseFixedBitSet; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.NeighborQueue; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** * An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to @@ -74,7 +74,7 @@ public static NeighborQueue search( float[] query, int topK, int numSeed, - RandomAccessVectorValues.Floats vectors, + FloatVectorValues vectors, VectorSimilarityFunction similarityFunction, HnswGraph graphValues, Bits acceptOrds, diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 81f8d97a9a0c..80fccb5a7ad8 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -439,49 +439,6 @@ public int size() { return size; } - @Override - public float[] vectorValue() throws IOException { - dataIn.seek((long) ord * byteSize); - dataIn.readFloats(value, 0, value.length); - return value; - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() { - if (++ord >= size) { - doc = NO_MORE_DOCS; - } else { - doc = ordToDocOperator.applyAsInt(ord); - } - return doc; - } - - @Override - public int advance(int target) { - assert docID() < target; - - if (ordToDoc == null) { - ord = target; - } else { - ord = Arrays.binarySearch(ordToDoc, ord + 1, ordToDoc.length, target); - if (ord < 0) { - ord = -(ord + 1); - } - } - - if (ord < size) { - doc = ordToDocOperator.applyAsInt(ord); - } else { - doc = NO_MORE_DOCS; - } - return doc; - } - @Override public OffHeapFloatVectorValues copy() { return new OffHeapFloatVectorValues( @@ -504,12 +461,13 @@ public VectorScorer scorer(float[] target) { return new VectorScorer() { @Override public float score() throws IOException { - return values.similarityFunction.compare(values.vectorValue(), target); + return values.similarityFunction.compare( + values.vectorValue(values.iterator().index()), target); } @Override public DocIdSetIterator iterator() { - return values; + return values.iterator(); } }; } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index 19dc82cc46d5..5cdfc5d7ad2c 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -105,30 +105,6 @@ public DenseOffHeapVectorValues( super(dimension, size, vectorSimilarityFunction, slice); } - @Override - public float[] vectorValue() throws IOException { - return vectorValue(doc); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - if (target >= size) { - return doc = NO_MORE_DOCS; - } - return doc = target; - } - @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues(dimension, size, vectorSimilarityFunction, slice.clone()); @@ -145,12 +121,13 @@ public VectorScorer scorer(float[] query) throws IOException { return new VectorScorer() { @Override public float score() throws IOException { - return values.vectorSimilarityFunction.compare(values.vectorValue(), query); + return values.vectorSimilarityFunction.compare( + values.vectorValue(values.iterator().index()), query); } @Override public DocIdSetIterator iterator() { - return values; + return values.iterator(); } }; } @@ -186,27 +163,6 @@ public SparseOffHeapVectorValues( fieldEntry.size()); } - @Override - public float[] vectorValue() throws IOException { - return vectorValue(disi.index()); - } - - @Override - public int docID() { - return disi.docID(); - } - - @Override - public int nextDoc() throws IOException { - return disi.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - return disi.advance(target); - } - @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( @@ -242,12 +198,13 @@ public VectorScorer scorer(float[] query) throws IOException { return new VectorScorer() { @Override public float score() throws IOException { - return values.vectorSimilarityFunction.compare(values.vectorValue(), query); + return values.vectorSimilarityFunction.compare( + values.vectorValue(values.iterator().index()), query); } @Override public DocIdSetIterator iterator() { - return values; + return values.iterator(); } }; } @@ -271,26 +228,6 @@ public int size() { return 0; } - @Override - public float[] vectorValue() throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - return doc = NO_MORE_DOCS; - } - @Override public OffHeapFloatVectorValues copy() throws IOException { throw new UnsupportedOperationException(); diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index 0c909e3839df..53ec069f5d4f 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -119,30 +119,6 @@ public DenseOffHeapVectorValues( super(dimension, size, slice, vectorSimilarityFunction, byteSize); } - @Override - public byte[] vectorValue() throws IOException { - return vectorValue(doc); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - if (target >= size) { - return doc = NO_MORE_DOCS; - } - return doc = target; - } - @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( @@ -160,12 +136,13 @@ public VectorScorer scorer(byte[] query) throws IOException { return new VectorScorer() { @Override public float score() throws IOException { - return vectorSimilarityFunction.compare(copy.vectorValue(), query); + return vectorSimilarityFunction.compare( + copy.vectorValue(copy.docToOrd(copy.iterator().docID())), query); } @Override public DocIdSetIterator iterator() { - return copy; + return copy.iterator(); } }; } @@ -202,27 +179,6 @@ public SparseOffHeapVectorValues( fieldEntry.size()); } - @Override - public byte[] vectorValue() throws IOException { - return vectorValue(disi.index()); - } - - @Override - public int docID() { - return disi.docID(); - } - - @Override - public int nextDoc() throws IOException { - return disi.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - return disi.advance(target); - } - @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( @@ -258,12 +214,12 @@ public VectorScorer scorer(byte[] query) throws IOException { return new VectorScorer() { @Override public float score() throws IOException { - return vectorSimilarityFunction.compare(copy.vectorValue(), query); + return vectorSimilarityFunction.compare(copy.vectorValue(copy.disi.index()), query); } @Override public DocIdSetIterator iterator() { - return copy; + return copy.disi; } }; } @@ -287,26 +243,6 @@ public int size() { return 0; } - @Override - public byte[] vectorValue() throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - return doc = NO_MORE_DOCS; - } - @Override public OffHeapByteVectorValues copy() throws IOException { throw new UnsupportedOperationException(); diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index 91f97b8a41fa..27dfbd98c937 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -115,30 +115,6 @@ public DenseOffHeapVectorValues( super(dimension, size, slice, vectorSimilarityFunction, byteSize); } - @Override - public float[] vectorValue() throws IOException { - return vectorValue(doc); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - if (target >= size) { - return doc = NO_MORE_DOCS; - } - return doc = target; - } - @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( @@ -156,12 +132,13 @@ public VectorScorer scorer(float[] query) throws IOException { return new VectorScorer() { @Override public float score() throws IOException { - return values.vectorSimilarityFunction.compare(values.vectorValue(), query); + return values.vectorSimilarityFunction.compare( + values.vectorValue(values.iterator.index()), query); } @Override public DocIdSetIterator iterator() { - return values; + return values.iterator(); } }; } @@ -198,27 +175,6 @@ public SparseOffHeapVectorValues( fieldEntry.size()); } - @Override - public float[] vectorValue() throws IOException { - return vectorValue(disi.index()); - } - - @Override - public int docID() { - return disi.docID(); - } - - @Override - public int nextDoc() throws IOException { - return disi.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - return disi.advance(target); - } - @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( @@ -254,12 +210,13 @@ public VectorScorer scorer(float[] query) throws IOException { return new VectorScorer() { @Override public float score() throws IOException { - return values.vectorSimilarityFunction.compare(values.vectorValue(), query); + return values.vectorSimilarityFunction.compare( + values.vectorValue(values.iterator().index()), query); } @Override public DocIdSetIterator iterator() { - return values; + return values.iterator(); } }; } @@ -283,26 +240,6 @@ public int size() { return 0; } - @Override - public float[] vectorValue() throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - return doc = NO_MORE_DOCS; - } - @Override public OffHeapFloatVectorValues copy() throws IOException { throw new UnsupportedOperationException(); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java index 39828524d264..aa54f03071b6 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java @@ -17,8 +17,6 @@ package org.apache.lucene.backward_codecs.lucene90; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -35,7 +33,6 @@ import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** * Writes vector values and knn graphs to index segments. @@ -188,12 +185,13 @@ private static int[] writeVectorData(IndexOutput output, FloatVectorValues vecto int count = 0; ByteBuffer binaryVector = ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc(), count++) { + for (int ord = 0; ord < vectors.size(); ord++) { + int docV = vectors.ordToDoc(ord); // write vector - float[] vectorValue = vectors.vectorValue(); + float[] vectorValue = vectors.vectorValue(ord); binaryVector.asFloatBuffer().put(vectorValue); output.writeBytes(binaryVector.array(), binaryVector.limit()); - docIds[count] = docV; + docIds[ord] = docV; } if (docIds.length > count) { @@ -234,7 +232,7 @@ private void writeGraphOffsets(IndexOutput out, long[] offsets) throws IOExcepti private void writeGraph( IndexOutput graphData, - RandomAccessVectorValues.Floats vectorValues, + FloatVectorValues vectorValues, VectorSimilarityFunction similarityFunction, long graphDataOffset, long[] offsets, diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java index dbb9a71b4218..5ef85a8419c2 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswGraphBuilder.java @@ -25,6 +25,7 @@ import java.util.SplittableRandom; import java.util.concurrent.TimeUnit; import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.InfoStream; @@ -32,7 +33,6 @@ import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.HnswGraphSearcher; import org.apache.lucene.util.hnsw.NeighborQueue; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; /** @@ -57,7 +57,7 @@ public final class Lucene91HnswGraphBuilder { private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); private final VectorSimilarityFunction similarityFunction; - private final RandomAccessVectorValues.Floats vectorValues; + private final FloatVectorValues vectorValues; private final SplittableRandom random; private final Lucene91BoundsChecker bound; private final HnswGraphSearcher graphSearcher; @@ -68,7 +68,7 @@ public final class Lucene91HnswGraphBuilder { // we need two sources of vectors in order to perform diversity check comparisons without // colliding - private RandomAccessVectorValues.Floats buildVectors; + private FloatVectorValues buildVectors; /** * Reads all the vectors from vector values, builds a graph connecting them by their dense @@ -83,7 +83,7 @@ public final class Lucene91HnswGraphBuilder { * to ensure repeatable construction. */ public Lucene91HnswGraphBuilder( - RandomAccessVectorValues.Floats vectors, + FloatVectorValues vectors, VectorSimilarityFunction similarityFunction, int maxConn, int beamWidth, @@ -113,14 +113,14 @@ public Lucene91HnswGraphBuilder( } /** - * Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two - * copies enables efficient retrieval without extra data copying, while avoiding collision of the + * Reads all the vectors from two copies of a {@link FloatVectorValues}. Providing two copies + * enables efficient retrieval without extra data copying, while avoiding collision of the * returned values. * - * @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet + * @param vectors the vectors for which to build a nearest neighbors graph. Must be an independent * accessor for the vectors */ - public Lucene91OnHeapHnswGraph build(RandomAccessVectorValues.Floats vectors) throws IOException { + public Lucene91OnHeapHnswGraph build(FloatVectorValues vectors) throws IOException { if (vectors == vectorValues) { throw new IllegalArgumentException( "Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()"); @@ -254,7 +254,7 @@ private boolean diversityCheck( float[] candidate, float score, Lucene91NeighborArray neighbors, - RandomAccessVectorValues.Floats vectorValues) + FloatVectorValues vectorValues) throws IOException { bound.set(score); for (int i = 0; i < neighbors.size(); i++) { diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java index 37b752503817..430313e75685 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java @@ -17,8 +17,6 @@ package org.apache.lucene.backward_codecs.lucene91; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -37,7 +35,6 @@ import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.hnsw.HnswGraph; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** * Writes vector values and knn graphs to index segments. @@ -183,9 +180,10 @@ private static DocsWithFieldSet writeVectorData(IndexOutput output, FloatVectorV DocsWithFieldSet docsWithField = new DocsWithFieldSet(); ByteBuffer binaryVector = ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) { + for (int ord = 0; ord < vectors.size(); ord++) { + int docV = vectors.ordToDoc(ord); // write vector - float[] vectorValue = vectors.vectorValue(); + float[] vectorValue = vectors.vectorValue(ord); binaryVector.asFloatBuffer().put(vectorValue); output.writeBytes(binaryVector.array(), binaryVector.limit()); docsWithField.add(docV); @@ -243,7 +241,7 @@ private void writeMeta( } private Lucene91OnHeapHnswGraph writeGraph( - RandomAccessVectorValues.Floats vectorValues, VectorSimilarityFunction similarityFunction) + FloatVectorValues vectorValues, VectorSimilarityFunction similarityFunction) throws IOException { // build graph diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java index caa8fc3da149..e0c6895c7be9 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java @@ -18,7 +18,6 @@ package org.apache.lucene.backward_codecs.lucene92; import static org.apache.lucene.backward_codecs.lucene92.Lucene92RWHnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; import java.nio.ByteBuffer; @@ -43,7 +42,6 @@ import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.OnHeapHnswGraph; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.packed.DirectMonotonicWriter; @@ -190,9 +188,10 @@ private static DocsWithFieldSet writeVectorData(IndexOutput output, FloatVectorV DocsWithFieldSet docsWithField = new DocsWithFieldSet(); ByteBuffer binaryVector = ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - for (int docV = vectors.nextDoc(); docV != NO_MORE_DOCS; docV = vectors.nextDoc()) { + for (int ord = 0; ord < vectors.size(); ord++) { + int docV = vectors.ordToDoc(ord); // write vector - float[] vectorValue = vectors.vectorValue(); + float[] vectorValue = vectors.vectorValue(ord); binaryVector.asFloatBuffer().put(vectorValue); output.writeBytes(binaryVector.array(), binaryVector.limit()); docsWithField.add(docV); @@ -277,7 +276,7 @@ private void writeMeta( } private OnHeapHnswGraph writeGraph( - RandomAccessVectorValues.Floats vectorValues, VectorSimilarityFunction similarityFunction) + FloatVectorValues vectorValues, VectorSimilarityFunction similarityFunction) throws IOException { DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer(); // build graph diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java index 1cb445cab776..476f2502e0e5 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java @@ -18,7 +18,6 @@ package org.apache.lucene.backward_codecs.lucene94; import static org.apache.lucene.backward_codecs.lucene94.Lucene94RWHnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; import java.nio.ByteBuffer; @@ -52,7 +51,6 @@ import org.apache.lucene.util.hnsw.HnswGraphBuilder; import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.OnHeapHnswGraph; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.packed.DirectMonotonicWriter; @@ -590,11 +588,10 @@ private void writeMeta( private static DocsWithFieldSet writeByteVectorData( IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - for (int docV = byteVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = byteVectorValues.nextDoc()) { + for (int ord = 0; ord < byteVectorValues.size(); ord++) { + int docV = byteVectorValues.ordToDoc(ord); // write vector - byte[] binaryValue = byteVectorValues.vectorValue(); + byte[] binaryValue = byteVectorValues.vectorValue(ord); assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize; output.writeBytes(binaryValue, binaryValue.length); docsWithField.add(docV); @@ -611,11 +608,10 @@ private static DocsWithFieldSet writeVectorData( ByteBuffer binaryVector = ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize) .order(ByteOrder.LITTLE_ENDIAN); - for (int docV = floatVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = floatVectorValues.nextDoc()) { + for (int ord = 0; ord < floatVectorValues.size(); ord++) { + int docV = floatVectorValues.ordToDoc(ord); // write vector - float[] vectorValue = floatVectorValues.vectorValue(); + float[] vectorValue = floatVectorValues.vectorValue(ord); binaryVector.asFloatBuffer().put(vectorValue); output.writeBytes(binaryVector.array(), binaryVector.limit()); docsWithField.add(docV); @@ -672,11 +668,11 @@ public float[] copyValue(float[] value) { case BYTE -> defaultFlatVectorScorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), - RandomAccessVectorValues.fromBytes((List) vectors, dim)); + ByteVectorValues.fromBytes((List) vectors, dim)); case FLOAT32 -> defaultFlatVectorScorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), - RandomAccessVectorValues.fromFloats((List) vectors, dim)); + FloatVectorValues.fromFloats((List) vectors, dim)); }; hnswGraphBuilder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java index 37c39d311d6b..db7081d04c4f 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java @@ -39,6 +39,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; @@ -56,7 +57,6 @@ import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger; import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.OnHeapHnswGraph; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.packed.DirectMonotonicWriter; @@ -222,7 +222,7 @@ private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocM int offset = 1; // 0 means no vector for this (field, document) DocIdSetIterator iterator = fieldData.docsWithField.iterator(); for (int docID = iterator.nextDoc(); - docID != DocIdSetIterator.NO_MORE_DOCS; + docID != NO_MORE_DOCS; docID = iterator.nextDoc()) { int newDocID = sortMap.oldToNew(docID); docIdOffsets[newDocID] = offset++; @@ -482,18 +482,18 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]); } } - DocIdSetIterator mergedVectorIterator = null; + KnnVectorValues mergedVectorValues = null; switch (fieldInfo.getVectorEncoding()) { case BYTE -> - mergedVectorIterator = + mergedVectorValues = KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); case FLOAT32 -> - mergedVectorIterator = + mergedVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); } graph = merger.merge( - mergedVectorIterator, segmentWriteState.infoStream, docsWithField.cardinality()); + mergedVectorValues, segmentWriteState.infoStream, docsWithField.cardinality()); vectorIndexNodeOffsets = writeGraph(graph); } long vectorIndexLength = vectorIndex.getFilePointer() - vectorIndexOffset; @@ -636,11 +636,10 @@ private void writeMeta( private static DocsWithFieldSet writeByteVectorData( IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - for (int docV = byteVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = byteVectorValues.nextDoc()) { + for (int ord = 0; ord < byteVectorValues.size(); ord++) { + int docV = byteVectorValues.ordToDoc(ord); // write vector - byte[] binaryValue = byteVectorValues.vectorValue(); + byte[] binaryValue = byteVectorValues.vectorValue(ord); assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize; output.writeBytes(binaryValue, binaryValue.length); docsWithField.add(docV); @@ -657,11 +656,10 @@ private static DocsWithFieldSet writeVectorData( ByteBuffer buffer = ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize) .order(ByteOrder.LITTLE_ENDIAN); - for (int docV = floatVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = floatVectorValues.nextDoc()) { + for (int ord = 0; ord < floatVectorValues.size(); ord++) { + int docV = floatVectorValues.ordToDoc(ord); // write vector - float[] value = floatVectorValues.vectorValue(); + float[] value = floatVectorValues.vectorValue(ord); buffer.asFloatBuffer().put(value); output.writeBytes(buffer.array(), buffer.limit()); docsWithField.add(docV); @@ -718,11 +716,11 @@ public float[] copyValue(float[] value) { case BYTE -> defaultFlatVectorScorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), - RandomAccessVectorValues.fromBytes((List) vectors, dim)); + ByteVectorValues.fromBytes((List) vectors, dim)); case FLOAT32 -> defaultFlatVectorScorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), - RandomAccessVectorValues.fromFloats((List) vectors, dim)); + FloatVectorValues.fromFloats((List) vectors, dim)); }; hnswGraphBuilder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, HnswGraphBuilder.randSeed); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBasicBackwardsCompatibility.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBasicBackwardsCompatibility.java index 8d35a1128be9..a223d24608ad 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBasicBackwardsCompatibility.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBasicBackwardsCompatibility.java @@ -477,10 +477,15 @@ public static void searchIndex( FloatVectorValues values = ctx.reader().getFloatVectorValues(KNN_VECTOR_FIELD); if (values != null) { assertEquals(KNN_VECTOR_FIELD_TYPE.vectorDimension(), values.dimension()); - for (int doc = values.nextDoc(); doc != NO_MORE_DOCS; doc = values.nextDoc()) { + for (int doc = values.iterator().nextDoc(); + doc != NO_MORE_DOCS; + doc = values.iterator().nextDoc()) { float[] expectedVector = {KNN_VECTOR[0], KNN_VECTOR[1], KNN_VECTOR[2] + 0.1f * cnt}; assertArrayEquals( - "vectors do not match for doc=" + cnt, expectedVector, values.vectorValue(), 0); + "vectors do not match for doc=" + cnt, + expectedVector, + values.vectorValue(values.iterator().index()), + 0); cnt++; } } diff --git a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerBenchmark.java b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerBenchmark.java index c4d3040f2835..0a4da1f48867 100644 --- a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerBenchmark.java +++ b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorScorerBenchmark.java @@ -25,6 +25,7 @@ import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; @@ -32,7 +33,6 @@ import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.util.IOUtils; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.openjdk.jmh.annotations.*; @@ -55,7 +55,7 @@ public class VectorScorerBenchmark { Directory dir; IndexInput in; - RandomAccessVectorValues vectorValues; + KnnVectorValues vectorValues; byte[] vec1, vec2; RandomVectorScorer scorer; @@ -95,7 +95,7 @@ public float binaryDotProductMemSeg() throws IOException { return scorer.score(1); } - static RandomAccessVectorValues vectorValues( + static KnnVectorValues vectorValues( int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { return new OffHeapByteVectorValues.DenseOffHeapVectorValues( dims, size, in.slice("test", 0, in.length()), dims, new ThrowingFlatVectorScorer(), sim); @@ -105,23 +105,19 @@ static final class ThrowingFlatVectorScorer implements FlatVectorsScorer { @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) { + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) { throw new UnsupportedOperationException(); } @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - float[] target) { + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) { throw new UnsupportedOperationException(); } @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - byte[] target) { + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) { throw new UnsupportedOperationException(); } } diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java b/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java index b8ff37c2654a..8ffcc1c8d50e 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/bitvectors/FlatBitVectorsScorer.java @@ -19,10 +19,11 @@ import java.io.IOException; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.Bits; import org.apache.lucene.util.VectorUtil; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; @@ -30,45 +31,39 @@ public class FlatBitVectorsScorer implements FlatVectorsScorer { @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException { - assert vectorValues instanceof RandomAccessVectorValues.Bytes; - if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) { + assert vectorValues instanceof ByteVectorValues; + if (vectorValues instanceof ByteVectorValues byteVectorValues) { return new BitRandomVectorScorerSupplier(byteVectorValues); } - throw new IllegalArgumentException( - "vectorValues must be an instance of RandomAccessVectorValues.Bytes"); + throw new IllegalArgumentException("vectorValues must be an instance of ByteVectorValues"); } @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - float[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) throws IOException { throw new IllegalArgumentException("bit vectors do not support float[] targets"); } @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - byte[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) throws IOException { - assert vectorValues instanceof RandomAccessVectorValues.Bytes; - if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) { + assert vectorValues instanceof ByteVectorValues; + if (vectorValues instanceof ByteVectorValues byteVectorValues) { return new BitRandomVectorScorer(byteVectorValues, target); } - throw new IllegalArgumentException( - "vectorValues must be an instance of RandomAccessVectorValues.Bytes"); + throw new IllegalArgumentException("vectorValues must be an instance of ByteVectorValues"); } static class BitRandomVectorScorer implements RandomVectorScorer { - private final RandomAccessVectorValues.Bytes vectorValues; + private final ByteVectorValues vectorValues; private final int bitDimensions; private final byte[] query; - BitRandomVectorScorer(RandomAccessVectorValues.Bytes vectorValues, byte[] query) { + BitRandomVectorScorer(ByteVectorValues vectorValues, byte[] query) { this.query = query; this.bitDimensions = vectorValues.dimension() * Byte.SIZE; this.vectorValues = vectorValues; @@ -97,12 +92,11 @@ public Bits getAcceptOrds(Bits acceptDocs) { } static class BitRandomVectorScorerSupplier implements RandomVectorScorerSupplier { - protected final RandomAccessVectorValues.Bytes vectorValues; - protected final RandomAccessVectorValues.Bytes vectorValues1; - protected final RandomAccessVectorValues.Bytes vectorValues2; + protected final ByteVectorValues vectorValues; + protected final ByteVectorValues vectorValues1; + protected final ByteVectorValues vectorValues2; - public BitRandomVectorScorerSupplier(RandomAccessVectorValues.Bytes vectorValues) - throws IOException { + public BitRandomVectorScorerSupplier(ByteVectorValues vectorValues) throws IOException { this.vectorValues = vectorValues; this.vectorValues1 = vectorValues.copy(); this.vectorValues2 = vectorValues.copy(); diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index faba629715b7..66faf0f24fd0 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -192,8 +192,8 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits } FieldInfo info = readState.fieldInfos.fieldInfo(field); VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction(); - int doc; - while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + for (int ord = 0; ord < values.size(); ord++) { + int doc = values.ordToDoc(ord); if (acceptDocs != null && acceptDocs.get(doc) == false) { continue; } @@ -202,7 +202,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits break; } - float[] vector = values.vectorValue(); + float[] vector = values.vectorValue(ord); float score = vectorSimilarity.compare(vector, target); knnCollector.collect(doc, score); knnCollector.incVisitedCount(1); @@ -223,8 +223,8 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits FieldInfo info = readState.fieldInfos.fieldInfo(field); VectorSimilarityFunction vectorSimilarity = info.getVectorSimilarityFunction(); - int doc; - while ((doc = values.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + for (int ord = 0; ord < values.size(); ord++) { + int doc = values.ordToDoc(ord); if (acceptDocs != null && acceptDocs.get(doc) == false) { continue; } @@ -233,7 +233,7 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits break; } - byte[] vector = values.vectorValue(); + byte[] vector = values.vectorValue(ord); float score = vectorSimilarity.compare(vector, target); knnCollector.collect(doc, score); knnCollector.incVisitedCount(1); @@ -327,35 +327,8 @@ public int size() { } @Override - public float[] vectorValue() { - return values[curOrd]; - } - - @Override - public int docID() { - if (curOrd == -1) { - return -1; - } else if (curOrd >= entry.size()) { - // when call to advance / nextDoc below already returns NO_MORE_DOCS, calling docID - // immediately afterward should also return NO_MORE_DOCS - // this is needed for TestSimpleTextKnnVectorsFormat.testAdvance test case - return NO_MORE_DOCS; - } - - return entry.ordToDoc[curOrd]; - } - - @Override - public int nextDoc() throws IOException { - if (++curOrd < entry.size()) { - return docID(); - } - return NO_MORE_DOCS; - } - - @Override - public int advance(int target) throws IOException { - return slowAdvance(target); + public float[] vectorValue(int ord) { + return values[ord]; } @Override @@ -368,14 +341,16 @@ public VectorScorer scorer(float[] target) { return new VectorScorer() { @Override public float score() throws IOException { + int docId = simpleTextFloatVectorValues.iterator().docID(); + int ord = simpleTextFloatVectorValues.docToOrd(docId); return entry .similarityFunction() - .compare(simpleTextFloatVectorValues.vectorValue(), target); + .compare(simpleTextFloatVectorValues.vectorValue(ord), target); } @Override public DocIdSetIterator iterator() { - return simpleTextFloatVectorValues; + return simpleTextFloatVectorValues.iterator(); } }; } @@ -439,38 +414,11 @@ public int size() { } @Override - public byte[] vectorValue() { - binaryValue.bytes = values[curOrd]; + public byte[] vectorValue(int ord) { + binaryValue.bytes = values[ord]; return binaryValue.bytes; } - @Override - public int docID() { - if (curOrd == -1) { - return -1; - } else if (curOrd >= entry.size()) { - // when call to advance / nextDoc below already returns NO_MORE_DOCS, calling docID - // immediately afterward should also return NO_MORE_DOCS - // this is needed for TestSimpleTextKnnVectorsFormat.testAdvance test case - return NO_MORE_DOCS; - } - - return entry.ordToDoc[curOrd]; - } - - @Override - public int nextDoc() throws IOException { - if (++curOrd < entry.size()) { - return docID(); - } - return NO_MORE_DOCS; - } - - @Override - public int advance(int target) throws IOException { - return slowAdvance(target); - } - @Override public VectorScorer scorer(byte[] target) { if (size() == 0) { @@ -480,14 +428,16 @@ public VectorScorer scorer(byte[] target) { return new VectorScorer() { @Override public float score() throws IOException { + int docId = simpleTextByteVectorValues.iterator().docID(); + int ord = simpleTextByteVectorValues.docToOrd(docId); return entry .similarityFunction() - .compare(simpleTextByteVectorValues.vectorValue(), target); + .compare(simpleTextByteVectorValues.vectorValue(ord), target); } @Override public DocIdSetIterator iterator() { - return simpleTextByteVectorValues; + return simpleTextByteVectorValues.iterator(); } }; } diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java index a7a76ac1bb98..c71037f2ebd8 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java @@ -17,8 +17,6 @@ package org.apache.lucene.codecs.simpletext; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -77,19 +75,17 @@ public void writeField(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, throws IOException { long vectorDataOffset = vectorData.getFilePointer(); List docIds = new ArrayList<>(); - for (int docV = floatVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = floatVectorValues.nextDoc()) { - writeFloatVectorValue(floatVectorValues); - docIds.add(docV); + for (int ord = 0; ord < floatVectorValues.size(); ord++) { + writeFloatVectorValue(floatVectorValues, ord); + docIds.add(floatVectorValues.ordToDoc(ord)); } long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, docIds); } - private void writeFloatVectorValue(FloatVectorValues vectors) throws IOException { + private void writeFloatVectorValue(FloatVectorValues vectors, int ord) throws IOException { // write vector value - float[] value = vectors.vectorValue(); + float[] value = vectors.vectorValue(ord); assert value.length == vectors.dimension(); write(vectorData, Arrays.toString(value)); newline(vectorData); @@ -100,19 +96,17 @@ public void writeField(FieldInfo fieldInfo, ByteVectorValues byteVectorValues, i throws IOException { long vectorDataOffset = vectorData.getFilePointer(); List docIds = new ArrayList<>(); - for (int docV = byteVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = byteVectorValues.nextDoc()) { - writeByteVectorValue(byteVectorValues); - docIds.add(docV); + for (int ord = 0; ord < byteVectorValues.size(); ord++) { + writeByteVectorValue(byteVectorValues, ord); + docIds.add(byteVectorValues.ordToDoc(ord)); } long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, docIds); } - private void writeByteVectorValue(ByteVectorValues vectors) throws IOException { + private void writeByteVectorValue(ByteVectorValues vectors, int ord) throws IOException { // write vector value - byte[] value = vectors.vectorValue(); + byte[] value = vectors.vectorValue(ord); assert value.length == vectors.dimension(); write(vectorData, Arrays.toString(value)); newline(vectorData); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java index 8a9b4816571e..675025a1c93e 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java @@ -24,10 +24,9 @@ import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.Sorter; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.RamUsageEstimator; @@ -80,9 +79,7 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { case FLOAT32: BufferedFloatVectorValues bufferedFloatVectorValues = new BufferedFloatVectorValues( - fieldData.docsWithField, - (List) fieldData.vectors, - fieldData.fieldInfo.getVectorDimension()); + (List) fieldData.vectors, fieldData.fieldInfo.getVectorDimension()); FloatVectorValues floatVectorValues = sortMap != null ? new SortingFloatVectorValues(bufferedFloatVectorValues, sortMap) @@ -92,9 +89,7 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { case BYTE: BufferedByteVectorValues bufferedByteVectorValues = new BufferedByteVectorValues( - fieldData.docsWithField, - (List) fieldData.vectors, - fieldData.fieldInfo.getVectorDimension()); + (List) fieldData.vectors, fieldData.fieldInfo.getVectorDimension()); ByteVectorValues byteVectorValues = sortMap != null ? new SortingByteVectorValues(bufferedByteVectorValues, sortMap) @@ -105,126 +100,88 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { } } - /** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */ + /** + * Sorting FloatVectorValues that maps ordinals using the provided sortMap expressed in terms of + * docids + */ private static class SortingFloatVectorValues extends FloatVectorValues { - private final BufferedFloatVectorValues randomAccess; - private final int[] docIdOffsets; - private int docId = -1; + private final BufferedFloatVectorValues delegate; + private final int[] newToOld; SortingFloatVectorValues(BufferedFloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException { - this.randomAccess = delegate.copy(); - this.docIdOffsets = new int[sortMap.size()]; - - int offset = 1; // 0 means no vector for this (field, document) - int docID; - while ((docID = delegate.nextDoc()) != NO_MORE_DOCS) { - int newDocID = sortMap.oldToNew(docID); - docIdOffsets[newDocID] = offset++; - } + this.delegate = delegate.copy(); + newToOld = docMapToOrdMap(delegate, sortMap); } @Override - public int docID() { - return docId; - } - - @Override - public int nextDoc() throws IOException { - while (docId < docIdOffsets.length - 1) { - ++docId; - if (docIdOffsets[docId] != 0) { - return docId; - } - } - docId = NO_MORE_DOCS; - return docId; - } - - @Override - public float[] vectorValue() throws IOException { - return randomAccess.vectorValue(docIdOffsets[docId] - 1); + public float[] vectorValue(int ord) throws IOException { + return delegate.vectorValue(newToOld[ord]); } @Override public int dimension() { - return randomAccess.dimension(); + return delegate.dimension(); } @Override public int size() { - return randomAccess.size(); + return delegate.size(); } @Override - public int advance(int target) throws IOException { + public SortingFloatVectorValues copy() { throw new UnsupportedOperationException(); } + } - @Override - public VectorScorer scorer(float[] target) { - throw new UnsupportedOperationException(); + private static int[] docMapToOrdMap(KnnVectorValues values, Sorter.DocMap docMap) { + int[] newToOld = new int[docMap.size()]; + int ord = 0; + for (int doc = 0; doc < newToOld.length; doc++) { + int oldDoc = docMap.newToOld(doc); + int oldOrd = values.docToOrd(oldDoc); // no value represented by -1 + if (oldOrd >= 0) { + newToOld[ord++] = oldOrd; + } + } + if (ord < newToOld.length) { + newToOld = ArrayUtil.copyOfSubArray(newToOld, 0, ord); } + return newToOld; } - /** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */ + /** + * Sorting ByteVectorValues that maps ordinals using the provided sortMap expressed in terms of + * docids + */ private static class SortingByteVectorValues extends ByteVectorValues { - private final BufferedByteVectorValues randomAccess; - private final int[] docIdOffsets; - private int docId = -1; + private final BufferedByteVectorValues delegate; + private final int[] newToOld; SortingByteVectorValues(BufferedByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException { - this.randomAccess = delegate.copy(); - this.docIdOffsets = new int[sortMap.size()]; - - int offset = 1; // 0 means no vector for this (field, document) - int docID; - while ((docID = delegate.nextDoc()) != NO_MORE_DOCS) { - int newDocID = sortMap.oldToNew(docID); - docIdOffsets[newDocID] = offset++; - } + this.delegate = delegate; + newToOld = docMapToOrdMap(delegate, sortMap); } @Override - public int docID() { - return docId; - } - - @Override - public int nextDoc() throws IOException { - while (docId < docIdOffsets.length - 1) { - ++docId; - if (docIdOffsets[docId] != 0) { - return docId; - } - } - docId = NO_MORE_DOCS; - return docId; - } - - @Override - public byte[] vectorValue() throws IOException { - return randomAccess.vectorValue(docIdOffsets[docId] - 1); + public byte[] vectorValue(int ord) throws IOException { + return delegate.vectorValue(newToOld[ord]); } @Override public int dimension() { - return randomAccess.dimension(); + return delegate.dimension(); } @Override public int size() { - return randomAccess.size(); - } - - @Override - public int advance(int target) throws IOException { - throw new UnsupportedOperationException(); + return delegate.size(); } @Override - public VectorScorer scorer(byte[] target) { + public SortingByteVectorValues copy() { throw new UnsupportedOperationException(); } } @@ -296,7 +253,9 @@ public final void addValue(int docID, T value) { @Override public final long ramBytesUsed() { - if (vectors.size() == 0) return 0; + if (vectors.isEmpty()) { + return 0; + } return docsWithField.ramBytesUsed() + vectors.size() * (long) @@ -307,25 +266,13 @@ public final long ramBytesUsed() { } private static class BufferedFloatVectorValues extends FloatVectorValues { - final DocsWithFieldSet docsWithField; - // These are always the vectors of a VectorValuesWriter, which are copied when added to it final List vectors; final int dimension; - DocIdSetIterator docsWithFieldIter; - int ord = -1; - - BufferedFloatVectorValues( - DocsWithFieldSet docsWithField, List vectors, int dimension) { - this.docsWithField = docsWithField; + BufferedFloatVectorValues(List vectors, int dimension) { this.vectors = vectors; this.dimension = dimension; - docsWithFieldIter = docsWithField.iterator(); - } - - public BufferedFloatVectorValues copy() { - return new BufferedFloatVectorValues(docsWithField, vectors, dimension); } @Override @@ -339,58 +286,24 @@ public int size() { } @Override - public float[] vectorValue() { - return vectors.get(ord); - } - - float[] vectorValue(int targetOrd) { + public float[] vectorValue(int targetOrd) { return vectors.get(targetOrd); } @Override - public int docID() { - return docsWithFieldIter.docID(); - } - - @Override - public int nextDoc() throws IOException { - int docID = docsWithFieldIter.nextDoc(); - if (docID != NO_MORE_DOCS) { - ++ord; - } - return docID; - } - - @Override - public int advance(int target) { - throw new UnsupportedOperationException(); - } - - @Override - public VectorScorer scorer(float[] target) { - throw new UnsupportedOperationException(); + public BufferedFloatVectorValues copy() { + return this; } } private static class BufferedByteVectorValues extends ByteVectorValues { - final DocsWithFieldSet docsWithField; - // These are always the vectors of a VectorValuesWriter, which are copied when added to it final List vectors; final int dimension; - DocIdSetIterator docsWithFieldIter; - int ord = -1; - - BufferedByteVectorValues(DocsWithFieldSet docsWithField, List vectors, int dimension) { - this.docsWithField = docsWithField; + BufferedByteVectorValues(List vectors, int dimension) { this.vectors = vectors; this.dimension = dimension; - docsWithFieldIter = docsWithField.iterator(); - } - - public BufferedByteVectorValues copy() { - return new BufferedByteVectorValues(docsWithField, vectors, dimension); } @Override @@ -404,36 +317,13 @@ public int size() { } @Override - public byte[] vectorValue() { - return vectors.get(ord); - } - - byte[] vectorValue(int targetOrd) { + public byte[] vectorValue(int targetOrd) { return vectors.get(targetOrd); } @Override - public int docID() { - return docsWithFieldIter.docID(); - } - - @Override - public int nextDoc() throws IOException { - int docID = docsWithFieldIter.nextDoc(); - if (docID != NO_MORE_DOCS) { - ++ord; - } - return docID; - } - - @Override - public int advance(int target) { - throw new UnsupportedOperationException(); - } - - @Override - public VectorScorer scorer(byte[] target) { - throw new UnsupportedOperationException(); + public BufferedByteVectorValues copy() { + return this; } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index 3b185fd13a07..408a336f1561 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -60,10 +60,8 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE (KnnFieldVectorsWriter) addField(fieldInfo); ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); - for (int doc = mergedBytes.nextDoc(); - doc != DocIdSetIterator.NO_MORE_DOCS; - doc = mergedBytes.nextDoc()) { - byteWriter.addValue(doc, mergedBytes.vectorValue()); + for (int ord = 0; ord < mergedBytes.size(); ord++) { + byteWriter.addValue(mergedBytes.ordToDoc(ord), mergedBytes.vectorValue(ord)); } break; case FLOAT32: @@ -71,10 +69,8 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE (KnnFieldVectorsWriter) addField(fieldInfo); FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - for (int doc = mergedFloats.nextDoc(); - doc != DocIdSetIterator.NO_MORE_DOCS; - doc = mergedFloats.nextDoc()) { - floatWriter.addValue(doc, mergedFloats.vectorValue()); + for (int ord = 0; ord < mergedFloats.size(); ord++) { + floatWriter.addValue(mergedFloats.ordToDoc(ord), mergedFloats.vectorValue(ord)); } break; } @@ -118,15 +114,16 @@ private static class FloatVectorValuesSub extends DocIDMerger.Sub { final FloatVectorValues values; + int ord = -1; + FloatVectorValuesSub(MergeState.DocMap docMap, FloatVectorValues values) { super(docMap); this.values = values; - assert values.docID() == -1; } @Override public int nextDoc() throws IOException { - return values.nextDoc(); + return values.ordToDoc(++ord); } } @@ -134,15 +131,16 @@ private static class ByteVectorValuesSub extends DocIDMerger.Sub { final ByteVectorValues values; + int ord = -1; + ByteVectorValuesSub(MergeState.DocMap docMap, ByteVectorValues values) { super(docMap); this.values = values; - assert values.docID() == -1; } @Override public int nextDoc() throws IOException { - return values.nextDoc(); + return values.ordToDoc(++ord); } } @@ -287,44 +285,48 @@ static class MergedFloat32VectorValues extends FloatVectorValues { private final List subs; private final DocIDMerger docIdMerger; private final int size; - private int docId; - FloatVectorValuesSub current; + private final int[] ends; private MergedFloat32VectorValues(List subs, MergeState mergeState) throws IOException { this.subs = subs; docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort); int totalSize = 0; + ends = new int[subs.size()]; + int iSub = 0; for (FloatVectorValuesSub sub : subs) { totalSize += sub.values.size(); + ends[iSub++] = totalSize; } size = totalSize; - docId = -1; - } - - @Override - public int docID() { - return docId; } @Override - public int nextDoc() throws IOException { - current = docIdMerger.next(); - if (current == null) { - docId = NO_MORE_DOCS; + public float[] vectorValue(int ord) throws IOException { + // TODO: if this is a bottleneck we could predict the last iSub + int iSub = Arrays.binarySearch(ends, ord); + if (iSub < 0) { + iSub = -(iSub + 1); + } + if (iSub == 0) { + return subs.get(iSub).values.vectorValue(ord); } else { - docId = current.mappedDocID; + return subs.get(iSub).values.vectorValue(ord - ends[iSub - 1]); } - return docId; } @Override - public float[] vectorValue() throws IOException { - return current.values.vectorValue(); + public int ordToDoc(int ord) { + throw new UnsupportedOperationException(); + /** + * TODO: is this needed? int iSub = Arrays.binarySearch(ends, ord); if (iSub < 0) { iSub = + * -(iSub + 1); } if (iSub == 0) { return subs.get(iSub).values.ordToDoc(ord); } else { + * return subs.get(iSub).values.ordToDoc(ord - ends[iSub - 1]); } + */ } @Override - public int advance(int target) { + public int docToOrd(int ord) { throw new UnsupportedOperationException(); } @@ -342,12 +344,18 @@ public int dimension() { public VectorScorer scorer(float[] target) { throw new UnsupportedOperationException(); } + + @Override + public FloatVectorValues copy() { + throw new UnsupportedOperationException(); + } } static class MergedByteVectorValues extends ByteVectorValues { private final List subs; private final DocIDMerger docIdMerger; private final int size; + private final int[] ends; private int docId; ByteVectorValuesSub current; @@ -357,37 +365,36 @@ private MergedByteVectorValues(List subs, MergeState mergeS this.subs = subs; docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort); int totalSize = 0; + ends = new int[subs.size()]; + int iSub = 0; for (ByteVectorValuesSub sub : subs) { totalSize += sub.values.size(); + ends[iSub++] = totalSize; } size = totalSize; - docId = -1; } @Override - public byte[] vectorValue() throws IOException { - return current.values.vectorValue(); - } - - @Override - public int docID() { - return docId; - } - - @Override - public int nextDoc() throws IOException { - current = docIdMerger.next(); - if (current == null) { - docId = NO_MORE_DOCS; + public byte[] vectorValue(int ord) throws IOException { + int iSub = Arrays.binarySearch(ends, ord); + if (iSub < 0) { + iSub = -(iSub + 1); + } + if (iSub == 0) { + return subs.get(iSub).values.vectorValue(ord); } else { - docId = current.mappedDocID; + return subs.get(iSub).values.vectorValue(ord - ends[iSub - 1]); } - return docId; } @Override - public int advance(int target) { + public int ordToDoc(int ord) { throw new UnsupportedOperationException(); + /** + * TODO: is this needed? int iSub = Arrays.binarySearch(ends, ord); if (iSub < 0) { iSub = + * -(iSub + 1); } if (iSub == 0) { return subs.get(iSub).values.ordToDoc(ord); } else { + * return subs.get(iSub).values.ordToDoc(ord - ends[iSub - 1]); } + */ } @Override @@ -404,6 +411,11 @@ public int dimension() { public VectorScorer scorer(byte[] target) { throw new UnsupportedOperationException(); } + + @Override + public ByteVectorValues copy() { + throw new UnsupportedOperationException(); + } } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java index 1274e1c789e4..e1f1b24d3685 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java @@ -18,8 +18,10 @@ package org.apache.lucene.codecs.hnsw; import java.io.IOException; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; @@ -34,24 +36,22 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer { @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException { - if (vectorValues instanceof RandomAccessVectorValues.Floats floatVectorValues) { + if (vectorValues instanceof FloatVectorValues floatVectorValues) { return new FloatScoringSupplier(floatVectorValues, similarityFunction); - } else if (vectorValues instanceof RandomAccessVectorValues.Bytes byteVectorValues) { + } else if (vectorValues instanceof ByteVectorValues byteVectorValues) { return new ByteScoringSupplier(byteVectorValues, similarityFunction); } throw new IllegalArgumentException( - "vectorValues must be an instance of RandomAccessVectorValues.Floats or RandomAccessVectorValues.Bytes"); + "vectorValues must be an instance of FloatVectorValues or ByteVectorValues"); } @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - float[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) throws IOException { - assert vectorValues instanceof RandomAccessVectorValues.Floats; + assert vectorValues instanceof FloatVectorValues; if (target.length != vectorValues.dimension()) { throw new IllegalArgumentException( "vector query dimension: " @@ -59,17 +59,14 @@ public RandomVectorScorer getRandomVectorScorer( + " differs from field dimension: " + vectorValues.dimension()); } - return new FloatVectorScorer( - (RandomAccessVectorValues.Floats) vectorValues, target, similarityFunction); + return new FloatVectorScorer((FloatVectorValues) vectorValues, target, similarityFunction); } @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - byte[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) throws IOException { - assert vectorValues instanceof RandomAccessVectorValues.Bytes; + assert vectorValues instanceof ByteVectorValues; if (target.length != vectorValues.dimension()) { throw new IllegalArgumentException( "vector query dimension: " @@ -77,8 +74,7 @@ public RandomVectorScorer getRandomVectorScorer( + " differs from field dimension: " + vectorValues.dimension()); } - return new ByteVectorScorer( - (RandomAccessVectorValues.Bytes) vectorValues, target, similarityFunction); + return new ByteVectorScorer((ByteVectorValues) vectorValues, target, similarityFunction); } @Override @@ -88,14 +84,13 @@ public String toString() { /** RandomVectorScorerSupplier for bytes vector */ private static final class ByteScoringSupplier implements RandomVectorScorerSupplier { - private final RandomAccessVectorValues.Bytes vectors; - private final RandomAccessVectorValues.Bytes vectors1; - private final RandomAccessVectorValues.Bytes vectors2; + private final ByteVectorValues vectors; + private final ByteVectorValues vectors1; + private final ByteVectorValues vectors2; private final VectorSimilarityFunction similarityFunction; private ByteScoringSupplier( - RandomAccessVectorValues.Bytes vectors, VectorSimilarityFunction similarityFunction) - throws IOException { + ByteVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException { this.vectors = vectors; vectors1 = vectors.copy(); vectors2 = vectors.copy(); @@ -125,14 +120,13 @@ public String toString() { /** RandomVectorScorerSupplier for Float vector */ private static final class FloatScoringSupplier implements RandomVectorScorerSupplier { - private final RandomAccessVectorValues.Floats vectors; - private final RandomAccessVectorValues.Floats vectors1; - private final RandomAccessVectorValues.Floats vectors2; + private final FloatVectorValues vectors; + private final FloatVectorValues vectors1; + private final FloatVectorValues vectors2; private final VectorSimilarityFunction similarityFunction; private FloatScoringSupplier( - RandomAccessVectorValues.Floats vectors, VectorSimilarityFunction similarityFunction) - throws IOException { + FloatVectorValues vectors, VectorSimilarityFunction similarityFunction) throws IOException { this.vectors = vectors; vectors1 = vectors.copy(); vectors2 = vectors.copy(); @@ -162,14 +156,12 @@ public String toString() { /** A {@link RandomVectorScorer} for float vectors. */ private static class FloatVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { - private final RandomAccessVectorValues.Floats values; + private final FloatVectorValues values; private final float[] query; private final VectorSimilarityFunction similarityFunction; public FloatVectorScorer( - RandomAccessVectorValues.Floats values, - float[] query, - VectorSimilarityFunction similarityFunction) { + FloatVectorValues values, float[] query, VectorSimilarityFunction similarityFunction) { super(values); this.values = values; this.query = query; @@ -184,14 +176,12 @@ public float score(int node) throws IOException { /** A {@link RandomVectorScorer} for byte vectors. */ private static class ByteVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { - private final RandomAccessVectorValues.Bytes values; + private final ByteVectorValues values; private final byte[] query; private final VectorSimilarityFunction similarityFunction; public ByteVectorScorer( - RandomAccessVectorValues.Bytes values, - byte[] query, - VectorSimilarityFunction similarityFunction) { + ByteVectorValues values, byte[] query, VectorSimilarityFunction similarityFunction) { super(values); this.values = values; this.query = query; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsScorer.java index 17430c24f276..6ed170731de4 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsScorer.java @@ -18,8 +18,8 @@ package org.apache.lucene.codecs.hnsw; import java.io.IOException; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; @@ -40,8 +40,7 @@ public interface FlatVectorsScorer { * @throws IOException if an I/O error occurs */ RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) - throws IOException; + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException; /** * Returns a {@link RandomVectorScorer} for the given set of vectors and target vector. @@ -53,9 +52,7 @@ RandomVectorScorerSupplier getRandomVectorScorerSupplier( * @throws IOException if an I/O error occurs when reading from the index. */ RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - float[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) throws IOException; /** @@ -68,8 +65,6 @@ RandomVectorScorer getRandomVectorScorer( * @throws IOException if an I/O error occurs when reading from the index. */ RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - byte[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) throws IOException; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java index 4b73e1f7a4a6..ceb826aa3a11 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/ScalarQuantizedVectorScorer.java @@ -18,13 +18,13 @@ package org.apache.lucene.codecs.hnsw; import java.io.IOException; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.VectorUtil; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizedVectorSimilarity; import org.apache.lucene.util.quantization.ScalarQuantizer; @@ -60,9 +60,9 @@ public ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) { @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException { - if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) { + if (vectorValues instanceof QuantizedByteVectorValues quantizedByteVectorValues) { return new ScalarQuantizedRandomVectorScorerSupplier( similarityFunction, quantizedByteVectorValues.getScalarQuantizer(), @@ -74,11 +74,9 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier( @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - float[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) throws IOException { - if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) { + if (vectorValues instanceof QuantizedByteVectorValues quantizedByteVectorValues) { ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer(); byte[] targetBytes = new byte[target.length]; float offsetCorrection = @@ -104,9 +102,7 @@ public float score(int node) throws IOException { @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - byte[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) throws IOException { return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); } @@ -124,14 +120,14 @@ public String toString() { public static class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { - private final RandomAccessQuantizedByteVectorValues values; + private final QuantizedByteVectorValues values; private final ScalarQuantizedVectorSimilarity similarity; private final VectorSimilarityFunction vectorSimilarityFunction; public ScalarQuantizedRandomVectorScorerSupplier( VectorSimilarityFunction similarityFunction, ScalarQuantizer scalarQuantizer, - RandomAccessQuantizedByteVectorValues values) { + QuantizedByteVectorValues values) { this.similarity = ScalarQuantizedVectorSimilarity.fromVectorSimilarity( similarityFunction, @@ -144,7 +140,7 @@ public ScalarQuantizedRandomVectorScorerSupplier( private ScalarQuantizedRandomVectorScorerSupplier( ScalarQuantizedVectorSimilarity similarity, VectorSimilarityFunction vectorSimilarityFunction, - RandomAccessQuantizedByteVectorValues values) { + QuantizedByteVectorValues values) { this.similarity = similarity; this.values = values; this.vectorSimilarityFunction = vectorSimilarityFunction; @@ -152,7 +148,7 @@ private ScalarQuantizedRandomVectorScorerSupplier( @Override public RandomVectorScorer scorer(int ord) throws IOException { - final RandomAccessQuantizedByteVectorValues vectorsCopy = values.copy(); + final QuantizedByteVectorValues vectorsCopy = values.copy(); final byte[] queryVector = values.vectorValue(ord); final float queryOffset = values.getScoreCorrectionConstant(ord); return new RandomVectorScorer.AbstractRandomVectorScorer(vectorsCopy) { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index f45158eadac7..9a399e581b8c 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -132,9 +132,6 @@ public static OffHeapByteVectorValues load( * vector. */ public static class DenseOffHeapVectorValues extends OffHeapByteVectorValues { - - private int doc = -1; - public DenseOffHeapVectorValues( int dimension, int size, @@ -146,27 +143,8 @@ public DenseOffHeapVectorValues( } @Override - public byte[] vectorValue() throws IOException { - return vectorValue(doc); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - if (target >= size) { - return doc = NO_MORE_DOCS; - } - return doc = target; + public byte[] vectorValue(int ord) throws IOException { + return vectorValue(ord); } @Override @@ -188,12 +166,12 @@ public VectorScorer scorer(byte[] query) throws IOException { return new VectorScorer() { @Override public float score() throws IOException { - return scorer.score(copy.doc); + return scorer.score(copy.iterator().docID()); } @Override public DocIdSetIterator iterator() { - return copy; + return copy.iterator(); } }; } @@ -238,27 +216,6 @@ public SparseOffHeapVectorValues( configuration.size); } - @Override - public byte[] vectorValue() throws IOException { - return vectorValue(disi.index()); - } - - @Override - public int docID() { - return disi.docID(); - } - - @Override - public int nextDoc() throws IOException { - return disi.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - return disi.advance(target); - } - @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( @@ -307,7 +264,7 @@ public float score() throws IOException { @Override public DocIdSetIterator iterator() { - return copy; + return copy.disi; } }; } @@ -335,35 +292,15 @@ public int size() { } @Override - public byte[] vectorValue() throws IOException { + public byte[] vectorValue(int ord) throws IOException { throw new UnsupportedOperationException(); } - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - return doc = NO_MORE_DOCS; - } - @Override public EmptyOffHeapVectorValues copy() throws IOException { throw new UnsupportedOperationException(); } - @Override - public byte[] vectorValue(int targetOrd) throws IOException { - throw new UnsupportedOperationException(); - } - @Override public int ordToDoc(int ord) { throw new UnsupportedOperationException(); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 1f61283b5002..73d71e31e69b 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -71,7 +71,6 @@ public int size() { return size; } - @Override public IndexInput getSlice() { return slice; } @@ -128,8 +127,6 @@ public static OffHeapFloatVectorValues load( */ public static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues { - private int doc = -1; - public DenseOffHeapVectorValues( int dimension, int size, @@ -140,30 +137,6 @@ public DenseOffHeapVectorValues( super(dimension, size, slice, byteSize, flatVectorsScorer, similarityFunction); } - @Override - public float[] vectorValue() throws IOException { - return vectorValue(doc); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - if (target >= size) { - return doc = NO_MORE_DOCS; - } - return doc = target; - } - @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( @@ -183,12 +156,12 @@ public VectorScorer scorer(float[] query) throws IOException { return new VectorScorer() { @Override public float score() throws IOException { - return randomVectorScorer.score(copy.doc); + return randomVectorScorer.score(copy.iterator().docID()); } @Override public DocIdSetIterator iterator() { - return copy; + return copy.iterator(); } }; } @@ -227,27 +200,6 @@ public SparseOffHeapVectorValues( configuration.size); } - @Override - public float[] vectorValue() throws IOException { - return vectorValue(disi.index()); - } - - @Override - public int docID() { - return disi.docID(); - } - - @Override - public int nextDoc() throws IOException { - return disi.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - return disi.advance(target); - } - @Override public SparseOffHeapVectorValues copy() throws IOException { return new SparseOffHeapVectorValues( @@ -296,7 +248,7 @@ public float score() throws IOException { @Override public DocIdSetIterator iterator() { - return copy; + return copy.iterator(); } }; } @@ -311,8 +263,6 @@ public EmptyOffHeapVectorValues( super(dimension, 0, null, 0, flatVectorsScorer, similarityFunction); } - private int doc = -1; - @Override public int dimension() { return super.dimension(); @@ -323,26 +273,6 @@ public int size() { return 0; } - @Override - public float[] vectorValue() throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) { - return doc = NO_MORE_DOCS; - } - @Override public EmptyOffHeapVectorValues copy() { throw new UnsupportedOperationException(); @@ -353,11 +283,6 @@ public float[] vectorValue(int targetOrd) { throw new UnsupportedOperationException(); } - @Override - public int ordToDoc(int ord) { - throw new UnsupportedOperationException(); - } - @Override public Bits getAcceptOrds(Bits acceptDocs) { return null; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java index 1af68618d833..d60f52698552 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java @@ -18,7 +18,6 @@ package org.apache.lucene.codecs.lucene99; import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.Closeable; import java.io.IOException; @@ -361,14 +360,12 @@ private void writeMeta( private static DocsWithFieldSet writeByteVectorData( IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - for (int docV = byteVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = byteVectorValues.nextDoc()) { + for (int ord = 0; ord < byteVectorValues.size(); ord++) { // write vector - byte[] binaryValue = byteVectorValues.vectorValue(); + byte[] binaryValue = byteVectorValues.vectorValue(ord); assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize; output.writeBytes(binaryValue, binaryValue.length); - docsWithField.add(docV); + docsWithField.add(byteVectorValues.ordToDoc(ord)); } return docsWithField; } @@ -382,14 +379,12 @@ private static DocsWithFieldSet writeVectorData( ByteBuffer buffer = ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize) .order(ByteOrder.LITTLE_ENDIAN); - for (int docV = floatVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = floatVectorValues.nextDoc()) { + for (int ord = 0; ord < floatVectorValues.size(); ord++) { // write vector - float[] value = floatVectorValues.vectorValue(); + float[] value = floatVectorValues.vectorValue(ord); buffer.asFloatBuffer().put(value); output.writeBytes(buffer.array(), buffer.limit()); - docsWithField.add(docV); + docsWithField.add(floatVectorValues.ordToDoc(ord)); } return docsWithField; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java index dc0fb7184c7c..0f4e8196d52d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsWriter.java @@ -32,14 +32,16 @@ import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.IOUtils; @@ -54,7 +56,6 @@ import org.apache.lucene.util.hnsw.IncrementalHnswGraphMerger; import org.apache.lucene.util.hnsw.NeighborArray; import org.apache.lucene.util.hnsw.OnHeapHnswGraph; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.packed.DirectMonotonicWriter; @@ -359,18 +360,18 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE mergeState.knnVectorsReaders[i], mergeState.docMaps[i], mergeState.liveDocs[i]); } } - DocIdSetIterator mergedVectorIterator = null; + KnnVectorValues mergedVectorValues = null; switch (fieldInfo.getVectorEncoding()) { case BYTE -> - mergedVectorIterator = + mergedVectorValues = KnnVectorsWriter.MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); case FLOAT32 -> - mergedVectorIterator = + mergedVectorValues = KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); } graph = merger.merge( - mergedVectorIterator, + mergedVectorValues, segmentWriteState.infoStream, scorerSupplier.totalVectorCount()); vectorIndexNodeOffsets = writeGraph(graph); @@ -582,13 +583,13 @@ static FieldWriter create( case BYTE -> scorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), - RandomAccessVectorValues.fromBytes( + ByteVectorValues.fromBytes( (List) flatFieldVectorsWriter.getVectors(), fieldInfo.getVectorDimension())); case FLOAT32 -> scorer.getRandomVectorScorerSupplier( fieldInfo.getVectorSimilarityFunction(), - RandomAccessVectorValues.fromFloats( + FloatVectorValues.fromFloats( (List) flatFieldVectorsWriter.getVectors(), fieldInfo.getVectorDimension())); }; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java index 8443017d3f9a..a4770f01f46d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorScorer.java @@ -21,12 +21,12 @@ import java.io.IOException; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.VectorUtil; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizer; /** @@ -45,9 +45,9 @@ public Lucene99ScalarQuantizedVectorScorer(FlatVectorsScorer flatVectorsScorer) @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityFunction, RandomAccessVectorValues vectorValues) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException { - if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) { + if (vectorValues instanceof QuantizedByteVectorValues quantizedByteVectorValues) { return new ScalarQuantizedRandomVectorScorerSupplier( quantizedByteVectorValues, similarityFunction); } @@ -57,11 +57,9 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier( @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - float[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, float[] target) throws IOException { - if (vectorValues instanceof RandomAccessQuantizedByteVectorValues quantizedByteVectorValues) { + if (vectorValues instanceof QuantizedByteVectorValues quantizedByteVectorValues) { ScalarQuantizer scalarQuantizer = quantizedByteVectorValues.getScalarQuantizer(); byte[] targetBytes = new byte[target.length]; float offsetCorrection = @@ -79,9 +77,7 @@ public RandomVectorScorer getRandomVectorScorer( @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityFunction, - RandomAccessVectorValues vectorValues, - byte[] target) + VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues, byte[] target) throws IOException { return nonQuantizedDelegate.getRandomVectorScorer(similarityFunction, vectorValues, target); } @@ -96,7 +92,7 @@ static RandomVectorScorer fromVectorSimilarity( float offsetCorrection, VectorSimilarityFunction sim, float constMultiplier, - RandomAccessQuantizedByteVectorValues values) { + QuantizedByteVectorValues values) { return switch (sim) { case EUCLIDEAN -> new Euclidean(values, constMultiplier, targetBytes); case COSINE, DOT_PRODUCT -> @@ -120,7 +116,7 @@ private static RandomVectorScorer.AbstractRandomVectorScorer dotProductFactory( byte[] targetBytes, float offsetCorrection, float constMultiplier, - RandomAccessQuantizedByteVectorValues values, + QuantizedByteVectorValues values, FloatToFloatFunction scoreAdjustmentFunction) { if (values.getScalarQuantizer().getBits() <= 4) { if (values.getVectorByteLength() != values.dimension() && values.getSlice() != null) { @@ -137,10 +133,9 @@ private static RandomVectorScorer.AbstractRandomVectorScorer dotProductFactory( private static class Euclidean extends RandomVectorScorer.AbstractRandomVectorScorer { private final float constMultiplier; private final byte[] targetBytes; - private final RandomAccessQuantizedByteVectorValues values; + private final QuantizedByteVectorValues values; - private Euclidean( - RandomAccessQuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes) { + private Euclidean(QuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes) { super(values); this.values = values; this.constMultiplier = constMultiplier; @@ -159,13 +154,13 @@ public float score(int node) throws IOException { /** Calculates dot product on quantized vectors, applying the appropriate corrections */ private static class DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer { private final float constMultiplier; - private final RandomAccessQuantizedByteVectorValues values; + private final QuantizedByteVectorValues values; private final byte[] targetBytes; private final float offsetCorrection; private final FloatToFloatFunction scoreAdjustmentFunction; public DotProduct( - RandomAccessQuantizedByteVectorValues values, + QuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes, float offsetCorrection, @@ -193,14 +188,14 @@ public float score(int vectorOrdinal) throws IOException { private static class CompressedInt4DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer { private final float constMultiplier; - private final RandomAccessQuantizedByteVectorValues values; + private final QuantizedByteVectorValues values; private final byte[] compressedVector; private final byte[] targetBytes; private final float offsetCorrection; private final FloatToFloatFunction scoreAdjustmentFunction; private CompressedInt4DotProduct( - RandomAccessQuantizedByteVectorValues values, + QuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes, float offsetCorrection, @@ -231,13 +226,13 @@ public float score(int vectorOrdinal) throws IOException { private static class Int4DotProduct extends RandomVectorScorer.AbstractRandomVectorScorer { private final float constMultiplier; - private final RandomAccessQuantizedByteVectorValues values; + private final QuantizedByteVectorValues values; private final byte[] targetBytes; private final float offsetCorrection; private final FloatToFloatFunction scoreAdjustmentFunction; public Int4DotProduct( - RandomAccessQuantizedByteVectorValues values, + QuantizedByteVectorValues values, float constMultiplier, byte[] targetBytes, float offsetCorrection, @@ -271,13 +266,12 @@ private static final class ScalarQuantizedRandomVectorScorerSupplier implements RandomVectorScorerSupplier { private final VectorSimilarityFunction vectorSimilarityFunction; - private final RandomAccessQuantizedByteVectorValues values; - private final RandomAccessQuantizedByteVectorValues values1; - private final RandomAccessQuantizedByteVectorValues values2; + private final QuantizedByteVectorValues values; + private final QuantizedByteVectorValues values1; + private final QuantizedByteVectorValues values2; public ScalarQuantizedRandomVectorScorerSupplier( - RandomAccessQuantizedByteVectorValues values, - VectorSimilarityFunction vectorSimilarityFunction) + QuantizedByteVectorValues values, VectorSimilarityFunction vectorSimilarityFunction) throws IOException { this.values = values; this.values1 = values.copy(); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java index b8188a43bae5..8badd1f248fa 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java @@ -420,29 +420,8 @@ public int size() { } @Override - public float[] vectorValue() throws IOException { - return rawVectorValues.vectorValue(); - } - - @Override - public int docID() { - return rawVectorValues.docID(); - } - - @Override - public int nextDoc() throws IOException { - int rawDocId = rawVectorValues.nextDoc(); - int quantizedDocId = quantizedVectorValues.nextDoc(); - assert rawDocId == quantizedDocId; - return quantizedDocId; - } - - @Override - public int advance(int target) throws IOException { - int rawDocId = rawVectorValues.advance(target); - int quantizedDocId = quantizedVectorValues.advance(target); - assert rawDocId == quantizedDocId; - return quantizedDocId; + public float[] vectorValue(int ord) throws IOException { + return rawVectorValues.vectorValue(ord); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index bb333ad45c22..4b7e73cf8623 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -22,7 +22,6 @@ import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.DYNAMIC_CONFIDENCE_INTERVAL; import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT; import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; import java.io.Closeable; @@ -30,6 +29,7 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Objects; import org.apache.lucene.codecs.CodecUtil; @@ -51,7 +51,6 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.internal.hppc.IntArrayList; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; @@ -652,18 +651,8 @@ public static ScalarQuantizer mergeAndRecalculateQuantiles( // TODO: this is very conservative, could we reuse information for even int4 quantization? || bits <= 4 || shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) { - int numVectors = 0; - FloatVectorValues vectorValues = - KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - // iterate vectorValues and increment numVectors - for (int doc = vectorValues.nextDoc(); - doc != DocIdSetIterator.NO_MORE_DOCS; - doc = vectorValues.nextDoc()) { - numVectors++; - } return buildScalarQuantizer( KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState), - numVectors, fieldInfo.getVectorSimilarityFunction(), confidenceInterval, bits); @@ -673,7 +662,6 @@ public static ScalarQuantizer mergeAndRecalculateQuantiles( static ScalarQuantizer buildScalarQuantizer( FloatVectorValues floatVectorValues, - int numVectors, VectorSimilarityFunction vectorSimilarityFunction, Float confidenceInterval, byte bits) @@ -684,14 +672,13 @@ static ScalarQuantizer buildScalarQuantizer( } if (confidenceInterval != null && confidenceInterval == DYNAMIC_CONFIDENCE_INTERVAL) { return ScalarQuantizer.fromVectorsAutoInterval( - floatVectorValues, vectorSimilarityFunction, numVectors, bits); + floatVectorValues, vectorSimilarityFunction, bits); } return ScalarQuantizer.fromVectors( floatVectorValues, confidenceInterval == null ? calculateDefaultConfidenceInterval(floatVectorValues.dimension()) : confidenceInterval, - numVectors, bits); } @@ -730,11 +717,9 @@ public static DocsWithFieldSet writeQuantizedVectorData( ? OffHeapQuantizedByteVectorValues.compressedArray( quantizedByteVectorValues.dimension(), bits) : null; - for (int docV = quantizedByteVectorValues.nextDoc(); - docV != NO_MORE_DOCS; - docV = quantizedByteVectorValues.nextDoc()) { + for (int ord = 0; ord < quantizedByteVectorValues.size(); ord++) { // write vector - byte[] binaryValue = quantizedByteVectorValues.vectorValue(); + byte[] binaryValue = quantizedByteVectorValues.vectorValue(ord); assert binaryValue.length == quantizedByteVectorValues.dimension() : "dim=" + quantizedByteVectorValues.dimension() + " len=" + binaryValue.length; if (compressedVector != null) { @@ -743,8 +728,9 @@ public static DocsWithFieldSet writeQuantizedVectorData( } else { output.writeBytes(binaryValue, binaryValue.length); } - output.writeInt(Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant())); - docsWithField.add(docV); + output.writeInt( + Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant(ord))); + docsWithField.add(quantizedByteVectorValues.ordToDoc(ord)); } return docsWithField; } @@ -805,7 +791,6 @@ ScalarQuantizer createQuantizer() throws IOException { ScalarQuantizer quantizer = buildScalarQuantizer( new FloatVectorWrapper(floatVectors), - floatVectors.size(), fieldInfo.getVectorSimilarityFunction(), confidenceInterval, bits); @@ -872,51 +857,27 @@ public int size() { } @Override - public float[] vectorValue() throws IOException { - if (curDoc == -1 || curDoc >= vectorList.size()) { - throw new IOException("Current doc not set or too many iterations"); - } - return vectorList.get(curDoc); - } - - @Override - public int docID() { - if (curDoc >= vectorList.size()) { - return NO_MORE_DOCS; + public float[] vectorValue(int ord) throws IOException { + if (ord < 0 || ord >= vectorList.size()) { + throw new IOException("vector ord " + ord + " out of bounds"); } - return curDoc; - } - - @Override - public int nextDoc() throws IOException { - curDoc++; - return docID(); - } - - @Override - public int advance(int target) throws IOException { - curDoc = target; - return docID(); - } - - @Override - public VectorScorer scorer(float[] target) { - throw new UnsupportedOperationException(); + return vectorList.get(ord); } } static class QuantizedByteVectorValueSub extends DocIDMerger.Sub { private final QuantizedByteVectorValues values; + int ord = -1; + QuantizedByteVectorValueSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) { super(docMap); this.values = values; - assert values.docID() == -1; } @Override public int nextDoc() throws IOException { - return values.nextDoc(); + return values.ordToDoc(++ord); } } @@ -972,6 +933,7 @@ public static MergedQuantizedVectorValues mergeQuantizedByteVectorValues( private final List subs; private final DocIDMerger docIdMerger; private final int size; + private final int[] ends; private int docId; private QuantizedByteVectorValueSub current; @@ -981,37 +943,36 @@ private MergedQuantizedVectorValues( this.subs = subs; docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort); int totalSize = 0; + ends = new int[subs.size()]; + int iSub = 0; for (QuantizedByteVectorValueSub sub : subs) { totalSize += sub.values.size(); + ends[iSub++] = totalSize; } size = totalSize; - docId = -1; - } - - @Override - public byte[] vectorValue() throws IOException { - return current.values.vectorValue(); } @Override - public int docID() { - return docId; - } - - @Override - public int nextDoc() throws IOException { - current = docIdMerger.next(); - if (current == null) { - docId = NO_MORE_DOCS; + public byte[] vectorValue(int ord) throws IOException { + int iSub = Arrays.binarySearch(ends, ord); + if (iSub < 0) { + iSub = -(iSub + 1); + } + if (iSub == 0) { + return subs.get(iSub).values.vectorValue(ord); } else { - docId = current.mappedDocID; + return subs.get(iSub).values.vectorValue(ord - ends[iSub - 1]); } - return docId; } @Override - public int advance(int target) { + public int ordToDoc(int ord) { throw new UnsupportedOperationException(); + /** + * TODO: is this needed? int iSub = Arrays.binarySearch(ends, ord); if (iSub < 0) { iSub = + * -(iSub + 1); } if (iSub == 0) { return subs.get(iSub).values.ordToDoc(ord); } else { return + * subs.get(iSub).values.ordToDoc(ord - ends[iSub - 1]); } + */ } @Override @@ -1025,13 +986,8 @@ public int dimension() { } @Override - public float getScoreCorrectionConstant() throws IOException { - return current.values.getScoreCorrectionConstant(); - } - - @Override - public VectorScorer scorer(float[] target) throws IOException { - throw new UnsupportedOperationException(); + public float getScoreCorrectionConstant(int ord) throws IOException { + return current.values.getScoreCorrectionConstant(ord); } } @@ -1039,6 +995,7 @@ static class QuantizedFloatVectorValues extends QuantizedByteVectorValues { private final FloatVectorValues values; private final ScalarQuantizer quantizer; private final byte[] quantizedVector; + private int lastOrd = -1; private float offsetValue = 0f; private final VectorSimilarityFunction vectorSimilarityFunction; @@ -1054,7 +1011,14 @@ public QuantizedFloatVectorValues( } @Override - public float getScoreCorrectionConstant() { + public float getScoreCorrectionConstant(int ord) { + if (ord != lastOrd) { + throw new IllegalStateException( + "attempt to retrieve score correction for different ord " + + ord + + " than the quantization was done for: " + + lastOrd); + } return offsetValue; } @@ -1069,31 +1033,12 @@ public int size() { } @Override - public byte[] vectorValue() throws IOException { - return quantizedVector; - } - - @Override - public int docID() { - return values.docID(); - } - - @Override - public int nextDoc() throws IOException { - int doc = values.nextDoc(); - if (doc != NO_MORE_DOCS) { - quantize(); - } - return doc; - } - - @Override - public int advance(int target) throws IOException { - int doc = values.advance(target); - if (doc != NO_MORE_DOCS) { - quantize(); + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + offsetValue = quantize(ord); + lastOrd = ord; } - return doc; + return quantizedVector; } @Override @@ -1101,9 +1046,8 @@ public VectorScorer scorer(float[] target) throws IOException { throw new UnsupportedOperationException(); } - private void quantize() throws IOException { - offsetValue = - quantizer.quantize(values.vectorValue(), quantizedVector, vectorSimilarityFunction); + private float quantize(int ord) throws IOException { + return quantizer.quantize(values.vectorValue(ord), quantizedVector, vectorSimilarityFunction); } } @@ -1160,9 +1104,9 @@ static final class OffsetCorrectedQuantizedByteVectorValues extends QuantizedByt } @Override - public float getScoreCorrectionConstant() throws IOException { + public float getScoreCorrectionConstant(int ord) throws IOException { return scalarQuantizer.recalculateCorrectiveOffset( - in.vectorValue(), oldScalarQuantizer, vectorSimilarityFunction); + in.vectorValue(ord), oldScalarQuantizer, vectorSimilarityFunction); } @Override @@ -1176,35 +1120,15 @@ public int size() { } @Override - public byte[] vectorValue() throws IOException { - return in.vectorValue(); - } - - @Override - public int docID() { - return in.docID(); - } - - @Override - public int nextDoc() throws IOException { - return in.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - return in.advance(target); - } - - @Override - public VectorScorer scorer(float[] target) throws IOException { - throw new UnsupportedOperationException(); + public byte[] vectorValue(int ord) throws IOException { + return in.vectorValue(ord); } } static final class NormalizedFloatVectorValues extends FloatVectorValues { private final FloatVectorValues values; private final float[] normalizedVector; - int curDoc = -1; + int curOrd = -1; public NormalizedFloatVectorValues(FloatVectorValues values) { this.values = values; @@ -1222,38 +1146,10 @@ public int size() { } @Override - public float[] vectorValue() throws IOException { + public float[] vectorValue(int ord) throws IOException { + System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length); + VectorUtil.l2normalize(normalizedVector); return normalizedVector; } - - @Override - public VectorScorer scorer(float[] query) throws IOException { - throw new UnsupportedOperationException(); - } - - @Override - public int docID() { - return values.docID(); - } - - @Override - public int nextDoc() throws IOException { - curDoc = values.nextDoc(); - if (curDoc != NO_MORE_DOCS) { - System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length); - VectorUtil.l2normalize(normalizedVector); - } - return curDoc; - } - - @Override - public int advance(int target) throws IOException { - curDoc = values.advance(target); - if (curDoc != NO_MORE_DOCS) { - System.arraycopy(values.vectorValue(), 0, normalizedVector, 0, normalizedVector.length); - VectorUtil.l2normalize(normalizedVector); - } - return curDoc; - } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index 655dcca11667..5d488c200ce1 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -27,10 +27,10 @@ import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; +import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizer; /** @@ -38,7 +38,7 @@ * supports both iterated and random access. */ public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValues - implements RandomAccessQuantizedByteVectorValues { + implements RandomAccessVectorValues.Bytes { protected final int dimension; protected final int size; @@ -141,11 +141,6 @@ public byte[] vectorValue(int targetOrd) throws IOException { return binaryValue; } - @Override - public float getScoreCorrectionConstant() { - return scoreCorrectionConstant[0]; - } - @Override public float getScoreCorrectionConstant(int targetOrd) throws IOException { if (lastOrd == targetOrd) { @@ -226,30 +221,6 @@ public DenseOffHeapVectorValues( super(dimension, size, scalarQuantizer, similarityFunction, vectorsScorer, compress, slice); } - @Override - public byte[] vectorValue() throws IOException { - return vectorValue(doc); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - if (target >= size) { - return doc = NO_MORE_DOCS; - } - return doc = target; - } - @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( @@ -280,7 +251,7 @@ public float score() throws IOException { @Override public DocIdSetIterator iterator() { - return copy; + return copy.iterator(); } }; } @@ -312,24 +283,11 @@ public SparseOffHeapVectorValues( } @Override - public byte[] vectorValue() throws IOException { - return vectorValue(disi.index()); - } - - @Override - public int docID() { - return disi.docID(); - } - - @Override - public int nextDoc() throws IOException { - return disi.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - assert docID() < target; - return disi.advance(target); + public KnnValuesDocIterator iterator() { + if (iterator == null) { + iterator = fromIndexedDISI(disi); + } + return iterator; } @Override @@ -382,7 +340,7 @@ public float score() throws IOException { @Override public DocIdSetIterator iterator() { - return copy; + return copy.iterator(); } }; } @@ -416,26 +374,6 @@ public int size() { return 0; } - @Override - public byte[] vectorValue() { - throw new UnsupportedOperationException(); - } - - @Override - public int docID() { - return doc; - } - - @Override - public int nextDoc() throws IOException { - return advance(doc + 1); - } - - @Override - public int advance(int target) { - return doc = NO_MORE_DOCS; - } - @Override public EmptyOffHeapVectorValues copy() { throw new UnsupportedOperationException(); @@ -461,4 +399,33 @@ public VectorScorer scorer(float[] target) { return null; } } + + static KnnValuesDocIterator fromIndexedDISI(IndexedDISI disi) { + return new KnnValuesDocIterator() { + @Override + public int docID() { + return disi.docID(); + } + + @Override + public int index() { + return disi.index(); + } + + @Override + public int nextDoc() throws IOException { + return disi.nextDoc(); + } + + @Override + public int advance(int target) throws IOException { + return disi.advance(target); + } + + @Override + public long cost() { + return disi.cost(); + } + }; + } } diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index d33ca1ca3544..1c59928c64bb 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -17,9 +17,10 @@ package org.apache.lucene.index; import java.io.IOException; +import java.util.List; import org.apache.lucene.document.KnnByteVectorField; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; /** * This class provides access to per-document floating point vector values indexed as {@link @@ -27,34 +28,23 @@ * * @lucene.experimental */ -public abstract class ByteVectorValues extends DocIdSetIterator { +public abstract class ByteVectorValues extends KnnVectorValues { /** Sole constructor */ protected ByteVectorValues() {} - /** Return the dimension of the vectors */ - public abstract int dimension(); - - /** - * Return the number of vectors for this field. - * - * @return the number of vectors returned by this iterator - */ - public abstract int size(); - @Override - public final long cost() { - return size(); + public ByteVectorValues copy() throws IOException { + return this; } /** - * Return the vector value for the current document ID. It is illegal to call this method when the - * iterator is not positioned: before advancing, or after failing to advance. The returned array - * may be shared across calls, re-used, and modified as the iterator advances. + * Return the vector value for the given vector ordinal which must be in [0, size() - 1], + * otherwise IndexOutOfBoundsException is thrown. The returned array may be shared across calls. * * @return the vector value */ - public abstract byte[] vectorValue() throws IOException; + public abstract byte[] vectorValue(int ord) throws IOException; /** * Checks the Vector Encoding of a field @@ -78,12 +68,57 @@ public static void checkField(LeafReader in, String field) { } /** - * Return a {@link VectorScorer} for the given query vector. The iterator for the scorer is not - * the same instance as the iterator for this {@link ByteVectorValues}. It is a copy, and - * iteration over the scorer will not affect the iteration of this {@link ByteVectorValues}. + * Return a {@link VectorScorer} for the given query vector. * * @param query the query vector * @return a {@link VectorScorer} instance or null */ - public abstract VectorScorer scorer(byte[] query) throws IOException; + public VectorScorer scorer(byte[] query) throws IOException { + throw new UnsupportedOperationException(); + } + + /** Returns the vector byte length, defaults to dimension multiplied by byte size */ + @Override + public int getVectorByteLength() { + return dimension() * Byte.BYTES; + } + + /** + * Returns a slice of the underlying {@link IndexInput} that contains the vector values if + * available + */ + public IndexInput getSlice() { + return null; + } + + /** + * Creates a {@link ByteVectorValues} from a list of byte arrays. + * + * @param vectors the list of byte arrays + * @param dim the dimension of the vectors + * @return a {@link ByteVectorValues} instancec + */ + public static ByteVectorValues fromBytes(List vectors, int dim) { + return new ByteVectorValues() { + @Override + public int size() { + return vectors.size(); + } + + @Override + public int dimension() { + return dim; + } + + @Override + public byte[] vectorValue(int targetOrd) { + return vectors.get(targetOrd); + } + + @Override + public ByteVectorValues copy() { + return this; + } + }; + } } diff --git a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java index d768aaf1ab23..cef6a3354fe8 100644 --- a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java +++ b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java @@ -2760,16 +2760,16 @@ private static void checkFloatVectorValues( CheckIndex.Status.VectorValuesStatus status, CodecReader codecReader) throws IOException { - int docCount = 0; + int count = 0; int everyNdoc = Math.max(values.size() / 64, 1); - while (values.nextDoc() != NO_MORE_DOCS) { + while (count < values.size()) { // search the first maxNumSearches vectors to exercise the graph - if (values.docID() % everyNdoc == 0) { + if (values.ordToDoc(count) % everyNdoc == 0) { KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE); if (vectorsReaderSupportsSearch(codecReader, fieldInfo.name)) { codecReader .getVectorReader() - .search(fieldInfo.name, values.vectorValue(), collector, null); + .search(fieldInfo.name, values.vectorValue(count), collector, null); TopDocs docs = collector.topDocs(); if (docs.scoreDocs.length == 0) { throw new CheckIndexException( @@ -2777,7 +2777,7 @@ private static void checkFloatVectorValues( } } } - int valueLength = values.vectorValue().length; + int valueLength = values.vectorValue(count).length; if (valueLength != fieldInfo.getVectorDimension()) { throw new CheckIndexException( "Field \"" @@ -2787,19 +2787,19 @@ private static void checkFloatVectorValues( + " not matching the field's dimension=" + fieldInfo.getVectorDimension()); } - ++docCount; + ++count; } - if (docCount != values.size()) { + if (count != values.size()) { throw new CheckIndexException( "Field \"" + fieldInfo.name + "\" has size=" + values.size() + " but when iterated, returns " - + docCount + + count + " docs with values"); } - status.totalVectorValues += docCount; + status.totalVectorValues += count; } private static void checkByteVectorValues( @@ -2808,21 +2808,23 @@ private static void checkByteVectorValues( CheckIndex.Status.VectorValuesStatus status, CodecReader codecReader) throws IOException { - int docCount = 0; + int count = 0; int everyNdoc = Math.max(values.size() / 64, 1); boolean supportsSearch = vectorsReaderSupportsSearch(codecReader, fieldInfo.name); - while (values.nextDoc() != NO_MORE_DOCS) { + while (count < values.size()) { // search the first maxNumSearches vectors to exercise the graph - if (supportsSearch && values.docID() % everyNdoc == 0) { + if (supportsSearch && values.ordToDoc(count) % everyNdoc == 0) { KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE); - codecReader.getVectorReader().search(fieldInfo.name, values.vectorValue(), collector, null); + codecReader + .getVectorReader() + .search(fieldInfo.name, values.vectorValue(count), collector, null); TopDocs docs = collector.topDocs(); if (docs.scoreDocs.length == 0) { throw new CheckIndexException( "Field \"" + fieldInfo.name + "\" failed to search k nearest neighbors"); } } - int valueLength = values.vectorValue().length; + int valueLength = values.vectorValue(count).length; if (valueLength != fieldInfo.getVectorDimension()) { throw new CheckIndexException( "Field \"" @@ -2832,19 +2834,19 @@ private static void checkByteVectorValues( + " not matching the field's dimension=" + fieldInfo.getVectorDimension()); } - ++docCount; + ++count; } - if (docCount != values.size()) { + if (count != values.size()) { throw new CheckIndexException( "Field \"" + fieldInfo.name + "\" has size=" + values.size() + " but when iterated, returns " - + docCount + + count + " docs with values"); } - status.totalVectorValues += docCount; + status.totalVectorValues += count; } /** diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java index ca2cb1a27d45..e2395dfd40a3 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java @@ -429,47 +429,32 @@ private void checkAndThrow(DocIdSetIterator in) { } private class ExitableFloatVectorValues extends FloatVectorValues { - private int docToCheck; + private int nextCheck; private final FloatVectorValues vectorValues; public ExitableFloatVectorValues(FloatVectorValues vectorValues) { this.vectorValues = vectorValues; - docToCheck = 0; } @Override - public int advance(int target) throws IOException { - final int advance = vectorValues.advance(target); - if (advance >= docToCheck) { - checkAndThrow(); - docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK; - } - return advance; - } - - @Override - public int docID() { - return vectorValues.docID(); + public int dimension() { + return vectorValues.dimension(); } @Override - public int nextDoc() throws IOException { - final int nextDoc = vectorValues.nextDoc(); - if (nextDoc >= docToCheck) { + public float[] vectorValue(int ord) throws IOException { + if (nextCheck >= DOCS_BETWEEN_TIMEOUT_CHECK) { checkAndThrow(); - docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK; + nextCheck = 0; + } else { + nextCheck++; } - return nextDoc; - } - - @Override - public int dimension() { - return vectorValues.dimension(); + return vectorValues.vectorValue(ord); } @Override - public float[] vectorValue() throws IOException { - return vectorValues.vectorValue(); + public int ordToDoc(int ord) { + return vectorValues.ordToDoc(ord); } @Override @@ -482,6 +467,11 @@ public VectorScorer scorer(float[] target) throws IOException { return vectorValues.scorer(target); } + @Override + public FloatVectorValues copy() { + throw new UnsupportedOperationException(); + } + /** * Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or * if {@link Thread#interrupted()} returns true. @@ -501,57 +491,47 @@ private void checkAndThrow() { } private class ExitableByteVectorValues extends ByteVectorValues { - private int docToCheck; + private int nextCheck; private final ByteVectorValues vectorValues; public ExitableByteVectorValues(ByteVectorValues vectorValues) { this.vectorValues = vectorValues; - docToCheck = 0; } @Override - public int advance(int target) throws IOException { - final int advance = vectorValues.advance(target); - if (advance >= docToCheck) { - checkAndThrow(); - docToCheck = advance + DOCS_BETWEEN_TIMEOUT_CHECK; - } - return advance; + public int dimension() { + return vectorValues.dimension(); } @Override - public int docID() { - return vectorValues.docID(); + public int size() { + return vectorValues.size(); } @Override - public int nextDoc() throws IOException { - final int nextDoc = vectorValues.nextDoc(); - if (nextDoc >= docToCheck) { + public byte[] vectorValue(int ord) throws IOException { + if (nextCheck >= DOCS_BETWEEN_TIMEOUT_CHECK) { checkAndThrow(); - docToCheck = nextDoc + DOCS_BETWEEN_TIMEOUT_CHECK; + nextCheck = 0; + } else { + nextCheck++; } - return nextDoc; - } - - @Override - public int dimension() { - return vectorValues.dimension(); + return vectorValues.vectorValue(ord); } @Override - public int size() { - return vectorValues.size(); + public int ordToDoc(int ord) { + return vectorValues.ordToDoc(ord); } @Override - public byte[] vectorValue() throws IOException { - return vectorValues.vectorValue(); + public VectorScorer scorer(byte[] target) throws IOException { + return vectorValues.scorer(target); } @Override - public VectorScorer scorer(byte[] target) throws IOException { - return vectorValues.scorer(target); + public ByteVectorValues copy() { + throw new UnsupportedOperationException(); } /** diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index e5dbc620f5c3..4a9408a3c399 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -17,9 +17,10 @@ package org.apache.lucene.index; import java.io.IOException; +import java.util.List; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; /** * This class provides access to per-document floating point vector values indexed as {@link @@ -27,34 +28,31 @@ * * @lucene.experimental */ -public abstract class FloatVectorValues extends DocIdSetIterator { +public abstract class FloatVectorValues extends KnnVectorValues { /** Sole constructor */ protected FloatVectorValues() {} - /** Return the dimension of the vectors */ - public abstract int dimension(); - - /** - * Return the number of vectors for this field. - * - * @return the number of vectors returned by this iterator - */ - public abstract int size(); - @Override - public final long cost() { - return size(); + public FloatVectorValues copy() throws IOException { + return this; } /** - * Return the vector value for the current document ID. It is illegal to call this method when the - * iterator is not positioned: before advancing, or after failing to advance. The returned array - * may be shared across calls, re-used, and modified as the iterator advances. + * Return the vector value for the given vector ordinal which must be in [0, size() - 1], + * otherwise IndexOutOfBoundsException is thrown. The returned array may be shared across calls. * * @return the vector value */ - public abstract float[] vectorValue() throws IOException; + public abstract float[] vectorValue(int ord) throws IOException; + + /** + * Returns a slice of the underlying {@link IndexInput} that contains the vector values if + * available + */ + public IndexInput getSlice() { + return null; + } /** * Checks the Vector Encoding of a field @@ -79,12 +77,49 @@ public static void checkField(LeafReader in, String field) { /** * Return a {@link VectorScorer} for the given query vector and the current {@link - * FloatVectorValues}. The iterator for the scorer is not the same instance as the iterator for - * this {@link FloatVectorValues}. It is a copy, and iteration over the scorer will not affect the - * iteration of this {@link FloatVectorValues}. + * FloatVectorValues}. * - * @param query the query vector + * @param target the query vector * @return a {@link VectorScorer} instance or null */ - public abstract VectorScorer scorer(float[] query) throws IOException; + public VectorScorer scorer(float[] target) throws IOException { + throw new UnsupportedOperationException(); + } + + /** Returns the vector byte length, defaults to dimension multiplied by float byte size */ + @Override + public int getVectorByteLength() { + return dimension() * Float.BYTES; + } + + /** + * Creates a {@link FloatVectorValues} from a list of float arrays. + * + * @param vectors the list of float arrays + * @param dim the dimension of the vectors + * @return a {@link FloatVectorValues} instance + */ + public static FloatVectorValues fromFloats(List vectors, int dim) { + return new FloatVectorValues() { + @Override + public int size() { + return vectors.size(); + } + + @Override + public int dimension() { + return dim; + } + + @Override + public float[] vectorValue(int targetOrd) { + return vectors.get(targetOrd); + } + + @Override + public FloatVectorValues copy() { + return this; + } + }; + } } diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java new file mode 100644 index 000000000000..7dafab09ce80 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.index; + +import java.io.IOException; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.util.Bits; + +/** + * This class abstracts addressing of document vector values indexed as {@link KnnFloatVectorField} + * or {@link KnnByteVectorField}. Hmm, we also need docToOrd(s)? We need to be able to retrieve a + * vector value for a document. + * + * @lucene.experimental + */ +public abstract class KnnVectorValues { + + protected KnnValuesDocIterator iterator; + + /** Return the dimension of the vectors */ + public abstract int dimension(); + + /** + * Return the number of vectors for this field. + * + * @return the number of vectors returned by this iterator + */ + public abstract int size(); + + /** Return the docid of the document indexed with the given vector ordingl */ + public int ordToDoc(int ord) { + throw new UnsupportedOperationException(); + } + + /** Return the vector ordinal indexed for the given document or -1 if there is none */ + public int docToOrd(int ord) { + throw new UnsupportedOperationException(); + } + + /** + * Creates a new copy of this {@link KnnVectorValues}. This is helpful when you need to access + * different values at once, to avoid overwriting the underlying vector returned. + */ + public KnnVectorValues copy() throws IOException { + throw new UnsupportedOperationException(); + } + + /** Returns the byte length of the vector values. */ + public abstract int getVectorByteLength(); + + public Bits getAcceptOrds(Bits acceptDocs) { + // FIXME: change default to return acceptDocs and provide this impl + // somewhere more specialized + if (acceptDocs == null) { + return null; + } + return new Bits() { + @Override + public boolean get(int index) { + return acceptDocs.get(ordToDoc(index)); + } + + @Override + public int length() { + return size(); + } + }; + } + + public abstract static class KnnValuesDocIterator extends DocIdSetIterator { + + /** return the value index (aka "ordinal" or "ord") corresponding to the current doc */ + public abstract int index(); + } + + public KnnValuesDocIterator iterator() { + if (iterator == null) { + iterator = + new KnnValuesDocIterator() { + + int ord = -1; + int doc = -1; + + @Override + public int docID() { + return doc; + } + + @Override + public int index() { + return ord; + } + + @Override + public int nextDoc() throws IOException { + if (ord >= size() - 1) { + return NO_MORE_DOCS; + } else { + doc = docToOrd(++ord); + return doc; + } + } + + @Override + public int advance(int target) { + throw new UnsupportedOperationException(); + } + + @Override + public long cost() { + return size(); + } + }; + } + return iterator; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index 6d8342908ef5..68e489a31dd3 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -832,10 +832,11 @@ public void checkIntegrity() throws IOException { @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { - List> subs = new ArrayList<>(); + List subs = new ArrayList<>(); int i = 0; int dimension = -1; int size = 0; + int[] ends = new int[codecReaders.length]; for (CodecReader reader : codecReaders) { FloatVectorValues values = reader.getFloatVectorValues(field); if (values != null) { @@ -844,12 +845,11 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { } size += values.size(); } - subs.add(new DocValuesSub<>(values, docStarts[i], docStarts[i + 1])); - i++; + ends[i++] = size; + subs.add(values); } final int finalDimension = dimension; final int finalSize = size; - MergedDocIdSetIterator mergedIterator = new MergedDocIdSetIterator<>(subs); return new FloatVectorValues() { @Override @@ -863,27 +863,32 @@ public int size() { } @Override - public float[] vectorValue() throws IOException { - return mergedIterator.current.sub.vectorValue(); - } - - @Override - public int docID() { - return mergedIterator.docID(); + public float[] vectorValue(int ord) throws IOException { + int iSub = Arrays.binarySearch(ends, ord); + if (iSub < 0) { + iSub = -(iSub + 1); + } + int subOrd; + if (iSub == 0) { + subOrd = ord; + } else { + subOrd = ord - ends[iSub - 1]; + } + return subs.get(iSub).vectorValue(subOrd); } @Override - public int nextDoc() throws IOException { - return mergedIterator.nextDoc(); + public int ordToDoc(int ord) { + throw new UnsupportedOperationException(); } @Override - public int advance(int target) throws IOException { - return mergedIterator.advance(target); + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); } @Override - public VectorScorer scorer(float[] target) { + public FloatVectorValues copy() { throw new UnsupportedOperationException(); } }; @@ -891,10 +896,11 @@ public VectorScorer scorer(float[] target) { @Override public ByteVectorValues getByteVectorValues(String field) throws IOException { - List> subs = new ArrayList<>(); + List subs = new ArrayList<>(); int i = 0; int dimension = -1; int size = 0; + int[] ends = new int[codecReaders.length]; for (CodecReader reader : codecReaders) { ByteVectorValues values = reader.getByteVectorValues(field); if (values != null) { @@ -903,12 +909,11 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } size += values.size(); } - subs.add(new DocValuesSub<>(values, docStarts[i], docStarts[i + 1])); - i++; + ends[i++] = size; + subs.add(values); } final int finalDimension = dimension; final int finalSize = size; - MergedDocIdSetIterator mergedIterator = new MergedDocIdSetIterator<>(subs); return new ByteVectorValues() { @Override @@ -922,27 +927,32 @@ public int size() { } @Override - public byte[] vectorValue() throws IOException { - return mergedIterator.current.sub.vectorValue(); - } - - @Override - public int docID() { - return mergedIterator.docID(); + public byte[] vectorValue(int ord) throws IOException { + int iSub = Arrays.binarySearch(ends, ord); + if (iSub < 0) { + iSub = -(iSub + 1); + } + int subOrd; + if (iSub == 0) { + subOrd = ord; + } else { + subOrd = ord - ends[iSub - 1]; + } + return subs.get(iSub).vectorValue(subOrd); } @Override - public int nextDoc() throws IOException { - return mergedIterator.nextDoc(); + public int ordToDoc(int ord) { + throw new UnsupportedOperationException(); } @Override - public int advance(int target) throws IOException { - return mergedIterator.advance(target); + public VectorScorer scorer(byte[] target) { + throw new UnsupportedOperationException(); } @Override - public VectorScorer scorer(byte[] target) { + public ByteVectorValues copy() { throw new UnsupportedOperationException(); } }; diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index 191111c0ad36..7e8a151a3465 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -206,121 +206,83 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue } } - /** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */ + /** Sorting FloatVectorValues that maps ordinals using the provided sortMap */ private static class SortingFloatVectorValues extends FloatVectorValues { - final int size; - final int dimension; - final FixedBitSet docsWithField; - final float[][] vectors; - - private int docId = -1; + final FloatVectorValues delegate; + final Sorter.DocMap sortMap; SortingFloatVectorValues(FloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException { - this.size = delegate.size(); - this.dimension = delegate.dimension(); - docsWithField = new FixedBitSet(sortMap.size()); - vectors = new float[sortMap.size()][]; - for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) { - int newDocID = sortMap.oldToNew(doc); - docsWithField.set(newDocID); - vectors[newDocID] = delegate.vectorValue().clone(); - } - } - - @Override - public int docID() { - return docId; + this.delegate = delegate; + this.sortMap = sortMap; } @Override - public int nextDoc() throws IOException { - return advance(docId + 1); + public float[] vectorValue(int ord) throws IOException { + return delegate.vectorValue(sortMap.newToOld(ord)); } @Override - public float[] vectorValue() throws IOException { - return vectors[docId]; + public int ordToDoc(int ord) { + return delegate.ordToDoc(sortMap.newToOld(ord)); } @Override public int dimension() { - return dimension; + return delegate.dimension(); } @Override public int size() { - return size; + return delegate.size(); } @Override - public int advance(int target) throws IOException { - if (target >= docsWithField.length()) { - return NO_MORE_DOCS; - } - return docId = docsWithField.nextSetBit(target); + public VectorScorer scorer(float[] target) { + throw new UnsupportedOperationException(); } @Override - public VectorScorer scorer(float[] target) { + public FloatVectorValues copy() { throw new UnsupportedOperationException(); } } private static class SortingByteVectorValues extends ByteVectorValues { - final int size; - final int dimension; - final FixedBitSet docsWithField; - final byte[][] vectors; - - private int docId = -1; + final ByteVectorValues delegate; + final Sorter.DocMap sortMap; SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException { - this.size = delegate.size(); - this.dimension = delegate.dimension(); - docsWithField = new FixedBitSet(sortMap.size()); - vectors = new byte[sortMap.size()][]; - for (int doc = delegate.nextDoc(); doc != NO_MORE_DOCS; doc = delegate.nextDoc()) { - int newDocID = sortMap.oldToNew(doc); - docsWithField.set(newDocID); - vectors[newDocID] = delegate.vectorValue().clone(); - } - } - - @Override - public int docID() { - return docId; + this.delegate = delegate; + this.sortMap = sortMap; } @Override - public int nextDoc() throws IOException { - return advance(docId + 1); + public byte[] vectorValue(int ord) throws IOException { + return delegate.vectorValue(sortMap.newToOld(ord)); } @Override - public byte[] vectorValue() throws IOException { - return vectors[docId]; + public int ordToDoc(int ord) { + return delegate.ordToDoc(sortMap.newToOld(ord)); } @Override public int dimension() { - return dimension; + return delegate.dimension(); } @Override public int size() { - return size; + return delegate.size(); } @Override - public int advance(int target) throws IOException { - if (target >= docsWithField.length()) { - return NO_MORE_DOCS; - } - return docId = docsWithField.nextSetBit(target); + public VectorScorer scorer(byte[] target) { + throw new UnsupportedOperationException(); } @Override - public VectorScorer scorer(byte[] target) { + public ByteVectorValues copy() { throw new UnsupportedOperationException(); } } diff --git a/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java b/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java index 409bcbc0b643..adaace27727e 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/FieldExistsQuery.java @@ -181,8 +181,8 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti } else if (fieldInfo.getVectorDimension() != 0) { // the field indexes vectors iterator = switch (fieldInfo.getVectorEncoding()) { - case FLOAT32 -> context.reader().getFloatVectorValues(field); - case BYTE -> context.reader().getByteVectorValues(field); + case FLOAT32 -> context.reader().getFloatVectorValues(field).iterator(); + case BYTE -> context.reader().getByteVectorValues(field).iterator(); }; } else if (fieldInfo.getDocValuesType() != DocValuesType.NONE) { // the field indexes doc values diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java index 392d83fa262c..c4e7d159b489 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/ConcurrentHnswMerger.java @@ -19,7 +19,7 @@ import java.io.IOException; import org.apache.lucene.codecs.hnsw.HnswGraphProvider; import org.apache.lucene.index.FieldInfo; -import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.search.TaskExecutor; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.FixedBitSet; @@ -46,7 +46,7 @@ public ConcurrentHnswMerger( } @Override - protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int maxOrd) + protected HnswBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxOrd) throws IOException { if (initReader == null) { return new HnswConcurrentMergeBuilder( @@ -61,7 +61,7 @@ protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int m HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name); BitSet initializedNodes = new FixedBitSet(maxOrd); - int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes); + int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorValues, initializedNodes); return new HnswConcurrentMergeBuilder( taskExecutor, diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphMerger.java index 7ed5dd142de5..31e9c768dc03 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphMerger.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphMerger.java @@ -18,8 +18,8 @@ import java.io.IOException; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.Bits; import org.apache.lucene.util.InfoStream; @@ -45,12 +45,12 @@ HnswGraphMerger addReader(KnnVectorsReader reader, MergeState.DocMap docMap, Bit /** * Merge and produce the on heap graph * - * @param mergedVectorIterator iterator over the vectors in the merged segment + * @param mergedVectorValues view of the vectors in the merged segment * @param infoStream optional info stream to set to builder * @param maxOrd max number of vectors that will be added to the graph * @return merged graph * @throws IOException during merge */ - OnHeapHnswGraph merge(DocIdSetIterator mergedVectorIterator, InfoStream infoStream, int maxOrd) + OnHeapHnswGraph merge(KnnVectorValues mergedVectorValues, InfoStream infoStream, int maxOrd) throws IOException; } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java index 7331111d45a9..85583654a53e 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java @@ -16,8 +16,6 @@ */ package org.apache.lucene.util.hnsw; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - import java.io.IOException; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.hnsw.HnswGraphProvider; @@ -25,9 +23,9 @@ import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.internal.hppc.IntIntHashMap; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; @@ -108,12 +106,12 @@ public IncrementalHnswGraphMerger addReader( * Builds a new HnswGraphBuilder using the biggest graph from the merge state as a starting point. * If no valid readers were added to the merge state, a new graph is created. * - * @param mergedVectorIterator iterator over the vectors in the merged segment + * @param mergedVectorValues vector values in the merged segment * @param maxOrd max num of vectors that will be merged into the graph * @return HnswGraphBuilder * @throws IOException If an error occurs while reading from the merge state */ - protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int maxOrd) + protected HnswBuilder createBuilder(KnnVectorValues mergedVectorValues, int maxOrd) throws IOException { if (initReader == null) { return HnswGraphBuilder.create( @@ -123,7 +121,7 @@ protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int m HnswGraph initializerGraph = ((HnswGraphProvider) initReader).getGraph(fieldInfo.name); BitSet initializedNodes = new FixedBitSet(maxOrd); - int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorIterator, initializedNodes); + int[] oldToNewOrdinalMap = getNewOrdMapping(mergedVectorValues, initializedNodes); return InitializedHnswGraphBuilder.fromGraph( scorerSupplier, M, @@ -137,8 +135,8 @@ protected HnswBuilder createBuilder(DocIdSetIterator mergedVectorIterator, int m @Override public OnHeapHnswGraph merge( - DocIdSetIterator mergedVectorIterator, InfoStream infoStream, int maxOrd) throws IOException { - HnswBuilder builder = createBuilder(mergedVectorIterator, maxOrd); + KnnVectorValues mergedVectorValues, InfoStream infoStream, int maxOrd) throws IOException { + HnswBuilder builder = createBuilder(mergedVectorValues, maxOrd); builder.setInfoStream(infoStream); return builder.build(maxOrd); } @@ -147,46 +145,39 @@ public OnHeapHnswGraph merge( * Creates a new mapping from old ordinals to new ordinals and returns the total number of vectors * in the newly merged segment. * - * @param mergedVectorIterator iterator over the vectors in the merged segment + * @param mergedVectorValues vector values in the merged segment * @param initializedNodes track what nodes have been initialized * @return the mapping from old ordinals to new ordinals * @throws IOException If an error occurs while reading from the merge state */ protected final int[] getNewOrdMapping( - DocIdSetIterator mergedVectorIterator, BitSet initializedNodes) throws IOException { - DocIdSetIterator initializerIterator = null; + KnnVectorValues mergedVectorValues, BitSet initializedNodes) throws IOException { + KnnVectorValues initializerValues = null; switch (fieldInfo.getVectorEncoding()) { - case BYTE -> initializerIterator = initReader.getByteVectorValues(fieldInfo.name); - case FLOAT32 -> initializerIterator = initReader.getFloatVectorValues(fieldInfo.name); + case BYTE -> initializerValues = initReader.getByteVectorValues(fieldInfo.name); + case FLOAT32 -> initializerValues = initReader.getFloatVectorValues(fieldInfo.name); } IntIntHashMap newIdToOldOrdinal = new IntIntHashMap(initGraphSize); - int oldOrd = 0; int maxNewDocID = -1; - for (int oldId = initializerIterator.nextDoc(); - oldId != NO_MORE_DOCS; - oldId = initializerIterator.nextDoc()) { - int newId = initDocMap.get(oldId); + for (int oldOrd = 0; oldOrd < initializerValues.size(); oldOrd++) { + int newId = initDocMap.get(initializerValues.ordToDoc(oldOrd)); maxNewDocID = Math.max(newId, maxNewDocID); newIdToOldOrdinal.put(newId, oldOrd); - oldOrd++; } if (maxNewDocID == -1) { return new int[0]; } final int[] oldToNewOrdinalMap = new int[initGraphSize]; - int newOrd = 0; - for (int newDocId = mergedVectorIterator.nextDoc(); - newDocId <= maxNewDocID; - newDocId = mergedVectorIterator.nextDoc()) { + for (int newOrd = 0; newOrd < mergedVectorValues.size(); newOrd++) { + int newDocId = mergedVectorValues.ordToDoc(newOrd); int hashDocIndex = newIdToOldOrdinal.indexOf(newDocId); if (newIdToOldOrdinal.indexExists(hashDocIndex)) { initializedNodes.set(newOrd); oldToNewOrdinalMap[newIdToOldOrdinal.indexGet(hashDocIndex)] = newOrd; } - newOrd++; } return oldToNewOrdinalMap; } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java index e2c7372b667a..411468c1b5e4 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java @@ -17,11 +17,6 @@ package org.apache.lucene.util.hnsw; -import java.io.IOException; -import java.util.List; -import org.apache.lucene.store.IndexInput; -import org.apache.lucene.util.Bits; - /** * Provides random access to vectors by dense ordinal. This interface is used by HNSW-based * implementations of KNN search. @@ -30,146 +25,9 @@ */ public interface RandomAccessVectorValues { - /** Return the number of vector values */ - int size(); - - /** Return the dimension of the returned vector values */ - int dimension(); - - /** - * Creates a new copy of this {@link RandomAccessVectorValues}. This is helpful when you need to - * access different values at once, to avoid overwriting the underlying vector returned. - */ - RandomAccessVectorValues copy() throws IOException; - - /** - * Returns a slice of the underlying {@link IndexInput} that contains the vector values if - * available - */ - default IndexInput getSlice() { - return null; - } - - /** Returns the byte length of the vector values. */ - int getVectorByteLength(); - - /** - * Translates vector ordinal to the correct document ID. By default, this is an identity function. - * - * @param ord the vector ordinal - * @return the document Id for that vector ordinal - */ - default int ordToDoc(int ord) { - return ord; - } - - /** - * Returns the {@link Bits} representing live documents. By default, this is an identity function. - * - * @param acceptDocs the accept docs - * @return the accept docs - */ - default Bits getAcceptOrds(Bits acceptDocs) { - return acceptDocs; - } - /** Float vector values. */ - interface Floats extends RandomAccessVectorValues { - @Override - RandomAccessVectorValues.Floats copy() throws IOException; - - /** - * Return the vector value indexed at the given ordinal. - * - * @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}. - */ - float[] vectorValue(int targetOrd) throws IOException; - - /** Returns the vector byte length, defaults to dimension multiplied by float byte size */ - @Override - default int getVectorByteLength() { - return dimension() * Float.BYTES; - } - } + interface Floats extends RandomAccessVectorValues {} /** Byte vector values. */ - interface Bytes extends RandomAccessVectorValues { - @Override - RandomAccessVectorValues.Bytes copy() throws IOException; - - /** - * Return the vector value indexed at the given ordinal. - * - * @param targetOrd a valid ordinal, ≥ 0 and < {@link #size()}. - */ - byte[] vectorValue(int targetOrd) throws IOException; - - /** Returns the vector byte length, defaults to dimension multiplied by byte size */ - @Override - default int getVectorByteLength() { - return dimension() * Byte.BYTES; - } - } - - /** - * Creates a {@link RandomAccessVectorValues.Floats} from a list of float arrays. - * - * @param vectors the list of float arrays - * @param dim the dimension of the vectors - * @return a {@link RandomAccessVectorValues.Floats} instance - */ - static RandomAccessVectorValues.Floats fromFloats(List vectors, int dim) { - return new RandomAccessVectorValues.Floats() { - @Override - public int size() { - return vectors.size(); - } - - @Override - public int dimension() { - return dim; - } - - @Override - public float[] vectorValue(int targetOrd) { - return vectors.get(targetOrd); - } - - @Override - public RandomAccessVectorValues.Floats copy() { - return this; - } - }; - } - - /** - * Creates a {@link RandomAccessVectorValues.Bytes} from a list of byte arrays. - * - * @param vectors the list of byte arrays - * @param dim the dimension of the vectors - * @return a {@link RandomAccessVectorValues.Bytes} instance - */ - static RandomAccessVectorValues.Bytes fromBytes(List vectors, int dim) { - return new RandomAccessVectorValues.Bytes() { - @Override - public int size() { - return vectors.size(); - } - - @Override - public int dimension() { - return dim; - } - - @Override - public byte[] vectorValue(int targetOrd) { - return vectors.get(targetOrd); - } - - @Override - public RandomAccessVectorValues.Bytes copy() { - return this; - } - }; - } + interface Bytes extends RandomAccessVectorValues {} } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java index fc8ed3d004a1..a135df436991 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomVectorScorer.java @@ -18,6 +18,7 @@ package org.apache.lucene.util.hnsw; import java.io.IOException; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.util.Bits; /** A {@link RandomVectorScorer} for scoring random nodes in batches against an abstract query. */ @@ -57,14 +58,14 @@ default Bits getAcceptOrds(Bits acceptDocs) { /** Creates a default scorer for random access vectors. */ abstract class AbstractRandomVectorScorer implements RandomVectorScorer { - private final RandomAccessVectorValues values; + private final KnnVectorValues values; /** * Creates a new scorer for the given vector values. * * @param values the vector values */ - public AbstractRandomVectorScorer(RandomAccessVectorValues values) { + public AbstractRandomVectorScorer(KnnVectorValues values) { this.values = values; } diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java index a0fe957fecb4..0626f6669221 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java @@ -18,8 +18,8 @@ import java.io.IOException; import org.apache.lucene.index.ByteVectorValues; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; /** * A version of {@link ByteVectorValues}, but additionally retrieving score correction offset for @@ -27,31 +27,34 @@ * * @lucene.experimental */ -public abstract class QuantizedByteVectorValues extends DocIdSetIterator { - public abstract float getScoreCorrectionConstant() throws IOException; +public abstract class QuantizedByteVectorValues extends ByteVectorValues { - public abstract byte[] vectorValue() throws IOException; + public ScalarQuantizer getScalarQuantizer() { + throw new UnsupportedOperationException(); + } - /** Return the dimension of the vectors */ - public abstract int dimension(); + public abstract float getScoreCorrectionConstant(int ord) throws IOException; /** - * Return the number of vectors for this field. + * Return a {@link VectorScorer} for the given query vector. * - * @return the number of vectors returned by this iterator + * @param query the query vector + * @return a {@link VectorScorer} instance or null */ - public abstract int size(); + public VectorScorer scorer(float[] query) throws IOException { + throw new UnsupportedOperationException(); + } @Override - public final long cost() { - return size(); + public QuantizedByteVectorValues copy() throws IOException { + return this; } /** - * Return a {@link VectorScorer} for the given query vector. - * - * @param query the query vector - * @return a {@link VectorScorer} instance or null + * Returns a slice of the underlying {@link IndexInput} that contains the vector values if + * available */ - public abstract VectorScorer scorer(float[] query) throws IOException; + public IndexInput getSlice() { + return null; + } } diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java deleted file mode 100644 index b86009a690e1..000000000000 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/RandomAccessQuantizedByteVectorValues.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.lucene.util.quantization; - -import java.io.IOException; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; - -/** - * Random access values for byte[], but also includes accessing the score correction - * constant for the current vector in the buffer. - * - * @lucene.experimental - */ -public interface RandomAccessQuantizedByteVectorValues extends RandomAccessVectorValues.Bytes { - - ScalarQuantizer getScalarQuantizer(); - - float getScoreCorrectionConstant(int vectorOrd) throws IOException; - - @Override - RandomAccessQuantizedByteVectorValues copy() throws IOException; -} diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java index ab8a911ddfae..364b6a78e05b 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java @@ -16,8 +16,6 @@ */ package org.apache.lucene.util.quantization; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -237,43 +235,33 @@ private static int[] reservoirSampleIndices(int numFloatVecs, int sampleSize) { * * @param floatVectorValues the float vector values from which to calculate the quantiles * @param confidenceInterval the confidence interval used to calculate the quantiles - * @param totalVectorCount the total number of live float vectors in the index. This is vital for - * accounting for deleted documents when calculating the quantiles. * @param bits the number of bits to use for quantization * @return A new {@link ScalarQuantizer} instance * @throws IOException if there is an error reading the float vector values */ public static ScalarQuantizer fromVectors( - FloatVectorValues floatVectorValues, - float confidenceInterval, - int totalVectorCount, - byte bits) - throws IOException { + FloatVectorValues floatVectorValues, float confidenceInterval, byte bits) throws IOException { return fromVectors( - floatVectorValues, - confidenceInterval, - totalVectorCount, - bits, - SCALAR_QUANTIZATION_SAMPLE_SIZE); + floatVectorValues, confidenceInterval, bits, SCALAR_QUANTIZATION_SAMPLE_SIZE); } static ScalarQuantizer fromVectors( FloatVectorValues floatVectorValues, float confidenceInterval, - int totalVectorCount, byte bits, int quantizationSampleSize) throws IOException { assert 0.9f <= confidenceInterval && confidenceInterval <= 1f; assert quantizationSampleSize > SCRATCH_SIZE; + int totalVectorCount = floatVectorValues.size(); if (totalVectorCount == 0) { return new ScalarQuantizer(0f, 0f, bits); } if (confidenceInterval == 1f) { float min = Float.POSITIVE_INFINITY; float max = Float.NEGATIVE_INFINITY; - while (floatVectorValues.nextDoc() != NO_MORE_DOCS) { - for (float v : floatVectorValues.vectorValue()) { + for (int ord = 0; ord < floatVectorValues.size(); ord++) { + for (float v : floatVectorValues.vectorValue(ord)) { min = Math.min(min, v); max = Math.max(max, v); } @@ -289,8 +277,8 @@ static ScalarQuantizer fromVectors( if (totalVectorCount <= quantizationSampleSize) { int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount); int i = 0; - while (floatVectorValues.nextDoc() != NO_MORE_DOCS) { - float[] vectorValue = floatVectorValues.vectorValue(); + for (int ord = 0; ord < floatVectorValues.size(); ord++) { + float[] vectorValue = floatVectorValues.vectorValue(ord); System.arraycopy( vectorValue, 0, quantileGatheringScratch, i * vectorValue.length, vectorValue.length); i++; @@ -309,13 +297,7 @@ static ScalarQuantizer fromVectors( int index = 0; int idx = 0; for (int i : vectorsToTake) { - while (index <= i) { - // We cannot use `advance(docId)` as MergedVectorValues does not support it - floatVectorValues.nextDoc(); - index++; - } - assert floatVectorValues.docID() != NO_MORE_DOCS; - float[] vectorValue = floatVectorValues.vectorValue(); + float[] vectorValue = floatVectorValues.vectorValue(i); System.arraycopy( vectorValue, 0, quantileGatheringScratch, idx * vectorValue.length, vectorValue.length); idx++; @@ -329,12 +311,10 @@ static ScalarQuantizer fromVectors( } public static ScalarQuantizer fromVectorsAutoInterval( - FloatVectorValues floatVectorValues, - VectorSimilarityFunction function, - int totalVectorCount, - byte bits) + FloatVectorValues floatVectorValues, VectorSimilarityFunction function, byte bits) throws IOException { assert function != VectorSimilarityFunction.COSINE; + int totalVectorCount = floatVectorValues.size(); if (totalVectorCount == 0) { return new ScalarQuantizer(0f, 0f, bits); } @@ -356,8 +336,8 @@ public static ScalarQuantizer fromVectorsAutoInterval( if (totalVectorCount <= sampleSize) { int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount); int i = 0; - while (floatVectorValues.nextDoc() != NO_MORE_DOCS) { - gatherSample(floatVectorValues, quantileGatheringScratch, sampledDocs, i); + for (int ord = 0; ord < floatVectorValues.size(); ord++) { + gatherSample(floatVectorValues.vectorValue(ord), quantileGatheringScratch, sampledDocs, i); i++; if (i == scratchSize) { extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum); @@ -372,13 +352,7 @@ public static ScalarQuantizer fromVectorsAutoInterval( int index = 0; int idx = 0; for (int i : vectorsToTake) { - while (index <= i) { - // We cannot use `advance(docId)` as MergedVectorValues does not support it - floatVectorValues.nextDoc(); - index++; - } - assert floatVectorValues.docID() != NO_MORE_DOCS; - gatherSample(floatVectorValues, quantileGatheringScratch, sampledDocs, idx); + gatherSample(floatVectorValues.vectorValue(i), quantileGatheringScratch, sampledDocs, idx); idx++; if (idx == SCRATCH_SIZE) { extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum); @@ -437,12 +411,8 @@ private static void extractQuantiles( } private static void gatherSample( - FloatVectorValues floatVectorValues, - float[] quantileGatheringScratch, - List sampledDocs, - int i) + float[] vectorValue, float[] quantileGatheringScratch, List sampledDocs, int i) throws IOException { - float[] vectorValue = floatVectorValues.vectorValue(); float[] copy = new float[vectorValue.length]; System.arraycopy(vectorValue, 0, copy, 0, vectorValue.length); sampledDocs.add(copy); diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java index 0798885c9067..dae2cc3502cd 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorer.java @@ -19,11 +19,11 @@ import java.io.IOException; import java.lang.foreign.MemorySegment; import java.util.Optional; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; abstract sealed class Lucene99MemorySegmentByteVectorScorer @@ -39,10 +39,7 @@ abstract sealed class Lucene99MemorySegmentByteVectorScorer * returned. */ public static Optional create( - VectorSimilarityFunction type, - IndexInput input, - RandomAccessVectorValues values, - byte[] queryVector) { + VectorSimilarityFunction type, IndexInput input, KnnVectorValues values, byte[] queryVector) { input = FilterIndexInput.unwrapOnlyTest(input); if (!(input instanceof MemorySegmentAccessInput msInput)) { return Optional.empty(); @@ -58,7 +55,7 @@ public static Optional create( } Lucene99MemorySegmentByteVectorScorer( - MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] queryVector) { + MemorySegmentAccessInput input, KnnVectorValues values, byte[] queryVector) { super(values); this.input = input; this.vectorByteSize = values.getVectorByteLength(); @@ -92,7 +89,7 @@ final void checkOrdinal(int ord) { } static final class CosineScorer extends Lucene99MemorySegmentByteVectorScorer { - CosineScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + CosineScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) { super(input, values, query); } @@ -105,8 +102,7 @@ public float score(int node) throws IOException { } static final class DotProductScorer extends Lucene99MemorySegmentByteVectorScorer { - DotProductScorer( - MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + DotProductScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) { super(input, values, query); } @@ -120,7 +116,7 @@ public float score(int node) throws IOException { } static final class EuclideanScorer extends Lucene99MemorySegmentByteVectorScorer { - EuclideanScorer(MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + EuclideanScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) { super(input, values, query); } @@ -133,8 +129,7 @@ public float score(int node) throws IOException { } static final class MaxInnerProductScorer extends Lucene99MemorySegmentByteVectorScorer { - MaxInnerProductScorer( - MemorySegmentAccessInput input, RandomAccessVectorValues values, byte[] query) { + MaxInnerProductScorer(MemorySegmentAccessInput input, KnnVectorValues values, byte[] query) { super(input, values, query); } diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java index 90b3bfb014c3..9dd2b4620ace 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentByteVectorScorerSupplier.java @@ -19,11 +19,11 @@ import java.io.IOException; import java.lang.foreign.MemorySegment; import java.util.Optional; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.FilterIndexInput; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; @@ -33,7 +33,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier final int vectorByteSize; final int maxOrd; final MemorySegmentAccessInput input; - final RandomAccessVectorValues values; // to support ordToDoc/getAcceptOrds + final KnnVectorValues values; // to support ordToDoc/getAcceptOrds byte[] scratch1, scratch2; /** @@ -41,7 +41,7 @@ public abstract sealed class Lucene99MemorySegmentByteVectorScorerSupplier * optional is returned. */ static Optional create( - VectorSimilarityFunction type, IndexInput input, RandomAccessVectorValues values) { + VectorSimilarityFunction type, IndexInput input, KnnVectorValues values) { input = FilterIndexInput.unwrapOnlyTest(input); if (!(input instanceof MemorySegmentAccessInput msInput)) { return Optional.empty(); @@ -56,7 +56,7 @@ static Optional create( } Lucene99MemorySegmentByteVectorScorerSupplier( - MemorySegmentAccessInput input, RandomAccessVectorValues values) { + MemorySegmentAccessInput input, KnnVectorValues values) { this.input = input; this.values = values; this.vectorByteSize = values.getVectorByteLength(); @@ -103,7 +103,7 @@ final MemorySegment getSecondSegment(int ord) throws IOException { static final class CosineSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier { - CosineSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { + CosineSupplier(MemorySegmentAccessInput input, KnnVectorValues values) { super(input, values); } @@ -128,7 +128,7 @@ public CosineSupplier copy() throws IOException { static final class DotProductSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier { - DotProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { + DotProductSupplier(MemorySegmentAccessInput input, KnnVectorValues values) { super(input, values); } @@ -155,7 +155,7 @@ public DotProductSupplier copy() throws IOException { static final class EuclideanSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier { - EuclideanSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { + EuclideanSupplier(MemorySegmentAccessInput input, KnnVectorValues values) { super(input, values); } @@ -181,7 +181,7 @@ public EuclideanSupplier copy() throws IOException { static final class MaxInnerProductSupplier extends Lucene99MemorySegmentByteVectorScorerSupplier { - MaxInnerProductSupplier(MemorySegmentAccessInput input, RandomAccessVectorValues values) { + MaxInnerProductSupplier(MemorySegmentAccessInput input, KnnVectorValues values) { super(input, values); } diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java index b085185fb113..d901ac655d91 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java @@ -19,11 +19,12 @@ import java.io.IOException; import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; public class Lucene99MemorySegmentFlatVectorsScorer implements FlatVectorsScorer { @@ -38,15 +39,15 @@ private Lucene99MemorySegmentFlatVectorsScorer(FlatVectorsScorer delegate) { @Override public RandomVectorScorerSupplier getRandomVectorScorerSupplier( - VectorSimilarityFunction similarityType, RandomAccessVectorValues vectorValues) - throws IOException { + VectorSimilarityFunction similarityType, KnnVectorValues vectorValues) throws IOException { // a quantized values here is a wrapping or delegation issue - assert !(vectorValues instanceof RandomAccessQuantizedByteVectorValues); + assert !(vectorValues instanceof QuantizedByteVectorValues); // currently only supports binary vectors - if (vectorValues instanceof RandomAccessVectorValues.Bytes && vectorValues.getSlice() != null) { + if (vectorValues instanceof ByteVectorValues byteVectorValues + && byteVectorValues.getSlice() != null) { var scorer = Lucene99MemorySegmentByteVectorScorerSupplier.create( - similarityType, vectorValues.getSlice(), vectorValues); + similarityType, byteVectorValues.getSlice(), vectorValues); if (scorer.isPresent()) { return scorer.get(); } @@ -56,9 +57,7 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier( @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityType, - RandomAccessVectorValues vectorValues, - float[] target) + VectorSimilarityFunction similarityType, KnnVectorValues vectorValues, float[] target) throws IOException { // currently only supports binary vectors, so always delegate return delegate.getRandomVectorScorer(similarityType, vectorValues, target); @@ -66,17 +65,16 @@ public RandomVectorScorer getRandomVectorScorer( @Override public RandomVectorScorer getRandomVectorScorer( - VectorSimilarityFunction similarityType, - RandomAccessVectorValues vectorValues, - byte[] queryVector) + VectorSimilarityFunction similarityType, KnnVectorValues vectorValues, byte[] queryVector) throws IOException { checkDimensions(queryVector.length, vectorValues.dimension()); // a quantized values here is a wrapping or delegation issue - assert !(vectorValues instanceof RandomAccessQuantizedByteVectorValues); - if (vectorValues instanceof RandomAccessVectorValues.Bytes && vectorValues.getSlice() != null) { + assert !(vectorValues instanceof QuantizedByteVectorValues); + if (vectorValues instanceof ByteVectorValues byteVectorValues + && byteVectorValues.getSlice() != null) { var scorer = Lucene99MemorySegmentByteVectorScorer.create( - similarityType, vectorValues.getSlice(), vectorValues, queryVector); + similarityType, byteVectorValues.getSlice(), vectorValues, queryVector); if (scorer.isPresent()) { return scorer.get(); } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java index 9bce1f10a432..6fe9a685e1b4 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/hnsw/TestFlatVectorScorer.java @@ -35,6 +35,8 @@ import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; import org.apache.lucene.codecs.lucene95.OffHeapFloatVectorValues; import org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorScorer; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; @@ -42,7 +44,6 @@ import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.tests.util.LuceneTestCase; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.hamcrest.Matcher; import org.hamcrest.MatcherAssert; @@ -174,13 +175,13 @@ public void testCheckFloatDimensions() throws IOException { } } - RandomAccessVectorValues byteVectorValues( - int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { + ByteVectorValues byteVectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) + throws IOException { return new OffHeapByteVectorValues.DenseOffHeapVectorValues( dims, size, in.slice("byteValues", 0, in.length()), dims, flatVectorsScorer, sim); } - RandomAccessVectorValues floatVectorValues( + FloatVectorValues floatVectorValues( int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { return new OffHeapFloatVectorValues.DenseOffHeapVectorValues( dims, diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index 642b20b73dc2..4ae6b8879fe6 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -17,7 +17,6 @@ package org.apache.lucene.codecs.lucene99; import static java.lang.String.format; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.oneOf; @@ -140,7 +139,7 @@ public void testQuantizedVectorsWriteAndRead() throws Exception { new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors); ScalarQuantizer scalarQuantizer = Lucene99ScalarQuantizedVectorsWriter.buildScalarQuantizer( - toQuantize, numVectors, similarityFunction, confidenceInterval, (byte) bits); + toQuantize, similarityFunction, confidenceInterval, (byte) bits); float[] expectedCorrections = new float[numVectors]; byte[][] expectedVectors = new byte[numVectors][]; for (int i = 0; i < numVectors; i++) { @@ -192,14 +191,13 @@ public void testQuantizedVectorsWriteAndRead() throws Exception { assertNotNull(hnswReader.getQuantizationState("f")); QuantizedByteVectorValues quantizedByteVectorValues = hnswReader.getQuantizedVectorValues("f"); - int docId = -1; - while ((docId = quantizedByteVectorValues.nextDoc()) != NO_MORE_DOCS) { - byte[] vector = quantizedByteVectorValues.vectorValue(); - float offset = quantizedByteVectorValues.getScoreCorrectionConstant(); + for (int ord = 0; ord < quantizedByteVectorValues.size(); ord++) { + byte[] vector = quantizedByteVectorValues.vectorValue(ord); + float offset = quantizedByteVectorValues.getScoreCorrectionConstant(ord); for (int i = 0; i < dim; i++) { - assertEquals(vector[i], expectedVectors[docId][i]); + assertEquals(vector[i], expectedVectors[ord][i]); } - assertEquals(offset, expectedCorrections[docId], 0.00001f); + assertEquals(offset, expectedCorrections[ord], 0.00001f); } } else { fail("reader is not Lucene99HnswVectorsReader"); diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java index 58e6c27e326a..867ae49b0589 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java @@ -46,7 +46,7 @@ import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomVectorScorer; -import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.apache.lucene.util.quantization.QuantizedByteVectorValues; import org.apache.lucene.util.quantization.ScalarQuantizer; public class TestLucene99ScalarQuantizedVectorScorer extends LuceneTestCase { @@ -100,8 +100,8 @@ private void vectorNonZeroScoringTest(int bits, boolean compress) throws IOExcep try (IndexInput in = dir.openInput(fileName, IOContext.DEFAULT)) { Lucene99ScalarQuantizedVectorScorer scorer = new Lucene99ScalarQuantizedVectorScorer(new DefaultFlatVectorScorer()); - RandomAccessQuantizedByteVectorValues values = - new RandomAccessQuantizedByteVectorValues() { + QuantizedByteVectorValues values = + new QuantizedByteVectorValues() { @Override public int dimension() { return 32; @@ -128,7 +128,7 @@ public float getScoreCorrectionConstant(int ord) { } @Override - public RandomAccessQuantizedByteVectorValues copy() throws IOException { + public QuantizedByteVectorValues copy() throws IOException { return this; } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java index 094d90ba5a24..59cb3f68a0e2 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java @@ -17,7 +17,6 @@ package org.apache.lucene.codecs.lucene99; import static java.lang.String.format; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.oneOf; @@ -116,7 +115,6 @@ public void testQuantizedVectorsWriteAndRead() throws Exception { ScalarQuantizer scalarQuantizer = Lucene99ScalarQuantizedVectorsWriter.buildScalarQuantizer( new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors), - numVectors, similarityFunction, confidenceInterval, (byte) bits); @@ -172,9 +170,9 @@ public void testQuantizedVectorsWriteAndRead() throws Exception { QuantizedByteVectorValues quantizedByteVectorValues = quantizedReader.getQuantizedVectorValues("f"); int docId = -1; - while ((docId = quantizedByteVectorValues.nextDoc()) != NO_MORE_DOCS) { - byte[] vector = quantizedByteVectorValues.vectorValue(); - float offset = quantizedByteVectorValues.getScoreCorrectionConstant(); + for (int ord = 0; ord < quantizedByteVectorValues.size(); ord++) { + byte[] vector = quantizedByteVectorValues.vectorValue(ord); + float offset = quantizedByteVectorValues.getScoreCorrectionConstant(ord); for (int i = 0; i < dim; i++) { assertEquals(vector[i], expectedVectors[docId][i]); } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsWriter.java index 0bf2a4ef6b87..a397f7f88bcd 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsWriter.java @@ -96,7 +96,7 @@ private void assertScalarQuantizer( new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors); ScalarQuantizer scalarQuantizer = Lucene99ScalarQuantizedVectorsWriter.buildScalarQuantizer( - vectorValues, 30, vectorSimilarityFunction, confidenceInterval, bits); + vectorValues, vectorSimilarityFunction, confidenceInterval, bits); assertEquals(expectedQuantiles[0], scalarQuantizer.getLowerQuantile(), 0.0001f); assertEquals(expectedQuantiles[1], scalarQuantizer.getUpperQuantile(), 0.0001f); } diff --git a/lucene/core/src/test/org/apache/lucene/document/TestField.java b/lucene/core/src/test/org/apache/lucene/document/TestField.java index 6e3a855a0df4..685f5b6bd1a5 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestField.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestField.java @@ -18,6 +18,7 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import java.io.IOException; import java.io.StringReader; import java.nio.charset.StandardCharsets; import org.apache.lucene.codecs.Codec; @@ -713,17 +714,19 @@ public void testKnnVectorField() throws Exception { try (IndexReader r = DirectoryReader.open(w)) { ByteVectorValues binary = r.leaves().get(0).reader().getByteVectorValues("binary"); assertEquals(1, binary.size()); - assertNotEquals(NO_MORE_DOCS, binary.nextDoc()); - assertNotNull(binary.vectorValue()); - assertArrayEquals(b, binary.vectorValue()); - assertEquals(NO_MORE_DOCS, binary.nextDoc()); + assertNotEquals(NO_MORE_DOCS, binary.iterator().nextDoc()); + assertNotNull(binary.vectorValue(0)); + assertArrayEquals(b, binary.vectorValue(0)); + assertEquals(NO_MORE_DOCS, binary.iterator().nextDoc()); + expectThrows(IOException.class, () -> binary.vectorValue(1)); FloatVectorValues floatValues = r.leaves().get(0).reader().getFloatVectorValues("float"); assertEquals(1, floatValues.size()); - assertNotEquals(NO_MORE_DOCS, floatValues.nextDoc()); - assertEquals(vector.length, floatValues.vectorValue().length); - assertEquals(vector[0], floatValues.vectorValue()[0], 0); - assertEquals(NO_MORE_DOCS, floatValues.nextDoc()); + assertNotEquals(NO_MORE_DOCS, floatValues.iterator().nextDoc()); + assertEquals(vector.length, floatValues.vectorValue(0).length); + assertEquals(vector[0], floatValues.vectorValue(0)[0], 0); + assertEquals(NO_MORE_DOCS, floatValues.iterator().nextDoc()); + expectThrows(IOException.class, () -> floatValues.vectorValue(1)); } } } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java index 3c82cd6b33e4..92f458a1ffb5 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java @@ -459,8 +459,8 @@ public void testFloatVectorValues() throws IOException { expectThrows( ExitingReaderException.class, () -> { - DocIdSetIterator iter = leaf.getFloatVectorValues("vector"); - scanAndRetrieve(leaf, iter); + KnnVectorValues values = leaf.getFloatVectorValues("vector"); + scanAndRetrieve(leaf, values); }); expectThrows( @@ -473,8 +473,8 @@ public void testFloatVectorValues() throws IOException { leaf.getLiveDocs(), Integer.MAX_VALUE)); } else { - DocIdSetIterator iter = leaf.getFloatVectorValues("vector"); - scanAndRetrieve(leaf, iter); + KnnVectorValues values = leaf.getFloatVectorValues("vector"); + scanAndRetrieve(leaf, values); leaf.searchNearestVectors( "vector", @@ -534,8 +534,8 @@ public void testByteVectorValues() throws IOException { expectThrows( ExitingReaderException.class, () -> { - DocIdSetIterator iter = leaf.getByteVectorValues("vector"); - scanAndRetrieve(leaf, iter); + KnnVectorValues values = leaf.getByteVectorValues("vector"); + scanAndRetrieve(leaf, values); }); expectThrows( @@ -549,8 +549,8 @@ public void testByteVectorValues() throws IOException { Integer.MAX_VALUE)); } else { - DocIdSetIterator iter = leaf.getByteVectorValues("vector"); - scanAndRetrieve(leaf, iter); + KnnVectorValues values = leaf.getByteVectorValues("vector"); + scanAndRetrieve(leaf, values); leaf.searchNearestVectors( "vector", @@ -564,20 +564,24 @@ public void testByteVectorValues() throws IOException { directory.close(); } - private static void scanAndRetrieve(LeafReader leaf, DocIdSetIterator iter) throws IOException { - for (iter.nextDoc(); - iter.docID() != DocIdSetIterator.NO_MORE_DOCS && iter.docID() < leaf.maxDoc(); ) { - final int nextDocId = iter.docID() + 1; + private static void scanAndRetrieve(LeafReader leaf, KnnVectorValues values) throws IOException { + KnnVectorValues.KnnValuesDocIterator iter = values.iterator(); + for (int ord = 0; ord < values.size(); ord++) { + int docId = values.ordToDoc(ord); + if (docId >= leaf.maxDoc()) { + break; + } + final int nextDocId = docId + 1; if (random().nextBoolean() && nextDocId < leaf.maxDoc()) { iter.advance(nextDocId); } else { iter.nextDoc(); } - + ord = iter.index(); if (random().nextBoolean() && iter.docID() != DocIdSetIterator.NO_MORE_DOCS - && iter instanceof FloatVectorValues) { - ((FloatVectorValues) iter).vectorValue(); + && values instanceof FloatVectorValues) { + ((FloatVectorValues) values).vectorValue(ord); } } } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index 72be0bd929fa..644bfa3e5048 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -414,7 +414,7 @@ private void assertConsistentGraph(IndexWriter iw, float[][] values, String vect int nextDocWithVectors = 0; StoredFields storedFields = reader.storedFields(); for (int i = 0; i < reader.maxDoc(); i++) { - nextDocWithVectors = vectorValues.advance(i); + nextDocWithVectors = vectorValues.iterator().advance(i); while (i < nextDocWithVectors && i < reader.maxDoc()) { int id = Integer.parseInt(storedFields.document(i).get("id")); assertNull("document " + id + " has no vector, but was expected to", values[id]); @@ -425,7 +425,7 @@ private void assertConsistentGraph(IndexWriter iw, float[][] values, String vect } int id = Integer.parseInt(storedFields.document(i).get("id")); // documents with KnnGraphValues have the expected vectors - float[] scratch = vectorValues.vectorValue(); + float[] scratch = vectorValues.vectorValue(vectorValues.iterator().index()); assertArrayEquals( "vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch), values[id], @@ -435,9 +435,9 @@ private void assertConsistentGraph(IndexWriter iw, float[][] values, String vect } // if IndexDisi.doc == NO_MORE_DOCS, we should not call IndexDisi.nextDoc() if (nextDocWithVectors != NO_MORE_DOCS) { - assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc()); } else { - assertEquals(NO_MORE_DOCS, vectorValues.docID()); + assertEquals(NO_MORE_DOCS, vectorValues.iterator().docID()); } // assert graph values: diff --git a/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java b/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java index 241fc0a5fe5f..82f96897af1f 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java @@ -262,7 +262,7 @@ public void testSortOnAddIndicesRandom() throws IOException { assertTrue(sorted_numeric_dv.advanceExact(idNext)); assertTrue(sorted_set_dv.advanceExact(idNext)); assertTrue(binary_sorted_dv.advanceExact(idNext)); - assertEquals(idNext, vectorValues.advance(idNext)); + assertEquals(idNext, vectorValues.ordToDoc(i)); assertEquals(new BytesRef(ids.longValue() + ""), binary_dv.binaryValue()); assertEquals( new BytesRef(ids.longValue() + ""), @@ -274,7 +274,7 @@ public void testSortOnAddIndicesRandom() throws IOException { assertEquals(1, sorted_numeric_dv.docValueCount()); assertEquals(ids.longValue(), sorted_numeric_dv.nextValue()); - float[] vectorValue = vectorValues.vectorValue(); + float[] vectorValue = vectorValues.vectorValue(i); assertEquals(1, vectorValue.length); assertEquals((float) ids.longValue(), vectorValue[0], 0.001f); diff --git a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java index da9c312ef96d..b935e83331bf 100644 --- a/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/internal/vectorization/TestVectorScorer.java @@ -39,6 +39,7 @@ import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene95.OffHeapByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; @@ -47,7 +48,6 @@ import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.NamedThreadFactory; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.junit.BeforeClass; @@ -329,8 +329,8 @@ public void testLarge() throws IOException { } } - RandomAccessVectorValues vectorValues( - int dims, int size, IndexInput in, VectorSimilarityFunction sim) throws IOException { + KnnVectorValues vectorValues(int dims, int size, IndexInput in, VectorSimilarityFunction sim) + throws IOException { return new OffHeapByteVectorValues.DenseOffHeapVectorValues( dims, size, in.slice("byteValues", 0, in.length()), dims, MEMSEG_SCORER, sim); } diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index 21a33f9ca3e7..afa150e387fe 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -38,6 +38,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.QueryTimeout; @@ -740,7 +741,7 @@ public void testMergeAwayAllValues() throws IOException { LeafReader leafReader = getOnlyLeafReader(reader); FieldInfo fi = leafReader.getFieldInfos().fieldInfo("field"); assertNotNull(fi); - DocIdSetIterator vectorValues; + KnnVectorValues vectorValues; switch (fi.getVectorEncoding()) { case BYTE: vectorValues = leafReader.getByteVectorValues("field"); @@ -752,7 +753,7 @@ public void testMergeAwayAllValues() throws IOException { throw new AssertionError(); } assertNotNull(vectorValues); - assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc()); } } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java index 54de3919b516..83603de88ed2 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java @@ -20,9 +20,10 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.util.BytesRef; -abstract class AbstractMockVectorValues implements RandomAccessVectorValues { +abstract class AbstractMockVectorValues extends KnnVectorValues { protected final int dimension; protected final T[] denseValues; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 4a6794b4994e..24c43cb9e48b 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -56,6 +56,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.StoredFields; @@ -112,18 +113,17 @@ abstract AbstractMockVectorValues vectorValues( abstract Field knnVectorField(String name, T vector, VectorSimilarityFunction similarityFunction); - abstract RandomAccessVectorValues circularVectorValues(int nDoc); + abstract KnnVectorValues circularVectorValues(int nDoc); abstract T getTargetVector(); - protected RandomVectorScorerSupplier buildScorerSupplier(RandomAccessVectorValues vectors) + protected RandomVectorScorerSupplier buildScorerSupplier(KnnVectorValues vectors) throws IOException { return flatVectorScorer.getRandomVectorScorerSupplier(similarityFunction, vectors); } - protected RandomVectorScorer buildScorer(RandomAccessVectorValues vectors, T query) - throws IOException { - RandomAccessVectorValues vectorsCopy = vectors.copy(); + protected RandomVectorScorer buildScorer(KnnVectorValues vectors, T query) throws IOException { + KnnVectorValues vectorsCopy = vectors.copy(); return switch (getVectorEncoding()) { case BYTE -> flatVectorScorer.getRandomVectorScorer(similarityFunction, vectorsCopy, (byte[]) query); @@ -461,7 +461,7 @@ void assertGraphEqual(HnswGraph g, HnswGraph h) throws IOException { public void testAknnDiverse() throws IOException { int nDoc = 100; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; - RandomAccessVectorValues vectors = circularVectorValues(nDoc); + KnnVectorValues vectors = circularVectorValues(nDoc); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 100, random().nextInt()); OnHeapHnswGraph hnsw = builder.build(vectors.size()); @@ -493,7 +493,7 @@ public void testAknnDiverse() throws IOException { @SuppressWarnings("unchecked") public void testSearchWithAcceptOrds() throws IOException { int nDoc = 100; - RandomAccessVectorValues vectors = circularVectorValues(nDoc); + KnnVectorValues vectors = circularVectorValues(nDoc); similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); @@ -518,7 +518,7 @@ public void testSearchWithAcceptOrds() throws IOException { @SuppressWarnings("unchecked") public void testSearchWithSelectiveAcceptOrds() throws IOException { int nDoc = 100; - RandomAccessVectorValues vectors = circularVectorValues(nDoc); + KnnVectorValues vectors = circularVectorValues(nDoc); similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); @@ -711,7 +711,7 @@ private int[] createOffsetOrdinalMap( public void testVisitedLimit() throws IOException { int nDoc = 500; similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; - RandomAccessVectorValues vectors = circularVectorValues(nDoc); + KnnVectorValues vectors = circularVectorValues(nDoc); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); OnHeapHnswGraph hnsw = builder.build(vectors.size()); @@ -746,7 +746,7 @@ public void testRamUsageEstimate() throws IOException { int M = randomIntBetween(4, 96); similarityFunction = RandomizedTest.randomFrom(VectorSimilarityFunction.values()); - RandomAccessVectorValues vectors = vectorValues(size, dim); + KnnVectorValues vectors = vectorValues(size, dim); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = @@ -1076,8 +1076,7 @@ private int computeOverlap(int[] a, int[] b) { } /** Returns vectors evenly distributed around the upper unit semicircle. */ - static class CircularFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues.Floats { + static class CircularFloatVectorValues extends FloatVectorValues { private final int size; private final float[] value; @@ -1103,22 +1102,18 @@ public int size() { return size; } - @Override public float[] vectorValue() { return vectorValue(doc); } - @Override public int docID() { return doc; } - @Override public int nextDoc() { return advance(doc + 1); } - @Override public int advance(int target) { if (target >= 0 && target < size) { doc = target; @@ -1140,8 +1135,7 @@ public VectorScorer scorer(float[] target) { } /** Returns vectors evenly distributed around the upper unit semicircle. */ - static class CircularByteVectorValues extends ByteVectorValues - implements RandomAccessVectorValues.Bytes { + static class CircularByteVectorValues extends ByteVectorValues { private final int size; private final float[] value; private final byte[] bValue; @@ -1169,22 +1163,18 @@ public int size() { return size; } - @Override public byte[] vectorValue() { return vectorValue(doc); } - @Override public int docID() { return doc; } - @Override public int nextDoc() { return advance(doc + 1); } - @Override public int advance(int target) { if (target >= 0 && target < size) { doc = target; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java index a3b17b9a621e..6a51cb60335f 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java @@ -20,8 +20,7 @@ import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.ArrayUtil; -class MockByteVectorValues extends AbstractMockVectorValues - implements RandomAccessVectorValues.Bytes { +class MockByteVectorValues extends AbstractMockVectorValues { private final byte[] scratch; static MockByteVectorValues fromValues(byte[][] values) { @@ -71,4 +70,9 @@ public byte[] vectorValue() { return scratch; } } + + @Override + public int getVectorByteLength() { + throw new UnsupportedOperationException(); + } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java index f183f6c99a67..cfd0f5a11e05 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java @@ -20,8 +20,7 @@ import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.ArrayUtil; -class MockVectorValues extends AbstractMockVectorValues - implements RandomAccessVectorValues.Floats { +class MockVectorValues extends AbstractMockVectorValues { private final float[] scratch; static MockVectorValues fromValues(float[][] values) { @@ -70,4 +69,9 @@ public float[] vectorValue() { public float[] vectorValue(int targetOrd) { return denseValues[targetOrd]; } + + @Override + public int getVectorByteLength() { + throw new UnsupportedOperationException(); + } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java index 649bc1a64519..88149b8f54d0 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java @@ -118,10 +118,9 @@ AbstractMockVectorValues vectorValues(LeafReader reader, String fieldNam throws IOException { ByteVectorValues vectorValues = reader.getByteVectorValues(fieldName); byte[][] vectors = new byte[reader.maxDoc()][]; - while (vectorValues.nextDoc() != NO_MORE_DOCS) { - vectors[vectorValues.docID()] = - ArrayUtil.copyOfSubArray( - vectorValues.vectorValue(), 0, vectorValues.vectorValue().length); + for (int i = 0; i < vectorValues.size(); i++) { + vectors[vectorValues.ordToDoc(i)] = + ArrayUtil.copyOfSubArray(vectorValues.vectorValue(i), 0, vectorValues.dimension()); } return MockByteVectorValues.fromValues(vectors); } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java index 5621edc4b35e..8500db4a9cba 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java @@ -74,10 +74,9 @@ AbstractMockVectorValues vectorValues(LeafReader reader, String fieldNa throws IOException { FloatVectorValues vectorValues = reader.getFloatVectorValues(fieldName); float[][] vectors = new float[reader.maxDoc()][]; - while (vectorValues.nextDoc() != NO_MORE_DOCS) { - vectors[vectorValues.docID()] = - ArrayUtil.copyOfSubArray( - vectorValues.vectorValue(), 0, vectorValues.vectorValue().length); + for (int i = 0; i < vectorValues.size(); i++) { + vectors[vectorValues.ordToDoc(i)] = + ArrayUtil.copyOfSubArray(vectorValues.vectorValue(i), 0, vectorValues.dimension()); } return MockVectorValues.fromValues(vectors); } @@ -129,7 +128,7 @@ float[] getTargetVector() { public void testSearchWithSkewedAcceptOrds() throws IOException { int nDoc = 1000; similarityFunction = VectorSimilarityFunction.EUCLIDEAN; - RandomAccessVectorValues.Floats vectors = circularVectorValues(nDoc); + FloatVectorValues vectors = circularVectorValues(nDoc); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 16, 100, random().nextInt()); OnHeapHnswGraph hnsw = builder.build(vectors.size()); diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java index bdba822d4eca..a001982049b5 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java @@ -59,8 +59,7 @@ public void testToEuclidean() throws IOException { float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f); FloatVectorValues floatVectorValues = fromFloats(floats); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors( - floatVectorValues, confidenceInterval, floats.length, (byte) 7); + ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, (byte) 7); byte[][] quantized = new byte[floats.length][]; float[] offsets = quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.EUCLIDEAN); @@ -92,8 +91,7 @@ public void testToCosine() throws IOException { float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f); FloatVectorValues floatVectorValues = fromFloatsNormalized(floats, null); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors( - floatVectorValues, confidenceInterval, floats.length, (byte) 7); + ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, (byte) 7); byte[][] quantized = new byte[floats.length][]; float[] offsets = quantizeVectorsNormalized( @@ -129,8 +127,7 @@ public void testToDotProduct() throws IOException { float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f); FloatVectorValues floatVectorValues = fromFloats(floats); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors( - floatVectorValues, confidenceInterval, floats.length, (byte) 7); + ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, (byte) 7); byte[][] quantized = new byte[floats.length][]; float[] offsets = quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.DOT_PRODUCT); @@ -162,8 +159,7 @@ public void testToMaxInnerProduct() throws IOException { float error = Math.max((100 - confidenceInterval) * 0.5f, 0.5f); FloatVectorValues floatVectorValues = fromFloats(floats); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors( - floatVectorValues, confidenceInterval, floats.length, (byte) 7); + ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, (byte) 7); byte[][] quantized = new byte[floats.length][]; float[] offsets = quantizeVectors( diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index 48eb7ce651c6..5fffc1b69582 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -16,6 +16,7 @@ */ package org.apache.lucene.util.quantization; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.util.quantization.ScalarQuantizer.SCRATCH_SIZE; import java.io.IOException; @@ -46,9 +47,8 @@ public void testTinyVectors() throws IOException { } ScalarQuantizer scalarQuantizer = random().nextBoolean() - ? ScalarQuantizer.fromVectors(floatVectorValues, 0.9f, numVecs, bits) - : ScalarQuantizer.fromVectorsAutoInterval( - floatVectorValues, function, numVecs, bits); + ? ScalarQuantizer.fromVectors(floatVectorValues, 0.9f, bits) + : ScalarQuantizer.fromVectorsAutoInterval(floatVectorValues, function, bits); // We simply assert that we created a scalar quantizer and didn't trip any assertions // the quality of the quantization might be poor, but this is expected as sampling size is // tiny @@ -71,16 +71,14 @@ public void testNanAndInfValueFailure() { FloatVectorValues floatVectorValues = fromFloats(floats); expectThrows( IllegalStateException.class, - () -> ScalarQuantizer.fromVectors(floatVectorValues, 0.9f, numVecs, bits)); + () -> ScalarQuantizer.fromVectors(floatVectorValues, 0.9f, bits)); VectorSimilarityFunction actualFunction = function == VectorSimilarityFunction.COSINE ? VectorSimilarityFunction.DOT_PRODUCT : function; expectThrows( IllegalStateException.class, - () -> - ScalarQuantizer.fromVectorsAutoInterval( - floatVectorValues, actualFunction, numVecs, bits)); + () -> ScalarQuantizer.fromVectorsAutoInterval(floatVectorValues, actualFunction, bits)); } } } @@ -92,8 +90,7 @@ public void testQuantizeAndDeQuantize7Bit() throws IOException { float[][] floats = randomFloats(numVecs, dims); FloatVectorValues floatVectorValues = fromFloats(floats); - ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors(floatVectorValues, 1, numVecs, (byte) 7); + ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, 1, (byte) 7); float[] dequantized = new float[dims]; byte[] quantized = new byte[dims]; byte[] requantized = new byte[dims]; @@ -156,7 +153,6 @@ public void testScalarWithSampling() throws IOException { ScalarQuantizer.fromVectors( floatVectorValues, 0.99f, - floatVectorValues.numLiveVectors, (byte) 7, Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1)); } @@ -166,7 +162,6 @@ public void testScalarWithSampling() throws IOException { ScalarQuantizer.fromVectors( floatVectorValues, 0.99f, - floatVectorValues.numLiveVectors, (byte) 7, Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1)); } @@ -176,7 +171,6 @@ public void testScalarWithSampling() throws IOException { ScalarQuantizer.fromVectors( floatVectorValues, 0.99f, - floatVectorValues.numLiveVectors, (byte) 7, Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1)); } @@ -186,7 +180,6 @@ public void testScalarWithSampling() throws IOException { ScalarQuantizer.fromVectors( floatVectorValues, 0.99f, - floatVectorValues.numLiveVectors, (byte) 7, Math.max(random().nextInt(floatVectorValues.floats.length - 1) + 1, SCRATCH_SIZE + 1)); } @@ -203,8 +196,7 @@ public void testFromVectorsAutoInterval4Bit() throws IOException { } FloatVectorValues floatVectorValues = fromFloats(floats); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectorsAutoInterval( - floatVectorValues, similarityFunction, numVecs, (byte) 4); + ScalarQuantizer.fromVectorsAutoInterval(floatVectorValues, similarityFunction, (byte) 4); assertNotNull(scalarQuantizer); float[] dequantized = new float[dims]; byte[] quantized = new byte[dims]; @@ -289,10 +281,15 @@ public int dimension() { @Override public int size() { + // FIXME I don't get what this is up to return floats.length; } @Override + public float[] vectorValue(int ord) throws IOException { + return floats[ord]; + } + public float[] vectorValue() throws IOException { if (curDoc == -1 || curDoc >= floats.length) { throw new IOException("Current doc not set or too many iterations"); @@ -300,7 +297,6 @@ public float[] vectorValue() throws IOException { return floats[curDoc]; } - @Override public int docID() { if (curDoc >= floats.length) { return NO_MORE_DOCS; @@ -308,7 +304,6 @@ public int docID() { return curDoc; } - @Override public int nextDoc() throws IOException { while (++curDoc < floats.length) { if (deletedVectors == null || !deletedVectors.contains(curDoc)) { @@ -318,7 +313,6 @@ public int nextDoc() throws IOException { return docID(); } - @Override public int advance(int target) throws IOException { curDoc = target - 1; return nextDoc(); diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java index 8874f4d30c6b..1d46428dae2f 100644 --- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java +++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java @@ -2302,8 +2302,8 @@ public int size() { } @Override - public float[] vectorValue() { - if (currentDoc == 0) { + public float[] vectorValue(int ord) { + if (ord == 0) { return info.floatVectorValues[0]; } else { return null; @@ -2325,40 +2325,17 @@ public VectorScorer scorer(float[] query) { public float score() throws IOException { return info.fieldInfo .getVectorSimilarityFunction() - .compare(vectorValues.vectorValue(), query); + .compare( + vectorValues.vectorValue(vectorValues.docToOrd(vectorValues.iterator().docID())), + query); } @Override public DocIdSetIterator iterator() { - return vectorValues; + return vectorValues.iterator(); } }; } - - @Override - public int docID() { - return currentDoc; - } - - @Override - public int nextDoc() { - int doc = ++currentDoc; - if (doc == 0) { - return doc; - } else { - return NO_MORE_DOCS; - } - } - - @Override - public int advance(int target) { - if (target == 0) { - currentDoc = target; - return target; - } else { - return NO_MORE_DOCS; - } - } } private static final class MemoryByteVectorValues extends ByteVectorValues { @@ -2380,8 +2357,8 @@ public int size() { } @Override - public byte[] vectorValue() { - if (currentDoc == 0) { + public byte[] vectorValue(int ord) { + if (ord == 0) { return info.byteVectorValues[0]; } else { return null; @@ -2403,39 +2380,16 @@ public VectorScorer scorer(byte[] query) { public float score() { return info.fieldInfo .getVectorSimilarityFunction() - .compare(vectorValues.vectorValue(), query); + .compare( + vectorValues.vectorValue(vectorValues.docToOrd(vectorValues.iterator().docID())), + query); } @Override public DocIdSetIterator iterator() { - return vectorValues; + return vectorValues.iterator(); } }; } - - @Override - public int docID() { - return currentDoc; - } - - @Override - public int nextDoc() { - int doc = ++currentDoc; - if (doc == 0) { - return doc; - } else { - return NO_MORE_DOCS; - } - } - - @Override - public int advance(int target) { - if (target == 0) { - currentDoc = target; - return target; - } else { - return NO_MORE_DOCS; - } - } } } diff --git a/lucene/memory/src/test/org/apache/lucene/index/memory/TestMemoryIndex.java b/lucene/memory/src/test/org/apache/lucene/index/memory/TestMemoryIndex.java index 18e97c67d9d9..6a68de001832 100644 --- a/lucene/memory/src/test/org/apache/lucene/index/memory/TestMemoryIndex.java +++ b/lucene/memory/src/test/org/apache/lucene/index/memory/TestMemoryIndex.java @@ -851,9 +851,9 @@ private static void assertFloatVectorValue(MemoryIndex mi, String fieldName, flo .reader() .getFloatVectorValues(fieldName); assertNotNull(fvv); - assertEquals(0, fvv.nextDoc()); - assertArrayEquals(expected, fvv.vectorValue(), 1e-6f); - assertEquals(DocIdSetIterator.NO_MORE_DOCS, fvv.nextDoc()); + assertEquals(0, fvv.iterator().nextDoc()); + assertArrayEquals(expected, fvv.vectorValue(0), 1e-6f); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, fvv.iterator().nextDoc()); } private static void assertFloatVectorScore( @@ -868,7 +868,7 @@ private static void assertFloatVectorScore( .getFloatVectorValues(fieldName); assertNotNull(fvv); if (random().nextBoolean()) { - fvv.nextDoc(); + fvv.iterator().nextDoc(); } VectorScorer scorer = fvv.scorer(queryVector); assertEquals(0, scorer.iterator().nextDoc()); @@ -886,9 +886,9 @@ private static void assertByteVectorValue(MemoryIndex mi, String fieldName, byte .reader() .getByteVectorValues(fieldName); assertNotNull(bvv); - assertEquals(0, bvv.nextDoc()); - assertArrayEquals(expected, bvv.vectorValue()); - assertEquals(DocIdSetIterator.NO_MORE_DOCS, bvv.nextDoc()); + assertEquals(0, bvv.iterator().nextDoc()); + assertArrayEquals(expected, bvv.vectorValue(0)); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, bvv.iterator().nextDoc()); } private static void assertByteVectorScore( @@ -903,7 +903,7 @@ private static void assertByteVectorScore( .getByteVectorValues(fieldName); assertNotNull(bvv); if (random().nextBoolean()) { - bvv.nextDoc(); + bvv.iterator().nextDoc(); } VectorScorer scorer = bvv.scorer(queryVector); assertEquals(0, scorer.iterator().nextDoc()); diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java index 32517496d542..26d3f21486dd 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java @@ -67,7 +67,7 @@ protected DocIdSetIterator getVectorIterator() { @Override public byte[] byteVectorVal(int doc) throws IOException { if (exists(doc)) { - return vectorValues.vectorValue(); + return vectorValues.vectorValue(vectorValues.docToOrd(vectorValues.iterator().docID())); } else { return null; } @@ -75,7 +75,7 @@ public byte[] byteVectorVal(int doc) throws IOException { @Override protected DocIdSetIterator getVectorIterator() { - return vectorValues; + return vectorValues.iterator(); } }; } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java index 43cc3aff880e..1438477b61c9 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java @@ -66,7 +66,7 @@ protected DocIdSetIterator getVectorIterator() { @Override public float[] floatVectorVal(int doc) throws IOException { if (exists(doc)) { - return vectorValues.vectorValue(); + return vectorValues.vectorValue(vectorValues.docToOrd(doc)); } else { return null; } @@ -74,7 +74,7 @@ public float[] floatVectorVal(int doc) throws IOException { @Override protected DocIdSetIterator getVectorIterator() { - return vectorValues; + return vectorValues.iterator(); } }; } diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/KMeans.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/KMeans.java index bb9d3ca63df5..88d2adba5fad 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/KMeans.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/KMeans.java @@ -25,11 +25,11 @@ import java.util.List; import java.util.Random; import java.util.Set; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.NeighborQueue; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** KMeans clustering algorithm for vectors */ public class KMeans { @@ -38,7 +38,7 @@ public class KMeans { public static final int DEFAULT_ITRS = 10; public static final int DEFAULT_SAMPLE_SIZE = 100_000; - private final RandomAccessVectorValues.Floats vectors; + private final FloatVectorValues vectors; private final int numVectors; private final int numCentroids; private final Random random; @@ -57,9 +57,7 @@ public class KMeans { * @throws IOException when if there is an error accessing vectors */ public static Results cluster( - RandomAccessVectorValues.Floats vectors, - VectorSimilarityFunction similarityFunction, - int numClusters) + FloatVectorValues vectors, VectorSimilarityFunction similarityFunction, int numClusters) throws IOException { return cluster( vectors, @@ -93,7 +91,7 @@ public static Results cluster( * @throws IOException if there is error accessing vectors */ public static Results cluster( - RandomAccessVectorValues.Floats vectors, + FloatVectorValues vectors, int numClusters, boolean assignCentroidsToVectors, long seed, @@ -124,7 +122,7 @@ public static Results cluster( if (numClusters == 1) { centroids = new float[1][vectors.dimension()]; } else { - RandomAccessVectorValues.Floats sampleVectors = + FloatVectorValues sampleVectors = vectors.size() <= sampleSize ? vectors : createSampleReader(vectors, sampleSize, seed); KMeans kmeans = new KMeans(sampleVectors, numClusters, random, initializationMethod, restarts, iters); @@ -142,7 +140,7 @@ public static Results cluster( } private KMeans( - RandomAccessVectorValues.Floats vectors, + FloatVectorValues vectors, int numCentroids, Random random, KmeansInitializationMethod initializationMethod, @@ -276,7 +274,7 @@ private float[][] initializePlusPlus() throws IOException { * @throws IOException if there is an error accessing vector values */ private static double runKMeansStep( - RandomAccessVectorValues.Floats vectors, + FloatVectorValues vectors, float[][] centroids, short[] docCentroids, boolean useKahanSummation, @@ -348,9 +346,7 @@ private static double runKMeansStep( * descending distance to the current centroid set */ static void assignCentroids( - RandomAccessVectorValues.Floats vectors, - float[][] centroids, - List unassignedCentroidsIdxs) + FloatVectorValues vectors, float[][] centroids, List unassignedCentroidsIdxs) throws IOException { int[] assignedCentroidsIdxs = new int[centroids.length - unassignedCentroidsIdxs.size()]; int assignedIndex = 0; diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java index 9a718c811017..09a5b911e17a 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java @@ -20,18 +20,17 @@ import java.io.IOException; import java.util.Random; import java.util.function.IntUnaryOperator; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** A reader of vector values that samples a subset of the vectors. */ -public class SampleReader implements RandomAccessVectorValues.Floats { - private final RandomAccessVectorValues.Floats origin; +public class SampleReader extends FloatVectorValues { + private final FloatVectorValues origin; private final int sampleSize; private final IntUnaryOperator sampleFunction; - SampleReader( - RandomAccessVectorValues.Floats origin, int sampleSize, IntUnaryOperator sampleFunction) { + SampleReader(FloatVectorValues origin, int sampleSize, IntUnaryOperator sampleFunction) { this.origin = origin; this.sampleSize = sampleSize; this.sampleFunction = sampleFunction; @@ -48,7 +47,7 @@ public int dimension() { } @Override - public Floats copy() throws IOException { + public FloatVectorValues copy() throws IOException { throw new IllegalStateException("Not supported"); } @@ -77,8 +76,7 @@ public Bits getAcceptOrds(Bits acceptDocs) { throw new IllegalStateException("Not supported"); } - public static SampleReader createSampleReader( - RandomAccessVectorValues.Floats origin, int k, long seed) { + public static SampleReader createSampleReader(FloatVectorValues origin, int k, long seed) { int[] samples = reservoirSample(origin.size(), k, seed); return new SampleReader(origin, samples.length, i -> samples[i]); } diff --git a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/quantization/TestKMeans.java b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/quantization/TestKMeans.java index 61c0e58c91ef..3669079b719d 100644 --- a/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/quantization/TestKMeans.java +++ b/lucene/sandbox/src/test/org/apache/lucene/sandbox/codecs/quantization/TestKMeans.java @@ -20,9 +20,9 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.tests.util.LuceneTestCase; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; public class TestKMeans extends LuceneTestCase { @@ -32,7 +32,7 @@ public void testKMeansAPI() throws IOException { int dims = random().nextInt(2, 20); int randIdx = random().nextInt(VectorSimilarityFunction.values().length); VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.values()[randIdx]; - RandomAccessVectorValues.Floats vectors = generateData(nVectors, dims, nClusters); + FloatVectorValues vectors = generateData(nVectors, dims, nClusters); // default case { @@ -75,7 +75,7 @@ public void testKMeansSpecialCases() throws IOException { // nClusters > nVectors int nClusters = 20; int nVectors = 10; - RandomAccessVectorValues.Floats vectors = generateData(nVectors, 5, nClusters); + FloatVectorValues vectors = generateData(nVectors, 5, nClusters); KMeans.Results results = KMeans.cluster(vectors, VectorSimilarityFunction.EUCLIDEAN, nClusters); // assert that we get 1 centroid, as nClusters will be adjusted @@ -87,7 +87,7 @@ public void testKMeansSpecialCases() throws IOException { int sampleSize = 2; int nClusters = 2; int nVectors = 300; - RandomAccessVectorValues.Floats vectors = generateData(nVectors, 5, nClusters); + FloatVectorValues vectors = generateData(nVectors, 5, nClusters); KMeans.KmeansInitializationMethod initializationMethod = KMeans.KmeansInitializationMethod.PLUS_PLUS; KMeans.Results results = @@ -108,7 +108,7 @@ public void testKMeansSpecialCases() throws IOException { // test unassigned centroids int nClusters = 4; int nVectors = 400; - RandomAccessVectorValues.Floats vectors = generateData(nVectors, 5, nClusters); + FloatVectorValues vectors = generateData(nVectors, 5, nClusters); KMeans.Results results = KMeans.cluster(vectors, VectorSimilarityFunction.EUCLIDEAN, nClusters); float[][] centroids = results.centroids(); @@ -118,8 +118,7 @@ public void testKMeansSpecialCases() throws IOException { } } - private static RandomAccessVectorValues.Floats generateData( - int nSamples, int nDims, int nClusters) { + private static FloatVectorValues generateData(int nSamples, int nDims, int nClusters) { List vectors = new ArrayList<>(nSamples); float[][] centroids = new float[nClusters][nDims]; // Generate random centroids @@ -137,6 +136,6 @@ private static RandomAccessVectorValues.Floats generateData( } vectors.add(vector); } - return RandomAccessVectorValues.fromFloats(vectors, nDims); + return FloatVectorValues.fromFloats(vectors, nDims); } } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java index 501e2e5616f0..21c62090a698 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java @@ -125,7 +125,7 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { && fi.getVectorEncoding() == VectorEncoding.FLOAT32; FloatVectorValues floatValues = delegate.getFloatVectorValues(field); assert floatValues != null; - assert floatValues.docID() == -1; + assert floatValues.iterator().docID() == -1; assert floatValues.size() >= 0; assert floatValues.dimension() > 0; return floatValues; @@ -139,7 +139,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { && fi.getVectorEncoding() == VectorEncoding.BYTE; ByteVectorValues values = delegate.getByteVectorValues(field); assert values != null; - assert values.docID() == -1; + assert values.iterator().docID() == -1; assert values.size() >= 0; assert values.dimension() > 0; return values; diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index ed4abb6f2c62..51daaa0380d5 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -434,9 +434,9 @@ public void testAddIndexesDirectory0() throws Exception { try (IndexReader reader = DirectoryReader.open(w2)) { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); - assertEquals(0, vectorValues.nextDoc()); - assertEquals(0, vectorValues.vectorValue()[0], 0); - assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + assertEquals(0, vectorValues.iterator().nextDoc()); + assertEquals(0, vectorValues.vectorValue(0)[0], 0); + assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc()); } } } @@ -459,9 +459,9 @@ public void testAddIndexesDirectory1() throws Exception { try (IndexReader reader = DirectoryReader.open(w2)) { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); - assertNotEquals(NO_MORE_DOCS, vectorValues.nextDoc()); - assertEquals(0, vectorValues.vectorValue()[0], 0); - assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + assertNotEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc()); + assertEquals(0, vectorValues.vectorValue(vectorValues.iterator().index())[0], 0); + assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc()); } } } @@ -486,12 +486,12 @@ public void testAddIndexesDirectory01() throws Exception { try (IndexReader reader = DirectoryReader.open(w2)) { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); - assertEquals(0, vectorValues.nextDoc()); + assertEquals(0, vectorValues.iterator().nextDoc()); // The merge order is randomized, we might get 0 first, or 1 - float value = vectorValues.vectorValue()[0]; + float value = vectorValues.vectorValue(0)[0]; assertTrue(value == 0 || value == 1); - assertEquals(1, vectorValues.nextDoc()); - value += vectorValues.vectorValue()[0]; + assertEquals(1, vectorValues.iterator().nextDoc()); + value += vectorValues.vectorValue(1)[0]; assertEquals(1, value, 0); } } @@ -876,8 +876,9 @@ public void testSparseVectors() throws Exception { ByteVectorValues byteVectorValues = ctx.reader().getByteVectorValues(fieldName); if (byteVectorValues != null) { docCount += byteVectorValues.size(); - while (byteVectorValues.nextDoc() != NO_MORE_DOCS) { - checksum += byteVectorValues.vectorValue()[0]; + while (byteVectorValues.iterator().nextDoc() != NO_MORE_DOCS) { + checksum += + byteVectorValues.vectorValue(byteVectorValues.iterator().index())[0]; } } } @@ -887,8 +888,8 @@ public void testSparseVectors() throws Exception { FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName); if (vectorValues != null) { docCount += vectorValues.size(); - while (vectorValues.nextDoc() != NO_MORE_DOCS) { - checksum += vectorValues.vectorValue()[0]; + while (vectorValues.iterator().nextDoc() != NO_MORE_DOCS) { + checksum += vectorValues.vectorValue(vectorValues.iterator().index())[0]; } } } @@ -945,12 +946,14 @@ public void testFloatVectorScorerIteration() throws Exception { assertNotNull(scorer); DocIdSetIterator iterator = scorer.iterator(); assertSame(iterator, scorer.iterator()); + assertSame(iterator, vectorValues.iterator()); assertNotSame(iterator, scorer); // verify scorer iteration scores are valid & iteration with vectorValues is consistent - while (iterator.nextDoc() != NO_MORE_DOCS && vectorValues.nextDoc() != NO_MORE_DOCS) { + while (iterator.nextDoc() != NO_MORE_DOCS + && vectorValues.iterator().nextDoc() != NO_MORE_DOCS) { float score = scorer.score(); assertTrue(score >= 0f); - assertEquals(iterator.docID(), vectorValues.docID()); + assertEquals(iterator.docID(), vectorValues.iterator().docID()); } // verify that a new scorer can be obtained after iteration VectorScorer newScorer = vectorValues.scorer(vectorToScore); @@ -1004,12 +1007,14 @@ public void testByteVectorScorerIteration() throws Exception { assertNotNull(scorer); DocIdSetIterator iterator = scorer.iterator(); assertSame(iterator, scorer.iterator()); + assertSame(iterator, vectorValues.iterator()); assertNotSame(iterator, scorer); // verify scorer iteration scores are valid & iteration with vectorValues is consistent - while (iterator.nextDoc() != NO_MORE_DOCS && vectorValues.nextDoc() != NO_MORE_DOCS) { + while (iterator.nextDoc() != NO_MORE_DOCS + && vectorValues.iterator().nextDoc() != NO_MORE_DOCS) { float score = scorer.score(); assertTrue(score >= 0f); - assertEquals(iterator.docID(), vectorValues.docID()); + assertEquals(iterator.docID(), vectorValues.iterator().docID()); } // verify that a new scorer can be obtained after iteration VectorScorer newScorer = vectorValues.scorer(vectorToScore); @@ -1115,12 +1120,15 @@ public void testIndexedValueNotAliased() throws Exception { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); assertEquals(3, vectorValues.size()); - vectorValues.nextDoc(); - assertEquals(1, vectorValues.vectorValue()[0], 0); - vectorValues.nextDoc(); - assertEquals(1, vectorValues.vectorValue()[0], 0); - vectorValues.nextDoc(); - assertEquals(2, vectorValues.vectorValue()[0], 0); + vectorValues.iterator().nextDoc(); + assertEquals(0, vectorValues.iterator().index()); + assertEquals(1, vectorValues.vectorValue(0)[0], 0); + vectorValues.iterator().nextDoc(); + assertEquals(1, vectorValues.iterator().index()); + assertEquals(1, vectorValues.vectorValue(1)[0], 0); + vectorValues.iterator().nextDoc(); + assertEquals(2, vectorValues.iterator().index()); + assertEquals(2, vectorValues.vectorValue(2)[0], 0); } } } @@ -1143,13 +1151,13 @@ public void testSortedIndex() throws Exception { FloatVectorValues vectorValues = leaf.getFloatVectorValues(fieldName); assertEquals(2, vectorValues.dimension()); assertEquals(3, vectorValues.size()); - assertEquals("1", storedFields.document(vectorValues.nextDoc()).get("id")); - assertEquals(-1f, vectorValues.vectorValue()[0], 0); - assertEquals("2", storedFields.document(vectorValues.nextDoc()).get("id")); - assertEquals(1, vectorValues.vectorValue()[0], 0); - assertEquals("4", storedFields.document(vectorValues.nextDoc()).get("id")); - assertEquals(0, vectorValues.vectorValue()[0], 0); - assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + assertEquals("1", storedFields.document(vectorValues.iterator().nextDoc()).get("id")); + assertEquals(-1f, vectorValues.vectorValue(0)[0], 0); + assertEquals("2", storedFields.document(vectorValues.iterator().nextDoc()).get("id")); + assertEquals(1, vectorValues.vectorValue(1)[0], 0); + assertEquals("4", storedFields.document(vectorValues.iterator().nextDoc()).get("id")); + assertEquals(0, vectorValues.vectorValue(2)[0], 0); + assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc()); } } } @@ -1172,13 +1180,13 @@ public void testSortedIndexBytes() throws Exception { ByteVectorValues vectorValues = leaf.getByteVectorValues(fieldName); assertEquals(2, vectorValues.dimension()); assertEquals(3, vectorValues.size()); - assertEquals("1", storedFields.document(vectorValues.nextDoc()).get("id")); - assertEquals(-1, vectorValues.vectorValue()[0], 0); - assertEquals("2", storedFields.document(vectorValues.nextDoc()).get("id")); - assertEquals(1, vectorValues.vectorValue()[0], 0); - assertEquals("4", storedFields.document(vectorValues.nextDoc()).get("id")); - assertEquals(0, vectorValues.vectorValue()[0], 0); - assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + assertEquals("1", storedFields.document(vectorValues.iterator().nextDoc()).get("id")); + assertEquals(-1, vectorValues.vectorValue(0)[0], 0); + assertEquals("2", storedFields.document(vectorValues.iterator().nextDoc()).get("id")); + assertEquals(1, vectorValues.vectorValue(1)[0], 0); + assertEquals("4", storedFields.document(vectorValues.iterator().nextDoc()).get("id")); + assertEquals(0, vectorValues.vectorValue(2)[0], 0); + assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc()); } } } @@ -1208,27 +1216,27 @@ public void testIndexMultipleKnnVectorFields() throws Exception { FloatVectorValues vectorValues = leaf.getFloatVectorValues("field1"); assertEquals(2, vectorValues.dimension()); assertEquals(2, vectorValues.size()); - vectorValues.nextDoc(); - assertEquals(1f, vectorValues.vectorValue()[0], 0); - vectorValues.nextDoc(); - assertEquals(2f, vectorValues.vectorValue()[0], 0); - assertEquals(NO_MORE_DOCS, vectorValues.nextDoc()); + vectorValues.iterator().nextDoc(); + assertEquals(1f, vectorValues.vectorValue(0)[0], 0); + vectorValues.iterator().nextDoc(); + assertEquals(2f, vectorValues.vectorValue(1)[0], 0); + assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc()); FloatVectorValues vectorValues2 = leaf.getFloatVectorValues("field2"); assertEquals(4, vectorValues2.dimension()); assertEquals(2, vectorValues2.size()); - vectorValues2.nextDoc(); - assertEquals(2f, vectorValues2.vectorValue()[1], 0); - vectorValues2.nextDoc(); - assertEquals(2f, vectorValues2.vectorValue()[1], 0); - assertEquals(NO_MORE_DOCS, vectorValues2.nextDoc()); + vectorValues2.iterator().nextDoc(); + assertEquals(2f, vectorValues2.vectorValue(0)[1], 0); + vectorValues2.iterator().nextDoc(); + assertEquals(2f, vectorValues2.vectorValue(1)[1], 0); + assertEquals(NO_MORE_DOCS, vectorValues2.iterator().nextDoc()); FloatVectorValues vectorValues3 = leaf.getFloatVectorValues("field3"); assertEquals(4, vectorValues3.dimension()); assertEquals(1, vectorValues3.size()); - vectorValues3.nextDoc(); - assertEquals(1f, vectorValues3.vectorValue()[0], 0.1); - assertEquals(NO_MORE_DOCS, vectorValues3.nextDoc()); + vectorValues3.iterator().nextDoc(); + assertEquals(1f, vectorValues3.vectorValue(0)[0], 0.1); + assertEquals(NO_MORE_DOCS, vectorValues3.iterator().nextDoc()); } } } @@ -1292,8 +1300,8 @@ public void testRandom() throws Exception { totalSize += vectorValues.size(); StoredFields storedFields = ctx.reader().storedFields(); int docId; - while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) { - float[] v = vectorValues.vectorValue(); + while ((docId = vectorValues.iterator().nextDoc()) != NO_MORE_DOCS) { + float[] v = vectorValues.vectorValue(vectorValues.iterator().index()); assertEquals(dimension, v.length); String idString = storedFields.document(docId).getField("id").stringValue(); int id = Integer.parseInt(idString); @@ -1372,8 +1380,8 @@ public void testRandomBytes() throws Exception { totalSize += vectorValues.size(); StoredFields storedFields = ctx.reader().storedFields(); int docId; - while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) { - byte[] v = vectorValues.vectorValue(); + while ((docId = vectorValues.iterator().nextDoc()) != NO_MORE_DOCS) { + byte[] v = vectorValues.vectorValue(vectorValues.iterator().index()); assertEquals(dimension, v.length); String idString = storedFields.document(docId).getField("id").stringValue(); int id = Integer.parseInt(idString); @@ -1492,8 +1500,8 @@ public void testRandomWithUpdatesAndGraph() throws Exception { StoredFields storedFields = ctx.reader().storedFields(); int docId; int numLiveDocsWithVectors = 0; - while ((docId = vectorValues.nextDoc()) != NO_MORE_DOCS) { - float[] v = vectorValues.vectorValue(); + while ((docId = vectorValues.iterator().nextDoc()) != NO_MORE_DOCS) { + float[] v = vectorValues.vectorValue(vectorValues.iterator().index()); assertEquals(dimension, v.length); String idString = storedFields.document(docId).getField("id").stringValue(); int id = Integer.parseInt(idString); @@ -1701,24 +1709,25 @@ public void testAdvance() throws Exception { int[] vectorDocs = new int[vectorValues.size() + 1]; int cur = -1; while (++cur < vectorValues.size() + 1) { - vectorDocs[cur] = vectorValues.nextDoc(); + vectorDocs[cur] = vectorValues.iterator().nextDoc(); if (cur != 0) { assertTrue(vectorDocs[cur] > vectorDocs[cur - 1]); } } vectorValues = r.getFloatVectorValues(fieldName); + DocIdSetIterator iter = vectorValues.iterator(); cur = -1; for (int i = 0; i < numdocs; i++) { // randomly advance to i if (random().nextInt(4) == 3) { while (vectorDocs[++cur] < i) {} - assertEquals(vectorDocs[cur], vectorValues.advance(i)); - assertEquals(vectorDocs[cur], vectorValues.docID()); - if (vectorValues.docID() == NO_MORE_DOCS) { + assertEquals(vectorDocs[cur], iter.advance(i)); + assertEquals(vectorDocs[cur], iter.docID()); + if (iter.docID() == NO_MORE_DOCS) { break; } // make i equal to docid so that it is greater than docId in the next loop iteration - i = vectorValues.docID(); + i = iter.docID(); } } } @@ -1776,9 +1785,10 @@ public void testVectorValuesReportCorrectDocs() throws Exception { if (byteVectorValues != null) { docCount += byteVectorValues.size(); StoredFields storedFields = ctx.reader().storedFields(); - while (byteVectorValues.nextDoc() != NO_MORE_DOCS) { - checksum += byteVectorValues.vectorValue()[0]; - Document doc = storedFields.document(byteVectorValues.docID(), Set.of("id")); + for (int ord = 0; ord < byteVectorValues.size(); ord++) { + checksum += byteVectorValues.vectorValue(ord)[0]; + Document doc = + storedFields.document(byteVectorValues.ordToDoc(ord), Set.of("id")); sumDocIds += Integer.parseInt(doc.get("id")); } } @@ -1790,9 +1800,9 @@ public void testVectorValuesReportCorrectDocs() throws Exception { if (vectorValues != null) { docCount += vectorValues.size(); StoredFields storedFields = ctx.reader().storedFields(); - while (vectorValues.nextDoc() != NO_MORE_DOCS) { - checksum += vectorValues.vectorValue()[0]; - Document doc = storedFields.document(vectorValues.docID(), Set.of("id")); + for (int ord = 0; ord < vectorValues.size(); ord++) { + checksum += vectorValues.vectorValue(ord)[0]; + Document doc = storedFields.document(vectorValues.ordToDoc(ord), Set.of("id")); sumDocIds += Integer.parseInt(doc.get("id")); } } From 2bbf8f1feda6c959dea190696c10d1a57972c954 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Mon, 2 Sep 2024 15:10:14 -0400 Subject: [PATCH 02/25] adding some ordToDoc --- .../codecs/BufferingKnnVectorsWriter.java | 10 +++++++ .../lucene/codecs/KnnVectorsWriter.java | 30 +++++++++++-------- .../lucene95/OffHeapByteVectorValues.java | 5 ---- .../lucene95/OffHeapFloatVectorValues.java | 10 +++++++ .../apache/lucene/index/KnnVectorValues.java | 6 ++-- .../index/BaseKnnVectorsFormatTestCase.java | 1 - 6 files changed, 41 insertions(+), 21 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java index 675025a1c93e..11673f357bd5 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java @@ -285,6 +285,16 @@ public int size() { return vectors.size(); } + @Override + public int ordToDoc(int ord) { + return ord; + } + + @Override + public int docToOrd(int doc) { + return doc; + } + @Override public float[] vectorValue(int targetOrd) { return vectors.get(targetOrd); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index 408a336f1561..3b50c1a2496d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -317,12 +317,15 @@ public float[] vectorValue(int ord) throws IOException { @Override public int ordToDoc(int ord) { - throw new UnsupportedOperationException(); - /** - * TODO: is this needed? int iSub = Arrays.binarySearch(ends, ord); if (iSub < 0) { iSub = - * -(iSub + 1); } if (iSub == 0) { return subs.get(iSub).values.ordToDoc(ord); } else { - * return subs.get(iSub).values.ordToDoc(ord - ends[iSub - 1]); } - */ + int iSub = Arrays.binarySearch(ends, ord); + if (iSub < 0) { + iSub = -(iSub + 1); + } + if (iSub == 0) { + return subs.get(iSub).values.ordToDoc(ord); + } else { + return subs.get(iSub).values.ordToDoc(ord - ends[iSub - 1]); + } } @Override @@ -389,12 +392,15 @@ public byte[] vectorValue(int ord) throws IOException { @Override public int ordToDoc(int ord) { - throw new UnsupportedOperationException(); - /** - * TODO: is this needed? int iSub = Arrays.binarySearch(ends, ord); if (iSub < 0) { iSub = - * -(iSub + 1); } if (iSub == 0) { return subs.get(iSub).values.ordToDoc(ord); } else { - * return subs.get(iSub).values.ordToDoc(ord - ends[iSub - 1]); } - */ + int iSub = Arrays.binarySearch(ends, ord); + if (iSub < 0) { + iSub = -(iSub + 1); + } + if (iSub == 0) { + return subs.get(iSub).values.ordToDoc(ord); + } else { + return subs.get(iSub).values.ordToDoc(ord - ends[iSub - 1]); + } } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index 9a399e581b8c..c6e9850331d2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -142,11 +142,6 @@ public DenseOffHeapVectorValues( super(dimension, size, slice, byteSize, flatVectorsScorer, vectorSimilarityFunction); } - @Override - public byte[] vectorValue(int ord) throws IOException { - return vectorValue(ord); - } - @Override public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues( diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 73d71e31e69b..821db225e67e 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -143,6 +143,16 @@ public DenseOffHeapVectorValues copy() throws IOException { dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); } + @Override + public int ordToDoc(int ord) { + return ord; + } + + @Override + public int docToOrd(int doc) { + return doc; + } + @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java index 7dafab09ce80..4302ea87a0b3 100644 --- a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java @@ -45,12 +45,12 @@ public abstract class KnnVectorValues { /** Return the docid of the document indexed with the given vector ordingl */ public int ordToDoc(int ord) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException("by class " + getClass().getName()); } /** Return the vector ordinal indexed for the given document or -1 if there is none */ public int docToOrd(int ord) { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException("by class " + getClass().getName()); } /** @@ -58,7 +58,7 @@ public int docToOrd(int ord) { * different values at once, to avoid overwriting the underlying vector returned. */ public KnnVectorValues copy() throws IOException { - throw new UnsupportedOperationException(); + throw new UnsupportedOperationException("by class " + getClass().getName()); } /** Returns the byte length of the vector values. */ diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index 51daaa0380d5..af2742b87fb3 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -946,7 +946,6 @@ public void testFloatVectorScorerIteration() throws Exception { assertNotNull(scorer); DocIdSetIterator iterator = scorer.iterator(); assertSame(iterator, scorer.iterator()); - assertSame(iterator, vectorValues.iterator()); assertNotSame(iterator, scorer); // verify scorer iteration scores are valid & iteration with vectorValues is consistent while (iterator.nextDoc() != NO_MORE_DOCS From a451fdb5494859f1da293b3b6ff19c36cb1fd812 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Mon, 2 Sep 2024 15:26:32 -0400 Subject: [PATCH 03/25] restore vector count argument to scalarquantizer methods --- .../Lucene99ScalarQuantizedVectorsWriter.java | 16 ++++++++++++++- .../util/quantization/ScalarQuantizer.java | 20 ++++++++++++++----- 2 files changed, 30 insertions(+), 6 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index 4b7e73cf8623..3b650bbf09ff 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -51,6 +51,7 @@ import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.internal.hppc.IntArrayList; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; @@ -651,8 +652,18 @@ public static ScalarQuantizer mergeAndRecalculateQuantiles( // TODO: this is very conservative, could we reuse information for even int4 quantization? || bits <= 4 || shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) { + int numVectors = 0; + DocIdSetIterator iter = + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState).iterator(); + // iterate vectorValues and increment numVectors + for (int doc = iter.nextDoc(); + doc != DocIdSetIterator.NO_MORE_DOCS; + doc = iter.nextDoc()) { + numVectors++; + } return buildScalarQuantizer( KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState), + numVectors, fieldInfo.getVectorSimilarityFunction(), confidenceInterval, bits); @@ -662,6 +673,7 @@ public static ScalarQuantizer mergeAndRecalculateQuantiles( static ScalarQuantizer buildScalarQuantizer( FloatVectorValues floatVectorValues, + int numVectors, VectorSimilarityFunction vectorSimilarityFunction, Float confidenceInterval, byte bits) @@ -672,13 +684,14 @@ static ScalarQuantizer buildScalarQuantizer( } if (confidenceInterval != null && confidenceInterval == DYNAMIC_CONFIDENCE_INTERVAL) { return ScalarQuantizer.fromVectorsAutoInterval( - floatVectorValues, vectorSimilarityFunction, bits); + floatVectorValues, vectorSimilarityFunction, numVectors, bits); } return ScalarQuantizer.fromVectors( floatVectorValues, confidenceInterval == null ? calculateDefaultConfidenceInterval(floatVectorValues.dimension()) : confidenceInterval, + numVectors, bits); } @@ -791,6 +804,7 @@ ScalarQuantizer createQuantizer() throws IOException { ScalarQuantizer quantizer = buildScalarQuantizer( new FloatVectorWrapper(floatVectors), + floatVectors.size(), fieldInfo.getVectorSimilarityFunction(), confidenceInterval, bits); diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java index 364b6a78e05b..271d059362a2 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java @@ -240,20 +240,28 @@ private static int[] reservoirSampleIndices(int numFloatVecs, int sampleSize) { * @throws IOException if there is an error reading the float vector values */ public static ScalarQuantizer fromVectors( - FloatVectorValues floatVectorValues, float confidenceInterval, byte bits) throws IOException { + FloatVectorValues floatVectorValues, + float confidenceInterval, + int totalVectorCount, + byte bits) + throws IOException { return fromVectors( - floatVectorValues, confidenceInterval, bits, SCALAR_QUANTIZATION_SAMPLE_SIZE); + floatVectorValues, + confidenceInterval, + totalVectorCount, + bits, + SCALAR_QUANTIZATION_SAMPLE_SIZE); } static ScalarQuantizer fromVectors( FloatVectorValues floatVectorValues, float confidenceInterval, + int totalVectorCount, byte bits, int quantizationSampleSize) throws IOException { assert 0.9f <= confidenceInterval && confidenceInterval <= 1f; assert quantizationSampleSize > SCRATCH_SIZE; - int totalVectorCount = floatVectorValues.size(); if (totalVectorCount == 0) { return new ScalarQuantizer(0f, 0f, bits); } @@ -311,10 +319,12 @@ static ScalarQuantizer fromVectors( } public static ScalarQuantizer fromVectorsAutoInterval( - FloatVectorValues floatVectorValues, VectorSimilarityFunction function, byte bits) + FloatVectorValues floatVectorValues, + VectorSimilarityFunction function, + int totalVectorCount, + byte bits) throws IOException { assert function != VectorSimilarityFunction.COSINE; - int totalVectorCount = floatVectorValues.size(); if (totalVectorCount == 0) { return new ScalarQuantizer(0f, 0f, bits); } From 8152b9dcb4254f15898470e37bc3a364a744fde8 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Tue, 3 Sep 2024 08:16:26 -0400 Subject: [PATCH 04/25] remove docToOrd; mostly can use iterator.index() --- .../lucene94/OffHeapByteVectorValues.java | 2 +- .../SimpleTextKnnVectorsReader.java | 6 ++-- .../codecs/BufferingKnnVectorsWriter.java | 32 +++++++++++-------- .../lucene/codecs/KnnVectorsWriter.java | 5 --- .../lucene95/OffHeapFloatVectorValues.java | 5 --- .../apache/lucene/index/KnnVectorValues.java | 10 ++---- .../lucene/index/memory/MemoryIndex.java | 6 ++-- .../valuesource/ByteKnnVectorFieldSource.java | 2 +- .../FloatKnnVectorFieldSource.java | 2 +- 9 files changed, 29 insertions(+), 41 deletions(-) diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index 53ec069f5d4f..0577528446cd 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -137,7 +137,7 @@ public VectorScorer scorer(byte[] query) throws IOException { @Override public float score() throws IOException { return vectorSimilarityFunction.compare( - copy.vectorValue(copy.docToOrd(copy.iterator().docID())), query); + copy.vectorValue(copy.iterator().docID()), query); } @Override diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index 66faf0f24fd0..714e31a8beb0 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -341,8 +341,7 @@ public VectorScorer scorer(float[] target) { return new VectorScorer() { @Override public float score() throws IOException { - int docId = simpleTextFloatVectorValues.iterator().docID(); - int ord = simpleTextFloatVectorValues.docToOrd(docId); + int ord = simpleTextFloatVectorValues.iterator().index(); return entry .similarityFunction() .compare(simpleTextFloatVectorValues.vectorValue(ord), target); @@ -428,8 +427,7 @@ public VectorScorer scorer(byte[] target) { return new VectorScorer() { @Override public float score() throws IOException { - int docId = simpleTextByteVectorValues.iterator().docID(); - int ord = simpleTextByteVectorValues.docToOrd(docId); + int ord = simpleTextByteVectorValues.iterator().index(); return entry .similarityFunction() .compare(simpleTextByteVectorValues.vectorValue(ord), target); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java index 11673f357bd5..bad64c74bb17 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java @@ -135,18 +135,27 @@ public SortingFloatVectorValues copy() { } } + // create a map from new ord to old ord assuming ords are sorted by doc but + // may be sparse private static int[] docMapToOrdMap(KnnVectorValues values, Sorter.DocMap docMap) { - int[] newToOld = new int[docMap.size()]; - int ord = 0; - for (int doc = 0; doc < newToOld.length; doc++) { - int oldDoc = docMap.newToOld(doc); - int oldOrd = values.docToOrd(oldDoc); // no value represented by -1 - if (oldOrd >= 0) { - newToOld[ord++] = oldOrd; + // fill with -1? + int[] newDocToOldOrd = new int[docMap.size()]; + int count = 0; + for (int ord = 0; ord < values.size(); ord++) { + int oldDoc = values.ordToDoc(ord); + int newDoc = docMap.oldToNew(oldDoc); + // no value will be represented by 0 + if (newDoc >= 0) { + newDocToOldOrd[newDoc] = ord + 1; + ++count; } } - if (ord < newToOld.length) { - newToOld = ArrayUtil.copyOfSubArray(newToOld, 0, ord); + int [] newToOld = new int[count]; + count = 0; + for (int ord = 0; ord < newDocToOldOrd.length; ord++) { + if (newDocToOldOrd[ord] > 0) { + newToOld[count++] = newDocToOldOrd[ord] - 1; + } } return newToOld; } @@ -290,11 +299,6 @@ public int ordToDoc(int ord) { return ord; } - @Override - public int docToOrd(int doc) { - return doc; - } - @Override public float[] vectorValue(int targetOrd) { return vectors.get(targetOrd); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index 3b50c1a2496d..231438430134 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -328,11 +328,6 @@ public int ordToDoc(int ord) { } } - @Override - public int docToOrd(int ord) { - throw new UnsupportedOperationException(); - } - @Override public int size() { return size; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 821db225e67e..71ed68c74e6c 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -148,11 +148,6 @@ public int ordToDoc(int ord) { return ord; } - @Override - public int docToOrd(int doc) { - return doc; - } - @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java index 4302ea87a0b3..f488fa5a2fa1 100644 --- a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java @@ -24,8 +24,7 @@ /** * This class abstracts addressing of document vector values indexed as {@link KnnFloatVectorField} - * or {@link KnnByteVectorField}. Hmm, we also need docToOrd(s)? We need to be able to retrieve a - * vector value for a document. + * or {@link KnnByteVectorField}. * * @lucene.experimental */ @@ -48,11 +47,6 @@ public int ordToDoc(int ord) { throw new UnsupportedOperationException("by class " + getClass().getName()); } - /** Return the vector ordinal indexed for the given document or -1 if there is none */ - public int docToOrd(int ord) { - throw new UnsupportedOperationException("by class " + getClass().getName()); - } - /** * Creates a new copy of this {@link KnnVectorValues}. This is helpful when you need to access * different values at once, to avoid overwriting the underlying vector returned. @@ -112,7 +106,7 @@ public int nextDoc() throws IOException { if (ord >= size() - 1) { return NO_MORE_DOCS; } else { - doc = docToOrd(++ord); + doc = ordToDoc(++ord); return doc; } } diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java index 1d46428dae2f..3d577e06e6f1 100644 --- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java +++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java @@ -2323,10 +2323,11 @@ public VectorScorer scorer(float[] query) { return new VectorScorer() { @Override public float score() throws IOException { + assert iterator().docID() == 0; return info.fieldInfo .getVectorSimilarityFunction() .compare( - vectorValues.vectorValue(vectorValues.docToOrd(vectorValues.iterator().docID())), + vectorValues.vectorValue(0), query); } @@ -2378,10 +2379,11 @@ public VectorScorer scorer(byte[] query) { return new VectorScorer() { @Override public float score() { + assert iterator().docID() == 0; return info.fieldInfo .getVectorSimilarityFunction() .compare( - vectorValues.vectorValue(vectorValues.docToOrd(vectorValues.iterator().docID())), + vectorValues.vectorValue(0), query); } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java index 26d3f21486dd..362a8e1f7853 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java @@ -67,7 +67,7 @@ protected DocIdSetIterator getVectorIterator() { @Override public byte[] byteVectorVal(int doc) throws IOException { if (exists(doc)) { - return vectorValues.vectorValue(vectorValues.docToOrd(vectorValues.iterator().docID())); + return vectorValues.vectorValue(vectorValues.iterator().index()); } else { return null; } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java index 1438477b61c9..824cc8a782df 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java @@ -66,7 +66,7 @@ protected DocIdSetIterator getVectorIterator() { @Override public float[] floatVectorVal(int doc) throws IOException { if (exists(doc)) { - return vectorValues.vectorValue(vectorValues.docToOrd(doc)); + return vectorValues.vectorValue(vectorValues.iterator().index()); } else { return null; } From dce766c9751d4379c6c730b322a05a95b73e7ff1 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Thu, 5 Sep 2024 17:53:04 -0400 Subject: [PATCH 05/25] Make KnnVectorValues primarily a random access API --- .../lucene90/Lucene90HnswVectorsReader.java | 3 +- .../lucene91/Lucene91HnswVectorsReader.java | 11 +- .../lucene92/OffHeapFloatVectorValues.java | 4 - .../lucene94/OffHeapByteVectorValues.java | 7 +- .../lucene94/OffHeapFloatVectorValues.java | 4 - .../SimpleTextKnnVectorsReader.java | 10 ++ .../SimpleTextKnnVectorsWriter.java | 10 +- .../codecs/BufferingKnnVectorsWriter.java | 7 +- .../lucene/codecs/KnnVectorsWriter.java | 160 +++++++++++------- .../codecs/hnsw/DefaultFlatVectorScorer.java | 14 +- .../lucene95/OffHeapByteVectorValues.java | 15 ++ .../lucene95/OffHeapFloatVectorValues.java | 15 ++ .../lucene99/Lucene99FlatVectorsWriter.java | 16 +- .../Lucene99ScalarQuantizedVectorsReader.java | 10 ++ .../Lucene99ScalarQuantizedVectorsWriter.java | 141 +++++++++++---- .../OffHeapQuantizedByteVectorValues.java | 46 ++--- .../apache/lucene/index/ByteVectorValues.java | 16 +- .../lucene/index/ExitableDirectoryReader.java | 86 +++++----- .../lucene/index/FloatVectorValues.java | 10 +- .../apache/lucene/index/KnnVectorValues.java | 145 +++++++++++----- .../SlowCompositeCodecReaderWrapper.java | 92 ++++++---- .../lucene/index/SortingCodecReader.java | 85 +++++++--- .../util/hnsw/IncrementalHnswGraphMerger.java | 24 ++- .../util/quantization/ScalarQuantizer.java | 43 +++-- ...estLucene99HnswQuantizedVectorsFormat.java | 2 +- ...tLucene99ScalarQuantizedVectorsFormat.java | 10 +- ...tLucene99ScalarQuantizedVectorsWriter.java | 2 +- .../index/TestExitableDirectoryReader.java | 10 +- .../org/apache/lucene/index/TestKnnGraph.java | 3 +- .../lucene/index/TestSortingCodecReader.java | 4 +- .../util/hnsw/AbstractMockVectorValues.java | 90 ---------- .../lucene/util/hnsw/HnswGraphTestCase.java | 146 +++++++++------- .../util/hnsw/MockByteVectorValues.java | 42 +++-- .../lucene/util/hnsw/MockVectorValues.java | 39 +++-- .../util/hnsw/TestHnswByteVectorGraph.java | 33 ++-- .../util/hnsw/TestHnswFloatVectorGraph.java | 32 ++-- .../TestScalarQuantizedVectorSimilarity.java | 14 +- .../quantization/TestScalarQuantizer.java | 91 ++++++---- .../lucene/index/memory/MemoryIndex.java | 8 +- .../index/BaseKnnVectorsFormatTestCase.java | 24 ++- .../lucene/tests/search/AssertingScorer.java | 2 +- 41 files changed, 919 insertions(+), 607 deletions(-) delete mode 100644 lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index e9a74b29726c..3b8334e05e51 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -31,7 +31,6 @@ import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; -import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.KnnCollector; @@ -422,7 +421,7 @@ public float score() throws IOException { } @Override - public KnnVectorValues.KnnValuesDocIterator iterator() { + public DocIterator iterator() { return values.iterator(); } }; diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 80fccb5a7ad8..b3b2e3324f4a 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -398,8 +398,7 @@ int ordToDoc(int ord) { } /** Read the vector values from the index input. This supports both iterated and random access. */ - static class OffHeapFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues.Floats { + static class OffHeapFloatVectorValues extends FloatVectorValues { private final int dimension; private final int size; @@ -410,9 +409,6 @@ static class OffHeapFloatVectorValues extends FloatVectorValues private final float[] value; private final VectorSimilarityFunction similarityFunction; - private int ord = -1; - private int doc = -1; - OffHeapFloatVectorValues( int dimension, int size, @@ -452,6 +448,11 @@ public float[] vectorValue(int targetOrd) throws IOException { return value; } + @Override + public int ordToDoc(int ord) { + return ordToDocOperator.applyAsInt(ord); + } + @Override public VectorScorer scorer(float[] target) { if (size == 0) { diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index 5cdfc5d7ad2c..ae232b1fb828 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -95,8 +95,6 @@ static OffHeapFloatVectorValues load( static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues { - private int doc = -1; - public DenseOffHeapVectorValues( int dimension, int size, @@ -216,8 +214,6 @@ public EmptyOffHeapVectorValues(int dimension) { super(dimension, 0, VectorSimilarityFunction.COSINE, null); } - private int doc = -1; - @Override public int dimension() { return super.dimension(); diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index 0577528446cd..b8ada89da3e3 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -108,8 +108,6 @@ static OffHeapByteVectorValues load( static class DenseOffHeapVectorValues extends OffHeapByteVectorValues { - private int doc = -1; - public DenseOffHeapVectorValues( int dimension, int size, @@ -136,8 +134,7 @@ public VectorScorer scorer(byte[] query) throws IOException { return new VectorScorer() { @Override public float score() throws IOException { - return vectorSimilarityFunction.compare( - copy.vectorValue(copy.iterator().docID()), query); + return vectorSimilarityFunction.compare(copy.vectorValue(copy.iterator().docID()), query); } @Override @@ -231,8 +228,6 @@ public EmptyOffHeapVectorValues(int dimension) { super(dimension, 0, null, VectorSimilarityFunction.COSINE, 0); } - private int doc = -1; - @Override public int dimension() { return super.dimension(); diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index 27dfbd98c937..77ca4c2bb9c5 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -104,8 +104,6 @@ static OffHeapFloatVectorValues load( static class DenseOffHeapVectorValues extends OffHeapFloatVectorValues { - private int doc = -1; - public DenseOffHeapVectorValues( int dimension, int size, @@ -228,8 +226,6 @@ public EmptyOffHeapVectorValues(int dimension) { super(dimension, 0, null, VectorSimilarityFunction.COSINE, 0); } - private int doc = -1; - @Override public int dimension() { return super.dimension(); diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index 714e31a8beb0..18fb7b2a085c 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -331,6 +331,11 @@ public float[] vectorValue(int ord) { return values[ord]; } + @Override + protected DocIterator createIterator() { + return createDenseIterator(this); + } + @Override public VectorScorer scorer(float[] target) { if (size() == 0) { @@ -418,6 +423,11 @@ public byte[] vectorValue(int ord) { return binaryValue.bytes; } + @Override + protected DocIterator createIterator() { + return createDenseIterator(this); + } + @Override public VectorScorer scorer(byte[] target) { if (size() == 0) { diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java index c71037f2ebd8..d6b25f0e4761 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java @@ -17,6 +17,8 @@ package org.apache.lucene.codecs.simpletext; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -26,6 +28,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.BytesRef; @@ -75,9 +78,10 @@ public void writeField(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, throws IOException { long vectorDataOffset = vectorData.getFilePointer(); List docIds = new ArrayList<>(); - for (int ord = 0; ord < floatVectorValues.size(); ord++) { - writeFloatVectorValue(floatVectorValues, ord); - docIds.add(floatVectorValues.ordToDoc(ord)); + KnnVectorValues.DocIterator iter = floatVectorValues.iterator(); + for (int docId = iter.nextDoc(); docId != NO_MORE_DOCS; docId = iter.nextDoc()) { + writeFloatVectorValue(floatVectorValues, iter.index()); + docIds.add(docId); } long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, docIds); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java index bad64c74bb17..620e083a42fa 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java @@ -150,7 +150,7 @@ private static int[] docMapToOrdMap(KnnVectorValues values, Sorter.DocMap docMap ++count; } } - int [] newToOld = new int[count]; + int[] newToOld = new int[count]; count = 0; for (int ord = 0; ord < newDocToOldOrd.length; ord++) { if (newDocToOldOrd[ord] > 0) { @@ -304,6 +304,11 @@ public float[] vectorValue(int targetOrd) { return vectors.get(targetOrd); } + @Override + public DocIterator createIterator() { + return createDenseIterator(this); + } + @Override public BufferedFloatVectorValues copy() { return this; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index 231438430134..b6c6ad53a199 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -30,6 +30,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.Sorter; import org.apache.lucene.index.VectorEncoding; @@ -55,24 +56,26 @@ protected KnnVectorsWriter() {} @SuppressWarnings("unchecked") public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { switch (fieldInfo.getVectorEncoding()) { - case BYTE: + case BYTE -> { KnnFieldVectorsWriter byteWriter = (KnnFieldVectorsWriter) addField(fieldInfo); ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); - for (int ord = 0; ord < mergedBytes.size(); ord++) { - byteWriter.addValue(mergedBytes.ordToDoc(ord), mergedBytes.vectorValue(ord)); + KnnVectorValues.DocIterator iter = mergedBytes.iterator(); + for (int doc = iter.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iter.nextDoc()) { + byteWriter.addValue(doc, mergedBytes.vectorValue(iter.index())); } - break; - case FLOAT32: + } + case FLOAT32 -> { KnnFieldVectorsWriter floatWriter = (KnnFieldVectorsWriter) addField(fieldInfo); FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - for (int ord = 0; ord < mergedFloats.size(); ord++) { - floatWriter.addValue(mergedFloats.ordToDoc(ord), mergedFloats.vectorValue(ord)); + KnnVectorValues.DocIterator iter = mergedFloats.iterator(); + for (int doc = iter.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iter.nextDoc()) { + floatWriter.addValue(doc, mergedFloats.vectorValue(iter.index())); } - break; + } } } @@ -114,16 +117,15 @@ private static class FloatVectorValuesSub extends DocIDMerger.Sub { final FloatVectorValues values; - int ord = -1; - FloatVectorValuesSub(MergeState.DocMap docMap, FloatVectorValues values) { super(docMap); this.values = values; + assert values.iterator().docID() == -1; } @Override public int nextDoc() throws IOException { - return values.ordToDoc(++ord); + return values.iterator().nextDoc(); } } @@ -131,16 +133,15 @@ private static class ByteVectorValuesSub extends DocIDMerger.Sub { final ByteVectorValues values; - int ord = -1; - ByteVectorValuesSub(MergeState.DocMap docMap, ByteVectorValues values) { super(docMap); this.values = values; + assert values.iterator().docID() == -1; } @Override public int nextDoc() throws IOException { - return values.ordToDoc(++ord); + return values.iterator().nextDoc(); } } @@ -285,47 +286,62 @@ static class MergedFloat32VectorValues extends FloatVectorValues { private final List subs; private final DocIDMerger docIdMerger; private final int size; - private final int[] ends; + private int docId; + FloatVectorValuesSub current; private MergedFloat32VectorValues(List subs, MergeState mergeState) throws IOException { this.subs = subs; docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort); int totalSize = 0; - ends = new int[subs.size()]; - int iSub = 0; for (FloatVectorValuesSub sub : subs) { totalSize += sub.values.size(); - ends[iSub++] = totalSize; } size = totalSize; + docId = -1; } @Override - public float[] vectorValue(int ord) throws IOException { - // TODO: if this is a bottleneck we could predict the last iSub - int iSub = Arrays.binarySearch(ends, ord); - if (iSub < 0) { - iSub = -(iSub + 1); - } - if (iSub == 0) { - return subs.get(iSub).values.vectorValue(ord); - } else { - return subs.get(iSub).values.vectorValue(ord - ends[iSub - 1]); - } + protected DocIterator createIterator() { + return new DocIterator() { + private int index = -1; + + @Override + public int docID() { + return docId; + } + + @Override + public int index() { + return index; + } + + @Override + public int nextDoc() throws IOException { + current = docIdMerger.next(); + if (current == null) { + docId = NO_MORE_DOCS; + index = NO_MORE_DOCS; + } else { + docId = current.mappedDocID; + ++index; + } + return docId; + } + + @Override + public long cost() { + return size; + } + }; } + // This really only needs to support iterative access as a first-class citizen (it is never + // used for searching) so simply ignore the supplied ordinal ... we'd like to assert that + // docId == ordToDoc(ord) but we didn't implement ordToDoc @Override - public int ordToDoc(int ord) { - int iSub = Arrays.binarySearch(ends, ord); - if (iSub < 0) { - iSub = -(iSub + 1); - } - if (iSub == 0) { - return subs.get(iSub).values.ordToDoc(ord); - } else { - return subs.get(iSub).values.ordToDoc(ord - ends[iSub - 1]); - } + public float[] vectorValue(int ord) throws IOException { + return current.values.vectorValue(current.values.iterator().index()); } @Override @@ -338,6 +354,11 @@ public int dimension() { return subs.get(0).values.dimension(); } + @Override + public int ordToDoc(int ord) { + throw new UnsupportedOperationException(); + } + @Override public VectorScorer scorer(float[] target) { throw new UnsupportedOperationException(); @@ -353,7 +374,6 @@ static class MergedByteVectorValues extends ByteVectorValues { private final List subs; private final DocIDMerger docIdMerger; private final int size; - private final int[] ends; private int docId; ByteVectorValuesSub current; @@ -363,39 +383,50 @@ private MergedByteVectorValues(List subs, MergeState mergeS this.subs = subs; docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort); int totalSize = 0; - ends = new int[subs.size()]; - int iSub = 0; for (ByteVectorValuesSub sub : subs) { totalSize += sub.values.size(); - ends[iSub++] = totalSize; } size = totalSize; } @Override public byte[] vectorValue(int ord) throws IOException { - int iSub = Arrays.binarySearch(ends, ord); - if (iSub < 0) { - iSub = -(iSub + 1); - } - if (iSub == 0) { - return subs.get(iSub).values.vectorValue(ord); - } else { - return subs.get(iSub).values.vectorValue(ord - ends[iSub - 1]); - } + return current.values.vectorValue(current.values.iterator().index()); } @Override - public int ordToDoc(int ord) { - int iSub = Arrays.binarySearch(ends, ord); - if (iSub < 0) { - iSub = -(iSub + 1); - } - if (iSub == 0) { - return subs.get(iSub).values.ordToDoc(ord); - } else { - return subs.get(iSub).values.ordToDoc(ord - ends[iSub - 1]); - } + protected DocIterator createIterator() { + return new DocIterator() { + private int index = -1; + + @Override + public int docID() { + return docId; + } + + @Override + public int index() { + return index; + } + + @Override + public int nextDoc() throws IOException { + current = docIdMerger.next(); + if (current == null) { + docId = NO_MORE_DOCS; + index = NO_MORE_DOCS; + } else { + docId = current.mappedDocID; + ++index; + } + return docId; + } + + @Override + public long cost() { + return size; + } + }; } @Override @@ -408,6 +439,11 @@ public int dimension() { return subs.get(0).values.dimension(); } + @Override + public int ordToDoc(int ord) { + throw new UnsupportedOperationException(); + } + @Override public VectorScorer scorer(byte[] target) { throw new UnsupportedOperationException(); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java index e1f1b24d3685..3e506037969a 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/DefaultFlatVectorScorer.java @@ -38,13 +38,17 @@ public class DefaultFlatVectorScorer implements FlatVectorsScorer { public RandomVectorScorerSupplier getRandomVectorScorerSupplier( VectorSimilarityFunction similarityFunction, KnnVectorValues vectorValues) throws IOException { - if (vectorValues instanceof FloatVectorValues floatVectorValues) { - return new FloatScoringSupplier(floatVectorValues, similarityFunction); - } else if (vectorValues instanceof ByteVectorValues byteVectorValues) { - return new ByteScoringSupplier(byteVectorValues, similarityFunction); + switch (vectorValues.getEncoding()) { + case FLOAT32 -> { + return new FloatScoringSupplier((FloatVectorValues) vectorValues, similarityFunction); + } + case BYTE -> { + return new ByteScoringSupplier((ByteVectorValues) vectorValues, similarityFunction); + } } throw new IllegalArgumentException( - "vectorValues must be an instance of FloatVectorValues or ByteVectorValues"); + "vectorValues must be an instance of FloatVectorValues or ByteVectorValues, got a " + + vectorValues.getClass().getName()); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index c6e9850331d2..4f4e5adbf887 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -148,6 +148,11 @@ public DenseOffHeapVectorValues copy() throws IOException { dimension, size, slice.clone(), byteSize, flatVectorsScorer, similarityFunction); } + @Override + public DocIterator createIterator() { + return createDenseIterator(this); + } + @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; @@ -228,6 +233,11 @@ public int ordToDoc(int ord) { return (int) ordToDoc.get(ord); } + @Override + protected DocIterator createIterator() { + return DocIterator.fromIndexedDISI(disi); + } + @Override public Bits getAcceptOrds(Bits acceptDocs) { if (acceptDocs == null) { @@ -291,6 +301,11 @@ public byte[] vectorValue(int ord) throws IOException { throw new UnsupportedOperationException(); } + @Override + protected DocIterator createIterator() { + return createDenseIterator(this); + } + @Override public EmptyOffHeapVectorValues copy() throws IOException { throw new UnsupportedOperationException(); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 71ed68c74e6c..8d6241faa422 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -153,6 +153,11 @@ public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; } + @Override + public DocIterator createIterator() { + return createDenseIterator(this); + } + @Override public VectorScorer scorer(float[] query) throws IOException { DenseOffHeapVectorValues copy = copy(); @@ -240,6 +245,11 @@ public int length() { }; } + @Override + protected DocIterator createIterator() { + return DocIterator.fromIndexedDISI(disi); + } + @Override public VectorScorer scorer(float[] query) throws IOException { SparseOffHeapVectorValues copy = copy(); @@ -288,6 +298,11 @@ public float[] vectorValue(int targetOrd) { throw new UnsupportedOperationException(); } + @Override + protected DocIterator createIterator() { + return createDenseIterator(this); + } + @Override public Bits getAcceptOrds(Bits acceptDocs) { return null; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java index d60f52698552..b331fca151ef 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java @@ -18,6 +18,7 @@ package org.apache.lucene.codecs.lucene99; import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.Closeable; import java.io.IOException; @@ -38,6 +39,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; @@ -360,12 +362,13 @@ private void writeMeta( private static DocsWithFieldSet writeByteVectorData( IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - for (int ord = 0; ord < byteVectorValues.size(); ord++) { + KnnVectorValues.DocIterator iter = byteVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector - byte[] binaryValue = byteVectorValues.vectorValue(ord); + byte[] binaryValue = byteVectorValues.vectorValue(iter.index()); assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize; output.writeBytes(binaryValue, binaryValue.length); - docsWithField.add(byteVectorValues.ordToDoc(ord)); + docsWithField.add(docV); } return docsWithField; } @@ -379,12 +382,13 @@ private static DocsWithFieldSet writeVectorData( ByteBuffer buffer = ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize) .order(ByteOrder.LITTLE_ENDIAN); - for (int ord = 0; ord < floatVectorValues.size(); ord++) { + KnnVectorValues.DocIterator iter = floatVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector - float[] value = floatVectorValues.vectorValue(ord); + float[] value = floatVectorValues.vectorValue(iter.index()); buffer.asFloatBuffer().put(value); output.writeBytes(buffer.array(), buffer.limit()); - docsWithField.add(floatVectorValues.ordToDoc(ord)); + docsWithField.add(docV); } return docsWithField; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java index 8badd1f248fa..90befb38ad38 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java @@ -424,9 +424,19 @@ public float[] vectorValue(int ord) throws IOException { return rawVectorValues.vectorValue(ord); } + @Override + public int ordToDoc(int ord) { + return rawVectorValues.ordToDoc(ord); + } + @Override public VectorScorer scorer(float[] query) throws IOException { return quantizedVectorValues.scorer(query); } + + @Override + public DocIterator iterator() { + return rawVectorValues.iterator(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index 3b650bbf09ff..f6dab2b67031 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -22,6 +22,7 @@ import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.DYNAMIC_CONFIDENCE_INTERVAL; import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT; import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; import java.io.Closeable; @@ -29,7 +30,6 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.Objects; import org.apache.lucene.codecs.CodecUtil; @@ -45,6 +45,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; @@ -654,11 +655,10 @@ public static ScalarQuantizer mergeAndRecalculateQuantiles( || shouldRecomputeQuantiles(mergedQuantiles, quantizationStates)) { int numVectors = 0; DocIdSetIterator iter = - KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState).iterator(); + KnnVectorsWriter.MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState) + .iterator(); // iterate vectorValues and increment numVectors - for (int doc = iter.nextDoc(); - doc != DocIdSetIterator.NO_MORE_DOCS; - doc = iter.nextDoc()) { + for (int doc = iter.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iter.nextDoc()) { numVectors++; } return buildScalarQuantizer( @@ -730,9 +730,10 @@ public static DocsWithFieldSet writeQuantizedVectorData( ? OffHeapQuantizedByteVectorValues.compressedArray( quantizedByteVectorValues.dimension(), bits) : null; - for (int ord = 0; ord < quantizedByteVectorValues.size(); ord++) { + KnnVectorValues.DocIterator iter = quantizedByteVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector - byte[] binaryValue = quantizedByteVectorValues.vectorValue(ord); + byte[] binaryValue = quantizedByteVectorValues.vectorValue(iter.index()); assert binaryValue.length == quantizedByteVectorValues.dimension() : "dim=" + quantizedByteVectorValues.dimension() + " len=" + binaryValue.length; if (compressedVector != null) { @@ -742,8 +743,8 @@ public static DocsWithFieldSet writeQuantizedVectorData( output.writeBytes(binaryValue, binaryValue.length); } output.writeInt( - Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant(ord))); - docsWithField.add(quantizedByteVectorValues.ordToDoc(ord)); + Float.floatToIntBits(quantizedByteVectorValues.getScoreCorrectionConstant(iter.index()))); + docsWithField.add(docV); } return docsWithField; } @@ -877,21 +878,30 @@ public float[] vectorValue(int ord) throws IOException { } return vectorList.get(ord); } + + @Override + public DocIterator createIterator() { + return createDenseIterator(this); + } } static class QuantizedByteVectorValueSub extends DocIDMerger.Sub { + private final int maxDoc; private final QuantizedByteVectorValues values; int ord = -1; - QuantizedByteVectorValueSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) { + QuantizedByteVectorValueSub( + int maxDoc, MergeState.DocMap docMap, QuantizedByteVectorValues values) { super(docMap); + this.maxDoc = maxDoc; this.values = values; + assert values.iterator().docID() == -1; } @Override public int nextDoc() throws IOException { - return values.ordToDoc(++ord); + return values.iterator().nextDoc(); } } @@ -925,12 +935,14 @@ public static MergedQuantizedVectorValues mergeQuantizedByteVectorValues( } sub = new QuantizedByteVectorValueSub( + mergeState.maxDocs[i], mergeState.docMaps[i], new QuantizedFloatVectorValues( toQuantize, fieldInfo.getVectorSimilarityFunction(), scalarQuantizer)); } else { sub = new QuantizedByteVectorValueSub( + mergeState.maxDocs[i], mergeState.docMaps[i], new OffsetCorrectedQuantizedByteVectorValues( reader.getQuantizedVectorValues(fieldInfo.name), @@ -949,9 +961,6 @@ public static MergedQuantizedVectorValues mergeQuantizedByteVectorValues( private final int size; private final int[] ends; - private int docId; - private QuantizedByteVectorValueSub current; - private MergedQuantizedVectorValues( List subs, MergeState mergeState) throws IOException { this.subs = subs; @@ -968,25 +977,15 @@ private MergedQuantizedVectorValues( @Override public byte[] vectorValue(int ord) throws IOException { - int iSub = Arrays.binarySearch(ends, ord); - if (iSub < 0) { - iSub = -(iSub + 1); - } - if (iSub == 0) { - return subs.get(iSub).values.vectorValue(ord); - } else { - return subs.get(iSub).values.vectorValue(ord - ends[iSub - 1]); - } + int iSub = ((CompositeIterator) iterator()).iSub; + QuantizedByteVectorValues values = subs.get(iSub).values; + assert ord == values.iterator().index() + (iSub == 0 ? 0 : ends[iSub - 1]); + return values.vectorValue(values.iterator().index()); } @Override - public int ordToDoc(int ord) { - throw new UnsupportedOperationException(); - /** - * TODO: is this needed? int iSub = Arrays.binarySearch(ends, ord); if (iSub < 0) { iSub = - * -(iSub + 1); } if (iSub == 0) { return subs.get(iSub).values.ordToDoc(ord); } else { return - * subs.get(iSub).values.ordToDoc(ord - ends[iSub - 1]); } - */ + protected DocIterator createIterator() { + return new CompositeIterator(subs); } @Override @@ -1001,7 +1000,57 @@ public int dimension() { @Override public float getScoreCorrectionConstant(int ord) throws IOException { - return current.values.getScoreCorrectionConstant(ord); + int iSub = ((CompositeIterator) iterator()).iSub; + + QuantizedByteVectorValues current = subs.get(iSub).values; + return current.getScoreCorrectionConstant(current.iterator().index()); + } + + static class CompositeIterator extends KnnVectorValues.DocIterator { + private final List subs; + int iSub; + int docBase; + int ordBase; + int docId; + + public CompositeIterator(List subs) { + this.subs = subs; + iSub = 0; + docBase = 0; + ordBase = 0; + docId = -1; + } + + @Override + public int index() { + if (iSub == subs.size()) { + return NO_MORE_DOCS; + } + return ordBase + subs.get(iSub).values.iterator().index(); + } + + @Override + public int docID() { + if (iSub == subs.size()) { + return NO_MORE_DOCS; + } + int subDocId = subs.get(iSub).values.iterator().docID(); + return docBase + subDocId; + } + + @Override + public int nextDoc() throws IOException { + while (iSub < subs.size()) { + int doc = subs.get(iSub).nextMappedDoc(); + if (doc != NO_MORE_DOCS) { + return docBase + doc; + } + ordBase += subs.get(iSub).values.size(); + docBase += subs.get(iSub).maxDoc; + ++iSub; + } + return NO_MORE_DOCS; + } } } @@ -1063,6 +1112,16 @@ public VectorScorer scorer(float[] target) throws IOException { private float quantize(int ord) throws IOException { return quantizer.quantize(values.vectorValue(ord), quantizedVector, vectorSimilarityFunction); } + + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + + @Override + public DocIterator iterator() { + return values.iterator(); + } } static final class ScalarQuantizedCloseableRandomVectorScorerSupplier @@ -1137,6 +1196,16 @@ public int size() { public byte[] vectorValue(int ord) throws IOException { return in.vectorValue(ord); } + + @Override + public int ordToDoc(int ord) { + return in.ordToDoc(ord); + } + + @Override + public DocIterator iterator() { + return in.iterator(); + } } static final class NormalizedFloatVectorValues extends FloatVectorValues { @@ -1159,11 +1228,21 @@ public int size() { return values.size(); } + @Override + public int ordToDoc(int ord) { + return values.ordToDoc(ord); + } + @Override public float[] vectorValue(int ord) throws IOException { System.arraycopy(values.vectorValue(ord), 0, normalizedVector, 0, normalizedVector.length); VectorUtil.l2normalize(normalizedVector); return normalizedVector; } + + @Override + public DocIterator iterator() { + return values.iterator(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index 5d488c200ce1..8129e65c87d1 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -255,6 +255,11 @@ public DocIdSetIterator iterator() { } }; } + + @Override + protected DocIterator createIterator() { + return createDenseIterator(this); + } } private static class SparseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues { @@ -283,11 +288,8 @@ public SparseOffHeapVectorValues( } @Override - public KnnValuesDocIterator iterator() { - if (iterator == null) { - iterator = fromIndexedDISI(disi); - } - return iterator; + public DocIterator createIterator() { + return DocIterator.fromIndexedDISI(disi); } @Override @@ -374,6 +376,11 @@ public int size() { return 0; } + @Override + protected DocIterator createIterator() { + return createDenseIterator(this); + } + @Override public EmptyOffHeapVectorValues copy() { throw new UnsupportedOperationException(); @@ -399,33 +406,4 @@ public VectorScorer scorer(float[] target) { return null; } } - - static KnnValuesDocIterator fromIndexedDISI(IndexedDISI disi) { - return new KnnValuesDocIterator() { - @Override - public int docID() { - return disi.docID(); - } - - @Override - public int index() { - return disi.index(); - } - - @Override - public int nextDoc() throws IOException { - return disi.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - return disi.advance(target); - } - - @Override - public long cost() { - return disi.cost(); - } - }; - } } diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index 1c59928c64bb..c1151c7f53f4 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -77,12 +77,6 @@ public VectorScorer scorer(byte[] query) throws IOException { throw new UnsupportedOperationException(); } - /** Returns the vector byte length, defaults to dimension multiplied by byte size */ - @Override - public int getVectorByteLength() { - return dimension() * Byte.BYTES; - } - /** * Returns a slice of the underlying {@link IndexInput} that contains the vector values if * available @@ -91,6 +85,11 @@ public IndexInput getSlice() { return null; } + @Override + public VectorEncoding getEncoding() { + return VectorEncoding.BYTE; + } + /** * Creates a {@link ByteVectorValues} from a list of byte arrays. * @@ -119,6 +118,11 @@ public byte[] vectorValue(int targetOrd) { public ByteVectorValues copy() { return this; } + + @Override + protected DocIterator createIterator() { + return createDenseIterator(this); + } }; } } diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java index e2395dfd40a3..961681bccae6 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java @@ -429,7 +429,6 @@ private void checkAndThrow(DocIdSetIterator in) { } private class ExitableFloatVectorValues extends FloatVectorValues { - private int nextCheck; private final FloatVectorValues vectorValues; public ExitableFloatVectorValues(FloatVectorValues vectorValues) { @@ -443,12 +442,6 @@ public int dimension() { @Override public float[] vectorValue(int ord) throws IOException { - if (nextCheck >= DOCS_BETWEEN_TIMEOUT_CHECK) { - checkAndThrow(); - nextCheck = 0; - } else { - nextCheck++; - } return vectorValues.vectorValue(ord); } @@ -462,6 +455,11 @@ public int size() { return vectorValues.size(); } + @Override + protected DocIterator createIterator() { + return createExitableIterator(vectorValues.iterator(), queryTimeout); + } + @Override public VectorScorer scorer(float[] target) throws IOException { return vectorValues.scorer(target); @@ -471,27 +469,9 @@ public VectorScorer scorer(float[] target) throws IOException { public FloatVectorValues copy() { throw new UnsupportedOperationException(); } - - /** - * Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or - * if {@link Thread#interrupted()} returns true. - */ - private void checkAndThrow() { - if (queryTimeout.shouldExit()) { - throw new ExitingReaderException( - "The request took too long to iterate over vector values. Timeout: " - + queryTimeout.toString() - + ", FloatVectorValues=" - + in); - } else if (Thread.interrupted()) { - throw new ExitingReaderException( - "Interrupted while iterating over vector values. FloatVectorValues=" + in); - } - } } private class ExitableByteVectorValues extends ByteVectorValues { - private int nextCheck; private final ByteVectorValues vectorValues; public ExitableByteVectorValues(ByteVectorValues vectorValues) { @@ -510,12 +490,6 @@ public int size() { @Override public byte[] vectorValue(int ord) throws IOException { - if (nextCheck >= DOCS_BETWEEN_TIMEOUT_CHECK) { - checkAndThrow(); - nextCheck = 0; - } else { - nextCheck++; - } return vectorValues.vectorValue(ord); } @@ -524,6 +498,11 @@ public int ordToDoc(int ord) { return vectorValues.ordToDoc(ord); } + @Override + protected DocIterator createIterator() { + return createExitableIterator(vectorValues.iterator(), queryTimeout); + } + @Override public VectorScorer scorer(byte[] target) throws IOException { return vectorValues.scorer(target); @@ -533,24 +512,47 @@ public VectorScorer scorer(byte[] target) throws IOException { public ByteVectorValues copy() { throw new UnsupportedOperationException(); } + } + } + + private static KnnVectorValues.DocIterator createExitableIterator( + KnnVectorValues.DocIterator delegate, QueryTimeout queryTimeout) { + return new KnnVectorValues.DocIterator() { + private int nextCheck; + + @Override + public int index() { + return delegate.index(); + } + + @Override + public int docID() { + return delegate.docID(); + } + + @Override + public int nextDoc() throws IOException { + int doc = delegate.nextDoc(); + if (doc >= nextCheck) { + checkAndThrow(); + nextCheck = doc + ExitableFilterAtomicReader.DOCS_BETWEEN_TIMEOUT_CHECK; + } + return doc; + } - /** - * Throws {@link ExitingReaderException} if {@link QueryTimeout#shouldExit()} returns true, or - * if {@link Thread#interrupted()} returns true. - */ private void checkAndThrow() { if (queryTimeout.shouldExit()) { throw new ExitingReaderException( - "The request took too long to iterate over vector values. Timeout: " + "The request took too long to iterate over knn vector values. Timeout: " + queryTimeout.toString() - + ", ByteVectorValues=" - + in); + + ", KnnVectorValues=" + + delegate); } else if (Thread.interrupted()) { throw new ExitingReaderException( - "Interrupted while iterating over vector values. ByteVectorValues=" + in); + "Interrupted while iterating over knn vector values. KnnVectorValues=" + delegate); } } - } + }; } /** Wrapper class for another PointValues implementation that is used by ExitableFields. */ @@ -663,7 +665,7 @@ private void checkAndThrow() { if (queryTimeout.shouldExit()) { throw new ExitingReaderException( "The request took too long to intersect point values. Timeout: " - + queryTimeout.toString() + + queryTimeout + ", PointValues=" + pointValues); } else if (Thread.interrupted()) { @@ -795,7 +797,7 @@ public void grow(int count) { /** Wrapper class for another Terms implementation that is used by ExitableFields. */ public static class ExitableTerms extends FilterTerms { - private QueryTimeout queryTimeout; + private final QueryTimeout queryTimeout; /** Constructor * */ public ExitableTerms(Terms terms, QueryTimeout queryTimeout) { diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index 4a9408a3c399..5220a2b6729e 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -86,10 +86,9 @@ public VectorScorer scorer(float[] target) throws IOException { throw new UnsupportedOperationException(); } - /** Returns the vector byte length, defaults to dimension multiplied by float byte size */ @Override - public int getVectorByteLength() { - return dimension() * Float.BYTES; + public VectorEncoding getEncoding() { + return VectorEncoding.FLOAT32; } /** @@ -120,6 +119,11 @@ public float[] vectorValue(int targetOrd) { public FloatVectorValues copy() { return this; } + + @Override + protected DocIterator createIterator() { + return createDenseIterator(this); + } }; } } diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java index f488fa5a2fa1..3004878351a1 100644 --- a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java @@ -17,6 +17,7 @@ package org.apache.lucene.index; import java.io.IOException; +import org.apache.lucene.codecs.lucene90.IndexedDISI; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.search.DocIdSetIterator; @@ -30,7 +31,7 @@ */ public abstract class KnnVectorValues { - protected KnnValuesDocIterator iterator; + protected DocIterator iterator; /** Return the dimension of the vectors */ public abstract int dimension(); @@ -42,9 +43,13 @@ public abstract class KnnVectorValues { */ public abstract int size(); - /** Return the docid of the document indexed with the given vector ordingl */ + /** + * Return the docid of the document indexed with the given vector ordinal. This default + * implementation returns the argument and is appropriate for dense values implementations where + * every doc has a value. + */ public int ordToDoc(int ord) { - throw new UnsupportedOperationException("by class " + getClass().getName()); + return ord; } /** @@ -55,8 +60,12 @@ public KnnVectorValues copy() throws IOException { throw new UnsupportedOperationException("by class " + getClass().getName()); } - /** Returns the byte length of the vector values. */ - public abstract int getVectorByteLength(); + /** Returns the vector byte length, defaults to dimension multiplied by float byte size */ + public int getVectorByteLength() { + return dimension() * getEncoding().byteSize; + } + + public abstract VectorEncoding getEncoding(); public Bits getAcceptOrds(Bits acceptDocs) { // FIXME: change default to return acceptDocs and provide this impl @@ -77,51 +86,101 @@ public int length() { }; } - public abstract static class KnnValuesDocIterator extends DocIdSetIterator { + public abstract static class DocIterator extends DocIdSetIterator { /** return the value index (aka "ordinal" or "ord") corresponding to the current doc */ public abstract int index(); + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + @Override + public long cost() { + throw new UnsupportedOperationException("for class " + getClass().getName()); + } + + public static DocIterator fromIndexedDISI(IndexedDISI disi) { + return new DocIterator() { + @Override + public int docID() { + return disi.docID(); + } + + @Override + public int index() { + return disi.index(); + } + + @Override + public int nextDoc() throws IOException { + return disi.nextDoc(); + } + + @Override + public int advance(int target) throws IOException { + return disi.advance(target); + } + + @Override + public long cost() { + return disi.cost(); + } + }; + } } - public KnnValuesDocIterator iterator() { + public DocIterator iterator() { if (iterator == null) { - iterator = - new KnnValuesDocIterator() { - - int ord = -1; - int doc = -1; - - @Override - public int docID() { - return doc; - } - - @Override - public int index() { - return ord; - } - - @Override - public int nextDoc() throws IOException { - if (ord >= size() - 1) { - return NO_MORE_DOCS; - } else { - doc = ordToDoc(++ord); - return doc; - } - } - - @Override - public int advance(int target) { - throw new UnsupportedOperationException(); - } - - @Override - public long cost() { - return size(); - } - }; + iterator = createIterator(); } return iterator; } + + protected DocIterator createIterator() { + // don't force every class to implement; some are just wrappers of other values and use their + // iterators + throw new UnsupportedOperationException(); + } + ; + + protected static DocIterator createDenseIterator(KnnVectorValues values) { + return new DocIterator() { + + int doc = -1; + + @Override + public int docID() { + return doc; + } + + @Override + public int index() { + return doc; + } + + @Override + public int nextDoc() throws IOException { + if (doc >= values.size() - 1) { + return doc = NO_MORE_DOCS; + } else { + return ++doc; + } + } + + @Override + public int advance(int target) { + if (target >= values.size()) { + return doc = NO_MORE_DOCS; + } + return doc = target; + } + + @Override + public long cost() { + return values.size(); + } + }; + } } diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index 68e489a31dd3..66aa99010b23 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -832,11 +832,9 @@ public void checkIntegrity() throws IOException { @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { - List subs = new ArrayList<>(); - int i = 0; + List subs = new ArrayList<>(); int dimension = -1; int size = 0; - int[] ends = new int[codecReaders.length]; for (CodecReader reader : codecReaders) { FloatVectorValues values = reader.getFloatVectorValues(field); if (values != null) { @@ -845,7 +843,6 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { } size += values.size(); } - ends[i++] = size; subs.add(values); } final int finalDimension = dimension; @@ -864,17 +861,14 @@ public int size() { @Override public float[] vectorValue(int ord) throws IOException { - int iSub = Arrays.binarySearch(ends, ord); - if (iSub < 0) { - iSub = -(iSub + 1); - } - int subOrd; - if (iSub == 0) { - subOrd = ord; - } else { - subOrd = ord - ends[iSub - 1]; - } - return subs.get(iSub).vectorValue(subOrd); + FloatVectorValues sub = + (FloatVectorValues) subs.get(((CompositeIterator) iterator()).iSub); + return sub.vectorValue(sub.iterator().index()); + } + + @Override + protected DocIterator createIterator() { + return new CompositeIterator(subs); } @Override @@ -896,11 +890,9 @@ public FloatVectorValues copy() { @Override public ByteVectorValues getByteVectorValues(String field) throws IOException { - List subs = new ArrayList<>(); - int i = 0; + List subs = new ArrayList<>(); int dimension = -1; int size = 0; - int[] ends = new int[codecReaders.length]; for (CodecReader reader : codecReaders) { ByteVectorValues values = reader.getByteVectorValues(field); if (values != null) { @@ -909,7 +901,6 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } size += values.size(); } - ends[i++] = size; subs.add(values); } final int finalDimension = dimension; @@ -928,17 +919,8 @@ public int size() { @Override public byte[] vectorValue(int ord) throws IOException { - int iSub = Arrays.binarySearch(ends, ord); - if (iSub < 0) { - iSub = -(iSub + 1); - } - int subOrd; - if (iSub == 0) { - subOrd = ord; - } else { - subOrd = ord - ends[iSub - 1]; - } - return subs.get(iSub).vectorValue(subOrd); + ByteVectorValues sub = (ByteVectorValues) subs.get(((CompositeIterator) iterator()).iSub); + return sub.vectorValue(sub.iterator().index()); } @Override @@ -955,6 +937,11 @@ public VectorScorer scorer(byte[] target) { public ByteVectorValues copy() { throw new UnsupportedOperationException(); } + + @Override + protected DocIterator createIterator() { + return new CompositeIterator(subs); + } }; } @@ -969,6 +956,51 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits throws IOException { throw new UnsupportedOperationException(); } + + static class CompositeIterator extends KnnVectorValues.DocIterator { + private final List subs; + int iSub; + int docId; + + public CompositeIterator(List subs) { + this.subs = subs; + iSub = 0; + docId = -1; + } + + @Override + public int index() { + if (iSub == subs.size()) { + return NO_MORE_DOCS; + } + return subs.get(iSub).iterator().index(); + } + + @Override + public int docID() { + if (iSub == subs.size()) { + return NO_MORE_DOCS; + } + return subs.get(iSub).iterator().docID(); + } + + @Override + public int nextDoc() throws IOException { + while (iSub < subs.size()) { + int doc = subs.get(iSub).iterator().nextDoc(); + if (doc != NO_MORE_DOCS) { + return doc; + } + ++iSub; + } + return NO_MORE_DOCS; + } + + @Override + public long cost() { + return 0; + } + } } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index 7e8a151a3465..73db8af872c9 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -35,7 +35,7 @@ import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; -import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.IOSupplier; @@ -206,24 +206,67 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue } } + private static class SortingValuesIterator extends KnnVectorValues.DocIterator { + private final BitSet docsWithValues; + private final int[] docToOrd; + private final int size; + + int doc = -1; + + SortingValuesIterator(KnnVectorValues value, Sorter.DocMap docMap) throws IOException { + docToOrd = new int[docMap.size()]; + docsWithValues = new FixedBitSet(docMap.size()); + KnnVectorValues.DocIterator iter = value.createIterator(); + int count = 0; + for (int doc = iter.nextDoc(); doc != NO_MORE_DOCS; doc = iter.nextDoc()) { + int newDocId = docMap.oldToNew(doc); + if (newDocId != -1) { + docToOrd[newDocId] = iter.index(); + docsWithValues.set(newDocId); + ++count; + } + } + size = count; + } + + @Override + public int docID() { + return doc; + } + + @Override + public int index() { + return docToOrd[doc]; + } + + @Override + public int nextDoc() { + if (doc >= docsWithValues.length() - 1) { + doc = NO_MORE_DOCS; + } else { + doc = docsWithValues.nextSetBit(doc + 1); + } + return doc; + } + + @Override + public long cost() { + return size; + } + } + /** Sorting FloatVectorValues that maps ordinals using the provided sortMap */ private static class SortingFloatVectorValues extends FloatVectorValues { final FloatVectorValues delegate; - final Sorter.DocMap sortMap; SortingFloatVectorValues(FloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate; - this.sortMap = sortMap; + iterator = new SortingValuesIterator(delegate, sortMap); } @Override public float[] vectorValue(int ord) throws IOException { - return delegate.vectorValue(sortMap.newToOld(ord)); - } - - @Override - public int ordToDoc(int ord) { - return delegate.ordToDoc(sortMap.newToOld(ord)); + return delegate.vectorValue(iterator.index()); } @Override @@ -233,37 +276,36 @@ public int dimension() { @Override public int size() { - return delegate.size(); + return (int) iterator.cost(); } @Override - public VectorScorer scorer(float[] target) { + public FloatVectorValues copy() { throw new UnsupportedOperationException(); } @Override - public FloatVectorValues copy() { - throw new UnsupportedOperationException(); + protected DocIterator createIterator() { + throw new IllegalStateException(); } } private static class SortingByteVectorValues extends ByteVectorValues { final ByteVectorValues delegate; - final Sorter.DocMap sortMap; SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate; - this.sortMap = sortMap; + iterator = new SortingValuesIterator(delegate, sortMap); } @Override public byte[] vectorValue(int ord) throws IOException { - return delegate.vectorValue(sortMap.newToOld(ord)); + return delegate.vectorValue(iterator().index()); } @Override - public int ordToDoc(int ord) { - return delegate.ordToDoc(sortMap.newToOld(ord)); + protected DocIterator createIterator() { + throw new IllegalStateException(); } @Override @@ -273,12 +315,7 @@ public int dimension() { @Override public int size() { - return delegate.size(); - } - - @Override - public VectorScorer scorer(byte[] target) { - throw new UnsupportedOperationException(); + return (int) iterator().cost(); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java index 85583654a53e..9188805d020a 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java @@ -16,6 +16,8 @@ */ package org.apache.lucene.util.hnsw; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + import java.io.IOException; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.hnsw.HnswGraphProvider; @@ -152,29 +154,35 @@ public OnHeapHnswGraph merge( */ protected final int[] getNewOrdMapping( KnnVectorValues mergedVectorValues, BitSet initializedNodes) throws IOException { - KnnVectorValues initializerValues = null; + KnnVectorValues.DocIterator initializerIterator = null; switch (fieldInfo.getVectorEncoding()) { - case BYTE -> initializerValues = initReader.getByteVectorValues(fieldInfo.name); - case FLOAT32 -> initializerValues = initReader.getFloatVectorValues(fieldInfo.name); + case BYTE -> initializerIterator = initReader.getByteVectorValues(fieldInfo.name).iterator(); + case FLOAT32 -> + initializerIterator = initReader.getFloatVectorValues(fieldInfo.name).iterator(); } IntIntHashMap newIdToOldOrdinal = new IntIntHashMap(initGraphSize); int maxNewDocID = -1; - for (int oldOrd = 0; oldOrd < initializerValues.size(); oldOrd++) { - int newId = initDocMap.get(initializerValues.ordToDoc(oldOrd)); + for (int docId = initializerIterator.nextDoc(); + docId != NO_MORE_DOCS; + docId = initializerIterator.nextDoc()) { + int newId = initDocMap.get(docId); maxNewDocID = Math.max(newId, maxNewDocID); - newIdToOldOrdinal.put(newId, oldOrd); + newIdToOldOrdinal.put(newId, initializerIterator.index()); } if (maxNewDocID == -1) { return new int[0]; } final int[] oldToNewOrdinalMap = new int[initGraphSize]; - for (int newOrd = 0; newOrd < mergedVectorValues.size(); newOrd++) { - int newDocId = mergedVectorValues.ordToDoc(newOrd); + KnnVectorValues.DocIterator mergedVectorIterator = mergedVectorValues.iterator(); + for (int newDocId = mergedVectorIterator.nextDoc(); + newDocId <= maxNewDocID; + newDocId = mergedVectorIterator.nextDoc()) { int hashDocIndex = newIdToOldOrdinal.indexOf(newDocId); if (newIdToOldOrdinal.indexExists(hashDocIndex)) { + int newOrd = mergedVectorIterator.index(); initializedNodes.set(newOrd); oldToNewOrdinalMap[newIdToOldOrdinal.indexGet(hashDocIndex)] = newOrd; } diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java index 271d059362a2..bea76f231cef 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java @@ -16,6 +16,8 @@ */ package org.apache.lucene.util.quantization; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; @@ -235,6 +237,8 @@ private static int[] reservoirSampleIndices(int numFloatVecs, int sampleSize) { * * @param floatVectorValues the float vector values from which to calculate the quantiles * @param confidenceInterval the confidence interval used to calculate the quantiles + * @param totalVectorCount the total number of live float vectors in the index. This is vital for + * accounting for deleted documents when calculating the quantiles. * @param bits the number of bits to use for quantization * @return A new {@link ScalarQuantizer} instance * @throws IOException if there is an error reading the float vector values @@ -268,8 +272,8 @@ static ScalarQuantizer fromVectors( if (confidenceInterval == 1f) { float min = Float.POSITIVE_INFINITY; float max = Float.NEGATIVE_INFINITY; - for (int ord = 0; ord < floatVectorValues.size(); ord++) { - for (float v : floatVectorValues.vectorValue(ord)) { + while (floatVectorValues.iterator().nextDoc() != NO_MORE_DOCS) { + for (float v : floatVectorValues.vectorValue(floatVectorValues.iterator().index())) { min = Math.min(min, v); max = Math.max(max, v); } @@ -285,8 +289,8 @@ static ScalarQuantizer fromVectors( if (totalVectorCount <= quantizationSampleSize) { int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount); int i = 0; - for (int ord = 0; ord < floatVectorValues.size(); ord++) { - float[] vectorValue = floatVectorValues.vectorValue(ord); + while (floatVectorValues.iterator().nextDoc() != NO_MORE_DOCS) { + float[] vectorValue = floatVectorValues.vectorValue(floatVectorValues.iterator().index()); System.arraycopy( vectorValue, 0, quantileGatheringScratch, i * vectorValue.length, vectorValue.length); i++; @@ -305,7 +309,13 @@ static ScalarQuantizer fromVectors( int index = 0; int idx = 0; for (int i : vectorsToTake) { - float[] vectorValue = floatVectorValues.vectorValue(i); + while (index <= i) { + // We cannot use `advance(docId)` as MergedVectorValues does not support it + floatVectorValues.iterator().nextDoc(); + index++; + } + assert floatVectorValues.iterator().docID() != NO_MORE_DOCS; + float[] vectorValue = floatVectorValues.vectorValue(floatVectorValues.iterator().index()); System.arraycopy( vectorValue, 0, quantileGatheringScratch, idx * vectorValue.length, vectorValue.length); idx++; @@ -346,8 +356,12 @@ public static ScalarQuantizer fromVectorsAutoInterval( if (totalVectorCount <= sampleSize) { int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount); int i = 0; - for (int ord = 0; ord < floatVectorValues.size(); ord++) { - gatherSample(floatVectorValues.vectorValue(ord), quantileGatheringScratch, sampledDocs, i); + while (floatVectorValues.iterator().nextDoc() != NO_MORE_DOCS) { + gatherSample( + floatVectorValues.vectorValue(floatVectorValues.iterator().index()), + quantileGatheringScratch, + sampledDocs, + i); i++; if (i == scratchSize) { extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum); @@ -362,7 +376,17 @@ public static ScalarQuantizer fromVectorsAutoInterval( int index = 0; int idx = 0; for (int i : vectorsToTake) { - gatherSample(floatVectorValues.vectorValue(i), quantileGatheringScratch, sampledDocs, idx); + while (index <= i) { + // We cannot use `advance(docId)` as MergedVectorValues does not support it + floatVectorValues.iterator().nextDoc(); + index++; + } + assert floatVectorValues.iterator().docID() != NO_MORE_DOCS; + gatherSample( + floatVectorValues.vectorValue(floatVectorValues.iterator().index()), + quantileGatheringScratch, + sampledDocs, + idx); idx++; if (idx == SCRATCH_SIZE) { extractQuantiles(confidenceIntervals, quantileGatheringScratch, upperSum, lowerSum); @@ -421,8 +445,7 @@ private static void extractQuantiles( } private static void gatherSample( - float[] vectorValue, float[] quantileGatheringScratch, List sampledDocs, int i) - throws IOException { + float[] vectorValue, float[] quantileGatheringScratch, List sampledDocs, int i) { float[] copy = new float[vectorValue.length]; System.arraycopy(vectorValue, 0, copy, 0, vectorValue.length); sampledDocs.add(copy); diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index 4ae6b8879fe6..0449851dee2d 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -139,7 +139,7 @@ public void testQuantizedVectorsWriteAndRead() throws Exception { new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors); ScalarQuantizer scalarQuantizer = Lucene99ScalarQuantizedVectorsWriter.buildScalarQuantizer( - toQuantize, similarityFunction, confidenceInterval, (byte) bits); + toQuantize, numVectors, similarityFunction, confidenceInterval, (byte) bits); float[] expectedCorrections = new float[numVectors]; byte[][] expectedVectors = new byte[numVectors][]; for (int i = 0; i < numVectors; i++) { diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java index 59cb3f68a0e2..5b1e912dc64f 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java @@ -17,6 +17,7 @@ package org.apache.lucene.codecs.lucene99; import static java.lang.String.format; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.oneOf; @@ -36,6 +37,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.NoMergePolicy; import org.apache.lucene.index.VectorSimilarityFunction; @@ -115,6 +117,7 @@ public void testQuantizedVectorsWriteAndRead() throws Exception { ScalarQuantizer scalarQuantizer = Lucene99ScalarQuantizedVectorsWriter.buildScalarQuantizer( new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors), + numVectors, similarityFunction, confidenceInterval, (byte) bits); @@ -170,9 +173,10 @@ public void testQuantizedVectorsWriteAndRead() throws Exception { QuantizedByteVectorValues quantizedByteVectorValues = quantizedReader.getQuantizedVectorValues("f"); int docId = -1; - for (int ord = 0; ord < quantizedByteVectorValues.size(); ord++) { - byte[] vector = quantizedByteVectorValues.vectorValue(ord); - float offset = quantizedByteVectorValues.getScoreCorrectionConstant(ord); + KnnVectorValues.DocIterator iter = quantizedByteVectorValues.iterator(); + for (docId = iter.nextDoc(); docId != NO_MORE_DOCS; docId = iter.nextDoc()) { + byte[] vector = quantizedByteVectorValues.vectorValue(iter.index()); + float offset = quantizedByteVectorValues.getScoreCorrectionConstant(iter.index()); for (int i = 0; i < dim; i++) { assertEquals(vector[i], expectedVectors[docId][i]); } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsWriter.java index a397f7f88bcd..0bf2a4ef6b87 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsWriter.java @@ -96,7 +96,7 @@ private void assertScalarQuantizer( new Lucene99ScalarQuantizedVectorsWriter.FloatVectorWrapper(vectors); ScalarQuantizer scalarQuantizer = Lucene99ScalarQuantizedVectorsWriter.buildScalarQuantizer( - vectorValues, vectorSimilarityFunction, confidenceInterval, bits); + vectorValues, 30, vectorSimilarityFunction, confidenceInterval, bits); assertEquals(expectedQuantiles[0], scalarQuantizer.getLowerQuantile(), 0.0001f); assertEquals(expectedQuantiles[1], scalarQuantizer.getUpperQuantile(), 0.0001f); } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java index 92f458a1ffb5..f60cd28c0ee1 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java @@ -565,9 +565,10 @@ public void testByteVectorValues() throws IOException { } private static void scanAndRetrieve(LeafReader leaf, KnnVectorValues values) throws IOException { - KnnVectorValues.KnnValuesDocIterator iter = values.iterator(); - for (int ord = 0; ord < values.size(); ord++) { - int docId = values.ordToDoc(ord); + KnnVectorValues.DocIterator iter = values.iterator(); + for (iter.nextDoc(); + iter.docID() != DocIdSetIterator.NO_MORE_DOCS && iter.docID() < leaf.maxDoc(); ) { + int docId = iter.docID(); if (docId >= leaf.maxDoc()) { break; } @@ -577,11 +578,10 @@ private static void scanAndRetrieve(LeafReader leaf, KnnVectorValues values) thr } else { iter.nextDoc(); } - ord = iter.index(); if (random().nextBoolean() && iter.docID() != DocIdSetIterator.NO_MORE_DOCS && values instanceof FloatVectorValues) { - ((FloatVectorValues) values).vectorValue(ord); + ((FloatVectorValues) values).vectorValue(iter.index()); } } } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index 644bfa3e5048..d3e69986c279 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -417,7 +417,8 @@ private void assertConsistentGraph(IndexWriter iw, float[][] values, String vect nextDocWithVectors = vectorValues.iterator().advance(i); while (i < nextDocWithVectors && i < reader.maxDoc()) { int id = Integer.parseInt(storedFields.document(i).get("id")); - assertNull("document " + id + " has no vector, but was expected to", values[id]); + assertNull( + "document " + id + ", expected to have no vector, does have one", values[id]); ++i; } if (nextDocWithVectors == NO_MORE_DOCS) { diff --git a/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java b/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java index 82f96897af1f..de588d3e15c7 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java @@ -262,7 +262,7 @@ public void testSortOnAddIndicesRandom() throws IOException { assertTrue(sorted_numeric_dv.advanceExact(idNext)); assertTrue(sorted_set_dv.advanceExact(idNext)); assertTrue(binary_sorted_dv.advanceExact(idNext)); - assertEquals(idNext, vectorValues.ordToDoc(i)); + assertEquals(idNext, vectorValues.iterator().advance(idNext)); assertEquals(new BytesRef(ids.longValue() + ""), binary_dv.binaryValue()); assertEquals( new BytesRef(ids.longValue() + ""), @@ -274,7 +274,7 @@ public void testSortOnAddIndicesRandom() throws IOException { assertEquals(1, sorted_numeric_dv.docValueCount()); assertEquals(ids.longValue(), sorted_numeric_dv.nextValue()); - float[] vectorValue = vectorValues.vectorValue(i); + float[] vectorValue = vectorValues.vectorValue(vectorValues.iterator().index()); assertEquals(1, vectorValue.length); assertEquals((float) ids.longValue(), vectorValue[0], 0.001f); diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java deleted file mode 100644 index 83603de88ed2..000000000000 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/AbstractMockVectorValues.java +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.util.hnsw; - -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - -import java.io.IOException; -import org.apache.lucene.index.KnnVectorValues; -import org.apache.lucene.util.BytesRef; - -abstract class AbstractMockVectorValues extends KnnVectorValues { - - protected final int dimension; - protected final T[] denseValues; - protected final T[] values; - protected final int numVectors; - protected final BytesRef binaryValue; - - protected int pos = -1; - - AbstractMockVectorValues(T[] values, int dimension, T[] denseValues, int numVectors) { - this.dimension = dimension; - this.values = values; - this.denseValues = denseValues; - // used by tests that build a graph from bytes rather than floats - binaryValue = new BytesRef(dimension); - binaryValue.length = dimension; - this.numVectors = numVectors; - } - - @Override - public int size() { - return numVectors; - } - - @Override - public int dimension() { - return dimension; - } - - public T vectorValue(int targetOrd) { - return denseValues[targetOrd]; - } - - @Override - public abstract AbstractMockVectorValues copy(); - - public abstract T vectorValue() throws IOException; - - private boolean seek(int target) { - if (target >= 0 && target < values.length && values[target] != null) { - pos = target; - return true; - } else { - return false; - } - } - - public int docID() { - return pos; - } - - public int nextDoc() { - return advance(pos + 1); - } - - public int advance(int target) { - while (++pos < values.length) { - if (seek(pos)) { - return pos; - } - } - return NO_MORE_DOCS; - } -} diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 24c43cb9e48b..dea6b679d9be 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -98,18 +98,14 @@ abstract class HnswGraphTestCase extends LuceneTestCase { abstract T randomVector(int dim); - abstract AbstractMockVectorValues vectorValues(int size, int dimension); + abstract KnnVectorValues vectorValues(int size, int dimension); - abstract AbstractMockVectorValues vectorValues(float[][] values); + abstract KnnVectorValues vectorValues(float[][] values); - abstract AbstractMockVectorValues vectorValues(LeafReader reader, String fieldName) - throws IOException; + abstract KnnVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException; - abstract AbstractMockVectorValues vectorValues( - int size, - int dimension, - AbstractMockVectorValues pregeneratedVectorValues, - int pregeneratedOffset); + abstract KnnVectorValues vectorValues( + int size, int dimension, KnnVectorValues pregeneratedVectorValues, int pregeneratedOffset); abstract Field knnVectorField(String name, T vector, VectorSimilarityFunction similarityFunction); @@ -134,6 +130,7 @@ protected RandomVectorScorer buildScorer(KnnVectorValues vectors, T query) throw // Tests writing segments of various sizes and merging to ensure there are no errors // in the HNSW graph merging logic. + @SuppressWarnings("unchecked") public void testRandomReadWriteAndMerge() throws IOException { int dim = random().nextInt(100) + 1; int[] segmentSizes = @@ -148,7 +145,7 @@ public void testRandomReadWriteAndMerge() throws IOException { int M = random().nextInt(4) + 2; int beamWidth = random().nextInt(10) + 5; long seed = random().nextLong(); - AbstractMockVectorValues vectors = vectorValues(numVectors, dim); + KnnVectorValues vectors = vectorValues(numVectors, dim); HnswGraphBuilder.randSeed = seed; try (Directory dir = newDirectory()) { @@ -173,7 +170,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { try (IndexWriter iw = new IndexWriter(dir, iwc)) { for (int i = 0; i < segmentSizes.length; i++) { int size = segmentSizes[i]; - while (vectors.nextDoc() < size) { + for (int ord = 0; ord < size; ord++) { if (isSparse[i] && random().nextBoolean()) { int d = random().nextInt(10) + 1; for (int j = 0; j < d; j++) { @@ -182,8 +179,24 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { } } Document doc = new Document(); - doc.add(knnVectorField("field", vectors.vectorValue(), similarityFunction)); - doc.add(new StringField("id", Integer.toString(vectors.docID()), Field.Store.NO)); + switch (vectors.getEncoding()) { + case BYTE -> { + doc.add( + knnVectorField( + "field", + (T) ((ByteVectorValues) vectors).vectorValue(ord), + similarityFunction)); + } + case FLOAT32 -> { + doc.add( + knnVectorField( + "field", + (T) ((FloatVectorValues) vectors).vectorValue(ord), + similarityFunction)); + } + } + ; + doc.add(new StringField("id", Integer.toString(vectors.ordToDoc(ord)), Field.Store.NO)); iw.addDocument(doc); } iw.commit(); @@ -199,13 +212,26 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { } try (IndexReader reader = DirectoryReader.open(dir)) { for (LeafReaderContext ctx : reader.leaves()) { - AbstractMockVectorValues values = vectorValues(ctx.reader(), "field"); + KnnVectorValues values = vectorValues(ctx.reader(), "field"); assertEquals(dim, values.dimension()); } } } } + @SuppressWarnings("unchecked") + private T vectorValue(KnnVectorValues vectors, int ord) throws IOException { + switch (vectors.getEncoding()) { + case BYTE -> { + return (T) ((ByteVectorValues) vectors).vectorValue(ord); + } + case FLOAT32 -> { + return (T) ((FloatVectorValues) vectors).vectorValue(ord); + } + } + throw new AssertionError("unknown encoding " + vectors.getEncoding()); + } + // test writing out and reading in a graph gives the expected graph public void testReadWrite() throws IOException { int dim = random().nextInt(100) + 1; @@ -213,8 +239,8 @@ public void testReadWrite() throws IOException { int M = random().nextInt(4) + 2; int beamWidth = random().nextInt(10) + 5; long seed = random().nextLong(); - AbstractMockVectorValues vectors = vectorValues(nDoc, dim); - AbstractMockVectorValues v2 = vectors.copy(), v3 = vectors.copy(); + KnnVectorValues vectors = vectorValues(nDoc, dim); + KnnVectorValues v2 = vectors.copy(), v3 = vectors.copy(); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, M, beamWidth, seed); HnswGraph hnsw = builder.build(vectors.size()); @@ -242,15 +268,16 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { } }); try (IndexWriter iw = new IndexWriter(dir, iwc)) { - while (v2.nextDoc() != NO_MORE_DOCS) { - while (indexedDoc < v2.docID()) { + while (v2.iterator().nextDoc() != NO_MORE_DOCS) { + while (indexedDoc < v2.iterator().docID()) { // increment docId in the index by adding empty documents iw.addDocument(new Document()); indexedDoc++; } Document doc = new Document(); - doc.add(knnVectorField("field", v2.vectorValue(), similarityFunction)); - doc.add(new StoredField("id", v2.docID())); + doc.add( + knnVectorField("field", vectorValue(v2, v2.iterator().index()), similarityFunction)); + doc.add(new StoredField("id", v2.iterator().docID())); iw.addDocument(doc); nVec++; indexedDoc++; @@ -258,7 +285,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { } try (IndexReader reader = DirectoryReader.open(dir)) { for (LeafReaderContext ctx : reader.leaves()) { - AbstractMockVectorValues values = vectorValues(ctx.reader(), "field"); + KnnVectorValues values = vectorValues(ctx.reader(), "field"); assertEquals(dim, values.dimension()); assertEquals(nVec, values.size()); assertEquals(indexedDoc, ctx.reader().maxDoc()); @@ -280,7 +307,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { public void testSortedAndUnsortedIndicesReturnSameResults() throws IOException { int dim = random().nextInt(10) + 3; int nDoc = random().nextInt(200) + 100; - AbstractMockVectorValues vectors = vectorValues(nDoc, dim); + KnnVectorValues vectors = vectorValues(nDoc, dim); int M = random().nextInt(10) + 5; int beamWidth = random().nextInt(10) + 10; @@ -323,15 +350,15 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { int indexedDoc = 0; try (IndexWriter iw = new IndexWriter(dir, iwc); IndexWriter iw2 = new IndexWriter(dir2, iwc2)) { - while (vectors.nextDoc() != NO_MORE_DOCS) { - while (indexedDoc < vectors.docID()) { + for (int ord = 0; ord < vectors.size(); ord++) { + while (indexedDoc < vectors.ordToDoc(ord)) { // increment docId in the index by adding empty documents iw.addDocument(new Document()); indexedDoc++; } Document doc = new Document(); - doc.add(knnVectorField("vector", vectors.vectorValue(), similarityFunction)); - doc.add(new StoredField("id", vectors.docID())); + doc.add(knnVectorField("vector", vectorValue(vectors, ord), similarityFunction)); + doc.add(new StoredField("id", vectors.ordToDoc(ord))); doc.add(new NumericDocValuesField("sortkey", random().nextLong())); iw.addDocument(doc); iw2.addDocument(doc); @@ -552,13 +579,13 @@ public void testHnswGraphBuilderInitializationFromGraph_withOffsetZero() throws int dim = atLeast(10); long seed = random().nextLong(); - AbstractMockVectorValues initializerVectors = vectorValues(initializerSize, dim); + KnnVectorValues initializerVectors = vectorValues(initializerSize, dim); RandomVectorScorerSupplier initialscorerSupplier = buildScorerSupplier(initializerVectors); HnswGraphBuilder initializerBuilder = HnswGraphBuilder.create(initialscorerSupplier, 10, 30, seed); OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size()); - AbstractMockVectorValues finalVectorValues = + KnnVectorValues finalVectorValues = vectorValues(totalSize, dim, initializerVectors, docIdOffset); int[] initializerOrdMap = createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset); @@ -598,13 +625,13 @@ public void testHnswGraphBuilderInitializationFromGraph_withNonZeroOffset() thro int dim = atLeast(10); long seed = random().nextLong(); - AbstractMockVectorValues initializerVectors = vectorValues(initializerSize, dim); + KnnVectorValues initializerVectors = vectorValues(initializerSize, dim); RandomVectorScorerSupplier initialscorerSupplier = buildScorerSupplier(initializerVectors); HnswGraphBuilder initializerBuilder = HnswGraphBuilder.create(initialscorerSupplier, 10, 30, seed); OnHeapHnswGraph initializerGraph = initializerBuilder.build(initializerVectors.size()); - AbstractMockVectorValues finalVectorValues = + KnnVectorValues finalVectorValues = vectorValues(totalSize, dim, initializerVectors.copy(), docIdOffset); int[] initializerOrdMap = createOffsetOrdinalMap(initializerSize, finalVectorValues, docIdOffset); @@ -688,19 +715,18 @@ private int[] mapArrayAndSort(int[] arr, int[] offset) { } private int[] createOffsetOrdinalMap( - int docIdSize, AbstractMockVectorValues totalVectorValues, int docIdOffset) { + int docIdSize, KnnVectorValues totalVectorValues, int docIdOffset) throws IOException { // Compute the offset for the ordinal map to be the number of non-null vectors in the total - // vector values - // before the docIdOffset + // vector values before the docIdOffset int ordinalOffset = 0; - while (totalVectorValues.nextDoc() < docIdOffset) { + while (totalVectorValues.iterator().nextDoc() < docIdOffset) { ordinalOffset++; } int[] offsetOrdinalMap = new int[docIdSize]; for (int curr = 0; - totalVectorValues.docID() < docIdOffset + docIdSize; - totalVectorValues.nextDoc()) { + totalVectorValues.iterator().docID() < docIdOffset + docIdSize; + totalVectorValues.iterator().nextDoc()) { offsetOrdinalMap[curr] = ordinalOffset + curr++; } @@ -771,7 +797,7 @@ public void testDiversity() throws IOException { unitVector2d(0.77), unitVector2d(0.6) }; - AbstractMockVectorValues vectors = vectorValues(values); + KnnVectorValues vectors = vectorValues(values); // First add nodes until everybody gets a full neighbor list RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 2, 10, random().nextInt()); @@ -825,7 +851,7 @@ public void testDiversityFallback() throws IOException { {10, 0, 0}, {0, 4, 0} }; - AbstractMockVectorValues vectors = vectorValues(values); + KnnVectorValues vectors = vectorValues(values); // First add nodes until everybody gets a full neighbor list RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 1, 10, random().nextInt()); @@ -855,7 +881,7 @@ public void testDiversity3d() throws IOException { {0, 0, 20}, {0, 9, 0} }; - AbstractMockVectorValues vectors = vectorValues(values); + KnnVectorValues vectors = vectorValues(values); // First add nodes until everybody gets a full neighbor list RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 1, 10, random().nextInt()); @@ -891,7 +917,7 @@ private void assertLevel0Neighbors(OnHeapHnswGraph graph, int node, int... expec public void testRandom() throws IOException { int size = atLeast(100); int dim = atLeast(10); - AbstractMockVectorValues vectors = vectorValues(size, dim); + KnnVectorValues vectors = vectorValues(size, dim); int topK = 5; RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 30, random().nextLong()); @@ -908,15 +934,13 @@ public void testRandom() throws IOException { TopDocs topDocs = actual.topDocs(); NeighborQueue expected = new NeighborQueue(topK, false); for (int j = 0; j < size; j++) { - if (vectors.vectorValue(j) != null && (acceptOrds == null || acceptOrds.get(j))) { + if (vectorValue(vectors, j) != null && (acceptOrds == null || acceptOrds.get(j))) { if (getVectorEncoding() == VectorEncoding.BYTE) { assert query instanceof byte[]; - expected.add( - j, similarityFunction.compare((byte[]) query, (byte[]) vectors.vectorValue(j))); + expected.add(j, similarityFunction.compare((byte[]) query, vectorValue(vectors, j))); } else { assert query instanceof float[]; - expected.add( - j, similarityFunction.compare((float[]) query, (float[]) vectors.vectorValue(j))); + expected.add(j, similarityFunction.compare((float[]) query, vectorValue(vectors, j))); } if (expected.size() > topK) { expected.pop(); @@ -940,7 +964,7 @@ public void testOnHeapHnswGraphSearch() throws IOException, ExecutionException, InterruptedException, TimeoutException { int size = atLeast(100); int dim = atLeast(10); - AbstractMockVectorValues vectors = vectorValues(size, dim); + KnnVectorValues vectors = vectorValues(size, dim); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); HnswGraphBuilder builder = HnswGraphBuilder.create(scorerSupplier, 10, 30, random().nextLong()); OnHeapHnswGraph hnsw = builder.build(vectors.size()); @@ -1004,7 +1028,7 @@ public void testOnHeapHnswGraphSearch() public void testConcurrentMergeBuilder() throws IOException { int size = atLeast(1000); int dim = atLeast(10); - AbstractMockVectorValues vectors = vectorValues(size, dim); + KnnVectorValues vectors = vectorValues(size, dim); RandomVectorScorerSupplier scorerSupplier = buildScorerSupplier(vectors); ExecutorService exec = Executors.newFixedThreadPool(4, new NamedThreadFactory("hnswMerge")); TaskExecutor taskExecutor = new TaskExecutor(exec); @@ -1033,7 +1057,7 @@ public void testAllNodesVisitedInSingleLevel() throws IOException { // Search for a large number of results int topK = size - 1; - AbstractMockVectorValues docVectors = vectorValues(size, dim); + KnnVectorValues docVectors = vectorValues(size, dim); HnswGraph graph = HnswGraphBuilder.create(buildScorerSupplier(docVectors), 10, 30, random().nextLong()) .build(size); @@ -1047,8 +1071,8 @@ public int numLevels() { } }; - AbstractMockVectorValues queryVectors = vectorValues(1, dim); - RandomVectorScorer queryScorer = buildScorer(docVectors, queryVectors.vectorValue(0)); + KnnVectorValues queryVectors = vectorValues(1, dim); + RandomVectorScorer queryScorer = buildScorer(docVectors, vectorValue(queryVectors, 0)); KnnCollector collector = new TopKnnCollector(topK, Integer.MAX_VALUE); HnswGraphSearcher.search(queryScorer, collector, singleLevelGraph, null); @@ -1217,27 +1241,25 @@ private Set getNeighborNodes(HnswGraph g) throws IOException { return neighbors; } - void assertVectorsEqual(AbstractMockVectorValues u, AbstractMockVectorValues v) - throws IOException { + void assertVectorsEqual(KnnVectorValues u, KnnVectorValues v) throws IOException { int uDoc, vDoc; - while (true) { - uDoc = u.nextDoc(); - vDoc = v.nextDoc(); + assertEquals(u.size(), v.size()); + for (int ord = 0; ord < u.size(); ord++) { + uDoc = u.ordToDoc(ord); + vDoc = v.ordToDoc(ord); assertEquals(uDoc, vDoc); - if (uDoc == NO_MORE_DOCS) { - break; - } + assertNotEquals(NO_MORE_DOCS, uDoc); switch (getVectorEncoding()) { case BYTE -> assertArrayEquals( "vectors do not match for doc=" + uDoc, - (byte[]) u.vectorValue(), - (byte[]) v.vectorValue()); + vectorValue(u, ord), + (byte[]) vectorValue(v, ord)); case FLOAT32 -> assertArrayEquals( "vectors do not match for doc=" + uDoc, - (float[]) u.vectorValue(), - (float[]) v.vectorValue(), + vectorValue(u, ord), + (float[]) vectorValue(v, ord), 1e-4f); default -> throw new IllegalArgumentException("unknown vector encoding: " + getVectorEncoding()); diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java index 6a51cb60335f..360f1b3cb0f5 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java @@ -17,12 +17,21 @@ package org.apache.lucene.util.hnsw; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.BytesRef; -class MockByteVectorValues extends AbstractMockVectorValues { +class MockByteVectorValues extends ByteVectorValues { + private final int dimension; + private final byte[][] denseValues; + protected final byte[][] values; + private final int numVectors; + private final BytesRef binaryValue; private final byte[] scratch; + private int pos = -1; + static MockByteVectorValues fromValues(byte[][] values) { byte[] firstNonNull = null; int j = 0; @@ -42,10 +51,26 @@ static MockByteVectorValues fromValues(byte[][] values) { } MockByteVectorValues(byte[][] values, int dimension, byte[][] denseValues, int numVectors) { - super(values, dimension, denseValues, numVectors); + this.dimension = dimension; + this.values = values; + this.denseValues = denseValues; + this.numVectors = numVectors; + // used by tests that build a graph from bytes rather than floats + binaryValue = new BytesRef(dimension); + binaryValue.length = dimension; scratch = new byte[dimension]; } + @Override + public int size() { + return values.length; + } + + @Override + public int dimension() { + return dimension; + } + @Override public MockByteVectorValues copy() { return new MockByteVectorValues( @@ -54,25 +79,20 @@ public MockByteVectorValues copy() { @Override public byte[] vectorValue(int ord) { - return values[ord]; - } - - @Override - public byte[] vectorValue() { if (LuceneTestCase.random().nextBoolean()) { - return values[pos]; + return values[ord]; } else { // Sometimes use the same scratch array repeatedly, mimicing what the codec will do. // This should help us catch cases of aliasing where the same ByteVectorValues source is used // twice in a // single computation. - System.arraycopy(values[pos], 0, scratch, 0, dimension); + System.arraycopy(values[ord], 0, scratch, 0, dimension); return scratch; } } @Override - public int getVectorByteLength() { - throw new UnsupportedOperationException(); + protected DocIterator createIterator() { + return createDenseIterator(this); } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java index cfd0f5a11e05..e494232774c3 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java @@ -17,12 +17,19 @@ package org.apache.lucene.util.hnsw; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.util.ArrayUtil; -class MockVectorValues extends AbstractMockVectorValues { +class MockVectorValues extends FloatVectorValues { + private final int dimension; + private final float[][] denseValues; + protected final float[][] values; + private final int numVectors; private final float[] scratch; + private int pos = -1; + static MockVectorValues fromValues(float[][] values) { float[] firstNonNull = null; int j = 0; @@ -42,10 +49,23 @@ static MockVectorValues fromValues(float[][] values) { } MockVectorValues(float[][] values, int dimension, float[][] denseValues, int numVectors) { - super(values, dimension, denseValues, numVectors); + this.dimension = dimension; + this.values = values; + this.denseValues = denseValues; + this.numVectors = numVectors; this.scratch = new float[dimension]; } + @Override + public int size() { + return values.length; + } + + @Override + public int dimension() { + return dimension; + } + @Override public MockVectorValues copy() { return new MockVectorValues( @@ -53,25 +73,20 @@ public MockVectorValues copy() { } @Override - public float[] vectorValue() { + public float[] vectorValue(int ord) { if (LuceneTestCase.random().nextBoolean()) { - return values[pos]; + return values[ord]; } else { // Sometimes use the same scratch array repeatedly, mimicing what the codec will do. // This should help us catch cases of aliasing where the same vector values source is used // twice in a single computation. - System.arraycopy(values[pos], 0, scratch, 0, dimension); + System.arraycopy(values[ord], 0, scratch, 0, dimension); return scratch; } } @Override - public float[] vectorValue(int targetOrd) { - return denseValues[targetOrd]; - } - - @Override - public int getVectorByteLength() { - throw new UnsupportedOperationException(); + protected DocIterator createIterator() { + return createDenseIterator(this); } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java index 88149b8f54d0..f0e6745211c6 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswByteVectorGraph.java @@ -17,13 +17,12 @@ package org.apache.lucene.util.hnsw; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - import com.carrotsearch.randomizedtesting.RandomizedTest; import java.io.IOException; import org.apache.lucene.document.Field; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; @@ -56,7 +55,7 @@ byte[] randomVector(int dim) { } @Override - AbstractMockVectorValues vectorValues(int size, int dimension) { + MockByteVectorValues vectorValues(int size, int dimension) { return MockByteVectorValues.fromValues(createRandomByteVectors(size, dimension, random())); } @@ -65,7 +64,7 @@ static boolean fitsInByte(float v) { } @Override - AbstractMockVectorValues vectorValues(float[][] values) { + MockByteVectorValues vectorValues(float[][] values) { byte[][] bValues = new byte[values.length][]; // The case when all floats fit within a byte already. boolean scaleSimple = fitsInByte(values[0][0]); @@ -86,36 +85,30 @@ AbstractMockVectorValues vectorValues(float[][] values) { } @Override - AbstractMockVectorValues vectorValues( - int size, - int dimension, - AbstractMockVectorValues pregeneratedVectorValues, - int pregeneratedOffset) { + MockByteVectorValues vectorValues( + int size, int dimension, KnnVectorValues pregeneratedVectorValues, int pregeneratedOffset) { + + MockByteVectorValues pvv = (MockByteVectorValues) pregeneratedVectorValues; byte[][] vectors = new byte[size][]; - byte[][] randomVectors = - createRandomByteVectors(size - pregeneratedVectorValues.values.length, dimension, random()); + byte[][] randomVectors = createRandomByteVectors(size - pvv.values.length, dimension, random()); for (int i = 0; i < pregeneratedOffset; i++) { vectors[i] = randomVectors[i]; } - int currentDoc; - while ((currentDoc = pregeneratedVectorValues.nextDoc()) != NO_MORE_DOCS) { - vectors[pregeneratedOffset + currentDoc] = pregeneratedVectorValues.values[currentDoc]; + for (int currentOrd = 0; currentOrd < pvv.size(); currentOrd++) { + vectors[pregeneratedOffset + currentOrd] = pvv.values[currentOrd]; } - for (int i = pregeneratedOffset + pregeneratedVectorValues.values.length; - i < vectors.length; - i++) { - vectors[i] = randomVectors[i - pregeneratedVectorValues.values.length]; + for (int i = pregeneratedOffset + pvv.values.length; i < vectors.length; i++) { + vectors[i] = randomVectors[i - pvv.values.length]; } return MockByteVectorValues.fromValues(vectors); } @Override - AbstractMockVectorValues vectorValues(LeafReader reader, String fieldName) - throws IOException { + MockByteVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException { ByteVectorValues vectorValues = reader.getByteVectorValues(fieldName); byte[][] vectors = new byte[reader.maxDoc()][]; for (int i = 0; i < vectorValues.size(); i++) { diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java index 8500db4a9cba..52d1da3dfa83 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswFloatVectorGraph.java @@ -17,13 +17,12 @@ package org.apache.lucene.util.hnsw; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - import com.carrotsearch.randomizedtesting.RandomizedTest; import java.io.IOException; import org.apache.lucene.document.Field; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; @@ -60,18 +59,17 @@ float[] randomVector(int dim) { } @Override - AbstractMockVectorValues vectorValues(int size, int dimension) { + MockVectorValues vectorValues(int size, int dimension) { return MockVectorValues.fromValues(createRandomFloatVectors(size, dimension, random())); } @Override - AbstractMockVectorValues vectorValues(float[][] values) { + MockVectorValues vectorValues(float[][] values) { return MockVectorValues.fromValues(values); } @Override - AbstractMockVectorValues vectorValues(LeafReader reader, String fieldName) - throws IOException { + MockVectorValues vectorValues(LeafReader reader, String fieldName) throws IOException { FloatVectorValues vectorValues = reader.getFloatVectorValues(fieldName); float[][] vectors = new float[reader.maxDoc()][]; for (int i = 0; i < vectorValues.size(); i++) { @@ -82,29 +80,23 @@ AbstractMockVectorValues vectorValues(LeafReader reader, String fieldNa } @Override - AbstractMockVectorValues vectorValues( - int size, - int dimension, - AbstractMockVectorValues pregeneratedVectorValues, - int pregeneratedOffset) { + MockVectorValues vectorValues( + int size, int dimension, KnnVectorValues pregeneratedVectorValues, int pregeneratedOffset) { + MockVectorValues pvv = (MockVectorValues) pregeneratedVectorValues; float[][] vectors = new float[size][]; float[][] randomVectors = - createRandomFloatVectors( - size - pregeneratedVectorValues.values.length, dimension, random()); + createRandomFloatVectors(size - pvv.values.length, dimension, random()); for (int i = 0; i < pregeneratedOffset; i++) { vectors[i] = randomVectors[i]; } - int currentDoc; - while ((currentDoc = pregeneratedVectorValues.nextDoc()) != NO_MORE_DOCS) { - vectors[pregeneratedOffset + currentDoc] = pregeneratedVectorValues.values[currentDoc]; + for (int currentOrd = 0; currentOrd < pvv.size(); currentOrd++) { + vectors[pregeneratedOffset + currentOrd] = pvv.values[currentOrd]; } - for (int i = pregeneratedOffset + pregeneratedVectorValues.values.length; - i < vectors.length; - i++) { - vectors[i] = randomVectors[i - pregeneratedVectorValues.values.length]; + for (int i = pregeneratedOffset + pvv.values.length; i < vectors.length; i++) { + vectors[i] = randomVectors[i - pvv.values.length]; } return MockVectorValues.fromValues(vectors); diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java index a001982049b5..69e8d030c119 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java @@ -59,7 +59,7 @@ public void testToEuclidean() throws IOException { float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f); FloatVectorValues floatVectorValues = fromFloats(floats); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, (byte) 7); + ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, numVecs, (byte) 7); byte[][] quantized = new byte[floats.length][]; float[] offsets = quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.EUCLIDEAN); @@ -91,7 +91,7 @@ public void testToCosine() throws IOException { float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f); FloatVectorValues floatVectorValues = fromFloatsNormalized(floats, null); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, (byte) 7); + ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, numVecs, (byte) 7); byte[][] quantized = new byte[floats.length][]; float[] offsets = quantizeVectorsNormalized( @@ -127,7 +127,7 @@ public void testToDotProduct() throws IOException { float error = Math.max((100 - confidenceInterval) * 0.01f, 0.01f); FloatVectorValues floatVectorValues = fromFloats(floats); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, (byte) 7); + ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, numVecs, (byte) 7); byte[][] quantized = new byte[floats.length][]; float[] offsets = quantizeVectors(scalarQuantizer, floats, quantized, VectorSimilarityFunction.DOT_PRODUCT); @@ -159,7 +159,7 @@ public void testToMaxInnerProduct() throws IOException { float error = Math.max((100 - confidenceInterval) * 0.5f, 0.5f); FloatVectorValues floatVectorValues = fromFloats(floats); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, (byte) 7); + ScalarQuantizer.fromVectors(floatVectorValues, confidenceInterval, numVecs, (byte) 7); byte[][] quantized = new byte[floats.length][]; float[] offsets = quantizeVectors( @@ -238,11 +238,11 @@ private static FloatVectorValues fromFloatsNormalized( float[][] floats, Set deletedVectors) { return new TestScalarQuantizer.TestSimpleFloatVectorValues(floats, deletedVectors) { @Override - public float[] vectorValue() throws IOException { - if (curDoc == -1 || curDoc >= floats.length) { + public float[] vectorValue(int ord) throws IOException { + if (iterator().docID() == -1 || iterator().docID() >= floats.length) { throw new IOException("Current doc not set or too many iterations"); } - float[] v = ArrayUtil.copyArray(floats[curDoc]); + float[] v = ArrayUtil.copyArray(floats[ordToDoc[ord]]); VectorUtil.l2normalize(v); return v; } diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index 5fffc1b69582..f88648b7abf6 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -47,8 +47,9 @@ public void testTinyVectors() throws IOException { } ScalarQuantizer scalarQuantizer = random().nextBoolean() - ? ScalarQuantizer.fromVectors(floatVectorValues, 0.9f, bits) - : ScalarQuantizer.fromVectorsAutoInterval(floatVectorValues, function, bits); + ? ScalarQuantizer.fromVectors(floatVectorValues, 0.9f, numVecs, bits) + : ScalarQuantizer.fromVectorsAutoInterval( + floatVectorValues, function, numVecs, bits); // We simply assert that we created a scalar quantizer and didn't trip any assertions // the quality of the quantization might be poor, but this is expected as sampling size is // tiny @@ -71,14 +72,16 @@ public void testNanAndInfValueFailure() { FloatVectorValues floatVectorValues = fromFloats(floats); expectThrows( IllegalStateException.class, - () -> ScalarQuantizer.fromVectors(floatVectorValues, 0.9f, bits)); + () -> ScalarQuantizer.fromVectors(floatVectorValues, 0.9f, numVecs, bits)); VectorSimilarityFunction actualFunction = function == VectorSimilarityFunction.COSINE ? VectorSimilarityFunction.DOT_PRODUCT : function; expectThrows( IllegalStateException.class, - () -> ScalarQuantizer.fromVectorsAutoInterval(floatVectorValues, actualFunction, bits)); + () -> + ScalarQuantizer.fromVectorsAutoInterval( + floatVectorValues, actualFunction, numVecs, bits)); } } } @@ -90,7 +93,8 @@ public void testQuantizeAndDeQuantize7Bit() throws IOException { float[][] floats = randomFloats(numVecs, dims); FloatVectorValues floatVectorValues = fromFloats(floats); - ScalarQuantizer scalarQuantizer = ScalarQuantizer.fromVectors(floatVectorValues, 1, (byte) 7); + ScalarQuantizer scalarQuantizer = + ScalarQuantizer.fromVectors(floatVectorValues, 1, numVecs, (byte) 7); float[] dequantized = new float[dims]; byte[] quantized = new byte[dims]; byte[] requantized = new byte[dims]; @@ -153,6 +157,7 @@ public void testScalarWithSampling() throws IOException { ScalarQuantizer.fromVectors( floatVectorValues, 0.99f, + floatVectorValues.numLiveVectors, (byte) 7, Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1)); } @@ -162,6 +167,7 @@ public void testScalarWithSampling() throws IOException { ScalarQuantizer.fromVectors( floatVectorValues, 0.99f, + floatVectorValues.numLiveVectors, (byte) 7, Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1)); } @@ -171,6 +177,7 @@ public void testScalarWithSampling() throws IOException { ScalarQuantizer.fromVectors( floatVectorValues, 0.99f, + floatVectorValues.numLiveVectors, (byte) 7, Math.max(floatVectorValues.numLiveVectors - 1, SCRATCH_SIZE + 1)); } @@ -180,6 +187,7 @@ public void testScalarWithSampling() throws IOException { ScalarQuantizer.fromVectors( floatVectorValues, 0.99f, + floatVectorValues.numLiveVectors, (byte) 7, Math.max(random().nextInt(floatVectorValues.floats.length - 1) + 1, SCRATCH_SIZE + 1)); } @@ -196,7 +204,8 @@ public void testFromVectorsAutoInterval4Bit() throws IOException { } FloatVectorValues floatVectorValues = fromFloats(floats); ScalarQuantizer scalarQuantizer = - ScalarQuantizer.fromVectorsAutoInterval(floatVectorValues, similarityFunction, (byte) 4); + ScalarQuantizer.fromVectorsAutoInterval( + floatVectorValues, similarityFunction, numVecs, (byte) 4); assertNotNull(scalarQuantizer); float[] dequantized = new float[dims]; byte[] quantized = new byte[dims]; @@ -264,14 +273,27 @@ static TestSimpleFloatVectorValues fromFloatsWithRandomDeletions( static class TestSimpleFloatVectorValues extends FloatVectorValues { protected final float[][] floats; protected final Set deletedVectors; + protected final int[] ordToDoc; protected final int numLiveVectors; - protected int curDoc = -1; TestSimpleFloatVectorValues(float[][] values, Set deletedVectors) { this.floats = values; this.deletedVectors = deletedVectors; - this.numLiveVectors = + numLiveVectors = deletedVectors == null ? values.length : values.length - deletedVectors.size(); + ordToDoc = new int[numLiveVectors]; + if (deletedVectors == null) { + for (int i = 0; i < numLiveVectors; i++) { + ordToDoc[i] = i; + } + } else { + int ord = 0; + for (int doc = 0; doc < values.length; doc++) { + if (!deletedVectors.contains(doc)) { + ordToDoc[ord++] = doc; + } + } + } } @Override @@ -281,41 +303,48 @@ public int dimension() { @Override public int size() { - // FIXME I don't get what this is up to return floats.length; } @Override public float[] vectorValue(int ord) throws IOException { - return floats[ord]; + return floats[ordToDoc(ord)]; } - public float[] vectorValue() throws IOException { - if (curDoc == -1 || curDoc >= floats.length) { - throw new IOException("Current doc not set or too many iterations"); - } - return floats[curDoc]; + @Override + public int ordToDoc(int ord) { + return ordToDoc[ord]; } - public int docID() { - if (curDoc >= floats.length) { - return NO_MORE_DOCS; - } - return curDoc; - } + @Override + public DocIterator createIterator() { + return new DocIterator() { + + int ord = -1; + int doc = -1; - public int nextDoc() throws IOException { - while (++curDoc < floats.length) { - if (deletedVectors == null || !deletedVectors.contains(curDoc)) { - return curDoc; + @Override + public int docID() { + return doc; } - } - return docID(); - } - public int advance(int target) throws IOException { - curDoc = target - 1; - return nextDoc(); + @Override + public int nextDoc() throws IOException { + while (doc < floats.length - 1) { + ++doc; + if (deletedVectors == null || !deletedVectors.contains(doc)) { + ++ord; + return doc; + } + } + return doc = NO_MORE_DOCS; + } + + @Override + public int index() { + return ord; + } + }; } @Override diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java index 3d577e06e6f1..817a219d6542 100644 --- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java +++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java @@ -2326,9 +2326,7 @@ public float score() throws IOException { assert iterator().docID() == 0; return info.fieldInfo .getVectorSimilarityFunction() - .compare( - vectorValues.vectorValue(0), - query); + .compare(vectorValues.vectorValue(0), query); } @Override @@ -2382,9 +2380,7 @@ public float score() { assert iterator().docID() == 0; return info.fieldInfo .getVectorSimilarityFunction() - .compare( - vectorValues.vectorValue(0), - query); + .compare(vectorValues.vectorValue(0), query); } @Override diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index af2742b87fb3..2294d4d01cbc 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -54,6 +54,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.MergePolicy; @@ -1006,7 +1007,6 @@ public void testByteVectorScorerIteration() throws Exception { assertNotNull(scorer); DocIdSetIterator iterator = scorer.iterator(); assertSame(iterator, scorer.iterator()); - assertSame(iterator, vectorValues.iterator()); assertNotSame(iterator, scorer); // verify scorer iteration scores are valid & iteration with vectorValues is consistent while (iterator.nextDoc() != NO_MORE_DOCS @@ -1777,6 +1777,7 @@ public void testVectorValuesReportCorrectDocs() throws Exception { double checksum = 0; int docCount = 0; long sumDocIds = 0; + long sumOrdToDocIds = 0; switch (vectorEncoding) { case BYTE -> { for (LeafReaderContext ctx : r.leaves()) { @@ -1784,11 +1785,17 @@ public void testVectorValuesReportCorrectDocs() throws Exception { if (byteVectorValues != null) { docCount += byteVectorValues.size(); StoredFields storedFields = ctx.reader().storedFields(); - for (int ord = 0; ord < byteVectorValues.size(); ord++) { + KnnVectorValues.DocIterator iter = byteVectorValues.iterator(); + for (iter.nextDoc(); iter.docID() != NO_MORE_DOCS; iter.nextDoc()) { + int ord = iter.index(); checksum += byteVectorValues.vectorValue(ord)[0]; + Document doc = storedFields.document(iter.docID(), Set.of("id")); + sumDocIds += Integer.parseInt(doc.get("id")); + } + for (int ord = 0; ord < byteVectorValues.size(); ord++) { Document doc = storedFields.document(byteVectorValues.ordToDoc(ord), Set.of("id")); - sumDocIds += Integer.parseInt(doc.get("id")); + sumOrdToDocIds += Integer.parseInt(doc.get("id")); } } } @@ -1799,11 +1806,17 @@ public void testVectorValuesReportCorrectDocs() throws Exception { if (vectorValues != null) { docCount += vectorValues.size(); StoredFields storedFields = ctx.reader().storedFields(); - for (int ord = 0; ord < vectorValues.size(); ord++) { + KnnVectorValues.DocIterator iter = vectorValues.iterator(); + for (iter.nextDoc(); iter.docID() != NO_MORE_DOCS; iter.nextDoc()) { + int ord = iter.index(); checksum += vectorValues.vectorValue(ord)[0]; - Document doc = storedFields.document(vectorValues.ordToDoc(ord), Set.of("id")); + Document doc = storedFields.document(iter.docID(), Set.of("id")); sumDocIds += Integer.parseInt(doc.get("id")); } + for (int ord = 0; ord < vectorValues.size(); ord++) { + Document doc = storedFields.document(vectorValues.ordToDoc(ord), Set.of("id")); + sumOrdToDocIds += Integer.parseInt(doc.get("id")); + } } } } @@ -1815,6 +1828,7 @@ public void testVectorValuesReportCorrectDocs() throws Exception { vectorEncoding == VectorEncoding.BYTE ? numDocs * 0.2 : 1e-5); assertEquals(fieldDocCount, docCount); assertEquals(fieldSumDocIDs, sumDocIds); + assertEquals(fieldSumDocIDs, sumOrdToDocIds); } } } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingScorer.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingScorer.java index 8badba0d12b7..dd408befdbf3 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingScorer.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/AssertingScorer.java @@ -183,7 +183,7 @@ public int advance(int target) throws IOException { } else { state = IteratorState.ITERATING; } - assert in.docID() == advanced; + assert in.docID() == advanced : in.docID() + " != " + advanced + " in " + in; assert AssertingScorer.this.in.docID() == in.docID(); return doc = advanced; } From 2f0cc8c6de9bc493f2e22bd9457d203981b1380f Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Fri, 6 Sep 2024 12:56:34 -0400 Subject: [PATCH 06/25] HasIndexSlice --- .../lucene/codecs/lucene95/HasIndexSlice.java | 23 +++++++++++++++++++ .../lucene95/OffHeapByteVectorValues.java | 4 +--- .../lucene95/OffHeapFloatVectorValues.java | 3 ++- .../Lucene99ScalarQuantizedVectorsWriter.java | 2 ++ .../apache/lucene/index/ByteVectorValues.java | 8 ------- .../lucene/index/FloatVectorValues.java | 9 -------- .../apache/lucene/index/KnnVectorValues.java | 7 +++--- .../SlowCompositeCodecReaderWrapper.java | 1 + .../QuantizedByteVectorValues.java | 9 -------- ...ucene99MemorySegmentFlatVectorsScorer.java | 5 ++-- .../codecs/quantization/SampleReader.java | 6 +++-- 11 files changed, 39 insertions(+), 38 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/codecs/lucene95/HasIndexSlice.java diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/HasIndexSlice.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/HasIndexSlice.java new file mode 100644 index 000000000000..8f2129fc16f0 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/HasIndexSlice.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.codecs.lucene95; + +import org.apache.lucene.store.IndexInput; + +public interface HasIndexSlice { + IndexInput getSlice(); +} diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index 4f4e5adbf887..71211a4586f0 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -35,7 +35,7 @@ /** Read the vector values from the index input. This supports both iterated and random access. */ public abstract class OffHeapByteVectorValues extends ByteVectorValues - implements RandomAccessVectorValues.Bytes { + implements HasIndexSlice { protected final int dimension; protected final int size; @@ -284,8 +284,6 @@ public EmptyOffHeapVectorValues( super(dimension, 0, null, 0, flatVectorsScorer, vectorSimilarityFunction); } - private int doc = -1; - @Override public int dimension() { return super.dimension(); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 8d6241faa422..c99b72b6242f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -34,7 +34,7 @@ /** Read the vector values from the index input. This supports both iterated and random access. */ public abstract class OffHeapFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues.Floats { + implements HasIndexSlice { protected final int dimension; protected final int size; @@ -71,6 +71,7 @@ public int size() { return size; } + @Override public IndexInput getSlice() { return slice; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index f6dab2b67031..be101cc9244e 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -1040,6 +1040,8 @@ public int docID() { @Override public int nextDoc() throws IOException { + // FIXME: this is incorrect for a sorted index. + // We need to use DocIDMerger while also tracking ordinals while (iSub < subs.size()) { int doc = subs.get(iSub).nextMappedDoc(); if (doc != NO_MORE_DOCS) { diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index c1151c7f53f4..91894352ce62 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -77,14 +77,6 @@ public VectorScorer scorer(byte[] query) throws IOException { throw new UnsupportedOperationException(); } - /** - * Returns a slice of the underlying {@link IndexInput} that contains the vector values if - * available - */ - public IndexInput getSlice() { - return null; - } - @Override public VectorEncoding getEncoding() { return VectorEncoding.BYTE; diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index 5220a2b6729e..84f8113577c8 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -20,7 +20,6 @@ import java.util.List; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.search.VectorScorer; -import org.apache.lucene.store.IndexInput; /** * This class provides access to per-document floating point vector values indexed as {@link @@ -46,14 +45,6 @@ public FloatVectorValues copy() throws IOException { */ public abstract float[] vectorValue(int ord) throws IOException; - /** - * Returns a slice of the underlying {@link IndexInput} that contains the vector values if - * available - */ - public IndexInput getSlice() { - return null; - } - /** * Checks the Vector Encoding of a field * diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java index 3004878351a1..dcf710a9847d 100644 --- a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java @@ -21,6 +21,7 @@ import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; /** @@ -46,7 +47,7 @@ public abstract class KnnVectorValues { /** * Return the docid of the document indexed with the given vector ordinal. This default * implementation returns the argument and is appropriate for dense values implementations where - * every doc has a value. + * every doc has a single value. */ public int ordToDoc(int ord) { return ord; @@ -56,9 +57,7 @@ public int ordToDoc(int ord) { * Creates a new copy of this {@link KnnVectorValues}. This is helpful when you need to access * different values at once, to avoid overwriting the underlying vector returned. */ - public KnnVectorValues copy() throws IOException { - throw new UnsupportedOperationException("by class " + getClass().getName()); - } + public abstract KnnVectorValues copy() throws IOException; /** Returns the vector byte length, defaults to dimension multiplied by float byte size */ public int getVectorByteLength() { diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index 66aa99010b23..73621f751e84 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -295,6 +295,7 @@ public void checkIntegrity() throws IOException { private record DocValuesSub(T sub, int docStart, int docEnd) {} + // TODO: use this instead of the CompositeIterator we created here? private static class MergedDocIdSetIterator extends DocIdSetIterator { final Iterator> it; diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java index 0626f6669221..4af62b21a83f 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java @@ -19,7 +19,6 @@ import java.io.IOException; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.search.VectorScorer; -import org.apache.lucene.store.IndexInput; /** * A version of {@link ByteVectorValues}, but additionally retrieving score correction offset for @@ -49,12 +48,4 @@ public VectorScorer scorer(float[] query) throws IOException { public QuantizedByteVectorValues copy() throws IOException { return this; } - - /** - * Returns a slice of the underlying {@link IndexInput} that contains the vector values if - * available - */ - public IndexInput getSlice() { - return null; - } } diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java index d901ac655d91..3e1113de0242 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java @@ -19,6 +19,7 @@ import java.io.IOException; import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; +import org.apache.lucene.codecs.lucene95.HasIndexSlice; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; @@ -43,7 +44,7 @@ public RandomVectorScorerSupplier getRandomVectorScorerSupplier( // a quantized values here is a wrapping or delegation issue assert !(vectorValues instanceof QuantizedByteVectorValues); // currently only supports binary vectors - if (vectorValues instanceof ByteVectorValues byteVectorValues + if (vectorValues instanceof HasIndexSlice byteVectorValues && byteVectorValues.getSlice() != null) { var scorer = Lucene99MemorySegmentByteVectorScorerSupplier.create( @@ -70,7 +71,7 @@ public RandomVectorScorer getRandomVectorScorer( checkDimensions(queryVector.length, vectorValues.dimension()); // a quantized values here is a wrapping or delegation issue assert !(vectorValues instanceof QuantizedByteVectorValues); - if (vectorValues instanceof ByteVectorValues byteVectorValues + if (vectorValues instanceof HasIndexSlice byteVectorValues && byteVectorValues.getSlice() != null) { var scorer = Lucene99MemorySegmentByteVectorScorer.create( diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java index 09a5b911e17a..e781c168cd97 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java @@ -20,12 +20,14 @@ import java.io.IOException; import java.util.Random; import java.util.function.IntUnaryOperator; + +import org.apache.lucene.codecs.lucene95.HasIndexSlice; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; /** A reader of vector values that samples a subset of the vectors. */ -public class SampleReader extends FloatVectorValues { +public class SampleReader extends FloatVectorValues implements HasIndexSlice { private final FloatVectorValues origin; private final int sampleSize; private final IntUnaryOperator sampleFunction; @@ -53,7 +55,7 @@ public FloatVectorValues copy() throws IOException { @Override public IndexInput getSlice() { - return origin.getSlice(); + return ((HasIndexSlice) origin).getSlice(); } @Override From 327b930d245c612e5186b7a9eeb6e94e1c8fe95e Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Sat, 7 Sep 2024 08:55:24 -0400 Subject: [PATCH 07/25] remove RandomAccessVectorValues --- .../lucene90/Lucene90HnswGraphBuilder.java | 3 +- .../lucene90/Lucene90HnswVectorsReader.java | 4 +-- .../lucene91/Lucene91HnswVectorsReader.java | 1 - .../lucene92/OffHeapFloatVectorValues.java | 4 +-- .../lucene94/OffHeapByteVectorValues.java | 4 +-- .../lucene94/OffHeapFloatVectorValues.java | 4 +-- .../lucene95/OffHeapByteVectorValues.java | 1 - .../lucene95/OffHeapFloatVectorValues.java | 1 - .../OffHeapQuantizedByteVectorValues.java | 4 +-- .../util/hnsw/RandomAccessVectorValues.java | 33 ------------------- .../QuantizedByteVectorValues.java | 10 +++++- 11 files changed, 15 insertions(+), 54 deletions(-) delete mode 100644 lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java index 30b49f64b542..638adb944898 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java @@ -26,7 +26,6 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.hnsw.NeighborQueue; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** * Builder for HNSW graph. See {@link Lucene90OnHeapHnswGraph} for a gloss on the algorithm and the @@ -98,7 +97,7 @@ public Lucene90HnswGraphBuilder( } /** - * Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two + * Reads all the vectors from two copies of a {@link FloatVectorValues}. Providing two * copies enables efficient retrieval without extra data copying, while avoiding collision of the * returned values. * diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index 3b8334e05e51..8682e9ccc606 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -42,7 +42,6 @@ import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.NeighborQueue; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; /** * Reads vectors from the index segments along with index data structures supporting KNN search. @@ -353,8 +352,7 @@ int size() { } /** Read the vector values from the index input. This supports both iterated and random access. */ - static class OffHeapFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues.Floats { + static class OffHeapFloatVectorValues extends FloatVectorValues { final int dimension; final int[] ordToDoc; diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index b3b2e3324f4a..dbc4061e6cf7 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -46,7 +46,6 @@ import org.apache.lucene.util.hnsw.HnswGraph; import org.apache.lucene.util.hnsw.HnswGraphSearcher; import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; /** diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index ae232b1fb828..a5298e3c2f9f 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -26,12 +26,10 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.packed.DirectMonotonicReader; /** Read the vector values from the index input. This supports both iterated and random access. */ -abstract class OffHeapFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues.Floats { +abstract class OffHeapFloatVectorValues extends FloatVectorValues { protected final int dimension; protected final int size; diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index b8ada89da3e3..28306e0221bd 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -28,12 +28,10 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.packed.DirectMonotonicReader; /** Read the vector values from the index input. This supports both iterated and random access. */ -abstract class OffHeapByteVectorValues extends ByteVectorValues - implements RandomAccessVectorValues.Bytes { +abstract class OffHeapByteVectorValues extends ByteVectorValues { protected final int dimension; protected final int size; diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index 77ca4c2bb9c5..b9895eb12d30 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -26,12 +26,10 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.packed.DirectMonotonicReader; /** Read the vector values from the index input. This supports both iterated and random access. */ -abstract class OffHeapFloatVectorValues extends FloatVectorValues - implements RandomAccessVectorValues.Floats { +abstract class OffHeapFloatVectorValues extends FloatVectorValues { protected final int dimension; protected final int size; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index 71211a4586f0..4137089a8dd5 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -29,7 +29,6 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index c99b72b6242f..a9c593e61653 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -28,7 +28,6 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.RandomAccessInput; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index 8129e65c87d1..ecb5a0bc8e38 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -27,7 +27,6 @@ import org.apache.lucene.search.VectorScorer; import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; -import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.packed.DirectMonotonicReader; import org.apache.lucene.util.quantization.QuantizedByteVectorValues; @@ -37,8 +36,7 @@ * Read the quantized vector values and their score correction values from the index input. This * supports both iterated and random access. */ -public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValues - implements RandomAccessVectorValues.Bytes { +public abstract class OffHeapQuantizedByteVectorValues extends QuantizedByteVectorValues { protected final int dimension; protected final int size; diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java deleted file mode 100644 index 411468c1b5e4..000000000000 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/RandomAccessVectorValues.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.lucene.util.hnsw; - -/** - * Provides random access to vectors by dense ordinal. This interface is used by HNSW-based - * implementations of KNN search. - * - * @lucene.experimental - */ -public interface RandomAccessVectorValues { - - /** Float vector values. */ - interface Floats extends RandomAccessVectorValues {} - - /** Byte vector values. */ - interface Bytes extends RandomAccessVectorValues {} -} diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java index 4af62b21a83f..a08ab64a29ee 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java @@ -17,8 +17,11 @@ package org.apache.lucene.util.quantization; import java.io.IOException; + +import org.apache.lucene.codecs.lucene95.HasIndexSlice; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.search.VectorScorer; +import org.apache.lucene.store.IndexInput; /** * A version of {@link ByteVectorValues}, but additionally retrieving score correction offset for @@ -26,7 +29,7 @@ * * @lucene.experimental */ -public abstract class QuantizedByteVectorValues extends ByteVectorValues { +public abstract class QuantizedByteVectorValues extends ByteVectorValues implements HasIndexSlice { public ScalarQuantizer getScalarQuantizer() { throw new UnsupportedOperationException(); @@ -48,4 +51,9 @@ public VectorScorer scorer(float[] query) throws IOException { public QuantizedByteVectorValues copy() throws IOException { return this; } + + @Override + public IndexInput getSlice() { + return null; + } } From 98ab0a6b4e0eb2d4f451172bf55abc338e82935e Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Tue, 10 Sep 2024 08:35:20 -0400 Subject: [PATCH 08/25] tests pass --- .../lucene90/Lucene90HnswGraphBuilder.java | 4 +- .../lucene90/Lucene90HnswVectorsReader.java | 13 +- .../lucene91/Lucene91HnswVectorsReader.java | 5 + .../lucene92/OffHeapFloatVectorValues.java | 10 ++ .../lucene94/OffHeapByteVectorValues.java | 10 ++ .../lucene94/OffHeapFloatVectorValues.java | 10 ++ .../lucene90/Lucene90HnswVectorsWriter.java | 11 +- .../lucene91/Lucene91HnswVectorsWriter.java | 7 +- .../lucene92/Lucene92HnswVectorsWriter.java | 9 +- .../lucene94/Lucene94HnswVectorsWriter.java | 22 ++- .../lucene95/Lucene95HnswVectorsWriter.java | 18 +- .../SimpleTextKnnVectorsReader.java | 9 +- .../SimpleTextKnnVectorsWriter.java | 7 +- .../codecs/BufferingKnnVectorsWriter.java | 97 +++++----- .../lucene/codecs/KnnVectorsWriter.java | 1 + .../lucene95/OffHeapByteVectorValues.java | 3 +- .../lucene95/OffHeapFloatVectorValues.java | 3 +- .../Lucene99ScalarQuantizedVectorsWriter.java | 73 +++----- .../OffHeapQuantizedByteVectorValues.java | 2 - .../apache/lucene/index/ByteVectorValues.java | 1 - .../apache/lucene/index/KnnVectorValues.java | 72 +++++++- .../SlowCompositeCodecReaderWrapper.java | 167 +++++------------- .../lucene/index/SortingCodecReader.java | 28 +-- .../QuantizedByteVectorValues.java | 1 - ...ucene99MemorySegmentFlatVectorsScorer.java | 1 - .../lucene/util/hnsw/HnswGraphTestCase.java | 14 +- .../util/hnsw/MockByteVectorValues.java | 2 - .../lucene/util/hnsw/MockVectorValues.java | 2 - .../quantization/TestScalarQuantizer.java | 1 - .../lucene/index/memory/MemoryIndex.java | 12 +- .../codecs/quantization/SampleReader.java | 1 - .../index/BaseKnnVectorsFormatTestCase.java | 2 +- 32 files changed, 312 insertions(+), 306 deletions(-) diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java index 638adb944898..0d7fd520a303 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswGraphBuilder.java @@ -97,8 +97,8 @@ public Lucene90HnswGraphBuilder( } /** - * Reads all the vectors from two copies of a {@link FloatVectorValues}. Providing two - * copies enables efficient retrieval without extra data copying, while avoiding collision of the + * Reads all the vectors from two copies of a {@link FloatVectorValues}. Providing two copies + * enables efficient retrieval without extra data copying, while avoiding collision of the * returned values. * * @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index 8682e9ccc606..03f921b8df11 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -363,9 +363,6 @@ static class OffHeapFloatVectorValues extends FloatVectorValues { final float[] value; final VectorSimilarityFunction similarityFunction; - int ord = -1; - int doc = -1; - OffHeapFloatVectorValues( int dimension, int[] ordToDoc, @@ -406,6 +403,16 @@ public float[] vectorValue(int targetOrd) throws IOException { return value; } + @Override + public int ordToDoc(int ord) { + return ordToDoc[ord]; + } + + @Override + protected DocIterator createIterator() { + return fromOrdToDoc(this); + } + @Override public VectorScorer scorer(float[] target) { if (size() == 0) { diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index dbc4061e6cf7..395f6a1fceef 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -452,6 +452,11 @@ public int ordToDoc(int ord) { return ordToDocOperator.applyAsInt(ord); } + @Override + protected DocIterator createIterator() { + return fromOrdToDoc(this); + } + @Override public VectorScorer scorer(float[] target) { if (size == 0) { diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index a5298e3c2f9f..1daaba1a5b83 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -106,6 +106,11 @@ public DenseOffHeapVectorValues copy() throws IOException { return new DenseOffHeapVectorValues(dimension, size, vectorSimilarityFunction, slice.clone()); } + @Override + protected DocIterator createIterator() { + return createDenseIterator(this); + } + @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; @@ -165,6 +170,11 @@ public SparseOffHeapVectorValues copy() throws IOException { fieldEntry, dataIn, vectorSimilarityFunction, slice.clone()); } + @Override + protected DocIterator createIterator() { + return DocIterator.fromIndexedDISI(disi); + } + @Override public int ordToDoc(int ord) { return (int) ordToDoc.get(ord); diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index 28306e0221bd..489d826076e5 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -121,6 +121,11 @@ public DenseOffHeapVectorValues copy() throws IOException { dimension, size, slice.clone(), vectorSimilarityFunction, byteSize); } + @Override + public DocIterator createIterator() { + return createDenseIterator(this); + } + @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; @@ -185,6 +190,11 @@ public int ordToDoc(int ord) { return (int) ordToDoc.get(ord); } + @Override + protected DocIterator createIterator() { + return fromDISI(disi); + } + @Override public Bits getAcceptOrds(Bits acceptDocs) { if (acceptDocs == null) { diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index b9895eb12d30..5330fb964af9 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -117,6 +117,11 @@ public DenseOffHeapVectorValues copy() throws IOException { dimension, size, slice.clone(), vectorSimilarityFunction, byteSize); } + @Override + public DocIterator createIterator() { + return createDenseIterator(this); + } + @Override public Bits getAcceptOrds(Bits acceptDocs) { return acceptDocs; @@ -177,6 +182,11 @@ public SparseOffHeapVectorValues copy() throws IOException { fieldEntry, dataIn, slice.clone(), vectorSimilarityFunction, byteSize); } + @Override + protected DocIterator createIterator() { + return DocIterator.fromIndexedDISI(disi); + } + @Override public int ordToDoc(int ord) { return (int) ordToDoc.get(ord); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java index aa54f03071b6..c1e45beed059 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java @@ -17,6 +17,8 @@ package org.apache.lucene.backward_codecs.lucene90; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + import java.io.IOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; @@ -27,6 +29,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; @@ -185,13 +188,13 @@ private static int[] writeVectorData(IndexOutput output, FloatVectorValues vecto int count = 0; ByteBuffer binaryVector = ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - for (int ord = 0; ord < vectors.size(); ord++) { - int docV = vectors.ordToDoc(ord); + KnnVectorValues.DocIterator iter = vectors.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector - float[] vectorValue = vectors.vectorValue(ord); + float[] vectorValue = vectors.vectorValue(iter.index()); binaryVector.asFloatBuffer().put(vectorValue); output.writeBytes(binaryVector.array(), binaryVector.limit()); - docIds[ord] = docV; + docIds[count++] = docV; } if (docIds.length > count) { diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java index 430313e75685..ec43577a0633 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java @@ -28,6 +28,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; @@ -180,10 +181,10 @@ private static DocsWithFieldSet writeVectorData(IndexOutput output, FloatVectorV DocsWithFieldSet docsWithField = new DocsWithFieldSet(); ByteBuffer binaryVector = ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - for (int ord = 0; ord < vectors.size(); ord++) { - int docV = vectors.ordToDoc(ord); + KnnVectorValues.DocIterator iter = vectors.iterator(); + for (int docV = iter.nextDoc(); docV != DocIdSetIterator.NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector - float[] vectorValue = vectors.vectorValue(ord); + float[] vectorValue = vectors.vectorValue(iter.index()); binaryVector.asFloatBuffer().put(vectorValue); output.writeBytes(binaryVector.array(), binaryVector.limit()); docsWithField.add(docV); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java index e0c6895c7be9..9e0e9e638809 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java @@ -32,6 +32,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.DocIdSetIterator; @@ -188,10 +189,12 @@ private static DocsWithFieldSet writeVectorData(IndexOutput output, FloatVectorV DocsWithFieldSet docsWithField = new DocsWithFieldSet(); ByteBuffer binaryVector = ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - for (int ord = 0; ord < vectors.size(); ord++) { - int docV = vectors.ordToDoc(ord); + KnnVectorValues.DocIterator iterator = vectors.iterator(); + for (int docV = iterator.nextDoc(); + docV != DocIdSetIterator.NO_MORE_DOCS; + docV = iterator.nextDoc()) { // write vector - float[] vectorValue = vectors.vectorValue(ord); + float[] vectorValue = vectors.vectorValue(iterator.index()); binaryVector.asFloatBuffer().put(vectorValue); output.writeBytes(binaryVector.array(), binaryVector.limit()); docsWithField.add(docV); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java index 476f2502e0e5..24b99cd30550 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java @@ -18,6 +18,7 @@ package org.apache.lucene.backward_codecs.lucene94; import static org.apache.lucene.backward_codecs.lucene94.Lucene94RWHnswVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; import java.nio.ByteBuffer; @@ -35,6 +36,7 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; @@ -214,9 +216,7 @@ private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocM final int[] docIdOffsets = new int[sortMap.size()]; int offset = 1; // 0 means no vector for this (field, document) DocIdSetIterator iterator = fieldData.docsWithField.iterator(); - for (int docID = iterator.nextDoc(); - docID != DocIdSetIterator.NO_MORE_DOCS; - docID = iterator.nextDoc()) { + for (int docID = iterator.nextDoc(); docID != NO_MORE_DOCS; docID = iterator.nextDoc()) { int newDocID = sortMap.oldToNew(docID); docIdOffsets[newDocID] = offset++; } @@ -554,9 +554,7 @@ private void writeMeta( final DirectMonotonicWriter ordToDocWriter = DirectMonotonicWriter.getInstance(meta, vectorData, count, DIRECT_MONOTONIC_BLOCK_SHIFT); DocIdSetIterator iterator = docsWithField.iterator(); - for (int doc = iterator.nextDoc(); - doc != DocIdSetIterator.NO_MORE_DOCS; - doc = iterator.nextDoc()) { + for (int doc = iterator.nextDoc(); doc != NO_MORE_DOCS; doc = iterator.nextDoc()) { ordToDocWriter.add(doc); } ordToDocWriter.finish(); @@ -588,10 +586,10 @@ private void writeMeta( private static DocsWithFieldSet writeByteVectorData( IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - for (int ord = 0; ord < byteVectorValues.size(); ord++) { - int docV = byteVectorValues.ordToDoc(ord); + KnnVectorValues.DocIterator iter = byteVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector - byte[] binaryValue = byteVectorValues.vectorValue(ord); + byte[] binaryValue = byteVectorValues.vectorValue(iter.index()); assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize; output.writeBytes(binaryValue, binaryValue.length); docsWithField.add(docV); @@ -605,13 +603,13 @@ private static DocsWithFieldSet writeByteVectorData( private static DocsWithFieldSet writeVectorData( IndexOutput output, FloatVectorValues floatVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); + KnnVectorValues.DocIterator iter = floatVectorValues.iterator(); ByteBuffer binaryVector = ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize) .order(ByteOrder.LITTLE_ENDIAN); - for (int ord = 0; ord < floatVectorValues.size(); ord++) { - int docV = floatVectorValues.ordToDoc(ord); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector - float[] vectorValue = floatVectorValues.vectorValue(ord); + float[] vectorValue = floatVectorValues.vectorValue(iter.index()); binaryVector.asFloatBuffer().put(vectorValue); output.writeBytes(binaryVector.array(), binaryVector.limit()); docsWithField.add(docV); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java index db7081d04c4f..9cdb7f1737a3 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java @@ -221,9 +221,7 @@ private void writeSortingField(FieldWriter fieldData, int maxDoc, Sorter.DocM final int[] docIdOffsets = new int[sortMap.size()]; int offset = 1; // 0 means no vector for this (field, document) DocIdSetIterator iterator = fieldData.docsWithField.iterator(); - for (int docID = iterator.nextDoc(); - docID != NO_MORE_DOCS; - docID = iterator.nextDoc()) { + for (int docID = iterator.nextDoc(); docID != NO_MORE_DOCS; docID = iterator.nextDoc()) { int newDocID = sortMap.oldToNew(docID); docIdOffsets[newDocID] = offset++; } @@ -636,13 +634,13 @@ private void writeMeta( private static DocsWithFieldSet writeByteVectorData( IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - for (int ord = 0; ord < byteVectorValues.size(); ord++) { - int docV = byteVectorValues.ordToDoc(ord); + KnnVectorValues.DocIterator iter = byteVectorValues.iterator(); + for (int docId = iter.nextDoc(); docId != NO_MORE_DOCS; docId = iter.nextDoc()) { // write vector - byte[] binaryValue = byteVectorValues.vectorValue(ord); + byte[] binaryValue = byteVectorValues.vectorValue(iter.index()); assert binaryValue.length == byteVectorValues.dimension() * VectorEncoding.BYTE.byteSize; output.writeBytes(binaryValue, binaryValue.length); - docsWithField.add(docV); + docsWithField.add(docId); } return docsWithField; } @@ -656,10 +654,10 @@ private static DocsWithFieldSet writeVectorData( ByteBuffer buffer = ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize) .order(ByteOrder.LITTLE_ENDIAN); - for (int ord = 0; ord < floatVectorValues.size(); ord++) { - int docV = floatVectorValues.ordToDoc(ord); + KnnVectorValues.DocIterator iter = floatVectorValues.iterator(); + for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector - float[] value = floatVectorValues.vectorValue(ord); + float[] value = floatVectorValues.vectorValue(iter.index()); buffer.asFloatBuffer().put(value); output.writeBytes(buffer.array(), buffer.limit()); docsWithField.add(docV); diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index 18fb7b2a085c..9cdbd96da38d 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -331,9 +331,14 @@ public float[] vectorValue(int ord) { return values[ord]; } + @Override + public int ordToDoc(int ord) { + return entry.ordToDoc[ord]; + } + @Override protected DocIterator createIterator() { - return createDenseIterator(this); + return fromOrdToDoc(this); } @Override @@ -425,7 +430,7 @@ public byte[] vectorValue(int ord) { @Override protected DocIterator createIterator() { - return createDenseIterator(this); + return fromOrdToDoc(this); } @Override diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java index d6b25f0e4761..005fb0b5065c 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java @@ -100,9 +100,10 @@ public void writeField(FieldInfo fieldInfo, ByteVectorValues byteVectorValues, i throws IOException { long vectorDataOffset = vectorData.getFilePointer(); List docIds = new ArrayList<>(); - for (int ord = 0; ord < byteVectorValues.size(); ord++) { - writeByteVectorValue(byteVectorValues, ord); - docIds.add(byteVectorValues.ordToDoc(ord)); + KnnVectorValues.DocIterator it = byteVectorValues.iterator(); + for (int docV = it.nextDoc(); docV != NO_MORE_DOCS; docV = it.nextDoc()) { + writeByteVectorValue(byteVectorValues, it.index()); + docIds.add(docV); } long vectorDataLength = vectorData.getFilePointer() - vectorDataOffset; writeMeta(fieldInfo, vectorDataOffset, vectorDataLength, docIds); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java index 620e083a42fa..480b8f0cbecd 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java @@ -24,9 +24,10 @@ import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.SortingCodecReader.SortingValuesIterator; +import org.apache.lucene.search.DocIdSet; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.RamUsageEstimator; @@ -79,20 +80,26 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { case FLOAT32: BufferedFloatVectorValues bufferedFloatVectorValues = new BufferedFloatVectorValues( - (List) fieldData.vectors, fieldData.fieldInfo.getVectorDimension()); + (List) fieldData.vectors, + fieldData.fieldInfo.getVectorDimension(), + fieldData.docsWithField); FloatVectorValues floatVectorValues = sortMap != null - ? new SortingFloatVectorValues(bufferedFloatVectorValues, sortMap) + ? new SortingFloatVectorValues( + bufferedFloatVectorValues, fieldData.docsWithField, sortMap) : bufferedFloatVectorValues; writeField(fieldData.fieldInfo, floatVectorValues, maxDoc); break; case BYTE: BufferedByteVectorValues bufferedByteVectorValues = new BufferedByteVectorValues( - (List) fieldData.vectors, fieldData.fieldInfo.getVectorDimension()); + (List) fieldData.vectors, + fieldData.fieldInfo.getVectorDimension(), + fieldData.docsWithField); ByteVectorValues byteVectorValues = sortMap != null - ? new SortingByteVectorValues(bufferedByteVectorValues, sortMap) + ? new SortingByteVectorValues( + bufferedByteVectorValues, fieldData.docsWithField, sortMap) : bufferedByteVectorValues; writeField(fieldData.fieldInfo, byteVectorValues, maxDoc); break; @@ -106,17 +113,18 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { */ private static class SortingFloatVectorValues extends FloatVectorValues { private final BufferedFloatVectorValues delegate; - private final int[] newToOld; + private final DocIterator iterator; - SortingFloatVectorValues(BufferedFloatVectorValues delegate, Sorter.DocMap sortMap) + SortingFloatVectorValues( + BufferedFloatVectorValues delegate, DocsWithFieldSet docsWithField, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate.copy(); - newToOld = docMapToOrdMap(delegate, sortMap); + iterator = new SortingValuesIterator(delegate.copy().iterator(), sortMap); } @Override public float[] vectorValue(int ord) throws IOException { - return delegate.vectorValue(newToOld[ord]); + return delegate.vectorValue(ord); } @Override @@ -133,31 +141,11 @@ public int size() { public SortingFloatVectorValues copy() { throw new UnsupportedOperationException(); } - } - // create a map from new ord to old ord assuming ords are sorted by doc but - // may be sparse - private static int[] docMapToOrdMap(KnnVectorValues values, Sorter.DocMap docMap) { - // fill with -1? - int[] newDocToOldOrd = new int[docMap.size()]; - int count = 0; - for (int ord = 0; ord < values.size(); ord++) { - int oldDoc = values.ordToDoc(ord); - int newDoc = docMap.oldToNew(oldDoc); - // no value will be represented by 0 - if (newDoc >= 0) { - newDocToOldOrd[newDoc] = ord + 1; - ++count; - } - } - int[] newToOld = new int[count]; - count = 0; - for (int ord = 0; ord < newDocToOldOrd.length; ord++) { - if (newDocToOldOrd[ord] > 0) { - newToOld[count++] = newDocToOldOrd[ord] - 1; - } + @Override + public DocIterator iterator() { + return iterator; } - return newToOld; } /** @@ -166,17 +154,18 @@ private static int[] docMapToOrdMap(KnnVectorValues values, Sorter.DocMap docMap */ private static class SortingByteVectorValues extends ByteVectorValues { private final BufferedByteVectorValues delegate; - private final int[] newToOld; + private final DocIterator iterator; - SortingByteVectorValues(BufferedByteVectorValues delegate, Sorter.DocMap sortMap) + SortingByteVectorValues( + BufferedByteVectorValues delegate, DocsWithFieldSet docsWithField, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate; - newToOld = docMapToOrdMap(delegate, sortMap); + iterator = new SortingValuesIterator(delegate.copy().iterator(), sortMap); } @Override public byte[] vectorValue(int ord) throws IOException { - return delegate.vectorValue(newToOld[ord]); + return delegate.vectorValue(ord); } @Override @@ -193,6 +182,11 @@ public int size() { public SortingByteVectorValues copy() { throw new UnsupportedOperationException(); } + + @Override + public DocIterator iterator() { + return iterator; + } } @Override @@ -278,10 +272,15 @@ private static class BufferedFloatVectorValues extends FloatVectorValues { // These are always the vectors of a VectorValuesWriter, which are copied when added to it final List vectors; final int dimension; + private final DocIdSet docsWithField; + private final DocIterator iterator; - BufferedFloatVectorValues(List vectors, int dimension) { + BufferedFloatVectorValues(List vectors, int dimension, DocIdSet docsWithField) + throws IOException { this.vectors = vectors; this.dimension = dimension; + this.docsWithField = docsWithField; + this.iterator = fromDISI(docsWithField.iterator()); } @Override @@ -305,13 +304,13 @@ public float[] vectorValue(int targetOrd) { } @Override - public DocIterator createIterator() { - return createDenseIterator(this); + public DocIterator iterator() { + return iterator; } @Override - public BufferedFloatVectorValues copy() { - return this; + public BufferedFloatVectorValues copy() throws IOException { + return new BufferedFloatVectorValues(vectors, dimension, docsWithField); } } @@ -319,10 +318,15 @@ private static class BufferedByteVectorValues extends ByteVectorValues { // These are always the vectors of a VectorValuesWriter, which are copied when added to it final List vectors; final int dimension; + private final DocIdSet docsWithField; + private final DocIterator iterator; - BufferedByteVectorValues(List vectors, int dimension) { + BufferedByteVectorValues(List vectors, int dimension, DocIdSet docsWithField) + throws IOException { this.vectors = vectors; this.dimension = dimension; + this.docsWithField = docsWithField; + iterator = fromDISI(docsWithField.iterator()); } @Override @@ -341,8 +345,13 @@ public byte[] vectorValue(int targetOrd) { } @Override - public BufferedByteVectorValues copy() { - return this; + public DocIterator iterator() { + return iterator; + } + + @Override + public BufferedByteVectorValues copy() throws IOException { + return new BufferedByteVectorValues(vectors, dimension, docsWithField); } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index b6c6ad53a199..2a551fd77050 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -356,6 +356,7 @@ public int dimension() { @Override public int ordToDoc(int ord) { + // FIXME support ordToDoc for backward-codecs? throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index 4137089a8dd5..74f7060ea5dd 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -33,8 +33,7 @@ import org.apache.lucene.util.packed.DirectMonotonicReader; /** Read the vector values from the index input. This supports both iterated and random access. */ -public abstract class OffHeapByteVectorValues extends ByteVectorValues - implements HasIndexSlice { +public abstract class OffHeapByteVectorValues extends ByteVectorValues implements HasIndexSlice { protected final int dimension; protected final int size; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index a9c593e61653..db6cb157253d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -32,8 +32,7 @@ import org.apache.lucene.util.packed.DirectMonotonicReader; /** Read the vector values from the index input. This supports both iterated and random access. */ -public abstract class OffHeapFloatVectorValues extends FloatVectorValues - implements HasIndexSlice { +public abstract class OffHeapFloatVectorValues extends FloatVectorValues implements HasIndexSlice { protected final int dimension; protected final int size; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index be101cc9244e..8a111d874d92 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -886,15 +886,10 @@ public DocIterator createIterator() { } static class QuantizedByteVectorValueSub extends DocIDMerger.Sub { - private final int maxDoc; private final QuantizedByteVectorValues values; - int ord = -1; - - QuantizedByteVectorValueSub( - int maxDoc, MergeState.DocMap docMap, QuantizedByteVectorValues values) { + QuantizedByteVectorValueSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) { super(docMap); - this.maxDoc = maxDoc; this.values = values; assert values.iterator().docID() == -1; } @@ -935,14 +930,12 @@ public static MergedQuantizedVectorValues mergeQuantizedByteVectorValues( } sub = new QuantizedByteVectorValueSub( - mergeState.maxDocs[i], mergeState.docMaps[i], new QuantizedFloatVectorValues( toQuantize, fieldInfo.getVectorSimilarityFunction(), scalarQuantizer)); } else { sub = new QuantizedByteVectorValueSub( - mergeState.maxDocs[i], mergeState.docMaps[i], new OffsetCorrectedQuantizedByteVectorValues( reader.getQuantizedVectorValues(fieldInfo.name), @@ -959,33 +952,28 @@ public static MergedQuantizedVectorValues mergeQuantizedByteVectorValues( private final List subs; private final DocIDMerger docIdMerger; private final int size; - private final int[] ends; + + private QuantizedByteVectorValueSub current; private MergedQuantizedVectorValues( List subs, MergeState mergeState) throws IOException { this.subs = subs; docIdMerger = DocIDMerger.of(subs, mergeState.needsIndexSort); int totalSize = 0; - ends = new int[subs.size()]; - int iSub = 0; for (QuantizedByteVectorValueSub sub : subs) { totalSize += sub.values.size(); - ends[iSub++] = totalSize; } size = totalSize; } @Override public byte[] vectorValue(int ord) throws IOException { - int iSub = ((CompositeIterator) iterator()).iSub; - QuantizedByteVectorValues values = subs.get(iSub).values; - assert ord == values.iterator().index() + (iSub == 0 ? 0 : ends[iSub - 1]); - return values.vectorValue(values.iterator().index()); + return current.values.vectorValue(current.values.iterator().index()); } @Override protected DocIterator createIterator() { - return new CompositeIterator(subs); + return new CompositeIterator(); } @Override @@ -1000,58 +988,39 @@ public int dimension() { @Override public float getScoreCorrectionConstant(int ord) throws IOException { - int iSub = ((CompositeIterator) iterator()).iSub; - - QuantizedByteVectorValues current = subs.get(iSub).values; - return current.getScoreCorrectionConstant(current.iterator().index()); + return current.values.getScoreCorrectionConstant(current.values.iterator().index()); } - static class CompositeIterator extends KnnVectorValues.DocIterator { - private final List subs; - int iSub; - int docBase; - int ordBase; - int docId; + private class CompositeIterator extends KnnVectorValues.DocIterator { + private int docId; + private int ord; - public CompositeIterator(List subs) { - this.subs = subs; - iSub = 0; - docBase = 0; - ordBase = 0; + public CompositeIterator() { docId = -1; + ord = -1; } @Override public int index() { - if (iSub == subs.size()) { - return NO_MORE_DOCS; - } - return ordBase + subs.get(iSub).values.iterator().index(); + return ord; } @Override public int docID() { - if (iSub == subs.size()) { - return NO_MORE_DOCS; - } - int subDocId = subs.get(iSub).values.iterator().docID(); - return docBase + subDocId; + return docId; } @Override public int nextDoc() throws IOException { - // FIXME: this is incorrect for a sorted index. - // We need to use DocIDMerger while also tracking ordinals - while (iSub < subs.size()) { - int doc = subs.get(iSub).nextMappedDoc(); - if (doc != NO_MORE_DOCS) { - return docBase + doc; - } - ordBase += subs.get(iSub).values.size(); - docBase += subs.get(iSub).maxDoc; - ++iSub; + current = docIdMerger.next(); + if (current == null) { + docId = NO_MORE_DOCS; + ord = NO_MORE_DOCS; + } else { + docId = current.mappedDocID; + ++ord; } - return NO_MORE_DOCS; + return docId; } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index ecb5a0bc8e38..b08387ca100c 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -362,8 +362,6 @@ public EmptyOffHeapVectorValues( null); } - private int doc = -1; - @Override public int dimension() { return super.dimension(); diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index 91894352ce62..052efc97a8ce 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -20,7 +20,6 @@ import java.util.List; import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.search.VectorScorer; -import org.apache.lucene.store.IndexInput; /** * This class provides access to per-document floating point vector values indexed as {@link diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java index dcf710a9847d..f397ff9055e3 100644 --- a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java @@ -21,7 +21,6 @@ import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.Bits; /** @@ -68,7 +67,7 @@ public int getVectorByteLength() { public Bits getAcceptOrds(Bits acceptDocs) { // FIXME: change default to return acceptDocs and provide this impl - // somewhere more specialized + // somewhere more specialized (in every non-dense impl). if (acceptDocs == null) { return null; } @@ -101,6 +100,7 @@ public long cost() { } public static DocIterator fromIndexedDISI(IndexedDISI disi) { + // can we replace with fromDISI? return new DocIterator() { @Override public int docID() { @@ -182,4 +182,72 @@ public long cost() { } }; } + + protected static DocIterator fromDISI(DocIdSetIterator docsWithField) { + return new DocIterator() { + + int ord = -1; + + @Override + public int docID() { + return docsWithField.docID(); + } + + @Override + public int index() { + return ord; + } + + @Override + public int nextDoc() throws IOException { + if (docID() == NO_MORE_DOCS) { + return NO_MORE_DOCS; + } + ord++; + return docsWithField.nextDoc(); + } + + @Override + public long cost() { + return docsWithField.cost(); + } + }; + } + + protected static DocIterator fromOrdToDoc(KnnVectorValues values) { + return new DocIterator() { + private int ord = -1; + + @Override + public int docID() { + if (ord == -1) { + return -1; + } + if (ord == NO_MORE_DOCS) { + return NO_MORE_DOCS; + } + return values.ordToDoc(ord); + } + + @Override + public int index() { + return ord; + } + + @Override + public int nextDoc() throws IOException { + if (ord >= values.size() - 1) { + ord = NO_MORE_DOCS; + } else { + ++ord; + } + return docID(); + } + + @Override + public long cost() { + return values.size(); + } + }; + } } diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index 73621f751e84..2779a97dfd4a 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -33,9 +33,7 @@ import org.apache.lucene.index.MultiDocValues.MultiSortedDocValues; import org.apache.lucene.index.MultiDocValues.MultiSortedSetDocValues; import org.apache.lucene.internal.hppc.IntObjectHashMap; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; -import org.apache.lucene.search.VectorScorer; import org.apache.lucene.util.ArrayUtil; import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; @@ -293,22 +291,23 @@ public void checkIntegrity() throws IOException { } } - private record DocValuesSub(T sub, int docStart, int docEnd) {} + private record DocValuesSub(T sub, int docStart) {} - // TODO: use this instead of the CompositeIterator we created here? - private static class MergedDocIdSetIterator extends DocIdSetIterator { + private static class MergedDocIterator + extends KnnVectorValues.DocIterator { final Iterator> it; final long cost; DocValuesSub current; - int currentIndex = 0; + int ord = -1; int doc = -1; + int currentOrdBase = 0; - MergedDocIdSetIterator(List> subs) { + MergedDocIterator(List> subs) { long cost = 0; for (DocValuesSub sub : subs) { if (sub.sub != null) { - cost += sub.sub.cost(); + cost += sub.sub.size(); } } this.cost = cost; @@ -316,29 +315,23 @@ private static class MergedDocIdSetIterator extends current = it.next(); } - private boolean advanceSub(int target) { - while (current.sub == null || current.docEnd <= target) { - if (it.hasNext() == false) { - doc = NO_MORE_DOCS; - return false; - } - current = it.next(); - currentIndex++; - } - return true; - } - @Override public int docID() { return doc; } + @Override + public int index() { + return ord; + } + @Override public int nextDoc() throws IOException { while (true) { if (current.sub != null) { - int next = current.sub.nextDoc(); + int next = current.sub.iterator().nextDoc(); if (next != NO_MORE_DOCS) { + ++ord; return doc = current.docStart + next; } } @@ -346,22 +339,7 @@ public int nextDoc() throws IOException { return doc = NO_MORE_DOCS; } current = it.next(); - currentIndex++; - } - } - - @Override - public int advance(int target) throws IOException { - while (true) { - if (advanceSub(target) == false) { - return DocIdSetIterator.NO_MORE_DOCS; - } - int next = current.sub.advance(target - current.docStart); - if (next == DocIdSetIterator.NO_MORE_DOCS) { - target = current.docEnd; - } else { - return doc = current.docStart + next; - } + currentOrdBase = ord; } } @@ -833,7 +811,8 @@ public void checkIntegrity() throws IOException { @Override public FloatVectorValues getFloatVectorValues(String field) throws IOException { - List subs = new ArrayList<>(); + List> subs = new ArrayList<>(); + int i = 0; int dimension = -1; int size = 0; for (CodecReader reader : codecReaders) { @@ -844,12 +823,20 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { } size += values.size(); } - subs.add(values); + subs.add(new DocValuesSub<>(values, docStarts[i])); + i++; } final int finalDimension = dimension; final int finalSize = size; return new FloatVectorValues() { + final MergedDocIterator iter = new MergedDocIterator<>(subs); + + @Override + public MergedDocIterator iterator() { + return iter; + } + @Override public int dimension() { return finalDimension; @@ -862,36 +849,15 @@ public int size() { @Override public float[] vectorValue(int ord) throws IOException { - FloatVectorValues sub = - (FloatVectorValues) subs.get(((CompositeIterator) iterator()).iSub); - return sub.vectorValue(sub.iterator().index()); - } - - @Override - protected DocIterator createIterator() { - return new CompositeIterator(subs); - } - - @Override - public int ordToDoc(int ord) { - throw new UnsupportedOperationException(); - } - - @Override - public VectorScorer scorer(float[] target) { - throw new UnsupportedOperationException(); - } - - @Override - public FloatVectorValues copy() { - throw new UnsupportedOperationException(); + return iter.current.sub.vectorValue(ord - iter.currentOrdBase); } }; } @Override public ByteVectorValues getByteVectorValues(String field) throws IOException { - List subs = new ArrayList<>(); + List> subs = new ArrayList<>(); + int i = 0; int dimension = -1; int size = 0; for (CodecReader reader : codecReaders) { @@ -902,12 +868,20 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } size += values.size(); } - subs.add(values); + subs.add(new DocValuesSub<>(values, docStarts[i])); + i++; } final int finalDimension = dimension; final int finalSize = size; return new ByteVectorValues() { + final MergedDocIterator iter = new MergedDocIterator<>(subs); + + @Override + public MergedDocIterator iterator() { + return iter; + } + @Override public int dimension() { return finalDimension; @@ -920,28 +894,12 @@ public int size() { @Override public byte[] vectorValue(int ord) throws IOException { - ByteVectorValues sub = (ByteVectorValues) subs.get(((CompositeIterator) iterator()).iSub); - return sub.vectorValue(sub.iterator().index()); - } - - @Override - public int ordToDoc(int ord) { - throw new UnsupportedOperationException(); - } - - @Override - public VectorScorer scorer(byte[] target) { - throw new UnsupportedOperationException(); - } - - @Override - public ByteVectorValues copy() { - throw new UnsupportedOperationException(); + return iter.current.sub.vectorValue(ord - iter.currentOrdBase); } @Override protected DocIterator createIterator() { - return new CompositeIterator(subs); + return new MergedDocIterator(subs); } }; } @@ -957,51 +915,6 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits throws IOException { throw new UnsupportedOperationException(); } - - static class CompositeIterator extends KnnVectorValues.DocIterator { - private final List subs; - int iSub; - int docId; - - public CompositeIterator(List subs) { - this.subs = subs; - iSub = 0; - docId = -1; - } - - @Override - public int index() { - if (iSub == subs.size()) { - return NO_MORE_DOCS; - } - return subs.get(iSub).iterator().index(); - } - - @Override - public int docID() { - if (iSub == subs.size()) { - return NO_MORE_DOCS; - } - return subs.get(iSub).iterator().docID(); - } - - @Override - public int nextDoc() throws IOException { - while (iSub < subs.size()) { - int doc = subs.get(iSub).iterator().nextDoc(); - if (doc != NO_MORE_DOCS) { - return doc; - } - ++iSub; - } - return NO_MORE_DOCS; - } - - @Override - public long cost() { - return 0; - } - } } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index 73db8af872c9..e3e9cb66a427 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -32,10 +32,11 @@ import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; -import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.IOSupplier; @@ -206,27 +207,28 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue } } - private static class SortingValuesIterator extends KnnVectorValues.DocIterator { - private final BitSet docsWithValues; + public static class SortingValuesIterator extends KnnVectorValues.DocIterator { + private final DocIdSetIterator docsWithValues; private final int[] docToOrd; private final int size; int doc = -1; - SortingValuesIterator(KnnVectorValues value, Sorter.DocMap docMap) throws IOException { + public SortingValuesIterator(KnnVectorValues.DocIterator iter, Sorter.DocMap docMap) + throws IOException { docToOrd = new int[docMap.size()]; - docsWithValues = new FixedBitSet(docMap.size()); - KnnVectorValues.DocIterator iter = value.createIterator(); + FixedBitSet docBits = new FixedBitSet(docMap.size()); int count = 0; for (int doc = iter.nextDoc(); doc != NO_MORE_DOCS; doc = iter.nextDoc()) { int newDocId = docMap.oldToNew(doc); if (newDocId != -1) { docToOrd[newDocId] = iter.index(); - docsWithValues.set(newDocId); + docBits.set(newDocId); ++count; } } size = count; + docsWithValues = new BitSetIterator(docBits, count); } @Override @@ -240,11 +242,9 @@ public int index() { } @Override - public int nextDoc() { - if (doc >= docsWithValues.length() - 1) { - doc = NO_MORE_DOCS; - } else { - doc = docsWithValues.nextSetBit(doc + 1); + public int nextDoc() throws IOException { + if (doc != NO_MORE_DOCS) { + doc = docsWithValues.nextDoc(); } return doc; } @@ -261,7 +261,7 @@ private static class SortingFloatVectorValues extends FloatVectorValues { SortingFloatVectorValues(FloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate; - iterator = new SortingValuesIterator(delegate, sortMap); + iterator = new SortingValuesIterator(delegate.iterator(), sortMap); } @Override @@ -295,7 +295,7 @@ private static class SortingByteVectorValues extends ByteVectorValues { SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate; - iterator = new SortingValuesIterator(delegate, sortMap); + iterator = new SortingValuesIterator(delegate.iterator(), sortMap); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java index a08ab64a29ee..b90ab8276dd1 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/QuantizedByteVectorValues.java @@ -17,7 +17,6 @@ package org.apache.lucene.util.quantization; import java.io.IOException; - import org.apache.lucene.codecs.lucene95.HasIndexSlice; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.search.VectorScorer; diff --git a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java index 3e1113de0242..63e79bccbdea 100644 --- a/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java +++ b/lucene/core/src/java21/org/apache/lucene/internal/vectorization/Lucene99MemorySegmentFlatVectorsScorer.java @@ -20,7 +20,6 @@ import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer; import org.apache.lucene.codecs.hnsw.FlatVectorsScorer; import org.apache.lucene.codecs.lucene95.HasIndexSlice; -import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.util.hnsw.RandomVectorScorer; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index dea6b679d9be..50a817f49c8b 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -220,7 +220,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { } @SuppressWarnings("unchecked") - private T vectorValue(KnnVectorValues vectors, int ord) throws IOException { + private T vectorValue(KnnVectorValues vectors, int ord) throws IOException { switch (vectors.getEncoding()) { case BYTE -> { return (T) ((ByteVectorValues) vectors).vectorValue(ord); @@ -936,11 +936,11 @@ public void testRandom() throws IOException { for (int j = 0; j < size; j++) { if (vectorValue(vectors, j) != null && (acceptOrds == null || acceptOrds.get(j))) { if (getVectorEncoding() == VectorEncoding.BYTE) { - assert query instanceof byte[]; - expected.add(j, similarityFunction.compare((byte[]) query, vectorValue(vectors, j))); + expected.add( + j, similarityFunction.compare((byte[]) query, (byte[]) vectorValue(vectors, j))); } else { - assert query instanceof float[]; - expected.add(j, similarityFunction.compare((float[]) query, vectorValue(vectors, j))); + expected.add( + j, similarityFunction.compare((float[]) query, (float[]) vectorValue(vectors, j))); } if (expected.size() > topK) { expected.pop(); @@ -1253,12 +1253,12 @@ void assertVectorsEqual(KnnVectorValues u, KnnVectorValues v) throws IOException case BYTE -> assertArrayEquals( "vectors do not match for doc=" + uDoc, - vectorValue(u, ord), + (byte[]) vectorValue(u, ord), (byte[]) vectorValue(v, ord)); case FLOAT32 -> assertArrayEquals( "vectors do not match for doc=" + uDoc, - vectorValue(u, ord), + (float[]) vectorValue(u, ord), (float[]) vectorValue(v, ord), 1e-4f); default -> diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java index 360f1b3cb0f5..d88980294915 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java @@ -30,8 +30,6 @@ class MockByteVectorValues extends ByteVectorValues { private final BytesRef binaryValue; private final byte[] scratch; - private int pos = -1; - static MockByteVectorValues fromValues(byte[][] values) { byte[] firstNonNull = null; int j = 0; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java index e494232774c3..a37c3f208037 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java @@ -28,8 +28,6 @@ class MockVectorValues extends FloatVectorValues { private final int numVectors; private final float[] scratch; - private int pos = -1; - static MockVectorValues fromValues(float[][] values) { float[] firstNonNull = null; int j = 0; diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index f88648b7abf6..4f09decd5eb2 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -16,7 +16,6 @@ */ package org.apache.lucene.util.quantization; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.util.quantization.ScalarQuantizer.SCRATCH_SIZE; import java.io.IOException; diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java index 817a219d6542..679b1ebc9138 100644 --- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java +++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java @@ -2285,7 +2285,6 @@ public int[] clear() { private static final class MemoryFloatVectorValues extends FloatVectorValues { private final Info info; - private int currentDoc = -1; MemoryFloatVectorValues(Info info) { this.info = info; @@ -2310,6 +2309,11 @@ public float[] vectorValue(int ord) { } } + @Override + protected DocIterator createIterator() { + return createDenseIterator(this); + } + @Override public VectorScorer scorer(float[] query) { if (query.length != info.fieldInfo.getVectorDimension()) { @@ -2339,7 +2343,6 @@ public DocIdSetIterator iterator() { private static final class MemoryByteVectorValues extends ByteVectorValues { private final Info info; - private int currentDoc = -1; MemoryByteVectorValues(Info info) { this.info = info; @@ -2364,6 +2367,11 @@ public byte[] vectorValue(int ord) { } } + @Override + protected DocIterator createIterator() { + return createDenseIterator(this); + } + @Override public VectorScorer scorer(byte[] query) { if (query.length != info.fieldInfo.getVectorDimension()) { diff --git a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java index e781c168cd97..684c9fac838f 100644 --- a/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java +++ b/lucene/sandbox/src/java/org/apache/lucene/sandbox/codecs/quantization/SampleReader.java @@ -20,7 +20,6 @@ import java.io.IOException; import java.util.Random; import java.util.function.IntUnaryOperator; - import org.apache.lucene.codecs.lucene95.HasIndexSlice; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.store.IndexInput; diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index 2294d4d01cbc..d66e0f88c47e 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -1305,7 +1305,7 @@ public void testRandom() throws Exception { String idString = storedFields.document(docId).getField("id").stringValue(); int id = Integer.parseInt(idString); if (ctx.reader().getLiveDocs() == null || ctx.reader().getLiveDocs().get(docId)) { - assertArrayEquals(idString, values[id], v, 0); + assertArrayEquals(idString + " " + docId, values[id], v, 0); ++valueCount; } else { ++numDeletes; From 1450b440e8242cb7a1ef9daded3c54fac8aaf205 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Tue, 10 Sep 2024 13:45:04 -0400 Subject: [PATCH 09/25] fixing up javadocs and making iterator methods instance methods --- .../lucene90/Lucene90HnswVectorsReader.java | 2 +- .../lucene91/Lucene91HnswVectorsReader.java | 2 +- .../lucene92/OffHeapFloatVectorValues.java | 2 +- .../lucene94/OffHeapByteVectorValues.java | 2 +- .../lucene94/OffHeapFloatVectorValues.java | 2 +- .../SimpleTextKnnVectorsReader.java | 4 +- .../lucene/codecs/lucene95/HasIndexSlice.java | 6 ++ .../lucene95/OffHeapByteVectorValues.java | 4 +- .../lucene95/OffHeapFloatVectorValues.java | 4 +- .../Lucene99ScalarQuantizedVectorsWriter.java | 2 +- .../OffHeapQuantizedByteVectorValues.java | 4 +- .../apache/lucene/index/ByteVectorValues.java | 2 +- .../lucene/index/FloatVectorValues.java | 2 +- .../apache/lucene/index/KnnVectorValues.java | 74 +++++++++++++------ .../lucene/index/SortingCodecReader.java | 5 ++ .../util/hnsw/MockByteVectorValues.java | 2 +- .../lucene/util/hnsw/MockVectorValues.java | 2 +- .../lucene/index/memory/MemoryIndex.java | 4 +- 18 files changed, 82 insertions(+), 43 deletions(-) diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index 03f921b8df11..cc59c654d4c5 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -410,7 +410,7 @@ public int ordToDoc(int ord) { @Override protected DocIterator createIterator() { - return fromOrdToDoc(this); + return fromOrdToDoc(); } @Override diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 395f6a1fceef..71e1273b78df 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -454,7 +454,7 @@ public int ordToDoc(int ord) { @Override protected DocIterator createIterator() { - return fromOrdToDoc(this); + return fromOrdToDoc(); } @Override diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index 1daaba1a5b83..a79f47a85e30 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -108,7 +108,7 @@ public DenseOffHeapVectorValues copy() throws IOException { @Override protected DocIterator createIterator() { - return createDenseIterator(this); + return createDenseIterator(); } @Override diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index 489d826076e5..40c218e518b2 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -123,7 +123,7 @@ public DenseOffHeapVectorValues copy() throws IOException { @Override public DocIterator createIterator() { - return createDenseIterator(this); + return createDenseIterator(); } @Override diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index 5330fb964af9..69776bc4f826 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -119,7 +119,7 @@ public DenseOffHeapVectorValues copy() throws IOException { @Override public DocIterator createIterator() { - return createDenseIterator(this); + return createDenseIterator(); } @Override diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index 9cdbd96da38d..864100177fac 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -338,7 +338,7 @@ public int ordToDoc(int ord) { @Override protected DocIterator createIterator() { - return fromOrdToDoc(this); + return fromOrdToDoc(); } @Override @@ -430,7 +430,7 @@ public byte[] vectorValue(int ord) { @Override protected DocIterator createIterator() { - return fromOrdToDoc(this); + return fromOrdToDoc(); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/HasIndexSlice.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/HasIndexSlice.java index 8f2129fc16f0..2bfe72386a05 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/HasIndexSlice.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/HasIndexSlice.java @@ -18,6 +18,12 @@ import org.apache.lucene.store.IndexInput; +/** + * Implementors can return the IndexInput from which their values are read. For use by vector + * quantizers. + */ public interface HasIndexSlice { + + /** Returns an IndexInput from which to read this instance's values. */ IndexInput getSlice(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index 74f7060ea5dd..244ffc68b770 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -148,7 +148,7 @@ public DenseOffHeapVectorValues copy() throws IOException { @Override public DocIterator createIterator() { - return createDenseIterator(this); + return createDenseIterator(); } @Override @@ -299,7 +299,7 @@ public byte[] vectorValue(int ord) throws IOException { @Override protected DocIterator createIterator() { - return createDenseIterator(this); + return createDenseIterator(); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index db6cb157253d..0e216d22a3f9 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -154,7 +154,7 @@ public Bits getAcceptOrds(Bits acceptDocs) { @Override public DocIterator createIterator() { - return createDenseIterator(this); + return createDenseIterator(); } @Override @@ -299,7 +299,7 @@ public float[] vectorValue(int targetOrd) { @Override protected DocIterator createIterator() { - return createDenseIterator(this); + return createDenseIterator(); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index 8a111d874d92..b756047f4a5d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -881,7 +881,7 @@ public float[] vectorValue(int ord) throws IOException { @Override public DocIterator createIterator() { - return createDenseIterator(this); + return createDenseIterator(); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index b08387ca100c..8a0c3d188a0c 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -256,7 +256,7 @@ public DocIdSetIterator iterator() { @Override protected DocIterator createIterator() { - return createDenseIterator(this); + return createDenseIterator(); } } @@ -374,7 +374,7 @@ public int size() { @Override protected DocIterator createIterator() { - return createDenseIterator(this); + return createDenseIterator(); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index 052efc97a8ce..94aeb1190e37 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -112,7 +112,7 @@ public ByteVectorValues copy() { @Override protected DocIterator createIterator() { - return createDenseIterator(this); + return createDenseIterator(); } }; } diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index 84f8113577c8..fb97e5861b84 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -113,7 +113,7 @@ public FloatVectorValues copy() { @Override protected DocIterator createIterator() { - return createDenseIterator(this); + return createDenseIterator(); } }; } diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java index f397ff9055e3..44b1462aca83 100644 --- a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java @@ -31,6 +31,7 @@ */ public abstract class KnnVectorValues { + /** The iterator associated with these values. */ protected DocIterator iterator; /** Return the dimension of the vectors */ @@ -63,8 +64,10 @@ public int getVectorByteLength() { return dimension() * getEncoding().byteSize; } + /** The vector encoding of these values. */ public abstract VectorEncoding getEncoding(); + /** Returns a Bits accepting docs accepted by the argument and having a vector value */ public Bits getAcceptOrds(Bits acceptDocs) { // FIXME: change default to return acceptDocs and provide this impl // somewhere more specialized (in every non-dense impl). @@ -84,6 +87,29 @@ public int length() { }; } + /** + * Return the iterator for this instance. If you need multiple iterators, call + * this.copy().iterator(). + */ + public DocIterator iterator() { + if (iterator == null) { + iterator = createIterator(); + } + return iterator; + } + + /** + * Create an iterator for this instance; typically called once by iterator(). Wrapper + * value classes delegate to their inner instance's iterator and shouldn't implement this. + */ + protected DocIterator createIterator() { + throw new UnsupportedOperationException(); + } + + /** + * A DocIdSetIterator that also provides an index() method tracking a distinct ordinal for a + * vector associated with each doc. + */ public abstract static class DocIterator extends DocIdSetIterator { /** return the value index (aka "ordinal" or "ord") corresponding to the current doc */ @@ -96,9 +122,13 @@ public int advance(int target) throws IOException { @Override public long cost() { - throw new UnsupportedOperationException("for class " + getClass().getName()); + throw new UnsupportedOperationException(); } + /** + * Returns an iterator that delegates to the IndexedDISI. Advancing this iterator will advance + * the underlying IndexedDISI, and vice-versa. + */ public static DocIterator fromIndexedDISI(IndexedDISI disi) { // can we replace with fromDISI? return new DocIterator() { @@ -130,21 +160,11 @@ public long cost() { } } - public DocIterator iterator() { - if (iterator == null) { - iterator = createIterator(); - } - return iterator; - } - - protected DocIterator createIterator() { - // don't force every class to implement; some are just wrappers of other values and use their - // iterators - throw new UnsupportedOperationException(); - } - ; - - protected static DocIterator createDenseIterator(KnnVectorValues values) { + /** + * Creates an iterator for instances where every doc has a value, and the value ordinals are equal + * to the docids. + */ + protected DocIterator createDenseIterator() { return new DocIterator() { int doc = -1; @@ -161,7 +181,7 @@ public int index() { @Override public int nextDoc() throws IOException { - if (doc >= values.size() - 1) { + if (doc >= size() - 1) { return doc = NO_MORE_DOCS; } else { return ++doc; @@ -170,7 +190,7 @@ public int nextDoc() throws IOException { @Override public int advance(int target) { - if (target >= values.size()) { + if (target >= size()) { return doc = NO_MORE_DOCS; } return doc = target; @@ -178,11 +198,15 @@ public int advance(int target) { @Override public long cost() { - return values.size(); + return size(); } }; } + /** + * Creates an iterator from a DocIdSetIterator indicating which docs have values, and for which + * ordinals increase monotonically with docid. + */ protected static DocIterator fromDISI(DocIdSetIterator docsWithField) { return new DocIterator() { @@ -214,7 +238,11 @@ public long cost() { }; } - protected static DocIterator fromOrdToDoc(KnnVectorValues values) { + /** + * Creates an iterator from this instance's ordinal-to-docid mapping which must be monotonic + * (docid increases when ordinal does). + */ + protected DocIterator fromOrdToDoc() { return new DocIterator() { private int ord = -1; @@ -226,7 +254,7 @@ public int docID() { if (ord == NO_MORE_DOCS) { return NO_MORE_DOCS; } - return values.ordToDoc(ord); + return ordToDoc(ord); } @Override @@ -236,7 +264,7 @@ public int index() { @Override public int nextDoc() throws IOException { - if (ord >= values.size() - 1) { + if (ord >= size() - 1) { ord = NO_MORE_DOCS; } else { ++ord; @@ -246,7 +274,7 @@ public int nextDoc() throws IOException { @Override public long cost() { - return values.size(); + return size(); } }; } diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index e3e9cb66a427..a5eeb121305d 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -207,6 +207,10 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue } } + /** + * Iterator over KnnVectorValues accepting a mapping to differently-sorted docs. Consequently + * index() may skip around, not increasing monotonically as iteration proceeds. + */ public static class SortingValuesIterator extends KnnVectorValues.DocIterator { private final DocIdSetIterator docsWithValues; private final int[] docToOrd; @@ -214,6 +218,7 @@ public static class SortingValuesIterator extends KnnVectorValues.DocIterator { int doc = -1; + /** Creates an iterator accepting a mapping to differently-sorted docs. */ public SortingValuesIterator(KnnVectorValues.DocIterator iter, Sorter.DocMap docMap) throws IOException { docToOrd = new int[docMap.size()]; diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java index d88980294915..f9c362774b49 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java @@ -91,6 +91,6 @@ public byte[] vectorValue(int ord) { @Override protected DocIterator createIterator() { - return createDenseIterator(this); + return createDenseIterator(); } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java index a37c3f208037..b29445d258df 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java @@ -85,6 +85,6 @@ public float[] vectorValue(int ord) { @Override protected DocIterator createIterator() { - return createDenseIterator(this); + return createDenseIterator(); } } diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java index 679b1ebc9138..bd652f2e5f6e 100644 --- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java +++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java @@ -2311,7 +2311,7 @@ public float[] vectorValue(int ord) { @Override protected DocIterator createIterator() { - return createDenseIterator(this); + return createDenseIterator(); } @Override @@ -2369,7 +2369,7 @@ public byte[] vectorValue(int ord) { @Override protected DocIterator createIterator() { - return createDenseIterator(this); + return createDenseIterator(); } @Override From 8d087e21cad24f8415243d6848fd4613b7c81dd7 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Tue, 10 Sep 2024 14:04:38 -0400 Subject: [PATCH 10/25] rename DocIterator to DocIndexIterator --- .../lucene90/Lucene90HnswVectorsReader.java | 4 ++-- .../lucene91/Lucene91HnswVectorsReader.java | 2 +- .../lucene92/OffHeapFloatVectorValues.java | 6 ++--- .../lucene94/OffHeapByteVectorValues.java | 4 ++-- .../lucene94/OffHeapFloatVectorValues.java | 6 ++--- .../lucene90/Lucene90HnswVectorsWriter.java | 2 +- .../TestLucene90HnswVectorsFormat.java | 2 +- .../lucene91/Lucene91HnswVectorsWriter.java | 2 +- .../lucene92/Lucene92HnswVectorsWriter.java | 2 +- .../lucene94/Lucene94HnswVectorsWriter.java | 4 ++-- .../lucene95/Lucene95HnswVectorsWriter.java | 4 ++-- .../SimpleTextKnnVectorsReader.java | 4 ++-- .../SimpleTextKnnVectorsWriter.java | 4 ++-- .../codecs/BufferingKnnVectorsWriter.java | 16 ++++++------- .../lucene/codecs/KnnVectorsWriter.java | 12 +++++----- .../lucene95/OffHeapByteVectorValues.java | 8 +++---- .../lucene95/OffHeapFloatVectorValues.java | 8 +++---- .../lucene99/Lucene99FlatVectorsWriter.java | 4 ++-- .../Lucene99ScalarQuantizedVectorsReader.java | 2 +- .../Lucene99ScalarQuantizedVectorsWriter.java | 14 +++++------ .../OffHeapQuantizedByteVectorValues.java | 8 +++---- .../apache/lucene/index/ByteVectorValues.java | 2 +- .../lucene/index/ExitableDirectoryReader.java | 10 ++++---- .../lucene/index/FloatVectorValues.java | 2 +- .../apache/lucene/index/KnnVectorValues.java | 24 +++++++++---------- .../SlowCompositeCodecReaderWrapper.java | 4 ++-- .../lucene/index/SortingCodecReader.java | 8 +++---- .../util/hnsw/IncrementalHnswGraphMerger.java | 4 ++-- ...tLucene99ScalarQuantizedVectorsFormat.java | 2 +- .../index/TestExitableDirectoryReader.java | 2 +- .../util/hnsw/MockByteVectorValues.java | 2 +- .../lucene/util/hnsw/MockVectorValues.java | 2 +- .../quantization/TestScalarQuantizer.java | 4 ++-- .../lucene/index/memory/MemoryIndex.java | 4 ++-- .../index/BaseKnnVectorsFormatTestCase.java | 4 ++-- 35 files changed, 96 insertions(+), 96 deletions(-) diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index cc59c654d4c5..9ef5dc95bc21 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -409,7 +409,7 @@ public int ordToDoc(int ord) { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return fromOrdToDoc(); } @@ -426,7 +426,7 @@ public float score() throws IOException { } @Override - public DocIterator iterator() { + public DocIndexIterator iterator() { return values.iterator(); } }; diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 71e1273b78df..0863db920b67 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -453,7 +453,7 @@ public int ordToDoc(int ord) { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return fromOrdToDoc(); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index a79f47a85e30..ea66a70bba60 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -107,7 +107,7 @@ public DenseOffHeapVectorValues copy() throws IOException { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return createDenseIterator(); } @@ -171,8 +171,8 @@ public SparseOffHeapVectorValues copy() throws IOException { } @Override - protected DocIterator createIterator() { - return DocIterator.fromIndexedDISI(disi); + protected DocIndexIterator createIterator() { + return DocIndexIterator.fromIndexedDISI(disi); } @Override diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index 40c218e518b2..44a9d3dd2c97 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -122,7 +122,7 @@ public DenseOffHeapVectorValues copy() throws IOException { } @Override - public DocIterator createIterator() { + public DocIndexIterator createIterator() { return createDenseIterator(); } @@ -191,7 +191,7 @@ public int ordToDoc(int ord) { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return fromDISI(disi); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index 69776bc4f826..ac68a85e3a14 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -118,7 +118,7 @@ public DenseOffHeapVectorValues copy() throws IOException { } @Override - public DocIterator createIterator() { + public DocIndexIterator createIterator() { return createDenseIterator(); } @@ -183,8 +183,8 @@ public SparseOffHeapVectorValues copy() throws IOException { } @Override - protected DocIterator createIterator() { - return DocIterator.fromIndexedDISI(disi); + protected DocIndexIterator createIterator() { + return DocIndexIterator.fromIndexedDISI(disi); } @Override diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java index c1e45beed059..f60411752d20 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsWriter.java @@ -188,7 +188,7 @@ private static int[] writeVectorData(IndexOutput output, FloatVectorValues vecto int count = 0; ByteBuffer binaryVector = ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - KnnVectorValues.DocIterator iter = vectors.iterator(); + KnnVectorValues.DocIndexIterator iter = vectors.iterator(); for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector float[] vectorValue = vectors.vectorValue(iter.index()); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java index 2c689d5c0e5d..0a8f541cb73b 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene90/TestLucene90HnswVectorsFormat.java @@ -12,7 +12,7 @@ * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and - * limitations under the License. + * limIndexedDISIitations under the License. */ package org.apache.lucene.backward_codecs.lucene90; diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java index ec43577a0633..a984a3ef1f8b 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsWriter.java @@ -181,7 +181,7 @@ private static DocsWithFieldSet writeVectorData(IndexOutput output, FloatVectorV DocsWithFieldSet docsWithField = new DocsWithFieldSet(); ByteBuffer binaryVector = ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - KnnVectorValues.DocIterator iter = vectors.iterator(); + KnnVectorValues.DocIndexIterator iter = vectors.iterator(); for (int docV = iter.nextDoc(); docV != DocIdSetIterator.NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector float[] vectorValue = vectors.vectorValue(iter.index()); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java index 9e0e9e638809..bf1c89a536d8 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsWriter.java @@ -189,7 +189,7 @@ private static DocsWithFieldSet writeVectorData(IndexOutput output, FloatVectorV DocsWithFieldSet docsWithField = new DocsWithFieldSet(); ByteBuffer binaryVector = ByteBuffer.allocate(vectors.dimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - KnnVectorValues.DocIterator iterator = vectors.iterator(); + KnnVectorValues.DocIndexIterator iterator = vectors.iterator(); for (int docV = iterator.nextDoc(); docV != DocIdSetIterator.NO_MORE_DOCS; docV = iterator.nextDoc()) { diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java index 24b99cd30550..01698da79893 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsWriter.java @@ -586,7 +586,7 @@ private void writeMeta( private static DocsWithFieldSet writeByteVectorData( IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - KnnVectorValues.DocIterator iter = byteVectorValues.iterator(); + KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator(); for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector byte[] binaryValue = byteVectorValues.vectorValue(iter.index()); @@ -603,7 +603,7 @@ private static DocsWithFieldSet writeByteVectorData( private static DocsWithFieldSet writeVectorData( IndexOutput output, FloatVectorValues floatVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - KnnVectorValues.DocIterator iter = floatVectorValues.iterator(); + KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator(); ByteBuffer binaryVector = ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize) .order(ByteOrder.LITTLE_ENDIAN); diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java index 9cdb7f1737a3..c855d8f5e073 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsWriter.java @@ -634,7 +634,7 @@ private void writeMeta( private static DocsWithFieldSet writeByteVectorData( IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - KnnVectorValues.DocIterator iter = byteVectorValues.iterator(); + KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator(); for (int docId = iter.nextDoc(); docId != NO_MORE_DOCS; docId = iter.nextDoc()) { // write vector byte[] binaryValue = byteVectorValues.vectorValue(iter.index()); @@ -654,7 +654,7 @@ private static DocsWithFieldSet writeVectorData( ByteBuffer buffer = ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize) .order(ByteOrder.LITTLE_ENDIAN); - KnnVectorValues.DocIterator iter = floatVectorValues.iterator(); + KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator(); for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector float[] value = floatVectorValues.vectorValue(iter.index()); diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index 864100177fac..3b860c0ad537 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -337,7 +337,7 @@ public int ordToDoc(int ord) { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return fromOrdToDoc(); } @@ -429,7 +429,7 @@ public byte[] vectorValue(int ord) { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return fromOrdToDoc(); } diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java index 005fb0b5065c..eaf4b657755c 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsWriter.java @@ -78,7 +78,7 @@ public void writeField(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, throws IOException { long vectorDataOffset = vectorData.getFilePointer(); List docIds = new ArrayList<>(); - KnnVectorValues.DocIterator iter = floatVectorValues.iterator(); + KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator(); for (int docId = iter.nextDoc(); docId != NO_MORE_DOCS; docId = iter.nextDoc()) { writeFloatVectorValue(floatVectorValues, iter.index()); docIds.add(docId); @@ -100,7 +100,7 @@ public void writeField(FieldInfo fieldInfo, ByteVectorValues byteVectorValues, i throws IOException { long vectorDataOffset = vectorData.getFilePointer(); List docIds = new ArrayList<>(); - KnnVectorValues.DocIterator it = byteVectorValues.iterator(); + KnnVectorValues.DocIndexIterator it = byteVectorValues.iterator(); for (int docV = it.nextDoc(); docV != NO_MORE_DOCS; docV = it.nextDoc()) { writeByteVectorValue(byteVectorValues, it.index()); docIds.add(docV); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java index 480b8f0cbecd..0e371a88e790 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java @@ -113,7 +113,7 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { */ private static class SortingFloatVectorValues extends FloatVectorValues { private final BufferedFloatVectorValues delegate; - private final DocIterator iterator; + private final DocIndexIterator iterator; SortingFloatVectorValues( BufferedFloatVectorValues delegate, DocsWithFieldSet docsWithField, Sorter.DocMap sortMap) @@ -143,7 +143,7 @@ public SortingFloatVectorValues copy() { } @Override - public DocIterator iterator() { + public DocIndexIterator iterator() { return iterator; } } @@ -154,7 +154,7 @@ public DocIterator iterator() { */ private static class SortingByteVectorValues extends ByteVectorValues { private final BufferedByteVectorValues delegate; - private final DocIterator iterator; + private final DocIndexIterator iterator; SortingByteVectorValues( BufferedByteVectorValues delegate, DocsWithFieldSet docsWithField, Sorter.DocMap sortMap) @@ -184,7 +184,7 @@ public SortingByteVectorValues copy() { } @Override - public DocIterator iterator() { + public DocIndexIterator iterator() { return iterator; } } @@ -273,7 +273,7 @@ private static class BufferedFloatVectorValues extends FloatVectorValues { final List vectors; final int dimension; private final DocIdSet docsWithField; - private final DocIterator iterator; + private final DocIndexIterator iterator; BufferedFloatVectorValues(List vectors, int dimension, DocIdSet docsWithField) throws IOException { @@ -304,7 +304,7 @@ public float[] vectorValue(int targetOrd) { } @Override - public DocIterator iterator() { + public DocIndexIterator iterator() { return iterator; } @@ -319,7 +319,7 @@ private static class BufferedByteVectorValues extends ByteVectorValues { final List vectors; final int dimension; private final DocIdSet docsWithField; - private final DocIterator iterator; + private final DocIndexIterator iterator; BufferedByteVectorValues(List vectors, int dimension, DocIdSet docsWithField) throws IOException { @@ -345,7 +345,7 @@ public byte[] vectorValue(int targetOrd) { } @Override - public DocIterator iterator() { + public DocIndexIterator iterator() { return iterator; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index 2a551fd77050..2838e1bdf317 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -61,7 +61,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE (KnnFieldVectorsWriter) addField(fieldInfo); ByteVectorValues mergedBytes = MergedVectorValues.mergeByteVectorValues(fieldInfo, mergeState); - KnnVectorValues.DocIterator iter = mergedBytes.iterator(); + KnnVectorValues.DocIndexIterator iter = mergedBytes.iterator(); for (int doc = iter.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iter.nextDoc()) { byteWriter.addValue(doc, mergedBytes.vectorValue(iter.index())); } @@ -71,7 +71,7 @@ public void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOE (KnnFieldVectorsWriter) addField(fieldInfo); FloatVectorValues mergedFloats = MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState); - KnnVectorValues.DocIterator iter = mergedFloats.iterator(); + KnnVectorValues.DocIndexIterator iter = mergedFloats.iterator(); for (int doc = iter.nextDoc(); doc != DocIdSetIterator.NO_MORE_DOCS; doc = iter.nextDoc()) { floatWriter.addValue(doc, mergedFloats.vectorValue(iter.index())); } @@ -302,8 +302,8 @@ private MergedFloat32VectorValues(List subs, MergeState me } @Override - protected DocIterator createIterator() { - return new DocIterator() { + protected DocIndexIterator createIterator() { + return new DocIndexIterator() { private int index = -1; @Override @@ -396,8 +396,8 @@ public byte[] vectorValue(int ord) throws IOException { } @Override - protected DocIterator createIterator() { - return new DocIterator() { + protected DocIndexIterator createIterator() { + return new DocIndexIterator() { private int index = -1; @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index 244ffc68b770..5e669e2fcfe6 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -147,7 +147,7 @@ public DenseOffHeapVectorValues copy() throws IOException { } @Override - public DocIterator createIterator() { + public DocIndexIterator createIterator() { return createDenseIterator(); } @@ -232,8 +232,8 @@ public int ordToDoc(int ord) { } @Override - protected DocIterator createIterator() { - return DocIterator.fromIndexedDISI(disi); + protected DocIndexIterator createIterator() { + return DocIndexIterator.fromIndexedDISI(disi); } @Override @@ -298,7 +298,7 @@ public byte[] vectorValue(int ord) throws IOException { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return createDenseIterator(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 0e216d22a3f9..cb891fb0a68b 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -153,7 +153,7 @@ public Bits getAcceptOrds(Bits acceptDocs) { } @Override - public DocIterator createIterator() { + public DocIndexIterator createIterator() { return createDenseIterator(); } @@ -245,8 +245,8 @@ public int length() { } @Override - protected DocIterator createIterator() { - return DocIterator.fromIndexedDISI(disi); + protected DocIndexIterator createIterator() { + return DocIndexIterator.fromIndexedDISI(disi); } @Override @@ -298,7 +298,7 @@ public float[] vectorValue(int targetOrd) { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return createDenseIterator(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java index b331fca151ef..b731e758b7a8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99FlatVectorsWriter.java @@ -362,7 +362,7 @@ private void writeMeta( private static DocsWithFieldSet writeByteVectorData( IndexOutput output, ByteVectorValues byteVectorValues) throws IOException { DocsWithFieldSet docsWithField = new DocsWithFieldSet(); - KnnVectorValues.DocIterator iter = byteVectorValues.iterator(); + KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator(); for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector byte[] binaryValue = byteVectorValues.vectorValue(iter.index()); @@ -382,7 +382,7 @@ private static DocsWithFieldSet writeVectorData( ByteBuffer buffer = ByteBuffer.allocate(floatVectorValues.dimension() * VectorEncoding.FLOAT32.byteSize) .order(ByteOrder.LITTLE_ENDIAN); - KnnVectorValues.DocIterator iter = floatVectorValues.iterator(); + KnnVectorValues.DocIndexIterator iter = floatVectorValues.iterator(); for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector float[] value = floatVectorValues.vectorValue(iter.index()); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java index 90befb38ad38..08267fc68c01 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java @@ -435,7 +435,7 @@ public VectorScorer scorer(float[] query) throws IOException { } @Override - public DocIterator iterator() { + public DocIndexIterator iterator() { return rawVectorValues.iterator(); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index b756047f4a5d..833eaccb1629 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -730,7 +730,7 @@ public static DocsWithFieldSet writeQuantizedVectorData( ? OffHeapQuantizedByteVectorValues.compressedArray( quantizedByteVectorValues.dimension(), bits) : null; - KnnVectorValues.DocIterator iter = quantizedByteVectorValues.iterator(); + KnnVectorValues.DocIndexIterator iter = quantizedByteVectorValues.iterator(); for (int docV = iter.nextDoc(); docV != NO_MORE_DOCS; docV = iter.nextDoc()) { // write vector byte[] binaryValue = quantizedByteVectorValues.vectorValue(iter.index()); @@ -880,7 +880,7 @@ public float[] vectorValue(int ord) throws IOException { } @Override - public DocIterator createIterator() { + public DocIndexIterator createIterator() { return createDenseIterator(); } } @@ -972,7 +972,7 @@ public byte[] vectorValue(int ord) throws IOException { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return new CompositeIterator(); } @@ -991,7 +991,7 @@ public float getScoreCorrectionConstant(int ord) throws IOException { return current.values.getScoreCorrectionConstant(current.values.iterator().index()); } - private class CompositeIterator extends KnnVectorValues.DocIterator { + private class CompositeIterator extends DocIndexIterator { private int docId; private int ord; @@ -1090,7 +1090,7 @@ public int ordToDoc(int ord) { } @Override - public DocIterator iterator() { + public DocIndexIterator iterator() { return values.iterator(); } } @@ -1174,7 +1174,7 @@ public int ordToDoc(int ord) { } @Override - public DocIterator iterator() { + public DocIndexIterator iterator() { return in.iterator(); } } @@ -1212,7 +1212,7 @@ public float[] vectorValue(int ord) throws IOException { } @Override - public DocIterator iterator() { + public DocIndexIterator iterator() { return values.iterator(); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index 8a0c3d188a0c..72be31949f25 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -255,7 +255,7 @@ public DocIdSetIterator iterator() { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return createDenseIterator(); } } @@ -286,8 +286,8 @@ public SparseOffHeapVectorValues( } @Override - public DocIterator createIterator() { - return DocIterator.fromIndexedDISI(disi); + public DocIndexIterator createIterator() { + return DocIndexIterator.fromIndexedDISI(disi); } @Override @@ -373,7 +373,7 @@ public int size() { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return createDenseIterator(); } diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index 94aeb1190e37..f63c26d204bb 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -111,7 +111,7 @@ public ByteVectorValues copy() { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return createDenseIterator(); } }; diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java index 961681bccae6..aa49c2443ea1 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java @@ -456,7 +456,7 @@ public int size() { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return createExitableIterator(vectorValues.iterator(), queryTimeout); } @@ -499,7 +499,7 @@ public int ordToDoc(int ord) { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return createExitableIterator(vectorValues.iterator(), queryTimeout); } @@ -515,9 +515,9 @@ public ByteVectorValues copy() { } } - private static KnnVectorValues.DocIterator createExitableIterator( - KnnVectorValues.DocIterator delegate, QueryTimeout queryTimeout) { - return new KnnVectorValues.DocIterator() { + private static KnnVectorValues.DocIndexIterator createExitableIterator( + KnnVectorValues.DocIndexIterator delegate, QueryTimeout queryTimeout) { + return new KnnVectorValues.DocIndexIterator() { private int nextCheck; @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index fb97e5861b84..13c8c047d8e7 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -112,7 +112,7 @@ public FloatVectorValues copy() { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return createDenseIterator(); } }; diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java index 44b1462aca83..09423b785657 100644 --- a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java @@ -32,7 +32,7 @@ public abstract class KnnVectorValues { /** The iterator associated with these values. */ - protected DocIterator iterator; + protected DocIndexIterator iterator; /** Return the dimension of the vectors */ public abstract int dimension(); @@ -91,7 +91,7 @@ public int length() { * Return the iterator for this instance. If you need multiple iterators, call * this.copy().iterator(). */ - public DocIterator iterator() { + public DocIndexIterator iterator() { if (iterator == null) { iterator = createIterator(); } @@ -102,7 +102,7 @@ public DocIterator iterator() { * Create an iterator for this instance; typically called once by iterator(). Wrapper * value classes delegate to their inner instance's iterator and shouldn't implement this. */ - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { throw new UnsupportedOperationException(); } @@ -110,7 +110,7 @@ protected DocIterator createIterator() { * A DocIdSetIterator that also provides an index() method tracking a distinct ordinal for a * vector associated with each doc. */ - public abstract static class DocIterator extends DocIdSetIterator { + public abstract static class DocIndexIterator extends DocIdSetIterator { /** return the value index (aka "ordinal" or "ord") corresponding to the current doc */ public abstract int index(); @@ -129,9 +129,9 @@ public long cost() { * Returns an iterator that delegates to the IndexedDISI. Advancing this iterator will advance * the underlying IndexedDISI, and vice-versa. */ - public static DocIterator fromIndexedDISI(IndexedDISI disi) { + public static DocIndexIterator fromIndexedDISI(IndexedDISI disi) { // can we replace with fromDISI? - return new DocIterator() { + return new DocIndexIterator() { @Override public int docID() { return disi.docID(); @@ -164,8 +164,8 @@ public long cost() { * Creates an iterator for instances where every doc has a value, and the value ordinals are equal * to the docids. */ - protected DocIterator createDenseIterator() { - return new DocIterator() { + protected DocIndexIterator createDenseIterator() { + return new DocIndexIterator() { int doc = -1; @@ -207,8 +207,8 @@ public long cost() { * Creates an iterator from a DocIdSetIterator indicating which docs have values, and for which * ordinals increase monotonically with docid. */ - protected static DocIterator fromDISI(DocIdSetIterator docsWithField) { - return new DocIterator() { + protected static DocIndexIterator fromDISI(DocIdSetIterator docsWithField) { + return new DocIndexIterator() { int ord = -1; @@ -242,8 +242,8 @@ public long cost() { * Creates an iterator from this instance's ordinal-to-docid mapping which must be monotonic * (docid increases when ordinal does). */ - protected DocIterator fromOrdToDoc() { - return new DocIterator() { + protected DocIndexIterator fromOrdToDoc() { + return new DocIndexIterator() { private int ord = -1; @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index 2779a97dfd4a..fcfd9da83fba 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -294,7 +294,7 @@ public void checkIntegrity() throws IOException { private record DocValuesSub(T sub, int docStart) {} private static class MergedDocIterator - extends KnnVectorValues.DocIterator { + extends KnnVectorValues.DocIndexIterator { final Iterator> it; final long cost; @@ -898,7 +898,7 @@ public byte[] vectorValue(int ord) throws IOException { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return new MergedDocIterator(subs); } }; diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index a5eeb121305d..c1699388de1a 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -211,7 +211,7 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue * Iterator over KnnVectorValues accepting a mapping to differently-sorted docs. Consequently * index() may skip around, not increasing monotonically as iteration proceeds. */ - public static class SortingValuesIterator extends KnnVectorValues.DocIterator { + public static class SortingValuesIterator extends KnnVectorValues.DocIndexIterator { private final DocIdSetIterator docsWithValues; private final int[] docToOrd; private final int size; @@ -219,7 +219,7 @@ public static class SortingValuesIterator extends KnnVectorValues.DocIterator { int doc = -1; /** Creates an iterator accepting a mapping to differently-sorted docs. */ - public SortingValuesIterator(KnnVectorValues.DocIterator iter, Sorter.DocMap docMap) + public SortingValuesIterator(KnnVectorValues.DocIndexIterator iter, Sorter.DocMap docMap) throws IOException { docToOrd = new int[docMap.size()]; FixedBitSet docBits = new FixedBitSet(docMap.size()); @@ -290,7 +290,7 @@ public FloatVectorValues copy() { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { throw new IllegalStateException(); } } @@ -309,7 +309,7 @@ public byte[] vectorValue(int ord) throws IOException { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { throw new IllegalStateException(); } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java index 9188805d020a..d64961a02ee4 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/IncrementalHnswGraphMerger.java @@ -154,7 +154,7 @@ public OnHeapHnswGraph merge( */ protected final int[] getNewOrdMapping( KnnVectorValues mergedVectorValues, BitSet initializedNodes) throws IOException { - KnnVectorValues.DocIterator initializerIterator = null; + KnnVectorValues.DocIndexIterator initializerIterator = null; switch (fieldInfo.getVectorEncoding()) { case BYTE -> initializerIterator = initReader.getByteVectorValues(fieldInfo.name).iterator(); @@ -176,7 +176,7 @@ protected final int[] getNewOrdMapping( return new int[0]; } final int[] oldToNewOrdinalMap = new int[initGraphSize]; - KnnVectorValues.DocIterator mergedVectorIterator = mergedVectorValues.iterator(); + KnnVectorValues.DocIndexIterator mergedVectorIterator = mergedVectorValues.iterator(); for (int newDocId = mergedVectorIterator.nextDoc(); newDocId <= maxNewDocID; newDocId = mergedVectorIterator.nextDoc()) { diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java index 5b1e912dc64f..e02d6f49d8ac 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java @@ -173,7 +173,7 @@ public void testQuantizedVectorsWriteAndRead() throws Exception { QuantizedByteVectorValues quantizedByteVectorValues = quantizedReader.getQuantizedVectorValues("f"); int docId = -1; - KnnVectorValues.DocIterator iter = quantizedByteVectorValues.iterator(); + KnnVectorValues.DocIndexIterator iter = quantizedByteVectorValues.iterator(); for (docId = iter.nextDoc(); docId != NO_MORE_DOCS; docId = iter.nextDoc()) { byte[] vector = quantizedByteVectorValues.vectorValue(iter.index()); float offset = quantizedByteVectorValues.getScoreCorrectionConstant(iter.index()); diff --git a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java index f60cd28c0ee1..d03c8cf42b59 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java @@ -565,7 +565,7 @@ public void testByteVectorValues() throws IOException { } private static void scanAndRetrieve(LeafReader leaf, KnnVectorValues values) throws IOException { - KnnVectorValues.DocIterator iter = values.iterator(); + KnnVectorValues.DocIndexIterator iter = values.iterator(); for (iter.nextDoc(); iter.docID() != DocIdSetIterator.NO_MORE_DOCS && iter.docID() < leaf.maxDoc(); ) { int docId = iter.docID(); diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java index f9c362774b49..e67c0b8f4980 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java @@ -90,7 +90,7 @@ public byte[] vectorValue(int ord) { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return createDenseIterator(); } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java index b29445d258df..45d6006a10d3 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java @@ -84,7 +84,7 @@ public float[] vectorValue(int ord) { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return createDenseIterator(); } } diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index 4f09decd5eb2..17e61b8a22ab 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -316,8 +316,8 @@ public int ordToDoc(int ord) { } @Override - public DocIterator createIterator() { - return new DocIterator() { + public DocIndexIterator createIterator() { + return new DocIndexIterator() { int ord = -1; int doc = -1; diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java index bd652f2e5f6e..119c8b099dfc 100644 --- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java +++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java @@ -2310,7 +2310,7 @@ public float[] vectorValue(int ord) { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return createDenseIterator(); } @@ -2368,7 +2368,7 @@ public byte[] vectorValue(int ord) { } @Override - protected DocIterator createIterator() { + protected DocIndexIterator createIterator() { return createDenseIterator(); } diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index d66e0f88c47e..645bc0261588 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -1785,7 +1785,7 @@ public void testVectorValuesReportCorrectDocs() throws Exception { if (byteVectorValues != null) { docCount += byteVectorValues.size(); StoredFields storedFields = ctx.reader().storedFields(); - KnnVectorValues.DocIterator iter = byteVectorValues.iterator(); + KnnVectorValues.DocIndexIterator iter = byteVectorValues.iterator(); for (iter.nextDoc(); iter.docID() != NO_MORE_DOCS; iter.nextDoc()) { int ord = iter.index(); checksum += byteVectorValues.vectorValue(ord)[0]; @@ -1806,7 +1806,7 @@ public void testVectorValuesReportCorrectDocs() throws Exception { if (vectorValues != null) { docCount += vectorValues.size(); StoredFields storedFields = ctx.reader().storedFields(); - KnnVectorValues.DocIterator iter = vectorValues.iterator(); + KnnVectorValues.DocIndexIterator iter = vectorValues.iterator(); for (iter.nextDoc(); iter.docID() != NO_MORE_DOCS; iter.nextDoc()) { int ord = iter.index(); checksum += vectorValues.vectorValue(ord)[0]; From c2ae86b640f929ac76664c1f7fe7f9a1f6989ae9 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Tue, 10 Sep 2024 14:53:06 -0400 Subject: [PATCH 11/25] clean up some comments --- .../lucene/codecs/BufferingKnnVectorsWriter.java | 12 +++--------- .../org/apache/lucene/codecs/KnnVectorsWriter.java | 4 ---- 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java index 0e371a88e790..e9b777fb24e8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java @@ -107,11 +107,8 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { } } - /** - * Sorting FloatVectorValues that maps ordinals using the provided sortMap expressed in terms of - * docids - */ - private static class SortingFloatVectorValues extends FloatVectorValues { + /** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */ + private static class SortingFloatVectorValues extends FloatVectorValues { private final BufferedFloatVectorValues delegate; private final DocIndexIterator iterator; @@ -148,10 +145,7 @@ public DocIndexIterator iterator() { } } - /** - * Sorting ByteVectorValues that maps ordinals using the provided sortMap expressed in terms of - * docids - */ + /** Sorting ByteVectorValues that iterate over documents in the order of the provided sortMap */ private static class SortingByteVectorValues extends ByteVectorValues { private final BufferedByteVectorValues delegate; private final DocIndexIterator iterator; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index 2838e1bdf317..dc4544e11637 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -336,9 +336,6 @@ public long cost() { }; } - // This really only needs to support iterative access as a first-class citizen (it is never - // used for searching) so simply ignore the supplied ordinal ... we'd like to assert that - // docId == ordToDoc(ord) but we didn't implement ordToDoc @Override public float[] vectorValue(int ord) throws IOException { return current.values.vectorValue(current.values.iterator().index()); @@ -356,7 +353,6 @@ public int dimension() { @Override public int ordToDoc(int ord) { - // FIXME support ordToDoc for backward-codecs? throw new UnsupportedOperationException(); } From ff7a317a81df18ac944f07d66ef770cfd4ab3b06 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Thu, 12 Sep 2024 12:56:46 -0400 Subject: [PATCH 12/25] fix case where index is reordered --- .../codecs/BufferingKnnVectorsWriter.java | 2 +- .../lucene/codecs/KnnVectorsWriter.java | 1 + .../SlowCompositeCodecReaderWrapper.java | 164 ++++++++++++------ .../lucene/index/SortingCodecReader.java | 15 +- 4 files changed, 128 insertions(+), 54 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java index e9b777fb24e8..95cb1bdf533d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java @@ -108,7 +108,7 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { } /** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */ - private static class SortingFloatVectorValues extends FloatVectorValues { + private static class SortingFloatVectorValues extends FloatVectorValues { private final BufferedFloatVectorValues delegate; private final DocIndexIterator iterator; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index dc4544e11637..f0d6639388e3 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -338,6 +338,7 @@ public long cost() { @Override public float[] vectorValue(int ord) throws IOException { + assert ord == iterator.index(); return current.values.vectorValue(current.values.iterator().index()); } diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index fcfd9da83fba..9256f9876022 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -291,7 +291,7 @@ public void checkIntegrity() throws IOException { } } - private record DocValuesSub(T sub, int docStart) {} + private record DocValuesSub(T sub, int docStart, int ordStart) {} private static class MergedDocIterator extends KnnVectorValues.DocIndexIterator { @@ -301,7 +301,6 @@ private static class MergedDocIterator DocValuesSub current; int ord = -1; int doc = -1; - int currentOrdBase = 0; MergedDocIterator(List> subs) { long cost = 0; @@ -336,10 +335,11 @@ public int nextDoc() throws IOException { } } if (it.hasNext() == false) { + ord = NO_MORE_DOCS; return doc = NO_MORE_DOCS; } current = it.next(); - currentOrdBase = ord; + ord = current.ordStart - 1; } } @@ -817,41 +817,65 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException { int size = 0; for (CodecReader reader : codecReaders) { FloatVectorValues values = reader.getFloatVectorValues(field); + subs.add(new DocValuesSub<>(values, docStarts[i], size)); if (values != null) { if (dimension == -1) { dimension = values.dimension(); } size += values.size(); } - subs.add(new DocValuesSub<>(values, docStarts[i])); i++; } - final int finalDimension = dimension; - final int finalSize = size; - return new FloatVectorValues() { - - final MergedDocIterator iter = new MergedDocIterator<>(subs); + return new MergedFloatVectorValues(dimension, size, subs); + } - @Override - public MergedDocIterator iterator() { - return iter; + class MergedFloatVectorValues extends FloatVectorValues { + final int dimension; + final int size; + final DocValuesSub[] subs; + final MergedDocIterator iter; + final int[] starts; + int lastSubIndex; + + MergedFloatVectorValues(int dimension, int size, List> subs) { + this.dimension = dimension; + this.size = size; + this.subs = subs.toArray(new DocValuesSub[0]); + iter = new MergedDocIterator<>(subs); + // [0, start(1), ..., size] - we want the extra element + // to avoid checking for out-of-array bounds + starts = new int[subs.size() + 1]; + for (int i = 0; i < subs.size(); i++) { + starts[i] = subs.get(i).ordStart; } + starts[starts.length - 1] = size; + } - @Override - public int dimension() { - return finalDimension; - } + @Override + public MergedDocIterator iterator() { + return iter; + } - @Override - public int size() { - return finalSize; - } + @Override + public int dimension() { + return dimension; + } - @Override - public float[] vectorValue(int ord) throws IOException { - return iter.current.sub.vectorValue(ord - iter.currentOrdBase); - } - }; + @Override + public int size() { + return size; + } + + @Override + public float[] vectorValue(int ord) throws IOException { + assert ord >= 0 && ord < size; + // We need to implement fully random-access API here in order to support callers like + // SortingCodecReader that + // rely on it. + lastSubIndex = findSub(ord, lastSubIndex, starts); + return ((FloatVectorValues) subs[lastSubIndex].sub) + .vectorValue(ord - subs[lastSubIndex].ordStart); + } } @Override @@ -862,46 +886,86 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { int size = 0; for (CodecReader reader : codecReaders) { ByteVectorValues values = reader.getByteVectorValues(field); + subs.add(new DocValuesSub<>(values, docStarts[i], size)); if (values != null) { if (dimension == -1) { dimension = values.dimension(); } size += values.size(); } - subs.add(new DocValuesSub<>(values, docStarts[i])); i++; } - final int finalDimension = dimension; - final int finalSize = size; - return new ByteVectorValues() { - - final MergedDocIterator iter = new MergedDocIterator<>(subs); + return new MergedByteVectorValues(dimension, size, subs); + } - @Override - public MergedDocIterator iterator() { - return iter; + class MergedByteVectorValues extends ByteVectorValues { + final int dimension; + final int size; + final DocValuesSub[] subs; + final MergedDocIterator iter; + final int[] starts; + int lastSubIndex; + + MergedByteVectorValues(int dimension, int size, List> subs) { + this.dimension = dimension; + this.size = size; + this.subs = subs.toArray(new DocValuesSub[0]); + iter = new MergedDocIterator<>(subs); + // [0, start(1), ..., size] - we want the extra element + // to avoid checking for out-of-array bounds + starts = new int[subs.size() + 1]; + for (int i = 0; i < subs.size(); i++) { + starts[i] = subs.get(i).ordStart; } + starts[starts.length - 1] = size; + } - @Override - public int dimension() { - return finalDimension; - } + @Override + public MergedDocIterator iterator() { + return iter; + } - @Override - public int size() { - return finalSize; - } + @Override + public int dimension() { + return dimension; + } - @Override - public byte[] vectorValue(int ord) throws IOException { - return iter.current.sub.vectorValue(ord - iter.currentOrdBase); - } + @Override + public int size() { + return size; + } - @Override - protected DocIndexIterator createIterator() { - return new MergedDocIterator(subs); + @Override + public byte[] vectorValue(int ord) throws IOException { + assert ord >= 0 && ord < size; + // We need to implement fully random-access API here in order to support callers like + // SortingCodecReader that rely on it. We maintain lastSubIndex since we expect some + // repetition. + lastSubIndex = findSub(ord, lastSubIndex, starts); + return ((ByteVectorValues) subs[lastSubIndex].sub) + .vectorValue(ord - subs[lastSubIndex].ordStart); + } + } + + private static int findSub(int ord, int lastSubIndex, int[] starts) { + if (ord >= starts[lastSubIndex]) { + if (ord >= starts[lastSubIndex + 1]) { + return binarySearchStarts(starts, ord, lastSubIndex + 1, starts.length); } - }; + } else { + return binarySearchStarts(starts, ord, 0, lastSubIndex); + } + return lastSubIndex; + } + + private static int binarySearchStarts(int[] starts, int ord, int from, int to) { + int pos = Arrays.binarySearch(starts, from, to, ord); + // also subtract one since starts[] is shifted by one + if (pos < 0) { + return -2 - pos; + } else { + return pos - 1; + } } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index c1699388de1a..31b79f6b0235 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -212,6 +212,7 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue * index() may skip around, not increasing monotonically as iteration proceeds. */ public static class SortingValuesIterator extends KnnVectorValues.DocIndexIterator { + private final FixedBitSet docBits; private final DocIdSetIterator docsWithValues; private final int[] docToOrd; private final int size; @@ -222,8 +223,10 @@ public static class SortingValuesIterator extends KnnVectorValues.DocIndexIterat public SortingValuesIterator(KnnVectorValues.DocIndexIterator iter, Sorter.DocMap docMap) throws IOException { docToOrd = new int[docMap.size()]; - FixedBitSet docBits = new FixedBitSet(docMap.size()); + docBits = new FixedBitSet(docMap.size()); int count = 0; + // Note: docToOrd will contain zero for docids that have no vector. This is OK though + // because the iterator cannot be positioned on such docs for (int doc = iter.nextDoc(); doc != NO_MORE_DOCS; doc = iter.nextDoc()) { int newDocId = docMap.oldToNew(doc); if (newDocId != -1) { @@ -243,6 +246,7 @@ public int docID() { @Override public int index() { + assert docBits.get(doc); return docToOrd[doc]; } @@ -266,12 +270,15 @@ private static class SortingFloatVectorValues extends FloatVectorValues { SortingFloatVectorValues(FloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate; + // SortingValuesIterator consumes the iterator and records the docs and ord mapping iterator = new SortingValuesIterator(delegate.iterator(), sortMap); } @Override public float[] vectorValue(int ord) throws IOException { - return delegate.vectorValue(iterator.index()); + // ords are interpreted in the delegate's ord-space. + assert ord == iterator.index(); + return delegate.vectorValue(ord); } @Override @@ -300,12 +307,14 @@ private static class SortingByteVectorValues extends ByteVectorValues { SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate; + // SortingValuesIterator consumes the iterator and records the docs and ord mapping iterator = new SortingValuesIterator(delegate.iterator(), sortMap); } @Override public byte[] vectorValue(int ord) throws IOException { - return delegate.vectorValue(iterator().index()); + assert ord == iterator.index(); + return delegate.vectorValue(ord); } @Override From 9e5b9f9ba2fa634854be97e125d52a1800019c8a Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Sun, 15 Sep 2024 15:52:24 -0400 Subject: [PATCH 13/25] rename 'fromOrdToDoc' to 'all'; move fromIndexedDISI to codecs/lucene90 --- .../lucene90/Lucene90HnswVectorsReader.java | 2 +- .../lucene91/Lucene91HnswVectorsReader.java | 2 +- .../lucene92/OffHeapFloatVectorValues.java | 2 +- .../lucene94/OffHeapFloatVectorValues.java | 2 +- .../SimpleTextKnnVectorsReader.java | 4 +- .../lucene/codecs/lucene90/IndexedDISI.java | 36 +++++++++++++++++ .../lucene95/OffHeapByteVectorValues.java | 2 +- .../lucene95/OffHeapFloatVectorValues.java | 2 +- .../OffHeapQuantizedByteVectorValues.java | 2 +- .../apache/lucene/index/KnnVectorValues.java | 39 ++----------------- 10 files changed, 49 insertions(+), 44 deletions(-) diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index 9ef5dc95bc21..371e7ada08ed 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -410,7 +410,7 @@ public int ordToDoc(int ord) { @Override protected DocIndexIterator createIterator() { - return fromOrdToDoc(); + return all(); } @Override diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 0863db920b67..009435af9752 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -454,7 +454,7 @@ public int ordToDoc(int ord) { @Override protected DocIndexIterator createIterator() { - return fromOrdToDoc(); + return all(); } @Override diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index ea66a70bba60..5b7b19db6015 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -172,7 +172,7 @@ public SparseOffHeapVectorValues copy() throws IOException { @Override protected DocIndexIterator createIterator() { - return DocIndexIterator.fromIndexedDISI(disi); + return IndexedDISI.asDocIndexIterator(disi); } @Override diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index ac68a85e3a14..07b7beca03d4 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -184,7 +184,7 @@ public SparseOffHeapVectorValues copy() throws IOException { @Override protected DocIndexIterator createIterator() { - return DocIndexIterator.fromIndexedDISI(disi); + return IndexedDISI.asDocIndexIterator(disi); } @Override diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index 3b860c0ad537..7e87a4435368 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -338,7 +338,7 @@ public int ordToDoc(int ord) { @Override protected DocIndexIterator createIterator() { - return fromOrdToDoc(); + return all(); } @Override @@ -430,7 +430,7 @@ public byte[] vectorValue(int ord) { @Override protected DocIndexIterator createIterator() { - return fromOrdToDoc(); + return all(); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java index a2b2c84e12ae..f29f090849cd 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java @@ -18,6 +18,8 @@ import java.io.DataInput; import java.io.IOException; + +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; @@ -439,6 +441,40 @@ public static RandomAccessInput createJumpTable( // ALL variables int gap; + /** + * Returns an iterator that delegates to the IndexedDISI. Advancing this iterator will advance + * the underlying IndexedDISI, and vice-versa. + */ + public static KnnVectorValues.DocIndexIterator asDocIndexIterator(IndexedDISI disi) { + // can we replace with fromDISI? + return new KnnVectorValues.DocIndexIterator() { + @Override + public int docID() { + return disi.docID(); + } + + @Override + public int index() { + return disi.index(); + } + + @Override + public int nextDoc() throws IOException { + return disi.nextDoc(); + } + + @Override + public int advance(int target) throws IOException { + return disi.advance(target); + } + + @Override + public long cost() { + return disi.cost(); + } + }; + } + @Override public int docID() { return doc; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index 5e669e2fcfe6..2b8f7a282f05 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -233,7 +233,7 @@ public int ordToDoc(int ord) { @Override protected DocIndexIterator createIterator() { - return DocIndexIterator.fromIndexedDISI(disi); + return IndexedDISI.asDocIndexIterator(disi); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index cb891fb0a68b..0872026d8e32 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -246,7 +246,7 @@ public int length() { @Override protected DocIndexIterator createIterator() { - return DocIndexIterator.fromIndexedDISI(disi); + return IndexedDISI.asDocIndexIterator(disi); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index 72be31949f25..6d1fea07b4f9 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -287,7 +287,7 @@ public SparseOffHeapVectorValues( @Override public DocIndexIterator createIterator() { - return DocIndexIterator.fromIndexedDISI(disi); + return IndexedDISI.asDocIndexIterator(disi); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java index 09423b785657..125f57432574 100644 --- a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java @@ -17,7 +17,7 @@ package org.apache.lucene.index; import java.io.IOException; -import org.apache.lucene.codecs.lucene90.IndexedDISI; + import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.search.DocIdSetIterator; @@ -57,6 +57,8 @@ public int ordToDoc(int ord) { * Creates a new copy of this {@link KnnVectorValues}. This is helpful when you need to access * different values at once, to avoid overwriting the underlying vector returned. */ + // FIXME: replace with some kind of wrapper??? Or can we eliminate completely given + // that we no longer have seek() API? public abstract KnnVectorValues copy() throws IOException; /** Returns the vector byte length, defaults to dimension multiplied by float byte size */ @@ -125,39 +127,6 @@ public long cost() { throw new UnsupportedOperationException(); } - /** - * Returns an iterator that delegates to the IndexedDISI. Advancing this iterator will advance - * the underlying IndexedDISI, and vice-versa. - */ - public static DocIndexIterator fromIndexedDISI(IndexedDISI disi) { - // can we replace with fromDISI? - return new DocIndexIterator() { - @Override - public int docID() { - return disi.docID(); - } - - @Override - public int index() { - return disi.index(); - } - - @Override - public int nextDoc() throws IOException { - return disi.nextDoc(); - } - - @Override - public int advance(int target) throws IOException { - return disi.advance(target); - } - - @Override - public long cost() { - return disi.cost(); - } - }; - } } /** @@ -242,7 +211,7 @@ public long cost() { * Creates an iterator from this instance's ordinal-to-docid mapping which must be monotonic * (docid increases when ordinal does). */ - protected DocIndexIterator fromOrdToDoc() { + protected DocIndexIterator all() { return new DocIndexIterator() { private int ord = -1; From d43785d5a68c08280f0d06edfba7e89feefe28cf Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Sun, 15 Sep 2024 18:30:37 -0400 Subject: [PATCH 14/25] no default advance(); default cost() unsupported --- .../codecs/BufferingKnnVectorsWriter.java | 11 ++- .../lucene/codecs/KnnVectorsWriter.java | 8 +- .../lucene/codecs/lucene90/IndexedDISI.java | 5 +- .../Lucene99ScalarQuantizedVectorsWriter.java | 9 +- .../lucene/index/ExitableDirectoryReader.java | 10 +++ .../apache/lucene/index/KnnVectorValues.java | 17 ++-- .../SlowCompositeCodecReaderWrapper.java | 12 +-- .../lucene/index/SortingCodecReader.java | 84 +++++++++++++------ .../quantization/TestScalarQuantizer.java | 5 ++ 9 files changed, 104 insertions(+), 57 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java index 95cb1bdf533d..d4a782eeb863 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java @@ -20,12 +20,14 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; +import java.util.function.Supplier; import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.DocsWithFieldSet; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.SortingCodecReader; import org.apache.lucene.index.SortingCodecReader.SortingValuesIterator; import org.apache.lucene.search.DocIdSet; import org.apache.lucene.util.ArrayUtil; @@ -110,13 +112,14 @@ public void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { /** Sorting FloatVectorValues that iterate over documents in the order of the provided sortMap */ private static class SortingFloatVectorValues extends FloatVectorValues { private final BufferedFloatVectorValues delegate; - private final DocIndexIterator iterator; + private final Supplier iteratorSupplier; SortingFloatVectorValues( BufferedFloatVectorValues delegate, DocsWithFieldSet docsWithField, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate.copy(); - iterator = new SortingValuesIterator(delegate.copy().iterator(), sortMap); + iteratorSupplier = SortingCodecReader.iteratorSupplier(delegate, sortMap); + iterator = iteratorSupplier.get(); } @Override @@ -148,13 +151,15 @@ public DocIndexIterator iterator() { /** Sorting ByteVectorValues that iterate over documents in the order of the provided sortMap */ private static class SortingByteVectorValues extends ByteVectorValues { private final BufferedByteVectorValues delegate; + private final Supplier iteratorSupplier; private final DocIndexIterator iterator; SortingByteVectorValues( BufferedByteVectorValues delegate, DocsWithFieldSet docsWithField, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate; - iterator = new SortingValuesIterator(delegate.copy().iterator(), sortMap); + iteratorSupplier = SortingCodecReader.iteratorSupplier(delegate, sortMap); + iterator = iteratorSupplier.get(); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index f0d6639388e3..ae2cac5ccb39 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -330,8 +330,8 @@ public int nextDoc() throws IOException { } @Override - public long cost() { - return size; + public int advance(int target) throws IOException { + throw new UnsupportedOperationException(); } }; } @@ -421,8 +421,8 @@ public int nextDoc() throws IOException { } @Override - public long cost() { - return size; + public int advance(int target) throws IOException { + throw new UnsupportedOperationException(); } }; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java index f29f090849cd..dbd56125fcd1 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java @@ -18,7 +18,6 @@ import java.io.DataInput; import java.io.IOException; - import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.store.IndexInput; @@ -442,8 +441,8 @@ public static RandomAccessInput createJumpTable( int gap; /** - * Returns an iterator that delegates to the IndexedDISI. Advancing this iterator will advance - * the underlying IndexedDISI, and vice-versa. + * Returns an iterator that delegates to the IndexedDISI. Advancing this iterator will advance the + * underlying IndexedDISI, and vice-versa. */ public static KnnVectorValues.DocIndexIterator asDocIndexIterator(IndexedDISI disi) { // can we replace with fromDISI? diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index 833eaccb1629..e2578bf5cd68 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -19,9 +19,7 @@ import static org.apache.lucene.codecs.KnnVectorsWriter.MergedVectorValues.hasVectorValues; import static org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat.DIRECT_MONOTONIC_BLOCK_SHIFT; -import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.DYNAMIC_CONFIDENCE_INTERVAL; -import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.QUANTIZED_VECTOR_COMPONENT; -import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.calculateDefaultConfidenceInterval; +import static org.apache.lucene.codecs.lucene99.Lucene99ScalarQuantizedVectorsFormat.*; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.apache.lucene.util.RamUsageEstimator.shallowSizeOfInstance; @@ -1022,6 +1020,11 @@ public int nextDoc() throws IOException { } return docId; } + + @Override + public int advance(int target) throws IOException { + throw new UnsupportedOperationException(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java index aa49c2443ea1..f458aef87e45 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java @@ -540,6 +540,16 @@ public int nextDoc() throws IOException { return doc; } + @Override + public int advance(int target) throws IOException { + int doc = delegate.advance(target); + if (doc >= nextCheck) { + checkAndThrow(); + nextCheck = doc + ExitableFilterAtomicReader.DOCS_BETWEEN_TIMEOUT_CHECK; + } + return doc; + } + private void checkAndThrow() { if (queryTimeout.shouldExit()) { throw new ExitingReaderException( diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java index 125f57432574..ac4cd865f9fc 100644 --- a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java @@ -17,7 +17,6 @@ package org.apache.lucene.index; import java.io.IOException; - import org.apache.lucene.document.KnnByteVectorField; import org.apache.lucene.document.KnnFloatVectorField; import org.apache.lucene.search.DocIdSetIterator; @@ -117,16 +116,10 @@ public abstract static class DocIndexIterator extends DocIdSetIterator { /** return the value index (aka "ordinal" or "ord") corresponding to the current doc */ public abstract int index(); - @Override - public int advance(int target) throws IOException { - return slowAdvance(target); - } - @Override public long cost() { throw new UnsupportedOperationException(); } - } /** @@ -200,6 +193,11 @@ public int nextDoc() throws IOException { return docsWithField.nextDoc(); } + @Override + public int advance(int target) throws IOException { + return docsWithField.advance(target); + } + @Override public long cost() { return docsWithField.cost(); @@ -241,6 +239,11 @@ public int nextDoc() throws IOException { return docID(); } + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + @Override public long cost() { return size(); diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index 9256f9876022..a0a355d689ae 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -297,19 +297,11 @@ private static class MergedDocIterator extends KnnVectorValues.DocIndexIterator { final Iterator> it; - final long cost; DocValuesSub current; int ord = -1; int doc = -1; MergedDocIterator(List> subs) { - long cost = 0; - for (DocValuesSub sub : subs) { - if (sub.sub != null) { - cost += sub.sub.size(); - } - } - this.cost = cost; this.it = subs.iterator(); current = it.next(); } @@ -344,8 +336,8 @@ public int nextDoc() throws IOException { } @Override - public long cost() { - return cost; + public int advance(int target) throws IOException { + throw new UnsupportedOperationException(); } } diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index 31b79f6b0235..85cd22830b11 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -25,6 +25,7 @@ import java.util.Iterator; import java.util.Map; import java.util.Objects; +import java.util.function.Supplier; import org.apache.lucene.codecs.DocValuesProducer; import org.apache.lucene.codecs.FieldsProducer; import org.apache.lucene.codecs.KnnVectorsReader; @@ -207,6 +208,48 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue } } + // FIXME rename this better? + static class IteratorSupplier implements Supplier { + private final FixedBitSet docBits; + private final int[] docToOrd; + private final int size; + + IteratorSupplier(FixedBitSet docBits, int[] docToOrd, int size) { + this.docBits = docBits; + this.docToOrd = docToOrd; + this.size = size; + } + + @Override + public SortingValuesIterator get() { + return new SortingValuesIterator(docBits, docToOrd, size); + } + + public int size() { + return size; + } + } + + public static IteratorSupplier iteratorSupplier(KnnVectorValues values, Sorter.DocMap docMap) + throws IOException { + + final int[] docToOrd = new int[docMap.size()]; + final FixedBitSet docBits = new FixedBitSet(docMap.size()); + int count = 0; + // Note: docToOrd will contain zero for docids that have no vector. This is OK though + // because the iterator cannot be positioned on such docs + KnnVectorValues.DocIndexIterator iter = values.iterator(); + for (int doc = iter.nextDoc(); doc != NO_MORE_DOCS; doc = iter.nextDoc()) { + int newDocId = docMap.oldToNew(doc); + if (newDocId != -1) { + docToOrd[newDocId] = iter.index(); + docBits.set(newDocId); + ++count; + } + } + return new IteratorSupplier(docBits, docToOrd, count); + } + /** * Iterator over KnnVectorValues accepting a mapping to differently-sorted docs. Consequently * index() may skip around, not increasing monotonically as iteration proceeds. @@ -215,28 +258,13 @@ public static class SortingValuesIterator extends KnnVectorValues.DocIndexIterat private final FixedBitSet docBits; private final DocIdSetIterator docsWithValues; private final int[] docToOrd; - private final int size; int doc = -1; - /** Creates an iterator accepting a mapping to differently-sorted docs. */ - public SortingValuesIterator(KnnVectorValues.DocIndexIterator iter, Sorter.DocMap docMap) - throws IOException { - docToOrd = new int[docMap.size()]; - docBits = new FixedBitSet(docMap.size()); - int count = 0; - // Note: docToOrd will contain zero for docids that have no vector. This is OK though - // because the iterator cannot be positioned on such docs - for (int doc = iter.nextDoc(); doc != NO_MORE_DOCS; doc = iter.nextDoc()) { - int newDocId = docMap.oldToNew(doc); - if (newDocId != -1) { - docToOrd[newDocId] = iter.index(); - docBits.set(newDocId); - ++count; - } - } - size = count; - docsWithValues = new BitSetIterator(docBits, count); + SortingValuesIterator(FixedBitSet docBits, int[] docToOrd, int size) { + this.docBits = docBits; + this.docToOrd = docToOrd; + docsWithValues = new BitSetIterator(docBits, size); } @Override @@ -259,19 +287,20 @@ public int nextDoc() throws IOException { } @Override - public long cost() { - return size; + public int advance(int target) { + throw new UnsupportedOperationException(); } } /** Sorting FloatVectorValues that maps ordinals using the provided sortMap */ private static class SortingFloatVectorValues extends FloatVectorValues { final FloatVectorValues delegate; + final IteratorSupplier iteratorSupplier; SortingFloatVectorValues(FloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate; // SortingValuesIterator consumes the iterator and records the docs and ord mapping - iterator = new SortingValuesIterator(delegate.iterator(), sortMap); + iteratorSupplier = iteratorSupplier(delegate, sortMap); } @Override @@ -288,7 +317,7 @@ public int dimension() { @Override public int size() { - return (int) iterator.cost(); + return iteratorSupplier.size(); } @Override @@ -298,17 +327,18 @@ public FloatVectorValues copy() { @Override protected DocIndexIterator createIterator() { - throw new IllegalStateException(); + return iteratorSupplier.get(); } } private static class SortingByteVectorValues extends ByteVectorValues { final ByteVectorValues delegate; + final IteratorSupplier iteratorSupplier; SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate; // SortingValuesIterator consumes the iterator and records the docs and ord mapping - iterator = new SortingValuesIterator(delegate.iterator(), sortMap); + iteratorSupplier = iteratorSupplier(delegate, sortMap); } @Override @@ -319,7 +349,7 @@ public byte[] vectorValue(int ord) throws IOException { @Override protected DocIndexIterator createIterator() { - throw new IllegalStateException(); + return iteratorSupplier.get(); } @Override @@ -329,7 +359,7 @@ public int dimension() { @Override public int size() { - return (int) iterator().cost(); + return iteratorSupplier.size(); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index 17e61b8a22ab..68551f275ce7 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -343,6 +343,11 @@ public int nextDoc() throws IOException { public int index() { return ord; } + + @Override + public int advance(int target) throws IOException { + throw new UnsupportedOperationException(); + } }; } From 787e89cafb9ac7831aa8aae9976d471cb1faefe5 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Sun, 15 Sep 2024 21:03:17 -0400 Subject: [PATCH 15/25] make iterator() API sane --- .../lucene90/Lucene90HnswVectorsReader.java | 7 +- .../lucene91/Lucene91HnswVectorsReader.java | 8 +- .../lucene92/OffHeapFloatVectorValues.java | 14 ++- .../lucene94/OffHeapByteVectorValues.java | 14 ++- .../lucene94/OffHeapFloatVectorValues.java | 15 ++- .../TestBasicBackwardsCompatibility.java | 8 +- .../SimpleTextKnnVectorsReader.java | 15 ++- .../codecs/BufferingKnnVectorsWriter.java | 7 +- .../lucene/codecs/KnnVectorsWriter.java | 33 ++++-- .../lucene95/OffHeapByteVectorValues.java | 11 +- .../lucene95/OffHeapFloatVectorValues.java | 16 +-- .../Lucene99ScalarQuantizedVectorsWriter.java | 18 ++- .../OffHeapQuantizedByteVectorValues.java | 18 +-- .../apache/lucene/index/ByteVectorValues.java | 2 +- .../lucene/index/ExitableDirectoryReader.java | 4 +- .../lucene/index/FloatVectorValues.java | 2 +- .../apache/lucene/index/KnnVectorValues.java | 19 +--- .../lucene/index/SortingCodecReader.java | 6 +- .../util/quantization/ScalarQuantizer.java | 27 +++-- .../org/apache/lucene/document/TestField.java | 11 +- .../org/apache/lucene/index/TestKnnGraph.java | 9 +- .../lucene/index/TestSortingCodecReader.java | 5 +- .../lucene/util/hnsw/HnswGraphTestCase.java | 17 ++- .../util/hnsw/MockByteVectorValues.java | 2 +- .../lucene/util/hnsw/MockVectorValues.java | 2 +- .../TestScalarQuantizedVectorSimilarity.java | 3 - .../quantization/TestScalarQuantizer.java | 2 +- .../lucene/index/memory/MemoryIndex.java | 14 ++- .../lucene/index/memory/TestMemoryIndex.java | 11 +- .../index/BaseKnnVectorsFormatTestCase.java | 106 +++++++++++------- 30 files changed, 234 insertions(+), 192 deletions(-) diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index 371e7ada08ed..cc658ea714a8 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -409,7 +409,7 @@ public int ordToDoc(int ord) { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return all(); } @@ -419,15 +419,16 @@ public VectorScorer scorer(float[] target) { return null; } OffHeapFloatVectorValues values = this.copy(); + DocIndexIterator iterator = values.iterator(); return new VectorScorer() { @Override public float score() throws IOException { - return values.similarityFunction.compare(values.vectorValue(iterator().index()), target); + return values.similarityFunction.compare(values.vectorValue(iterator.index()), target); } @Override public DocIndexIterator iterator() { - return values.iterator(); + return iterator; } }; } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 009435af9752..42c9c0257514 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -453,7 +453,7 @@ public int ordToDoc(int ord) { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return all(); } @@ -463,16 +463,16 @@ public VectorScorer scorer(float[] target) { return null; } OffHeapFloatVectorValues values = this.copy(); + DocIndexIterator iterator = values.iterator(); return new VectorScorer() { @Override public float score() throws IOException { - return values.similarityFunction.compare( - values.vectorValue(values.iterator().index()), target); + return values.similarityFunction.compare(values.vectorValue(iterator.index()), target); } @Override public DocIdSetIterator iterator() { - return values.iterator(); + return iterator; } }; } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index 5b7b19db6015..c53594f36a4d 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -107,7 +107,7 @@ public DenseOffHeapVectorValues copy() throws IOException { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createDenseIterator(); } @@ -119,16 +119,17 @@ public Bits getAcceptOrds(Bits acceptDocs) { @Override public VectorScorer scorer(float[] query) throws IOException { DenseOffHeapVectorValues values = this.copy(); + DocIndexIterator iterator = values.iterator(); return new VectorScorer() { @Override public float score() throws IOException { return values.vectorSimilarityFunction.compare( - values.vectorValue(values.iterator().index()), query); + values.vectorValue(iterator.index()), query); } @Override public DocIdSetIterator iterator() { - return values.iterator(); + return iterator; } }; } @@ -171,7 +172,7 @@ public SparseOffHeapVectorValues copy() throws IOException { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return IndexedDISI.asDocIndexIterator(disi); } @@ -201,16 +202,17 @@ public int length() { @Override public VectorScorer scorer(float[] query) throws IOException { SparseOffHeapVectorValues values = this.copy(); + DocIndexIterator iterator = values.iterator(); return new VectorScorer() { @Override public float score() throws IOException { return values.vectorSimilarityFunction.compare( - values.vectorValue(values.iterator().index()), query); + values.vectorValue(iterator.index()), query); } @Override public DocIdSetIterator iterator() { - return values.iterator(); + return iterator; } }; } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index 44a9d3dd2c97..aedda7a6258c 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -122,7 +122,7 @@ public DenseOffHeapVectorValues copy() throws IOException { } @Override - public DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createDenseIterator(); } @@ -134,15 +134,16 @@ public Bits getAcceptOrds(Bits acceptDocs) { @Override public VectorScorer scorer(byte[] query) throws IOException { DenseOffHeapVectorValues copy = this.copy(); + DocIndexIterator iterator = copy.iterator(); return new VectorScorer() { @Override public float score() throws IOException { - return vectorSimilarityFunction.compare(copy.vectorValue(copy.iterator().docID()), query); + return vectorSimilarityFunction.compare(copy.vectorValue(iterator.index()), query); } @Override public DocIdSetIterator iterator() { - return copy.iterator(); + return iterator; } }; } @@ -191,7 +192,7 @@ public int ordToDoc(int ord) { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return fromDISI(disi); } @@ -216,15 +217,16 @@ public int length() { @Override public VectorScorer scorer(byte[] query) throws IOException { SparseOffHeapVectorValues copy = this.copy(); + IndexedDISI disi = copy.disi; return new VectorScorer() { @Override public float score() throws IOException { - return vectorSimilarityFunction.compare(copy.vectorValue(copy.disi.index()), query); + return vectorSimilarityFunction.compare(copy.vectorValue(disi.index()), query); } @Override public DocIdSetIterator iterator() { - return copy.disi; + return disi; } }; } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index 07b7beca03d4..02664837982b 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -118,7 +118,7 @@ public DenseOffHeapVectorValues copy() throws IOException { } @Override - public DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createDenseIterator(); } @@ -130,16 +130,18 @@ public Bits getAcceptOrds(Bits acceptDocs) { @Override public VectorScorer scorer(float[] query) throws IOException { DenseOffHeapVectorValues values = this.copy(); + DocIndexIterator iterator = values.iterator(); + return new VectorScorer() { @Override public float score() throws IOException { return values.vectorSimilarityFunction.compare( - values.vectorValue(values.iterator.index()), query); + values.vectorValue(iterator.index()), query); } @Override public DocIdSetIterator iterator() { - return values.iterator(); + return iterator; } }; } @@ -183,7 +185,7 @@ public SparseOffHeapVectorValues copy() throws IOException { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return IndexedDISI.asDocIndexIterator(disi); } @@ -213,16 +215,17 @@ public int length() { @Override public VectorScorer scorer(float[] query) throws IOException { SparseOffHeapVectorValues values = this.copy(); + DocIndexIterator iterator = values.iterator(); return new VectorScorer() { @Override public float score() throws IOException { return values.vectorSimilarityFunction.compare( - values.vectorValue(values.iterator().index()), query); + values.vectorValue(iterator.index()), query); } @Override public DocIdSetIterator iterator() { - return values.iterator(); + return iterator; } }; } diff --git a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBasicBackwardsCompatibility.java b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBasicBackwardsCompatibility.java index a223d24608ad..cf50b9e1526d 100644 --- a/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBasicBackwardsCompatibility.java +++ b/lucene/backward-codecs/src/test/org/apache/lucene/backward_index/TestBasicBackwardsCompatibility.java @@ -52,6 +52,7 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; import org.apache.lucene.index.IndexableField; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.LogByteSizeMergePolicy; import org.apache.lucene.index.MultiBits; @@ -477,14 +478,13 @@ public static void searchIndex( FloatVectorValues values = ctx.reader().getFloatVectorValues(KNN_VECTOR_FIELD); if (values != null) { assertEquals(KNN_VECTOR_FIELD_TYPE.vectorDimension(), values.dimension()); - for (int doc = values.iterator().nextDoc(); - doc != NO_MORE_DOCS; - doc = values.iterator().nextDoc()) { + KnnVectorValues.DocIndexIterator it = values.iterator(); + for (int doc = it.nextDoc(); doc != NO_MORE_DOCS; doc = it.nextDoc()) { float[] expectedVector = {KNN_VECTOR[0], KNN_VECTOR[1], KNN_VECTOR[2] + 0.1f * cnt}; assertArrayEquals( "vectors do not match for doc=" + cnt, expectedVector, - values.vectorValue(values.iterator().index()), + values.vectorValue(it.index()), 0); cnt++; } diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index 7e87a4435368..ce03705fa467 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -337,7 +337,7 @@ public int ordToDoc(int ord) { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return all(); } @@ -348,10 +348,11 @@ public VectorScorer scorer(float[] target) { } SimpleTextFloatVectorValues simpleTextFloatVectorValues = new SimpleTextFloatVectorValues(this); + DocIndexIterator iterator = simpleTextFloatVectorValues.iterator(); return new VectorScorer() { @Override public float score() throws IOException { - int ord = simpleTextFloatVectorValues.iterator().index(); + int ord = iterator.index(); return entry .similarityFunction() .compare(simpleTextFloatVectorValues.vectorValue(ord), target); @@ -359,7 +360,7 @@ public float score() throws IOException { @Override public DocIdSetIterator iterator() { - return simpleTextFloatVectorValues.iterator(); + return iterator; } }; } @@ -429,7 +430,7 @@ public byte[] vectorValue(int ord) { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return all(); } @@ -440,9 +441,11 @@ public VectorScorer scorer(byte[] target) { } SimpleTextByteVectorValues simpleTextByteVectorValues = new SimpleTextByteVectorValues(this); return new VectorScorer() { + DocIndexIterator it = simpleTextByteVectorValues.iterator(); + @Override public float score() throws IOException { - int ord = simpleTextByteVectorValues.iterator().index(); + int ord = it.index(); return entry .similarityFunction() .compare(simpleTextByteVectorValues.vectorValue(ord), target); @@ -450,7 +453,7 @@ public float score() throws IOException { @Override public DocIdSetIterator iterator() { - return simpleTextByteVectorValues.iterator(); + return it; } }; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java index d4a782eeb863..96b0f75a259f 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/BufferingKnnVectorsWriter.java @@ -119,7 +119,6 @@ private static class SortingFloatVectorValues extends FloatVectorValues { throws IOException { this.delegate = delegate.copy(); iteratorSupplier = SortingCodecReader.iteratorSupplier(delegate, sortMap); - iterator = iteratorSupplier.get(); } @Override @@ -144,7 +143,7 @@ public SortingFloatVectorValues copy() { @Override public DocIndexIterator iterator() { - return iterator; + return iteratorSupplier.get(); } } @@ -152,14 +151,12 @@ public DocIndexIterator iterator() { private static class SortingByteVectorValues extends ByteVectorValues { private final BufferedByteVectorValues delegate; private final Supplier iteratorSupplier; - private final DocIndexIterator iterator; SortingByteVectorValues( BufferedByteVectorValues delegate, DocsWithFieldSet docsWithField, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate; iteratorSupplier = SortingCodecReader.iteratorSupplier(delegate, sortMap); - iterator = iteratorSupplier.get(); } @Override @@ -184,7 +181,7 @@ public SortingByteVectorValues copy() { @Override public DocIndexIterator iterator() { - return iterator; + return iteratorSupplier.get(); } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index ae2cac5ccb39..0ff782781ffe 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -116,32 +116,44 @@ public final void merge(MergeState mergeState) throws IOException { private static class FloatVectorValuesSub extends DocIDMerger.Sub { final FloatVectorValues values; + final KnnVectorValues.DocIndexIterator iterator; FloatVectorValuesSub(MergeState.DocMap docMap, FloatVectorValues values) { super(docMap); this.values = values; - assert values.iterator().docID() == -1; + this.iterator = values.iterator(); + assert iterator.docID() == -1; } @Override public int nextDoc() throws IOException { - return values.iterator().nextDoc(); + return iterator.nextDoc(); + } + + public int index() { + return iterator.index(); } } private static class ByteVectorValuesSub extends DocIDMerger.Sub { final ByteVectorValues values; + final KnnVectorValues.DocIndexIterator iterator; ByteVectorValuesSub(MergeState.DocMap docMap, ByteVectorValues values) { super(docMap); this.values = values; - assert values.iterator().docID() == -1; + iterator = values.iterator(); + assert iterator.docID() == -1; } @Override public int nextDoc() throws IOException { - return values.iterator().nextDoc(); + return iterator.nextDoc(); + } + + int index() { + return iterator.index(); } } @@ -302,7 +314,7 @@ private MergedFloat32VectorValues(List subs, MergeState me } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return new DocIndexIterator() { private int index = -1; @@ -338,8 +350,9 @@ public int advance(int target) throws IOException { @Override public float[] vectorValue(int ord) throws IOException { - assert ord == iterator.index(); - return current.values.vectorValue(current.values.iterator().index()); + // FIXME what can we assert here? + // assert ord == iterator.index(); + return current.values.vectorValue(current.index()); } @Override @@ -389,11 +402,13 @@ private MergedByteVectorValues(List subs, MergeState mergeS @Override public byte[] vectorValue(int ord) throws IOException { - return current.values.vectorValue(current.values.iterator().index()); + // FIXME + // assert ord == iterator.index(); + return current.values.vectorValue(current.index()); } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return new DocIndexIterator() { private int index = -1; diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index 2b8f7a282f05..1e78c8ea7aa2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -147,7 +147,7 @@ public DenseOffHeapVectorValues copy() throws IOException { } @Override - public DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createDenseIterator(); } @@ -159,17 +159,18 @@ public Bits getAcceptOrds(Bits acceptDocs) { @Override public VectorScorer scorer(byte[] query) throws IOException { DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); RandomVectorScorer scorer = flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); return new VectorScorer() { @Override public float score() throws IOException { - return scorer.score(copy.iterator().docID()); + return scorer.score(iterator.docID()); } @Override public DocIdSetIterator iterator() { - return copy.iterator(); + return iterator; } }; } @@ -232,7 +233,7 @@ public int ordToDoc(int ord) { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return IndexedDISI.asDocIndexIterator(disi); } @@ -298,7 +299,7 @@ public byte[] vectorValue(int ord) throws IOException { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createDenseIterator(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 0872026d8e32..2384657e93e1 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -153,24 +153,25 @@ public Bits getAcceptOrds(Bits acceptDocs) { } @Override - public DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createDenseIterator(); } @Override public VectorScorer scorer(float[] query) throws IOException { DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); RandomVectorScorer randomVectorScorer = flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); return new VectorScorer() { @Override public float score() throws IOException { - return randomVectorScorer.score(copy.iterator().docID()); + return randomVectorScorer.score(iterator.docID()); } @Override public DocIdSetIterator iterator() { - return copy.iterator(); + return iterator; } }; } @@ -245,24 +246,25 @@ public int length() { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return IndexedDISI.asDocIndexIterator(disi); } @Override public VectorScorer scorer(float[] query) throws IOException { SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); RandomVectorScorer randomVectorScorer = flatVectorsScorer.getRandomVectorScorer(similarityFunction, copy, query); return new VectorScorer() { @Override public float score() throws IOException { - return randomVectorScorer.score(copy.disi.index()); + return randomVectorScorer.score(iterator.index()); } @Override public DocIdSetIterator iterator() { - return copy.iterator(); + return iterator; } }; } @@ -298,7 +300,7 @@ public float[] vectorValue(int targetOrd) { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createDenseIterator(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index e2578bf5cd68..6e6b4b0e50cd 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -878,23 +878,29 @@ public float[] vectorValue(int ord) throws IOException { } @Override - public DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createDenseIterator(); } } static class QuantizedByteVectorValueSub extends DocIDMerger.Sub { private final QuantizedByteVectorValues values; + private final KnnVectorValues.DocIndexIterator iterator; QuantizedByteVectorValueSub(MergeState.DocMap docMap, QuantizedByteVectorValues values) { super(docMap); this.values = values; - assert values.iterator().docID() == -1; + iterator = values.iterator(); + assert iterator.docID() == -1; } @Override public int nextDoc() throws IOException { - return values.iterator().nextDoc(); + return iterator.nextDoc(); + } + + public int index() { + return iterator.index(); } } @@ -966,11 +972,11 @@ private MergedQuantizedVectorValues( @Override public byte[] vectorValue(int ord) throws IOException { - return current.values.vectorValue(current.values.iterator().index()); + return current.values.vectorValue(current.index()); } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return new CompositeIterator(); } @@ -986,7 +992,7 @@ public int dimension() { @Override public float getScoreCorrectionConstant(int ord) throws IOException { - return current.values.getScoreCorrectionConstant(current.values.iterator().index()); + return current.values.getScoreCorrectionConstant(current.index()); } private class CompositeIterator extends DocIndexIterator { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java index 6d1fea07b4f9..051c926a679e 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java @@ -206,8 +206,6 @@ public static OffHeapQuantizedByteVectorValues load( */ public static class DenseOffHeapVectorValues extends OffHeapQuantizedByteVectorValues { - private int doc = -1; - public DenseOffHeapVectorValues( int dimension, int size, @@ -239,23 +237,24 @@ public Bits getAcceptOrds(Bits acceptDocs) { @Override public VectorScorer scorer(float[] target) throws IOException { DenseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); RandomVectorScorer vectorScorer = vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); return new VectorScorer() { @Override public float score() throws IOException { - return vectorScorer.score(copy.doc); + return vectorScorer.score(iterator.index()); } @Override public DocIdSetIterator iterator() { - return copy.iterator(); + return iterator; } }; } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createDenseIterator(); } } @@ -286,7 +285,7 @@ public SparseOffHeapVectorValues( } @Override - public DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return IndexedDISI.asDocIndexIterator(disi); } @@ -330,17 +329,18 @@ public int length() { @Override public VectorScorer scorer(float[] target) throws IOException { SparseOffHeapVectorValues copy = copy(); + DocIndexIterator iterator = copy.iterator(); RandomVectorScorer vectorScorer = vectorsScorer.getRandomVectorScorer(similarityFunction, copy, target); return new VectorScorer() { @Override public float score() throws IOException { - return vectorScorer.score(copy.disi.index()); + return vectorScorer.score(iterator.index()); } @Override public DocIdSetIterator iterator() { - return copy.iterator(); + return iterator; } }; } @@ -373,7 +373,7 @@ public int size() { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createDenseIterator(); } diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index f63c26d204bb..bb84ba51ef8c 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -111,7 +111,7 @@ public ByteVectorValues copy() { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createDenseIterator(); } }; diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java index f458aef87e45..c7531207c23d 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java @@ -456,7 +456,7 @@ public int size() { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createExitableIterator(vectorValues.iterator(), queryTimeout); } @@ -499,7 +499,7 @@ public int ordToDoc(int ord) { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createExitableIterator(vectorValues.iterator(), queryTimeout); } diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index 13c8c047d8e7..d35cdcea5c86 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -112,7 +112,7 @@ public FloatVectorValues copy() { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createDenseIterator(); } }; diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java index ac4cd865f9fc..e7d6595dd5d9 100644 --- a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java @@ -30,9 +30,6 @@ */ public abstract class KnnVectorValues { - /** The iterator associated with these values. */ - protected DocIndexIterator iterator; - /** Return the dimension of the vectors */ public abstract int dimension(); @@ -88,22 +85,8 @@ public int length() { }; } - /** - * Return the iterator for this instance. If you need multiple iterators, call - * this.copy().iterator(). - */ + /** Create an iterator for this instance. */ public DocIndexIterator iterator() { - if (iterator == null) { - iterator = createIterator(); - } - return iterator; - } - - /** - * Create an iterator for this instance; typically called once by iterator(). Wrapper - * value classes delegate to their inner instance's iterator and shouldn't implement this. - */ - protected DocIndexIterator createIterator() { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index 85cd22830b11..3d6927b9b51d 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -306,7 +306,6 @@ private static class SortingFloatVectorValues extends FloatVectorValues { @Override public float[] vectorValue(int ord) throws IOException { // ords are interpreted in the delegate's ord-space. - assert ord == iterator.index(); return delegate.vectorValue(ord); } @@ -326,7 +325,7 @@ public FloatVectorValues copy() { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return iteratorSupplier.get(); } } @@ -343,12 +342,11 @@ private static class SortingByteVectorValues extends ByteVectorValues { @Override public byte[] vectorValue(int ord) throws IOException { - assert ord == iterator.index(); return delegate.vectorValue(ord); } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return iteratorSupplier.get(); } diff --git a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java index bea76f231cef..3f7bcf6c5c45 100644 --- a/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java +++ b/lucene/core/src/java/org/apache/lucene/util/quantization/ScalarQuantizer.java @@ -25,6 +25,7 @@ import java.util.Random; import java.util.stream.IntStream; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.HitQueue; import org.apache.lucene.search.ScoreDoc; @@ -269,11 +270,12 @@ static ScalarQuantizer fromVectors( if (totalVectorCount == 0) { return new ScalarQuantizer(0f, 0f, bits); } + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); if (confidenceInterval == 1f) { float min = Float.POSITIVE_INFINITY; float max = Float.NEGATIVE_INFINITY; - while (floatVectorValues.iterator().nextDoc() != NO_MORE_DOCS) { - for (float v : floatVectorValues.vectorValue(floatVectorValues.iterator().index())) { + while (iterator.nextDoc() != NO_MORE_DOCS) { + for (float v : floatVectorValues.vectorValue(iterator.index())) { min = Math.min(min, v); max = Math.max(max, v); } @@ -289,8 +291,8 @@ static ScalarQuantizer fromVectors( if (totalVectorCount <= quantizationSampleSize) { int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount); int i = 0; - while (floatVectorValues.iterator().nextDoc() != NO_MORE_DOCS) { - float[] vectorValue = floatVectorValues.vectorValue(floatVectorValues.iterator().index()); + while (iterator.nextDoc() != NO_MORE_DOCS) { + float[] vectorValue = floatVectorValues.vectorValue(iterator.index()); System.arraycopy( vectorValue, 0, quantileGatheringScratch, i * vectorValue.length, vectorValue.length); i++; @@ -311,11 +313,11 @@ static ScalarQuantizer fromVectors( for (int i : vectorsToTake) { while (index <= i) { // We cannot use `advance(docId)` as MergedVectorValues does not support it - floatVectorValues.iterator().nextDoc(); + iterator.nextDoc(); index++; } - assert floatVectorValues.iterator().docID() != NO_MORE_DOCS; - float[] vectorValue = floatVectorValues.vectorValue(floatVectorValues.iterator().index()); + assert iterator.docID() != NO_MORE_DOCS; + float[] vectorValue = floatVectorValues.vectorValue(iterator.index()); System.arraycopy( vectorValue, 0, quantileGatheringScratch, idx * vectorValue.length, vectorValue.length); idx++; @@ -353,12 +355,13 @@ public static ScalarQuantizer fromVectorsAutoInterval( / (floatVectorValues.dimension() + 1), 1 - 1f / (floatVectorValues.dimension() + 1) }; + KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); if (totalVectorCount <= sampleSize) { int scratchSize = Math.min(SCRATCH_SIZE, totalVectorCount); int i = 0; - while (floatVectorValues.iterator().nextDoc() != NO_MORE_DOCS) { + while (iterator.nextDoc() != NO_MORE_DOCS) { gatherSample( - floatVectorValues.vectorValue(floatVectorValues.iterator().index()), + floatVectorValues.vectorValue(iterator.index()), quantileGatheringScratch, sampledDocs, i); @@ -378,12 +381,12 @@ public static ScalarQuantizer fromVectorsAutoInterval( for (int i : vectorsToTake) { while (index <= i) { // We cannot use `advance(docId)` as MergedVectorValues does not support it - floatVectorValues.iterator().nextDoc(); + iterator.nextDoc(); index++; } - assert floatVectorValues.iterator().docID() != NO_MORE_DOCS; + assert iterator.docID() != NO_MORE_DOCS; gatherSample( - floatVectorValues.vectorValue(floatVectorValues.iterator().index()), + floatVectorValues.vectorValue(iterator.index()), quantileGatheringScratch, sampledDocs, idx); diff --git a/lucene/core/src/test/org/apache/lucene/document/TestField.java b/lucene/core/src/test/org/apache/lucene/document/TestField.java index 685f5b6bd1a5..5c1b8f17294f 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestField.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestField.java @@ -28,6 +28,7 @@ import org.apache.lucene.index.IndexOptions; import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.Term; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; @@ -714,18 +715,20 @@ public void testKnnVectorField() throws Exception { try (IndexReader r = DirectoryReader.open(w)) { ByteVectorValues binary = r.leaves().get(0).reader().getByteVectorValues("binary"); assertEquals(1, binary.size()); - assertNotEquals(NO_MORE_DOCS, binary.iterator().nextDoc()); + KnnVectorValues.DocIndexIterator iterator = binary.iterator(); + assertNotEquals(NO_MORE_DOCS, iterator.nextDoc()); assertNotNull(binary.vectorValue(0)); assertArrayEquals(b, binary.vectorValue(0)); - assertEquals(NO_MORE_DOCS, binary.iterator().nextDoc()); + assertEquals(NO_MORE_DOCS, iterator.nextDoc()); expectThrows(IOException.class, () -> binary.vectorValue(1)); FloatVectorValues floatValues = r.leaves().get(0).reader().getFloatVectorValues("float"); assertEquals(1, floatValues.size()); - assertNotEquals(NO_MORE_DOCS, floatValues.iterator().nextDoc()); + KnnVectorValues.DocIndexIterator iterator1 = floatValues.iterator(); + assertNotEquals(NO_MORE_DOCS, iterator1.nextDoc()); assertEquals(vector.length, floatValues.vectorValue(0).length); assertEquals(vector[0], floatValues.vectorValue(0)[0], 0); - assertEquals(NO_MORE_DOCS, floatValues.iterator().nextDoc()); + assertEquals(NO_MORE_DOCS, iterator1.nextDoc()); expectThrows(IOException.class, () -> floatValues.vectorValue(1)); } } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index d3e69986c279..41410ad4e39d 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -413,8 +413,9 @@ private void assertConsistentGraph(IndexWriter iw, float[][] values, String vect // stored vector values are the same as original int nextDocWithVectors = 0; StoredFields storedFields = reader.storedFields(); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); for (int i = 0; i < reader.maxDoc(); i++) { - nextDocWithVectors = vectorValues.iterator().advance(i); + nextDocWithVectors = iterator.advance(i); while (i < nextDocWithVectors && i < reader.maxDoc()) { int id = Integer.parseInt(storedFields.document(i).get("id")); assertNull( @@ -426,7 +427,7 @@ private void assertConsistentGraph(IndexWriter iw, float[][] values, String vect } int id = Integer.parseInt(storedFields.document(i).get("id")); // documents with KnnGraphValues have the expected vectors - float[] scratch = vectorValues.vectorValue(vectorValues.iterator().index()); + float[] scratch = vectorValues.vectorValue(iterator.index()); assertArrayEquals( "vector did not match for doc " + i + ", id=" + id + ": " + Arrays.toString(scratch), values[id], @@ -436,9 +437,9 @@ private void assertConsistentGraph(IndexWriter iw, float[][] values, String vect } // if IndexDisi.doc == NO_MORE_DOCS, we should not call IndexDisi.nextDoc() if (nextDocWithVectors != NO_MORE_DOCS) { - assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc()); + assertEquals(NO_MORE_DOCS, iterator.nextDoc()); } else { - assertEquals(NO_MORE_DOCS, vectorValues.iterator().docID()); + assertEquals(NO_MORE_DOCS, iterator.docID()); } // assert graph values: diff --git a/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java b/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java index de588d3e15c7..fcbd0cdea21f 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestSortingCodecReader.java @@ -242,6 +242,7 @@ public void testSortOnAddIndicesRandom() throws IOException { NumericDocValues ids = leaf.getNumericDocValues("id"); long prevValue = -1; boolean usingAltIds = false; + KnnVectorValues.DocIndexIterator valuesIterator = vectorValues.iterator(); for (int i = 0; i < actualNumDocs; i++) { int idNext = ids.nextDoc(); if (idNext == DocIdSetIterator.NO_MORE_DOCS) { @@ -262,7 +263,7 @@ public void testSortOnAddIndicesRandom() throws IOException { assertTrue(sorted_numeric_dv.advanceExact(idNext)); assertTrue(sorted_set_dv.advanceExact(idNext)); assertTrue(binary_sorted_dv.advanceExact(idNext)); - assertEquals(idNext, vectorValues.iterator().advance(idNext)); + assertEquals(idNext, valuesIterator.advance(idNext)); assertEquals(new BytesRef(ids.longValue() + ""), binary_dv.binaryValue()); assertEquals( new BytesRef(ids.longValue() + ""), @@ -274,7 +275,7 @@ public void testSortOnAddIndicesRandom() throws IOException { assertEquals(1, sorted_numeric_dv.docValueCount()); assertEquals(ids.longValue(), sorted_numeric_dv.nextValue()); - float[] vectorValue = vectorValues.vectorValue(vectorValues.iterator().index()); + float[] vectorValue = vectorValues.vectorValue(valuesIterator.index()); assertEquals(1, vectorValue.length); assertEquals((float) ids.longValue(), vectorValue[0], 0.001f); diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java index 50a817f49c8b..41aeef2e5c8d 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/HnswGraphTestCase.java @@ -268,16 +268,16 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { } }); try (IndexWriter iw = new IndexWriter(dir, iwc)) { - while (v2.iterator().nextDoc() != NO_MORE_DOCS) { - while (indexedDoc < v2.iterator().docID()) { + KnnVectorValues.DocIndexIterator it2 = v2.iterator(); + while (it2.nextDoc() != NO_MORE_DOCS) { + while (indexedDoc < it2.docID()) { // increment docId in the index by adding empty documents iw.addDocument(new Document()); indexedDoc++; } Document doc = new Document(); - doc.add( - knnVectorField("field", vectorValue(v2, v2.iterator().index()), similarityFunction)); - doc.add(new StoredField("id", v2.iterator().docID())); + doc.add(knnVectorField("field", vectorValue(v2, it2.index()), similarityFunction)); + doc.add(new StoredField("id", it2.docID())); iw.addDocument(doc); nVec++; indexedDoc++; @@ -719,14 +719,13 @@ private int[] createOffsetOrdinalMap( // Compute the offset for the ordinal map to be the number of non-null vectors in the total // vector values before the docIdOffset int ordinalOffset = 0; - while (totalVectorValues.iterator().nextDoc() < docIdOffset) { + KnnVectorValues.DocIndexIterator it = totalVectorValues.iterator(); + while (it.nextDoc() < docIdOffset) { ordinalOffset++; } int[] offsetOrdinalMap = new int[docIdSize]; - for (int curr = 0; - totalVectorValues.iterator().docID() < docIdOffset + docIdSize; - totalVectorValues.iterator().nextDoc()) { + for (int curr = 0; it.docID() < docIdOffset + docIdSize; it.nextDoc()) { offsetOrdinalMap[curr] = ordinalOffset + curr++; } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java index e67c0b8f4980..4ab86c707816 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockByteVectorValues.java @@ -90,7 +90,7 @@ public byte[] vectorValue(int ord) { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createDenseIterator(); } } diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java index 45d6006a10d3..5411f2418de3 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/MockVectorValues.java @@ -84,7 +84,7 @@ public float[] vectorValue(int ord) { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createDenseIterator(); } } diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java index 69e8d030c119..f2cc3ac35c05 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizedVectorSimilarity.java @@ -239,9 +239,6 @@ private static FloatVectorValues fromFloatsNormalized( return new TestScalarQuantizer.TestSimpleFloatVectorValues(floats, deletedVectors) { @Override public float[] vectorValue(int ord) throws IOException { - if (iterator().docID() == -1 || iterator().docID() >= floats.length) { - throw new IOException("Current doc not set or too many iterations"); - } float[] v = ArrayUtil.copyArray(floats[ordToDoc[ord]]); VectorUtil.l2normalize(v); return v; diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index 68551f275ce7..b5b37d291a50 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -316,7 +316,7 @@ public int ordToDoc(int ord) { } @Override - public DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return new DocIndexIterator() { int ord = -1; diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java index 119c8b099dfc..11147decf25c 100644 --- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java +++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java @@ -2310,7 +2310,7 @@ public float[] vectorValue(int ord) { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createDenseIterator(); } @@ -2324,10 +2324,11 @@ public VectorScorer scorer(float[] query) { + info.fieldInfo.getVectorDimension()); } MemoryFloatVectorValues vectorValues = new MemoryFloatVectorValues(info); + DocIndexIterator iterator = vectorValues.iterator(); return new VectorScorer() { @Override public float score() throws IOException { - assert iterator().docID() == 0; + assert iterator.docID() == 0; return info.fieldInfo .getVectorSimilarityFunction() .compare(vectorValues.vectorValue(0), query); @@ -2335,7 +2336,7 @@ public float score() throws IOException { @Override public DocIdSetIterator iterator() { - return vectorValues.iterator(); + return iterator; } }; } @@ -2368,7 +2369,7 @@ public byte[] vectorValue(int ord) { } @Override - protected DocIndexIterator createIterator() { + public DocIndexIterator iterator() { return createDenseIterator(); } @@ -2382,10 +2383,11 @@ public VectorScorer scorer(byte[] query) { + info.fieldInfo.getVectorDimension()); } MemoryByteVectorValues vectorValues = new MemoryByteVectorValues(info); + DocIndexIterator iterator = vectorValues.iterator(); return new VectorScorer() { @Override public float score() { - assert iterator().docID() == 0; + assert iterator.docID() == 0; return info.fieldInfo .getVectorSimilarityFunction() .compare(vectorValues.vectorValue(0), query); @@ -2393,7 +2395,7 @@ public float score() { @Override public DocIdSetIterator iterator() { - return vectorValues.iterator(); + return iterator; } }; } diff --git a/lucene/memory/src/test/org/apache/lucene/index/memory/TestMemoryIndex.java b/lucene/memory/src/test/org/apache/lucene/index/memory/TestMemoryIndex.java index 6a68de001832..7c5928689127 100644 --- a/lucene/memory/src/test/org/apache/lucene/index/memory/TestMemoryIndex.java +++ b/lucene/memory/src/test/org/apache/lucene/index/memory/TestMemoryIndex.java @@ -63,6 +63,7 @@ import org.apache.lucene.index.IndexReader; import org.apache.lucene.index.IndexableField; import org.apache.lucene.index.IndexableFieldType; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.NumericDocValues; import org.apache.lucene.index.PostingsEnum; @@ -851,9 +852,10 @@ private static void assertFloatVectorValue(MemoryIndex mi, String fieldName, flo .reader() .getFloatVectorValues(fieldName); assertNotNull(fvv); - assertEquals(0, fvv.iterator().nextDoc()); + KnnVectorValues.DocIndexIterator iterator = fvv.iterator(); + assertEquals(0, iterator.nextDoc()); assertArrayEquals(expected, fvv.vectorValue(0), 1e-6f); - assertEquals(DocIdSetIterator.NO_MORE_DOCS, fvv.iterator().nextDoc()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); } private static void assertFloatVectorScore( @@ -886,9 +888,10 @@ private static void assertByteVectorValue(MemoryIndex mi, String fieldName, byte .reader() .getByteVectorValues(fieldName); assertNotNull(bvv); - assertEquals(0, bvv.iterator().nextDoc()); + KnnVectorValues.DocIndexIterator iterator = bvv.iterator(); + assertEquals(0, iterator.nextDoc()); assertArrayEquals(expected, bvv.vectorValue(0)); - assertEquals(DocIdSetIterator.NO_MORE_DOCS, bvv.iterator().nextDoc()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iterator.nextDoc()); } private static void assertByteVectorScore( diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index 645bc0261588..93abd5becc76 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -435,9 +435,10 @@ public void testAddIndexesDirectory0() throws Exception { try (IndexReader reader = DirectoryReader.open(w2)) { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); - assertEquals(0, vectorValues.iterator().nextDoc()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + assertEquals(0, iterator.nextDoc()); assertEquals(0, vectorValues.vectorValue(0)[0], 0); - assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc()); + assertEquals(NO_MORE_DOCS, iterator.nextDoc()); } } } @@ -460,9 +461,10 @@ public void testAddIndexesDirectory1() throws Exception { try (IndexReader reader = DirectoryReader.open(w2)) { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); - assertNotEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc()); - assertEquals(0, vectorValues.vectorValue(vectorValues.iterator().index())[0], 0); - assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + assertNotEquals(NO_MORE_DOCS, iterator.nextDoc()); + assertEquals(0, vectorValues.vectorValue(iterator.index())[0], 0); + assertEquals(NO_MORE_DOCS, iterator.nextDoc()); } } } @@ -487,11 +489,12 @@ public void testAddIndexesDirectory01() throws Exception { try (IndexReader reader = DirectoryReader.open(w2)) { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); - assertEquals(0, vectorValues.iterator().nextDoc()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + assertEquals(0, iterator.nextDoc()); // The merge order is randomized, we might get 0 first, or 1 float value = vectorValues.vectorValue(0)[0]; assertTrue(value == 0 || value == 1); - assertEquals(1, vectorValues.iterator().nextDoc()); + assertEquals(1, iterator.nextDoc()); value += vectorValues.vectorValue(1)[0]; assertEquals(1, value, 0); } @@ -877,9 +880,10 @@ public void testSparseVectors() throws Exception { ByteVectorValues byteVectorValues = ctx.reader().getByteVectorValues(fieldName); if (byteVectorValues != null) { docCount += byteVectorValues.size(); - while (byteVectorValues.iterator().nextDoc() != NO_MORE_DOCS) { - checksum += - byteVectorValues.vectorValue(byteVectorValues.iterator().index())[0]; + KnnVectorValues.DocIndexIterator iterator = byteVectorValues.iterator(); + while (true) { + if (!(iterator.nextDoc() != NO_MORE_DOCS)) break; + checksum += byteVectorValues.vectorValue(iterator.index())[0]; } } } @@ -889,8 +893,10 @@ public void testSparseVectors() throws Exception { FloatVectorValues vectorValues = ctx.reader().getFloatVectorValues(fieldName); if (vectorValues != null) { docCount += vectorValues.size(); - while (vectorValues.iterator().nextDoc() != NO_MORE_DOCS) { - checksum += vectorValues.vectorValue(vectorValues.iterator().index())[0]; + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + while (true) { + if (!(iterator.nextDoc() != NO_MORE_DOCS)) break; + checksum += vectorValues.vectorValue(iterator.index())[0]; } } } @@ -949,11 +955,12 @@ public void testFloatVectorScorerIteration() throws Exception { assertSame(iterator, scorer.iterator()); assertNotSame(iterator, scorer); // verify scorer iteration scores are valid & iteration with vectorValues is consistent - while (iterator.nextDoc() != NO_MORE_DOCS - && vectorValues.iterator().nextDoc() != NO_MORE_DOCS) { + KnnVectorValues.DocIndexIterator valuesIterator = vectorValues.iterator(); + while (iterator.nextDoc() != NO_MORE_DOCS) { + if (!(valuesIterator.nextDoc() != NO_MORE_DOCS)) break; float score = scorer.score(); assertTrue(score >= 0f); - assertEquals(iterator.docID(), vectorValues.iterator().docID()); + assertEquals(iterator.docID(), valuesIterator.docID()); } // verify that a new scorer can be obtained after iteration VectorScorer newScorer = vectorValues.scorer(vectorToScore); @@ -1009,11 +1016,12 @@ public void testByteVectorScorerIteration() throws Exception { assertSame(iterator, scorer.iterator()); assertNotSame(iterator, scorer); // verify scorer iteration scores are valid & iteration with vectorValues is consistent - while (iterator.nextDoc() != NO_MORE_DOCS - && vectorValues.iterator().nextDoc() != NO_MORE_DOCS) { + KnnVectorValues.DocIndexIterator valuesIterator = vectorValues.iterator(); + while (iterator.nextDoc() != NO_MORE_DOCS) { + if (!(valuesIterator.nextDoc() != NO_MORE_DOCS)) break; float score = scorer.score(); assertTrue(score >= 0f); - assertEquals(iterator.docID(), vectorValues.iterator().docID()); + assertEquals(iterator.docID(), valuesIterator.docID()); } // verify that a new scorer can be obtained after iteration VectorScorer newScorer = vectorValues.scorer(vectorToScore); @@ -1119,14 +1127,15 @@ public void testIndexedValueNotAliased() throws Exception { LeafReader r = getOnlyLeafReader(reader); FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); assertEquals(3, vectorValues.size()); - vectorValues.iterator().nextDoc(); - assertEquals(0, vectorValues.iterator().index()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + iterator.nextDoc(); + assertEquals(0, iterator.index()); assertEquals(1, vectorValues.vectorValue(0)[0], 0); - vectorValues.iterator().nextDoc(); - assertEquals(1, vectorValues.iterator().index()); + iterator.nextDoc(); + assertEquals(1, iterator.index()); assertEquals(1, vectorValues.vectorValue(1)[0], 0); - vectorValues.iterator().nextDoc(); - assertEquals(2, vectorValues.iterator().index()); + iterator.nextDoc(); + assertEquals(2, iterator.index()); assertEquals(2, vectorValues.vectorValue(2)[0], 0); } } @@ -1150,13 +1159,14 @@ public void testSortedIndex() throws Exception { FloatVectorValues vectorValues = leaf.getFloatVectorValues(fieldName); assertEquals(2, vectorValues.dimension()); assertEquals(3, vectorValues.size()); - assertEquals("1", storedFields.document(vectorValues.iterator().nextDoc()).get("id")); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + assertEquals("1", storedFields.document(iterator.nextDoc()).get("id")); assertEquals(-1f, vectorValues.vectorValue(0)[0], 0); - assertEquals("2", storedFields.document(vectorValues.iterator().nextDoc()).get("id")); + assertEquals("2", storedFields.document(iterator.nextDoc()).get("id")); assertEquals(1, vectorValues.vectorValue(1)[0], 0); - assertEquals("4", storedFields.document(vectorValues.iterator().nextDoc()).get("id")); + assertEquals("4", storedFields.document(iterator.nextDoc()).get("id")); assertEquals(0, vectorValues.vectorValue(2)[0], 0); - assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc()); + assertEquals(NO_MORE_DOCS, iterator.nextDoc()); } } } @@ -1215,27 +1225,30 @@ public void testIndexMultipleKnnVectorFields() throws Exception { FloatVectorValues vectorValues = leaf.getFloatVectorValues("field1"); assertEquals(2, vectorValues.dimension()); assertEquals(2, vectorValues.size()); - vectorValues.iterator().nextDoc(); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + iterator.nextDoc(); assertEquals(1f, vectorValues.vectorValue(0)[0], 0); - vectorValues.iterator().nextDoc(); + iterator.nextDoc(); assertEquals(2f, vectorValues.vectorValue(1)[0], 0); - assertEquals(NO_MORE_DOCS, vectorValues.iterator().nextDoc()); + assertEquals(NO_MORE_DOCS, iterator.nextDoc()); FloatVectorValues vectorValues2 = leaf.getFloatVectorValues("field2"); + KnnVectorValues.DocIndexIterator it2 = vectorValues2.iterator(); assertEquals(4, vectorValues2.dimension()); assertEquals(2, vectorValues2.size()); - vectorValues2.iterator().nextDoc(); + it2.nextDoc(); assertEquals(2f, vectorValues2.vectorValue(0)[1], 0); - vectorValues2.iterator().nextDoc(); + it2.nextDoc(); assertEquals(2f, vectorValues2.vectorValue(1)[1], 0); - assertEquals(NO_MORE_DOCS, vectorValues2.iterator().nextDoc()); + assertEquals(NO_MORE_DOCS, it2.nextDoc()); FloatVectorValues vectorValues3 = leaf.getFloatVectorValues("field3"); assertEquals(4, vectorValues3.dimension()); assertEquals(1, vectorValues3.size()); - vectorValues3.iterator().nextDoc(); + KnnVectorValues.DocIndexIterator it3 = vectorValues3.iterator(); + it3.nextDoc(); assertEquals(1f, vectorValues3.vectorValue(0)[0], 0.1); - assertEquals(NO_MORE_DOCS, vectorValues3.iterator().nextDoc()); + assertEquals(NO_MORE_DOCS, it3.nextDoc()); } } } @@ -1299,8 +1312,10 @@ public void testRandom() throws Exception { totalSize += vectorValues.size(); StoredFields storedFields = ctx.reader().storedFields(); int docId; - while ((docId = vectorValues.iterator().nextDoc()) != NO_MORE_DOCS) { - float[] v = vectorValues.vectorValue(vectorValues.iterator().index()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + while (true) { + if (!((docId = iterator.nextDoc()) != NO_MORE_DOCS)) break; + float[] v = vectorValues.vectorValue(iterator.index()); assertEquals(dimension, v.length); String idString = storedFields.document(docId).getField("id").stringValue(); int id = Integer.parseInt(idString); @@ -1379,8 +1394,10 @@ public void testRandomBytes() throws Exception { totalSize += vectorValues.size(); StoredFields storedFields = ctx.reader().storedFields(); int docId; - while ((docId = vectorValues.iterator().nextDoc()) != NO_MORE_DOCS) { - byte[] v = vectorValues.vectorValue(vectorValues.iterator().index()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + while (true) { + if (!((docId = iterator.nextDoc()) != NO_MORE_DOCS)) break; + byte[] v = vectorValues.vectorValue(iterator.index()); assertEquals(dimension, v.length); String idString = storedFields.document(docId).getField("id").stringValue(); int id = Integer.parseInt(idString); @@ -1499,8 +1516,10 @@ public void testRandomWithUpdatesAndGraph() throws Exception { StoredFields storedFields = ctx.reader().storedFields(); int docId; int numLiveDocsWithVectors = 0; - while ((docId = vectorValues.iterator().nextDoc()) != NO_MORE_DOCS) { - float[] v = vectorValues.vectorValue(vectorValues.iterator().index()); + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); + while (true) { + if (!((docId = iterator.nextDoc()) != NO_MORE_DOCS)) break; + float[] v = vectorValues.vectorValue(iterator.index()); assertEquals(dimension, v.length); String idString = storedFields.document(docId).getField("id").stringValue(); int id = Integer.parseInt(idString); @@ -1707,8 +1726,9 @@ public void testAdvance() throws Exception { FloatVectorValues vectorValues = r.getFloatVectorValues(fieldName); int[] vectorDocs = new int[vectorValues.size() + 1]; int cur = -1; + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); while (++cur < vectorValues.size() + 1) { - vectorDocs[cur] = vectorValues.iterator().nextDoc(); + vectorDocs[cur] = iterator.nextDoc(); if (cur != 0) { assertTrue(vectorDocs[cur] > vectorDocs[cur - 1]); } From 4feecf8ad246c5cabbc60a39e283d9e1ae0e921d Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Mon, 16 Sep 2024 07:31:01 -0700 Subject: [PATCH 16/25] Rename IteratorSupplier->SortingIteratorSupplier and add javadoc --- .../lucene/index/SortingCodecReader.java | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index 392e89bab591..c8249c95347f 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -208,13 +208,16 @@ public PointValues.Relation compare(byte[] minPackedValue, byte[] maxPackedValue } } - // FIXME rename this better? - static class IteratorSupplier implements Supplier { + /** + * Factory for SortingValuesIterator. This enables us to create new iterators as needed without + * recomputing the sorting mappings. + */ + static class SortingIteratorSupplier implements Supplier { private final FixedBitSet docBits; private final int[] docToOrd; private final int size; - IteratorSupplier(FixedBitSet docBits, int[] docToOrd, int size) { + SortingIteratorSupplier(FixedBitSet docBits, int[] docToOrd, int size) { this.docBits = docBits; this.docToOrd = docToOrd; this.size = size; @@ -230,8 +233,15 @@ public int size() { } } - public static IteratorSupplier iteratorSupplier(KnnVectorValues values, Sorter.DocMap docMap) - throws IOException { + /** + * Creates a factory for SortingValuesIterator. Does the work of computing the (new docId to old + * ordinal) mapping, and caches the result, enabling it to create new iterators cheaply. + * + * @param values the values over which to iterate + * @param docMap the mapping from "old" docIds to "new" (sorted) docIds. + */ + public static SortingIteratorSupplier iteratorSupplier( + KnnVectorValues values, Sorter.DocMap docMap) throws IOException { final int[] docToOrd = new int[docMap.size()]; final FixedBitSet docBits = new FixedBitSet(docMap.size()); @@ -247,7 +257,7 @@ public static IteratorSupplier iteratorSupplier(KnnVectorValues values, Sorter.D ++count; } } - return new IteratorSupplier(docBits, docToOrd, count); + return new SortingIteratorSupplier(docBits, docToOrd, count); } /** @@ -295,7 +305,7 @@ public int advance(int target) { /** Sorting FloatVectorValues that maps ordinals using the provided sortMap */ private static class SortingFloatVectorValues extends FloatVectorValues { final FloatVectorValues delegate; - final IteratorSupplier iteratorSupplier; + final SortingIteratorSupplier iteratorSupplier; SortingFloatVectorValues(FloatVectorValues delegate, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate; @@ -332,7 +342,7 @@ public DocIndexIterator iterator() { private static class SortingByteVectorValues extends ByteVectorValues { final ByteVectorValues delegate; - final IteratorSupplier iteratorSupplier; + final SortingIteratorSupplier iteratorSupplier; SortingByteVectorValues(ByteVectorValues delegate, Sorter.DocMap sortMap) throws IOException { this.delegate = delegate; From abc17132d4175f455849713a285074e42e17a1fa Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Mon, 16 Sep 2024 07:52:54 -0700 Subject: [PATCH 17/25] cache vector values iterators in VectorFieldSources --- .../function/valuesource/ByteKnnVectorFieldSource.java | 6 ++++-- .../function/valuesource/FloatKnnVectorFieldSource.java | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java index 362a8e1f7853..c95bf632a73a 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/ByteKnnVectorFieldSource.java @@ -20,6 +20,7 @@ import java.util.Map; import java.util.Objects; import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorEncoding; @@ -63,11 +64,12 @@ protected DocIdSetIterator getVectorIterator() { } return new VectorFieldFunction(this) { + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); @Override public byte[] byteVectorVal(int doc) throws IOException { if (exists(doc)) { - return vectorValues.vectorValue(vectorValues.iterator().index()); + return vectorValues.vectorValue(iterator.index()); } else { return null; } @@ -75,7 +77,7 @@ public byte[] byteVectorVal(int doc) throws IOException { @Override protected DocIdSetIterator getVectorIterator() { - return vectorValues.iterator(); + return iterator; } }; } diff --git a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java index 824cc8a782df..f026d9537bc6 100644 --- a/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java +++ b/lucene/queries/src/java/org/apache/lucene/queries/function/valuesource/FloatKnnVectorFieldSource.java @@ -20,6 +20,7 @@ import java.util.Map; import java.util.Objects; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.KnnVectorValues; import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.VectorEncoding; @@ -62,11 +63,12 @@ protected DocIdSetIterator getVectorIterator() { } return new VectorFieldFunction(this) { + KnnVectorValues.DocIndexIterator iterator = vectorValues.iterator(); @Override public float[] floatVectorVal(int doc) throws IOException { if (exists(doc)) { - return vectorValues.vectorValue(vectorValues.iterator().index()); + return vectorValues.vectorValue(iterator.index()); } else { return null; } @@ -74,7 +76,7 @@ public float[] floatVectorVal(int doc) throws IOException { @Override protected DocIdSetIterator getVectorIterator() { - return vectorValues.iterator(); + return iterator; } }; } From 3f6091cea2c97977c7ce3e781ce8d237fb3311f1 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Tue, 17 Sep 2024 11:42:29 -0700 Subject: [PATCH 18/25] rename KnnvectorValues.all() to createSparseIterator() --- .../backward_codecs/lucene90/Lucene90HnswVectorsReader.java | 2 +- .../backward_codecs/lucene91/Lucene91HnswVectorsReader.java | 2 +- .../lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java | 4 ++-- .../src/java/org/apache/lucene/index/KnnVectorValues.java | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index cc658ea714a8..1196ed3fdb64 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -410,7 +410,7 @@ public int ordToDoc(int ord) { @Override public DocIndexIterator iterator() { - return all(); + return createSparseIterator(); } @Override diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 42c9c0257514..a140b4fd7f39 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -454,7 +454,7 @@ public int ordToDoc(int ord) { @Override public DocIndexIterator iterator() { - return all(); + return createSparseIterator(); } @Override diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index ce03705fa467..0135a4a995c2 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -338,7 +338,7 @@ public int ordToDoc(int ord) { @Override public DocIndexIterator iterator() { - return all(); + return createSparseIterator(); } @Override @@ -431,7 +431,7 @@ public byte[] vectorValue(int ord) { @Override public DocIndexIterator iterator() { - return all(); + return createSparseIterator(); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java index e7d6595dd5d9..a43ad6296e86 100644 --- a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java @@ -192,7 +192,7 @@ public long cost() { * Creates an iterator from this instance's ordinal-to-docid mapping which must be monotonic * (docid increases when ordinal does). */ - protected DocIndexIterator all() { + protected DocIndexIterator createSparseIterator() { return new DocIndexIterator() { private int ord = -1; From d8ab1ec901b9323297507ea540e61265c7e8ba23 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Wed, 18 Sep 2024 09:53:02 -0700 Subject: [PATCH 19/25] implement cost(); enforce forward iteration in KnnVectorsWriter --- .../lucene/codecs/KnnVectorsWriter.java | 33 +++++++++++++++---- .../Lucene99ScalarQuantizedVectorsWriter.java | 5 +++ .../lucene/index/ExitableDirectoryReader.java | 5 +++ .../apache/lucene/index/KnnVectorValues.java | 5 --- .../SlowCompositeCodecReaderWrapper.java | 5 +++ .../lucene/index/SortingCodecReader.java | 5 +++ .../quantization/TestScalarQuantizer.java | 5 +++ 7 files changed, 51 insertions(+), 12 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java index 0ff782781ffe..cbf68dd8c20d 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsWriter.java @@ -298,7 +298,8 @@ static class MergedFloat32VectorValues extends FloatVectorValues { private final List subs; private final DocIDMerger docIdMerger; private final int size; - private int docId; + private int docId = -1; + private int lastOrd = -1; FloatVectorValuesSub current; private MergedFloat32VectorValues(List subs, MergeState mergeState) @@ -310,7 +311,6 @@ private MergedFloat32VectorValues(List subs, MergeState me totalSize += sub.values.size(); } size = totalSize; - docId = -1; } @Override @@ -345,13 +345,22 @@ public int nextDoc() throws IOException { public int advance(int target) throws IOException { throw new UnsupportedOperationException(); } + + @Override + public long cost() { + return size; + } }; } @Override public float[] vectorValue(int ord) throws IOException { - // FIXME what can we assert here? - // assert ord == iterator.index(); + if (ord != lastOrd + 1) { + throw new IllegalStateException( + "only supports forward iteration: ord=" + ord + ", lastOrd=" + lastOrd); + } else { + lastOrd = ord; + } return current.values.vectorValue(current.index()); } @@ -386,7 +395,8 @@ static class MergedByteVectorValues extends ByteVectorValues { private final DocIDMerger docIdMerger; private final int size; - private int docId; + private int lastOrd = -1; + private int docId = -1; ByteVectorValuesSub current; private MergedByteVectorValues(List subs, MergeState mergeState) @@ -402,8 +412,12 @@ private MergedByteVectorValues(List subs, MergeState mergeS @Override public byte[] vectorValue(int ord) throws IOException { - // FIXME - // assert ord == iterator.index(); + if (ord != lastOrd + 1) { + throw new IllegalStateException( + "only supports forward iteration: ord=" + ord + ", lastOrd=" + lastOrd); + } else { + lastOrd = ord; + } return current.values.vectorValue(current.index()); } @@ -439,6 +453,11 @@ public int nextDoc() throws IOException { public int advance(int target) throws IOException { throw new UnsupportedOperationException(); } + + @Override + public long cost() { + return size; + } }; } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index 6e6b4b0e50cd..4d3c33940a56 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -1031,6 +1031,11 @@ public int nextDoc() throws IOException { public int advance(int target) throws IOException { throw new UnsupportedOperationException(); } + + @Override + public long cost() { + return size; + } } } diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java index c7531207c23d..614a652cd35a 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java @@ -540,6 +540,11 @@ public int nextDoc() throws IOException { return doc; } + @Override + public long cost() { + return delegate.cost(); + } + @Override public int advance(int target) throws IOException { int doc = delegate.advance(target); diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java index a43ad6296e86..ffcb1686adb2 100644 --- a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java @@ -98,11 +98,6 @@ public abstract static class DocIndexIterator extends DocIdSetIterator { /** return the value index (aka "ordinal" or "ord") corresponding to the current doc */ public abstract int index(); - - @Override - public long cost() { - throw new UnsupportedOperationException(); - } } /** diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index e0634df91e37..ab18edde0ed8 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -344,6 +344,11 @@ public int nextDoc() throws IOException { } } + @Override + public long cost() { + throw new UnsupportedOperationException(); + } + @Override public int advance(int target) throws IOException { throw new UnsupportedOperationException(); diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index c8249c95347f..daec0c197d6a 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -296,6 +296,11 @@ public int nextDoc() throws IOException { return doc; } + @Override + public long cost() { + return docBits.cardinality(); + } + @Override public int advance(int target) { throw new UnsupportedOperationException(); diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index b5b37d291a50..50313de2ae17 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -344,6 +344,11 @@ public int index() { return ord; } + @Override + public long cost() { + return floats.length - deletedVectors.size(); + } + @Override public int advance(int target) throws IOException { throw new UnsupportedOperationException(); From a2ca172ab0abe8568485599d5e222904b6f8c40a Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Wed, 18 Sep 2024 17:19:48 -0700 Subject: [PATCH 20/25] add implementations of KnnVectorValues.copy() --- .../SimpleTextKnnVectorsReader.java | 10 ++++++++++ .../Lucene99ScalarQuantizedVectorsReader.java | 9 +++++++-- .../Lucene99ScalarQuantizedVectorsWriter.java | 12 +++++++++-- .../apache/lucene/index/ByteVectorValues.java | 8 +++----- .../lucene/index/FloatVectorValues.java | 8 +++----- .../apache/lucene/index/KnnVectorValues.java | 2 -- .../SlowCompositeCodecReaderWrapper.java | 20 +++++++++++++++++++ .../quantization/TestScalarQuantizer.java | 5 +++++ .../lucene/index/memory/MemoryIndex.java | 10 ++++++++++ 9 files changed, 68 insertions(+), 16 deletions(-) diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index 0135a4a995c2..97a518701b00 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -382,6 +382,11 @@ private void readVector(float[] value) throws IOException { value[i] = Float.parseFloat(floatStrings[i]); } } + + @Override + public SimpleTextFloatVectorValues copy() { + return this; + } } private static class SimpleTextByteVectorValues extends ByteVectorValues { @@ -475,6 +480,11 @@ private void readVector(byte[] value) throws IOException { value[i] = (byte) Float.parseFloat(floatStrings[i]); } } + + @Override + public SimpleTextByteVectorValues copy() { + return this; + } } private int readInt(IndexInput in, BytesRef field) throws IOException { diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java index 5f846cc21410..24123a4f21e3 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsReader.java @@ -402,10 +402,10 @@ static FieldEntry create( private static final class QuantizedVectorValues extends FloatVectorValues { private final FloatVectorValues rawVectorValues; - private final OffHeapQuantizedByteVectorValues quantizedVectorValues; + private final QuantizedByteVectorValues quantizedVectorValues; QuantizedVectorValues( - FloatVectorValues rawVectorValues, OffHeapQuantizedByteVectorValues quantizedVectorValues) { + FloatVectorValues rawVectorValues, QuantizedByteVectorValues quantizedVectorValues) { this.rawVectorValues = rawVectorValues; this.quantizedVectorValues = quantizedVectorValues; } @@ -430,6 +430,11 @@ public int ordToDoc(int ord) { return rawVectorValues.ordToDoc(ord); } + @Override + public QuantizedVectorValues copy() throws IOException { + return new QuantizedVectorValues(rawVectorValues.copy(), quantizedVectorValues.copy()); + } + @Override public VectorScorer scorer(float[] query) throws IOException { return quantizedVectorValues.scorer(query); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java index 4d3c33940a56..1a30b5271cd7 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99ScalarQuantizedVectorsWriter.java @@ -853,7 +853,6 @@ public DocsWithFieldSet getDocsWithFieldSet() { static class FloatVectorWrapper extends FloatVectorValues { private final List vectorList; - protected int curDoc = -1; FloatVectorWrapper(List vectorList) { this.vectorList = vectorList; @@ -869,6 +868,11 @@ public int size() { return vectorList.size(); } + @Override + public FloatVectorValues copy() throws IOException { + return this; + } + @Override public float[] vectorValue(int ord) throws IOException { if (ord < 0 || ord >= vectorList.size()) { @@ -1196,7 +1200,6 @@ public DocIndexIterator iterator() { static final class NormalizedFloatVectorValues extends FloatVectorValues { private final FloatVectorValues values; private final float[] normalizedVector; - int curOrd = -1; public NormalizedFloatVectorValues(FloatVectorValues values) { this.values = values; @@ -1229,5 +1232,10 @@ public float[] vectorValue(int ord) throws IOException { public DocIndexIterator iterator() { return values.iterator(); } + + @Override + public NormalizedFloatVectorValues copy() throws IOException { + return new NormalizedFloatVectorValues(values.copy()); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index bb84ba51ef8c..e9be3423c181 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -32,11 +32,6 @@ public abstract class ByteVectorValues extends KnnVectorValues { /** Sole constructor */ protected ByteVectorValues() {} - @Override - public ByteVectorValues copy() throws IOException { - return this; - } - /** * Return the vector value for the given vector ordinal which must be in [0, size() - 1], * otherwise IndexOutOfBoundsException is thrown. The returned array may be shared across calls. @@ -45,6 +40,9 @@ public ByteVectorValues copy() throws IOException { */ public abstract byte[] vectorValue(int ord) throws IOException; + @Override + public abstract ByteVectorValues copy() throws IOException; + /** * Checks the Vector Encoding of a field * diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index d35cdcea5c86..aa840fc39319 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -32,11 +32,6 @@ public abstract class FloatVectorValues extends KnnVectorValues { /** Sole constructor */ protected FloatVectorValues() {} - @Override - public FloatVectorValues copy() throws IOException { - return this; - } - /** * Return the vector value for the given vector ordinal which must be in [0, size() - 1], * otherwise IndexOutOfBoundsException is thrown. The returned array may be shared across calls. @@ -45,6 +40,9 @@ public FloatVectorValues copy() throws IOException { */ public abstract float[] vectorValue(int ord) throws IOException; + @Override + public abstract FloatVectorValues copy() throws IOException; + /** * Checks the Vector Encoding of a field * diff --git a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java index ffcb1686adb2..8e58f387a334 100644 --- a/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/KnnVectorValues.java @@ -53,8 +53,6 @@ public int ordToDoc(int ord) { * Creates a new copy of this {@link KnnVectorValues}. This is helpful when you need to access * different values at once, to avoid overwriting the underlying vector returned. */ - // FIXME: replace with some kind of wrapper??? Or can we eliminate completely given - // that we no longer have seek() API? public abstract KnnVectorValues copy() throws IOException; /** Returns the vector byte length, defaults to dimension multiplied by float byte size */ diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index ab18edde0ed8..3f4dc75ffecd 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -872,6 +872,16 @@ public int size() { return size; } + @SuppressWarnings("unchecked") + @Override + public FloatVectorValues copy() throws IOException { + List> subsCopy = new ArrayList<>(); + for (Object sub : subs) { + subsCopy.add((DocValuesSub) sub); + } + return new MergedFloatVectorValues(dimension, size, subsCopy); + } + @Override public float[] vectorValue(int ord) throws IOException { assert ord >= 0 && ord < size; @@ -951,6 +961,16 @@ public byte[] vectorValue(int ord) throws IOException { return ((ByteVectorValues) subs[lastSubIndex].sub) .vectorValue(ord - subs[lastSubIndex].ordStart); } + + @SuppressWarnings("unchecked") + @Override + public ByteVectorValues copy() throws IOException { + List> newSubs = new ArrayList<>(); + for (Object sub : subs) { + newSubs.add((DocValuesSub) sub); + } + return new MergedByteVectorValues(dimension, size, newSubs); + } } private static int findSub(int ord, int lastSubIndex, int[] starts) { diff --git a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java index 50313de2ae17..7f56688b7999 100644 --- a/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java +++ b/lucene/core/src/test/org/apache/lucene/util/quantization/TestScalarQuantizer.java @@ -360,5 +360,10 @@ public int advance(int target) throws IOException { public VectorScorer scorer(float[] target) { throw new UnsupportedOperationException(); } + + @Override + public TestSimpleFloatVectorValues copy() { + return this; + } } } diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java index 11147decf25c..ce86365dab26 100644 --- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java +++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java @@ -2340,6 +2340,11 @@ public DocIdSetIterator iterator() { } }; } + + @Override + public MemoryFloatVectorValues copy() { + return this; + } } private static final class MemoryByteVectorValues extends ByteVectorValues { @@ -2399,5 +2404,10 @@ public DocIdSetIterator iterator() { } }; } + + @Override + public MemoryByteVectorValues copy() { + return this; + } } } From 2b21668bd25976c85ed5764db48532c561482cea Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Fri, 20 Sep 2024 14:54:28 -0400 Subject: [PATCH 21/25] fix SlowCOmpositeCodecReaderWrapper; off-by-one AND lazy iterator access --- .../lucene/index/SlowCompositeCodecReaderWrapper.java | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index 22d0b74b8464..7f5d8926b638 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -308,12 +308,14 @@ private static class MergedDocIterator final Iterator> it; DocValuesSub current; + KnnVectorValues.DocIndexIterator currentIterator; int ord = -1; int doc = -1; MergedDocIterator(List> subs) { this.it = subs.iterator(); current = it.next(); + currentIterator = current.sub.iterator(); } @Override @@ -330,7 +332,7 @@ public int index() { public int nextDoc() throws IOException { while (true) { if (current.sub != null) { - int next = current.sub.iterator().nextDoc(); + int next = currentIterator.nextDoc(); if (next != NO_MORE_DOCS) { ++ord; return doc = current.docStart + next; @@ -341,6 +343,7 @@ public int nextDoc() throws IOException { return doc = NO_MORE_DOCS; } current = it.next(); + currentIterator = current.sub.iterator(); ord = current.ordStart - 1; } } @@ -987,11 +990,11 @@ private static int findSub(int ord, int lastSubIndex, int[] starts) { private static int binarySearchStarts(int[] starts, int ord, int from, int to) { int pos = Arrays.binarySearch(starts, from, to, ord); - // also subtract one since starts[] is shifted by one if (pos < 0) { + // subtract one since binarySearch returns an *insertion point* return -2 - pos; } else { - return pos - 1; + return pos; } } From cb62025c536fc41310af8e158ba516d22d4b012c Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Fri, 20 Sep 2024 19:26:09 -0400 Subject: [PATCH 22/25] resolve merge conflicts --- .../index/BaseKnnVectorsFormatTestCase.java | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index 23e79663b8c4..e42d3e189819 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -1882,25 +1882,27 @@ public void testMismatchedFields() throws Exception { ByteVectorValues byteVectors = leafReader.getByteVectorValues("byte"); assertNotNull(byteVectors); - assertEquals(0, byteVectors.nextDoc()); - assertArrayEquals(new byte[] {42}, byteVectors.vectorValue()); - assertEquals(1, byteVectors.nextDoc()); - assertArrayEquals(new byte[] {42}, byteVectors.vectorValue()); - assertEquals(DocIdSetIterator.NO_MORE_DOCS, byteVectors.nextDoc()); + KnnVectorValues.DocIndexIterator iter = byteVectors.iterator(); + assertEquals(0, iter.nextDoc()); + assertArrayEquals(new byte[] {42}, byteVectors.vectorValue(0)); + assertEquals(1, iter.nextDoc()); + assertArrayEquals(new byte[] {42}, byteVectors.vectorValue(1)); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iter.nextDoc()); FloatVectorValues floatVectors = leafReader.getFloatVectorValues("float"); assertNotNull(floatVectors); - assertEquals(0, floatVectors.nextDoc()); - float[] vector = floatVectors.vectorValue(); + iter = floatVectors.iterator(); + assertEquals(0, iter.nextDoc()); + float[] vector = floatVectors.vectorValue(0); assertEquals(2, vector.length); assertEquals(1f, vector[0], 0f); assertEquals(2f, vector[1], 0f); - assertEquals(1, floatVectors.nextDoc()); - vector = floatVectors.vectorValue(); + assertEquals(1, iter.nextDoc()); + vector = floatVectors.vectorValue(1); assertEquals(2, vector.length); assertEquals(1f, vector[0], 0f); assertEquals(2f, vector[1], 0f); - assertEquals(DocIdSetIterator.NO_MORE_DOCS, floatVectors.nextDoc()); + assertEquals(DocIdSetIterator.NO_MORE_DOCS, iter.nextDoc()); IOUtils.close(reader, w2, dir1, dir2); } From 29c9e00fbb96f6c4a07a96b014ec6be52f243e69 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Thu, 26 Sep 2024 13:52:00 -0400 Subject: [PATCH 23/25] fix NPE introduced in recent patch when segment has no vectors --- .../index/SlowCompositeCodecReaderWrapper.java | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index 7f5d8926b638..ec2423310d93 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -315,7 +315,7 @@ private static class MergedDocIterator MergedDocIterator(List> subs) { this.it = subs.iterator(); current = it.next(); - currentIterator = current.sub.iterator(); + currentIterator = currentIterator(); } @Override @@ -343,11 +343,19 @@ public int nextDoc() throws IOException { return doc = NO_MORE_DOCS; } current = it.next(); - currentIterator = current.sub.iterator(); + currentIterator = currentIterator(); ord = current.ordStart - 1; } } + private KnnVectorValues.DocIndexIterator currentIterator() { + if (current.sub != null) { + return current.sub.iterator(); + } else { + return null; + } + } + @Override public long cost() { throw new UnsupportedOperationException(); From e219f3b7fa205c9125a66ff27d3420aa2bed1e29 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Thu, 26 Sep 2024 13:55:06 -0400 Subject: [PATCH 24/25] fix failing test due to leaking static in test class when iters>0 --- .../org/apache/lucene/search/TestTimeLimitingBulkScorer.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/test/org/apache/lucene/search/TestTimeLimitingBulkScorer.java b/lucene/core/src/test/org/apache/lucene/search/TestTimeLimitingBulkScorer.java index cce82cd34ac0..de5512a904a1 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestTimeLimitingBulkScorer.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestTimeLimitingBulkScorer.java @@ -113,7 +113,7 @@ public long cost() { private static QueryTimeout countingQueryTimeout(int timeallowed) { return new QueryTimeout() { - static int counter = 0; + int counter = 0; @Override public boolean shouldExit() { From a8dfe680c092d1525bc202b37d9467e92fad75b1 Mon Sep 17 00:00:00 2001 From: Michael Sokolov Date: Fri, 27 Sep 2024 08:47:55 -0400 Subject: [PATCH 25/25] remove stray print in test --- .../src/test/org/apache/lucene/util/hnsw/TestHnswUtil.java | 6 ------ 1 file changed, 6 deletions(-) diff --git a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswUtil.java b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswUtil.java index 3ec133ac46e0..316afff5ee25 100644 --- a/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswUtil.java +++ b/lucene/core/src/test/org/apache/lucene/util/hnsw/TestHnswUtil.java @@ -138,12 +138,6 @@ public void testRandom() throws Exception { } } MockGraph graph = new MockGraph(nodes); - /**/ - if (i == 2) { - System.out.println("iter " + i); - System.out.print(graph.toString()); - } - /**/ assertEquals(isRooted(nodes), HnswUtil.isRooted(graph)); } }