BZOJ 1044: [HAOI2008]木棍分割 DP,字首和優化,二分答案

just_sort發表於2017-01-21

Description

  有n根木棍, 第i根木棍的長度為Li,n根木棍依次連結了一起, 總共有n-1個連線處. 現在允許你最多砍斷m個連
接處, 砍完後n根木棍被分成了很多段,要求滿足總長度最大的一段長度最小, 並且輸出有多少種砍的方法使得總長
度最大的一段長度最小. 並將結果mod 10007。。。
Input

  輸入檔案第一行有2個數n,m.接下來n行每行一個正整數Li,表示第i根木棍的長度.n<=50000,0<=m<=min(n-1,10
00),1<=Li<=1000.
Output

  輸出有2個數, 第一個數是總長度最大的一段的長度最小值, 第二個數是有多少種砍的方法使得滿足條件.
Sample Input
3 2

1

1

10
Sample Output
10 2
HINT

兩種砍的方法: (1)(1)(10)和(1 1)(10)

解題方法:好題好題好題,磨了一晚上加一早上終於過去了這個題,好難。我們來說一說這個題的思路,第一問顯然一個二分就可以了。關鍵是第二問,我們容易想到這樣一個DP,我們設f[i][j]表示前i段一共分割了j次,設ss[i]為a[i]的字首和,然後寫出dp方程:

f[i][j] = Σf[k][j-1] 其中k要滿足的條件是(1 <= k < i) && (ss[i] - ss[k] <= len)(這是很容易從題目中得出的)。

於是我們就可以完成了。

但是這樣也太簡單了吧……畢竟是HAOI的題目,如果這麼簡單就是NOIP難度了(雖然本人不否認以前的省選題目也有NOIP難度的)

然後注意到資料範圍:n<=50000,0<=m<=min(n-1,1000)

我們注意到我們程式的時間複雜度實際上是O(n^2 m) 的,這明顯就是爆了時間的。

那然後該怎麼辦呢?

我們可以注意到,如果我們設sumf 表示列舉到k的時候Σf[k][j-1],(1 <= k < i) && (ss[i] - ss[k] <= len),mink表示滿足(1 <= k < i) && (ss[i] - ss[k] <= len)的最小的k。

其實對於 f[i][Now] ,其實是 f[mink][Last]…f[i-1][Last] 這一段 f[k][Last] 的和,mink 是滿足 Sum[i] - Sum[k] <= Len 的最小的 k ,對於從 1 到 n 列舉的 i ,相對應的 mink 也一定是非遞減的(因為 Sum[i] 是遞增的)。我們記錄下 f[1][Last]…f[i-1][Last] 的和 Sumf ,mink 初始設為 1,每次對於 i 將 mink 向後推移,推移的同時將被捨棄的 p 對應的 f[p][Last] 從 Sumf 中減去。那麼 f[i][Now] 就是 Sumf 的值。(此段複製自Evensgn的部落格,因為我覺得自己可能寫不出來這麼詳細)

這樣我們就不必列舉k,時間複雜度就降低到可以接受的O(nm)了。

但是這樣就完成了?別天真了,還有一個坑那,時間解決了,空間呢?我們的空間複雜度是O(nm)啊,用計算器算一下明顯超了。

這時候的DP有一個技巧(類似於飛揚的小鳥NOIP2014),我們發現其實j所屬的那一維,只能由j-1轉移而來,所以可以使用最常用的手段——滾動陣列,來滾動掉第二維

使用now和pre,f[maxn][2],now和pre只能為0或1,且pre = now^1,每完成一遍外層m迴圈更新now ^= 1,pre = now^1。

這樣子我們的空間複雜度也降到可以接受的O(n)辣了!(上面的題解是貼上的這個部落格的,見這裡) 我和他的思路是一樣的。

程式碼如下:

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 50005;
const int P = 10007;
int a[N], sum[N], pre[N];
int dp[N][2];
int presum[N];
int ans, n, m;
int answer;
bool check(int x){
    int num = 0, cnt = 0;
    for(int i = 1; i <= n; i++){
        if(num + a[i] <= x){
            num += a[i];
        }
        else{
            if(a[i] > x) return false;
            else{
                cnt++;
                num = a[i];
            }
        }
    }
    if(cnt <= m) return true;
    else return false;
}
int main(){
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i++){
        scanf("%d", &a[i]);
        sum[i] = sum[i-1] + a[i];
    }
    int l = 1, r = 1e9;
    while(l < r){
        int mid = (l + r) / 2;
        if(check(mid)) r = mid;
        else l = mid + 1;
    }
    ans = l;
    int now = 0;
    int sumf = 0;
    int mink = 0;
    for(int i = 0; i <= m; i++){
        sumf = 0;
        mink = 1;
        for(int j = 1; j <= n; j++){
            if(i == 0){
                if(sum[j] <= ans) dp[j][now] = 1;
                else dp[j][now] = 0;
            }
            else{
                while(mink < j && sum[j] - sum[mink] > ans){
                    sumf -= dp[mink][now^1];
                    sumf = (sumf + P) % P;
                    mink++;
                }
                dp[j][now] = sumf;
            }
            sumf += dp[j][now^1];
            sumf = (sumf + P) % P;
        }
        answer += dp[n][now];
        answer = (answer + P) % P;
        now ^= 1;
    }
    printf("%d %d\n", ans, answer);
    return 0;
}

相關文章