思路:
首先可以先考慮沒有換根的情況。
先將樹拍到 dfn 序上,那麼一個子樹 \(u\) 的所有點的 dfn 序區間為 \([dfn_u,dfn_u+siz_u-1]\)。
那麼詢問變為:
-
每次給定兩個區間 \([l_1,r_1],[l_2,r_2]\),對於在第一個區間內的點 \(x\) 和在第二個區間的點 \(y\),若 \((x,y)\) 有貢獻,當且僅當 \(w_x=w_y\)。
-
詢問有貢獻的點對數量。
即 P5268 [SNOI2017] 一個簡單的詢問。
設 \(F(l_1,r_1,l_2,r_2)\) 表示 \([l_1,r_1]\) 與 \([l_2,r_2]\) 的貢獻,那麼:
那麼一個詢問就都轉化為了四個 \(F(1,x,1,y)\) 的形式,考慮如何求 \(F(1,x,1,y)\),先欽定 \(x \le y\),那麼考慮莫隊:
-
設當前 \(p_{1,x},p_{2,x}\) 分別表示兩個區間 \(x\) 的出現次數。
-
若 \(x \gets x+1\) 時,貢獻會增加 \(p_{2,a_{x+1}}\)。
-
若 \(x \gets x-1\) 時,貢獻會減少 \(p_{2,a_x}\)。
-
若 \(y \gets y+1\) 時,貢獻會增加 \(p_{1,a_{y+1}}\)。
-
若 \(y \gets y-1\) 時,貢獻會減少 \(p_{1,a_y}\)。
現在再考慮換根操作,若當前以 \(rt\) 為根:
-
若 \(rt\) 不在初始以 \(1\) 為根時 \(x\) 的子樹內,則不好造成影響。
-
否則 \(x\) 子樹內的點即為除了\((x \to rt)\) 路徑上最接近 \(x\) 的點 \(y\) 子樹內的點的全部點。
因為 \(x\) 在原始樹上始終是 \(rt\) 的父親,則 \(y\) 是 \(rt\) 的 \(dep_{rt}-dep_{x}-1\) 級祖先,直接倍增即可。
時間複雜度為 \(O(N\sqrt{M}+M \log N+M)\)。
完整程式碼:
#include<bits/stdc++.h>
#define Add(x,y) (x+y>=mod)?(x+y-mod):(x+y)
#define lowbit(x) x&(-x)
#define pi pair<ll,ll>
#define pii pair<ll,pair<ll,ll>>
#define iip pair<pair<ll,ll>,ll>
#define ppii pair<pair<ll,ll>,pair<ll,ll>>
#define fi first
#define se second
#define full(l,r,x) for(auto it=l;it!=r;it++) (*it)=x
#define Full(a) memset(a,0,sizeof(a))
#define open(s1,s2) freopen(s1,"r",stdin),freopen(s2,"w",stdout);
#define For(i,l,r) for(int i=l;i<=r;i++)
#define _For(i,l,r) for(int i=r;i>=l;i--)
using namespace std;
typedef double db;
typedef unsigned long long ull;
typedef long long ll;
bool Begin;
const ll N=1e5+10,M=4e6+10,K=17;
inline ll read(){
ll x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){
if(c=='-')
f=-1;
c=getchar();
}
while(c>='0'&&c<='9'){
x=(x<<1)+(x<<3)+(c^48);
c=getchar();
}
return x*f;
}
inline void write(ll x){
if(x<0){
putchar('-');
x=-x;
}
if(x>9)
write(x/10);
putchar(x%10+'0');
}
ll op,n,m,t,q,u,v,rt,sum,l1,r1,l2,r2,l,r,cnt;
ll A[N],a[N],b[N],w[N],d[N],siz[N],dfn[N],p1[N],p2[N],ans[M];
ll F[N][K];
vector<pi> X,Y;
vector<ll> E[N];
struct Ques{
ll x,y;
ll id;
ll v;
inline bool operator<(const Ques &rhs)const{
if(A[x]^A[rhs.x])
return A[x]<A[rhs.x];
return y>rhs.y;
}
}Q[M];
inline void add(ll u,ll v){
E[u].push_back(v);
E[v].push_back(u);
}
inline void dfs(ll u,ll fa){
For(i,1,K-1)
F[u][i]=F[F[u][i-1]][i-1];
dfn[u]=++cnt;
w[cnt]=a[u];
siz[u]=1;
for(auto v:E[u]){
if(v==fa)
continue;
F[v][0]=u;
d[v]=d[u]+1;
dfs(v,u);
siz[u]+=siz[v];
}
}
inline ll get_fa(ll u,ll k){
_For(i,0,K-1){
if((k>>i)&1ll){
k-=(1ll<<i);
u=F[u][i];
}
}
return u;
}
inline vector<pi> get(ll x){
vector<pi> ans;
if(x==rt)
ans.push_back({1,n});
else if(dfn[x]<=dfn[rt]&&dfn[rt]<=dfn[x]+siz[x]-1){
ll y=get_fa(rt,d[rt]-d[x]-1);
if(dfn[y]!=1)
ans.push_back({1,dfn[y]-1});
if(dfn[y]+siz[y]<=n)
ans.push_back({dfn[y]+siz[y],n});
}
else
ans.push_back({dfn[x],dfn[x]+siz[x]-1});
return ans;
}
inline void get(ll l1,ll r1,ll l2,ll r2){
Q[++q]={r1,r2,cnt,1};
if(l1-1)
Q[++q]={l1-1,r2,cnt,-1};
if(l2-1)
Q[++q]={r1,l2-1,cnt,-1};
if(l1-1&&l2-1)
Q[++q]={l1-1,l2-1,cnt,1};
}
inline void insert1(ll x){
sum+=p2[w[x]];
p1[w[x]]++;
}
inline void insert2(ll x){
sum+=p1[w[x]];
p2[w[x]]++;
}
inline void del1(ll x){
sum-=p2[w[x]];
p1[w[x]]--;
}
inline void del2(ll x){
sum-=p1[w[x]];
p2[w[x]]--;
}
bool End;
int main(){
n=read(),m=read();
For(i,1,n){
a[i]=read();
b[++cnt]=a[i];
}
sort(b+1,b+cnt+1);
cnt=unique(b+1,b+cnt+1)-(b+1);
For(i,1,n)
a[i]=lower_bound(b+1,b+cnt+1,a[i])-b;
cnt=0;
For(i,1,n-1){
u=read(),v=read();
add(u,v);
}
dfs(1,1);
cnt=0;
For(i,1,m){
op=read(),u=read();
if(op==1){
rt=u;
continue;
}
++cnt;
v=read();
X=get(u);
Y=get(v);
for(auto x:X)
for(auto y:Y)
get(x.fi,x.se,y.fi,y.se);
}
t=max(n/max((ll)sqrt(m),1ll),1ll);
For(i,1,n)
A[i]=(i-1)/t+1;
For(i,1,q)
if(Q[i].x>Q[i].y)
swap(Q[i].x,Q[i].y);
sort(Q+1,Q+q+1);
For(i,1,q){
while(l<Q[i].x)
insert1(++l);
while(l>Q[i].x)
del1(l--);
while(r<Q[i].y)
insert2(++r);
while(r>Q[i].y)
del2(r--);
ans[Q[i].id]+=sum*Q[i].v;
}
For(i,1,cnt){
write(ans[i]);
putchar('\n');
}
//cerr<<'\n'<<abs(&Begin-&End)/1048576<<"MB";
return 0;
}