From 2076930ac56a74a54aa581c2d7742a8a5706e2cb Mon Sep 17 00:00:00 2001
From: Naveen Tatikonda <navtat@amazon.com>
Date: Wed, 22 Jan 2025 22:59:01 -0600
Subject: [PATCH] Fix failing tests

Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
---
 CHANGELOG.md                                  |  1 +
 ...eEngines990KnnVectorsWriterFlushTests.java | 15 +++--
 ...eEngines990KnnVectorsWriterMergeTests.java |  4 +-
 .../QuantizationServiceTests.java             | 56 ++++++++++++++++---
 4 files changed, 62 insertions(+), 14 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index b32de7bc5..05e16090a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
 - Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
 - Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331]
 - Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376]
+- Add support for Faiss onDisk 4x compression (#2425)[https://github.com/opensearch-project/k-NN/pull/2425]
 ### Enhancements
 - Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
 - Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java
index 03d0f6160..cb91ef6aa 100644
--- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java
+++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java
@@ -257,8 +257,9 @@ public void testFlush_WithQuantization() {
 
                 when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
                 try {
-                    when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
-                        .thenReturn(quantizationState);
+                    when(
+                        quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo)
+                    ).thenReturn(quantizationState);
                 } catch (Exception e) {
                     throw new RuntimeException(e);
                 }
@@ -690,8 +691,9 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres
 
                 when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
                 try {
-                    when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
-                        .thenReturn(quantizationState);
+                    when(
+                        quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo)
+                    ).thenReturn(quantizationState);
                 } catch (Exception e) {
                     throw new RuntimeException(e);
                 }
@@ -793,8 +795,9 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres
 
                 when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
                 try {
-                    when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
-                        .thenReturn(quantizationState);
+                    when(
+                        quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo)
+                    ).thenReturn(quantizationState);
                 } catch (Exception e) {
                     throw new RuntimeException(e);
                 }
diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java
index 77f3fd8ed..aedfc273b 100644
--- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java
+++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterMergeTests.java
@@ -325,7 +325,9 @@ public void testMerge_WithQuantization() {
 
             when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
             try {
-                when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size())).thenReturn(quantizationState);
+                when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size(), fieldInfo)).thenReturn(
+                    quantizationState
+                );
             } catch (Exception e) {
                 throw new RuntimeException(e);
             }
diff --git a/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java
index 690391dbd..9b1904720 100644
--- a/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java
+++ b/src/test/java/org/opensearch/knn/index/quantizationservice/QuantizationServiceTests.java
@@ -5,10 +5,13 @@
 
 package org.opensearch.knn.index.quantizationservice;
 
+import org.apache.lucene.index.FieldInfo;
 import org.opensearch.knn.KNNTestCase;
 import org.junit.Before;
 
 import org.opensearch.knn.index.VectorDataType;
+import org.opensearch.knn.index.codec.KNNCodecTestUtil;
+import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
 import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
 import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
 import org.opensearch.knn.index.vectorvalues.TestVectorValues;
@@ -24,6 +27,7 @@
 public class QuantizationServiceTests extends KNNTestCase {
     private QuantizationService<float[], byte[]> quantizationService;
     private KNNVectorValues<float[]> knnVectorValues;
+    private FieldInfo fieldInfo;
 
     @Before
     public void setUp() throws Exception {
@@ -42,11 +46,19 @@ public void setUp() throws Exception {
             VectorDataType.FLOAT,
             new TestVectorValues.PreDefinedFloatVectorValues(floatVectors)
         );
+
+        fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field").addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true").build();
     }
 
     public void testTrain_oneBitQuantizer_success() throws IOException {
         ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
-        QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
+
+        QuantizationState quantizationState = quantizationService.train(
+            oneBitParams,
+            knnVectorValues,
+            knnVectorValues.totalLiveDocs(),
+            fieldInfo
+        );
 
         assertTrue(quantizationState instanceof OneBitScalarQuantizationState);
         OneBitScalarQuantizationState oneBitState = (OneBitScalarQuantizationState) quantizationState;
@@ -62,7 +74,12 @@ public void testTrain_oneBitQuantizer_success() throws IOException {
 
     public void testTrain_twoBitQuantizer_success() throws IOException {
         ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
-        QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
+        QuantizationState quantizationState = quantizationService.train(
+            twoBitParams,
+            knnVectorValues,
+            knnVectorValues.totalLiveDocs(),
+            fieldInfo
+        );
 
         assertTrue(quantizationState instanceof MultiBitScalarQuantizationState);
         MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState;
@@ -85,7 +102,12 @@ public void testTrain_twoBitQuantizer_success() throws IOException {
 
     public void testTrain_fourBitQuantizer_success() throws IOException {
         ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
-        QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
+        QuantizationState quantizationState = quantizationService.train(
+            fourBitParams,
+            knnVectorValues,
+            knnVectorValues.totalLiveDocs(),
+            fieldInfo
+        );
 
         assertTrue(quantizationState instanceof MultiBitScalarQuantizationState);
         MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState;
@@ -110,7 +132,12 @@ public void testTrain_fourBitQuantizer_success() throws IOException {
 
     public void testQuantize_oneBitQuantizer_success() throws IOException {
         ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
-        QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
+        QuantizationState quantizationState = quantizationService.train(
+            oneBitParams,
+            knnVectorValues,
+            knnVectorValues.totalLiveDocs(),
+            fieldInfo
+        );
 
         QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams);
 
@@ -125,7 +152,12 @@ public void testQuantize_oneBitQuantizer_success() throws IOException {
 
     public void testQuantize_twoBitQuantizer_success() throws IOException {
         ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
-        QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
+        QuantizationState quantizationState = quantizationService.train(
+            twoBitParams,
+            knnVectorValues,
+            knnVectorValues.totalLiveDocs(),
+            fieldInfo
+        );
         QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(twoBitParams);
         byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 4.0f, 5.0f, 6.0f }, quantizationOutput);
 
@@ -138,7 +170,12 @@ public void testQuantize_twoBitQuantizer_success() throws IOException {
 
     public void testQuantize_fourBitQuantizer_success() throws IOException {
         ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
-        QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
+        QuantizationState quantizationState = quantizationService.train(
+            fourBitParams,
+            knnVectorValues,
+            knnVectorValues.totalLiveDocs(),
+            fieldInfo
+        );
         QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(fourBitParams);
 
         byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 7.0f, 8.0f, 9.0f }, quantizationOutput);
@@ -152,7 +189,12 @@ public void testQuantize_fourBitQuantizer_success() throws IOException {
 
     public void testQuantize_whenInvalidInput_thenThrows() throws IOException {
         ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
-        QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
+        QuantizationState quantizationState = quantizationService.train(
+            oneBitParams,
+            knnVectorValues,
+            knnVectorValues.totalLiveDocs(),
+            fieldInfo
+        );
         QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams);
         assertThrows(IllegalArgumentException.class, () -> quantizationService.quantize(quantizationState, null, quantizationOutput));
     }