斜率優化動態規劃

Viego發表於2021-02-02

前言

斜率優化通常使用單調佇列輔助進行實現,用於優化 \(DP\) 的時間複雜度,比較抽象,需要讀者有較高的數學素養。

本文例題連結

適用範圍

使用單調佇列優化 \(DP\) ,通常可以解決型如: \(dp[i]=min(f(j))+g(i)\) 的狀態轉移方程。其中 \(f(i)\) 是隻關於 \(i\) 的函式, \(g(j)\) 是隻關於 \(j\) 的函式。樸素的解決方法是在第二層迴圈中列舉 \(j\) 來實現最小值,時間複雜度為 \(O(n^2)\) 。可以使用單調佇列來維護這個最小值實現 \(O(n)\) 的時間複雜度。

而斜率優化利用上述方法進行改進,實現對於型如: \(dp[i]=min(f(i,j))+g(i)\) 的狀態轉移方程。對比第一種情況,可以發現函式 \(f\) 函式與兩個值 \(i,j\) 都有關,簡單地使用單調佇列是無法優化的。這時候就開始引入主題斜率優化了。

下面結合一道例題來具體詳解。題目來自於 \(HNOI2008\) 省選題目。

題目大意

\(n\) 個數字 \(C\),把它分為若干組,給出另一個數 \(L\) ,每組的花費為\((i-j+\sum_{k=i}^jC_k-L)^2\),總花費為所有組的花費之和。求最小總花費。

思路

先考慮樸素的 \(dp\) 做法。

\(dp[i]\) 為將前 \(i\) 個數字分組後的最小花費。求和可以考慮使用字首和來優化,設字首和陣列為 \(pre\) 。則狀態轉移方程可以寫為:

\(dp[i]=Min(dp[j]+(sum[i]-sum[j])+(i-(j+1))-L)^2,0≤j<i)\)

即是:

\(dp[i]=Min(dp[j]+(sum[i]-sum[j]+i-j-L-1)^2,0≤j<i)\)

那麼 \(sum\) 陣列可以初始化為:

for(int i = 1; i <= n; i++) {
	Quick_Read(val[i]);
	sum[i] = sum[i - 1] + val[i];
}

\(pre[i]=sum[i]+i\) ,再進一步設 \(l=L+1\) 那麼狀態轉移方程可以寫為:

\(dp[i]=Min(dp[j]+(pre[i]-pre[j]-l)^2,0≤j<i)\)

狀態轉移

int Get_Dp(int i, int j) {
	return dp[j] + (pre[i] - pre[j] - l) * (pre[i] - pre[j] - l);
}

若列舉 \(j\) ,則時間複雜度為 \(O(n)^2\) ,時間複雜度不優。使用斜率優化可以對其進行優化。

假設當前列舉到 \(i\) ,需要得到 \(i\) 的狀態。假設有兩個決策點 \(j\)\(k\) ,滿足決策點 \(j\) 優於決策點 \(k\) 。用符號語言可以表達為:

\(dp[j]+(pre[i]-pre[j]-l)^2<dp[k]+(pre[i]-pre[k]-l)^2\)

展開得:

\(dp[j]+pre[i]^2+pre[j]^2+l^2-2\times pre[i]\times pre[j]-2\times l\times pre[i]+2\times l\times pre[j]<dp[k]+pre[i]^2+pre[k]^2+l^2-2\times pre[i]\times pre[k]-2\times l\times pre[i]+2\times l\times pre[k]\)

進一步整理得 :

\(dp[j]+pre[j]^2-dp[k]-pre[k]^2<(pre[i]-l)\times 2\times (pre[j] - pre[k])\)

觀察可得:左邊的式子只與 \(j\)\(k\) 有關,但右邊的式子還與 \(i\) 有關。也可以發現若滿足上述式子,則會有 \(j\) 優於 \(k\) 。再分類討論:

  1. \(j>k\) ,則 \(pre[j]>pre[k]\),移項得 \(\frac{dp[j]+pre[j]^2-(dp[k]+pre[k]^2)}{pre[j] - pre[k]}<pre[i]-l\)\(pre[i]-l\) 可以 看為一個常數。那麼意味著點 \(j(dp[j]+pre[j]^2,pre[j])\) 與點 \(k(dp[k]+pre[k]^2,pre[k])\) 所構成的直線的斜率小於 \(pre[i]-l\) 這個常數。
  2. \(j<k\) ,則 \(pre[j]<pre[k]\),移項得 \(\frac{dp[j]+pre[j]^2-(dp[k]+pre[k]^2)}{pre[j] - pre[k]}>pre[i]-l\)\(pre[i]-l\) 可以 看為一個常數。那麼意味著點 \(j(dp[j]+pre[j]^2,pre[j])\) 與點 \(k(dp[k]+pre[k]^2,pre[k])\) 所構成的直線的斜率大於 \(pre[i]-l\) 這個常數。

