A Beginner's Guide to the Union-Find

EJ JungEJ Jung
3 min read

Union-Find, also known as Disjoint Set Union (DSU), is a data structure that efficiently keeps track of a set of elements partitioned into disjoint (non-overlapping) subsets. It's particularly useful in problems involving connectivity, such as determining whether two elements are in the same group or detecting cycles in a graph.


✨ Summary

  • Purpose: Track and manage disjoint sets

  • Core operations:

    • find(x): Find the representative (root) of the set containing x

    • union(x, y): Merge the sets that contain x and y

  • Optimizations: Path compression and union by rank/size

  • Time Complexity: Nearly O(1) per operation with optimizations


🧱 Basic Implementation

def find(x):
    if parent[x] != x:
        parent[x] = find(parent[x])  # Path compression
    return parent[x]

def union(x, y):
    rootX = find(x)
    rootY = find(y)
    if rootX != rootY:
        parent[rootY] = rootX  # Union

Initialization:

parent = [i for i in range(n)]

Each element is its own parent initially.


âš¡ Optimizations

Path Compression

  • During find(x), we recursively set the parent of each node to its root.

  • This flattens the tree, making future operations faster.

Union by Rank or Size (Optional)

  • Attach the smaller tree to the root of the larger tree.

  • Prevents tree from becoming tall.

rank = [1] * n

def union(x, y):
    rootX = find(x)
    rootY = find(y)
    if rootX != rootY:
        if rank[rootX] < rank[rootY]:
            parent[rootX] = rootY
        else:
            parent[rootY] = rootX
            if rank[rootX] == rank[rootY]:
                rank[rootX] += 1

🧠 Use Cases

Problem TypeExample Problems
Connectivity / ComponentsNumber of Provinces (LC 547)
Cycle DetectionRedundant Connection (LC 684)
Grouping / MergingAccounts Merge (LC 721)
String/Custom GroupingSentence Similarity II (LC 737)
Weighted Union-FindEvaluate Division (LC 399, advanced)

📌 Example: Number of Provinces (LC 547)

Given an adjacency matrix isConnected, count how many connected components (provinces) exist.

class Solution:
    def findCircleNum(self, isConnected: List[List[int]]) -> int:
        n = len(isConnected)
        parent = [i for i in range(n)]

        def find(x):
            if x != parent[x]:
                parent[x] = find(parent[x])  # Path compression
            return parent[x]

        def union(x, y):
            rootX = find(x)
            rootY = find(y)
            if rootX != rootY:
                parent[rootY] = rootX

        for i in range(n):
            for j in range(i + 1, n):
                if isConnected[i][j]:
                    union(i, j)

        return len(set(find(i) for i in range(n)))

💡 Tip

  • The problem involves grouping, connected components, or merging sets.

  • You need to answer: "Are these two elements in the same group?"


✨ Conclusion

Union-Find is a must-know data structure for any coding interview. With its powerful optimizations and simple API (find and union), it can solve many seemingly complex problems involving relationships, grouping, and connectivity.

🔗 References

0
Subscribe to my newsletter

Read articles from EJ Jung directly inside your inbox. Subscribe to the newsletter, and don't miss out.

Written by

EJ Jung
EJ Jung