并查集 Disjoint Set
并查集(union & find) 是一种树形结构,用于处理一些不交集(Disjoint Sets) 的合并及查询问题
Find: 确定元素属于哪一个子集。他可以被用来确定两个元素是否属于同一个子集。
Union: 将两个子集合合并成同一个集合
适用场景
- 组团、配对问题
- Group or not ?
伪代码
- makeSet(s):建立一个新的并查集,其中包含 s 个单元素集合。
- unionSet(x, y):把元素 x 和元素 y 所在的集合合并,要求 x 和 y 所在的集合不相交,如果相交则不合并。
- find(x):找到元素 x 所在的集合的代表,该操作也可以用于判断两个元 素是否位于同一个集合,只要将它们各自的代表比较一下就可以了。
def makeSet(x):
x.parent := x
def Find(x):
if x.parent == x
return x
else:
return Find(x.parent)
def Union(x, y):
xRoot := Find(x)
yRoot := Find(y)
xRoot.parent := yRoot
优化
- 优化 1: union by rank
def makeSet(x):
x.parent = x
def Find(x):
if x.parent == x
return x
else:
return Find(x.parent)
def Union(x, y):
xRoot := Find(x)
yRoot := Find(y)
if xRoot.rank < yRoot.rank:
xRoot.parent = yRoot
else if XRoot.rank > yRoot.rank:
yRoot.parent = xRoot
else:
yRoot.parent = xRoot
xRoot.rank = xRoot.rank + 1
- 优化2: 调用 find(d) 时路径压缩
class UnionFind {
private int count = 0;
private int[] parent;
public UnionFind(int n) {
count = n;
parent = new int[n];
for (int i = 0; i < n; i++) {
parent[i] = i;
}
}
public int find(int p) {
while (p != parent[p]) {
parent[p] = parent[parent[p]];
p = parent[p];
}
return p;
}
// 路径压缩
public int findRoot(int i) {
int root = i;
while (root != parent[root]) {
root = parent[root];
}
while (i != parent[i]) {
int tmp = parent[i];
parent[i] = root;
i = tmp;
}
return root;
}
public boolean connected(int p, int q) {
return find(p) == find(q);
}
public void union(int p, int q) {
int rootP = find(p);
int rootQ = find(q);
if (rootP == rootQ) return;
parent[rootP] = rootQ;
count--;
}
}
def init(p):
# for i=0 ..n:p[i]=i;
p=[i for i in range(n)]
def union(self,p,i,j):
p1=self.parent(p,i)
p2=self.parent(p,j)
p[p1]=p2
def parent(self,p,i):
root=i
while p[root]!=root:
root=p[root]
while p[i]!=i: # 路径压缩
x=i;i=p[i];p[x]=root
return root
class UnionFind(object):
def __init__(self, grid):
m, n = len(grid), len(grid[0])
self.count = 0
self.parent = [-1] * (m * n)
self.rank = [0] * (m * n)
for i in range(m):
for j in range(n):
if grid[i][j] == '1':
self.parent[i * n + j] = i * n + j
self.count += 1
def find(self, i):
if self.parent[i] != i:
self.parent[i] = self.find(self.parent[i])
return self.parent[i]
def union(self, x, y):
rootx = self.find(x)
rooty = self.find(y)
if rootx != rooty:
if self.rank[rootx] > self.rank[rooty]:
self.parent[rooty] = rootx
elif self.rank[rootx] < self.rank[rooty]:
self.parent[rootx] = rooty
else:
self.parent[rooty] = rootx
self.rank[rootx] += 1
self.count -= 1