Skip to content

Commit

Permalink
Add knnScorer to the explanation object
Browse files Browse the repository at this point in the history
Signed-off-by: Neetika Singhal <[email protected]>
  • Loading branch information
neetikasinghal committed Feb 5, 2025
1 parent 7c4f425 commit 8745703
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 15 deletions.
36 changes: 23 additions & 13 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,18 @@ public Explanation explain(LeafReaderContext context, int doc) {
// This method is called for ANN/Exact/Disk-based/Efficient-filtering search
public Explanation explain(LeafReaderContext context, int doc, float score, KNNScorer knnScorer) {
knnQuery.setExplain(true);
if (knnScorer == null) {
float knnScore;
try {
knnScore = getKnnScore(context, doc);
} catch (IOException e) {
throw new RuntimeException("Error while getting KNN score during explanation", e);
}
try {
knnScorer = getOrCreateKnnScorer(context, knnScorer);
float knnScore = getKnnScore(knnScorer, doc);

if (score == 0) {
score = knnScore;
}
assert score == knnScore : "Score mismatch in explain: provided score does not match KNN score";
} catch (IOException e) {
throw new RuntimeException("Error while explaining KNN score", e);
}

final String highLevelExplanation = getHighLevelExplanation();
final StringBuilder leafLevelExplanation = getLeafLevelExplanation(context);

Expand Down Expand Up @@ -243,13 +243,23 @@ private String buildDiskBasedSearchExplanation() {
return sb.toString();
}

private float getKnnScore(LeafReaderContext context, int doc) throws IOException {
KNNScorer knnScorer = (KNNScorer) scorer(context);
int resDoc = knnScorer.iterator().advance(doc);
if (resDoc == doc) {
return knnScorer.score();
private KNNScorer getOrCreateKnnScorer(LeafReaderContext context, KNNScorer existingScorer) throws IOException {
if (existingScorer != null) {
return existingScorer;
}
return 0;

KNNScorer cachedScorer = knnExplanation.getKnnScorerPerLeaf().get(context);
if (cachedScorer != null) {
return cachedScorer;
}

KNNScorer newScorer = (KNNScorer) scorer(context);
knnExplanation.getKnnScorerPerLeaf().put(context, newScorer);
return newScorer;
}

private float getKnnScore(KNNScorer knnScorer, int doc) throws IOException {
return (knnScorer.iterator().advance(doc) == doc) ? knnScorer.score() : 0;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.Getter;
import lombok.Setter;
import org.opensearch.knn.index.query.KNNScorer;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -23,13 +24,17 @@ public class KnnExplanation {
@Getter
private final Map<Integer, Float> rawScores;

@Getter
private final Map<Object, KNNScorer> knnScorerPerLeaf;

@Setter
@Getter
private int cardinality;

public KnnExplanation() {
this.annResultPerLeaf = new ConcurrentHashMap<>();
this.rawScores = new ConcurrentHashMap<>();
this.knnScorerPerLeaf = new ConcurrentHashMap<>();
this.cardinality = 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,8 @@ public void testDefaultANNSearch() {
public void testANN_FilteredExactSearchAfterANN() {
ExactSearcher mockedExactSearcher = mock(ExactSearcher.class);
KNNWeight.initialize(null, mockedExactSearcher);
when(mockedExactSearcher.searchLeaf(any(), any())).thenReturn(DOC_ID_TO_SCORES);
final Map<Integer, Float> translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation);
when(mockedExactSearcher.searchLeaf(any(), any())).thenReturn(translatedScores);
// Given
int k = 4;
jniServiceMockedStatic.when(
Expand Down Expand Up @@ -428,7 +429,6 @@ public void testANN_FilteredExactSearchAfterANN() {
);

final List<Integer> actualDocIds = new ArrayList<>();
final Map<Integer, Float> translatedScores = getTranslatedScores(SpaceType.L2::scoreTranslation);
for (int docId = docIdSetIterator.nextDoc(); docId != NO_MORE_DOCS; docId = docIdSetIterator.nextDoc()) {
actualDocIds.add(docId);
float score = translatedScores.get(docId) * boost;
Expand Down

0 comments on commit 8745703

Please sign in to comment.