【DP】區間DP入門

HinanawiTenshi發表於2021-02-15

在開始之前我要感謝y總,是他精彩的講解才讓我對區間DP有較深的認識。

簡介

一般是線性結構上的對區間進行求解最值,計數的動態規劃。大致思路是列舉斷點,然後對斷點兩邊求取最優解,然後進行合併從而得解。

原理

結合模板題(合併石子)講述:https://www.acwing.com/problem/content/284/

因為題目具有合併相鄰物品的性質,所以在合併的過程中,必然會在最後一步出現兩個物品合二為一的情況,而這兩個物品則是分別由左側的物品、右側的物品合併而來的。 因此,我們的思路是列舉最後一步合併兩個物品時候的斷點(記為 \(k\) ),為了方便起見,我們可以將斷點放在某個物品上面。

結合樣例具體來說:

k
1 3 5 2
  k
1 3 5 2
    k
1 3 5 2
    k
1 3 5 2

上面便是四個斷點。


對於本題,我們記f[l][r]為合併 \([l,r]\) 的物品所能得到的最小貢獻。
而斷點將 \([l,r]\) 分為了 \([l,k],[k+1,r]\) ,這兩個區間的貢獻分別是 f[l][k],f[k+1][r] 而合併這兩個區間的貢獻則是 sum(l,r)) (其中sum(l,r) 表示 \([l,r]\) 的物品的權值和)

從而得到遞推方程式: f[l][r] = min(f[l][r],f[l][k]+f[k+1][r]+sum(l,r))

可以看出,在列舉斷點的過程中,我們已經覆蓋了所有情況(根據斷點所有可能位置分類),因此這樣做能夠保證得到答案。

至此,在思維上不會有太大困難。

下面講一下怎麼用遞推的方法求解:

由本題的邏輯結構可知,我們要先處理出小區間的 \(f值\) 才能夠保證大區間可以得到更新,所以我們第一重迴圈列舉的是區間的長度len,下面的部分則是列舉起點(即 l), 結合長度我們可以得到 r = l+len-1 ,進而我們得到了相應的區間 \([l,r]\) ,接下來列舉斷點 \(k\) 即可。

結合程式碼理解:

#include<bits/stdc++.h>
using namespace std;

const int INF=0x3f3f3f3f;
const int N=305;

int f[N][N];
int w[N],s[N];
int n;
int main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>w[i];
        s[i]=s[i-1]+w[i];
    }
    
    for(int len=1;len<=n;len++)
        for(int l=1;l+len-1<=n;l++){
            int r=l+len-1;
            if(len==1){
                f[l][r]=0;
            }else{
                f[l][r]=INF;
                for(int k=l;k<r;k++)
                    f[l][r]=min(f[l][r],f[l][k]+f[k+1][r]+s[r]-s[l-1]);
            }
        }
    cout<<f[1][n]<<endl;
    
    return 0;
}

當然,也可以採取記憶化搜尋,這樣不需要考慮太多。

例題

環形石子合併:https://www.acwing.com/activity/content/problem/content/1297/1/

分析

這題無非是將上題排成一列的物品放在了環上,因此我們可以採取斷環成鏈的技巧:
顯然,合併 \(n\) 個物品需要 \(n-1\) 步,因此,必然存在兩個物品,它們並沒有進行合併,那麼它們之間便出現了“斷邊”,這樣的“斷邊”並不會參與到合併的過程中,問題便由環轉化為鏈的情況,所以我們只需列舉“斷邊”,然後進行求解即可。

有一個技巧:只需將原有的物品再按順序“複製”一份,分別得到區間:

對於樣例:

4 5 9 4

複製:

4 5 9 4 4 5 9 4

然後依次把區間(記為 \([s,t]\) )取出求解:

s     t
4 5 9 4 4 5 9 4
  s     t
4 5 9 4 4 5 9 4
    s     t
4 5 9 4 4 5 9 4
      s     t
4 5 9 4 4 5 9 4

(最後一個複製的元素是沒用的,可以忽略)

這樣分別求解四個子問題就行了。

程式碼:

#include<bits/stdc++.h>
using namespace std;

#define INF 0x3f3f3f3f

const int N = 410;

int f[N][N],g[N][N];
int s[N],w[N];
int n;

int main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>w[i];
        w[i+n]=w[i];
    }
    
    memset(f,0x3f,sizeof f);
    memset(g,0xcf,sizeof g);
    
    for(int i=1;i<=2*n;i++) s[i]=s[i-1]+w[i];
    
    for(int len=1;len<=n;len++){
        for(int l=1;l+len-1<=n*2;l++){
            int r=l+len-1;
            
            if(len==1) f[l][r]=g[l][r]=0;
            else{
                for(int k=l;k<r;k++){
                    f[l][r]=min(f[l][r],f[l][k]+f[k+1][r]+s[r]-s[l-1]);
                    g[l][r]=max(g[l][r],g[l][k]+g[k+1][r]+s[r]-s[l-1]);
                }
            }
                
        }
    }
    
    int maxv=-INF,minv=INF;
    for(int i=1;i<=n;i++){
        maxv=max(maxv,g[i][i+n-1]);
        minv=min(minv,f[i][i+n-1]);
    }
    
    cout<<minv<<endl<<maxv<<endl;
    
    return 0;
}

