對於這樣一類方程
\(dp_i=\min \limits_{j=1}^{i-1}(dp_j-a_ic_j)\),其中 \(a,c\) 都為正整數且遞增:
如果直接計算,時間複雜度為 \(\mathcal{O}(N^2)\)。
使用斜率最佳化,可以將時間複雜度將為 \(\mathcal{O}(N)\)。
在學習本節之前,請先學會單調佇列,還要知道在平面直角座標系中,斜率越小,直線(線段)越平;斜率越大,直線(線段)越陡。
建模
我們將方程變形一下。去掉 \(\min\) 和移項,得到:
\(dp_j=a_ic_j+dp_i\)。
設 \(y=dp_j,k=a_i,x=c_j,b=dp_i\),得到:
\(y=kx+b\)。
我們稱 \(k\) 為斜率,\(b\) 為截距。可以發現,最小的截距就是我們要求的 \(dp_i\)。
求一個 \(dp_i\)
假設我們已經求出了 \(dp_1 \sim dp_{i-1}\)。我們將這 \(i-1\) 個點的 \(x,y\) 座標畫到座標系上。假設這些點的位置如下圖。
因為 \(c\) 陣列遞增,所以每個點的 \(x\) 座標遞增,位置依次向右。
接下來,我們將斜率 \(k(a_i)=1\) 代入,生成直線:
\(b\) 值即為直線與 \(y\) 軸交點的 \(x\) 座標。在上圖中,點 \(2\) 的斜率最小。
假設我們再加入點 \(4\),影像如下:
此時,斜率為任意正整數時,\(4\) 的截距一定都比 \(3\) 的截距小,因此我們刪除點 \(3\)。
我們再將每個點順序連線:
然後刪除點 \(3\):
可以發現,這些點組成了下凸殼,而點 \(3\) 形成了上凸殼而被刪除了。
所以在求 \(dp_i\) 時,我們只需要將斜率 \(a_i\) 代入,在這些有用的點中找截距最小的點即可。
求所有 \(dp_i\)
剛才求出一個 \(dp_i\) 的時間複雜度為 \(\mathcal{O}(N)\),總時間複雜度依然為 \(\mathcal{O}(N^2)\)。如何最佳化?
注意條件中 \(a\) 陣列也是遞增的。在斜率逐漸遞增時,能取到最小值的點是逐漸向右的。
比如上圖的斜率為 \(1\),如果把斜率增加到 \(5\),則變成下圖:
此時點 \(4\) 變為最優點。
因此,我們可以用單調佇列來維護有用的點。如果新加的點與隊尾的點不滿足下凸殼,則彈出隊尾。那麼什麼時候該彈出隊頭呢?
可以發現,最優點與其左側點組成線段的斜率一定是小於 \(a_i\) 的,最優點與其右側點組成線段的斜率一定是大於 \(a_i\) 的。如上面斜率為 \(1\) 的圖中,最優點是 \(2\),線段 \(1 \to 2\) 的斜率小於 \(1\),線段 \(2 \to 4\) 的斜率大於 \(1\)。
因此,如果佇列中前兩點組成線段的斜率小於等於 \(a_i\),我們就把隊頭彈掉。滿足條件後,隊頭的點即為最優點。
單調佇列的時間複雜度為 \(\mathcal{O}(N)\)。
例題
給出兩個序列 \(a,c\),保證這兩個序列中的元素遞增。求出另一個序列 \(dp\),使得:
\(dp_i=\min \limits_{j=1}^{i-1}(dp_j+a_ic_j)\)。
特別的,\(dp_1=0\)。
\(1 \le N \le 10^6,1 \le a_i,c_i \le 3 \times 10^6\),最後答案不會小於 \(9 \times 10^{18}\)。
我們用 \(X(a)\) 表示 \(a\) 點的 \(x\) 座標,\(Y(a)\) 表示 \(a\) 點的 \(y\) 座標。計算斜率時會用到小數,容易有精度錯誤,因此我們改用乘法。用 \(slope1(a,b)\) 表示線段 \(ab(a<b)\) 斜率的分子,\(slope2(a,b)\) 表示線段 \(ab(a<b)\) 斜率的分母,\(cmp1(a,b,k)\) 判斷線段 \(ab\) 的斜率是否小於 等於 \(k\),\(cmp2(a,b,c,d)\) 判斷線段 \(ab\) 的斜率是否大於等於線段 \(cd\) 的斜率。
最後程式碼如下。
#include<cstdio>
#define UP(i,a,b) for(i=a;i<=(b);++i)
#define DN(i,a,b) for(i=a;i>=(b);--i)
typedef long long ll;
const int N=1e6+5;
int a[N],c[N],n,q[N],h,t;
ll dp[N];
ll X(int x){
return c[x];
}
ll Y(int x){
return dp[x];
}
ll slope1(int x,int y){
return Y(y)-Y(x);
}
ll slope2(int x,int y){
return X(y)-X(x);
}
bool cmp1(int x,int y,int k){
/*slope1(x,y)/slope2(x,y)<=k*/
/*slope1(x,y)<=k*slope2(x,y)*/
return slope1(x,y)<=k*slope2(x,y);
}
bool cmp2(int x,int y,int e,int f){
/*slope1(x,y)/slope2(x,y)>=slope1(e,f)/slope2(e,f)*/
/*slope1(x,y)*slope2(e,f)>=slope1(e,f)*slope2(x,y)*/
return slope1(x,y)*slope2(e,f)>=slope1(e,f)*slope2(x,y);
}
int main(){
int i,j;
scanf("%d",&n);
h=t=1;
UP(i,1,n){
scanf("%d%d",a+i,c+i);
/*彈掉隊頭斜率<=a[i]的*/
while(h<t&&cmp1(q[h],q[h+1],a[i])){
++h;
}
j=q[h];
/*此時j即為最優點*/
dp[i]=dp[j]-1ll*a[i]*c[j];
/*彈掉 隊尾兩點線段斜率>=線段(隊尾點->i)斜率的點*/
while(h<t&&cmp2(q[t-1],q[t],q[t],i)){
--t;
}
q[++t]=i;
printf("%lld%c",dp[i]," \n"[i==n]);
}
return 0;
}