E. Disrupting Communications

D06發表於2024-11-09
  • 注意可能出現dpx+1在模意義下為0的情況,此時需要額外維護0的個數而不能求逆元
  • 記f[x]表示x子樹內包含x的連通子圖的個數,g[x]表示全樹包含x的連通子圖的個數,由於子樹的限制,所有fx互斥 【子樹互斥模型】
  • 求出f[x]後換根DP求出g[x]。答案即為u-LCA(u,v)上f的和+g[LCA(u,v)]+v-LCA(u,v)上f的和,注意整個過程都和s[x]無關,顯然一個節點並不能和一種劃分一一對應
  • 【關於計數類的題目的迷思】樸素演算法很難實現,樣例很弱,手造樣例也很麻煩。這種情況下或許只能更多地把它當做一道數學題做
點選檢視程式碼
#include <bits/stdc++.h>
using namespace std;
const int mod=998244353;
vector<int>a[100005];
int d[100005],h[100005],z[100005],s[100005],fa[100005],zero[100005];
long long f[100005],g[100005],val[100005],sum[100005],inv[100005];
int power(int n,int p)
{
    if(p==0)
    {
        return 1;
    }
    long long tmp=power(n,p/2);
    if(p%2==0)
    {
        return tmp*tmp%mod;
    }
    return tmp*tmp%mod*n%mod;
}
void dfs1(int n1)
{
    s[n1]=1;
    z[n1]=0;
    f[n1]=val[n1]=1;
    zero[n1]=0;
    for(int i=0;i<a[n1].size();i++)
    {
        d[a[n1][i]]=d[n1]+1;
        dfs1(a[n1][i]);
        s[n1]+=s[a[n1][i]];
        f[n1]=f[n1]*(f[a[n1][i]]+1)%mod;
        if((f[a[n1][i]]+1)%mod)
        {
            val[n1]=val[n1]*(f[a[n1][i]]+1)%mod;
        }
        else
        {
            zero[n1]++;
        }
        if(s[a[n1][i]]>s[z[n1]])
        {
            z[n1]=a[n1][i];
        }
    }
    inv[n1]=power(f[n1]+1,998244351);
}
void dfs2(int n1)
{
    if(z[n1])
    {
        h[z[n1]]=h[n1];
        dfs2(z[n1]);
    }
    sum[n1]=(sum[z[n1]]+f[n1])%mod;
    for(int i=0;i<a[n1].size();i++)
    {
        if(a[n1][i]!=z[n1])
        {
            h[a[n1][i]]=a[n1][i];
            dfs2(a[n1][i]);
        }
    }
}
void dp(int n1,int fa)
{
    if(fa)
    {
        if((f[n1]+1)%mod)
        {
            g[n1]=f[n1]*(g[fa]*inv[n1]%mod+1)%mod;
            if((g[fa]*inv[n1]%mod+1)%mod)
            {
                val[n1]=val[n1]*(g[fa]*inv[n1]%mod+1)%mod;
            }
            else
            {
                zero[n1]++;
            }
        }
        else
        {
            if(zero[fa]>1)
            {
                g[n1]=0;
                zero[n1]++;
            }
            else
            {
                g[n1]=f[n1]*(val[fa]+1)%mod;
                if((val[fa]+1)%mod)
                {
                    val[n1]=val[n1]*(val[fa]+1)%mod;
                }
                else
                {
                    zero[n1]++;
                }
            }
        }
    }
    for(int i=0;i<a[n1].size();i++)
    {
        dp(a[n1][i],n1);
    }
}
int main()
{
    //freopen("example.in","r",stdin);
    ios::sync_with_stdio(false);
    cin.tie(NULL);
    int T;
    cin >> T;
    while(T--)
    {
        int n,q;
        cin >> n >> q;
        for(int i=1;i<=n;i++)
        {
            a[i].clear();
        }
        for(int i=2;i<=n;i++)
        {
            cin >> fa[i];
            a[fa[i]].push_back(i);
        }
        d[1]=1;
        dfs1(1);
        g[1]=f[1];
        dp(1,0);
        h[1]=1;
        dfs2(1);
        for(int i=1;i<=q;i++)
        {
            long long ans=0;
            int u,v,x;
            cin >> u >> v;
            while(h[u]!=h[v])
            {
                if(d[h[u]]<d[h[v]])
                {
                    swap(u,v);
                }
                ans=(ans+sum[h[u]]-sum[z[u]])%mod;
                u=fa[h[u]];
            }
            if(d[u]<d[v])
            {
                swap(u,v);
            }
            x=v;
            ans=(ans+sum[z[v]]-sum[z[u]])%mod;
            ans=(ans+g[x])%mod;
            cout<<(ans+mod)%mod<<"\n";
        }
    }
    return 0;
}

相關文章