diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index a9db3d9e47..5e51c654b3 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -117,18 +117,18 @@ public Explanation explain(LeafReaderContext context, int doc) { // This method is called for ANN/Exact/Disk-based/Efficient-filtering search public Explanation explain(LeafReaderContext context, int doc, float score, KNNScorer knnScorer) { knnQuery.setExplain(true); - if (knnScorer == null) { - float knnScore; - try { - knnScore = getKnnScore(context, doc); - } catch (IOException e) { - throw new RuntimeException("Error while getting KNN score during explanation", e); - } + try { + knnScorer = getOrCreateKnnScorer(context, knnScorer); + float knnScore = getKnnScore(knnScorer, doc); + if (score == 0) { score = knnScore; } assert score == knnScore : "Score mismatch in explain: provided score does not match KNN score"; + } catch (IOException e) { + throw new RuntimeException("Error while explaining KNN score", e); } + final String highLevelExplanation = getHighLevelExplanation(); final StringBuilder leafLevelExplanation = getLeafLevelExplanation(context); @@ -243,13 +243,23 @@ private String buildDiskBasedSearchExplanation() { return sb.toString(); } - private float getKnnScore(LeafReaderContext context, int doc) throws IOException { - KNNScorer knnScorer = (KNNScorer) scorer(context); - int resDoc = knnScorer.iterator().advance(doc); - if (resDoc == doc) { - return knnScorer.score(); + private KNNScorer getOrCreateKnnScorer(LeafReaderContext context, KNNScorer existingScorer) throws IOException { + if (existingScorer != null) { + return existingScorer; } - return 0; + + KNNScorer cachedScorer = knnExplanation.getKnnScorerPerLeaf().get(context); + if (cachedScorer != null) { + return cachedScorer; + } + + KNNScorer newScorer = (KNNScorer) scorer(context); + knnExplanation.getKnnScorerPerLeaf().put(context, newScorer); + return newScorer; + } + + private float getKnnScore(KNNScorer knnScorer, int doc) throws IOException { + return (knnScorer.iterator().advance(doc) == doc) ? knnScorer.score() : 0; } @Override diff --git a/src/main/java/org/opensearch/knn/index/query/explain/KnnExplanation.java b/src/main/java/org/opensearch/knn/index/query/explain/KnnExplanation.java index c1b6b10a63..ce594b1f14 100644 --- a/src/main/java/org/opensearch/knn/index/query/explain/KnnExplanation.java +++ b/src/main/java/org/opensearch/knn/index/query/explain/KnnExplanation.java @@ -7,6 +7,7 @@ import lombok.Getter; import lombok.Setter; +import org.opensearch.knn.index.query.KNNScorer; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; @@ -23,6 +24,9 @@ public class KnnExplanation { @Getter private final Map rawScores; + @Getter + private final Map knnScorerPerLeaf; + @Setter @Getter private int cardinality; @@ -30,6 +34,7 @@ public class KnnExplanation { public KnnExplanation() { this.annResultPerLeaf = new ConcurrentHashMap<>(); this.rawScores = new ConcurrentHashMap<>(); + this.knnScorerPerLeaf = new ConcurrentHashMap<>(); this.cardinality = 0; } } diff --git a/src/test/java/org/opensearch/knn/index/query/ExplainTests.java b/src/test/java/org/opensearch/knn/index/query/ExplainTests.java index 34188f8c7c..b4701e5bd9 100644 --- a/src/test/java/org/opensearch/knn/index/query/ExplainTests.java +++ b/src/test/java/org/opensearch/knn/index/query/ExplainTests.java @@ -381,7 +381,8 @@ public void testDefaultANNSearch() { public void testANN_FilteredExactSearchAfterANN() { ExactSearcher mockedExactSearcher = mock(ExactSearcher.class); KNNWeight.initialize(null, mockedExactSearcher); - when(mockedExactSearcher.searchLeaf(any(), any())).thenReturn(DOC_ID_TO_SCORES); + final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); + when(mockedExactSearcher.searchLeaf(any(), any())).thenReturn(translatedScores); // Given int k = 4; jniServiceMockedStatic.when( @@ -428,7 +429,6 @@ public void testANN_FilteredExactSearchAfterANN() { ); final List actualDocIds = new ArrayList<>(); - final Map translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation); for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) { actualDocIds.add(docId); float score = translatedScores.get(docId) * boost;