Skip to content

Commit

Permalink
Explain API changes
Browse files Browse the repository at this point in the history
Signed-off-by: Neetika Singhal <[email protected]>
  • Loading branch information
neetikasinghal committed Jan 17, 2025
1 parent eecb45c commit 93c1f96
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 14 deletions.
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ public class KNNConstants {
public static final String PATH = "path";
public static final String QUERY = "query";
public static final String KNN = "knn";
public static final String EXACT_SEARCH = "Exact";
public static final String ANN_SEARCH = "Approximate-NN";
public static final String RADIAL_SEARCH = "Radial";
public static final String DISK_BASED_SEARCH = "Disk-based";
public static final String VECTOR = "vector";
public static final String K = "k";
public static final String TYPE_KNN_VECTOR = "knn_vector";
Expand Down
119 changes: 111 additions & 8 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -73,6 +74,7 @@ public class KNNWeight extends Weight {

private static ExactSearcher DEFAULT_EXACT_SEARCHER;
private final QuantizationService quantizationService;
private final Map<Object, Integer> annResultPerLeaf;

public KNNWeight(KNNQuery query, float boost) {
super(query);
Expand All @@ -82,6 +84,7 @@ public KNNWeight(KNNQuery query, float boost) {
this.filterWeight = null;
this.exactSearcher = DEFAULT_EXACT_SEARCHER;
this.quantizationService = QuantizationService.getInstance();
annResultPerLeaf = new ConcurrentHashMap<>();
}

public KNNWeight(KNNQuery query, float boost, Weight filterWeight) {
Expand All @@ -92,6 +95,7 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) {
this.filterWeight = filterWeight;
this.exactSearcher = DEFAULT_EXACT_SEARCHER;
this.quantizationService = QuantizationService.getInstance();
annResultPerLeaf = new ConcurrentHashMap<>();
}

public static void initialize(ModelDao modelDao) {
Expand All @@ -106,7 +110,97 @@ static void initialize(ModelDao modelDao, ExactSearcher exactSearcher) {

@Override
public Explanation explain(LeafReaderContext context, int doc) {
return Explanation.match(1.0f, "No Explanation");
float score = 0;
try {
final KNNScorer knnScorer = (KNNScorer) scorer(context);
int resDoc = knnScorer.iterator().advance(doc);
if (resDoc == doc) {
score = knnScorer.score();
}
} catch (IOException e) {
throw new RuntimeException(e);
}

final StringBuilder sb = new StringBuilder();
sb.append("the type of knn search executed was ");
if (knnQuery.getRescoreContext() != null) {
sb.append(KNNConstants.DISK_BASED_SEARCH);
boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName());
int dimension = knnQuery.getQueryVector().length;
int firstPassK = knnQuery.getRescoreContext().getFirstPassK(knnQuery.getK(), isShardLevelRescoringDisabled, dimension);
sb.append(" and the first pass k was ").append(firstPassK)
.append( " with vector dimension of ").append(dimension)
.append( ", over sampling factor of ").append(knnQuery.getRescoreContext().getOversampleFactor());
if (isShardLevelRescoringDisabled) {
sb.append(", shard level rescoring disabled");
} else {
sb.append(", shard level rescoring enabled");
}
} else if (knnQuery.getRadius() != null) {
sb.append(KNNConstants.RADIAL_SEARCH);
sb.append(" with the radius of ").append(knnQuery.getRadius());
} else {
sb.append(KNNConstants.ANN_SEARCH);
}

int cardinality = 0;
try {
cardinality = getFilteredDocsBitSet(context).cardinality();
} catch (IOException e) {
// do nothing
}
int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName());
StringBuilder stringBuilder = new StringBuilder("the type of knn search executed at leaf was ");
if (filterWeight != null) {
if (isFilterIdCountLessThanK(cardinality)) {
stringBuilder.append(KNNConstants.EXACT_SEARCH)
.append(" since filteredIds = ")
.append(cardinality)
.append(" is less than or equal to K = ")
.append(knnQuery.getK());
} else if (isExactSearchThresholdSettingSet(filterThresholdValue) && (filterThresholdValue >= cardinality)) {
stringBuilder.append(KNNConstants.EXACT_SEARCH)
.append(" since filtered threshold value = ")
.append(filterThresholdValue)
.append(" is greater than or equal to cardinality = ")
.append(cardinality);
} else if (isMDCGreaterThanFilterIdCnt(cardinality)) {
stringBuilder.append(KNNConstants.EXACT_SEARCH)
.append(" since max distance computation = ")
.append(KNNConstants.MAX_DISTANCE_COMPUTATIONS)
.append(" is greater than or equal to cardinality = ")
.append(cardinality);
}
} else if (annResultPerLeaf.get(context.id()) == 0 && isMissingNativeEngineFiles(context)) {
stringBuilder.append(KNNConstants.EXACT_SEARCH)
.append(" since no native engine files are available");
} else if (isFilteredExactSearchRequireAfterANNSearch(cardinality, annResultPerLeaf.get(context.id()))) {
stringBuilder.append(KNNConstants.EXACT_SEARCH)
.append(" since the number of documents returned are less than K = ")
.append(knnQuery.getK())
.append(" and there are more than K filtered Ids = ")
.append(cardinality);
} else {
stringBuilder.append(KNNConstants.ANN_SEARCH);
}

final SegmentReader reader = Lucene.segmentReader(context.reader());
final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());

if (fieldInfo != null) {
// need to check on space type for exact search/model space/disk/default based search for binary vectors
stringBuilder.append(" with spaceType = ").append(FieldInfoExtractor.getSpaceType(modelDao, fieldInfo).getValue());
}
stringBuilder.append(", vectorDataType = ").append(knnQuery.getVectorDataType());

// optional - need to see if we need to add the query vector
if (knnQuery.getQueryVector() != null) {
stringBuilder.append(", queryVector = ").append(Arrays.toString(knnQuery.getQueryVector()));
} if (knnQuery.getByteQueryVector() != null) {
stringBuilder.append(", byteVector = ").append(Arrays.toString(knnQuery.getByteQueryVector()));
}

return Explanation.match(score, sb.toString(), Explanation.match(score, stringBuilder.toString()));
}

@Override
Expand Down Expand Up @@ -153,7 +247,7 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
*/
final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet;
final Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k);

annResultPerLeaf.put(context.id(), docIdsToScoreMap.size());
// See whether we have to perform exact search based on approx search results
// This is required if there are no native engine files or if approximate search returned
// results less than K, though we have more than k filtered docs
Expand Down Expand Up @@ -425,22 +519,31 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) {
);
int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName());
// Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic
if (knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK()) {
return true;
}
if (isFilterIdCountLessThanK(filterIdsCount)) return true;
// See user has defined Exact Search filtered threshold. if yes, then use that setting.
if (isExactSearchThresholdSettingSet(filterThresholdValue)) {
return filterThresholdValue >= filterIdsCount;
if (filterThresholdValue >= filterIdsCount) {
return true;
}
return false;
}

