線段樹最佳化 DP & CF833B The Bakery 題解

FlyPancake發表於2024-08-18

線段樹最佳化 DP & CF833B The Bakery 題解

題目大意:

將一個長度為 \(n\) 的序列分為 \(m\) 段,使得總價值最大。

一段區間的價值表示為區間內不同數字的個數。

\(n \le 35000, m \le 50\)

(雖然原題是 \(k\),但是我程式碼中寫的是 \(m\),所以就改成 \(m\) 了)


首先看到劃分割槽間,算總價值最大的題,可以先考慮樸素的區間 DP。

\(dp[i][j]\) 表示前 \(j\) 個數劃分為 \(i\) 段的最大總價值。\(val(l, r)\) 表示區間 \([l, r]\) 的價值,即其中有多少個不同的數。

可得轉移方程:

\[dp[i][j] = \max_{i-1 \le k \le j-1} (dp[i-1][k]+val(k+1, j)) \]

\({i-1 \le k \le j-1}\) 的原因是最少 \(i-1\) 個數字,最多 \(j-1\) 個數字可劃分為 \(i-1\) 段。

我們發現對於每次 \(dp[i][j]\) 轉移,只需用到 \(dp[i-1][j]\),所以 \(dp\) 陣列可以滾動掉第一維,但是沒用。

時間複雜度為 \(O(n^3k)\),顯然超時。

考慮最佳化:

最佳化 1:發現每次計算 \(val(l, r)\) 需要 \(O(n)\) 的時間複雜度,其中每個數都會經過很多次重複計算。所以我們反過來考慮,對於每個數,它對於 \(val(l, r)\) 的貢獻在哪裡。

記這個數 \(a[i]\) 前面的第一個與它相等的數的位置為 \(pre[a[i]]\)
那麼這個 \(a[i]\) 對於區間 \([pre[a[i]], i-1]\) 均有 \(1\) 的貢獻。

最佳化 2:發現轉移方程裡有 \(\max(\dots)\),並且沒有其它量(其實有已經確定的量也行),所以考慮用線段樹來最佳化。

因為每次由第 \(i-1\) 層轉移到第 \(i\) 層,所以我們順序 DP,先用上一次的 \(i-1\) 的 DP 值建樹。

然後每掃過一個數 \(a[j]\),就線上段樹上的區間 \([pre[a[j]], j-1]\) 全部加上 \(1\)

每次更新,就線上段樹上 \([i-1, j-1]\) 的區間找 \(\max\) 即可。

總時間複雜度為 \(O(nk\log n)\)


樸素 DP 程式碼:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 105;

int a[N];
int dp[N][N]; // dp[i][j] 表示前 j 個數劃分為 i 個連續區間的最大總價值
bool tong[N];

int calc(int l, int r){
    memset(tong, 0, sizeof(tong));
    int res = 0;
    for(int i=l; i<=r; i++){
        if(tong[a[i]]) continue;
        tong[a[i]] = 1;
        res++;
    }
    return res;
}

int main(){
    ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
    int n, t; cin>>n>>t;
    for(int i=1; i<=n; i++)
        cin>>a[i];
    for(int i=1; i<=t; i++){
        for(int j=1; j<=n; j++){
            for(int k=i-1; k<=j-1; k++){
                dp[i][j] = max(dp[i][j], dp[i-1][k]+calc(k+1, j));
            }
        }
    }
    cout<<dp[t][n];
    return 0;
}

正解程式碼:

#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 35005;

int a[N], pre[N], p[N];
int dp[55][N]; // dp[i][j] 表示前 j 個數劃分為 i 個連續區間的最大總價值
struct node{
    int l, r;
    int maxn, tag;
    #define ls (x<<1)
    #define rs (x<<1|1)
}tr[N<<2];

void pushup(int x){
    tr[x].maxn = max(tr[ls].maxn, tr[rs].maxn);
}

void pushdown(int x){
    if(!tr[x].tag) return;
    tr[ls].maxn += tr[x].tag;
    tr[rs].maxn += tr[x].tag;
    tr[ls].tag += tr[x].tag;
    tr[rs].tag += tr[x].tag;
    tr[x].tag = 0;
}

void build(int x, int l, int r, int k){
    tr[x].l = l, tr[x].r = r, tr[x].maxn = 0; tr[x].tag = 0;
    if(l == r){
        tr[x].maxn = dp[k][l];
        return;
    }
    int mid = (l+r)>>1;
    build(ls, l, mid, k);
    build(rs, mid+1, r, k);
    pushup(x);
}

void update(int x, int l, int r, int v){
    if(tr[x].l>=l && tr[x].r<=r){
        tr[x].maxn += v;
        tr[x].tag += v;
        return;
    }
    int mid = (tr[x].l+tr[x].r)>>1;
    pushdown(x);
    if(l<=mid) update(ls, l, r, v);
    if(r>mid) update(rs, l, r, v);
    pushup(x);
}

後記:不知道什麼時候能花個時間補一下 NOIP2023 T4。

相關文章