HDU7458-啟發式合併最佳化DP

yoshinow2001發表於2024-07-29

link:https://acm.hdu.edu.cn/showproblem.php?pid=7458
題意:給一棵樹,每個點有點權 \(w\) 和顏色 \(c\),選擇若干條不相交的路徑,每條路徑的起始點顏色相同,權值為起始點的權值之和,最大化權值之和。


對每條路徑 \((u,v)\) 可以放到LCA上考慮,即我們對每個子樹考慮,設 \(f(i,0/1)\) 分別表示在 \(i\) 的子樹內,強制選/不選 \(i\) 號點,在子樹內能獲得的最大收益,\(g(i)=\max(f(i,0),f(i,1))\),記 \(S_u\) 表示 \(u\) 的所有子節點好了,那麼:

\[f(x,0)=\sum_{v\in S_x}g(v) \]

\(f(x,1)\)有幾種情情況:從某兩個不同的子節點中的某兩個同色點連上來的,或者是直接從 \(x\) 作為一個端點連到某個孩子節點的,第一種情況是:

算答案的時候剛好多減去一個 \(g\),所以我們直接對每個子樹中每個顏色,維護 \(w_{u1}+\sum (f(u_i,0)-g(u_i))\) 的最大值,因為對每個顏色只關心最大值,可以用一個 map (甚至可以是unordered的)維護,每跳一層就給這個子樹做一個全域性加 \(f(u,0)-g(u)\) 的操作,用樹上啟發式合併的辦法,同時維護一個加法標記即可。

對於直接從 \(x\) 連下去的情況類似

#include<bits/stdc++.h>
#define rep(i,a,b) for(int i=(a);i<=(b);i++)
#define endl '\n'
#define fastio ios_base::sync_with_stdio(false);cin.tie(0);cout.tie(0)
using namespace std;
typedef long long ll;
const int N=2e5+5;
int n,c[N];
ll f[N][2],g[N],w[N];
vector<vector<int>> G;

int bl[N];
ll tag[N];
map<int,ll> mp[N];
void upd(ll &x,ll y){x=max(x,y);}
void dfs(int x,int fa){
    f[x][0]=f[x][1]=g[x]=0;
    for(auto to:G[x])if(to!=fa){
        dfs(to,x);
        f[x][0]+=g[to];
    }

    for(auto v:G[x])if(v!=fa){
        if(mp[bl[v]].count(c[x]))upd(f[x][1],(mp[bl[v]][c[x]]+tag[bl[v]])+f[x][0]+w[x]);
        //merge v to x;
        if(mp[bl[v]].size()>mp[bl[x]].size())swap(bl[v],bl[x]);

        //calc
        for(auto [col,val]:mp[bl[v]]){
            if(mp[bl[x]].count(col))
                upd(f[x][1],(val+tag[bl[v]])+(mp[bl[x]][col]+tag[bl[x]])+f[x][0]);
        }

        //merge
        for(auto [c,val]:mp[bl[v]]){
            if(mp[bl[x]].count(c))upd(mp[bl[x]][c],val+tag[bl[v]]-tag[bl[x]]);
            else mp[bl[x]][c]=val+tag[bl[v]]-tag[bl[x]];
        }
    }
    if(mp[bl[x]].count(c[x]))upd(mp[bl[x]][c[x]],w[x]-tag[bl[x]]);
    else mp[bl[x]][c[x]]=w[x]-tag[bl[x]];

    g[x]=max(f[x][0],f[x][1]);

    tag[bl[x]]+=f[x][0]-g[x];
}

void solve(){
    cin>>n;
    rep(i,1,n)cin>>c[i];
    rep(i,1,n)cin>>w[i];
    G=vector<vector<int>> (n+1);
    rep(i,1,n)bl[i]=i,tag[i]=0,mp[i].clear();
    rep(i,1,n-1){
        int u,v;
        cin>>u>>v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs(1,-1);
    // rep(i,1,n)cout<<f[i][0]<<' '<<f[i][1]<<' '<<g[i]<<endl;
    cout<<g[1]<<endl;
}
int main(){
    fastio;
    int tc;cin>>tc;
    while(tc--)solve();
    return 0;
}

相關文章