diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index ce6095fd05..b1aa24ec7b 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -22,6 +22,10 @@ public class KNNConstants { public static final String PATH = "path"; public static final String QUERY = "query"; public static final String KNN = "knn"; + public static final String EXACT_SEARCH = "Exact"; + public static final String ANN_SEARCH = "Approximate-NN"; + public static final String RADIAL_SEARCH = "Radial"; + public static final String DISK_BASED_SEARCH = "Disk-based"; public static final String VECTOR = "vector"; public static final String K = "k"; public static final String TYPE_KNN_VECTOR = "knn_vector"; diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 37b5cc9ad1..499cda932e 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -46,6 +46,7 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; @@ -73,6 +74,7 @@ public class KNNWeight extends Weight { private static ExactSearcher DEFAULT_EXACT_SEARCHER; private final QuantizationService quantizationService; + private final Map annResultPerLeaf; public KNNWeight(KNNQuery query, float boost) { super(query); @@ -82,6 +84,7 @@ public KNNWeight(KNNQuery query, float boost) { this.filterWeight = null; this.exactSearcher = DEFAULT_EXACT_SEARCHER; this.quantizationService = QuantizationService.getInstance(); + annResultPerLeaf = new ConcurrentHashMap<>(); } public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { @@ -92,6 +95,7 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { this.filterWeight = filterWeight; this.exactSearcher = DEFAULT_EXACT_SEARCHER; this.quantizationService = QuantizationService.getInstance(); + annResultPerLeaf = new ConcurrentHashMap<>(); } public static void initialize(ModelDao modelDao) { @@ -106,7 +110,97 @@ static void initialize(ModelDao modelDao, ExactSearcher exactSearcher) { @Override public Explanation explain(LeafReaderContext context, int doc) { - return Explanation.match(1.0f, "No Explanation"); + float score = 0; + try { + final KNNScorer knnScorer = (KNNScorer) scorer(context); + int resDoc = knnScorer.iterator().advance(doc); + if (resDoc == doc) { + score = knnScorer.score(); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + + final StringBuilder sb = new StringBuilder(); + sb.append("the type of knn search executed was "); + if (knnQuery.getRescoreContext() != null) { + sb.append(KNNConstants.DISK_BASED_SEARCH); + boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()); + int dimension = knnQuery.getQueryVector().length; + int firstPassK = knnQuery.getRescoreContext().getFirstPassK(knnQuery.getK(), isShardLevelRescoringDisabled, dimension); + sb.append(" and the first pass k was ").append(firstPassK) + .append( " with vector dimension of ").append(dimension) + .append( ", over sampling factor of ").append(knnQuery.getRescoreContext().getOversampleFactor()); + if (isShardLevelRescoringDisabled) { + sb.append(", shard level rescoring disabled"); + } else { + sb.append(", shard level rescoring enabled"); + } + } else if (knnQuery.getRadius() != null) { + sb.append(KNNConstants.RADIAL_SEARCH); + sb.append(" with the radius of ").append(knnQuery.getRadius()); + } else { + sb.append(KNNConstants.ANN_SEARCH); + } + + int cardinality = 0; + try { + cardinality = getFilteredDocsBitSet(context).cardinality(); + } catch (IOException e) { + // do nothing + } + int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); + StringBuilder stringBuilder = new StringBuilder("the type of knn search executed at leaf was "); + if (filterWeight != null) { + if (isFilterIdCountLessThanK(cardinality)) { + stringBuilder.append(KNNConstants.EXACT_SEARCH) + .append(" since filteredIds = ") + .append(cardinality) + .append(" is less than or equal to K = ") + .append(knnQuery.getK()); + } else if (isExactSearchThresholdSettingSet(filterThresholdValue) && (filterThresholdValue >= cardinality)) { + stringBuilder.append(KNNConstants.EXACT_SEARCH) + .append(" since filtered threshold value = ") + .append(filterThresholdValue) + .append(" is greater than or equal to cardinality = ") + .append(cardinality); + } else if (isMDCGreaterThanFilterIdCnt(cardinality)) { + stringBuilder.append(KNNConstants.EXACT_SEARCH) + .append(" since max distance computation = ") + .append(KNNConstants.MAX_DISTANCE_COMPUTATIONS) + .append(" is greater than or equal to cardinality = ") + .append(cardinality); + } + } else if (annResultPerLeaf.get(context.id()) == 0 && isMissingNativeEngineFiles(context)) { + stringBuilder.append(KNNConstants.EXACT_SEARCH) + .append(" since no native engine files are available"); + } else if (isFilteredExactSearchRequireAfterANNSearch(cardinality, annResultPerLeaf.get(context.id()))) { + stringBuilder.append(KNNConstants.EXACT_SEARCH) + .append(" since the number of documents returned are less than K = ") + .append(knnQuery.getK()) + .append(" and there are more than K filtered Ids = ") + .append(cardinality); + } else { + stringBuilder.append(KNNConstants.ANN_SEARCH); + } + + final SegmentReader reader = Lucene.segmentReader(context.reader()); + final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField()); + + if (fieldInfo != null) { + // need to check on space type for exact search/model space/disk/default based search for binary vectors + stringBuilder.append(" with spaceType = ").append(FieldInfoExtractor.getSpaceType(modelDao, fieldInfo).getValue()); + } + stringBuilder.append(", vectorDataType = ").append(knnQuery.getVectorDataType()); + + // optional - need to see if we need to add the query vector + if (knnQuery.getQueryVector() != null) { + stringBuilder.append(", queryVector = ").append(Arrays.toString(knnQuery.getQueryVector())); + } if (knnQuery.getByteQueryVector() != null) { + stringBuilder.append(", byteVector = ").append(Arrays.toString(knnQuery.getByteQueryVector())); + } + + return Explanation.match(score, sb.toString(), Explanation.match(score, stringBuilder.toString())); } @Override @@ -153,7 +247,7 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep */ final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet; final Map docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k); - + annResultPerLeaf.put(context.id(), docIdsToScoreMap.size()); // See whether we have to perform exact search based on approx search results // This is required if there are no native engine files or if approximate search returned // results less than K, though we have more than k filtered docs @@ -425,12 +519,13 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { ); int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); // Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic - if (knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK()) { - return true; - } + if (isFilterIdCountLessThanK(filterIdsCount)) return true; // See user has defined Exact Search filtered threshold. if yes, then use that setting. if (isExactSearchThresholdSettingSet(filterThresholdValue)) { - return filterThresholdValue >= filterIdsCount; + if (filterThresholdValue >= filterIdsCount) { + return true; + } + return false; } // if no setting is set, then use the default max distance computation value to see if we can do exact search. @@ -438,9 +533,17 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { * TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index * is cheaper than computation cost for non binary vector */ + return isMDCGreaterThanFilterIdCnt(filterIdsCount); + } + + private boolean isMDCGreaterThanFilterIdCnt(int filterIdsCount) { return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT - ? knnQuery.getQueryVector().length - : knnQuery.getByteQueryVector().length); + ? knnQuery.getQueryVector().length + : knnQuery.getByteQueryVector().length); + } + + private boolean isFilterIdCountLessThanK(int filterIdsCount) { + return knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK(); } /** diff --git a/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java b/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java index f38cc96c64..8eea9dd88f 100644 --- a/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/common/DocAndScoreQuery.java @@ -14,6 +14,7 @@ import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; import org.apache.lucene.search.Weight; +import org.opensearch.knn.index.query.KNNWeight; import java.io.IOException; import java.util.Arrays; @@ -31,13 +32,15 @@ final class DocAndScoreQuery extends Query { private final float[] scores; private final int[] segmentStarts; private final Object contextIdentity; + private final KNNWeight knnWeight; - public DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity) { + public DocAndScoreQuery(int k, int[] docs, float[] scores, int[] segmentStarts, Object contextIdentity, KNNWeight knnWeight) { this.k = k; this.docs = docs; this.scores = scores; this.segmentStarts = segmentStarts; this.contextIdentity = contextIdentity; + this.knnWeight = knnWeight; } @Override @@ -53,7 +56,7 @@ public Explanation explain(LeafReaderContext context, int doc) { if (found < 0) { return Explanation.noMatch("not in top " + k); } - return Explanation.match(scores[found] * boost, "within top " + k); + return knnWeight.explain(context, doc); } @Override diff --git a/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java b/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java index 5fc0fb077e..ce823c229f 100644 --- a/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java +++ b/src/main/java/org/opensearch/knn/index/query/common/QueryUtils.java @@ -18,6 +18,7 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.BitSetIterator; import org.apache.lucene.util.Bits; +import org.opensearch.knn.index.query.KNNWeight; import org.opensearch.knn.index.query.iterators.GroupedNestedDocIdSetIterator; import java.io.IOException; @@ -46,6 +47,10 @@ public class QueryUtils { * @return a query representing the given TopDocs */ public Query createDocAndScoreQuery(final IndexReader reader, final TopDocs topDocs) { + return createDocAndScoreQuery(reader, topDocs, null); + } + + public Query createDocAndScoreQuery(final IndexReader reader, final TopDocs topDocs, final KNNWeight knnWeight) { int len = topDocs.scoreDocs.length; Arrays.sort(topDocs.scoreDocs, Comparator.comparingInt(a -> a.doc)); int[] docs = new int[len]; @@ -55,7 +60,7 @@ public Query createDocAndScoreQuery(final IndexReader reader, final TopDocs topD scores[i] = topDocs.scoreDocs[i].score; } int[] segmentStarts = findSegmentStarts(reader, docs); - return new DocAndScoreQuery(len, docs, scores, segmentStarts, reader.getContext().id()); + return new DocAndScoreQuery(len, docs, scores, segmentStarts, reader.getContext().id(), knnWeight); } private int[] findSegmentStarts(final IndexReader reader, final int[] docs) { diff --git a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java index 5b4d6e7a1d..9d859e023b 100644 --- a/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java +++ b/src/main/java/org/opensearch/knn/index/query/nativelib/NativeEngineKnnVectorQuery.java @@ -98,7 +98,7 @@ public Weight createWeight(IndexSearcher indexSearcher, ScoreMode scoreMode, flo if (topK.scoreDocs.length == 0) { return new MatchNoDocsQuery().createWeight(indexSearcher, scoreMode, boost); } - return queryUtils.createDocAndScoreQuery(reader, topK).createWeight(indexSearcher, scoreMode, boost); + return queryUtils.createDocAndScoreQuery(reader, topK, knnWeight).createWeight(indexSearcher, scoreMode, boost); } /** diff --git a/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java b/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java index b32496138c..5a912b352b 100644 --- a/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java +++ b/src/test/java/org/opensearch/knn/index/query/common/DocAndScoreQueryTests.java @@ -50,7 +50,7 @@ public void testScorer() throws Exception { int[] expectedDocs = { 0, 1, 2, 3, 4 }; float[] expectedScores = { 0.1f, 1.2f, 2.3f, 5.1f, 3.4f }; int[] findSegments = { 0, 2, 5 }; - objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1, null); // When Scorer scorer1 = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1).scorer(leaf1); @@ -85,7 +85,7 @@ public void testWeight() { Explanation expectedExplanation = Explanation.match(1.2f, "within top 4"); // When - objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1); + objectUnderTest = new DocAndScoreQuery(4, expectedDocs, expectedScores, findSegments, 1, null); Weight weight = objectUnderTest.createWeight(indexSearcher, ScoreMode.COMPLETE, 1); Explanation explanation = weight.explain(leaf1, 1);