Skip to content

Commit

Permalink
First-class random access API for KnnVectorValues (#13779)
Browse files Browse the repository at this point in the history
  • Loading branch information
msokolov authored Sep 28, 2024
1 parent 7b4b023 commit 6053e1e
Show file tree
Hide file tree
Showing 83 changed files with 2,058 additions and 2,314 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,18 @@
package org.apache.lucene.analysis.synonym.word2vec;

import java.io.IOException;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefHash;
import org.apache.lucene.util.TermAndVector;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

/**
* Word2VecModel is a class representing the parsed Word2Vec model containing the vectors for each
* word in dictionary
*
* @lucene.experimental
*/
public class Word2VecModel implements RandomAccessVectorValues.Floats {
public class Word2VecModel extends FloatVectorValues {

private final int dictionarySize;
private final int vectorDimension;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
import java.util.Objects;
import java.util.SplittableRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

/**
* Builder for HNSW graph. See {@link Lucene90OnHeapHnswGraph} for a gloss on the algorithm and the
Expand All @@ -49,7 +49,7 @@ public final class Lucene90HnswGraphBuilder {
private final Lucene90NeighborArray scratch;

private final VectorSimilarityFunction similarityFunction;
private final RandomAccessVectorValues.Floats vectorValues;
private final FloatVectorValues vectorValues;
private final SplittableRandom random;
private final Lucene90BoundsChecker bound;
final Lucene90OnHeapHnswGraph hnsw;
Expand All @@ -58,7 +58,7 @@ public final class Lucene90HnswGraphBuilder {

// we need two sources of vectors in order to perform diversity check comparisons without
// colliding
private final RandomAccessVectorValues.Floats buildVectors;
private final FloatVectorValues buildVectors;

/**
* Reads all the vectors from vector values, builds a graph connecting them by their dense
Expand All @@ -73,7 +73,7 @@ public final class Lucene90HnswGraphBuilder {
* to ensure repeatable construction.
*/
public Lucene90HnswGraphBuilder(
RandomAccessVectorValues.Floats vectors,
FloatVectorValues vectors,
VectorSimilarityFunction similarityFunction,
int maxConn,
int beamWidth,
Expand All @@ -97,14 +97,14 @@ public Lucene90HnswGraphBuilder(
}

/**
* Reads all the vectors from two copies of a {@link RandomAccessVectorValues}. Providing two
* copies enables efficient retrieval without extra data copying, while avoiding collision of the
* Reads all the vectors from two copies of a {@link FloatVectorValues}. Providing two copies
* enables efficient retrieval without extra data copying, while avoiding collision of the
* returned values.
*
* @param vectors the vectors for which to build a nearest neighbors graph. Must be an independet
* accessor for the vectors
*/
public Lucene90OnHeapHnswGraph build(RandomAccessVectorValues.Floats vectors) throws IOException {
public Lucene90OnHeapHnswGraph build(FloatVectorValues vectors) throws IOException {
if (vectors == vectorValues) {
throw new IllegalArgumentException(
"Vectors to build must be independent of the source of vectors provided to HnswGraphBuilder()");
Expand Down Expand Up @@ -230,7 +230,7 @@ private boolean diversityCheck(
float[] candidate,
float score,
Lucene90NeighborArray neighbors,
RandomAccessVectorValues.Floats vectorValues)
FloatVectorValues vectorValues)
throws IOException {
bound.set(score);
for (int i = 0; i < neighbors.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.SplittableRandom;
Expand All @@ -34,7 +33,6 @@
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
Expand All @@ -44,7 +42,6 @@
import org.apache.lucene.util.IOUtils;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

/**
* Reads vectors from the index segments along with index data structures supporting KNN search.
Expand Down Expand Up @@ -355,8 +352,7 @@ int size() {
}

/** Read the vector values from the index input. This supports both iterated and random access. */
static class OffHeapFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues.Floats {
static class OffHeapFloatVectorValues extends FloatVectorValues {

final int dimension;
final int[] ordToDoc;
Expand All @@ -367,9 +363,6 @@ static class OffHeapFloatVectorValues extends FloatVectorValues
final float[] value;
final VectorSimilarityFunction similarityFunction;

int ord = -1;
int doc = -1;

OffHeapFloatVectorValues(
int dimension,
int[] ordToDoc,
Expand All @@ -394,42 +387,6 @@ public int size() {
return ordToDoc.length;
}

@Override
public float[] vectorValue() throws IOException {
return vectorValue(ord);
}

@Override
public int docID() {
return doc;
}

@Override
public int nextDoc() {
if (++ord >= size()) {
doc = NO_MORE_DOCS;
} else {
doc = ordToDoc[ord];
}
return doc;
}

@Override
public int advance(int target) {
assert docID() < target;
ord = Arrays.binarySearch(ordToDoc, ord + 1, ordToDoc.length, target);
if (ord < 0) {
ord = -(ord + 1);
}
assert ord <= ordToDoc.length;
if (ord == ordToDoc.length) {
doc = NO_MORE_DOCS;
} else {
doc = ordToDoc[ord];
}
return doc;
}

@Override
public OffHeapFloatVectorValues copy() {
return new OffHeapFloatVectorValues(dimension, ordToDoc, similarityFunction, dataIn.clone());
Expand All @@ -446,21 +403,32 @@ public float[] vectorValue(int targetOrd) throws IOException {
return value;
}

@Override
public int ordToDoc(int ord) {
return ordToDoc[ord];
}

@Override
public DocIndexIterator iterator() {
return createSparseIterator();
}

@Override
public VectorScorer scorer(float[] target) {
if (size() == 0) {
return null;
}
OffHeapFloatVectorValues values = this.copy();
DocIndexIterator iterator = values.iterator();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.similarityFunction.compare(values.vectorValue(), target);
return values.similarityFunction.compare(values.vectorValue(iterator.index()), target);
}

@Override
public DocIdSetIterator iterator() {
return values;
public DocIndexIterator iterator() {
return iterator;
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@
import java.util.ArrayList;
import java.util.List;
import java.util.SplittableRandom;
import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.util.Bits;
import org.apache.lucene.util.SparseFixedBitSet;
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.NeighborQueue;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;

/**
* An {@link HnswGraph} where all nodes and connections are held in memory. This class is used to
Expand Down Expand Up @@ -74,7 +74,7 @@ public static NeighborQueue search(
float[] query,
int topK,
int numSeed,
RandomAccessVectorValues.Floats vectors,
FloatVectorValues vectors,
VectorSimilarityFunction similarityFunction,
HnswGraph graphValues,
Bits acceptOrds,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import org.apache.lucene.util.hnsw.HnswGraph;
import org.apache.lucene.util.hnsw.HnswGraphSearcher;
import org.apache.lucene.util.hnsw.OrdinalTranslatedKnnCollector;
import org.apache.lucene.util.hnsw.RandomAccessVectorValues;
import org.apache.lucene.util.hnsw.RandomVectorScorer;

/**
Expand Down Expand Up @@ -398,8 +397,7 @@ int ordToDoc(int ord) {
}

/** Read the vector values from the index input. This supports both iterated and random access. */
static class OffHeapFloatVectorValues extends FloatVectorValues
implements RandomAccessVectorValues.Floats {
static class OffHeapFloatVectorValues extends FloatVectorValues {

private final int dimension;
private final int size;
Expand All @@ -410,9 +408,6 @@ static class OffHeapFloatVectorValues extends FloatVectorValues
private final float[] value;
private final VectorSimilarityFunction similarityFunction;

private int ord = -1;
private int doc = -1;

OffHeapFloatVectorValues(
int dimension,
int size,
Expand All @@ -439,49 +434,6 @@ public int size() {
return size;
}

@Override
public float[] vectorValue() throws IOException {
dataIn.seek((long) ord * byteSize);
dataIn.readFloats(value, 0, value.length);
return value;
}

@Override
public int docID() {
return doc;
}

@Override
public int nextDoc() {
if (++ord >= size) {
doc = NO_MORE_DOCS;
} else {
doc = ordToDocOperator.applyAsInt(ord);
}
return doc;
}

@Override
public int advance(int target) {
assert docID() < target;

if (ordToDoc == null) {
ord = target;
} else {
ord = Arrays.binarySearch(ordToDoc, ord + 1, ordToDoc.length, target);
if (ord < 0) {
ord = -(ord + 1);
}
}

if (ord < size) {
doc = ordToDocOperator.applyAsInt(ord);
} else {
doc = NO_MORE_DOCS;
}
return doc;
}

@Override
public OffHeapFloatVectorValues copy() {
return new OffHeapFloatVectorValues(
Expand All @@ -495,21 +447,32 @@ public float[] vectorValue(int targetOrd) throws IOException {
return value;
}

@Override
public int ordToDoc(int ord) {
return ordToDocOperator.applyAsInt(ord);
}

@Override
public DocIndexIterator iterator() {
return createSparseIterator();
}

@Override
public VectorScorer scorer(float[] target) {
if (size == 0) {
return null;
}
OffHeapFloatVectorValues values = this.copy();
DocIndexIterator iterator = values.iterator();
return new VectorScorer() {
@Override
public float score() throws IOException {
return values.similarityFunction.compare(values.vectorValue(), target);
return values.similarityFunction.compare(values.vectorValue(iterator.index()), target);
}

@Override
public DocIdSetIterator iterator() {
return values;
return iterator;
}
};
}
Expand Down
Loading

0 comments on commit 6053e1e

Please sign in to comment.