// if no setting is set, then use the default max distance computation value to see if we can do exact search.
/**
* TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index
* is cheaper than computation cost for non binary vector
*/
return isMDCGreaterThanFilterIdCnt(filterIdsCount);
}

private boolean isMDCGreaterThanFilterIdCnt(int filterIdsCount) {
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT
? knnQuery.getQueryVector().length
: knnQuery.getByteQueryVector().length);
? knnQuery.getQueryVector().length
: knnQuery.getByteQueryVector().length);
}

private boolean isFilterIdCountLessThanK(int filterIdsCount) {
return knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.Weight;
import org.opensearch.knn.index.query.KNNWeight;

import java.io.IOException;
import java.util.Arrays;
Expand All @@ -31,13 +32,15 @@ final class DocAndScoreQuery extends Query {
private final float[] scores;
private final int[] segmentStarts;
private final Object contextIdentity;
private final KNNWeight knnWeight;

public DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) {
public DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity, KNNWeight knnWeight) {
this.k = k;
this.docs = docs;
this.scores = scores;
this.segmentStarts = segmentStarts;
this.contextIdentity = contextIdentity;
this.knnWeight = knnWeight;
}

@Override
Expand All @@ -53,7 +56,7 @@ public Explanation explain(LeafReaderContext context, int doc) {
if (found < 0) {
return Explanation.noMatch("not in top " + k);
}
return Explanation.match(scores[found] * boost, "within top " + k);
return knnWeight.explain(context, doc);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.apache.lucene.util.Bits;
import org.opensearch.knn.index.query.KNNWeight;
import org.opensearch.knn.index.query.iterators.GroupedNestedDocIdSetIterator;

import java.io.IOException;
Expand Down Expand Up @@ -46,6 +47,10 @@ public class QueryUtils {
* @return a query representing the given TopDocs
*/
public Query createDocAndScoreQuery(final IndexReader reader, final TopDocs topDocs) {
return createDocAndScoreQuery(reader, topDocs, null);
}

public Query createDocAndScoreQuery(final IndexReader reader, final TopDocs topDocs, final KNNWeight knnWeight) {
int len = topDocs.scoreDocs.length;
Arrays.sort(topDocs.scoreDocs, Comparator.comparingInt(a -> a.doc));
int[] docs = new int[len];
Expand All @@ -55,7 +60,7 @@ public Query createDocAndScoreQuery(final IndexReader reader, final TopDocs topD
scores[i] = topDocs.scoreDocs[i].score;
}
int[] segmentStarts = findSegmentStarts(reader, docs);
return new DocAndScoreQuery(len, docs, scores, segmentStarts, reader.getContext().id());
return new DocAndScoreQuery(len, docs, scores, segmentStarts, reader.getContext().id(), knnWeight);
}

private int[] findSegmentStarts(final IndexReader reader, final int[] docs) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo
if (topK.scoreDocs.length == 0) {
return new MatchNoDocsQuery().createWeight(indexSearcher, scoreMode, boost);
}
return queryUtils.createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost);
return queryUtils.createDocAndScoreQuery(reader, topK, knnWeight).createWeight(indexSearcher, scoreMode, boost);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public void testScorer() throws Exception {
int[] expectedDocs = { 0, 1, 2, 3, 4 };
float[] expectedScores = { 0.1f, 1.2f, 2.3f, 5.1f, 3.4f };
int[] findSegments = { 0, 2, 5 };
objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1);
objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1, null);

// When
Scorer scorer1 = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1).scorer(leaf1);
Expand Down Expand Up @@ -85,7 +85,7 @@ public void testWeight() {
Explanation expectedExplanation = Explanation.match(1.2f, "within top 4");

// When
objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1);
objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1, null);
Weight weight = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1);
Explanation explanation = weight.explain(leaf1, 1);

Expand Down

0 comments on commit 93c1f96

Please sign in to comment.