diff --git a/CHANGELOG.md b/CHANGELOG.md index 6d6a2eb79..2bd66a117 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292] - Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331] - Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376] +- Add support for Faiss onDisk 4x compression (#2425)[https://github.com/opensearch-project/k-NN/pull/2425] ### Enhancements - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241] - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290] diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index 6685e2b22..db7664d54 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -257,8 +257,9 @@ public void testFlush_WithQuantization() { when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); try { - when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) - .thenReturn(quantizationState); + when( + quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo) + ).thenReturn(quantizationState); } catch (Exception e) { throw new RuntimeException(e); } @@ -690,8 +691,9 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); try { - when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) - .thenReturn(quantizationState); + when( + quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo) + ).thenReturn(quantizationState); } catch (Exception e) { throw new RuntimeException(e); } @@ -793,8 +795,9 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); try { - when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) - .thenReturn(quantizationState); + when( + quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo) + ).thenReturn(quantizationState); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java index cdc372bda..097144951 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java @@ -325,7 +325,9 @@ public void testMerge_WithQuantization() { when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); try { - when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size())).thenReturn(quantizationState); + when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size(), fieldInfo)).thenReturn( + quantizationState + ); } catch (Exception e) { throw new RuntimeException(e); } diff --git a/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java index 690391dbd..9b1904720 100644 --- a/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java +++ b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java @@ -5,10 +5,13 @@ package org.opensearch.knn.index.quantizationservice; +import org.apache.lucene.index.FieldInfo; import org.opensearch.knn.KNNTestCase; import org.junit.Before; import org.opensearch.knn.index.VectorDataType; +import org.opensearch.knn.index.codec.KNNCodecTestUtil; +import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.index.vectorvalues.TestVectorValues; @@ -24,6 +27,7 @@ public class QuantizationServiceTests extends KNNTestCase { private QuantizationService quantizationService; private KNNVectorValues knnVectorValues; + private FieldInfo fieldInfo; @Before public void setUp() throws Exception { @@ -42,11 +46,19 @@ public void setUp() throws Exception { VectorDataType.FLOAT, new TestVectorValues.PreDefinedFloatVectorValues(floatVectors) ); + + fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field").addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true").build(); } public void testTrain_oneBitQuantizer_success() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + + QuantizationState quantizationState = quantizationService.train( + oneBitParams, + knnVectorValues, + knnVectorValues.totalLiveDocs(), + fieldInfo + ); assertTrue(quantizationState instanceof OneBitScalarQuantizationState); OneBitScalarQuantizationState oneBitState = (OneBitScalarQuantizationState) quantizationState; @@ -62,7 +74,12 @@ public void testTrain_oneBitQuantizer_success() throws IOException { public void testTrain_twoBitQuantizer_success() throws IOException { ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + twoBitParams, + knnVectorValues, + knnVectorValues.totalLiveDocs(), + fieldInfo + ); assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState; @@ -85,7 +102,12 @@ public void testTrain_twoBitQuantizer_success() throws IOException { public void testTrain_fourBitQuantizer_success() throws IOException { ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + fourBitParams, + knnVectorValues, + knnVectorValues.totalLiveDocs(), + fieldInfo + ); assertTrue(quantizationState instanceof MultiBitScalarQuantizationState); MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState; @@ -110,7 +132,12 @@ public void testTrain_fourBitQuantizer_success() throws IOException { public void testQuantize_oneBitQuantizer_success() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + oneBitParams, + knnVectorValues, + knnVectorValues.totalLiveDocs(), + fieldInfo + ); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams); @@ -125,7 +152,12 @@ public void testQuantize_oneBitQuantizer_success() throws IOException { public void testQuantize_twoBitQuantizer_success() throws IOException { ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT); - QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + twoBitParams, + knnVectorValues, + knnVectorValues.totalLiveDocs(), + fieldInfo + ); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(twoBitParams); byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 4.0f, 5.0f, 6.0f }, quantizationOutput); @@ -138,7 +170,12 @@ public void testQuantize_twoBitQuantizer_success() throws IOException { public void testQuantize_fourBitQuantizer_success() throws IOException { ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT); - QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + fourBitParams, + knnVectorValues, + knnVectorValues.totalLiveDocs(), + fieldInfo + ); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(fourBitParams); byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 7.0f, 8.0f, 9.0f }, quantizationOutput); @@ -152,7 +189,12 @@ public void testQuantize_fourBitQuantizer_success() throws IOException { public void testQuantize_whenInvalidInput_thenThrows() throws IOException { ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT); - QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs()); + QuantizationState quantizationState = quantizationService.train( + oneBitParams, + knnVectorValues, + knnVectorValues.totalLiveDocs(), + fieldInfo + ); QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams); assertThrows(IllegalArgumentException.class, () -> quantizationService.quantize(quantizationState, null, quantizationOutput)); }