記憶化搜尋版本:(比較久之前寫的emm)

#include<bits/stdc++.h>
using namespace std;
#define maxn 101
int n;
int a[maxn<<1];
int f_max[maxn][maxn];
int f_min[maxn][maxn];
int rec[maxn];
int s[maxn];

int sum(int l,int r){
    return s[r]-s[l-1];
}

int dfs_max(int l,int r){
    if(l==r) return f_max[l][r]=0;
    if(f_max[l][r]) return f_max[l][r];

    int res=0;
    for(int k=l;k+1<=r;k++){
        res=max(res,dfs_max(l,k)+dfs_max(k+1,r)+sum(l,r));
    }
    return f_max[l][r]=res;
}

int dfs_min(int l,int r){
    if(l==r) return f_min[l][r]=0;
    if(f_min[l][r]) return f_min[l][r];

    int res=INT_MAX;
    for(int k=l;k+1<=r;k++){
        res=min(res,dfs_min(l,k)+dfs_min(k+1,r)+sum(l,r));
    }
    return f_min[l][r]=res;
}

int main(){
    cin>>n;
    for(int i=1;i<=n-1;i++) cin>>a[i],a[i+n]=a[i];
    cin>>a[n];

    int rec_max=0;
    int rec_min=INT_MAX;

    for(int st=1;st<=n;st++){
        memset(rec,0,sizeof(rec));
        memset(s,0,sizeof(s));
        memset(f_max,0,sizeof(f_max));
        memset(f_min,0,sizeof(f_min));
        for(int i=st;i<=st+n-1;i++) rec[i-st+1]=a[i];

        s[1]=rec[1];
        for(int i=2;i<=n;i++) s[i]=s[i-1]+rec[i];

        rec_max=max(rec_max,dfs_max(1,n));
        rec_min=min(rec_min,dfs_min(1,n));
    }
    cout<<rec_min<<endl;
    cout<<rec_max<<endl;
    return 0;
}

能量項鍊:https://www.acwing.com/problem/content/322/

分析
和上面題目類似(事實上區間DP的題都差不多),要注意理解是如何合併珠子的。

程式碼:

#include<bits/stdc++.h>
using namespace std;

const int N=105;

int n;
int w[N<<1];
int f[N<<1][N<<1];

int main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>w[i];
        w[n+i]=w[i];
    }
    
    for(int len=3;len<=n+1;len++)
        for(int l=1;l+len-1<=2*n;l++){
            int r=l+len-1;
            for(int k=l+1;k<=r-1;k++)
                f[l][r]=max(f[l][r],f[l][k]+f[k][r]+w[l]*w[k]*w[r]);
        }
        
    int res=0;
    for(int i=1;i<=n;i++) res=max(res,f[i][i+n]);
    
    cout<<res<<endl;
    
    return 0;
}

記憶化搜尋版本:

#include<bits/stdc++.h>
using namespace std;

const int N=210;

int n;
int w[N];
int f[N][N];

int dp(int l,int r){
    if(f[l][r]>=0) return f[l][r];
    if(r==l || r==l+1) return f[l][r]=0;
    
    int &v=f[l][r];
    for(int k=l+1;k<=r-1;k++){
        v=max(v,dp(l,k)+dp(k,r)+w[l]*w[k]*w[r]);
    }
    return v;
}

int main(){
    cin>>n;
    for(int i=1;i<=n;i++){
        cin>>w[i];
        w[n+i]=w[i];
    }
    
    memset(f,-1,sizeof f);
    
    int res=0;
    for(int i=1;i<=n;i++) res=max(res,dp(i,i+n));
    
    cout<<res<<endl;
    
    return 0;
}

加分二叉樹:https://www.acwing.com/problem/content/481/

分析

g[l][r] 表示 \([l,r]\) 的根節點。
將中序遍歷的序列看作是區間求解,然後列舉根節點(將它作為斷點),記錄答案的過程中要注意當答案得到更新的時候才記錄這個區間的根節點。

#include<bits/stdc++.h>
using namespace std;

const int N=35;

int f[N][N]; //dp
int g[N][N]; //path

int n;
int w[N];

void dfs(int l,int r){
    if(l>r) return;
    
    int root=g[l][r];
    cout<<root<<' ';
    dfs(l,root-1);
    dfs(root+1,r);
}
int main(){
    cin>>n;
    for(int i=1;i<=n;i++) cin>>w[i];
    
    for(int len=1;len<=n;len++)
        for(int l=1;l+len-1<=n;l++){
            int r=l+len-1;
            if(len==1){
                f[l][r]=w[l];
                g[l][r]=l;
            }
            else{
                for(int k=l;k<=r;k++){
                    int left= k==l?1:f[l][k-1];
                    int right= k==r?1:f[k+1][r];
                    int score=left*right + w[k];
                    if(score>f[l][r]){
                        f[l][r]=score;
                        g[l][r]=k;
                    }
                }
            }
        }
    
    cout<<f[1][n]<<endl;
    dfs(1,n);
    
    return 0;
}

相關文章