Distance Queries - CSES
Hello everyone, I hope you are learning well, In this blog, we shall learn to solve yet another application of the lowest common ancestor of two nodes which is present in the tree algorithm section of the CSES problem set.
So, Let's define the problem statement, we are given a tree consisting of N nodes and the task is to process Q queries of the form what is the distance between nodes a and b.
Let's understand how this problem can be solved using the LCA concept learned in the last video (I will attach the link to the video in the description).
Let's take a test case example to understand the problem statement better
In this tree, we want to find the distance between nodes 2 and 5, so think about how we can do this !!
One of the ways is to start the BFS from one of the nodes say a and simply keep track of the count of nodes visited till you reach the other node say b.
This strategy might seem easy but the time complexity for the BFS is O(N+E) and to process Q queries we need to perform BFS for Q times and it will cost O(Q*(N+E)) time complexity and we can see that to process queries Q of the order of 10^5 and N as the number of nodes of order 10^5, we require 10^5 * 10^5 = 10^10 operations and we know that the number of operations in 1 sec are 10^8 and hence the total number of seconds will be 10^10/10^8 = 100s.
Hence, we require 100s to execute the approach for the extreme cases.
A better approach
So, to solve this problem efficiently, you need to have a better approach that works under the given constraints.
So, observe that there is exactly one path between any two nodes a and b and we need to find the length of that path.
Here, the path length is of 4 nodes consisting of nodes 2,1,3 and 5.
Observe that any path goes through the LCA of the two nodes
So, if we know the distance of each node from their LCA, we can easily calculate the distance between two nodes, right?
Let's calculate the depth of all nodes of the tree
The code to calculate the depth is:
void dfs(int node, int par, vector<int>&level, int depth, vector<int>adj[]){
level[node] = depth;
for(auto child: adj[node]){
if(node != par){
dfs(child, node,level,depth+1,adj);
}
}
}
Here, we shall have the depth of each node in the level array.
Now, observe that we need to find the distance between 4 and 5 which is 2 and it can be observed that adding the depth of nodes 4 and 5 gives me the sum of distances of both of these nodes from the root node, and in this way, I am adding the depth of the LCA of 4 and 5 twice and to calculate the distance between 4 and 5, I need to subtract the depth from the addition twice.
So, the answer to the distance between two nodes is level[4] + level[5] - 2 level[LCA(4,5)] = 2 + 2 - 2 * 1 = 2.
Hence, the general formula to solve the problem is level[a] + level[b] - 2 * level[LCA(a,b)].
So, we need to calculate the LCA of the two nodes which we have already learned to do and the code is :
const int MAXN = 200111;
const int N = 20;
vector<int>adj[MAXN];
int up[MAXN][N],level[MAXN];
int q;
//Function to calculate the level
void dfs(int node, int par, vector<int>&level, int depth, vector<int>adj[]){
level[node] = depth;
up[node][0] = par;
for(auto child: adj[node]){
if(child != par){
dfs(child, node, level, depth+1, adj);
}
}
}
//preprocessing the up array
void preprocess(){
for(int i=1;i<MAXN;i++){
for(int j=1;j<N;j++){
if(up[i][j-1] != -1){
int par = up[i][j-1];
up[i][j] = up[par][j-1];
}
}
}
}
int LCA(int a, int b){
if(b < a)swap(a,b);
int d = level[b] - level[a];
while(d){
int i = log2(d);
b = up[b][i];
d -= (1 << i);
}
if(a == b)return a;
for(int i=N-1;i>=0;i--){
if(up[a][i] != -1 && (up[a][i] != up[b][i]){
a = up[a][i];
b = up[b][i];
}
}
return up[a][0];
}
and to process the Q queries we just need to use the equation level[a] + level[b] - 2 * level[LCA(a,b)].
cin >> q;
while(q--){
int a,b;
cin >> a >> b;
cout << level[a] + level[b] - 2 * level[LCA(a,b)] << "\n";
}
The solution code for the problem is
#include<bits/stdc++.h>
using namespace std;
const int MAXN = 200111;
const int N = 20;
vector<int>adj[MAXN];
int up[MAXN][N],level[MAXN];
void dfs(int node, int par, int lvl){
level[node]=lvl;
up[node][0]=par;
for(auto child: adj[node]){
if(child != par){
dfs(child,node,lvl+1);
}
}
}
void preprocess(){
dfs(1,-1,0);
for(int i=1;i<MAXN;i++){
for(int j=1;j<N;j++){
if(up[i][j-1] != -1){
int par = up[i][j-1];
up[i][j] = up[par][j-1];
}
}
}
}
int LCA(int a, int b){
if(level[b] < level[a])swap(a,b);
int d = level[b]-level[a];
while(d){
int i = log2(d);
b = up[b][i];
d -= (1 << i);
}
if(a == b)return a;
for(int i=N-1;i>=0;i--){
if(up[a][i] != -1 && (up[a][i] != up[b][i])){
a = up[a][i];
b = up[b][i];
}
}
return up[a][0];
}
int main(){
int n,q;
cin >> n >> q;
for(int i=2;i<=n;i++){
int a,b;
cin >> a >> b;
adj[a].push_back(b);
adj[b].push_back(a);
}
up[0][0] = -1;
level[0] = 0;
preprocess();
while(q--){
int a,b;
cin >> a >> b;
cout << level[a] + level[b] - 2 * level[LCA(a,b)] << "\n";
}
}
Time Complexity:
The time complexity for the given code is O(Q * log(N)).
That's all I have got for today, will see you soon with yet another problem :)
Link to the binary Search Video:
Subscribe to my newsletter
Read articles from Ramandeep Singh directly inside your inbox. Subscribe to the newsletter, and don't miss out.
Written by
Ramandeep Singh
Ramandeep Singh
A programming enthusiast trying to give share my knowledge with the community :D