Skip to content

Commit

Permalink
Fix failing tests
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jan 23, 2025
1 parent ebe809c commit 6721604
Show file tree
Hide file tree
Showing 11 changed files with 117 additions and 44 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Add binary index support for Lucene engine. (#2292)[https://github.com/opensearch-project/k-NN/pull/2292]
- Add expand_nested_docs Parameter support to NMSLIB engine (#2331)[https://github.com/opensearch-project/k-NN/pull/2331]
- Add cosine similarity support for faiss engine (#2376)[https://github.com/opensearch-project/k-NN/pull/2376]
- Add support for Faiss onDisk 4x compression (#2425)[https://github.com/opensearch-project/k-NN/pull/2425]
### Enhancements
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ public KNNEngine resolveEngine(
}

// 4x is supported by Lucene engine before version 2.19.0
if (knnMethodConfigContext.getVersionCreated().before(Version.V_2_19_0) && compressionLevel == CompressionLevel.x4) {
if (compressionLevel == CompressionLevel.x4 && knnMethodConfigContext.getVersionCreated().before(Version.V_2_19_0)) {
return KNNEngine.LUCENE;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.opensearch.knn.index.vectorvalues.TestVectorValues;
import org.opensearch.knn.plugin.stats.KNNGraphValue;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.test.OpenSearchTestCase;

Expand Down Expand Up @@ -256,9 +258,13 @@ public void testFlush_WithQuantization() {
).thenReturn(expectedVectorValues.get(i));

when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
when(quantizationParams.getTypeIdentifier()).thenReturn(
ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT)
);
try {
when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
.thenReturn(quantizationState);
when(
quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo)
).thenReturn(quantizationState);
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -689,9 +695,13 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres
).thenReturn(expectedVectorValues.get(i));

when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
when(quantizationParams.getTypeIdentifier()).thenReturn(
ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT)
);
try {
when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
.thenReturn(quantizationState);
when(
quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo)
).thenReturn(quantizationState);
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down Expand Up @@ -792,9 +802,13 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres
).thenReturn(expectedVectorValues.get(i));

when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
when(quantizationParams.getTypeIdentifier()).thenReturn(
ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT)
);
try {
when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
.thenReturn(quantizationState);
when(
quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size(), fieldInfo)
).thenReturn(quantizationState);
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.opensearch.knn.index.vectorvalues.TestVectorValues;
import org.opensearch.knn.plugin.stats.KNNGraphValue;
import org.opensearch.knn.quantization.enums.ScalarQuantizationType;
import org.opensearch.knn.quantization.models.quantizationParams.QuantizationParams;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.test.OpenSearchTestCase;

Expand Down Expand Up @@ -324,8 +326,13 @@ public void testMerge_WithQuantization() {
.thenReturn(knnVectorValues);

when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
when(quantizationParams.getTypeIdentifier()).thenReturn(
ScalarQuantizationParams.generateTypeIdentifier(ScalarQuantizationType.ONE_BIT)
);
try {
when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size())).thenReturn(quantizationState);
when(quantizationService.train(quantizationParams, knnVectorValues, mergedVectors.size(), fieldInfo)).thenReturn(
quantizationState
);
} catch (Exception e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index.engine;

import org.opensearch.Version;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.mapper.CompressionLevel;
Expand Down Expand Up @@ -68,18 +69,26 @@ public void testResolveEngine_whenCompressionIs1x_thenEngineBasedOnMode() {
);
}

