Skip to content

Commit

Permalink
Implement RandomSamplingUtils.randomSelection (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
gstamatelat committed Jul 7, 2021
1 parent b2e06a4 commit 14f021e
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 0 deletions.
43 changes: 43 additions & 0 deletions src/main/java/gr/james/sampling/RandomSamplingUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -118,4 +118,47 @@ public static <E> boolean iteratorsEquals(Iterator<E> a, Iterator<E> b) {
}
return !a.hasNext() && !b.hasNext();
}

/**
* Performs an unweighted selection without replacement of <code>k</code> elements from a population of
* <code>n</code> elements.
* <p>
* The population and the sample are represented by their indices and, as a result, this method will return
* <code>k</code> random and discrete indices in the range <code>[0,n)</code>.
* <p>
* The selection is performed in such a way that the higher order inclusion probabilities of all
* <code>k</code>-tuples are equal.
* <p>
* This method runs in time proportional to <code>k</code> in the worst case.
*
* @param n the size of the population
* @param k the size of the sample
* @param rng the random number generator to use
* @return an array of <code>k</code> random and discrete integers in the range <code>[0,n)</code>
* @throws IllegalArgumentException if <code>n</code> or <code>k</code> is less than 1
* @throws IllegalArgumentException if <code>k &gt; n</code>
*/
public static int[] randomSelection(int n, int k, Random rng) {
if (n < 1) {
throw new IllegalArgumentException("n must be at least 1");
}
if (k < 1) {
throw new IllegalArgumentException("k must be at least 1");
}
if (k > n) {
throw new IllegalArgumentException("k must be at most n");
}

final int[] a = new int[k];
final Map<Integer, Integer> swaps = new HashMap<>();
for (int i = 0; i < k; i++) {
final int nextIndex = rng.nextInt(n - i);
a[i] = swaps.getOrDefault(nextIndex, nextIndex);
swaps.put(nextIndex, swaps.getOrDefault(n - i - 1, n - i - 1));
}

assert Arrays.stream(a).distinct().count() == k;

return a;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package gr.james.sampling;

import org.junit.Assert;
import org.junit.Test;

import java.util.*;
import java.util.stream.Collectors;

/**
* Tests for {@link RandomSamplingUtils#randomSelection(int, int, Random)}.
*/
public class RandomSamplingUtilsRandomSelectionTest {
/**
* Check correctness for n=6 and k=3. The total number of 3-tuples should be 60 with equal probability of inclusion.
*/
@Test
public void correctness() {
final int REPS = 200000000;
final Random rng = new Random();
final Map<List<Integer>, Long> frequencies = new HashMap<>();
for (int i = 0; i < REPS; i++) {
final int[] a = RandomSamplingUtils.randomSelection(5, 3, rng);
final List<Integer> aList = Arrays.stream(a).boxed().collect(Collectors.toList());
frequencies.put(aList, frequencies.getOrDefault(aList, 0L) + 1);
}
Assert.assertEquals(60, frequencies.size());
final long firstFrequency = frequencies.values().iterator().next();
for (long v : frequencies.values()) {
Assert.assertEquals(1.0, (double) firstFrequency / v, 1.0e-2);
}
}
}

0 comments on commit 14f021e

Please sign in to comment.