- 注意可能出現dpx+1在模意義下為0的情況,此時需要額外維護0的個數而不能求逆元
- 記f[x]表示x子樹內包含x的連通子圖的個數,g[x]表示全樹包含x的連通子圖的個數,由於子樹的限制,所有fx互斥 【子樹互斥模型】
- 求出f[x]後換根DP求出g[x]。答案即為u-LCA(u,v)上f的和+g[LCA(u,v)]+v-LCA(u,v)上f的和,注意整個過程都和s[x]無關,顯然一個節點並不能和一種劃分一一對應
- 【關於計數類的題目的迷思】樸素演算法很難實現,樣例很弱,手造樣例也很麻煩。這種情況下或許只能更多地把它當做一道數學題做
點選檢視程式碼
#include <bits/stdc++.h>
using namespace std;
const int mod=998244353;
vector<int>a[100005];
int d[100005],h[100005],z[100005],s[100005],fa[100005],zero[100005];
long long f[100005],g[100005],val[100005],sum[100005],inv[100005];
int power(int n,int p)
{
if(p==0)
{
return 1;
}
long long tmp=power(n,p/2);
if(p%2==0)
{
return tmp*tmp%mod;
}
return tmp*tmp%mod*n%mod;
}
void dfs1(int n1)
{
s[n1]=1;
z[n1]=0;
f[n1]=val[n1]=1;
zero[n1]=0;
for(int i=0;i<a[n1].size();i++)
{
d[a[n1][i]]=d[n1]+1;
dfs1(a[n1][i]);
s[n1]+=s[a[n1][i]];
f[n1]=f[n1]*(f[a[n1][i]]+1)%mod;
if((f[a[n1][i]]+1)%mod)
{
val[n1]=val[n1]*(f[a[n1][i]]+1)%mod;
}
else
{
zero[n1]++;
}
if(s[a[n1][i]]>s[z[n1]])
{
z[n1]=a[n1][i];
}
}
inv[n1]=power(f[n1]+1,998244351);
}
void dfs2(int n1)
{
if(z[n1])
{
h[z[n1]]=h[n1];
dfs2(z[n1]);
}
sum[n1]=(sum[z[n1]]+f[n1])%mod;
for(int i=0;i<a[n1].size();i++)
{
if(a[n1][i]!=z[n1])
{
h[a[n1][i]]=a[n1][i];
dfs2(a[n1][i]);
}
}
}
void dp(int n1,int fa)
{
if(fa)
{
if((f[n1]+1)%mod)
{
g[n1]=f[n1]*(g[fa]*inv[n1]%mod+1)%mod;
if((g[fa]*inv[n1]%mod+1)%mod)
{
val[n1]=val[n1]*(g[fa]*inv[n1]%mod+1)%mod;
}
else
{
zero[n1]++;
}
}
else
{
if(zero[fa]>1)
{
g[n1]=0;
zero[n1]++;
}
else
{
g[n1]=f[n1]*(val[fa]+1)%mod;
if((val[fa]+1)%mod)
{
val[n1]=val[n1]*(val[fa]+1)%mod;
}
else
{
zero[n1]++;
}
}
}
}
for(int i=0;i<a[n1].size();i++)
{
dp(a[n1][i],n1);
}
}
int main()
{
//freopen("example.in","r",stdin);
ios::sync_with_stdio(false);
cin.tie(NULL);
int T;
cin >> T;
while(T--)
{
int n,q;
cin >> n >> q;
for(int i=1;i<=n;i++)
{
a[i].clear();
}
for(int i=2;i<=n;i++)
{
cin >> fa[i];
a[fa[i]].push_back(i);
}
d[1]=1;
dfs1(1);
g[1]=f[1];
dp(1,0);
h[1]=1;
dfs2(1);
for(int i=1;i<=q;i++)
{
long long ans=0;
int u,v,x;
cin >> u >> v;
while(h[u]!=h[v])
{
if(d[h[u]]<d[h[v]])
{
swap(u,v);
}
ans=(ans+sum[h[u]]-sum[z[u]])%mod;
u=fa[h[u]];
}
if(d[u]<d[v])
{
swap(u,v);
}
x=v;
ans=(ans+sum[z[v]]-sum[z[u]])%mod;
ans=(ans+g[x])%mod;
cout<<(ans+mod)%mod<<"\n";
}
}
return 0;
}