Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of Random Rotation Matrix and Query Scaling Matrix with 1-bit Binary Quantization #2524

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Enabled indices.breaker.total.use_real_memory setting via build.gradle for integTest Cluster to catch heap CB in local ITs and github CI actions [#2395](https://github.com/opensearch-project/k-NN/pull/2395/)
* Fixing Lucene912Codec Issue with BWC for Lucene 10.0.1 upgrade[#2429](https://github.com/opensearch-project/k-NN/pull/2429)
* Enabled idempotency of local builds when using `./gradlew clean` and nest `jni/release` directory under `jni/build` for easier cleanup [#2516](https://github.com/opensearch-project/k-NN/pull/2516)
### Refactoring
### Refactoring
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -595,4 +595,4 @@ task updateVersion {
ant.replaceregexp(file:".github/workflows/backwards_compatibility_tests_workflow.yml", match: oldBWCVersion, replace: oldBWCVersion + '", "' + opensearch_version.tokenize('-')[0], flags:'g', byline:true)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,7 @@ private QuantizationState train(
QuantizationState quantizationState = null;
if (quantizationParams != null && totalLiveDocs > 0) {
initQuantizationStateWriterIfNecessary();
KNNVectorValues<?> knnVectorValues = knnVectorValuesSupplier.get();
quantizationState = quantizationService.train(quantizationParams, knnVectorValues, totalLiveDocs);
quantizationState = quantizationService.train(quantizationParams, knnVectorValuesSupplier, totalLiveDocs);
quantizationStateWriter.writeState(fieldInfo.getFieldNumber(), quantizationState);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,27 @@
import java.io.IOException;

import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;
import java.util.function.Supplier;

/**
* KNNVectorQuantizationTrainingRequest is a concrete implementation of the abstract TrainingRequest class.
* It provides a mechanism to retrieve float vectors from the KNNVectorValues by document ID.
*/
@Log4j2
final class KNNVectorQuantizationTrainingRequest<T> extends TrainingRequest<T> {

private final KNNVectorValues<T> knnVectorValues;
private final Supplier<KNNVectorValues<T>> knnVectorValuesSupplier;
private KNNVectorValues<T> knnVectorValues;
private int lastIndex;

/**
* Constructs a new QuantizationFloatVectorTrainingRequest.
*
* @param knnVectorValues the KNNVectorValues instance containing the vectors.
* @param knnVectorValuesSupplier the KNNVectorValues instance containing the vectors.
*/
KNNVectorQuantizationTrainingRequest(KNNVectorValues<T> knnVectorValues, long liveDocs) {
KNNVectorQuantizationTrainingRequest(Supplier<KNNVectorValues<T>> knnVectorValuesSupplier, long liveDocs) {
super((int) liveDocs);
this.knnVectorValues = knnVectorValues;
this.knnVectorValuesSupplier = knnVectorValuesSupplier;
resetVectorValues(); // Initialize the first instance
this.lastIndex = 0;
}

Expand All @@ -52,4 +54,13 @@ public T getVectorAtThePosition(int position) throws IOException {
// Return the vector
return knnVectorValues.getVector();
}

/**
* Resets the KNNVectorValues to enable a fresh iteration by calling the supplier again.
*/
@Override
public void resetVectorValues() {
this.knnVectorValues = knnVectorValuesSupplier.get();
this.lastIndex = 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.knn.quantization.models.quantizationState.QuantizationState;
import org.opensearch.knn.quantization.quantizer.Quantizer;
import java.io.IOException;
import java.util.function.Supplier;

import static org.opensearch.knn.common.FieldInfoExtractor.extractQuantizationConfig;

Expand Down Expand Up @@ -53,19 +54,22 @@ public static <T, R> QuantizationService<T, R> getInstance() {
* {@link QuantizationState}. The quantizer is determined based on the given {@link QuantizationParams}.
*
* @param quantizationParams The {@link QuantizationParams} containing the parameters for quantization.
* @param knnVectorValues The {@link KNNVectorValues} representing the vector data to be used for training.
* @param knnVectorValuesSupplier The {@link KNNVectorValues} representing the vector data to be used for training.
* @return The {@link QuantizationState} containing the state of the trained quantizer.
* @throws IOException If an I/O error occurs during the training process.
*/
public QuantizationState train(
final QuantizationParams quantizationParams,
final KNNVectorValues<T> knnVectorValues,
final Supplier<KNNVectorValues<T>> knnVectorValuesSupplier,
final long liveDocs
) throws IOException {
Quantizer<T, R> quantizer = QuantizerFactory.getQuantizer(quantizationParams);

// Create the training request from the vector values
KNNVectorQuantizationTrainingRequest<T> trainingRequest = new KNNVectorQuantizationTrainingRequest<>(knnVectorValues, liveDocs);
// Create the training request using the supplier
KNNVectorQuantizationTrainingRequest<T> trainingRequest = new KNNVectorQuantizationTrainingRequest<>(
knnVectorValuesSupplier,
liveDocs
);

// Train the quantizer and return the quantization state
return quantizer.train(trainingRequest);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@

package org.opensearch.knn.quantization.models.quantizationState;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.NonNull;
import lombok.AllArgsConstructor;
import lombok.Builder;
import org.apache.lucene.util.RamUsageEstimator;
import org.opensearch.Version;
import org.opensearch.core.common.io.stream.StreamInput;
import org.opensearch.core.common.io.stream.StreamOutput;
import org.opensearch.core.common.io.stream.Writeable;
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams;

import java.io.IOException;
Expand All @@ -21,10 +24,12 @@
* including the mean values used for quantization.
*/
@Getter
@NoArgsConstructor // No-argument constructor for deserialization
@Builder
@AllArgsConstructor
@NoArgsConstructor(force = true)
public final class OneBitScalarQuantizationState implements QuantizationState {
private ScalarQuantizationParams quantizationParams;
@NonNull
private final ScalarQuantizationParams quantizationParams;
/**
* Mean thresholds used in the quantization process.
* Each threshold value corresponds to a dimension of the vector being quantized.
Expand All @@ -33,7 +38,27 @@ public final class OneBitScalarQuantizationState implements QuantizationState {
* If we have a vector [1.2, 3.4, 5.6] and mean thresholds [2.0, 3.0, 4.0],
* The quantized vector will be [0, 1, 1].
*/
private float[] meanThresholds;
@NonNull
private final float[] meanThresholds;

/**
* Represents the mean of all values below the threshold for each dimension.
*/
@Builder.Default
private float[] belowThresholdMeans = null;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add examples for these, as was done for meanThresholds?


/**
* Represents the mean of all values above the threshold for each dimension.
*/
@Builder.Default
private float[] aboveThresholdMeans = null;
@Builder.Default
private double averageL2L1Ratio = 0.0;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment on this?

/**
* Rotation matrix used when L2/L1 ratio > 0.6
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add more detail? For instance, what are the properties of this matrix. Also, what will happen if L2/L1 ratio is < 0.6

*/
@Builder.Default
private float[][] rotationMatrix = null;

@Override
public ScalarQuantizationParams getQuantizationParams() {
Expand All @@ -51,6 +76,20 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(Version.CURRENT.id); // Write the version
quantizationParams.writeTo(out);
out.writeFloatArray(meanThresholds);
out.writeOptionalArray(belowThresholdMeans != null ? new FloatArrayWrapper[] { new FloatArrayWrapper(belowThresholdMeans) } : null);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For BWC purposes, writeOptionalArray wont work in a mixed cluster state. We need to check the stream version.

// Serialize aboveThresholdMeans using writeOptionalArray
out.writeOptionalArray(aboveThresholdMeans != null ? new FloatArrayWrapper[] { new FloatArrayWrapper(aboveThresholdMeans) } : null);
out.writeOptionalDouble(averageL2L1Ratio);
// Write rotation matrix
if (rotationMatrix != null) {
out.writeBoolean(true);
out.writeVInt(rotationMatrix.length);
for (float[] row : rotationMatrix) {
out.writeFloatArray(row);
}
} else {
out.writeBoolean(false);
}
}

/**
Expand All @@ -63,6 +102,23 @@ public OneBitScalarQuantizationState(StreamInput in) throws IOException {
int version = in.readVInt(); // Read the version
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this going to be bwc? We need to check if the segment was created in a previous version as opposed to the node it is coming from.

this.quantizationParams = new ScalarQuantizationParams(in, version);
this.meanThresholds = in.readFloatArray();
if (Version.fromId(version).onOrAfter(Version.V_3_0_0)) {
// Deserialize belowThresholdMeans using readOptionalArray
FloatArrayWrapper[] wrappedBelowThresholdMeans = in.readOptionalArray(FloatArrayWrapper::new, FloatArrayWrapper[]::new);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why optional on this? Is there a case where its null?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes ..for Version less than 3.0.0

this.belowThresholdMeans = wrappedBelowThresholdMeans != null ? wrappedBelowThresholdMeans[0].getArray() : null;
// Deserialize aboveThresholdMeans using readOptionalArray
FloatArrayWrapper[] wrappedAboveThresholdMeans = in.readOptionalArray(FloatArrayWrapper::new, FloatArrayWrapper[]::new);
this.aboveThresholdMeans = wrappedAboveThresholdMeans != null ? wrappedAboveThresholdMeans[0].getArray() : null;
this.averageL2L1Ratio = in.readOptionalDouble();
// Read rotation matrix
if (in.readBoolean()) {
int dimensions = in.readVInt();
this.rotationMatrix = new float[dimensions][];
for (int i = 0; i < dimensions; i++) {
this.rotationMatrix[i] = in.readFloatArray();
}
}
}
}

/**
Expand Down Expand Up @@ -139,6 +195,41 @@ public long ramBytesUsed() {
long size = RamUsageEstimator.shallowSizeOfInstance(OneBitScalarQuantizationState.class);
size += RamUsageEstimator.shallowSizeOf(quantizationParams);
size += RamUsageEstimator.sizeOf(meanThresholds);
if (belowThresholdMeans != null) {
size += RamUsageEstimator.sizeOf(belowThresholdMeans);
}
if (aboveThresholdMeans != null) {
size += RamUsageEstimator.sizeOf(aboveThresholdMeans);
}
if (rotationMatrix != null) {
size += RamUsageEstimator.shallowSizeOf(rotationMatrix);
// Add size of each row array
for (float[] row : rotationMatrix) {
size += RamUsageEstimator.sizeOf(row);
}
}
return size;
}

private class FloatArrayWrapper implements Writeable {
private final float[] array;

public FloatArrayWrapper(float[] array) {
this.array = array;
}

// Constructor that matches Writeable.Reader<T>
public FloatArrayWrapper(StreamInput in) throws IOException {
this.array = in.readFloatArray();
}

public float[] getArray() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: lombok getter

return array;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeFloatArray(array);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,6 @@ public abstract class TrainingRequest<T> {
* @return the vector corresponding to the specified document ID.
*/
public abstract T getVectorAtThePosition(int position) throws IOException;

public abstract void resetVectorValues();
}
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,11 @@ public OneBitScalarQuantizer(final int samplingSize, final Sampler sampler) {
@Override
public QuantizationState train(final TrainingRequest<float[]> trainingRequest) throws IOException {
int[] sampledDocIds = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize);
float[] meanThresholds = QuantizerHelper.calculateMeanThresholds(trainingRequest, sampledDocIds);
return new OneBitScalarQuantizationState(new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT), meanThresholds);
return QuantizerHelper.calculateQuantizationState(
trainingRequest,
sampledDocIds,
new ScalarQuantizationParams(ScalarQuantizationType.ONE_BIT)
);
}

/**
Expand All @@ -73,7 +76,7 @@ public QuantizationState train(final TrainingRequest<float[]> trainingRequest) t
* @param output the QuantizationOutput object to store the quantized representation of the vector.
*/
@Override
public void quantize(final float[] vector, final QuantizationState state, final QuantizationOutput<byte[]> output) {
public void quantize(float[] vector, final QuantizationState state, final QuantizationOutput<byte[]> output) {
if (vector == null) {
throw new IllegalArgumentException("Vector to quantize must not be null.");
}
Expand All @@ -84,6 +87,10 @@ public void quantize(final float[] vector, final QuantizationState state, final
if (thresholds == null || thresholds.length != vectorLength) {
throw new IllegalArgumentException("Thresholds must not be null and must match the dimension of the vector.");
}
float[][] rotationMatrix = binaryState.getRotationMatrix();
if (rotationMatrix != null) {
vector = RandomGaussianRotation.applyRotation(vector, rotationMatrix);
}
output.prepareQuantizedVector(vectorLength);
BitPacker.quantizeAndPackBits(vector, thresholds, output.getQuantizedVector());
}
Expand Down
Loading
Loading