Skip to content

Commit

Permalink
Fix failing tests
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jan 23, 2025
1 parent ebe809c commit 101c94f
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,9 @@ public void testFlush_WithQuantization() {

when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
try {
when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
.thenReturn(quantizationState);
when(
quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo)
).thenReturn(quantizationState);
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -690,8 +691,9 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres

when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
try {
when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
.thenReturn(quantizationState);
when(
quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo)
).thenReturn(quantizationState);
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -793,8 +795,9 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres

when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
try {
when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
.thenReturn(quantizationState);
when(
quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo)
).thenReturn(quantizationState);
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,9 @@ public void testMerge_WithQuantization() {

when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
try {
when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size())).thenReturn(quantizationState);
when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size(), fieldInfo)).thenReturn(
quantizationState
);
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

package org.opensearch.knn.index.quantizationservice;

import org.apache.lucene.index.FieldInfo;
import org.opensearch.knn.KNNTestCase;
import org.junit.Before;

import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.KNNCodecTestUtil;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.opensearch.knn.index.vectorvalues.TestVectorValues;
Expand All @@ -24,6 +27,7 @@
public class QuantizationServiceTests extends KNNTestCase {
private QuantizationService<float[], byte[]> quantizationService;
private KNNVectorValues<float[]> knnVectorValues;
private FieldInfo fieldInfo;

@Before
public void setUp() throws Exception {
Expand All @@ -42,11 +46,19 @@ public void setUp() throws Exception {
VectorDataType.FLOAT,
new TestVectorValues.PreDefinedFloatVectorValues(floatVectors)
);

fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field").addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true").build();
}

public void testTrain_oneBitQuantizer_success() throws IOException {
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());

QuantizationState quantizationState = quantizationService.train(
oneBitParams,
knnVectorValues,
knnVectorValues.totalLiveDocs(),
fieldInfo
);

assertTrue(quantizationState instanceof OneBitScalarQuantizationState);
OneBitScalarQuantizationState oneBitState = (OneBitScalarQuantizationState) quantizationState;
Expand All @@ -62,7 +74,12 @@ public void testTrain_oneBitQuantizer_success() throws IOException {

public void testTrain_twoBitQuantizer_success() throws IOException {
ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
QuantizationState quantizationState = quantizationService.train(
twoBitParams,
knnVectorValues,
knnVectorValues.totalLiveDocs(),
fieldInfo
);

assertTrue(quantizationState instanceof MultiBitScalarQuantizationState);
MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState;
Expand All @@ -85,7 +102,12 @@ public void testTrain_twoBitQuantizer_success() throws IOException {

public void testTrain_fourBitQuantizer_success() throws IOException {
ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
QuantizationState quantizationState = quantizationService.train(
fourBitParams,
knnVectorValues,
knnVectorValues.totalLiveDocs(),
fieldInfo
);

assertTrue(quantizationState instanceof MultiBitScalarQuantizationState);
MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState;
Expand All @@ -110,7 +132,12 @@ public void testTrain_fourBitQuantizer_success() throws IOException {

public void testQuantize_oneBitQuantizer_success() throws IOException {
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
QuantizationState quantizationState = quantizationService.train(
oneBitParams,
knnVectorValues,
knnVectorValues.totalLiveDocs(),
fieldInfo
);

QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams);

Expand All @@ -125,7 +152,12 @@ public void testQuantize_oneBitQuantizer_success() throws IOException {

public void testQuantize_twoBitQuantizer_success() throws IOException {
ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
QuantizationState quantizationState = quantizationService.train(
twoBitParams,
knnVectorValues,
knnVectorValues.totalLiveDocs(),
fieldInfo
);
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(twoBitParams);
byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 4.0f, 5.0f, 6.0f }, quantizationOutput);

Expand All @@ -138,7 +170,12 @@ public void testQuantize_twoBitQuantizer_success() throws IOException {

public void testQuantize_fourBitQuantizer_success() throws IOException {
ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
QuantizationState quantizationState = quantizationService.train(
fourBitParams,
knnVectorValues,
knnVectorValues.totalLiveDocs(),
fieldInfo
);
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(fourBitParams);

byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 7.0f, 8.0f, 9.0f }, quantizationOutput);
Expand All @@ -152,7 +189,12 @@ public void testQuantize_fourBitQuantizer_success() throws IOException {

public void testQuantize_whenInvalidInput_thenThrows() throws IOException {
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
QuantizationState quantizationState = quantizationService.train(
oneBitParams,
knnVectorValues,
knnVectorValues.totalLiveDocs(),
fieldInfo
);
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams);
assertThrows(IllegalArgumentException.class, () -> quantizationService.quantize(quantizationState, null, quantizationOutput));
}
Expand Down

0 comments on commit 101c94f

Please sign in to comment.