Skip to content

Commit

Permalink
Cache buckets to speed up BytesRefHash#sort (#12784)
Browse files Browse the repository at this point in the history
  • Loading branch information
gf2121 committed Nov 10, 2023
1 parent 9a02453 commit d684987
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 34 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ Optimizations

* GITHUB#12381: Skip docs with DocValues in NumericLeafComparator. (Lu Xugang, Adrien Grand)

* GITHUB#12784: Cache buckets to speed up BytesRefHash#sort. (Guo Feng)

Changes in runtime behavior
---------------------

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,10 @@ <E extends Exception> void forEachOrdered(DeletedTermConsumer<E> consumer) throw
scratch.field = deleteFieldEntry.getKey();
BufferedUpdates.BytesRefIntMap terms = deleteFieldEntry.getValue();
int[] indices = terms.bytesRefHash.sort();
for (int index : indices) {
if (index != -1) {
terms.bytesRefHash.get(index, scratch.bytes);
consumer.accept(scratch, terms.values[index]);
}
for (int i = 0; i < terms.bytesRefHash.size(); i++) {
int index = indices[i];
terms.bytesRefHash.get(index, scratch.bytes);
consumer.accept(scratch, terms.values[index]);
}
}
}
Expand Down
55 changes: 55 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/BytesRefHash.java
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,63 @@ public int[] compact() {
*/
public int[] sort() {
final int[] compact = compact();
assert count * 2 <= compact.length : "We need load factor <= 0.5f to speed up this sort";
final int tmpOffset = count;
new StringSorter(BytesRefComparator.NATURAL) {

@Override
protected Sorter radixSorter(BytesRefComparator cmp) {
return new MSBStringRadixSorter(cmp) {

private int k;

@Override
protected void buildHistogram(
int prefixCommonBucket,
int prefixCommonLen,
int from,
int to,
int k,
int[] histogram) {
this.k = k;
histogram[prefixCommonBucket] = prefixCommonLen;
Arrays.fill(
compact, tmpOffset + from - prefixCommonLen, tmpOffset + from, prefixCommonBucket);
for (int i = from; i < to; ++i) {
int b = getBucket(i, k);
compact[tmpOffset + i] = b;
histogram[b]++;
}
}

@Override
protected boolean shouldFallback(int from, int to, int l) {
// We lower the fallback threshold because the bucket cache speeds up the reorder
return to - from <= LENGTH_THRESHOLD / 2 || l >= LEVEL_THRESHOLD;
}

private void swapBucketCache(int i, int j) {
swap(i, j);
int tmp = compact[tmpOffset + i];
compact[tmpOffset + i] = compact[tmpOffset + j];
compact[tmpOffset + j] = tmp;
}

@Override
protected void reorder(int from, int to, int[] startOffsets, int[] endOffsets, int k) {
assert this.k == k;
for (int i = 0; i < HISTOGRAM_SIZE; ++i) {
final int limit = endOffsets[i];
for (int h1 = startOffsets[i]; h1 < limit; h1 = startOffsets[i]) {
final int b = compact[tmpOffset + from + h1];
final int h2 = startOffsets[b]++;
swapBucketCache(from + h1, from + h2);
}
}
}
};
}

@Override
protected void swap(int i, int j) {
int tmp = compact[i];
Expand Down
22 changes: 11 additions & 11 deletions lucene/core/src/java/org/apache/lucene/util/MSBRadixSorter.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ public abstract class MSBRadixSorter extends Sorter {
// this is used as a protection against the fact that radix sort performs
// worse when there are long common prefixes (probably because of cache
// locality)
private static final int LEVEL_THRESHOLD = 8;
protected static final int LEVEL_THRESHOLD = 8;
// size of histograms: 256 + 1 to indicate that the string is finished
protected static final int HISTOGRAM_SIZE = 257;
// buckets below this size will be sorted with introsort
private static final int LENGTH_THRESHOLD = 100;
// buckets below this size will be sorted with fallback sorter
protected static final int LENGTH_THRESHOLD = 100;

// we store one histogram per recursion level
private final int[][] histograms = new int[LEVEL_THRESHOLD][];
Expand Down Expand Up @@ -130,15 +130,15 @@ public void sort(int from, int to) {
}

protected void sort(int from, int to, int k, int l) {
if (to - from <= LENGTH_THRESHOLD || l >= LEVEL_THRESHOLD) {
introSort(from, to, k);
if (shouldFallback(from, to, l)) {
getFallbackSorter(k).sort(from, to);
} else {
radixSort(from, to, k, l);
}
}

private void introSort(int from, int to, int k) {
getFallbackSorter(k).sort(from, to);
protected boolean shouldFallback(int from, int to, int l) {
return to - from <= LENGTH_THRESHOLD || l >= LEVEL_THRESHOLD;
}

/**
Expand Down Expand Up @@ -233,8 +233,6 @@ private int computeCommonPrefixLengthAndBuildHistogram(int from, int to, int k,
if (b != commonPrefix[j]) {
commonPrefixLength = j;
if (commonPrefixLength == 0) { // we have no common prefix
histogram[commonPrefix[0] + 1] = i - from;
histogram[b + 1] = 1;
break outer;
}
break;
Expand All @@ -245,7 +243,7 @@ private int computeCommonPrefixLengthAndBuildHistogram(int from, int to, int k,
if (i < to) {
// the loop got broken because there is no common prefix
assert commonPrefixLength == 0;
buildHistogram(i + 1, to, k, histogram);
buildHistogram(commonPrefix[0] + 1, i - from, i, to, k, histogram);
} else {
assert commonPrefixLength > 0;
histogram[commonPrefix[0] + 1] = to - from;
Expand All @@ -258,7 +256,9 @@ private int computeCommonPrefixLengthAndBuildHistogram(int from, int to, int k,
* Build an histogram of the k-th characters of values occurring between offsets {@code from} and
* {@code to}, using {@link #getBucket}.
*/
private void buildHistogram(int from, int to, int k, int[] histogram) {
protected void buildHistogram(
int prefixCommonBucket, int prefixCommonLen, int from, int to, int k, int[] histogram) {
histogram[prefixCommonBucket] = prefixCommonLen;
for (int i = from; i < to; ++i) {
histogram[getBucket(i, k)]++;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,9 @@ protected void restore(int i, int j) {

@Override
protected int compare(int i, int j) {
return StableStringSorter.this.compare(i, j);
get(scratch1, scratchBytes1, i);
get(scratch2, scratchBytes2, j);
return cmp.compare(scratchBytes1, scratchBytes2);
}

@Override
Expand Down
47 changes: 30 additions & 17 deletions lucene/core/src/java/org/apache/lucene/util/StringSorter.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,24 +58,35 @@ public void sort(int from, int to) {
}
}

protected Sorter radixSorter(BytesRefComparator cmp) {
return new MSBRadixSorter(cmp.comparedBytesCount) {
@Override
protected void swap(int i, int j) {
StringSorter.this.swap(i, j);
}
/** A radix sorter for {@link BytesRef} */
protected class MSBStringRadixSorter extends MSBRadixSorter {

@Override
protected int byteAt(int i, int k) {
get(scratch1, scratchBytes1, i);
return cmp.byteAt(scratchBytes1, k);
}
private final BytesRefComparator cmp;

@Override
protected Sorter getFallbackSorter(int k) {
return fallbackSorter((o1, o2) -> cmp.compare(o1, o2, k));
}
};
protected MSBStringRadixSorter(BytesRefComparator cmp) {
super(cmp.comparedBytesCount);
this.cmp = cmp;
}

@Override
protected void swap(int i, int j) {
StringSorter.this.swap(i, j);
}

@Override
protected int byteAt(int i, int k) {
get(scratch1, scratchBytes1, i);
return cmp.byteAt(scratchBytes1, k);
}

@Override
protected Sorter getFallbackSorter(int k) {
return fallbackSorter((o1, o2) -> cmp.compare(o1, o2, k));
}
}

protected Sorter radixSorter(BytesRefComparator cmp) {
return new MSBStringRadixSorter(cmp);
}

protected Sorter fallbackSorter(Comparator<BytesRef> cmp) {
Expand All @@ -87,7 +98,9 @@ protected void swap(int i, int j) {

@Override
protected int compare(int i, int j) {
return StringSorter.this.compare(i, j);
get(scratch1, scratchBytes1, i);
get(scratch2, scratchBytes2, j);
return cmp.compare(scratchBytes1, scratchBytes2);
}

@Override
Expand Down

0 comments on commit d684987

Please sign in to comment.