斜率最佳化 DP

lrx139發表於2024-04-08

對於這樣一類方程

\(dp_i=\min \limits_{j=1}^{i-1}(dp_j-a_ic_j)\),其中 \(a,c\) 都為正整數且遞增:

如果直接計算,時間複雜度為 \(\mathcal{O}(N^2)\)

使用斜率最佳化,可以將時間複雜度將為 \(\mathcal{O}(N)\)

在學習本節之前,請先學會單調佇列,還要知道在平面直角座標系中,斜率越小,直線(線段)越平;斜率越大,直線(線段)越陡

建模

我們將方程變形一下。去掉 \(\min\) 和移項,得到:

\(dp_j=a_ic_j+dp_i\)

\(y=dp_j,k=a_i,x=c_j,b=dp_i\),得到:

\(y=kx+b\)

我們稱 \(k\) 為斜率,\(b\) 為截距。可以發現,最小的截距就是我們要求的 \(dp_i\)

求一個 \(dp_i\)

假設我們已經求出了 \(dp_1 \sim dp_{i-1}\)。我們將這 \(i-1\) 個點的 \(x,y\) 座標畫到座標系上。假設這些點的位置如下圖。

因為 \(c\) 陣列遞增,所以每個點的 \(x\) 座標遞增,位置依次向右。

接下來,我們將斜率 \(k(a_i)=1\) 代入,生成直線:

\(b\) 值即為直線與 \(y\) 軸交點的 \(x\) 座標。在上圖中,點 \(2\) 的斜率最小。

假設我們再加入點 \(4\),影像如下:

此時,斜率為任意正整數時,\(4\) 的截距一定都比 \(3\) 的截距小,因此我們刪除點 \(3\)

我們再將每個點順序連線:

然後刪除點 \(3\)

可以發現,這些點組成了下凸殼,而點 \(3\) 形成了上凸殼而被刪除了。

所以在求 \(dp_i\) 時,我們只需要將斜率 \(a_i\) 代入,在這些有用的點中找截距最小的點即可。

求所有 \(dp_i\)

剛才求出一個 \(dp_i\) 的時間複雜度為 \(\mathcal{O}(N)\),總時間複雜度依然為 \(\mathcal{O}(N^2)\)。如何最佳化?

注意條件中 \(a\) 陣列也是遞增的。在斜率逐漸遞增時,能取到最小值的點是逐漸向右的。

比如上圖的斜率為 \(1\),如果把斜率增加到 \(5\),則變成下圖:

此時點 \(4\) 變為最優點。

因此,我們可以用單調佇列來維護有用的點。如果新加的點與隊尾的點不滿足下凸殼,則彈出隊尾。那麼什麼時候該彈出隊頭呢?

可以發現,最優點與其左側點組成線段的斜率一定是小於 \(a_i\) 的,最優點與其右側點組成線段的斜率一定是大於 \(a_i\) 的。如上面斜率為 \(1\) 的圖中,最優點是 \(2\),線段 \(1 \to 2\) 的斜率小於 \(1\),線段 \(2 \to 4\) 的斜率大於 \(1\)

因此,如果佇列中前兩點組成線段的斜率小於等於 \(a_i\),我們就把隊頭彈掉。滿足條件後,隊頭的點即為最優點。

單調佇列的時間複雜度為 \(\mathcal{O}(N)\)

例題

給出兩個序列 \(a,c\),保證這兩個序列中的元素遞增。求出另一個序列 \(dp\),使得:
\(dp_i=\min \limits_{j=1}^{i-1}(dp_j+a_ic_j)\)
特別的,\(dp_1=0\)
\(1 \le N \le 10^6,1 \le a_i,c_i \le 3 \times 10^6\),最後答案不會小於 \(9 \times 10^{18}\)

我們用 \(X(a)\) 表示 \(a\) 點的 \(x\) 座標,\(Y(a)\) 表示 \(a\) 點的 \(y\) 座標。計算斜率時會用到小數,容易有精度錯誤,因此我們改用乘法。用 \(slope1(a,b)\) 表示線段 \(ab(a<b)\) 斜率的分子,\(slope2(a,b)\) 表示線段 \(ab(a<b)\) 斜率的分母,\(cmp1(a,b,k)\) 判斷線段 \(ab\) 的斜率是否小於 等於 \(k\)\(cmp2(a,b,c,d)\) 判斷線段 \(ab\) 的斜率是否大於等於線段 \(cd\) 的斜率。

最後程式碼如下。

#include<cstdio>
#define UP(i,a,b) for(i=a;i<=(b);++i)
#define DN(i,a,b) for(i=a;i>=(b);--i)

typedef long long ll;

const int N=1e6+5;
int a[N],c[N],n,q[N],h,t;
ll dp[N];

ll X(int x){
	return c[x];
}
ll Y(int x){
	return dp[x];
}
ll slope1(int x,int y){
	return Y(y)-Y(x);
}
ll slope2(int x,int y){
	return X(y)-X(x);
}
bool cmp1(int x,int y,int k){
	/*slope1(x,y)/slope2(x,y)<=k*/
	/*slope1(x,y)<=k*slope2(x,y)*/
	return slope1(x,y)<=k*slope2(x,y);
}
bool cmp2(int x,int y,int e,int f){
	/*slope1(x,y)/slope2(x,y)>=slope1(e,f)/slope2(e,f)*/
	/*slope1(x,y)*slope2(e,f)>=slope1(e,f)*slope2(x,y)*/
	return slope1(x,y)*slope2(e,f)>=slope1(e,f)*slope2(x,y);
}
int main(){
	int i,j;
	scanf("%d",&n);
	h=t=1;
	UP(i,1,n){
		scanf("%d%d",a+i,c+i);
		/*彈掉隊頭斜率<=a[i]的*/
		while(h<t&&cmp1(q[h],q[h+1],a[i])){
			++h;
		}
		j=q[h];
		/*此時j即為最優點*/
		dp[i]=dp[j]-1ll*a[i]*c[j];
		/*彈掉 隊尾兩點線段斜率>=線段(隊尾點->i)斜率的點*/
		while(h<t&&cmp2(q[t-1],q[t],q[t],i)){
			--t;
		}
		q[++t]=i;
		printf("%lld%c",dp[i]," \n"[i==n]);
	}
	return 0;
}

相關文章