題意簡述
給定一個\(N\)個節點的樹,接下來有\(q\)次詢問。每次詢問給定\(a,b,c\),請問存在多少個節點\(i\),使得這棵樹在以\(i\)為根的情況下,\(a\)和\(b\)的LCA是\(c\)。
解題思路
首先透過分析樣例,我們發現:\(a,b\)的LCA一定在它們之間的簡單路徑上,所以如果\(c\)不在\(a,b\)之間的簡單路徑上,則輸出\(0\)。
進一步分析\(c\)在\(a,b\)間簡單路徑上的情況,我們可以歸納出三種情況:
- 如果\(c=lca(a,b)\),那麼答案是\(n-siz[jump(a,c)]-siz[jump(b,c)]\),其中\(siz[u]\)表示以\(u\)為根的子數大小,\(jump(u,v)\)表示\(u\)點向上跳到\(v\)的下一層所到達的節點。
- 如果\(c\)在\(a\)到\(lca(a,b)\)的路徑上,那麼答案是\(siz[c]-siz[jump(a,c)]\)。
- 如果\(c\)在\(b\)到\(lca(a,b)\)的路徑上,那麼答案是\(siz[c]-siz[jump(b,c)]\)。
接下來我們思考程式碼實現。怎樣知道\(c\)是哪一種情況呢?
- 第一種情況:\(c=lca(a,b)\)。
- 第二種情況:\(lca(a,c)=c\)且\(lca(a,b)=lca(b,c)\)。
- 第三種情況:\(lca(b,c)=c\)且\(lca(a,b)=lca(a,c)\)。
程式碼實現中,\(jump()\)函式我們可以透過倍增的思想在\(O(\log N)\)的時間複雜度內完成。而求\(lca()\)函式同樣可以用倍增達到\(O( \log N)\)的時間複雜度。
點選檢視程式碼
#include<bits/stdc++.h>
#define N 500010
using namespace std;
int n,q,dep[N],fa[N][20],siz[N];
vector<int> G[N];
void dfs(int u,int father){
siz[u]=1;
dep[u]=dep[father]+1;
fa[u][0]=father;
for(int i=1;i<20;i++)
fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i:G[u])
if(i!=father) dfs(i,u),siz[u]+=siz[i];
}
int lca(int u,int v){
if(dep[u]<dep[v]) swap(u,v);
for(int i=19;i>=0;i--)
if(dep[fa[u][i]]>=dep[v])
u=fa[u][i];
if(u==v) return v;
for(int i=19;i>=0;i--)
if(fa[u][i]!=fa[v][i])
u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
int jump(int a,int b){
if(dep[a]==dep[b]) return 0;
//計算a跳到b下一層的位置
for(int i=19;i>=0;i--)
if(dep[fa[a][i]]>dep[b])
a=fa[a][i];
return a;
}
int main(){
cin>>n>>q;
for(int i=1;i<n;i++){
int u,v;
cin>>u>>v;
G[u].emplace_back(v);
G[v].emplace_back(u);
}
dfs(1,0);
while(q--){
int a,b,c;
cin>>a>>b>>c;
int ab=lca(a,b),ac=lca(a,c),bc=lca(b,c);
if(ab==c){
cout<<n-siz[jump(a,c)]-siz[jump(b,c)]<<"\n";
}else if(ac==c&&ab==bc){
cout<<siz[c]-siz[jump(a,c)]<<"\n";
}else if(bc==c&&ab==ac){
cout<<siz[c]-siz[jump(b,c)]<<"\n";
}else{
cout<<"0\n";
}
}
return 0;
}