Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Explain API changes #2403

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/common/KNNConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ public class KNNConstants {
public static final String PATH = "path";
public static final String QUERY = "query";
public static final String KNN = "knn";
public static final String EXACT_SEARCH = "Exact";
public static final String ANN_SEARCH = "Approximate-NN";
public static final String RADIAL_SEARCH = "Radial";
public static final String DISK_BASED_SEARCH = "Disk-based";
public static final String VECTOR = "vector";
public static final String K = "k";
public static final String TYPE_KNN_VECTOR = "knn_vector";
Expand Down
30 changes: 25 additions & 5 deletions src/main/java/org/opensearch/knn/index/SpaceType.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public void validateVectorDataType(VectorDataType vectorDataType) {
throw new IllegalStateException("Unsupported method");
}
},
L2("l2") {
L2("l2", SpaceType.GENERIC_SCORE_TRANSLATION) {
@Override
public float scoreTranslation(float rawScore) {
return 1 / (1 + rawScore);
Expand All @@ -59,7 +59,7 @@ public float scoreToDistanceTranslation(float score) {
return 1 / score - 1;
}
},
COSINESIMIL("cosinesimil") {
COSINESIMIL("cosinesimil", "`Math.max((2.0F - rawScore) / 2.0F, 0.0F)`") {
/**
* Cosine similarity has range of [-1, 1] where -1 represents vectors are at diametrically opposite, and 1 is where
* they are identical in direction and perfectly similar. In Lucene, scores have to be in the range of [0, Float.MAX_VALUE].
Expand Down Expand Up @@ -100,13 +100,13 @@ public void validateVector(float[] vector) {
}
}
},
L1("l1") {
L1("l1", SpaceType.GENERIC_SCORE_TRANSLATION) {
@Override
public float scoreTranslation(float rawScore) {
return 1 / (1 + rawScore);
}
},
LINF("linf") {
LINF("linf", SpaceType.GENERIC_SCORE_TRANSLATION) {
@Override
public float scoreTranslation(float rawScore) {
return 1 / (1 + rawScore);
Expand All @@ -129,12 +129,17 @@ public float scoreTranslation(float rawScore) {
return -rawScore + 1;
}

@Override
public String explainScoreTranslation(float rawScore) {
return rawScore >= 0 ? GENERIC_SCORE_TRANSLATION : "`-rawScore + 1`";
}

@Override
public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
return KNNVectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
}
},
HAMMING("hamming") {
HAMMING("hamming", SpaceType.GENERIC_SCORE_TRANSLATION) {
@Override
public float scoreTranslation(float rawScore) {
return 1 / (1 + rawScore);
Expand Down Expand Up @@ -169,14 +174,29 @@ public KNNVectorSimilarityFunction getKnnVectorSimilarityFunction() {
.collect(Collectors.toList())
.toArray(new String[0]);

private static final String GENERIC_SCORE_TRANSLATION = "`1 / (1 + rawScore)`";
private final String value;
private final String explanationFormula;

SpaceType(String value) {
this.value = value;
this.explanationFormula = null;
}

SpaceType(String value, String explanationFormula) {
this.value = value;
this.explanationFormula = explanationFormula;
}

public abstract float scoreTranslation(float rawScore);

public String explainScoreTranslation(float rawScore) {
if (explanationFormula != null) {
return explanationFormula;
}
throw new UnsupportedOperationException("explainScoreTranslation is not defined for this space type.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this effectively means if object created with old constructor (formula is null) then this method will throw exception, is this desired behavior? Can we have some constant like "undefined" or similar, or if formula isn't defined the rets of explanation detail don't make sense as well?

}

/**
* Get KNNVectorSimilarityFunction that maps to this SpaceType
*
Expand Down
3 changes: 3 additions & 0 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ public class KNNQuery extends Query {
private BitSetProducer parentsFilter;
private Float radius;
private Context context;
@Setter
@Getter
private boolean explain;

public KNNQuery(
final String field,
Expand Down
184 changes: 178 additions & 6 deletions src/main/java/org/opensearch/knn/index/query/KNNWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.opensearch.knn.index.engine.KNNEngine;
import org.opensearch.knn.index.quantizationservice.QuantizationService;
import org.opensearch.knn.index.query.ExactSearcher.ExactSearcherContext.ExactSearcherContextBuilder;
import org.opensearch.knn.index.query.explain.KnnExplanation;
import org.opensearch.knn.indices.ModelDao;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelUtil;
Expand Down Expand Up @@ -73,6 +74,7 @@ public class KNNWeight extends Weight {

private static ExactSearcher DEFAULT_EXACT_SEARCHER;
private final QuantizationService quantizationService;
private final KnnExplanation knnExplanation;

public KNNWeight(KNNQuery query, float boost) {
super(query);
Expand All @@ -82,6 +84,7 @@ public KNNWeight(KNNQuery query, float boost) {
this.filterWeight = null;
this.exactSearcher = DEFAULT_EXACT_SEARCHER;
this.quantizationService = QuantizationService.getInstance();
this.knnExplanation = new KnnExplanation();
}

public KNNWeight(KNNQuery query, float boost, Weight filterWeight) {
Expand All @@ -92,6 +95,7 @@ public KNNWeight(KNNQuery query, float boost, Weight filterWeight) {
this.filterWeight = filterWeight;
this.exactSearcher = DEFAULT_EXACT_SEARCHER;
this.quantizationService = QuantizationService.getInstance();
this.knnExplanation = new KnnExplanation();
}

public static void initialize(ModelDao modelDao) {
Expand All @@ -105,8 +109,153 @@ static void initialize(ModelDao modelDao, ExactSearcher exactSearcher) {
}

@Override
// This method is called in case of Radial-Search
public Explanation explain(LeafReaderContext context, int doc) {
return Explanation.match(1.0f, "No Explanation");
return explain(context, doc, 0, null);
}

// This method is called for ANN/Exact/Disk-based/Efficient-filtering search
public Explanation explain(LeafReaderContext context, int doc, float score, KNNScorer knnScorer) {
knnQuery.setExplain(true);
try {
knnScorer = getOrCreateKnnScorer(context, knnScorer);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a scenario when knnScorer is not null? If no, then you should not pass it to this method and always create new instance inside the method.

float knnScore = getKnnScore(knnScorer, doc);

if (score == 0) {
score = knnScore;
}
assert score == knnScore : "Score mismatch in explain: provided score does not match KNN score";
} catch (IOException e) {
throw new RuntimeException("Error while explaining KNN score", e);
}

final String highLevelExplanation = getHighLevelExplanation();
final StringBuilder leafLevelExplanation = getLeafLevelExplanation(context);

final SegmentReader reader = Lucene.segmentReader(context.reader());
final FieldInfo fieldInfo = FieldInfoExtractor.getFieldInfo(reader, knnQuery.getField());
if (fieldInfo == null) {
return Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString()));
}
final SpaceType spaceType = getSpaceType(fieldInfo);
leafLevelExplanation.append(", spaceType = ").append(spaceType.getValue());

final Float rawScore = knnExplanation.getRawScores().get(doc);
Explanation rawScoreDetail = null;
if (rawScore != null && knnQuery.getRescoreContext() == null) {
leafLevelExplanation.append(" where score is computed as ")
.append(spaceType.explainScoreTranslation(rawScore))
.append(" from:");
rawScoreDetail = Explanation.match(
rawScore,
"rawScore, returned from " + FieldInfoExtractor.extractKNNEngine(fieldInfo) + " library"
);
}

return rawScoreDetail != null
? Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString(), rawScoreDetail))
: Explanation.match(score, highLevelExplanation, Explanation.match(score, leafLevelExplanation.toString()));
}

private StringBuilder getLeafLevelExplanation(LeafReaderContext context) {
int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName());
int cardinality = knnExplanation.getCardinality();
StringBuilder sb = new StringBuilder("the type of knn search executed at leaf was ");
if (filterWeight != null) {
if (isFilterIdCountLessThanK(cardinality)) {
sb.append(KNNConstants.EXACT_SEARCH)
.append(" since filteredIds = ")
.append(cardinality)
.append(" is less than or equal to K = ")
.append(knnQuery.getK());
} else if (isExactSearchThresholdSettingSet(filterThresholdValue) && (filterThresholdValue >= cardinality)) {
sb.append(KNNConstants.EXACT_SEARCH)
.append(" since filtered threshold value = ")
.append(filterThresholdValue)
.append(" is greater than or equal to cardinality = ")
.append(cardinality);
} else if (!isExactSearchThresholdSettingSet(filterThresholdValue) && isMDCGreaterThanFilterIdCnt(cardinality)) {
sb.append(KNNConstants.EXACT_SEARCH)
.append(" since max distance computation = ")
.append(KNNConstants.MAX_DISTANCE_COMPUTATIONS)
.append(" is greater than or equal to cardinality = ")
.append(cardinality);
}
}
Integer annResult = knnExplanation.getAnnResult(context.id());
if (annResult != null && annResult == 0 && isMissingNativeEngineFiles(context)) {
sb.append(KNNConstants.EXACT_SEARCH).append(" since no native engine files are available");
}
if (annResult != null && isFilteredExactSearchRequireAfterANNSearch(cardinality, annResult)) {
sb.append(KNNConstants.EXACT_SEARCH)
.append(" since the number of documents returned are less than K = ")
.append(knnQuery.getK())
.append(" and there are more than K filtered Ids = ")
.append(cardinality);
}
if (annResult != null && annResult > 0 && !isFilteredExactSearchRequireAfterANNSearch(cardinality, annResult)) {
sb.append(KNNConstants.ANN_SEARCH);
}
sb.append(" with vectorDataType = ").append(knnQuery.getVectorDataType());
return sb;
}

private SpaceType getSpaceType(FieldInfo fieldInfo) {
try {
return FieldInfoExtractor.getSpaceType(modelDao, fieldInfo);
} catch (IllegalArgumentException e) {
return knnQuery.getVectorDataType() == VectorDataType.BINARY ? SpaceType.DEFAULT_BINARY : SpaceType.DEFAULT;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we are not signaling to user that system is in bad state, wouldn't this be misleading, I mean giving false data to the user?

}
}

private String getHighLevelExplanation() {
StringBuilder sb = new StringBuilder("the type of knn search executed was ");
if (knnQuery.getRescoreContext() != null) {
sb.append(buildDiskBasedSearchExplanation());
} else if (knnQuery.getRadius() != null) {
sb.append(KNNConstants.RADIAL_SEARCH).append(" with the radius of ").append(knnQuery.getRadius());
} else {
sb.append(KNNConstants.ANN_SEARCH);
}
return sb.toString();
}

private String buildDiskBasedSearchExplanation() {
StringBuilder sb = new StringBuilder(KNNConstants.DISK_BASED_SEARCH);
boolean isShardLevelRescoringDisabled = KNNSettings.isShardLevelRescoringDisabledForDiskBasedVector(knnQuery.getIndexName());
int dimension = knnQuery.getQueryVector().length;
int firstPassK = knnQuery.getRescoreContext().getFirstPassK(knnQuery.getK(), isShardLevelRescoringDisabled, dimension);
sb.append(" and the first pass k was ")
.append(firstPassK)
.append(" with vector dimension of ")
.append(dimension)
.append(", over sampling factor of ")
.append(knnQuery.getRescoreContext().getOversampleFactor());
if (isShardLevelRescoringDisabled) {
sb.append(", shard level rescoring disabled");
} else {
sb.append(", shard level rescoring enabled");
}
return sb.toString();
}

private KNNScorer getOrCreateKnnScorer(LeafReaderContext context, KNNScorer existingScorer) throws IOException {
if (existingScorer != null) {
return existingScorer;
}

KNNScorer cachedScorer = knnExplanation.getKnnScorerPerLeaf().get(context);
if (cachedScorer != null) {
return cachedScorer;
}

KNNScorer newScorer = (KNNScorer) scorer(context);
knnExplanation.addKnnScorer(context, newScorer);
return newScorer;
}

private float getKnnScore(KNNScorer knnScorer, int doc) throws IOException {
return (knnScorer.iterator().advance(doc) == doc) ? knnScorer.score() : 0;
}

@Override
Expand Down Expand Up @@ -137,6 +286,9 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
if (filterWeight != null && cardinality == 0) {
return PerLeafResult.EMPTY_RESULT;
}
if (knnQuery.isExplain()) {
knnExplanation.setCardinality(cardinality);
}
/*
* The idea for this optimization is to get K results, we need to at least look at K vectors in the HNSW graph
* . Hence, if filtered results are less than K and filter query is present we should shift to exact search.
Expand All @@ -153,7 +305,9 @@ public PerLeafResult searchLeaf(LeafReaderContext context, int k) throws IOExcep
*/
final BitSet annFilter = (filterWeight != null && cardinality == maxDoc) ? null : filterBitSet;
final Map<Integer, Float> docIdsToScoreMap = doANNSearch(context, annFilter, cardinality, k);

if (knnQuery.isExplain()) {
knnExplanation.addLeafResult(context.id(), docIdsToScoreMap.size());
}
// See whether we have to perform exact search based on approx search results
// This is required if there are no native engine files or if approximate search returned
// results less than K, though we have more than k filtered docs
Expand Down Expand Up @@ -383,6 +537,15 @@ private Map<Integer, Float> doANNSearch(
log.debug("[KNN] Query yielded 0 results");
return Collections.emptyMap();
}
if (knnQuery.isExplain()) {
Arrays.stream(results).forEach(result -> {
if (KNNEngine.FAISS.getName().equals(knnEngine.getName()) && SpaceType.INNER_PRODUCT.equals(spaceType)) {
knnExplanation.addRawScore(result.getId(), -1 * result.getScore());
} else {
knnExplanation.addRawScore(result.getId(), result.getScore());
}
});
}

if (quantizedVector != null) {
return Arrays.stream(results)
Expand Down Expand Up @@ -425,24 +588,33 @@ private boolean isFilteredExactSearchPreferred(final int filterIdsCount) {
);
int filterThresholdValue = KNNSettings.getFilteredExactSearchThreshold(knnQuery.getIndexName());
// Refer this GitHub around more details https://github.com/opensearch-project/k-NN/issues/1049 on the logic
if (knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK()) {
return true;
}
if (isFilterIdCountLessThanK(filterIdsCount)) return true;
// See user has defined Exact Search filtered threshold. if yes, then use that setting.
if (isExactSearchThresholdSettingSet(filterThresholdValue)) {
return filterThresholdValue >= filterIdsCount;
if (filterThresholdValue >= filterIdsCount) {
return true;
}
return false;
}

// if no setting is set, then use the default max distance computation value to see if we can do exact search.
/**
* TODO we can have a different MAX_DISTANCE_COMPUTATIONS for binary index as computation cost for binary index
* is cheaper than computation cost for non binary vector
*/
return isMDCGreaterThanFilterIdCnt(filterIdsCount);
}

private boolean isMDCGreaterThanFilterIdCnt(int filterIdsCount) {
return KNNConstants.MAX_DISTANCE_COMPUTATIONS >= filterIdsCount * (knnQuery.getVectorDataType() == VectorDataType.FLOAT
? knnQuery.getQueryVector().length
: knnQuery.getByteQueryVector().length);
}

private boolean isFilterIdCountLessThanK(int filterIdsCount) {
return knnQuery.getRadius() == null && filterIdsCount <= knnQuery.getK();
}

/**
* This function validates if {@link KNNSettings#ADVANCED_FILTERED_EXACT_SEARCH_THRESHOLD} is set or not. This
* is done by validating if the setting value is equal to the default value.
Expand Down
Loading