獲得分子的函式:

int Get_Up(int j, int k) {
	return dp[j] + pre[j] * pre[j] - dp[k] - pre[k] * pre[k];
}

獲得分母的函式:

int Get_Down(int j, int k) {
	return pre[j] - pre[k];
}

有了上述的一級結論,可以進一步推匯出二級結論:
在這裡插入圖片描述
\(x,y\) 的斜率表示為 \(k(x,y)\) 。若存在三點 \(a,b,c\) ,有 \(k(a,b)>k(b,c)\) ,即是影像形成上凸的形狀時,那麼點 \(b\) 絕對不是最優的。

分類討論:

  1. \(k(a,b)>k(b,c)>pre[i]-l\) ,則對於上述結論可以得出 \(a\)\(b\) 更優,捨去 \(b\)
  2. \(pre[i]-l>k(a,b)>k(b,c)\) ,則對於上述結論可以得出 \(c\)\(b\) 更優,捨去 \(b\)
  3. \(pre[i]-l<k(a,b)\)\(pre[i]-l>k(b,c)\) ,則對於上述結論可以得出 \(a\)\(c\) 都比 \(b\) 更優,捨去 \(b\)

那麼就可以得出答案的點必須滿足 \(k(a_1,a_2)<k(a_2,a_3)<...<k(a_{m-1},a_m)\) 。全部呈現出下凸狀態,如下圖。
在這裡插入圖片描述
這樣下標遞增,斜率遞減的點集可以使用單調佇列來維護。

找出當前最優的點為 \(que[head]\) ,即隊頭元素。

while(Get_Up(que[head + 1], que[head]) <= 2 * (pre[i] - l) * Get_Down(que[head + 1], que[head]) && head < tail)
	head++;

用當前點 \(i\) 來更新佇列,使得該佇列呈下凸之勢。

while(Get_Up(que[tail], que[tail - 1]) * Get_Down(i, que[tail]) >= Get_Up(i, que[tail]) * Get_Down(que[tail], que[tail - 1]) && head < tail)
	tail--;

按照上述方法進行狀態轉移,得到的 \(dp[n]\) 就是當前的最優解。

C++程式碼

程式碼比較短,一氣呵成。(注意要開 \(long\) \(long\)

#include <cstdio>
#define int long long
void Quick_Read(int &N) {
	N = 0;
	int op = 1;
	char c = getchar();
	while(c < '0' || c > '9') {
		if(c == '-')
			op = -1;
		c = getchar();
	}
	while(c >= '0' && c <= '9') {
		N = (N << 1) + (N << 3) + (c ^ 48);
		c = getchar();
	}
	N *= op;
}
void Quick_Write(int N) {
	if(N < 0) {
		putchar('-');
		N = -N;
	}
	if(N >= 10)
		Quick_Write(N / 10);
	putchar(N % 10 + 48);
}
const int MAXN = 5e5 + 5;
int dp[MAXN];
int pre[MAXN], val[MAXN];
int n, l;
int que[MAXN];
int head, tail;
int Get_Dp(int i, int j) {
	return dp[j] + (pre[i] - pre[j] - l) * (pre[i] - pre[j] - l);
}
int Get_Up(int j, int k) {
	return dp[j] + pre[j] * pre[j] - dp[k] - pre[k] * pre[k];
}
int Get_Down(int j, int k) {
	return pre[j] - pre[k];
}
void Line_Dp() {
	head = 1;
	tail = 1;
	for(int i = 1; i <= n; i++) {
		while(Get_Up(que[head + 1], que[head]) <= 2 * (pre[i] - l) * Get_Down(que[head + 1], que[head]) && head < tail)
			head++;
		dp[i] = Get_Dp(i, que[head]);
		while(Get_Up(que[tail], que[tail - 1]) * Get_Down(i, que[tail]) >= Get_Up(i, que[tail]) * Get_Down(que[tail], que[tail - 1]) && head < tail)
			tail--;
		que[++tail] = i;
	}
	Quick_Write(dp[n]);
}
void Read() {
	Quick_Read(n);
	Quick_Read(l);
	l++;
	for(int i = 1; i <= n; i++) {
		Quick_Read(val[i]);
		pre[i] = pre[i - 1] + val[i] + 1;
	}
}
signed main() {
	Read();
	Line_Dp();
	return 0;
}

相關文章