diff --git a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java index 38eb628e2..13c82395e 100644 --- a/src/main/java/org/opensearch/knn/index/KNNIndexShard.java +++ b/src/main/java/org/opensearch/knn/index/KNNIndexShard.java @@ -198,7 +198,7 @@ List getEngineFileContexts(IndexReader indexReader, KNNEngine ) : FieldInfoExtractor.extractQuantizationConfig(fieldInfo) .getQuantizationType() == ScalarQuantizationType.EIGHT_BIT - ? VectorDataType.BYTE + ? VectorDataType.FLOAT : VectorDataType.BINARY ) ); diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java index 07a676dda..fb809850a 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/KNN990QuantizationStateReader.java @@ -89,8 +89,8 @@ public static QuantizationState read(QuantizationStateReadConfig readConfig) thr case TWO_BIT: case FOUR_BIT: return MultiBitScalarQuantizationState.fromByteArray(stateBytes); - case EIGHT_BIT: - return ByteScalarQuantizationState.fromByteArray(stateBytes); +// case EIGHT_BIT: +// return ByteScalarQuantizationState.fromByteArray(stateBytes); default: throw new IllegalArgumentException(String.format("Unexpected scalar quantization type: %s", scalarQuantizationType)); } diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 7c8636577..0d4cd4265 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -30,7 +30,9 @@ import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; @@ -255,10 +257,14 @@ private QuantizationState train( final QuantizationParams quantizationParams = quantizationService.getQuantizationParams(fieldInfo); QuantizationState quantizationState = null; if (quantizationParams != null && totalLiveDocs > 0) { - initQuantizationStateWriterIfNecessary(); KNNVectorValues knnVectorValues = knnVectorValuesSupplier.get(); - quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs); - quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); + if ((quantizationParams.getTypeIdentifier()).equals(ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.EIGHT_BIT))) { + quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs, fieldInfo); + } else { + initQuantizationStateWriterIfNecessary(); + quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs, fieldInfo); + quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState); + } } return quantizationState; diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java index 23c3ba116..67d4c25f5 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java @@ -12,6 +12,7 @@ import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.jni.JNIService; +import org.opensearch.knn.quantization.models.quantizationState.ByteScalarQuantizationState; import java.io.IOException; import java.security.AccessController; @@ -77,14 +78,14 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOExcept long vectorAddress = vectorTransfer.getVectorAddress(); // Currently this is if else as there are only two cases, with more cases this will have to be made // more maintainable - if (params.containsKey(MODEL_ID)) { + if (params.containsKey(MODEL_ID) || (indexInfo.getQuantizationState() instanceof ByteScalarQuantizationState)) { AccessController.doPrivileged((PrivilegedAction) () -> { JNIService.createIndexFromTemplate( intListToArray(transferredDocIds), vectorAddress, indexBuildSetup.getDimensions(), indexInfo.getIndexOutputWithBuffer(), - (byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER), + getIndexTemplate(params, indexInfo), params, indexInfo.getKnnEngine() ); @@ -112,4 +113,13 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOExcept ); } } + + private byte[] getIndexTemplate(Map params, BuildIndexParams indexInfo) { + if (params.containsKey(MODEL_ID)) { + return (byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER); + } + + ByteScalarQuantizationState byteSQState = (ByteScalarQuantizationState) indexInfo.getQuantizationState(); + return byteSQState.getIndexTemplate(); + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java index 27a1ecfb6..6f9805d8c 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -32,6 +32,7 @@ import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.plugin.stats.KNNGraphValue; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; @@ -312,11 +313,19 @@ private static NativeIndexWriter createWriter( @Nullable final QuantizationState quantizationState ) { final KNNEngine knnEngine = extractKNNEngine(fieldInfo); - boolean isTemplate = fieldInfo.attributes().containsKey(MODEL_ID); - boolean iterative = !isTemplate && KNNEngine.FAISS == knnEngine; + boolean iterative = !isTemplate(fieldInfo) && KNNEngine.FAISS == knnEngine; NativeIndexBuildStrategy strategy = iterative ? MemOptimizedNativeIndexBuildStrategy.getInstance() : DefaultIndexBuildStrategy.getInstance(); return new NativeIndexWriter(state, fieldInfo, strategy, quantizationState); } + + private static boolean isTemplate(FieldInfo fieldInfo) { + if (fieldInfo.attributes().containsKey(MODEL_ID)) { + return true; + } + + QuantizationConfig quantizationConfig = FieldInfoExtractor.extractQuantizationConfig(fieldInfo); + return quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() == ScalarQuantizationType.EIGHT_BIT; + } } diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java index c5994d66b..745802dd3 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/QuantizationIndexUtils.java @@ -10,6 +10,7 @@ import org.opensearch.knn.index.quantizationservice.QuantizationService; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; +import org.opensearch.knn.quantization.models.quantizationState.ByteScalarQuantizationState; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; @@ -58,7 +59,7 @@ static IndexBuildSetup prepareIndexBuild(KNNVectorValues knnVectorValues, Bui int bytesPerVector; int dimensions; - if (quantizationState != null) { + if (quantizationState != null && !(quantizationState instanceof ByteScalarQuantizationState)) { bytesPerVector = quantizationState.getBytesPerVector(); dimensions = quantizationState.getDimensions(); quantizationOutput = quantizationService.createQuantizationOutput(quantizationState.getQuantizationParams()); diff --git a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java index 3bc6087c8..c8d56fe95 100644 --- a/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java +++ b/src/main/java/org/opensearch/knn/index/engine/faiss/AbstractFaissMethod.java @@ -109,29 +109,38 @@ static KNNLibraryIndexingContext adjustIndexDescription( if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BINARY) { prefix = FAISS_BINARY_INDEX_DESCRIPTION_PREFIX; } - if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BYTE - || (encoderContext != null + + if (knnMethodConfigContext.getVectorDataType() == VectorDataType.BYTE) { + + // If VectorDataType is Byte using Faiss engine then manipulate Index Description to use "SQ8_direct_signed" scalar quantizer + // For example, Index Description "HNSW16,Flat" will be updated as "HNSW16,SQ8_direct_signed" + methodAsMapBuilder.indexDescription = updateIndexDescription(methodAsMapBuilder.indexDescription, FAISS_SIGNED_BYTE_SQ); + } + + if (encoderContext != null && Objects.equals(encoderContext.getName(), ENCODER_SQ) && Objects.equals( encoderContext.getParameters().getOrDefault(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16), FAISS_SQ_ENCODER_INT8 - ))) { - - // If VectorDataType is Byte using Faiss engine then manipulate Index Description to use "SQ8_direct_signed" scalar quantizer - // For example, Index Description "HNSW16,Flat" will be updated as "HNSW16,SQ8_direct_signed" - String indexDescription = methodAsMapBuilder.indexDescription; - if (StringUtils.isNotEmpty(indexDescription)) { - StringBuilder indexDescriptionBuilder = new StringBuilder(); - indexDescriptionBuilder.append(indexDescription.split(",")[0]); - indexDescriptionBuilder.append(","); - indexDescriptionBuilder.append(FAISS_SIGNED_BYTE_SQ); - methodAsMapBuilder.indexDescription = indexDescriptionBuilder.toString(); - } + )) { + methodAsMapBuilder.indexDescription = updateIndexDescription(methodAsMapBuilder.indexDescription, "SQ8"); } methodAsMapBuilder.indexDescription = prefix + methodAsMapBuilder.indexDescription; return methodAsMapBuilder.build(); } + private static String updateIndexDescription(String indexDescription, String indexDescriptionName) { + if (StringUtils.isEmpty(indexDescription)) { + return indexDescription; + } + + StringBuilder indexDescriptionBuilder = new StringBuilder(); + indexDescriptionBuilder.append(indexDescription.split(",")[0]); + indexDescriptionBuilder.append(","); + indexDescriptionBuilder.append(indexDescriptionName); + return indexDescriptionBuilder.toString(); + } + static MethodComponentContext getEncoderMethodComponent(MethodComponentContext methodComponentContext) { if (!methodComponentContext.getParameters().containsKey(METHOD_ENCODER_PARAMETER)) { return null; diff --git a/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java b/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java index ad506e978..e2c6270fb 100644 --- a/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java +++ b/src/main/java/org/opensearch/knn/index/quantizationservice/QuantizationService.java @@ -19,6 +19,7 @@ import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; +import org.opensearch.knn.quantization.quantizer.ByteScalarQuantizer; import org.opensearch.knn.quantization.quantizer.Quantizer; import java.io.IOException; @@ -60,9 +61,10 @@ public static QuantizationService getInstance() { * @throws IOException If an I/O error occurs during the training process. */ public QuantizationState train( - final QuantizationParams quantizationParams, - final KNNVectorValues knnVectorValues, - final long liveDocs + final QuantizationParams quantizationParams, + final KNNVectorValues knnVectorValues, + final long liveDocs, + final FieldInfo fieldInfo ) throws IOException { Quantizer quantizer = QuantizerFactory.getQuantizer(quantizationParams); @@ -70,6 +72,9 @@ public QuantizationState train( KNNVectorQuantizationTrainingRequest trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues, liveDocs); // Train the quantizer and return the quantization state + if(quantizer instanceof ByteScalarQuantizer) { + return quantizer.train(trainingRequest, fieldInfo); + } return quantizer.train(trainingRequest); } @@ -111,7 +116,7 @@ public VectorDataType getVectorDataTypeForTransfer(final FieldInfo fieldInfo) { QuantizationConfig quantizationConfig = extractQuantizationConfig(fieldInfo); if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() == ScalarQuantizationType.EIGHT_BIT) { - return VectorDataType.BYTE; + return VectorDataType.FLOAT; } if (quantizationConfig != QuantizationConfig.EMPTY && quantizationConfig.getQuantizationType() != null) { return VectorDataType.BINARY; @@ -130,9 +135,9 @@ public VectorDataType getVectorDataTypeForTransfer(final FieldInfo fieldInfo) { public QuantizationOutput createQuantizationOutput(final QuantizationParams quantizationParams) { if (quantizationParams instanceof ScalarQuantizationParams) { ScalarQuantizationParams scalarParams = (ScalarQuantizationParams) quantizationParams; - if (scalarParams.getSqType() == ScalarQuantizationType.EIGHT_BIT) { - return (QuantizationOutput) new ByteQuantizationOutput(scalarParams.getSqType().getId()); - } +// if (scalarParams.getSqType() == ScalarQuantizationType.EIGHT_BIT) { +// return (QuantizationOutput) new ByteQuantizationOutput(scalarParams.getSqType().getId()); +// } return (QuantizationOutput) new BinaryQuantizationOutput(scalarParams.getSqType().getId()); } throw new IllegalArgumentException("Unsupported quantization parameters: " + quantizationParams.getClass().getName()); diff --git a/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java b/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java index d25774cdc..6b7928abf 100644 --- a/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java +++ b/src/main/java/org/opensearch/knn/index/query/SegmentLevelQuantizationInfo.java @@ -11,7 +11,9 @@ import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.LeafReader; import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams; +import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; @@ -36,7 +38,7 @@ public class SegmentLevelQuantizationInfo { public static SegmentLevelQuantizationInfo build(final LeafReader leafReader, final FieldInfo fieldInfo, final String fieldName) throws IOException { final QuantizationParams quantizationParams = QuantizationService.getInstance().getQuantizationParams(fieldInfo); - if (quantizationParams == null) { + if (quantizationParams == null || (quantizationParams.getTypeIdentifier()).equals(ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.EIGHT_BIT))) { return null; } final QuantizationState quantizationState = SegmentLevelQuantizationUtil.getQuantizationState(leafReader, fieldName); diff --git a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/ByteScalarQuantizationState.java b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/ByteScalarQuantizationState.java index 679e60325..44d0ebfd8 100644 --- a/src/main/java/org/opensearch/knn/quantization/models/quantizationState/ByteScalarQuantizationState.java +++ b/src/main/java/org/opensearch/knn/quantization/models/quantizationState/ByteScalarQuantizationState.java @@ -22,8 +22,7 @@ @AllArgsConstructor public class ByteScalarQuantizationState implements QuantizationState { private ScalarQuantizationParams quantizationParams; - private float[] min; - private float[] diff; + private byte[] indexTemplate; @Override public QuantizationParams getQuantizationParams() { @@ -34,15 +33,14 @@ public QuantizationParams getQuantizationParams() { public void writeTo(StreamOutput out) throws IOException { out.writeVInt(Version.CURRENT.id); // Write the version quantizationParams.writeTo(out); - out.writeFloatArray(min); - out.writeFloatArray(diff); + out.writeByteArray(indexTemplate); } public ByteScalarQuantizationState(StreamInput in) throws IOException { int version = in.readVInt(); // Read the version this.quantizationParams = new ScalarQuantizationParams(in, version); - this.min = in.readFloatArray(); - this.diff = in.readFloatArray(); + this.indexTemplate = in.readByteArray(); + } @Override @@ -56,20 +54,19 @@ public static ByteScalarQuantizationState fromByteArray(final byte[] bytes) thro @Override public int getBytesPerVector() { - return min.length; + return 0; } @Override public int getDimensions() { - return min.length; + return 0; } @Override public long ramBytesUsed() { long size = RamUsageEstimator.shallowSizeOfInstance(ByteScalarQuantizationState.class); size += RamUsageEstimator.shallowSizeOf(quantizationParams); - size += RamUsageEstimator.sizeOf(min); - size += RamUsageEstimator.sizeOf(diff); + size += RamUsageEstimator.sizeOf(indexTemplate); return size; } } diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/ByteScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/ByteScalarQuantizer.java index b0c121ab3..ae1870fda 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/ByteScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/ByteScalarQuantizer.java @@ -5,6 +5,19 @@ package org.opensearch.knn.quantization.quantizer; +import org.apache.lucene.index.FieldInfo; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.xcontent.DeprecationHandler; +import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.index.KNNSettings; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransfer; +import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; @@ -17,7 +30,16 @@ import oshi.util.tuples.Pair; import java.io.IOException; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; +import static org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory.getVectorTransfer; public class ByteScalarQuantizer implements Quantizer { private final int bitsPerCoordinate; @@ -58,30 +80,110 @@ public ByteScalarQuantizer(final int bitsPerCoordinate) { // return new ByteScalarQuantizationState(params, minAndDiff.getA(), minAndDiff.getB()); // } + // Train using Quantile +// @Override +// public QuantizationState train(TrainingRequest trainingRequest) throws IOException { +// int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); +// float[][] transposedVec = transposeVectors(trainingRequest, sampledIndices); +// Pair minAndDiff = calculateMinAndDiffUsingQuantile(transposedVec); +// ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.EIGHT_BIT); +// return new ByteScalarQuantizationState(params, minAndDiff.getA(), minAndDiff.getB()); +// } + + @Override + public QuantizationState train(final TrainingRequest trainingRequest) throws IOException { + return null; + } @Override - public QuantizationState train(TrainingRequest trainingRequest) throws IOException { + public QuantizationState train(final TrainingRequest trainingRequest, final FieldInfo fieldInfo) throws IOException { int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize); - float[][] transposedVec = transposeVectors(trainingRequest, sampledIndices); - Pair minAndDiff = calculateMinAndDiffUsingQuantile(transposedVec); + if (sampledIndices.length == 0) { + return null; + } + float[] vector = trainingRequest.getVectorAtThePosition(sampledIndices[0]); + if (vector == null) { + throw new IllegalArgumentException("Vector at sampled index " + sampledIndices[0] + " is null."); + } + int dimension = vector.length; + byte[] indexTemplate = JNIService.trainIndex(getParameters(fieldInfo), dimension, getVectorAddressOfTrainData(sampledIndices, fieldInfo, trainingRequest, dimension), KNNEngine.FAISS); + ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.EIGHT_BIT); - return new ByteScalarQuantizationState(params, minAndDiff.getA(), minAndDiff.getB()); + return new ByteScalarQuantizationState(params, indexTemplate); + } + + private long getVectorAddressOfTrainData(int[] sampledIndices, FieldInfo fieldInfo, final TrainingRequest trainingRequest, int dimension) throws IOException { + int totalSamples = sampledIndices.length; + + final OffHeapVectorTransfer vectorTransfer = getVectorTransfer( + extractVectorDataType(fieldInfo), + 4*dimension, + sampledIndices.length + ); + final List transferredDocIds = new ArrayList<>(totalSamples); + for(int i=0; i< totalSamples; i++) { + Object vectorToTransfer = trainingRequest.getVectorAtThePosition(sampledIndices[i]); + vectorTransfer.transfer(vectorToTransfer, true); + } + vectorTransfer.flush(true); + return vectorTransfer.getVectorAddress(); + } + + private Map getParameters(final FieldInfo fieldInfo) throws IOException { + Map parameters = new HashMap<>(); + Map fieldAttributes = fieldInfo.attributes(); + String parametersString = fieldAttributes.get(KNNConstants.PARAMETERS); + + // parametersString will be null when legacy mapper is used + if (parametersString == null) { + parameters.put(KNNConstants.SPACE_TYPE, fieldAttributes.getOrDefault(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue())); + + Map algoParams = new HashMap<>(); + + String efConstruction = fieldAttributes.get(KNNConstants.HNSW_ALGO_EF_CONSTRUCTION); + if (efConstruction != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_EF_CONSTRUCTION, Integer.parseInt(efConstruction)); + } + + String m = fieldAttributes.get(KNNConstants.HNSW_ALGO_M); + if (m != null) { + algoParams.put(KNNConstants.METHOD_PARAMETER_M, Integer.parseInt(m)); + } + parameters.put(PARAMETERS, algoParams); + } else { + parameters.putAll( + XContentHelper.createParser( + NamedXContentRegistry.EMPTY, + DeprecationHandler.THROW_UNSUPPORTED_OPERATION, + new BytesArray(parametersString), + MediaTypeRegistry.getDefaultMediaType() + ).map() + ); + } + + parameters.put(KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.FLOAT); + + + // Used to determine how many threads to use when indexing + parameters.put(KNNConstants.INDEX_THREAD_QTY, KNNSettings.state().getSettingValue(KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY)); + + return parameters; } @Override public void quantize(float[] vector, QuantizationState state, QuantizationOutput output) { - if (vector == null) { - throw new IllegalArgumentException("Vector to quantize must not be null."); - } - validateState(state); - int vectorLength = vector.length; - ByteScalarQuantizationState byteSQState = (ByteScalarQuantizationState) state; - float[] minArray = byteSQState.getMin(); - float[] diffArray = byteSQState.getDiff(); - if (minArray == null || minArray.length != vectorLength || diffArray == null || diffArray.length != vectorLength) { - throw new IllegalArgumentException("min and diff arrays must not be null and must match the dimension of the vector."); - } - output.prepareQuantizedVector(vectorLength); - quantizeVector(vector, minArray, diffArray, output.getQuantizedVector()); +// if (vector == null) { +// throw new IllegalArgumentException("Vector to quantize must not be null."); +// } +// validateState(state); +// int vectorLength = vector.length; +// ByteScalarQuantizationState byteSQState = (ByteScalarQuantizationState) state; +// float[] minArray = byteSQState.getMin(); +// float[] diffArray = byteSQState.getDiff(); +// if (minArray == null || minArray.length != vectorLength || diffArray == null || diffArray.length != vectorLength) { +// throw new IllegalArgumentException("min and diff arrays must not be null and must match the dimension of the vector."); +// } +// output.prepareQuantizedVector(vectorLength); +// quantizeVector(vector, minArray, diffArray, output.getQuantizedVector()); } private void quantizeVector(final float[] vector, final float[] min, final float[] diff, byte[] quantizedVector) { diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java index 0bcc252d1..bc5022709 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/MultiBitScalarQuantizer.java @@ -6,6 +6,7 @@ package org.opensearch.knn.quantization.quantizer; +import org.apache.lucene.index.FieldInfo; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; @@ -119,6 +120,11 @@ public QuantizationState train(final TrainingRequest trainingRequest) t return new MultiBitScalarQuantizationState(params, thresholds); } + @Override + public QuantizationState train(final TrainingRequest trainingRequest, final FieldInfo fieldInfo) throws IOException { + return null; + } + /** * Quantizes the provided vector using the provided quantization state, producing a quantized output. * The vector is quantized based on the thresholds in the quantization state. diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java index 3cba89c39..70ed81a83 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/OneBitScalarQuantizer.java @@ -5,6 +5,7 @@ package org.opensearch.knn.quantization.quantizer; +import org.apache.lucene.index.FieldInfo; import org.opensearch.knn.quantization.enums.ScalarQuantizationType; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; @@ -64,6 +65,11 @@ public QuantizationState train(final TrainingRequest trainingRequest) t return new OneBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), meanThresholds); } + @Override + public QuantizationState train(final TrainingRequest trainingRequest, final FieldInfo fieldInfo) throws IOException { + return null; + } + /** * Quantizes the provided vector using the given quantization state. * It compares each dimension of the vector against the corresponding mean (threshold) to determine the quantized value. diff --git a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java index 521863205..1e343b5e8 100644 --- a/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java +++ b/src/main/java/org/opensearch/knn/quantization/quantizer/Quantizer.java @@ -5,6 +5,7 @@ package org.opensearch.knn.quantization.quantizer; +import org.apache.lucene.index.FieldInfo; import org.opensearch.knn.quantization.models.quantizationOutput.QuantizationOutput; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import org.opensearch.knn.quantization.models.requests.TrainingRequest; @@ -31,6 +32,8 @@ public interface Quantizer { */ QuantizationState train(TrainingRequest trainingRequest) throws IOException; + QuantizationState train(TrainingRequest trainingRequest, FieldInfo fieldInfo) throws IOException; + /** * Quantizes the provided vector using the specified quantization state. *