Skip to content

Commit

Permalink
Quantization using Mean and Standard Deviation
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jan 3, 2025
1 parent 234030f commit cdc469a
Showing 1 changed file with 47 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,23 @@ public ByteScalarQuantizer(final int bitsPerCoordinate) {
this.sampler = SamplingFactory.getSampler(SamplerType.RESERVOIR);
}

// Train using min-max technique
// @Override
// public QuantizationState train(TrainingRequest<float[]> trainingRequest) throws IOException {
// int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize);
// Pair<float[], float[]> minAndMax = QuantizerHelper.calculateMinAndMax(trainingRequest, sampledIndices);
// float[] diff = calculateDiff(minAndMax.getA(), minAndMax.getB());
// ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.EIGHT_BIT);
// return new ByteScalarQuantizationState(params, minAndMax.getA(), diff);
// }

@Override
public QuantizationState train(TrainingRequest<float[]> trainingRequest) throws IOException {
int[] sampledIndices = sampler.sample(trainingRequest.getTotalNumberOfVectors(), samplingSize);
Pair<float[], float[]> minAndMax = QuantizerHelper.calculateMinAndMax(trainingRequest, sampledIndices);
float[] diff = calculateDiff(minAndMax.getA(), minAndMax.getB());
Pair<float[], float[]> meanAndStdDev = QuantizerHelper.calculateMeanAndStdDev(trainingRequest, sampledIndices);
Pair<float[], float[]> minAndDiff = calculateMinAndDiff(meanAndStdDev.getA(), meanAndStdDev.getB());
ScalarQuantizationParams params = new ScalarQuantizationParams(ScalarQuantizationType.EIGHT_BIT);
return new ByteScalarQuantizationState(params, minAndMax.getA(), diff);
return new ByteScalarQuantizationState(params, minAndDiff.getA(), minAndDiff.getB());
}

@Override
Expand Down Expand Up @@ -89,6 +99,40 @@ private float[] calculateDiff(final float[] minArray, final float[] maxArray) {
return diffArray;
}

private Pair<float[], float[]> calculateMinAndDiff(final float[] meanArray, final float[] stdArray) {
int dimension = meanArray.length;
float rs_arg = 1.0f;
float[] diffArray = new float[dimension];
float[] minArray = new float[dimension];

for (int i = 0; i < dimension; i++) {
minArray[i] = meanArray[i] - stdArray[i] * rs_arg;
float max = meanArray[i] + stdArray[i] * rs_arg;
diffArray[i] = max - minArray[i];
}
return new Pair<>(minArray, diffArray);
}

// private float[][] transposeVectors(TrainingRequest<float[]> trainingRequest, int[] sampledIndices) throws IOException {
// int totalSamples = sampledIndices.length;
// if (totalSamples > 0) {
// float[] vector = trainingRequest.getVectorAtThePosition(sampledIndices[0]);
// if (vector == null) {
// throw new IllegalArgumentException("Vector at sampled index " + sampledIndices[0] + " is null.");
// }
// }
// int dimension = trainingRequest.getVectorAtThePosition(sampledIndices[0]).length;
// float[][] transposedVec = new float[dimension][totalSamples];
//
// for(int i=0; i<totalSamples; i++) {
// float[] vector = trainingRequest.getVectorAtThePosition(sampledIndices[0]);
// for(int j=0; j<dimension;j++) {
// transposedVec[j][i] = vector[j];
// }
// }
// return transposedVec;
// }

private void validateState(final QuantizationState state) {
if (!(state instanceof ByteScalarQuantizationState)) {
throw new IllegalArgumentException("Quantization state must be of type ByteScalarQuantizationState.");
Expand Down

0 comments on commit cdc469a

Please sign in to comment.