-
Notifications
You must be signed in to change notification settings - Fork 144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Explain API changes #2403
base: main
Are you sure you want to change the base?
Explain API changes #2403
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -35,6 +35,7 @@ | |
import org.opensearch.knn.index.engine.KNNEngine; | ||
import org.opensearch.knn.index.quantizationservice.QuantizationService; | ||
import org.opensearch.knn.index.query.ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder; | ||
import org.opensearch.knn.index.query.explain.KnnExplanation; | ||
import org.opensearch.knn.indices.ModelDao; | ||
import org.opensearch.knn.indices.ModelMetadata; | ||
import org.opensearch.knn.indices.ModelUtil; | ||
|
@@ -73,6 +74,7 @@ public class KNNWeight extends Weight { | |
|
||
private static ExactSearcher DEFAULT_EXACT_SEARCHER; | ||
private final QuantizationService quantizationService; | ||
private final KnnExplanation knnExplanation; | ||
|
||
public KNNWeight(KNNQuery query, float boost) { | ||
super(query); | ||
|
@@ -82,6 +84,7 @@ public KNNWeight(KNNQuery query, float boost) { | |
this.filterWeight = null; | ||
this.exactSearcher = DEFAULT_EXACT_SEARCHER; | ||
this.quantizationService = QuantizationService.getInstance(); | ||
this.knnExplanation = new KnnExplanation(); | ||
} | ||
|
||
public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { | ||
|
@@ -92,6 +95,7 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) { | |
this.filterWeight = filterWeight; | ||
this.exactSearcher = DEFAULT_EXACT_SEARCHER; | ||
this.quantizationService = QuantizationService.getInstance(); | ||
this.knnExplanation = new KnnExplanation(); | ||
} | ||
|
||
public static void initialize(ModelDao modelDao) { | ||
|
@@ -105,8 +109,153 @@ static void initialize(ModelDao modelDao, ExactSearcher exactSearcher) { | |
} | ||
|
||
@Override | ||
// This method is called in case of Radial-Search | ||
public Explanation explain(LeafReaderContext context, int doc) { | ||
neetikasinghal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return Explanation.match(1.0f, "No Explanation"); | ||
return explain(context, doc, 0, null); | ||
} | ||
|
||
// This method is called for ANN/Exact/Disk-based/Efficient-filtering search | ||
public Explanation explain(LeafReaderContext context, int doc, float score, KNNScorer knnScorer) { | ||
knnQuery.setExplain(true); | ||
try { | ||
knnScorer = getOrCreateKnnScorer(context, knnScorer); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a scenario when knnScorer is not null? If no, then you should not pass it to this method and always create new instance inside the method. |
||
float knnScore = getKnnScore(knnScorer, doc); | ||
|
||
if (score == 0) { | ||
score = knnScore; | ||
} | ||
neetikasinghal marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert score == knnScore : "Score mismatch in explain: provided score does not match KNN score"; | ||
} catch (IOException e) { | ||
throw new RuntimeException("Error while explaining KNN score", e); | ||
} | ||
|
||
final String highLevelExplanation = getHighLevelExplanation(); | ||
final StringBuilder leafLevelExplanation = getLeafLevelExplanation(context); | ||
|
||
final SegmentReader reader = Lucene.segmentReader(context.reader()); | ||
final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField()); | ||
if (fieldInfo == null) { | ||
return Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString())); | ||
} | ||
final SpaceType spaceType = getSpaceType(fieldInfo); | ||
leafLevelExplanation.append(", spaceType = ").append(spaceType.getValue()); | ||
|
||
final Float rawScore = knnExplanation.getRawScores().get(doc); | ||
Explanation rawScoreDetail = null; | ||
if (rawScore != null && knnQuery.getRescoreContext() == null) { | ||
leafLevelExplanation.append(" where score is computed as ") | ||
.append(spaceType.explainScoreTranslation(rawScore)) | ||
.append(" from:"); | ||
rawScoreDetail = Explanation.match( | ||
rawScore, | ||
"rawScore, returned from " + FieldInfoExtractor.extractKNNEngine(fieldInfo) + " library" | ||
); | ||
} | ||
|
||
return rawScoreDetail != null | ||
? Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString(), rawScoreDetail)) | ||
: Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString())); | ||
} | ||
|
||
private StringBuilder getLeafLevelExplanation(LeafReaderContext context) { | ||
int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); | ||
int cardinality = knnExplanation.getCardinality(); | ||
StringBuilder sb = new StringBuilder("the type of knn search executed at leaf was "); | ||
if (filterWeight != null) { | ||
if (isFilterIdCountLessThanK(cardinality)) { | ||
sb.append(KNNConstants.EXACT_SEARCH) | ||
.append(" since filteredIds = ") | ||
.append(cardinality) | ||
.append(" is less than or equal to K = ") | ||
.append(knnQuery.getK()); | ||
} else if (isExactSearchThresholdSettingSet(filterThresholdValue) && (filterThresholdValue >= cardinality)) { | ||
sb.append(KNNConstants.EXACT_SEARCH) | ||
.append(" since filtered threshold value = ") | ||
.append(filterThresholdValue) | ||
.append(" is greater than or equal to cardinality = ") | ||
.append(cardinality); | ||
} else if (!isExactSearchThresholdSettingSet(filterThresholdValue) && isMDCGreaterThanFilterIdCnt(cardinality)) { | ||
sb.append(KNNConstants.EXACT_SEARCH) | ||
.append(" since max distance computation = ") | ||
.append(KNNConstants.MAX_DISTANCE_COMPUTATIONS) | ||
.append(" is greater than or equal to cardinality = ") | ||
.append(cardinality); | ||
} | ||
} | ||
Integer annResult = knnExplanation.getAnnResult(context.id()); | ||
if (annResult != null && annResult == 0 && isMissingNativeEngineFiles(context)) { | ||
sb.append(KNNConstants.EXACT_SEARCH).append(" since no native engine files are available"); | ||
} | ||
if (annResult != null && isFilteredExactSearchRequireAfterANNSearch(cardinality, annResult)) { | ||
sb.append(KNNConstants.EXACT_SEARCH) | ||
.append(" since the number of documents returned are less than K = ") | ||
.append(knnQuery.getK()) | ||
.append(" and there are more than K filtered Ids = ") | ||
.append(cardinality); | ||
} | ||
if (annResult != null && annResult > 0 && !isFilteredExactSearchRequireAfterANNSearch(cardinality, annResult)) { | ||
sb.append(KNNConstants.ANN_SEARCH); | ||
} | ||
sb.append(" with vectorDataType = ").append(knnQuery.getVectorDataType()); | ||
return sb; | ||
} | ||
|
||
private SpaceType getSpaceType(FieldInfo fieldInfo) { | ||
try { | ||
return FieldInfoExtractor.getSpaceType(modelDao, fieldInfo); | ||
} catch (IllegalArgumentException e) { | ||
return knnQuery.getVectorDataType() == VectorDataType.BINARY ? SpaceType.DEFAULT_BINARY : SpaceType.DEFAULT; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why we are not signaling to user that system is in bad state, wouldn't this be misleading, I mean giving false data to the user? |
||
} | ||
} | ||
|
||
private String getHighLevelExplanation() { | ||
StringBuilder sb = new StringBuilder("the type of knn search executed was "); | ||
if (knnQuery.getRescoreContext() != null) { | ||
sb.append(buildDiskBasedSearchExplanation()); | ||
} else if (knnQuery.getRadius() != null) { | ||
sb.append(KNNConstants.RADIAL_SEARCH).append(" with the radius of ").append(knnQuery.getRadius()); | ||
} else { | ||
sb.append(KNNConstants.ANN_SEARCH); | ||
} | ||
return sb.toString(); | ||
} | ||
|
||
private String buildDiskBasedSearchExplanation() { | ||
StringBuilder sb = new StringBuilder(KNNConstants.DISK_BASED_SEARCH); | ||
boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName()); | ||
int dimension = knnQuery.getQueryVector().length; | ||
int firstPassK = knnQuery.getRescoreContext().getFirstPassK(knnQuery.getK(), isShardLevelRescoringDisabled, dimension); | ||
sb.append(" and the first pass k was ") | ||
.append(firstPassK) | ||
.append(" with vector dimension of ") | ||
.append(dimension) | ||
.append(", over sampling factor of ") | ||
.append(knnQuery.getRescoreContext().getOversampleFactor()); | ||
if (isShardLevelRescoringDisabled) { | ||
sb.append(", shard level rescoring disabled"); | ||
} else { | ||
sb.append(", shard level rescoring enabled"); | ||
} | ||
return sb.toString(); | ||
} | ||
|
||
private KNNScorer getOrCreateKnnScorer(LeafReaderContext context, KNNScorer existingScorer) throws IOException { | ||
if (existingScorer != null) { | ||
return existingScorer; | ||
} | ||
|
||
KNNScorer cachedScorer = knnExplanation.getKnnScorerPerLeaf().get(context); | ||
if (cachedScorer != null) { | ||
return cachedScorer; | ||
} | ||
|
||
KNNScorer newScorer = (KNNScorer) scorer(context); | ||
knnExplanation.addKnnScorer(context, newScorer); | ||
return newScorer; | ||
} | ||
|
||
private float getKnnScore(KNNScorer knnScorer, int doc) throws IOException { | ||
return (knnScorer.iterator().advance(doc) == doc) ? knnScorer.score() : 0; | ||
} | ||
|
||
@Override | ||
|
@@ -137,6 +286,9 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep | |
if (filterWeight != null && cardinality == 0) { | ||
return PerLeafResult.EMPTY_RESULT; | ||
} | ||
if (knnQuery.isExplain()) { | ||
knnExplanation.setCardinality(cardinality); | ||
} | ||
/* | ||
* The idea for this optimization is to get K results, we need to at least look at K vectors in the HNSW graph | ||
* . Hence, if filtered results are less than K and filter query is present we should shift to exact search. | ||
|
@@ -153,7 +305,9 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep | |
*/ | ||
final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet; | ||
final Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k); | ||
|
||
if (knnQuery.isExplain()) { | ||
knnExplanation.addLeafResult(context.id(), docIdsToScoreMap.size()); | ||
} | ||
// See whether we have to perform exact search based on approx search results | ||
// This is required if there are no native engine files or if approximate search returned | ||
// results less than K, though we have more than k filtered docs | ||
|
@@ -383,6 +537,15 @@ private Map<Integer, Float> doANNSearch( | |
log.debug("[KNN] Query yielded 0 results"); | ||
return Collections.emptyMap(); | ||
} | ||
if (knnQuery.isExplain()) { | ||
Arrays.stream(results).forEach(result -> { | ||
if (KNNEngine.FAISS.getName().equals(knnEngine.getName()) && SpaceType.INNER_PRODUCT.equals(spaceType)) { | ||
knnExplanation.addRawScore(result.getId(), -1 * result.getScore()); | ||
} else { | ||
knnExplanation.addRawScore(result.getId(), result.getScore()); | ||
} | ||
}); | ||
} | ||
|
||
if (quantizedVector != null) { | ||
return Arrays.stream(results) | ||
|
@@ -425,24 +588,33 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { | |
); | ||
int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName()); | ||
// Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic | ||
if (knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK()) { | ||
return true; | ||
} | ||
if (isFilterIdCountLessThanK(filterIdsCount)) return true; | ||
// See user has defined Exact Search filtered threshold. if yes, then use that setting. | ||
if (isExactSearchThresholdSettingSet(filterThresholdValue)) { | ||
return filterThresholdValue >= filterIdsCount; | ||
if (filterThresholdValue >= filterIdsCount) { | ||
return true; | ||
} | ||
return false; | ||
} | ||
|
||
// if no setting is set, then use the default max distance computation value to see if we can do exact search. | ||
/** | ||
* TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index | ||
* is cheaper than computation cost for non binary vector | ||
*/ | ||
return isMDCGreaterThanFilterIdCnt(filterIdsCount); | ||
} | ||
|
||
private boolean isMDCGreaterThanFilterIdCnt(int filterIdsCount) { | ||
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT | ||
? knnQuery.getQueryVector().length | ||
: knnQuery.getByteQueryVector().length); | ||
} | ||
|
||
private boolean isFilterIdCountLessThanK(int filterIdsCount) { | ||
return knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK(); | ||
} | ||
|
||
/** | ||
* This function validates if {@link KNNSettings#ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD} is set or not. This | ||
* is done by validating if the setting value is equal to the default value. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this effectively means if object created with old constructor (formula is null) then this method will throw exception, is this desired behavior? Can we have some constant like "undefined" or similar, or if formula isn't defined the rets of explanation detail don't make sense as well?