Skip to content

Commit

Permalink
Improves union find by implementing weighted union and path compressi…
Browse files Browse the repository at this point in the history
…on (#12)
  • Loading branch information
matt035343 authored Jul 7, 2022
1 parent 663890f commit a417a37
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/prepare_release.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ jobs:
- uses: SneaksAndData/github-actions/[email protected]
with:
major_v: 0
minor_v: 1
minor_v: 2
2 changes: 1 addition & 1 deletion anti_clustering/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def _post_process(
:param cluster_assignment_matrix: A matrix containing for each pair of elements if they belong to the same anti-cluster.
:return: The inputted dataframe with the new destination column.
"""
components = UnionFind({i: i for i in range(len(df))})
components = UnionFind(len(df))

for j in range(len(df)):
for i in range(0, j):
Expand Down
4 changes: 3 additions & 1 deletion anti_clustering/_cluster_swap_heuristic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def _get_random_clusters(self, num_groups: int, num_elements: int) -> npt.NDArra
initial_clusters = [i % num_groups for i in range(num_elements - num_groups)]
self.rnd.shuffle(initial_clusters)
initial_clusters = list(range(num_groups)) + initial_clusters
uf_init = UnionFind({i: cluster for i, cluster in enumerate(initial_clusters)}) # pylint: disable = R1721
uf_init = UnionFind(len(initial_clusters)) # pylint: disable = R1721
for i, cluster in enumerate(initial_clusters):
uf_init.union(i, cluster)

cluster_assignment = np.array(
[[uf_init.connected(i, j) for i in range(num_elements)] for j in range(num_elements)]
Expand Down
45 changes: 35 additions & 10 deletions anti_clustering/_union_find.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,37 +13,48 @@
# limitations under the License.
"""
A union find data structure for collecting results of the anti-clustering algorithm.
Based on:
Sedgewick, R. & Wayne, K. (2011), Algorithms, 4th Edition. , Addison-Wesley .
"""

from typing import Dict, TypeVar, Generic
from typing import TypeVar, Generic

T = TypeVar('T') # pylint: disable=C0103


class UnionFind(Generic[T]):
"""
A union find data structure for collecting results of the anti-clustering algorithm.
This implementation uses the weighted quick union with path compression.
"""
# A mapping from an element to its parent. If a parent maps to itself, it is the root of the component.
parent = {}
_parent = {}
_size = {}
components_count = 0

def __init__(self, parent: Dict[T, T]):
def __init__(self, initial_components_count: int):
"""
Initialize UnionFind with components.
:param parent: The initial components.
In most use cases all components will point to themselves (example: {0: 0, 1: 1, ...}).
:param initial_components_count: The initial number of components.
"""
self.parent = parent
self.components_count = initial_components_count
self._parent = {i: i for i in range(initial_components_count)}
self._size = {i: 1 for i in range(initial_components_count)}

def _find(self, a: T) -> T:
"""
Find the root of component of element a.
:param a: Element to find root of.
:return: The root of the component.
"""
if self.parent[a] == a:
return a
return self._find(self.parent[a])
# Compresses path while iterating up the tree.
while a != self._parent[a]:
b = self._parent[a]
self._parent[a] = self._parent[b]
a = b

return a

def find(self, a: T) -> T:
"""
Expand All @@ -60,9 +71,23 @@ def union(self, a: T, b: T) -> None:
:param b: Other element to unify.
:return:
"""
if a == b:
return

x = self._find(a)
y = self._find(b)
self.parent[x] = y

if x == y:
return

# Weighted union - the smaller component becomes the child of the root of the larger component.
if self._size[x] < self._size[y]:
self._parent[x] = y
self._size[y] += self._size[x]
else:
self._parent[y] = x
self._size[x] += self._size[y]
self.components_count -= 1

def connected(self, a: T, b: T) -> bool:
"""
Expand Down
53 changes: 53 additions & 0 deletions tests/test_union_find.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from anti_clustering._union_find import UnionFind


def test_union_self():
"""
Tests that unioning a node with itself does not connect it to other components.
"""
uf = UnionFind(5)
uf.union(0, 0)
uf.union(1, 2)
uf.union(2, 3)
uf.union(3, 4)
uf.union(4, 4)
assert uf.components_count == 2
assert uf.find(0) == 0
assert uf.find(1) == uf.find(2) == uf.find(3) == uf.find(4)
assert uf.find(0) != uf.find(1)
assert uf.connected(0, 0)
assert not uf.connected(0, 1)
assert uf.connected(1, 4)


def test_construction():
"""
Tests that the initial state is completely disconnected.
"""
uf = UnionFind(3)
assert uf.components_count == 3
assert uf.find(0) == 0
assert uf.find(1) == 1
assert uf.find(2) == 2
assert uf.connected(0, 0)
assert not uf.connected(0, 1)
assert not uf.connected(1, 2)
assert not uf.connected(0, 2)


def test_imbalanced_union():
"""
Tests that performing a chain of unioned elements connects all in the same components.
This is interesting due to the weighted union.
"""
uf = UnionFind(10)

for i in range(9):
uf.union(i, i+1)

assert uf.components_count == 1
assert uf.find(0) == uf.find(1) == uf.find(2) == uf.find(3) ==\
uf.find(4) == uf.find(5) == uf.find(6) == uf.find(7) == \
uf.find(8) == uf.find(9) == 0

assert uf.connected(0, 9)

0 comments on commit a417a37

Please sign in to comment.