public void testResolveEngine_whenCompressionIs4x_thenEngineIsLucene() {
public void testResolveEngine_whenCompressionIs4x_VersionIsBefore2_19_thenEngineIsLucene() {
assertEquals(
KNNEngine.LUCENE,
ENGINE_RESOLVER.resolveEngine(
KNNMethodConfigContext.builder().mode(Mode.ON_DISK).compressionLevel(CompressionLevel.x4).build(),
KNNMethodConfigContext.builder()
.mode(Mode.ON_DISK)
.compressionLevel(CompressionLevel.x4)
.versionCreated(Version.V_2_18_0)
.build(),
null,
false
)
);
assertEquals(
KNNEngine.LUCENE,
ENGINE_RESOLVER.resolveEngine(KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x4).build(), null, false)
ENGINE_RESOLVER.resolveEngine(
KNNMethodConfigContext.builder().compressionLevel(CompressionLevel.x4).versionCreated(Version.V_2_17_0).build(),
null,
false
)
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,21 +195,6 @@ private void validateResolveMethodContext(
}

public void testResolveMethod_whenInvalid_thenThrow() {
// Invalid compression
expectThrows(
ValidationException.class,
() -> TEST_RESOLVER.resolveMethod(
null,
KNNMethodConfigContext.builder()
.vectorDataType(VectorDataType.FLOAT)
.compressionLevel(CompressionLevel.x4)
.versionCreated(Version.CURRENT)
.build(),
false,
SpaceType.L2
)
);

expectThrows(
ValidationException.class,
() -> TEST_RESOLVER.resolveMethod(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,24 @@
package org.opensearch.knn.index.engine.faiss;

import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.engine.MethodComponentContext;
import org.opensearch.knn.index.mapper.CompressionLevel;

import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.ENCODER_SQ;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_FP16;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_ENCODER_INT8;
import static org.opensearch.knn.common.KNNConstants.FAISS_SQ_TYPE;

public class FaissSQEncoderTests extends KNNTestCase {
public void testCalculateCompressionLevel() {
FaissSQEncoder encoder = new FaissSQEncoder();
assertEquals(CompressionLevel.x2, encoder.calculateCompressionLevel(null, null));
assertEquals(CompressionLevel.x2, encoder.calculateCompressionLevel(generateMethodComponentContext(FAISS_SQ_ENCODER_FP16), null));
assertEquals(CompressionLevel.x4, encoder.calculateCompressionLevel(generateMethodComponentContext(FAISS_SQ_ENCODER_INT8), null));
}

private MethodComponentContext generateMethodComponentContext(String sqType) {
return new MethodComponentContext(ENCODER_SQ, Map.of(FAISS_SQ_TYPE, sqType));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,11 @@ public void testGetDefaultRescoreContext() {
assertNotNull(rescoreContext);
assertEquals(2.0f, rescoreContext.getOversampleFactor(), 0.0f);

// x4 with dimension <= 1000 should have an oversample factor of 5.0f (though it doesn't have its own RescoreContext)
rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, belowThresholdDimension, Version.CURRENT);
// x4 with dimension <= 1000 should have an oversample factor of 5.0f (though it doesn't have its own RescoreContext before V2.19.0)
rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, belowThresholdDimension, Version.V_2_18_1);
assertNull(rescoreContext);
// x4 with dimension > 1000 should return null (no RescoreContext is configured for x4)
rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, aboveThresholdDimension, Version.CURRENT);
// x4 with dimension > 1000 should return null (no RescoreContext is configured for x4 before V2.19.0)
rescoreContext = CompressionLevel.x4.getDefaultRescoreContext(mode, aboveThresholdDimension, Version.V_2_18_1);
assertNull(rescoreContext);
// Other compression levels should behave similarly with respect to dimension
rescoreContext = CompressionLevel.x2.getDefaultRescoreContext(mode, belowThresholdDimension, Version.CURRENT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1831,7 +1831,7 @@ public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOEx
true
);

// For 4x compression on disk, use Lucene
// For 4x compression on disk, use Faiss
xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
Expand All @@ -1847,7 +1847,7 @@ public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOEx
);
validateBuilderAfterParsing(
builder,
KNNEngine.LUCENE,
KNNEngine.FAISS,
SpaceType.L2,
VectorDataType.FLOAT,
CompressionLevel.x4,
Expand All @@ -1856,7 +1856,7 @@ public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOEx
false
);

// For 4x compression in memory, use Lucene
// For 4x compression in memory, use Faiss
xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
Expand All @@ -1872,7 +1872,7 @@ public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOEx
);
validateBuilderAfterParsing(
builder,
KNNEngine.LUCENE,
KNNEngine.FAISS,
SpaceType.L2,
VectorDataType.FLOAT,
CompressionLevel.x4,
Expand Down Expand Up @@ -1971,7 +1971,7 @@ public void testTypeParser_whenModeAndCompressionAreSet_thenHandle() throws IOEx
.field(COMPRESSION_LEVEL_PARAMETER, CompressionLevel.x4.getName())
.startObject(KNN_METHOD)
.field(NAME, METHOD_HNSW)
.field(KNN_ENGINE, KNNEngine.FAISS)
.field(KNN_ENGINE, KNNEngine.NMSLIB)
.endObject()
.endObject();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

package org.opensearch.knn.index.quantizationservice;

import org.apache.lucene.index.FieldInfo;
import org.opensearch.knn.KNNTestCase;
import org.junit.Before;

import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.codec.KNNCodecTestUtil;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
import org.opensearch.knn.index.vectorvalues.TestVectorValues;
Expand All @@ -24,6 +27,7 @@
public class QuantizationServiceTests extends KNNTestCase {
private QuantizationService<float[], byte[]> quantizationService;
private KNNVectorValues<float[]> knnVectorValues;
private FieldInfo fieldInfo;

@Before
public void setUp() throws Exception {
Expand All @@ -42,11 +46,19 @@ public void setUp() throws Exception {
VectorDataType.FLOAT,
new TestVectorValues.PreDefinedFloatVectorValues(floatVectors)
);

fieldInfo = KNNCodecTestUtil.FieldInfoBuilder.builder("test-field").addAttribute(KNNVectorFieldMapper.KNN_FIELD, "true").build();
}

public void testTrain_oneBitQuantizer_success() throws IOException {
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());

QuantizationState quantizationState = quantizationService.train(
oneBitParams,
knnVectorValues,
knnVectorValues.totalLiveDocs(),
fieldInfo
);

assertTrue(quantizationState instanceof OneBitScalarQuantizationState);
OneBitScalarQuantizationState oneBitState = (OneBitScalarQuantizationState) quantizationState;
Expand All @@ -62,7 +74,12 @@ public void testTrain_oneBitQuantizer_success() throws IOException {

public void testTrain_twoBitQuantizer_success() throws IOException {
ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
QuantizationState quantizationState = quantizationService.train(
twoBitParams,
knnVectorValues,
knnVectorValues.totalLiveDocs(),
fieldInfo
);

assertTrue(quantizationState instanceof MultiBitScalarQuantizationState);
MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState;
Expand All @@ -85,7 +102,12 @@ public void testTrain_twoBitQuantizer_success() throws IOException {

public void testTrain_fourBitQuantizer_success() throws IOException {
ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
QuantizationState quantizationState = quantizationService.train(
fourBitParams,
knnVectorValues,
knnVectorValues.totalLiveDocs(),
fieldInfo
);

assertTrue(quantizationState instanceof MultiBitScalarQuantizationState);
MultiBitScalarQuantizationState multiBitState = (MultiBitScalarQuantizationState) quantizationState;
Expand All @@ -110,7 +132,12 @@ public void testTrain_fourBitQuantizer_success() throws IOException {

public void testQuantize_oneBitQuantizer_success() throws IOException {
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
QuantizationState quantizationState = quantizationService.train(
oneBitParams,
knnVectorValues,
knnVectorValues.totalLiveDocs(),
fieldInfo
);

QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams);

Expand All @@ -125,7 +152,12 @@ public void testQuantize_oneBitQuantizer_success() throws IOException {

public void testQuantize_twoBitQuantizer_success() throws IOException {
ScalarQuantizationParams twoBitParams = new ScalarQuantizationParams(ScalarQuantizationType.TWO_BIT);
QuantizationState quantizationState = quantizationService.train(twoBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
QuantizationState quantizationState = quantizationService.train(
twoBitParams,
knnVectorValues,
knnVectorValues.totalLiveDocs(),
fieldInfo
);
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(twoBitParams);
byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 4.0f, 5.0f, 6.0f }, quantizationOutput);

Expand All @@ -138,7 +170,12 @@ public void testQuantize_twoBitQuantizer_success() throws IOException {

public void testQuantize_fourBitQuantizer_success() throws IOException {
ScalarQuantizationParams fourBitParams = new ScalarQuantizationParams(ScalarQuantizationType.FOUR_BIT);
QuantizationState quantizationState = quantizationService.train(fourBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
QuantizationState quantizationState = quantizationService.train(
fourBitParams,
knnVectorValues,
knnVectorValues.totalLiveDocs(),
fieldInfo
);
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(fourBitParams);

byte[] quantizedVector = quantizationService.quantize(quantizationState, new float[] { 7.0f, 8.0f, 9.0f }, quantizationOutput);
Expand All @@ -152,7 +189,12 @@ public void testQuantize_fourBitQuantizer_success() throws IOException {

public void testQuantize_whenInvalidInput_thenThrows() throws IOException {
ScalarQuantizationParams oneBitParams = new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT);
QuantizationState quantizationState = quantizationService.train(oneBitParams, knnVectorValues, knnVectorValues.totalLiveDocs());
QuantizationState quantizationState = quantizationService.train(
oneBitParams,
knnVectorValues,
knnVectorValues.totalLiveDocs(),
fieldInfo
);
QuantizationOutput quantizationOutput = quantizationService.createQuantizationOutput(oneBitParams);
assertThrows(IllegalArgumentException.class, () -> quantizationService.quantize(quantizationState, null, quantizationOutput));
}
Expand Down
Loading

0 comments on commit 6721604

Please sign in to comment.