Union-Find in Python
GOAL
To understand and to implement union-find in Python.
Example of problems using union-find algorithm
Dividing clusters
There is a list of friends from N people. Divide them in cluster.
n = 6 friendship = [(1, 2), (3, 4), (1, 6)]
Questions:
Are person 2 and person 3 in the same cluster?
How many members in the cluster which person 3 belongs to?
Connected Path
nodes = 10 adjacency_list = [[6, 8], [2, 3], [1], [1], [5, 7], [4, 7], [0], [4, 5, 9], [0], [7]] # => connections = [(0, 6), (0, 8), (1, 2), (1, 3), (4, 5), (4, 7), (5, 7), (7, 9)]
Questions:
Is there any path from 2 to 3?
How many connected graphs are there?
What is Union-Find algorithm?
Union-Find algorithm is the algorithm to handle disjoint-set data structure.
Uinion(x, y): Connect x and y. Marge the group which x belongs to and the group which y belongs to.
Find(x): Find the group which x belongs to.
Please refer to Basics of Disjoint Data Structures for the more description. This article is easy to understand and detailed.
Example of implementation
I’ll introduce a method to implement union-find algorithm using lists “rank” and “parent”. The below is overview.
class UnionFind: def __init__(self, size): self.rank = [0 for i in range(size)] self.parent = [i for i in range(size)] def Find(self, x): parent = self.parent[x] while parent != x: x = parent parent = self.parent[x] return parent def Union(self, x, y): root_x = self.Find(x) root_y = self.Find(y) if root_x == root_y: return else: if self.rank[root_x] >= self.rank[root_y]: self.parent[root_y] = root_x self.rank[root_x] = max(self.rank[root_y] + 1, self.rank[root_x]) else: self.parent[root_x] = root_y self.rank[root_y] = max(self.rank[root_x] + 1, self.rank[root_y]) def Same(self, x, y): return self.Find(x) == self.Find(y)
unionfind = UnionFind(5) unionfind.Union(4, 2) # rank [0, 0, 0, 0, 1], parent [0, 1, 4, 3, 4] print(unionfind.Find(0)) # 0 unionfind.Union(4, 0) # rank [0, 0, 0, 0, 1]k, parent [4, 1, 4, 3, 4] print(unionfind.Find(0)) # 4 unionfind.Union(1, 3) # rank [0, 1, 0, 0, 1], parent [4, 1, 4, 1, 4] print(unionfind.Find(3)) # 1 unionfind.Union(1, 2) # rank [0, 2, 0, 0, 1], parent [4, 1, 4, 1, 1] print(unionfind.Find(0)) # 1
Variation
Find root and get the size of the tree
Change the list “parent” so that it returns the parent or the size of tree. And if the item is the root of tree, it returns the size of the tree. To identify the value of the element is the root or the size, the size is stored as negative number. For example, when 2 is the root of the tree with the size 5, parent[2] is -5.
class UnionFind: def __init__(self, size): self.rank = [0 for i in range(size)] self.parent = [-1 for i in range(size)] def Find(self, x): parent = self.parent[x] while parent >= 0: #termination condition has changed x = parent parent = self.parent[x] return x def Union(self, x, y): root_x = self.Find(x) root_y = self.Find(y) if root_x == root_y: return else: if self.rank[root_x] >= self.rank[root_y]: self.parent[root_x] += self.parent[root_y] #parent[root_id] is always negative self.parent[root_y] = root_x self.rank[root_x] = max(self.rank[root_y] + 1, self.rank[root_x]) else: self.parent[root_y] += self.parent[root_x] self.parent[root_x] = root_y self.rank[root_y] = max(self.rank[root_x] + 1, self.rank[root_y]) def Same(self, x, y): return self.Find(x) == self.Find(y) def FindRootAndSize(self): return [(idx, -val) for idx, val in enumerate(self.parent) if val<0 ]
unionfind = UnionFind(5) unionfind.Union(4, 2) unionfind.Union(4, 0) unionfind.Union(1, 3) print(unionfind.FindRootAndSize()) # [(1, 2), (4, 3)] unionfind.Union(1, 2) print(unionfind.FindRootAndSize()) # [(1, 5)]
Find all items in the tree
Add list “children” to store children items and itself.
class UnionFind: def __init__(self, size): self.rank = [0 for i in range(size)] self.parent = [-1 for i in range(size)] self.children = [[i] for i in range(size)] def Find(self, x): parent = self.parent[x] while parent >= 0: x = parent parent = self.parent[x] return x def Union(self, x, y): root_x = self.Find(x) root_y = self.Find(y) if root_x == root_y: return else: if self.rank[root_x] >= self.rank[root_y]: self.parent[root_x] += self.parent[root_y] self.parent[root_y] = root_x self.rank[root_x] = max(self.rank[root_y] + 1, self.rank[root_x]) self.children[root_x] += self.children[root_y] else: self.parent[root_y] += self.parent[root_x] self.parent[root_x] = root_y self.rank[root_y] = max(self.rank[root_x] + 1, self.rank[root_y]) self.children[root_y] += self.children[root_x] def Same(self, x, y): return self.Find(x) == self.Find(y) def FindRootAndSizeAndChildren(self): return [(idx, -val, self.children[idx]) for idx, val in enumerate(self.parent) if val<0 ]
unionfind = UnionFind(5) unionfind.Union(4, 2) unionfind.Union(4, 0) unionfind.Union(1, 3) print(unionfind.FindRootAndSizeAndChildren()) # [(1, 2, [1, 3]), (4, 3, [4, 2, 0])] unionfind.Union(1, 2) print(unionfind.FindRootAndSizeAndChildren())] # [(1, 5, [1, 3, 4, 2, 0])]