線段樹最佳化 DP & CF833B The Bakery 題解
題目大意:
將一個長度為 \(n\) 的序列分為 \(m\) 段,使得總價值最大。
一段區間的價值表示為區間內不同數字的個數。
\(n \le 35000, m \le 50\)
(雖然原題是 \(k\),但是我程式碼中寫的是 \(m\),所以就改成 \(m\) 了)
首先看到劃分割槽間,算總價值最大的題,可以先考慮樸素的區間 DP。
設 \(dp[i][j]\) 表示前 \(j\) 個數劃分為 \(i\) 段的最大總價值。\(val(l, r)\) 表示區間 \([l, r]\) 的價值,即其中有多少個不同的數。
可得轉移方程:
\({i-1 \le k \le j-1}\) 的原因是最少 \(i-1\) 個數字,最多 \(j-1\) 個數字可劃分為 \(i-1\) 段。
我們發現對於每次 \(dp[i][j]\) 轉移,只需用到 \(dp[i-1][j]\),所以 \(dp\) 陣列可以滾動掉第一維,但是沒用。
時間複雜度為 \(O(n^3k)\),顯然超時。
考慮最佳化:
最佳化 1:發現每次計算 \(val(l, r)\) 需要 \(O(n)\) 的時間複雜度,其中每個數都會經過很多次重複計算。所以我們反過來考慮,對於每個數,它對於 \(val(l, r)\) 的貢獻在哪裡。
記這個數 \(a[i]\) 前面的第一個與它相等的數的位置為 \(pre[a[i]]\)。
那麼這個 \(a[i]\) 對於區間 \([pre[a[i]], i-1]\) 均有 \(1\) 的貢獻。
最佳化 2:發現轉移方程裡有 \(\max(\dots)\),並且沒有其它量(其實有已經確定的量也行),所以考慮用線段樹來最佳化。
因為每次由第 \(i-1\) 層轉移到第 \(i\) 層,所以我們順序 DP,先用上一次的 \(i-1\) 的 DP 值建樹。
然後每掃過一個數 \(a[j]\),就線上段樹上的區間 \([pre[a[j]], j-1]\) 全部加上 \(1\)。
每次更新,就線上段樹上 \([i-1, j-1]\) 的區間找 \(\max\) 即可。
總時間複雜度為 \(O(nk\log n)\)。
樸素 DP 程式碼:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 105;
int a[N];
int dp[N][N]; // dp[i][j] 表示前 j 個數劃分為 i 個連續區間的最大總價值
bool tong[N];
int calc(int l, int r){
memset(tong, 0, sizeof(tong));
int res = 0;
for(int i=l; i<=r; i++){
if(tong[a[i]]) continue;
tong[a[i]] = 1;
res++;
}
return res;
}
int main(){
ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
int n, t; cin>>n>>t;
for(int i=1; i<=n; i++)
cin>>a[i];
for(int i=1; i<=t; i++){
for(int j=1; j<=n; j++){
for(int k=i-1; k<=j-1; k++){
dp[i][j] = max(dp[i][j], dp[i-1][k]+calc(k+1, j));
}
}
}
cout<<dp[t][n];
return 0;
}
正解程式碼:
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 35005;
int a[N], pre[N], p[N];
int dp[55][N]; // dp[i][j] 表示前 j 個數劃分為 i 個連續區間的最大總價值
struct node{
int l, r;
int maxn, tag;
#define ls (x<<1)
#define rs (x<<1|1)
}tr[N<<2];
void pushup(int x){
tr[x].maxn = max(tr[ls].maxn, tr[rs].maxn);
}
void pushdown(int x){
if(!tr[x].tag) return;
tr[ls].maxn += tr[x].tag;
tr[rs].maxn += tr[x].tag;
tr[ls].tag += tr[x].tag;
tr[rs].tag += tr[x].tag;
tr[x].tag = 0;
}
void build(int x, int l, int r, int k){
tr[x].l = l, tr[x].r = r, tr[x].maxn = 0; tr[x].tag = 0;
if(l == r){
tr[x].maxn = dp[k][l];
return;
}
int mid = (l+r)>>1;
build(ls, l, mid, k);
build(rs, mid+1, r, k);
pushup(x);
}
void update(int x, int l, int r, int v){
if(tr[x].l>=l && tr[x].r<=r){
tr[x].maxn += v;
tr[x].tag += v;
return;
}
int mid = (tr[x].l+tr[x].r)>>1;
pushdown(x);
if(l<=mid) update(ls, l, r, v);
if(r>mid) update(rs, l, r, v);
pushup(x);
}
後記:不知道什麼時候能花個時間補一下 NOIP2023 T4。