Skip to content

Commit

Permalink
Make Merge Abortable
Browse files Browse the repository at this point in the history
Signed-off-by: luyuncheng <[email protected]>
  • Loading branch information
luyuncheng committed Feb 14, 2025
1 parent 349a715 commit 32e151a
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 3 deletions.
14 changes: 14 additions & 0 deletions jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <jni.h>
#include "faiss/MetricType.h"
#include "faiss/impl/io.h"
#include "faiss/impl/AuxIndexStructures.h"
#include "jni_util.h"
#include "faiss_methods.h"
#include "faiss_stream_support.h"
Expand Down Expand Up @@ -195,6 +196,19 @@ class ByteIndexService final : public IndexService {
void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) final;
}; // class ByteIndexService

struct OpenSearchMergeInterruptCallback : faiss::InterruptCallback {

OpenSearchMergeInterruptCallback(JNIUtil *jniUtil, JNIEnv *env) {
mergeHelperClass = jniUtil->FindClass(env,"org/apache/lucene/index/KNNMergeHelper");
isAbortedMethod = jniUtil->FindMethod(env, "org/apache/lucene/index/KNNMergeHelper", "isMergeAborted");
}
bool want_interrupt () override {
return (bool) jenv->CallStaticBooleanMethod(mergeHelperClass, isAbortedMethod);
}
JNIEnv *jenv;
jclass mergeHelperClass;
jmethodID isAbortedMethod;
};
}
}

Expand Down
14 changes: 14 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,20 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSea
JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex
(JNIEnv *, jclass, jlong, jfloatArray, jfloat, jobject, jint, jintArray);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: setMergeInterruptCallback
* Signature: ()V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setMergeInterruptCallback(JNIEnv * env, jclass cls);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: setMergeInterruptCallback
* Signature: ()V
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_unsetMergeInterruptCallback(JNIEnv * env, jclass cls);

#ifdef __cplusplus
}
#endif
Expand Down
5 changes: 5 additions & 0 deletions jni/src/jni_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ void knn_jni::JNIUtil::Initialize(JNIEnv *env) {
this->cachedClasses["org/opensearch/knn/index/query/KNNQueryResult"] = (jclass) env->NewGlobalRef(tempLocalClassRef);
this->cachedMethods["org/opensearch/knn/index/query/KNNQueryResult:<init>"] = env->GetMethodID(tempLocalClassRef, "<init>", "(IF)V");
env->DeleteLocalRef(tempLocalClassRef);

tempLocalClassRef = env->FindClass("org/apache/lucene/index/KNNMergeHelper");
this->cachedClasses["org/apache/lucene/index/KNNMergeHelper"] = (jclass) env->NewGlobalRef(tempLocalClassRef);
this->cachedMethods["org/apache/lucene/index/KNNMergeHelper:isMergeAborted"] = env->GetStaticMethodID(tempLocalClassRef, "isMergeAborted", "()Z");
env->DeleteLocalRef(tempLocalClassRef);
}

void knn_jni::JNIUtil::Uninitialize(JNIEnv* env) {
Expand Down
20 changes: 20 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -481,3 +481,23 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSea
}
return nullptr;
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_setMergeInterruptCallback(JNIEnv * env, jclass cls)
{
try {
faiss::InterruptCallback::instance.reset(
new knn_jni::faiss_wrapper::OpenSearchMergeInterruptCallback(&jniUtil, env)
);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}

JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_unsetMergeInterruptCallback(JNIEnv * env, jclass cls)
{
try {
faiss::InterruptCallback::instance.get()->clear_instance();
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
}
18 changes: 18 additions & 0 deletions src/main/java/org/apache/lucene/index/KNNMergeHelper.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.apache.lucene.index;

public class KNNMergeHelper {

private KNNMergeHelper() {}
public static boolean isMergeAborted() {
Thread mergeThread = Thread.currentThread();
if (mergeThread instanceof ConcurrentMergeScheduler.MergeThread) {
return ((ConcurrentMergeScheduler.MergeThread) mergeThread).merge.isAborted();
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.knn.index.codec.KNN80Codec;

import lombok.extern.log4j.Log4j2;
import org.apache.lucene.index.ConcurrentMergeScheduler;
import org.apache.lucene.index.KNNMergeHelper;
import org.opensearch.common.StopWatch;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNEngine;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.KNNMergeHelper;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.store.IndexOutput;
import org.opensearch.common.Nullable;
Expand Down Expand Up @@ -46,6 +47,8 @@
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.initializeVectorValues;
import static org.opensearch.knn.index.codec.util.KNNCodecUtil.buildEngineFileName;
import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX;
import static org.opensearch.knn.jni.JNIService.setMergeInterruptCallback;
import static org.opensearch.knn.jni.JNIService.unsetMergeInterruptCallback;

/**
* Writes KNN Index for a field in a segment. This is intended to be used for native engines
Expand Down Expand Up @@ -119,9 +122,18 @@ public void mergeIndex(final KNNVectorValues<?> knnVectorValues, int totalLiveDo
}

long bytesPerVector = knnVectorValues.bytesPerVector();
startMergeStats(totalLiveDocs, bytesPerVector);
buildAndWriteIndex(knnVectorValues, totalLiveDocs);
endMergeStats(totalLiveDocs, bytesPerVector);
final KNNEngine knnEngine = extractKNNEngine(fieldInfo);
setMergeInterruptCallback(knnEngine);
try {
startMergeStats(totalLiveDocs, bytesPerVector);
buildAndWriteIndex(knnVectorValues, totalLiveDocs);
endMergeStats(totalLiveDocs, bytesPerVector);
} catch (Exception ex) {
//TODO handle
log.debug("Merge may abort {}",KNNMergeHelper.isMergeAborted());
} finally {
unsetMergeInterruptCallback(knnEngine);
}
}

private void buildAndWriteIndex(final KNNVectorValues<?> knnVectorValues, int totalLiveDocs) throws IOException {
Expand Down
4 changes: 4 additions & 0 deletions src/main/java/org/opensearch/knn/jni/FaissService.java
Original file line number Diff line number Diff line change
Expand Up @@ -463,4 +463,8 @@ public static native KNNQueryResult[] rangeSearchIndex(
int indexMaxResultWindow,
int[] parentIds
);

public static native void setMergeInterruptCallback();

public static native void unsetMergeInterruptCallback();
}
13 changes: 13 additions & 0 deletions src/main/java/org/opensearch/knn/jni/JNIService.java
Original file line number Diff line number Diff line change
Expand Up @@ -476,4 +476,17 @@ public static KNNQueryResult[] radiusQueryIndex(
}
throw new IllegalArgumentException(String.format(Locale.ROOT, "RadiusQueryIndex not supported for provided engine"));
}

public static void setMergeInterruptCallback(KNNEngine knnEngine) {

if (KNNEngine.FAISS == knnEngine) {
FaissService.setMergeInterruptCallback();
}
}

public static void unsetMergeInterruptCallback(KNNEngine knnEngine) {
if (KNNEngine.FAISS == knnEngine) {
FaissService.unsetMergeInterruptCallback();
}
}
}

0 comments on commit 32e151a

Please sign in to comment.