Skip to content

Commit

Permalink
Quantization using Quantile method
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jan 10, 2025
1 parent cdc469a commit 4554912
Showing 1 changed file with 51 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import oshi.util.tuples.Pair;

import java.io.IOException;
import java.util.Arrays;

public class ByteScalarQuantizer implements Quantizer<float[], byte[]> {
private final int bitsPerCoordinate;
Expand Down Expand Up @@ -47,11 +48,21 @@ public ByteScalarQuantizer(final int bitsPerCoordinate) {
// return new ByteScalarQuantizationState(params, minAndMax.getA(), diff);
// }

// Train using mean standard deviation
// @Override
// public QuantizationState train(TrainingRequest<float[]> trainingRequest) throws IOException {
// int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize);
// Pair<float[], float[]> meanAndStdDev = QuantizerHelper.calculateMeanAndStdDev(trainingRequest, sampledIndices);
// Pair<float[], float[]> minAndDiff = calculateMinAndDiff(meanAndStdDev.getA(), meanAndStdDev.getB());
// ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.EIGHT_BIT);
// return new ByteScalarQuantizationState(params, minAndDiff.getA(), minAndDiff.getB());
// }

@Override
public QuantizationState train(TrainingRequest<float[]> trainingRequest) throws IOException {
int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize);
Pair<float[], float[]> meanAndStdDev = QuantizerHelper.calculateMeanAndStdDev(trainingRequest, sampledIndices);
Pair<float[], float[]> minAndDiff = calculateMinAndDiff(meanAndStdDev.getA(), meanAndStdDev.getB());
float[][] transposedVec = transposeVectors(trainingRequest, sampledIndices);
Pair<float[], float[]> minAndDiff = calculateMinAndDiffUsingQuantile(transposedVec);
ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.EIGHT_BIT);
return new ByteScalarQuantizationState(params, minAndDiff.getA(), minAndDiff.getB());
}
Expand Down Expand Up @@ -113,25 +124,44 @@ private Pair<float[], float[]> calculateMinAndDiff(final float[] meanArray, fina
return new Pair<>(minArray, diffArray);
}

// private float[][] transposeVectors(TrainingRequest<float[]> trainingRequest, int[] sampledIndices) throws IOException {
// int totalSamples = sampledIndices.length;
// if (totalSamples > 0) {
// float[] vector = trainingRequest.getVectorAtThePosition(sampledIndices[0]);
// if (vector == null) {
// throw new IllegalArgumentException("Vector at sampled index " + sampledIndices[0] + " is null.");
// }
// }
// int dimension = trainingRequest.getVectorAtThePosition(sampledIndices[0]).length;
// float[][] transposedVec = new float[dimension][totalSamples];
//
// for(int i=0; i<totalSamples; i++) {
// float[] vector = trainingRequest.getVectorAtThePosition(sampledIndices[0]);
// for(int j=0; j<dimension;j++) {
// transposedVec[j][i] = vector[j];
// }
// }
// return transposedVec;
// }
private Pair<float[], float[]> calculateMinAndDiffUsingQuantile(float[][] transposedVectors) {
int dimension = transposedVectors.length;
float rs_arg = 2f;
float[] diffArray = new float[dimension];
float[] minArray = new float[dimension];
int n = transposedVectors[0].length;

for (int i = 0; i < dimension; i++) {
float[] vec = transposedVectors[i];
Arrays.sort(vec);
int o = (int) (rs_arg * n);
if (o < 0) o = 0;
if (o > n - o) o = n / 2;
minArray[i] = vec[o];
diffArray[i] = vec[n - 1 - o] - minArray[i];
}
return new Pair<>(minArray, diffArray);
}

private float[][] transposeVectors(TrainingRequest<float[]> trainingRequest, int[] sampledIndices) throws IOException {
int totalSamples = sampledIndices.length;
if (totalSamples > 0) {
float[] vector = trainingRequest.getVectorAtThePosition(sampledIndices[0]);
if (vector == null) {
throw new IllegalArgumentException("Vector at sampled index " + sampledIndices[0] + " is null.");
}
}
int dimension = trainingRequest.getVectorAtThePosition(sampledIndices[0]).length;
float[][] transposedVec = new float[dimension][totalSamples];

for (int i = 0; i < totalSamples; i++) {
float[] vector = trainingRequest.getVectorAtThePosition(sampledIndices[0]);
for (int j = 0; j < dimension; j++) {
transposedVec[j][i] = vector[j];
}
}
return transposedVec;
}

private void validateState(final QuantizationState state) {
if (!(state instanceof ByteScalarQuantizationState)) {
Expand Down

0 comments on commit 4554912

Please sign in to comment.