2020CCPC長春站部分題解

敲程式碼的歐文發表於2020-11-09

F題

dsu on tree維護一個陣列t[][][]。t[i][j][k]表示a[u]=i且u的第j位是k的u的個數。

這個東西沒辦法直接維護的,但是對於j,你沒必要知道i是什麼,假如j的第k位是0,那麼你需要知道第k位是1的i的個數即可。

因此把i直接拆成20位,就可以統計答案了。

我個人理解dsu on tree對於處理子樹間的貢獻和子樹內的貢獻這兩種不同的題型有兩種不同的寫法,需要特別注意。

#include<bits/stdc++.h>
#define pb push_back
#define fi first
#define se second
#define sz(x)  (int)x.size()
#define cl(x)  x.clear()
#define all(x)  x.begin() , x.end()
#define rep(i , x , n)  for(int i = x ; i <= n ; i ++)
#define per(i , n , x)  for(int i = n ; i >= x ; i --)
#define mem0(x)  memset(x , 0 , sizeof(x))
#define mem_1(x)  memset(x , -1 , sizeof(x))
#define mem_inf(x)  memset(x , 0x3f , sizeof(x))
#define debug(x)  cerr << #x << " = " << x << '\n'
#define ddebug(x , y)  cerr << #x << " = " << x << "   " << #y << " = " << y << '\n'
#define ios std::ios::sync_with_stdio(false) , cin.tie(0)
using namespace std ;
typedef long long ll ;
typedef long double ld ;
typedef pair<int , int> pii ;
typedef pair<ll , ll> pll ;
typedef double db ;
const int mod = 998244353 ;
const int maxn = 1e5 + 10 ;
const int maxm = 1e6 + 1e5 + 10 ;
const int inf = 0x3f3f3f3f ;
const double eps = 1e-6 ; 
int n , a[maxn] ;
int c[maxn] ;
vector<int> g[maxn] ;
ll ans = 0 ;
struct Dsu_on_tree
{
    int siz[maxn] , son[maxn] ;
    int flag ;
    int t[maxm][20][2] ; //t[i][j][k]表示a[u]=i且u的第j位是k的u的個數
    void init() 
    {
        flag = 0 ;
        memset(siz , 0 , sizeof(siz)) ;
        memset(son , 0 , sizeof(son)) ;
        memset(t , 0 , sizeof(t)) ;   
        rep(i , 0 , 19)  c[i] = (1 << i) ;
    }
    void dfs1(int f , int u)
    {
        siz[u] = 1 ;
        for(auto v : g[u])
        {
            if(v == f) continue ;
            dfs1(u , v) ; 
            siz[u] += siz[v] ;
            if(siz[v] > siz[son[u]]) son[u] = v ;
        }
    }
    void add(int u , int x)
    {
        int tmp = a[u] ;
        rep(j , 0 , 19)  t[tmp][j][u % 2] += x , u /= 2 ;
    }
    void dfs3(int fa , int u , int lca)
    {
        int s = (a[u] ^ a[lca]) ;
        int tmp = u ;
        rep(j , 0 , 19)  ans += 1ll * t[s][j][1 - (tmp % 2)] * c[j] , tmp /= 2 ;
        for(auto v : g[u])
        {
            if(v == fa)  continue ;
            dfs3(u , v , lca) ;
        }
    }
    void dfs4(int fa , int u , int x)
    {
        add(u , x) ;
        for(auto v : g[u])
        {
            if(v == fa)  continue ;
            dfs4(u , v , x) ;
        }
    }
    void calc(int f , int u , int x)
    {
        for(auto v : g[u])
        {
            if(v == f || v == flag) continue ;
            if(x == 1)  dfs3(u , v , u) ;
            dfs4(u , v , x) ;
        }
        add(u , x) ;
    }
    void dfs2(int f , int u , int keep)
    {
        for(auto v : g[u])
        {
            if(v == f || v == son[u]) continue ;
            dfs2(u , v , 0) ;
        }
        if(son[u])  dfs2(u , son[u] , 1) , flag = son[u] ;
        calc(f , u , 1) ;
        if(son[u]) flag = 0 ;
        if(!keep)  calc(f , u , -1) ;
    }
} dsu_on_tree ;
int main()
{
    ios ;
    cin >> n ;
    rep(i , 1 , n)  cin >> a[i] ;
    rep(i , 1 , n - 1)
    {
        int u , v ;
        cin >> u >> v ;
        g[u].pb(v) , g[v].pb(u) ;
    }
    dsu_on_tree.init() ;
    dsu_on_tree.dfs1(1 , 1) ;
    dsu_on_tree.dfs2(1 , 1 , 0) ;
    cout << ans << '\n' ;
    return 0 ;
}

 

相關文章