diff --git a/CHANGELOG.md b/CHANGELOG.md index c16c78cf93..7087cf1aa0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292] - Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331] - Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376] +- Add support for Faiss onDisk 4x compression (#2425)[https://github.com/opensearch-project/k-NN/pull/2425] ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index d96c3e7557..605cc3e86d 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -69,10 +69,25 @@ class IndexService { */ virtual void writeIndex(faiss::IOWriter* writer, jlong idMapAddress); + /** + * Initialize index from template + * + * @param jniUtil jni util + * @param env jni environment + * @param dim dimension of vectors + * @param numVectors number of vectors + * @param threadCount number of thread count to be used while adding data + * @param templateIndexJ template index + * @return memory address of the native index object + */ + virtual jlong initIndexFromTemplate(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, int dim, int numVectors, int threadCount, jbyteArray templateIndexJ); + + virtual ~IndexService() = default; protected: virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors); + virtual jlong initAndAllocateIndex(std::unique_ptr &index, size_t threadCount, size_t dim, size_t numVectors); std::unique_ptr faissMethods; }; // class IndexService @@ -132,8 +147,22 @@ class BinaryIndexService final : public IndexService { */ void writeIndex(faiss::IOWriter* writer, jlong idMapAddress) final; + /** + * Initialize index from template + * + * @param jniUtil jni util + * @param env jni environment + * @param dim dimension of vectors + * @param numVectors number of vectors + * @param threadCount number of thread count to be used while adding data + * @param templateIndexJ template index + * @return memory address of the native index object + */ + jlong initIndexFromTemplate(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, int dim, int numVectors, int threadCount, jbyteArray templateIndexJ) final; + protected: void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) final; + jlong initAndAllocateIndex(std::unique_ptr &index, size_t threadCount, size_t dim, size_t numVectors); }; // class BinaryIndexService /** @@ -191,8 +220,22 @@ class ByteIndexService final : public IndexService { */ void writeIndex(faiss::IOWriter* writer, jlong idMapAddress) final; + /** + * Initialize index from template + * + * @param jniUtil jni util + * @param env jni environment + * @param dim dimension of vectors + * @param numVectors number of vectors + * @param threadCount number of thread count to be used while adding data + * @param templateIndexJ template index + * @return memory address of the native index object + */ + jlong initIndexFromTemplate(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, int dim, int numVectors, int threadCount, jbyteArray templateIndexJ) final; + protected: void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) final; + jlong initAndAllocateIndex(std::unique_ptr &index, size_t threadCount, size_t dim, size_t numVectors) final; }; // class ByteIndexService } diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index e48e6faa91..23e3ad6b30 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -25,6 +25,9 @@ namespace knn_jni { void WriteIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jobject output, jlong indexAddr, IndexService *indexService); + jlong InitIndexFromTemplate(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong numDocs, jint dimJ, jobject parametersJ, jbyteArray templateIndexJ, IndexService *indexService); + + // Create an index with ids and vectors. Instead of creating a new index, this function creates the index // based off of the template index passed in. The index is serialized to indexPathJ. void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index dce5801383..c3b14bbe3f 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -92,6 +92,32 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex */ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeByteIndex(JNIEnv *, jclass, jlong, jobject); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: initIndexFromTemplate + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndexFromTemplate(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ, jbyteArray templateIndexJ); +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: initBinaryIndexFromTemplate + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndexFromTemplate(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ, jbyteArray templateIndexJ); + +/* + * Class: org_opensearch_knn_jni_FaissService + * Method: initByteIndexFromTemplate + * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + */ +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndexFromTemplate(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ, jbyteArray templateIndexJ); + /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndexFromTemplate diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index 4999e31729..3d64a557ea 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -68,6 +68,25 @@ void IndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVector } } +jlong IndexService::initAndAllocateIndex(std::unique_ptr &index, size_t threadCount, size_t dim, size_t numVectors) { + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if (threadCount != 0) { + omp_set_num_threads(threadCount); + } + + std::unique_ptr idMap (faissMethods->indexIdMap(index.get())); + //Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor + idMap->own_fields = true; + + // TODO: allocIndex for IVF + allocIndex(dynamic_cast(idMap->index), dim, numVectors); + + //Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later + //in insert and write operations + index.release(); + return reinterpret_cast(idMap.release()); +} + jlong IndexService::initIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, @@ -81,11 +100,6 @@ jlong IndexService::initIndex( // Create index using Faiss factory method std::unique_ptr index(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); - // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread - if (threadCount != 0) { - omp_set_num_threads(threadCount); - } - // Add extra parameters that cant be configured with the index factory SetExtraParameters(jniUtil, env, parameters, index.get()); @@ -94,16 +108,7 @@ jlong IndexService::initIndex( throw std::runtime_error("Index is not trained"); } - std::unique_ptr idMap (faissMethods->indexIdMap(index.get())); - //Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor - idMap->own_fields = true; - - allocIndex(dynamic_cast(idMap->index), dim, numVectors); - - //Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later - //in insert and write operations - index.release(); - return reinterpret_cast(idMap.release()); + return initAndAllocateIndex(index, threadCount, dim, numVectors); } void IndexService::insertToIndex( @@ -155,6 +160,32 @@ void IndexService::writeIndex( } } +jlong IndexService::initIndexFromTemplate( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + int dim, + int numVectors, + int threadCount, + jbyteArray templateIndexJ + ) { + + // Get vector of bytes from jbytearray + int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ); + jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr); + + faiss::VectorIOReader vectorIoReader; + for (int i = 0; i < indexBytesCount; i++) { + vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]); + } + jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); + + // Create faiss index + std::unique_ptr index; + index.reset(faiss::read_index(&vectorIoReader, 0)); + + return initAndAllocateIndex(index, threadCount, dim, numVectors); +} + BinaryIndexService::BinaryIndexService(std::unique_ptr _faissMethods) : IndexService(std::move(_faissMethods)) { } @@ -166,6 +197,25 @@ void BinaryIndexService::allocIndex(faiss::Index * index, size_t dim, size_t num } } +jlong BinaryIndexService::initAndAllocateIndex(std::unique_ptr &index, size_t threadCount, size_t dim, size_t numVectors) { + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if (threadCount != 0) { + omp_set_num_threads(threadCount); + } + + std::unique_ptr idMap(faissMethods->indexBinaryIdMap(index.get())); + //Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor + idMap->own_fields = true; + + // TODO: allocIndex for IVF + allocIndex(dynamic_cast(idMap->index), dim, numVectors); + + //Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later + //in insert and write operations + index.release(); + return reinterpret_cast(idMap.release()); +} + jlong BinaryIndexService::initIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, @@ -178,10 +228,6 @@ jlong BinaryIndexService::initIndex( ) { // Create index using Faiss factory method std::unique_ptr index(faissMethods->indexBinaryFactory(dim, indexDescription.c_str())); - // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread - if (threadCount != 0) { - omp_set_num_threads(threadCount); - } // Add extra parameters that cant be configured with the index factory SetExtraParameters(jniUtil, env, parameters, index.get()); @@ -191,16 +237,7 @@ jlong BinaryIndexService::initIndex( throw std::runtime_error("Index is not trained"); } - std::unique_ptr idMap(faissMethods->indexBinaryIdMap(index.get())); - //Makes sure the index is deleted when the destructor is called - idMap->own_fields = true; - - allocIndex(dynamic_cast(idMap->index), dim, numVectors); - - //Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later - //in insert and write operations - index.release(); - return reinterpret_cast(idMap.release()); + return initAndAllocateIndex(index, threadCount, dim, numVectors); } void BinaryIndexService::insertToIndex( @@ -252,6 +289,35 @@ void BinaryIndexService::writeIndex( } } +jlong BinaryIndexService::initIndexFromTemplate( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + int dim, + int numVectors, + int threadCount, + jbyteArray templateIndexJ + ) { + if (dim % 8 != 0) { + throw std::runtime_error("Dimensions should be multiple of 8"); + } + + // Get vector of bytes from jbytearray + int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ); + jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr); + + faiss::VectorIOReader vectorIoReader; + for (int i = 0; i < indexBytesCount; i++) { + vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]); + } + jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); + + // Create faiss index + std::unique_ptr index; + index.reset(faiss::read_index_binary(&vectorIoReader, 0)); + + return initAndAllocateIndex(index, threadCount, dim, numVectors); +} + ByteIndexService::ByteIndexService(std::unique_ptr _faissMethods) : IndexService(std::move(_faissMethods)) { } @@ -264,6 +330,25 @@ void ByteIndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVe } } +jlong ByteIndexService::initAndAllocateIndex(std::unique_ptr &index, size_t threadCount, size_t dim, size_t numVectors) { + // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread + if (threadCount != 0) { + omp_set_num_threads(threadCount); + } + + std::unique_ptr idMap (faissMethods->indexIdMap(index.get())); + //Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor + idMap->own_fields = true; + + // TODO: allocIndex for IVF + allocIndex(dynamic_cast(idMap->index), dim, numVectors); + + //Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later + //in insert and write operations + index.release(); + return reinterpret_cast(idMap.release()); +} + jlong ByteIndexService::initIndex( knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, @@ -277,11 +362,6 @@ jlong ByteIndexService::initIndex( // Create index using Faiss factory method std::unique_ptr index(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); - // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread - if (threadCount != 0) { - omp_set_num_threads(threadCount); - } - // Add extra parameters that cant be configured with the index factory SetExtraParameters(jniUtil, env, parameters, index.get()); @@ -290,16 +370,7 @@ jlong ByteIndexService::initIndex( throw std::runtime_error("Index is not trained"); } - std::unique_ptr idMap (faissMethods->indexIdMap(index.get())); - //Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor - idMap->own_fields = true; - - allocIndex(dynamic_cast(idMap->index), dim, numVectors); - - //Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later - //in insert and write operations - index.release(); - return reinterpret_cast(idMap.release()); + return initAndAllocateIndex(index, threadCount, dim, numVectors); } void ByteIndexService::insertToIndex( @@ -368,5 +439,31 @@ void ByteIndexService::writeIndex( throw std::runtime_error("Failed to write index to disk"); } } + +jlong ByteIndexService::initIndexFromTemplate( + knn_jni::JNIUtilInterface * jniUtil, + JNIEnv * env, + int dim, + int numVectors, + int threadCount, + jbyteArray templateIndexJ + ) { + + // Get vector of bytes from jbytearray + int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ); + jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr); + + faiss::VectorIOReader vectorIoReader; + for (int i = 0; i < indexBytesCount; i++) { + vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]); + } + jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); + + // Create faiss index + std::unique_ptr index; + index.reset(faiss::read_index(&vectorIoReader, 0)); + + return initAndAllocateIndex(index, threadCount, dim, numVectors); +} } // namespace faiss_wrapper } // namesapce knn_jni diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index c02c410c1f..fa2b1068be 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -198,6 +198,49 @@ void knn_jni::faiss_wrapper::WriteIndex(knn_jni::JNIUtilInterface * jniUtil, JNI indexService->writeIndex(&writer, index_ptr); } +jlong knn_jni::faiss_wrapper::InitIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong numDocs, jint dimJ, + jobject parametersJ, jbyteArray templateIndexJ, IndexService* indexService) { + + if(dimJ <= 0) { + throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); + } + + if (parametersJ == nullptr) { + throw std::runtime_error("Parameters cannot be null"); + } + + if (templateIndexJ == nullptr) { + throw std::runtime_error("Template index cannot be null"); + } + + // parametersJ is a Java Map. ConvertJavaMapToCppMap converts it to a c++ map + // so that it is easier to access. + auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); + + // Thread count + int threadCount = 0; + if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + } + jniUtil->DeleteLocalRef(env, parametersJ); + + + // Dimension + int dim = (int)dimJ; + + // Number of docs + int docs = (int)numDocs; + // end parameters to pass + + // Create index + return indexService->initIndexFromTemplate(jniUtil, + env, + dim, + docs, + threadCount, + templateIndexJ); +} + void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jobject output, jbyteArray templateIndexJ, jobject parametersJ) { diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 836774402f..99310fb1f2 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -166,6 +166,49 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeByteIndex(J } } +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndexFromTemplate(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ, jbyteArray templateIndexJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + return knn_jni::faiss_wrapper::InitIndexFromTemplate(&jniUtil, env, numDocs, dimJ, parametersJ, templateIndexJ, &indexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (jlong)0; +} + +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndexFromTemplate(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ, jbyteArray templateIndexJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + return knn_jni::faiss_wrapper::InitIndexFromTemplate(&jniUtil, env, numDocs, dimJ, parametersJ, templateIndexJ, &binaryIndexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (jlong)0; +} + +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndexFromTemplate(JNIEnv * env, jclass cls, + jlong numDocs, jint dimJ, + jobject parametersJ, jbyteArray templateIndexJ) +{ + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::ByteIndexService byteIndexService(std::move(faissMethods)); + return knn_jni::faiss_wrapper::InitIndexFromTemplate(&jniUtil, env, numDocs, dimJ, parametersJ, templateIndexJ, &byteIndexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return (jlong)0; +} + + JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate(JNIEnv * env, jclass cls, jintArray idsJ, diff --git a/src/main/java/org/opensearch/knn/common/KNNConstants.java b/src/main/java/org/opensearch/knn/common/KNNConstants.java index a2985e9d8d..82b5686548 100644 --- a/src/main/java/org/opensearch/knn/common/KNNConstants.java +++ b/src/main/java/org/opensearch/knn/common/KNNConstants.java @@ -118,8 +118,11 @@ public class KNNConstants { public static final String FAISS_SQ_DESCRIPTION = "SQ"; public static final String FAISS_SQ_TYPE = "type"; public static final String FAISS_SQ_ENCODER_FP16 = "fp16"; - public static final List FAISS_SQ_ENCODER_TYPES = List.of(FAISS_SQ_ENCODER_FP16); + public static final String FAISS_SQ_ENCODER_INT8 = "int8"; + public static final List FAISS_SQ_ENCODER_TYPES = List.of(FAISS_SQ_ENCODER_FP16, FAISS_SQ_ENCODER_INT8); public static final String FAISS_SIGNED_BYTE_SQ = "SQ8_direct_signed"; + + public static final String FAISS_SQ8 = "SQ8"; public static final String FAISS_SQ_CLIP = "clip"; // Parameter defaults/limits diff --git a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java index ac4c055b01..922f5804c1 100644 --- a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java +++ b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java @@ -28,6 +28,7 @@ import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import java.io.IOException; import java.util.ArrayList; @@ -184,6 +185,7 @@ List getEngineFileContexts(IndexReader indexReader, KNNEngine String spaceTypeName = fieldInfo.attributes().getOrDefault(SPACE_TYPE, SpaceType.L2.getValue()); SpaceType spaceType = SpaceType.getSpace(spaceTypeName); String modelId = fieldInfo.attributes().getOrDefault(MODEL_ID, null); + QuantizationConfig quantizationConfig = FieldInfoExtractor.extractQuantizationConfig(fieldInfo); engineFiles.addAll( getEngineFileContexts( reader.getSegmentInfo(), @@ -191,11 +193,7 @@ List getEngineFileContexts(IndexReader indexReader, KNNEngine fileExtension, spaceType, modelId, - FieldInfoExtractor.extractQuantizationConfig(fieldInfo) == QuantizationConfig.EMPTY - ? VectorDataType.get( - fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue()) - ) - : VectorDataType.BINARY + getVectorDataType(quantizationConfig, fieldInfo) ) ); } @@ -225,6 +223,16 @@ List getEngineFileContexts( .collect(Collectors.toList()); } + private VectorDataType getVectorDataType(QuantizationConfig quantizationConfig, FieldInfo fieldInfo) { + if (quantizationConfig == QuantizationConfig.EMPTY) { + return VectorDataType.get(fieldInfo.attributes().getOrDefault(VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT.getValue())); + } + if (quantizationConfig.getQuantizationType() == ScalarQuantizationType.EIGHT_BIT) { + return VectorDataType.BYTE; + } + return VectorDataType.BINARY; + } + @AllArgsConstructor @Getter @VisibleForTesting diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index fb93bfc073..7f3f04f070 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -30,7 +30,9 @@ import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; @@ -253,10 +255,10 @@ private QuantizationState train( final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); QuantizationState quantizationState = null; if (quantizationParams != null && totalLiveDocs > 0) { - initQuantizationStateWriterIfNecessary(); KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); - quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs); - quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); + + quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs, fieldInfo); + writeQuantizationState(quantizationParams, quantizationState, fieldInfo.getFieldNumber()); } return quantizationState; @@ -277,11 +279,24 @@ private int getLiveDocs(KNNVectorValues vectorValues) throws IOException { return liveDocs; } - private void initQuantizationStateWriterIfNecessary() throws IOException { + private void writeQuantizationState(QuantizationParams quantizationParams, QuantizationState quantizationState, int fieldNumber) + throws IOException { + + // We will not write quantization state for 8 bits into a segment file because the query vectors will not be quantized and + // the template index is stored in the quantization state after training to use it later in the index build strategy during + // data ingestion. + if ((quantizationParams.getTypeIdentifier()).equals( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.EIGHT_BIT) + )) { + return; + } + + // Initialize quantization state writer if required if (quantizationStateWriter == null) { quantizationStateWriter = new KNN990QuantizationStateWriter(segmentWriteState); quantizationStateWriter.writeHeader(segmentWriteState); } + quantizationStateWriter.writeState(fieldNumber, quantizationState); } private boolean shouldSkipBuildingVectorDataStructure(final long docCount) { diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java index 2864be6d2c..afac0d2988 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java @@ -12,6 +12,7 @@ import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.quantization.models.quantizationState.ByteScalarQuantizationState; import java.io.IOException; import java.security.AccessController; @@ -57,16 +58,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOExcept KNNEngine engine = indexInfo.getKnnEngine(); Map indexParameters = indexInfo.getParameters(); IndexBuildSetup indexBuildSetup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, indexInfo); - - // Initialize the index - long indexMemoryAddress = AccessController.doPrivileged( - (PrivilegedAction) () -> JNIService.initIndex( - indexInfo.getTotalLiveDocs(), - indexBuildSetup.getDimensions(), - indexParameters, - engine - ) - ); + long indexMemoryAddress = initializeIndex(indexInfo, indexBuildSetup, indexParameters, engine); try ( final OffHeapVectorTransfer vectorTransfer = getVectorTransfer( @@ -133,4 +125,43 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOExcept ); } } + + private long initializeIndex( + BuildIndexParams indexInfo, + IndexBuildSetup indexBuildSetup, + Map indexParameters, + KNNEngine engine + ) { + if (isTemplate(indexInfo)) { + // Initialize the index from Template + return AccessController.doPrivileged( + (PrivilegedAction) () -> JNIService.initIndexFromTemplate( + indexInfo.getTotalLiveDocs(), + indexBuildSetup.getDimensions(), + indexParameters, + engine, + getIndexTemplate(indexInfo) + ) + ); + + } + // Initialize the index + return AccessController.doPrivileged( + (PrivilegedAction) () -> JNIService.initIndex( + indexInfo.getTotalLiveDocs(), + indexBuildSetup.getDimensions(), + indexParameters, + engine + ) + ); + } + + private static boolean isTemplate(final BuildIndexParams indexInfo) { + return (indexInfo.getQuantizationState() instanceof ByteScalarQuantizationState); + } + + private byte[] getIndexTemplate(BuildIndexParams indexInfo) { + ByteScalarQuantizationState byteSQState = (ByteScalarQuantizationState) indexInfo.getQuantizationState(); + return byteSQState.getIndexTemplate(); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java index c5994d66bb..745802dd39 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java @@ -10,6 +10,7 @@ import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.ByteScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; @@ -58,7 +59,7 @@ static IndexBuildSetup prepareIndexBuild(KNNVectorValues knnVectorValues, Bui int bytesPerVector; int dimensions; - if (quantizationState != null) { + if (quantizationState != null && !(quantizationState instanceof ByteScalarQuantizationState)) { bytesPerVector = quantizationState.getBytesPerVector(); dimensions = quantizationState.getDimensions(); quantizationOutput = quantizationService.createQuantizationOutput(quantizationState.getQuantizationParams()); diff --git a/src/main/java/org/opensearch/knn/index/engine/EngineResolver.java b/src/main/java/org/opensearch/knn/index/engine/EngineResolver.java index d52c21c4ca..801800ecb7 100644 --- a/src/main/java/org/opensearch/knn/index/engine/EngineResolver.java +++ b/src/main/java/org/opensearch/knn/index/engine/EngineResolver.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.engine; +import org.opensearch.Version; import org.opensearch.knn.index.mapper.CompressionLevel; import org.opensearch.knn.index.mapper.Mode; @@ -52,11 +53,15 @@ public KNNEngine resolveEngine( return mode == Mode.ON_DISK ? KNNEngine.FAISS : KNNEngine.NMSLIB; } - // Lucene is only engine that supports 4x - so we have to default to it here. - if (compressionLevel == CompressionLevel.x4) { + // 4x is supported by Lucene engine before version 2.19.0 + if (compressionLevel == CompressionLevel.x4 && knnMethodConfigContext.getVersionCreated().before(Version.V_2_19_0)) { return KNNEngine.LUCENE; } + if (compressionLevel == CompressionLevel.x4) { + return KNNEngine.FAISS; + } + return KNNEngine.FAISS; } } diff --git a/src/main/java/org/opensearch/knn/index/engine/MethodComponent.java b/src/main/java/org/opensearch/knn/index/engine/MethodComponent.java index d456ea89f4..bcb325fb39 100644 --- a/src/main/java/org/opensearch/knn/index/engine/MethodComponent.java +++ b/src/main/java/org/opensearch/knn/index/engine/MethodComponent.java @@ -339,6 +339,11 @@ public static Map getParameterMapWithDefaultsAdded( || compressionLevel == CompressionLevel.x16 || compressionLevel == CompressionLevel.x8); + // Check if the mode is ON_DISK and the compression level is x4 and index created version is >= 2.19.0 + // This determines whether to use faiss byte quantization-specific values for parameters like ef_search and ef_construction. + boolean isFaissOnDiskWithByteQuantization = compressionLevel == CompressionLevel.x4 + && indexCreationVersion.onOrAfter(Version.V_2_19_0); + for (Parameter parameter : methodComponent.getParameters().values()) { if (methodComponentContext.getParameters().containsKey(parameter.getName())) { parametersWithDefaultsMap.put(parameter.getName(), userProvidedParametersMap.get(parameter.getName())); @@ -346,7 +351,7 @@ public static Map getParameterMapWithDefaultsAdded( // Picking the right values for the parameters whose values are different based on different index // created version. if (parameter.getName().equals(KNNConstants.METHOD_PARAMETER_EF_SEARCH)) { - if (isOnDiskWithBinaryQuantization) { + if (isOnDiskWithBinaryQuantization || isFaissOnDiskWithByteQuantization) { parametersWithDefaultsMap.put(parameter.getName(), IndexHyperParametersUtil.getBinaryQuantizationEFSearchValue()); } else { parametersWithDefaultsMap.put( @@ -355,7 +360,7 @@ public static Map getParameterMapWithDefaultsAdded( ); } } else if (parameter.getName().equals(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION)) { - if (isOnDiskWithBinaryQuantization) { + if (isOnDiskWithBinaryQuantization || isFaissOnDiskWithByteQuantization) { parametersWithDefaultsMap.put( parameter.getName(), IndexHyperParametersUtil.getBinaryQuantizationEFConstructionValue() diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java index 356292678f..2f2286d1e7 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java @@ -17,7 +17,12 @@ import java.util.Objects; import java.util.Set; +import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; import static org.opensearch.knn.common.KNNConstants.FAISS_SIGNED_BYTE_SQ; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ8; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_INT8; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.index.engine.faiss.Faiss.FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; import static org.opensearch.knn.index.engine.faiss.FaissFP16Util.isFaissSQClipToFP16RangeEnabled; @@ -101,23 +106,37 @@ static KNNLibraryIndexingContext adjustIndexDescription( if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) { prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; } + if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BYTE) { // If VectorDataType is Byte using Faiss engine then manipulate Index Description to use "SQ8_direct_signed" scalar quantizer // For example, Index Description "HNSW16,Flat" will be updated as "HNSW16,SQ8_direct_signed" - String indexDescription = methodAsMapBuilder.indexDescription; - if (StringUtils.isNotEmpty(indexDescription)) { - StringBuilder indexDescriptionBuilder = new StringBuilder(); - indexDescriptionBuilder.append(indexDescription.split(",")[0]); - indexDescriptionBuilder.append(","); - indexDescriptionBuilder.append(FAISS_SIGNED_BYTE_SQ); - methodAsMapBuilder.indexDescription = indexDescriptionBuilder.toString(); - } + methodAsMapBuilder.indexDescription = updateIndexDescription(methodAsMapBuilder.indexDescription, FAISS_SIGNED_BYTE_SQ); + } + + // If SQ encoder type is 'int8' or the compression level is 4x then manipulate Index Description to use "SQ8" scalar quantizer + // For example, Index Description "HNSW16,Flat" will be updated as "HNSW16,SQ8" + if (encoderContext != null + && Objects.equals(encoderContext.getName(), ENCODER_SQ) + && Objects.equals(encoderContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16), FAISS_SQ_ENCODER_INT8)) { + methodAsMapBuilder.indexDescription = updateIndexDescription(methodAsMapBuilder.indexDescription, FAISS_SQ8); } methodAsMapBuilder.indexDescription = prefix + methodAsMapBuilder.indexDescription; return methodAsMapBuilder.build(); } + private static String updateIndexDescription(String indexDescription, String indexDescriptionName) { + if (StringUtils.isEmpty(indexDescription)) { + return indexDescription; + } + + StringBuilder indexDescriptionBuilder = new StringBuilder(); + indexDescriptionBuilder.append(indexDescription.split(",")[0]); + indexDescriptionBuilder.append(","); + indexDescriptionBuilder.append(indexDescriptionName); + return indexDescriptionBuilder.toString(); + } + static MethodComponentContext getEncoderMethodComponent(MethodComponentContext methodComponentContext) { if (!methodComponentContext.getParameters().containsKey(METHOD_ENCODER_PARAMETER)) { return null; diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolver.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolver.java index c976a0959b..9934327911 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolver.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolver.java @@ -19,12 +19,14 @@ import org.opensearch.knn.index.mapper.Mode; import java.util.HashMap; +import java.util.Locale; import java.util.Map; import java.util.Set; import static org.opensearch.knn.common.KNNConstants.ENCODER_FLAT; import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_INT8; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; @@ -37,6 +39,7 @@ public class FaissMethodResolver extends AbstractMethodResolver { private static final Set SUPPORTED_COMPRESSION_LEVELS = Set.of( CompressionLevel.x1, CompressionLevel.x2, + CompressionLevel.x4, CompressionLevel.x8, CompressionLevel.x16, CompressionLevel.x32 @@ -75,6 +78,7 @@ public ResolvedMethodContext resolveMethod( // Validate that resolved compression doesnt have any conflicts validateCompressionConflicts(knnMethodConfigContext.getCompressionLevel(), resolvedCompressionLevel); + validateMethodWithCompression(method, resolvedCompressionLevel); knnMethodConfigContext.setCompressionLevel(resolvedCompressionLevel); resolveMethodParams(resolvedKNNMethodContext.getMethodComponentContext(), knnMethodConfigContext, method); @@ -106,6 +110,12 @@ private void resolveEncoder( encoderComponentContext.getParameters().put(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16); } + if (CompressionLevel.x4 == resolvedCompressionLevel) { + encoderComponentContext = new MethodComponentContext(ENCODER_SQ, new HashMap<>()); + encoder = encoderMap.get(ENCODER_SQ); + encoderComponentContext.getParameters().put(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_INT8); + } + if (CompressionLevel.x8 == resolvedCompressionLevel) { encoderComponentContext = new MethodComponentContext(QFrameBitEncoder.NAME, new HashMap<>()); encoder = encoderMap.get(QFrameBitEncoder.NAME); @@ -157,4 +167,20 @@ private CompressionLevel getDefaultCompressionLevel(KNNMethodConfigContext knnMe } return CompressionLevel.x1; } + + private void validateMethodWithCompression(MethodComponent method, CompressionLevel compressionLevel) { + if (method == IVF_COMPONENT && compressionLevel == CompressionLevel.x4) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError( + String.format( + Locale.ROOT, + "sq \"%s\" encoder or compression level \"%s\" does not support \"%s\" method", + FAISS_SQ_ENCODER_INT8, + CompressionLevel.x4.getName(), + METHOD_IVF + ) + ); + throw validationException; + } + } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoder.java index cd7e1e5f38..0d48ce5fdb 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoder.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.engine.faiss; import com.google.common.collect.ImmutableSet; +import org.opensearch.common.ValidationException; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.Encoder; import org.opensearch.knn.index.engine.KNNMethodConfigContext; @@ -21,6 +22,7 @@ import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_DESCRIPTION; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_INT8; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_TYPES; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; @@ -58,7 +60,29 @@ public CompressionLevel calculateCompressionLevel( MethodComponentContext methodComponentContext, KNNMethodConfigContext knnMethodConfigContext ) { - // TODO: Hard code for now - return CompressionLevel.x2; + // Faiss Scalar Quantizer by default sets to fp16 or compression level x2 when encoder type is not set + if (methodComponentContext.getParameters().size() == 0 || !methodComponentContext.getParameters().containsKey(FAISS_SQ_TYPE)) { + return CompressionLevel.x2; + } + + // Map the sq encoder type passed in, back to the compression level + Object value = methodComponentContext.getParameters().get(FAISS_SQ_TYPE); + ValidationException validationException = METHOD_COMPONENT.getParameters() + .get(FAISS_SQ_TYPE) + .validate(value, knnMethodConfigContext); + if (validationException != null) { + throw validationException; + } + + String SQEncoderType = (String) value; + if (FAISS_SQ_ENCODER_FP16.equals(SQEncoderType)) { + return CompressionLevel.x2; + } + + if (FAISS_SQ_ENCODER_INT8.equals(SQEncoderType)) { + return CompressionLevel.x4; + } + + return CompressionLevel.NOT_CONFIGURED; } } diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/MethodAsMapBuilder.java b/src/main/java/org/opensearch/knn/index/engine/faiss/MethodAsMapBuilder.java index 8b688cbcc7..8e544b4f44 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/MethodAsMapBuilder.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/MethodAsMapBuilder.java @@ -6,6 +6,7 @@ package org.opensearch.knn.index.engine.faiss; import lombok.AllArgsConstructor; +import org.opensearch.Version; import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.engine.KNNLibraryIndexingContext; import org.opensearch.knn.index.engine.KNNLibraryIndexingContextImpl; @@ -14,6 +15,8 @@ import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.engine.Parameter; import org.opensearch.knn.index.engine.qframe.QuantizationConfig; +import org.opensearch.knn.index.mapper.CompressionLevel; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import java.util.HashMap; import java.util.Map; @@ -112,6 +115,12 @@ static MethodAsMapBuilder builder( PARAMETERS, MethodComponent.getParameterMapWithDefaultsAdded(methodComponentContext, methodComponent, knnMethodConfigContext) ); - return new MethodAsMapBuilder(baseDescription, methodComponent, initialMap, knnMethodConfigContext, QuantizationConfig.EMPTY); + + QuantizationConfig quantizationConfig = QuantizationConfig.EMPTY; + if (knnMethodConfigContext.getCompressionLevel() == CompressionLevel.x4 + && knnMethodConfigContext.getVersionCreated().onOrAfter(Version.V_2_19_0)) { + quantizationConfig = QuantizationConfig.builder().quantizationType(ScalarQuantizationType.EIGHT_BIT).build(); + } + return new MethodAsMapBuilder(baseDescription, methodComponent, initialMap, knnMethodConfigContext, quantizationConfig); } } diff --git a/src/main/java/org/opensearch/knn/index/engine/qframe/QuantizationConfigParser.java b/src/main/java/org/opensearch/knn/index/engine/qframe/QuantizationConfigParser.java index f86d7f886b..95a95fba0f 100644 --- a/src/main/java/org/opensearch/knn/index/engine/qframe/QuantizationConfigParser.java +++ b/src/main/java/org/opensearch/knn/index/engine/qframe/QuantizationConfigParser.java @@ -9,6 +9,7 @@ import org.opensearch.knn.index.engine.faiss.QFrameBitEncoder; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import java.util.List; import java.util.Locale; /** @@ -19,6 +20,10 @@ public class QuantizationConfigParser { public static final String SEPARATOR = "="; public static final String TYPE_NAME = "type"; public static final String BINARY_TYPE = QFrameBitEncoder.NAME; + + public static final String BYTE_TYPE = "byte"; + + public static final List QUANTIZATION_CONFIG_TYPES = List.of(BINARY_TYPE, BYTE_TYPE); public static final String BIT_COUNT_NAME = QFrameBitEncoder.BITCOUNT_PARAM; /** @@ -33,6 +38,9 @@ public static String toCsv(QuantizationConfig quantizationConfig) { || quantizationConfig.getQuantizationType() == null) { return ""; } + if (quantizationConfig.getQuantizationType() == ScalarQuantizationType.EIGHT_BIT) { + return TYPE_NAME + SEPARATOR + BYTE_TYPE + "," + BIT_COUNT_NAME + SEPARATOR + quantizationConfig.getQuantizationType().getId(); + } return TYPE_NAME + SEPARATOR + BINARY_TYPE + "," + BIT_COUNT_NAME + SEPARATOR + quantizationConfig.getQuantizationType().getId(); } @@ -54,7 +62,7 @@ public static QuantizationConfig fromCsv(String csv) { } String typeValue = getValueOrThrow(TYPE_NAME, csvArray[0]); - if (!typeValue.equals(BINARY_TYPE)) { + if (!QUANTIZATION_CONFIG_TYPES.contains(typeValue)) { throw new IllegalArgumentException(String.format(Locale.ROOT, "Unsupported quantization type: \"%s\"", typeValue)); } diff --git a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java index c14bb9d826..28a2c7b3d2 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java +++ b/src/main/java/org/opensearch/knn/index/mapper/CompressionLevel.java @@ -7,6 +7,7 @@ import lombok.AllArgsConstructor; import lombok.Getter; +import org.opensearch.Version; import org.opensearch.core.common.Strings; import org.opensearch.knn.index.query.rescore.RescoreContext; @@ -24,7 +25,8 @@ public enum CompressionLevel { NOT_CONFIGURED(-1, "", null, Collections.emptySet()), x1(1, "1x", null, Collections.emptySet()), x2(2, "2x", null, Collections.emptySet()), - x4(4, "4x", null, Collections.emptySet()), + + x4(4, "4x", new RescoreContext(2.0f, false, false), Set.of(Mode.ON_DISK)), x8(8, "8x", new RescoreContext(2.0f, false, false), Set.of(Mode.ON_DISK)), x16(16, "16x", new RescoreContext(3.0f, false, false), Set.of(Mode.ON_DISK)), x32(32, "32x", new RescoreContext(3.0f, false, false), Set.of(Mode.ON_DISK)), @@ -101,11 +103,17 @@ public static boolean isConfigured(CompressionLevel compressionLevel) { * * @param mode The {@link Mode} for which to retrieve the {@link RescoreContext}. * @param dimension The dimensional value that determines the {@link RescoreContext} behavior. + * @param indexVersionCreated OpenSearch cluster version in which the index was created * @return A {@link RescoreContext} with an oversample factor of 5.0f if {@code dimension} is less than * or equal to 1000, the default {@link RescoreContext} if greater, or {@code null} if the mode * is invalid. */ - public RescoreContext getDefaultRescoreContext(Mode mode, int dimension) { + public RescoreContext getDefaultRescoreContext(Mode mode, int dimension, Version indexVersionCreated) { + + // x4 compression was supported by Lucene engine before version 2.19.0 and there is no default rescore context + if (compressionLevel == CompressionLevel.x4.compressionLevel && indexVersionCreated.before(Version.V_2_19_0)) { + return null; + } if (modesForRescore.contains(mode)) { // Adjust RescoreContext based on dimension if (dimension <= RescoreContext.DIMENSION_THRESHOLD) { diff --git a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java index 461c6f7c8c..d142d44f6f 100644 --- a/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java +++ b/src/main/java/org/opensearch/knn/index/mapper/KNNVectorFieldType.java @@ -11,6 +11,7 @@ import org.apache.lucene.search.DocValuesFieldExistsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.util.BytesRef; +import org.opensearch.Version; import org.opensearch.index.fielddata.IndexFieldData; import org.opensearch.index.mapper.MappedFieldType; import org.opensearch.index.mapper.ArraySourceValueFetcher; @@ -107,9 +108,10 @@ public Object valueForDisplay(Object value) { * Resolve the rescore context provided for a user based on the field configuration * * @param userProvidedContext {@link RescoreContext} user passed; if null, the default should be configured + * @param indexVersionCreated OpenSearch cluster version in which the index was created * @return resolved {@link RescoreContext} */ - public RescoreContext resolveRescoreContext(RescoreContext userProvidedContext) { + public RescoreContext resolveRescoreContext(RescoreContext userProvidedContext, Version indexVersionCreated) { if (userProvidedContext != null) { return userProvidedContext; } @@ -117,7 +119,7 @@ public RescoreContext resolveRescoreContext(RescoreContext userProvidedContext) int dimension = knnMappingConfig.getDimension(); CompressionLevel compressionLevel = knnMappingConfig.getCompressionLevel(); Mode mode = knnMappingConfig.getMode(); - return compressionLevel.getDefaultRescoreContext(mode, dimension); + return compressionLevel.getDefaultRescoreContext(mode, dimension, indexVersionCreated); } /** diff --git a/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java b/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java index 7718487306..58c81392c6 100644 --- a/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java +++ b/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java @@ -11,12 +11,14 @@ import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.factory.QuantizerFactory; import org.opensearch.knn.quantization.models.quantizationOutput.BinaryQuantizationOutput; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.quantizer.ByteScalarQuantizer; import org.opensearch.knn.quantization.quantizer.Quantizer; import java.io.IOException; @@ -60,7 +62,8 @@ public static QuantizationService getInstance() { public QuantizationState train( final QuantizationParams quantizationParams, final KNNVectorValues knnVectorValues, - final long liveDocs + final long liveDocs, + final FieldInfo fieldInfo ) throws IOException { Quantizer quantizer = QuantizerFactory.getQuantizer(quantizationParams); @@ -68,6 +71,9 @@ public QuantizationState train( KNNVectorQuantizationTrainingRequest trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues, liveDocs); // Train the quantizer and return the quantization state + if (quantizer instanceof ByteScalarQuantizer) { + return quantizer.train(trainingRequest, fieldInfo); + } return quantizer.train(trainingRequest); } @@ -107,6 +113,10 @@ public QuantizationParams getQuantizationParams(final FieldInfo fieldInfo) { */ public VectorDataType getVectorDataTypeForTransfer(final FieldInfo fieldInfo) { QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo); + if (quantizationConfig != QuantizationConfig.EMPTY + && quantizationConfig.getQuantizationType() == ScalarQuantizationType.EIGHT_BIT) { + return VectorDataType.FLOAT; + } if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) { return VectorDataType.BINARY; } diff --git a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java index 3ba273e714..4fbdfb9351 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java @@ -427,7 +427,7 @@ protected Query doToQuery(QueryShardContext context) { MethodComponentContext methodComponentContext = queryConfigFromMapping.get().getMethodComponentContext(); SpaceType spaceType = queryConfigFromMapping.get().getSpaceType(); VectorDataType vectorDataType = queryConfigFromMapping.get().getVectorDataType(); - RescoreContext processedRescoreContext = knnVectorFieldType.resolveRescoreContext(rescoreContext); + RescoreContext processedRescoreContext = knnVectorFieldType.resolveRescoreContext(rescoreContext, context.indexVersionCreated()); knnVectorFieldType.transformQueryVector(vector); VectorQueryType vectorQueryType = getVectorQueryType(k, maxDistance, minScore); diff --git a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java index 37b5cc9ad1..d8992c9ff3 100644 --- a/src/main/java/org/opensearch/knn/index/query/KNNWeight.java +++ b/src/main/java/org/opensearch/knn/index/query/KNNWeight.java @@ -40,6 +40,9 @@ import org.opensearch.knn.indices.ModelUtil; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.plugin.stats.KNNCounter; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import java.io.IOException; import java.util.Arrays; @@ -309,7 +312,7 @@ private Map doANNSearch( knnEngine, knnQuery.getIndexName(), // TODO: In the future, more vector data types will be supported with quantization - quantizedVector == null ? vectorDataType : VectorDataType.BINARY + getVectorDataType(fieldInfo, vectorDataType) ), knnQuery.getIndexName(), modelId @@ -414,6 +417,19 @@ public static float normalizeScore(float score) { return -score + 1; } + private VectorDataType getVectorDataType(final FieldInfo fieldInfo, final VectorDataType vectorDataType) { + final QuantizationParams quantizationParams = QuantizationService.getInstance().getQuantizationParams(fieldInfo); + if (quantizationParams == null) { + return vectorDataType; + } + if ((quantizationParams.getTypeIdentifier()).equals( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.EIGHT_BIT) + )) { + return VectorDataType.BYTE; + } + return VectorDataType.BINARY; + } + private boolean isFilteredExactSearchPreferred(final int filterIdsCount) { if (filterWeight == null) { return false; diff --git a/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java b/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java index d25774cdcf..7fd1f61e62 100644 --- a/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java +++ b/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java @@ -11,7 +11,9 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReader; import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; @@ -36,7 +38,10 @@ public class SegmentLevelQuantizationInfo { public static SegmentLevelQuantizationInfo build(final LeafReader leafReader, final FieldInfo fieldInfo, final String fieldName) throws IOException { final QuantizationParams quantizationParams = QuantizationService.getInstance().getQuantizationParams(fieldInfo); - if (quantizationParams == null) { + if (quantizationParams == null + || (quantizationParams.getTypeIdentifier()).equals( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.EIGHT_BIT) + )) { return null; } final QuantizationState quantizationState = SegmentLevelQuantizationUtil.getQuantizationState(leafReader, fieldName); diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index dcc7b180d0..47fbb87606 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -154,6 +154,39 @@ class FaissService { */ public static native void writeByteIndex(long indexAddress, IndexOutputWithBuffer output); + /** + * Initialize an index for the native library with a provided template index. Takes in numDocs to + * allocate the correct amount of memory. + * + * @param numDocs number of documents to be added + * @param dim dimension of the vector to be indexed + * @param parameters parameters to build index + * @param templateIndex template index + */ + public static native long initIndexFromTemplate(long numDocs, int dim, Map parameters, byte[] templateIndex); + + /** + * Initialize an index for the native library with a provided template index. Takes in numDocs to + * allocate the correct amount of memory. + * + * @param numDocs number of documents to be added + * @param dim dimension of the vector to be indexed + * @param parameters parameters to build index + * @param templateIndex template index + */ + public static native long initBinaryIndexFromTemplate(long numDocs, int dim, Map parameters, byte[] templateIndex); + + /** + * Initialize a byte index for the native library with a provided template index. Takes in numDocs to + * allocate the correct amount of memory. + * + * @param numDocs number of documents to be added + * @param dim dimension of the vector to be indexed + * @param parameters parameters to build index + * @param templateIndex template index + */ + public static native long initByteIndexFromTemplate(long numDocs, int dim, Map parameters, byte[] templateIndex); + /** * Create an index for the native library with a provided template index * diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index 3d2a13f97f..92da8f706b 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -154,6 +154,39 @@ public static void createIndex( ); } + /** + * Initialize an index for the native library with a provided template index. Takes in numDocs to + * allocate the correct amount of memory. + * + * @param numDocs number of documents to be added + * @param dim dimension of the vector to be indexed + * @param parameters parameters to build index + * @param knnEngine knn engine + * @param templateIndex template index + * @return address of the index in memory + */ + public static long initIndexFromTemplate( + long numDocs, + int dim, + Map parameters, + KNNEngine knnEngine, + byte[] templateIndex + ) { + if (KNNEngine.FAISS == knnEngine) { + if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { + return FaissService.initBinaryIndexFromTemplate(numDocs, dim, parameters, templateIndex); + } + if (IndexUtil.isByteIndex(parameters)) { + return FaissService.initByteIndexFromTemplate(numDocs, dim, parameters, templateIndex); + } + return FaissService.initIndexFromTemplate(numDocs, dim, parameters, templateIndex); + } + + throw new IllegalArgumentException( + String.format(Locale.ROOT, "initIndexFromTemplate not supported for provided engine : %s", knnEngine.getName()) + ); + } + /** * Create an index for the native library with a provided template index * diff --git a/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java b/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java index 40347ad932..a4a55bf02c 100644 --- a/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java +++ b/src/main/java/org/opensearch/knn/quantization/enums/ScalarQuantizationType.java @@ -31,7 +31,12 @@ public enum ScalarQuantizationType { /** * FOUR_BIT quantization uses four bits per coordinate. */ - FOUR_BIT(4); + FOUR_BIT(4), + + /** + * EIGHT_BIT quantization uses eight bits per dimension. + */ + EIGHT_BIT(8); private final int id; diff --git a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java index 7b542aea0b..2d8f75ee73 100644 --- a/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java +++ b/src/main/java/org/opensearch/knn/quantization/factory/QuantizerRegistrar.java @@ -9,6 +9,7 @@ import lombok.NoArgsConstructor; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.quantizer.ByteScalarQuantizer; import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer; import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; @@ -42,5 +43,10 @@ static synchronized void registerDefaultQuantizers() { ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.FOUR_BIT), new MultiBitScalarQuantizer(4) ); + // Register ByteScalarQuantizer for SQParams with int8 or 8 bits + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.EIGHT_BIT), + new ByteScalarQuantizer(8) + ); } } diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/ByteScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/ByteScalarQuantizationState.java new file mode 100644 index 0000000000..f5715f22c0 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/ByteScalarQuantizationState.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.models.quantizationState; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.NoArgsConstructor; +import org.apache.lucene.util.RamUsageEstimator; +import org.opensearch.Version; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; + +import java.io.IOException; + +@Getter +@NoArgsConstructor +@AllArgsConstructor +public class ByteScalarQuantizationState implements QuantizationState { + private ScalarQuantizationParams quantizationParams; + private byte[] indexTemplate; + + @Override + public QuantizationParams getQuantizationParams() { + return quantizationParams; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(Version.CURRENT.id); // Write the version + quantizationParams.writeTo(out); + out.writeByteArray(indexTemplate); + } + + @Override + public byte[] toByteArray() throws IOException { + return QuantizationStateSerializer.serialize(this); + } + + @Override + public int getBytesPerVector() { + return 0; + } + + @Override + public int getDimensions() { + return 0; + } + + @Override + public long ramBytesUsed() { + long size = RamUsageEstimator.shallowSizeOfInstance(ByteScalarQuantizationState.class); + size += RamUsageEstimator.shallowSizeOf(quantizationParams); + size += RamUsageEstimator.sizeOf(indexTemplate); + return size; + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/ByteScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/ByteScalarQuantizer.java new file mode 100644 index 0000000000..2cabe7e857 --- /dev/null +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/ByteScalarQuantizer.java @@ -0,0 +1,128 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.apache.lucene.index.FieldInfo; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.ByteScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; +import org.opensearch.knn.quantization.sampler.Sampler; +import org.opensearch.knn.quantization.sampler.SamplerType; +import org.opensearch.knn.quantization.sampler.SamplingFactory; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; +import static org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory.getVectorTransfer; + +public class ByteScalarQuantizer implements Quantizer { + private final int bitsPerCoordinate; + private final int samplingSize; // Sampling size for training + private final Sampler sampler; // Sampler for training + private static final int DEFAULT_SAMPLE_SIZE = 25000; + + public ByteScalarQuantizer(final int bitsPerCoordinate) { + if (bitsPerCoordinate != 8) { + throw new IllegalArgumentException("bitsPerCoordinate must be 8 for byte scalar quantizer."); + } + this.bitsPerCoordinate = bitsPerCoordinate; + this.samplingSize = DEFAULT_SAMPLE_SIZE; + this.sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR); + } + + @Override + public QuantizationState train(final TrainingRequest trainingRequest) throws IOException { + return null; + } + + /** + * Samples vectors and train them using Faiss Scalar Quantizer which returns the templateIndex after training. + * The templateIndex contains trained parameters which will be used later to quantize the ingested vectors inside Faiss. + * + * @param trainingRequest the request containing data and parameters for training. + * @param fieldInfo fieldInfo + * @return a QuantizationState containing the learned parameters. + * @throws IOException + */ + @Override + public QuantizationState train(final TrainingRequest trainingRequest, final FieldInfo fieldInfo) throws IOException { + int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); + if (sampledIndices.length == 0) { + return null; + } + float[] vector = trainingRequest.getVectorAtThePosition(sampledIndices[0]); + if (vector == null) { + throw new IllegalArgumentException("Vector at sampled index " + sampledIndices[0] + " is null."); + } + int dimension = vector.length; + + try ( + final OffHeapVectorTransfer vectorTransfer = getVectorTransfer( + extractVectorDataType(fieldInfo), + 4 * dimension, + sampledIndices.length + ) + ) { + for (int i = 0; i < sampledIndices.length; i++) { + Object vectorToTransfer = trainingRequest.getVectorAtThePosition(sampledIndices[i]); + vectorTransfer.transfer(vectorToTransfer, true); + } + vectorTransfer.flush(true); + + byte[] indexTemplate = JNIService.trainIndex( + getParameters(fieldInfo), + dimension, + vectorTransfer.getVectorAddress(), + KNNEngine.FAISS + ); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.EIGHT_BIT); + return new ByteScalarQuantizationState(params, indexTemplate); + } + } + + private Map getParameters(final FieldInfo fieldInfo) throws IOException { + Map parameters = new HashMap<>(); + Map fieldAttributes = fieldInfo.attributes(); + String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); + + parameters.putAll( + XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(parametersString), + MediaTypeRegistry.getDefaultMediaType() + ).map() + ); + + parameters.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT); + + // Used to determine how many threads to use when indexing + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + + return parameters; + } + + @Override + public void quantize(float[] vector, QuantizationState state, QuantizationOutput output) { + // Quantization is performed inside Faiss Scalar Quantizer + } +} diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java index 0bcc252d1b..bc50227092 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java @@ -6,6 +6,7 @@ package org.opensearch.knn.quantization.quantizer; +import org.apache.lucene.index.FieldInfo; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; @@ -119,6 +120,11 @@ public QuantizationState train(final TrainingRequest trainingRequest) t return new MultiBitScalarQuantizationState(params, thresholds); } + @Override + public QuantizationState train(final TrainingRequest trainingRequest, final FieldInfo fieldInfo) throws IOException { + return null; + } + /** * Quantizes the provided vector using the provided quantization state, producing a quantized output. * The vector is quantized based on the thresholds in the quantization state. diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java index 3cba89c39c..70ed81a837 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java @@ -5,6 +5,7 @@ package org.opensearch.knn.quantization.quantizer; +import org.apache.lucene.index.FieldInfo; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; @@ -64,6 +65,11 @@ public QuantizationState train(final TrainingRequest trainingRequest) t return new OneBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), meanThresholds); } + @Override + public QuantizationState train(final TrainingRequest trainingRequest, final FieldInfo fieldInfo) throws IOException { + return null; + } + /** * Quantizes the provided vector using the given quantization state. * It compares each dimension of the vector against the corresponding mean (threshold) to determine the quantized value. diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java index 521863205e..1e343b5e8b 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java @@ -5,6 +5,7 @@ package org.opensearch.knn.quantization.quantizer; +import org.apache.lucene.index.FieldInfo; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.requests.TrainingRequest; @@ -31,6 +32,8 @@ public interface Quantizer { */ QuantizationState train(TrainingRequest trainingRequest) throws IOException; + QuantizationState train(TrainingRequest trainingRequest, FieldInfo fieldInfo) throws IOException; + /** * Quantizes the provided vector using the specified quantization state. * diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index e1f34c798b..2fc58caaf5 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -55,6 +55,7 @@ import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_CLIP; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_INT8; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; import static org.opensearch.knn.common.KNNConstants.FP16_MAX_VALUE; import static org.opensearch.knn.common.KNNConstants.FP16_MIN_VALUE; @@ -556,53 +557,18 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() { public void testHNSWSQFP16_whenIndexedAndQueried_thenSucceed() { String indexName = "test-index-hnsw-sqfp16"; String fieldName = "test-field-hnsw-sqfp16"; - SpaceType[] spaceTypes = { SpaceType.L2, SpaceType.INNER_PRODUCT }; - Random random = new Random(); - SpaceType spaceType = spaceTypes[random.nextInt(spaceTypes.length)]; - - List mValues = ImmutableList.of(16, 32, 64, 128); - List efConstructionValues = ImmutableList.of(16, 32, 64, 128); - List efSearchValues = ImmutableList.of(16, 32, 64, 128); - int dimension = 128; int numDocs = 100; + validateHNSWSQIndexingAndQuerying(indexName, fieldName, dimension, numDocs, FAISS_SQ_ENCODER_FP16); + } - // Create an index - XContentBuilder builder = XContentFactory.jsonBuilder() - .startObject() - .startObject("properties") - .startObject(fieldName) - .field("type", "knn_vector") - .field("dimension", dimension) - .startObject(KNN_METHOD) - .field(NAME, METHOD_HNSW) - .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) - .field(KNN_ENGINE, KNNEngine.FAISS.getName()) - .startObject(PARAMETERS) - .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) - .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) - .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) - .startObject(METHOD_ENCODER_PARAMETER) - .field(NAME, ENCODER_SQ) - .startObject(PARAMETERS) - .field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) - .endObject() - .endObject() - .endObject() - .endObject() - .endObject() - .endObject() - .endObject(); - - Map mappingMap = xContentBuilderToMap(builder); - String mapping = builder.toString(); - - createKnnIndex(indexName, mapping); - assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); - indexTestData(indexName, fieldName, dimension, numDocs); - queryTestData(indexName, fieldName, dimension, numDocs); - deleteKNNIndex(indexName); - validateGraphEviction(); + @SneakyThrows + public void testHNSWSQINT8_whenIndexedAndQueried_thenSucceed() { + String indexName = "test-index-hnsw-sqint8"; + String fieldName = "test-field-hnsw-sqint8"; + int dimension = 128; + int numDocs = 100; + validateHNSWSQIndexingAndQuerying(indexName, fieldName, dimension, numDocs, FAISS_SQ_ENCODER_INT8); } @SneakyThrows @@ -2341,4 +2307,52 @@ private void validateNearestNeighborsSearch( } } + @SneakyThrows + private void validateHNSWSQIndexingAndQuerying(String indexName, String fieldName, int dimension, int numDocs, String encoder) { + SpaceType[] spaceTypes = { SpaceType.L2, SpaceType.INNER_PRODUCT }; + Random random = new Random(); + SpaceType spaceType = spaceTypes[random.nextInt(spaceTypes.length)]; + + List mValues = ImmutableList.of(16, 32, 64, 128); + List efConstructionValues = ImmutableList.of(16, 32, 64, 128); + List efSearchValues = ImmutableList.of(16, 32, 64, 128); + + // Create an index + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, mValues.get(random().nextInt(mValues.size()))) + .field(METHOD_PARAMETER_EF_CONSTRUCTION, efConstructionValues.get(random().nextInt(efConstructionValues.size()))) + .field(KNNConstants.METHOD_PARAMETER_EF_SEARCH, efSearchValues.get(random().nextInt(efSearchValues.size()))) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_SQ) + .startObject(PARAMETERS) + .field(FAISS_SQ_TYPE, encoder) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + Map mappingMap = xContentBuilderToMap(builder); + String mapping = builder.toString(); + + createKnnIndex(indexName, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + indexTestData(indexName, fieldName, dimension, numDocs); + queryTestData(indexName, fieldName, dimension, numDocs); + deleteKNNIndex(indexName); + validateGraphEviction(); + } + } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index f87ed6bcf6..b5f4432393 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -27,7 +27,9 @@ import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.index.vectorvalues.TestVectorValues; import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.test.OpenSearchTestCase; @@ -256,9 +258,13 @@ public void testFlush_WithQuantization() { ).thenReturn(expectedVectorValues.get(i)); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + when(quantizationParams.getTypeIdentifier()).thenReturn( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT) + ); try { - when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) - .thenReturn(quantizationState); + when( + quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo) + ).thenReturn(quantizationState); } catch (Exception e) { throw new RuntimeException(e); } @@ -688,9 +694,13 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres ).thenReturn(expectedVectorValues.get(i)); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + when(quantizationParams.getTypeIdentifier()).thenReturn( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT) + ); try { - when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) - .thenReturn(quantizationState); + when( + quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo) + ).thenReturn(quantizationState); } catch (Exception e) { throw new RuntimeException(e); } @@ -790,9 +800,13 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres ).thenReturn(expectedVectorValues.get(i)); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + when(quantizationParams.getTypeIdentifier()).thenReturn( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT) + ); try { - when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) - .thenReturn(quantizationState); + when( + quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo) + ).thenReturn(quantizationState); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java index cdc372bda2..7f93148dd2 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -30,7 +30,9 @@ import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.index.vectorvalues.TestVectorValues; import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.test.OpenSearchTestCase; @@ -324,8 +326,13 @@ public void testMerge_WithQuantization() { .thenReturn(knnVectorValues); when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + when(quantizationParams.getTypeIdentifier()).thenReturn( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT) + ); try { - when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size())).thenReturn(quantizationState); + when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size(), fieldInfo)).thenReturn( + quantizationState + ); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java index 08942fe7f5..85c20d19b6 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java @@ -21,6 +21,7 @@ import org.opensearch.knn.index.vectorvalues.TestVectorValues; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.ByteScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.test.OpenSearchTestCase; @@ -246,4 +247,109 @@ public void testBuildAndWrite_withQuantization() { } } } + + @SneakyThrows + public void testBuildAndWrite_withByteScalarQuantization() { + // Given + ArgumentCaptor vectorAddressCaptor = ArgumentCaptor.forClass(Long.class); + ArgumentCaptor vectorTransferCapture = ArgumentCaptor.forClass(Object.class); + + List vectorValues = List.of(new float[] { 1, 2 }, new float[] { 2, 3 }, new float[] { 3, 4 }); + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + vectorValues + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, randomVectorValues); + + try ( + MockedStatic mockedJNIService = Mockito.mockStatic(JNIService.class); + MockedStatic mockedOffHeapVectorTransferFactory = Mockito.mockStatic( + OffHeapVectorTransferFactory.class + ); + MockedStatic mockedQuantizationIntegration = Mockito.mockStatic(QuantizationService.class) + ) { + byte[] indexTemplate = new byte[] { 1 }; + + // Limits transfer to 2 vectors + OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); + when(offHeapVectorTransfer.getTransferLimit()).thenReturn(2); + mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 8, 3)) + .thenReturn(offHeapVectorTransfer); + + QuantizationService quantizationService = mock(QuantizationService.class); + mockedQuantizationIntegration.when(QuantizationService::getInstance).thenReturn(quantizationService); + + ByteScalarQuantizationState quantizationState = mock(ByteScalarQuantizationState.class); + BuildIndexParams indexInfo = mock(BuildIndexParams.class); + when(indexInfo.getQuantizationState()).thenReturn(quantizationState); + when(quantizationState.getIndexTemplate()).thenReturn(indexTemplate); + + mockedJNIService.when(() -> JNIService.initIndexFromTemplate(3, 2, Map.of("index", "param"), KNNEngine.FAISS, indexTemplate)) + .thenReturn(100L); + + when(offHeapVectorTransfer.transfer(vectorTransferCapture.capture(), eq(false))).thenReturn(false) + .thenReturn(true) + .thenReturn(false); + when(offHeapVectorTransfer.flush(false)).thenReturn(true); + when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); + + IndexOutputWithBuffer indexOutputWithBuffer = Mockito.mock(IndexOutputWithBuffer.class); + BuildIndexParams buildIndexParams = BuildIndexParams.builder() + .indexOutputWithBuffer(indexOutputWithBuffer) + .knnEngine(KNNEngine.FAISS) + .vectorDataType(VectorDataType.FLOAT) + .parameters(Map.of("index", "param")) + .quantizationState(quantizationState) + .vectorValues(knnVectorValues) + .totalLiveDocs((int) knnVectorValues.totalLiveDocs()) + .build(); + + // When + MemOptimizedNativeIndexBuildStrategy.getInstance().buildAndWriteIndex(buildIndexParams); + + // Then + mockedJNIService.verify( + () -> JNIService.initIndexFromTemplate( + eq(knnVectorValues.totalLiveDocs()), + eq(knnVectorValues.dimension()), + eq(Map.of("index", "param")), + eq(KNNEngine.FAISS), + eq(indexTemplate) + ) + ); + + mockedJNIService.verify( + () -> JNIService.insertToIndex( + eq(new int[] { 0, 1 }), + vectorAddressCaptor.capture(), + eq(knnVectorValues.dimension()), + eq(Map.of("index", "param")), + eq(100L), + eq(KNNEngine.FAISS) + ) + ); + + // For the flush + mockedJNIService.verify( + () -> JNIService.insertToIndex( + eq(new int[] { 2 }), + vectorAddressCaptor.capture(), + eq(knnVectorValues.dimension()), + eq(Map.of("index", "param")), + eq(100L), + eq(KNNEngine.FAISS) + ) + ); + + mockedJNIService.verify( + () -> JNIService.writeIndex(eq(indexOutputWithBuffer), eq(100L), eq(KNNEngine.FAISS), eq(Map.of("index", "param"))) + ); + assertEquals(200L, vectorAddressCaptor.getValue().longValue()); + assertEquals(vectorAddressCaptor.getValue().longValue(), vectorAddressCaptor.getAllValues().get(0).longValue()); + verify(offHeapVectorTransfer, times(0)).reset(); + + for (Object vector : vectorTransferCapture.getAllValues()) { + assertTrue(vector instanceof float[]); + } + } + } } diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java index 61d3d7589e..742905987b 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtilsTests.java @@ -16,6 +16,7 @@ import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.ByteScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; @@ -82,6 +83,17 @@ public void testPrepareIndexBuild_withoutQuantization_success() throws IOExcepti assertEquals(knnVectorValues.dimension(), setup.getDimensions()); } + public void testPrepareIndexBuild_withByteScalarQuantization_success() throws IOException { + QuantizationState quantizationState = mock(ByteScalarQuantizationState.class); + when(buildIndexParams.getQuantizationState()).thenReturn(quantizationState); + knnVectorValues.nextDoc(); + knnVectorValues.getVector(); + IndexBuildSetup setup = QuantizationIndexUtils.prepareIndexBuild(knnVectorValues, buildIndexParams); + assertNotNull(setup.getQuantizationState()); + assertEquals(knnVectorValues.bytesPerVector(), setup.getBytesPerVector()); + assertEquals(knnVectorValues.dimension(), setup.getDimensions()); + } + public void testProcessAndReturnVector_withoutQuantization_success() throws IOException { // Set up the BuildIndexParams to return no quantization when(buildIndexParams.getQuantizationState()).thenReturn(null); diff --git a/src/test/java/org/opensearch/knn/index/engine/EngineResolverTests.java b/src/test/java/org/opensearch/knn/index/engine/EngineResolverTests.java index 291f0c671a..d000fda47f 100644 --- a/src/test/java/org/opensearch/knn/index/engine/EngineResolverTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/EngineResolverTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.engine; +import org.opensearch.Version; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.mapper.CompressionLevel; @@ -68,18 +69,49 @@ public void testResolveEngine_whenCompressionIs1x_thenEngineBasedOnMode() { ); } - public void testResolveEngine_whenCompressionIs4x_thenEngineIsLucene() { + public void testResolveEngine_whenCompressionIs4x_VersionIsBefore2_19_thenEngineIsLucene() { assertEquals( KNNEngine.LUCENE, ENGINE_RESOLVER.resolveEngine( - KNNMethodConfigContext.builder().mode(Mode.ON_DISK).compressionLevel(CompressionLevel.x4).build(), + KNNMethodConfigContext.builder() + .mode(Mode.ON_DISK) + .compressionLevel(CompressionLevel.x4) + .versionCreated(Version.V_2_18_0) + .build(), null, false ) ); assertEquals( KNNEngine.LUCENE, - ENGINE_RESOLVER.resolveEngine(KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x4).build(), null, false) + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x4).versionCreated(Version.V_2_17_0).build(), + null, + false + ) + ); + } + + public void testResolveEngine_whenCompressionIs4x_VersionIsOnOrAfter2_19_thenEngineIsFaiss() { + assertEquals( + KNNEngine.FAISS, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder() + .mode(Mode.ON_DISK) + .compressionLevel(CompressionLevel.x4) + .versionCreated(Version.CURRENT) + .build(), + null, + false + ) + ); + assertEquals( + KNNEngine.FAISS, + ENGINE_RESOLVER.resolveEngine( + KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x4).versionCreated(Version.CURRENT).build(), + null, + false + ) ); } diff --git a/src/test/java/org/opensearch/knn/index/engine/MethodComponentTests.java b/src/test/java/org/opensearch/knn/index/engine/MethodComponentTests.java index e247fb70a0..0aeca21885 100644 --- a/src/test/java/org/opensearch/knn/index/engine/MethodComponentTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/MethodComponentTests.java @@ -254,4 +254,36 @@ public void testGetParameterMapWithDefaultsAdded_forOnDiskWithBinaryQuantization assertEquals(IndexHyperParametersUtil.getBinaryQuantizationEFConstructionValue(), resultMap.get(parameterEFConstruction)); } + public void testGetParameterMapWithDefaultsAdded_forOnDiskWithByteQuantization() { + // Set up MethodComponent and context + String methodName = "test-method"; + String parameterEFSearch = "ef_search"; + String parameterEFConstruction = "ef_construction"; + + MethodComponent methodComponent = MethodComponent.Builder.builder(methodName) + .addParameter(parameterEFSearch, new Parameter.IntegerParameter(parameterEFSearch, 512, (v, context) -> v > 0)) + .addParameter(parameterEFConstruction, new Parameter.IntegerParameter(parameterEFConstruction, 512, (v, context) -> v > 0)) + .build(); + + // Simulate ON_DISK mode and byte quantization compression level + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(Version.CURRENT) + .mode(Mode.ON_DISK) // ON_DISK mode + .compressionLevel(CompressionLevel.x4) // Byte quantization compression level + .build(); + + MethodComponentContext methodComponentContext = new MethodComponentContext(methodName, Map.of()); + + // Retrieve parameter map with defaults added + Map resultMap = MethodComponent.getParameterMapWithDefaultsAdded( + methodComponentContext, + methodComponent, + knnMethodConfigContext + ); + + // Check that byte quantization values are used + assertEquals(IndexHyperParametersUtil.getBinaryQuantizationEFSearchValue(), resultMap.get(parameterEFSearch)); + assertEquals(IndexHyperParametersUtil.getBinaryQuantizationEFConstructionValue(), resultMap.get(parameterEFConstruction)); + } + } diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolverTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolverTests.java index 3a33736fa3..5206d9bf68 100644 --- a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolverTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissMethodResolverTests.java @@ -22,8 +22,12 @@ import java.util.Map; import static org.opensearch.knn.common.KNNConstants.ENCODER_FLAT; +import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_INT8; import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.METHOD_IVF; +import static org.opensearch.knn.common.KNNConstants.TYPE; public class FaissMethodResolverTests extends KNNTestCase { @@ -163,6 +167,34 @@ public void testResolveMethod_whenValid_thenResolve() { ); assertEquals(knnMethodConfigContext.getCompressionLevel(), CompressionLevel.x8); validateResolveMethodContext(resolvedMethodContext, CompressionLevel.x8, SpaceType.L2, QFrameBitEncoder.NAME, true); + + resolvedMethodContext = TEST_RESOLVER.resolveMethod( + null, + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .mode(Mode.ON_DISK) + .compressionLevel(CompressionLevel.x4) + .versionCreated(Version.CURRENT) + .build(), + false, + SpaceType.INNER_PRODUCT + ); + validateResolveMethodContext(resolvedMethodContext, CompressionLevel.x4, SpaceType.INNER_PRODUCT, ENCODER_SQ, false); + + resolvedMethodContext = TEST_RESOLVER.resolveMethod( + new KNNMethodContext( + KNNEngine.FAISS, + SpaceType.L2, + new MethodComponentContext( + METHOD_HNSW, + Map.of(METHOD_ENCODER_PARAMETER, new MethodComponentContext(ENCODER_SQ, Map.of(TYPE, FAISS_SQ_ENCODER_INT8))) + ) + ), + KNNMethodConfigContext.builder().vectorDataType(VectorDataType.FLOAT).versionCreated(Version.CURRENT).build(), + false, + SpaceType.L2 + ); + validateResolveMethodContext(resolvedMethodContext, CompressionLevel.x4, SpaceType.L2, ENCODER_SQ, false); } private void validateResolveMethodContext( @@ -195,21 +227,6 @@ private void validateResolveMethodContext( } public void testResolveMethod_whenInvalid_thenThrow() { - // Invalid compression - expectThrows( - ValidationException.class, - () -> TEST_RESOLVER.resolveMethod( - null, - KNNMethodConfigContext.builder() - .vectorDataType(VectorDataType.FLOAT) - .compressionLevel(CompressionLevel.x4) - .versionCreated(Version.CURRENT) - .build(), - false, - SpaceType.L2 - ) - ); - expectThrows( ValidationException.class, () -> TEST_RESOLVER.resolveMethod( @@ -269,5 +286,29 @@ public void testResolveMethod_whenInvalid_thenThrow() { ) ); + + // Using IVF with Faiss 4x compression + expectThrows( + ValidationException.class, + () -> TEST_RESOLVER.resolveMethod( + new KNNMethodContext( + KNNEngine.FAISS, + SpaceType.INNER_PRODUCT, + new MethodComponentContext( + METHOD_IVF, + Map.of(METHOD_ENCODER_PARAMETER, new MethodComponentContext(ENCODER_SQ, Map.of(TYPE, FAISS_SQ_ENCODER_INT8))) + ) + ), + KNNMethodConfigContext.builder() + .vectorDataType(VectorDataType.FLOAT) + .mode(Mode.ON_DISK) + .compressionLevel(CompressionLevel.x4) + .versionCreated(Version.CURRENT) + .build(), + false, + SpaceType.INNER_PRODUCT + ) + + ); } } diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoderTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoderTests.java index 3905158a27..02183b6262 100644 --- a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoderTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissSQEncoderTests.java @@ -6,11 +6,24 @@ package org.opensearch.knn.index.engine.faiss; import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.index.engine.MethodComponentContext; import org.opensearch.knn.index.mapper.CompressionLevel; +import java.util.Map; + +import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_INT8; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; + public class FaissSQEncoderTests extends KNNTestCase { public void testCalculateCompressionLevel() { FaissSQEncoder encoder = new FaissSQEncoder(); - assertEquals(CompressionLevel.x2, encoder.calculateCompressionLevel(null, null)); + assertEquals(CompressionLevel.x2, encoder.calculateCompressionLevel(generateMethodComponentContext(FAISS_SQ_ENCODER_FP16), null)); + assertEquals(CompressionLevel.x4, encoder.calculateCompressionLevel(generateMethodComponentContext(FAISS_SQ_ENCODER_INT8), null)); + } + + private MethodComponentContext generateMethodComponentContext(String sqType) { + return new MethodComponentContext(ENCODER_SQ, Map.of(FAISS_SQ_TYPE, sqType)); } } diff --git a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java index 75da6811e7..06a703c275 100644 --- a/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/faiss/FaissTests.java @@ -32,6 +32,7 @@ import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ; import static org.opensearch.knn.common.KNNConstants.FAISS_NAME; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16; +import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_INT8; import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE; import static org.opensearch.knn.common.KNNConstants.INDEX_DESCRIPTION_PARAMETER; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; @@ -141,6 +142,40 @@ public void testGetKNNLibraryIndexingContext_whenMethodIsHNSWSQFP16_thenCreateCo assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); } + @SneakyThrows + public void testGetKNNLibraryIndexingContext_whenMethodIsHNSWSQINT8_thenCreateCorrectIndexDescription() { + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(org.opensearch.Version.CURRENT) + .dimension(4) + .vectorDataType(VectorDataType.FLOAT) + .build(); + int hnswMParam = 24; + String expectedIndexDescription = String.format(Locale.ROOT, "HNSW%d,SQ8", hnswMParam); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, METHOD_HNSW) + .field(KNN_ENGINE, FAISS_NAME) + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_M, hnswMParam) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_SQ) + .startObject(PARAMETERS) + .field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_INT8) + .endObject() + .endObject() + .endObject() + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); + + Map map = Faiss.INSTANCE.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext) + .getLibraryParameters(); + + assertTrue(map.containsKey(INDEX_DESCRIPTION_PARAMETER)); + assertEquals(expectedIndexDescription, map.get(INDEX_DESCRIPTION_PARAMETER)); + } + public void testGetKNNLibraryIndexingContext_whenMethodIsIVFFlat_thenCreateCorrectIndexDescription() throws IOException { KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() .versionCreated(org.opensearch.Version.CURRENT) diff --git a/src/test/java/org/opensearch/knn/index/engine/qframe/QuantizationConfigParserTests.java b/src/test/java/org/opensearch/knn/index/engine/qframe/QuantizationConfigParserTests.java index 317d40e275..6f1a1547f7 100644 --- a/src/test/java/org/opensearch/knn/index/engine/qframe/QuantizationConfigParserTests.java +++ b/src/test/java/org/opensearch/knn/index/engine/qframe/QuantizationConfigParserTests.java @@ -70,5 +70,9 @@ public void testToCsv() { "type=binary,bits=2", QuantizationConfigParser.toCsv(QuantizationConfig.builder().quantizationType(ScalarQuantizationType.TWO_BIT).build()) ); + assertEquals( + "type=byte,bits=8", + QuantizationConfigParser.toCsv(QuantizationConfig.builder().quantizationType(ScalarQuantizationType.EIGHT_BIT).build()) + ); } } diff --git a/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java b/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java index e882d6697a..71f836f097 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/CompressionLevelTests.java @@ -5,6 +5,7 @@ package org.opensearch.knn.index.mapper; +import org.opensearch.Version; import org.opensearch.core.common.Strings; import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.index.query.rescore.RescoreContext; @@ -49,53 +50,62 @@ public void testGetDefaultRescoreContext() { int aboveThresholdDimension = 1500; // A dimension above the threshold // x32 with dimension <= 1000 should have an oversample factor of 5.0f - RescoreContext rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, belowThresholdDimension); + RescoreContext rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, belowThresholdDimension, Version.CURRENT); assertNotNull(rescoreContext); assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f); // x32 with dimension > 1000 should have an oversample factor of 3.0f - rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, aboveThresholdDimension); + rescoreContext = CompressionLevel.x32.getDefaultRescoreContext(mode, aboveThresholdDimension, Version.CURRENT); assertNotNull(rescoreContext); assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); // x16 with dimension <= 1000 should have an oversample factor of 5.0f - rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, belowThresholdDimension); + rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, belowThresholdDimension, Version.CURRENT); assertNotNull(rescoreContext); assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f); // x16 with dimension > 1000 should have an oversample factor of 3.0f - rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, aboveThresholdDimension); + rescoreContext = CompressionLevel.x16.getDefaultRescoreContext(mode, aboveThresholdDimension, Version.CURRENT); assertNotNull(rescoreContext); assertEquals(3.0f, rescoreContext.getOversampleFactor(), 0.0f); // x8 with dimension <= 1000 should have an oversample factor of 5.0f - rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, belowThresholdDimension); + rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, belowThresholdDimension, Version.CURRENT); assertNotNull(rescoreContext); assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f); // x8 with dimension > 1000 should have an oversample factor of 2.0f - rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, aboveThresholdDimension); + rescoreContext = CompressionLevel.x8.getDefaultRescoreContext(mode, aboveThresholdDimension, Version.CURRENT); assertNotNull(rescoreContext); assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f); - // x4 with dimension <= 1000 should have an oversample factor of 5.0f (though it doesn't have its own RescoreContext) - rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, belowThresholdDimension); + // x4 with dimension <= 1000 should have an oversample factor of 5.0f (though it doesn't have its own RescoreContext before V2.19.0) + rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, belowThresholdDimension, Version.V_2_18_1); assertNull(rescoreContext); - // x4 with dimension > 1000 should return null (no RescoreContext is configured for x4) - rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, aboveThresholdDimension); + // x4 with dimension > 1000 should return null (no RescoreContext is configured for x4 before V2.19.0) + rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, aboveThresholdDimension, Version.V_2_18_1); assertNull(rescoreContext); + // x8 with dimension <= 1000 should have an oversample factor of 5.0f on or after V2.19.0 + rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, belowThresholdDimension, Version.CURRENT); + assertNotNull(rescoreContext); + assertEquals(5.0f, rescoreContext.getOversampleFactor(), 0.0f); + // x8 with dimension > 1000 should have an oversample factor of 2.0f on or after V2.19.0 + rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, aboveThresholdDimension, Version.CURRENT); + assertNotNull(rescoreContext); + assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f); + // Other compression levels should behave similarly with respect to dimension - rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, belowThresholdDimension); + rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, belowThresholdDimension, Version.CURRENT); assertNull(rescoreContext); // x2 with dimension > 1000 should return null - rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, aboveThresholdDimension); + rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, aboveThresholdDimension, Version.CURRENT); assertNull(rescoreContext); - rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, belowThresholdDimension); + rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, belowThresholdDimension, Version.CURRENT); assertNull(rescoreContext); // x1 with dimension > 1000 should return null - rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, aboveThresholdDimension); + rescoreContext = CompressionLevel.x1.getDefaultRescoreContext(mode, aboveThresholdDimension, Version.CURRENT); assertNull(rescoreContext); // NOT_CONFIGURED with dimension <= 1000 should return a RescoreContext with an oversample factor of 5.0f - rescoreContext = CompressionLevel.NOT_CONFIGURED.getDefaultRescoreContext(mode, belowThresholdDimension); + rescoreContext = CompressionLevel.NOT_CONFIGURED.getDefaultRescoreContext(mode, belowThresholdDimension, Version.CURRENT); assertNull(rescoreContext); } } diff --git a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java index 49b15a0f43..0d64992726 100644 --- a/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java +++ b/src/test/java/org/opensearch/knn/index/mapper/KNNVectorFieldMapperTests.java @@ -1831,7 +1831,7 @@ public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOEx true ); - // For 4x compression on disk, use Lucene + // For 4x compression on disk, use Faiss xContentBuilder = XContentFactory.jsonBuilder() .startObject() .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) @@ -1847,7 +1847,7 @@ public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOEx ); validateBuilderAfterParsing( builder, - KNNEngine.LUCENE, + KNNEngine.FAISS, SpaceType.L2, VectorDataType.FLOAT, CompressionLevel.x4, @@ -1856,7 +1856,7 @@ public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOEx false ); - // For 4x compression in memory, use Lucene + // For 4x compression in memory, use Faiss xContentBuilder = XContentFactory.jsonBuilder() .startObject() .field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE) @@ -1872,7 +1872,7 @@ public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOEx ); validateBuilderAfterParsing( builder, - KNNEngine.LUCENE, + KNNEngine.FAISS, SpaceType.L2, VectorDataType.FLOAT, CompressionLevel.x4, @@ -1971,7 +1971,7 @@ public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOEx .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x4.getName()) .startObject(KNN_METHOD) .field(NAME, METHOD_HNSW) - .field(KNN_ENGINE, KNNEngine.FAISS) + .field(KNN_ENGINE, KNNEngine.NMSLIB) .endObject() .endObject(); diff --git a/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java index 690391dbdb..2fc85c7675 100644 --- a/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java +++ b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java @@ -5,16 +5,21 @@ package org.opensearch.knn.index.quantizationservice; +import org.apache.lucene.index.FieldInfo; import org.opensearch.knn.KNNTestCase; import org.junit.Before; +import org.opensearch.knn.common.KNNConstants; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.index.vectorvalues.TestVectorValues; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.ByteScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.MultiBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.OneBitScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; @@ -24,6 +29,7 @@ public class QuantizationServiceTests extends KNNTestCase { private QuantizationService quantizationService; private KNNVectorValues knnVectorValues; + private FieldInfo fieldInfo; @Before public void setUp() throws Exception { @@ -42,11 +48,19 @@ public void setUp() throws Exception { VectorDataType.FLOAT, new TestVectorValues.PreDefinedFloatVectorValues(floatVectors) ); + + fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field").addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true").build(); } public void testTrain_oneBitQuantizer_success() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + + QuantizationState quantizationState = quantizationService.train( + oneBitParams, + knnVectorValues, + knnVectorValues.totalLiveDocs(), + fieldInfo + ); assertTrue(quantizationState instanceof OneBitScalarQuantizationState); OneBitScalarQuantizationState oneBitState = (OneBitScalarQuantizationState) quantizationState; @@ -62,7 +76,12 @@ public void testTrain_oneBitQuantizer_success() throws IOException { public void testTrain_twoBitQuantizer_success() throws IOException { ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + twoBitParams, + knnVectorValues, + knnVectorValues.totalLiveDocs(), + fieldInfo + ); assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState; @@ -85,7 +104,12 @@ public void testTrain_twoBitQuantizer_success() throws IOException { public void testTrain_fourBitQuantizer_success() throws IOException { ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + fourBitParams, + knnVectorValues, + knnVectorValues.totalLiveDocs(), + fieldInfo + ); assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState; @@ -108,9 +132,36 @@ public void testTrain_fourBitQuantizer_success() throws IOException { } } + public void testTrain_eightBitQuantizer_success() throws IOException { + ScalarQuantizationParams byteSQParams = new ScalarQuantizationParams(ScalarQuantizationType.EIGHT_BIT); + String parameterString = + "{\"index_description\":\"HNSW16,SQ8\",\"spaceType\":\"l2\",\"name\":\"hnsw\",\"data_type\":\"float\",\"parameters\":{\"ef_search\":256,\"ef_construction\":256,\"encoder\":{\"name\":\"sq\",\"parameters\":{\"clip\":false}}}}"; + fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field") + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.PARAMETERS, parameterString) + .build(); + QuantizationState quantizationState = quantizationService.train( + byteSQParams, + knnVectorValues, + knnVectorValues.totalLiveDocs(), + fieldInfo + ); + assertTrue(quantizationState instanceof ByteScalarQuantizationState); + ByteScalarQuantizationState byteScalarQuantizationState = (ByteScalarQuantizationState) quantizationState; + + // Validate the template index obtained from the training + byte[] indexTemplate = byteScalarQuantizationState.getIndexTemplate(); + assertNotNull("Index Template should not be null", indexTemplate); + } + public void testQuantize_oneBitQuantizer_success() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + oneBitParams, + knnVectorValues, + knnVectorValues.totalLiveDocs(), + fieldInfo + ); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams); @@ -125,7 +176,12 @@ public void testQuantize_oneBitQuantizer_success() throws IOException { public void testQuantize_twoBitQuantizer_success() throws IOException { ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + twoBitParams, + knnVectorValues, + knnVectorValues.totalLiveDocs(), + fieldInfo + ); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(twoBitParams); byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 4.0f, 5.0f, 6.0f }, quantizationOutput); @@ -138,7 +194,12 @@ public void testQuantize_twoBitQuantizer_success() throws IOException { public void testQuantize_fourBitQuantizer_success() throws IOException { ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + fourBitParams, + knnVectorValues, + knnVectorValues.totalLiveDocs(), + fieldInfo + ); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(fourBitParams); byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 7.0f, 8.0f, 9.0f }, quantizationOutput); @@ -150,9 +211,28 @@ public void testQuantize_fourBitQuantizer_success() throws IOException { assertArrayEquals(expectedQuantizedVector, quantizedVector); } + public void testGetVectorDataTypeForTransfer_success() { + fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field") + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.QFRAMEWORK_CONFIG, "type=byte,bits=8") + .build(); + assertEquals(VectorDataType.FLOAT, quantizationService.getVectorDataTypeForTransfer(fieldInfo)); + + fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field") + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.QFRAMEWORK_CONFIG, "type=binary,bits=2") + .build(); + assertEquals(VectorDataType.BINARY, quantizationService.getVectorDataTypeForTransfer(fieldInfo)); + } + public void testQuantize_whenInvalidInput_thenThrows() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + oneBitParams, + knnVectorValues, + knnVectorValues.totalLiveDocs(), + fieldInfo + ); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams); assertThrows(IllegalArgumentException.class, () -> quantizationService.quantize(quantizationState, null, quantizationOutput)); } diff --git a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java index 8b2cf5d2b8..76e6cde680 100644 --- a/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java +++ b/src/test/java/org/opensearch/knn/integ/ModeAndCompressionIT.java @@ -423,6 +423,16 @@ public void testTraining_whenInvalid_thenFail() { .field(MODE_PARAMETER, Mode.ON_DISK) .endObject(); expectThrows(ResponseException.class, () -> trainModel(modelId, builder2)); + + XContentBuilder builder3 = XContentFactory.jsonBuilder() + .startObject() + .field(TRAIN_INDEX_PARAMETER, TRAINING_INDEX_NAME) + .field(TRAIN_FIELD_PARAMETER, TRAINING_FIELD_NAME) + .field(KNNConstants.DIMENSION, DIMENSION) + .field(MODEL_DESCRIPTION, "test") + .field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x4.getName()) + .endObject(); + expectThrows(ResponseException.class, () -> trainModel(modelId, builder3)); } @SneakyThrows @@ -610,7 +620,7 @@ private void validateSearch( List exactSearchKnnResults = parseSearchResponseScore(exactSearchResponseBody, FIELD_NAME); assertEquals(NUM_DOCS, exactSearchKnnResults.size()); - if (CompressionLevel.x4.getName().equals(compressionLevelString) == false && Mode.ON_DISK.getName().equals(mode)) { + if (Mode.ON_DISK.getName().equals(mode)) { Assert.assertEquals(exactSearchKnnResults, knnResults); } @@ -640,7 +650,8 @@ private void validateSearch( responseBody = EntityUtils.toString(response.getEntity()); knnResults = parseSearchResponseScore(responseBody, FIELD_NAME); assertEquals(K, knnResults.size()); - if (CompressionLevel.x4.getName().equals(compressionLevelString) == false && Mode.ON_DISK.getName().equals(mode)) { + + if (Mode.ON_DISK.getName().equals(mode)) { Assert.assertEquals(exactSearchKnnResults, knnResults); } } diff --git a/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java b/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java index 99621a0e53..91f008b476 100644 --- a/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java +++ b/src/test/java/org/opensearch/knn/quantization/enums/ScalarQuantizationTypeTests.java @@ -15,7 +15,8 @@ public void testSQTypesValues() { ScalarQuantizationType[] expectedValues = { ScalarQuantizationType.ONE_BIT, ScalarQuantizationType.TWO_BIT, - ScalarQuantizationType.FOUR_BIT }; + ScalarQuantizationType.FOUR_BIT, + ScalarQuantizationType.EIGHT_BIT }; assertArrayEquals(expectedValues, ScalarQuantizationType.values()); } @@ -23,6 +24,7 @@ public void testSQTypesValueOf() { assertEquals(ScalarQuantizationType.ONE_BIT, ScalarQuantizationType.valueOf("ONE_BIT")); assertEquals(ScalarQuantizationType.TWO_BIT, ScalarQuantizationType.valueOf("TWO_BIT")); assertEquals(ScalarQuantizationType.FOUR_BIT, ScalarQuantizationType.valueOf("FOUR_BIT")); + assertEquals(ScalarQuantizationType.EIGHT_BIT, ScalarQuantizationType.valueOf("EIGHT_BIT")); } public void testUniqueSQTypeValues() { diff --git a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java index 7c974e5172..9d2507ab49 100644 --- a/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java +++ b/src/test/java/org/opensearch/knn/quantization/factory/QuantizerRegistryTests.java @@ -9,6 +9,7 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.quantizer.ByteScalarQuantizer; import org.opensearch.knn.quantization.quantizer.MultiBitScalarQuantizer; import org.opensearch.knn.quantization.quantizer.OneBitScalarQuantizer; import org.opensearch.knn.quantization.quantizer.Quantizer; @@ -30,6 +31,10 @@ public static void setup() { ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.FOUR_BIT), new MultiBitScalarQuantizer(4) ); + QuantizerRegistry.register( + ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.EIGHT_BIT), + new ByteScalarQuantizer(8) + ); } catch (Exception e) { assertTrue(e.getMessage().contains("already registered")); } @@ -50,6 +55,11 @@ public void testRegisterAndGetQuantizer() { ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); Quantizer fourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); assertEquals(fourBitQuantizer.getClass(), MultiBitScalarQuantizer.class); + + // Test for ByteScalarQuantizer (8-bit) + ScalarQuantizationParams byteSQParams = new ScalarQuantizationParams(ScalarQuantizationType.EIGHT_BIT); + Quantizer byteScalarQuantizer = QuantizerRegistry.getQuantizer(byteSQParams); + assertEquals(byteScalarQuantizer.getClass(), ByteScalarQuantizer.class); } public void testQuantizerRegistryIsSingleton() { @@ -70,6 +80,12 @@ public void testQuantizerRegistryIsSingleton() { Quantizer firstFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); Quantizer secondFourBitQuantizer = QuantizerRegistry.getQuantizer(fourBitParams); assertSame(firstFourBitQuantizer, secondFourBitQuantizer); + + // Ensure the same instance is returned for the same type identifier (8-bit) + ScalarQuantizationParams byteSQParams = new ScalarQuantizationParams(ScalarQuantizationType.EIGHT_BIT); + Quantizer firstByteScalarQuantizer = QuantizerRegistry.getQuantizer(byteSQParams); + Quantizer secondByteScalarQuantizer = QuantizerRegistry.getQuantizer(byteSQParams); + assertSame(firstByteScalarQuantizer, secondByteScalarQuantizer); } public void testRegisterQuantizerThrowsExceptionWhenAlreadyRegistered() { diff --git a/src/test/java/org/opensearch/knn/quantization/quantizer/ByteScalarQuantizerTests.java b/src/test/java/org/opensearch/knn/quantization/quantizer/ByteScalarQuantizerTests.java new file mode 100644 index 0000000000..491a53ae5c --- /dev/null +++ b/src/test/java/org/opensearch/knn/quantization/quantizer/ByteScalarQuantizerTests.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.quantization.quantizer; + +import org.apache.lucene.index.FieldInfo; +import org.opensearch.knn.KNNTestCase; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; +import org.opensearch.knn.quantization.models.quantizationState.ByteScalarQuantizationState; +import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.models.requests.TrainingRequest; + +import java.io.IOException; + +public class ByteScalarQuantizerTests extends KNNTestCase { + public void testTrain() throws IOException { + float[][] vectors = { + { 0.5f, 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f }, + { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f }, + { 1.5f, 2.5f, 3.5f, 4.5f, 5.5f, 6.5f, 7.5f, 8.5f } }; + String parameterString = + "{\"index_description\":\"HNSW16,SQ8\",\"spaceType\":\"l2\",\"name\":\"hnsw\",\"data_type\":\"float\",\"parameters\":{\"ef_search\":256,\"ef_construction\":256,\"encoder\":{\"name\":\"sq\",\"parameters\":{\"clip\":false}}}}"; + FieldInfo fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field") + .addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true") + .addAttribute(KNNConstants.PARAMETERS, parameterString) + .build(); + + ByteScalarQuantizer byteScalarQuantizer = new ByteScalarQuantizer(8); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.EIGHT_BIT); + TrainingRequest request = new ByteScalarQuantizerTests.MockTrainingRequest(params, vectors); + QuantizationState state = byteScalarQuantizer.train(request, fieldInfo); + + assertTrue(state instanceof ByteScalarQuantizationState); + ByteScalarQuantizationState byteScalarQuantizationState = (ByteScalarQuantizationState) state; + assertNotNull(byteScalarQuantizationState.getIndexTemplate()); + } + + // Mock classes for testing + private static class MockTrainingRequest extends TrainingRequest { + private final float[][] vectors; + + public MockTrainingRequest(ScalarQuantizationParams params, float[][] vectors) { + super(vectors.length); + this.vectors = vectors; + } + + @Override + public float[] getVectorAtThePosition(int position) { + return vectors[position]; + } + } +}