-
Notifications
You must be signed in to change notification settings - Fork 144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of Random Rotation Matrix and Query Scaling Matrix with 1-bit Binary Quantization #2524
base: main
Are you sure you want to change the base?
Implementation of Random Rotation Matrix and Query Scaling Matrix with 1-bit Binary Quantization #2524
Changes from all commits
72ce48f
68ac180
01219ae
b1afdef
4465832
f1ff640
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,13 +5,16 @@ | |
|
||
package org.opensearch.knn.quantization.models.quantizationState; | ||
|
||
import lombok.AllArgsConstructor; | ||
import lombok.Getter; | ||
import lombok.NoArgsConstructor; | ||
import lombok.NonNull; | ||
import lombok.AllArgsConstructor; | ||
import lombok.Builder; | ||
import org.apache.lucene.util.RamUsageEstimator; | ||
import org.opensearch.Version; | ||
import org.opensearch.core.common.io.stream.StreamInput; | ||
import org.opensearch.core.common.io.stream.StreamOutput; | ||
import org.opensearch.core.common.io.stream.Writeable; | ||
import org.opensearch.knn.quantization.models.quantizationParams.ScalarQuantizationParams; | ||
|
||
import java.io.IOException; | ||
|
@@ -21,10 +24,12 @@ | |
* including the mean values used for quantization. | ||
*/ | ||
@Getter | ||
@NoArgsConstructor // No-argument constructor for deserialization | ||
@Builder | ||
@AllArgsConstructor | ||
@NoArgsConstructor(force = true) | ||
public final class OneBitScalarQuantizationState implements QuantizationState { | ||
private ScalarQuantizationParams quantizationParams; | ||
@NonNull | ||
private final ScalarQuantizationParams quantizationParams; | ||
/** | ||
* Mean thresholds used in the quantization process. | ||
* Each threshold value corresponds to a dimension of the vector being quantized. | ||
|
@@ -33,7 +38,27 @@ public final class OneBitScalarQuantizationState implements QuantizationState { | |
* If we have a vector [1.2, 3.4, 5.6] and mean thresholds [2.0, 3.0, 4.0], | ||
* The quantized vector will be [0, 1, 1]. | ||
*/ | ||
private float[] meanThresholds; | ||
@NonNull | ||
private final float[] meanThresholds; | ||
|
||
/** | ||
* Represents the mean of all values below the threshold for each dimension. | ||
*/ | ||
@Builder.Default | ||
private float[] belowThresholdMeans = null; | ||
|
||
/** | ||
* Represents the mean of all values above the threshold for each dimension. | ||
*/ | ||
@Builder.Default | ||
private float[] aboveThresholdMeans = null; | ||
@Builder.Default | ||
private double averageL2L1Ratio = 0.0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a comment on this? |
||
/** | ||
* Rotation matrix used when L2/L1 ratio > 0.6 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add more detail? For instance, what are the properties of this matrix. Also, what will happen if L2/L1 ratio is < 0.6 |
||
*/ | ||
@Builder.Default | ||
private float[][] rotationMatrix = null; | ||
|
||
@Override | ||
public ScalarQuantizationParams getQuantizationParams() { | ||
|
@@ -51,6 +76,20 @@ public void writeTo(StreamOutput out) throws IOException { | |
out.writeVInt(Version.CURRENT.id); // Write the version | ||
quantizationParams.writeTo(out); | ||
out.writeFloatArray(meanThresholds); | ||
out.writeOptionalArray(belowThresholdMeans != null ? new FloatArrayWrapper[] { new FloatArrayWrapper(belowThresholdMeans) } : null); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For BWC purposes, writeOptionalArray wont work in a mixed cluster state. We need to check the stream version. |
||
// Serialize aboveThresholdMeans using writeOptionalArray | ||
out.writeOptionalArray(aboveThresholdMeans != null ? new FloatArrayWrapper[] { new FloatArrayWrapper(aboveThresholdMeans) } : null); | ||
out.writeOptionalDouble(averageL2L1Ratio); | ||
// Write rotation matrix | ||
if (rotationMatrix != null) { | ||
out.writeBoolean(true); | ||
out.writeVInt(rotationMatrix.length); | ||
for (float[] row : rotationMatrix) { | ||
out.writeFloatArray(row); | ||
} | ||
} else { | ||
out.writeBoolean(false); | ||
} | ||
} | ||
|
||
/** | ||
|
@@ -63,6 +102,23 @@ public OneBitScalarQuantizationState(StreamInput in) throws IOException { | |
int version = in.readVInt(); // Read the version | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this going to be bwc? We need to check if the segment was created in a previous version as opposed to the node it is coming from. |
||
this.quantizationParams = new ScalarQuantizationParams(in, version); | ||
this.meanThresholds = in.readFloatArray(); | ||
if (Version.fromId(version).onOrAfter(Version.V_3_0_0)) { | ||
// Deserialize belowThresholdMeans using readOptionalArray | ||
FloatArrayWrapper[] wrappedBelowThresholdMeans = in.readOptionalArray(FloatArrayWrapper::new, FloatArrayWrapper[]::new); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why optional on this? Is there a case where its null? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes ..for Version less than 3.0.0 |
||
this.belowThresholdMeans = wrappedBelowThresholdMeans != null ? wrappedBelowThresholdMeans[0].getArray() : null; | ||
// Deserialize aboveThresholdMeans using readOptionalArray | ||
FloatArrayWrapper[] wrappedAboveThresholdMeans = in.readOptionalArray(FloatArrayWrapper::new, FloatArrayWrapper[]::new); | ||
this.aboveThresholdMeans = wrappedAboveThresholdMeans != null ? wrappedAboveThresholdMeans[0].getArray() : null; | ||
this.averageL2L1Ratio = in.readOptionalDouble(); | ||
// Read rotation matrix | ||
if (in.readBoolean()) { | ||
int dimensions = in.readVInt(); | ||
this.rotationMatrix = new float[dimensions][]; | ||
for (int i = 0; i < dimensions; i++) { | ||
this.rotationMatrix[i] = in.readFloatArray(); | ||
} | ||
} | ||
} | ||
} | ||
|
||
/** | ||
|
@@ -139,6 +195,41 @@ public long ramBytesUsed() { | |
long size = RamUsageEstimator.shallowSizeOfInstance(OneBitScalarQuantizationState.class); | ||
size += RamUsageEstimator.shallowSizeOf(quantizationParams); | ||
size += RamUsageEstimator.sizeOf(meanThresholds); | ||
if (belowThresholdMeans != null) { | ||
size += RamUsageEstimator.sizeOf(belowThresholdMeans); | ||
} | ||
if (aboveThresholdMeans != null) { | ||
size += RamUsageEstimator.sizeOf(aboveThresholdMeans); | ||
} | ||
if (rotationMatrix != null) { | ||
size += RamUsageEstimator.shallowSizeOf(rotationMatrix); | ||
// Add size of each row array | ||
for (float[] row : rotationMatrix) { | ||
size += RamUsageEstimator.sizeOf(row); | ||
} | ||
} | ||
return size; | ||
} | ||
|
||
private class FloatArrayWrapper implements Writeable { | ||
private final float[] array; | ||
|
||
public FloatArrayWrapper(float[] array) { | ||
this.array = array; | ||
} | ||
|
||
// Constructor that matches Writeable.Reader<T> | ||
public FloatArrayWrapper(StreamInput in) throws IOException { | ||
this.array = in.readFloatArray(); | ||
} | ||
|
||
public float[] getArray() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: lombok getter |
||
return array; | ||
} | ||
|
||
@Override | ||
public void writeTo(StreamOutput out) throws IOException { | ||
out.writeFloatArray(array); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add examples for these, as was done for meanThresholds?