題解:CF1799F Halve or Subtract

ffffyc發表於2024-09-26

\(\text{Link}\)

介紹一下一種高維 wqs 的方法。

此方法來自 @YeahPotato 的專欄 嚴謹的 WQS 二分方法

題意

給定一個長為 \(n\) 的序列 \(v_{1\dots n}\),三個常數 \(d,a,b\)。你可以執行若干次以下兩種操作:

  1. 選擇 \(1\le i\le n\),令 \(v_i\gets\lceil\frac{v_i}{2}\rceil\)
  2. 選擇 \(1\le i\le n\),令 \(v_i\gets\max(v_i-d,0)\)

你至多進行 \(a\) 次操作 1,\(b\) 次操作 2,同時對於每個元素,每種操作至多進行一次。

你需要最小化操作後 \(\sum v\) 的值並輸出。

\(1\le n\le 10^5\)

題解

兩個顯然的性質是,我們會把操作用完、我們會先用操作 1 再用操作 2。而根據費用流建圖,答案關於操作次數 \(a,b\) 均為下凸的。

我們設操作次數限制為 \(a,b\) 時的答案為 \(f(a,b)\),那麼我們需要使用兩層 wqs 二分分別去除兩維限制,而外層二分我們需要求出「使得 \(f(x,b)-kx\) 取到最小值的 \(x\)」,而它並不好求。問題的關鍵為我們無法直接透過調整斜率使得求出切到的點恰為給定值,無法同時使兩維取到 \(a,b\)

此時,我們就需要尋找求解凸函式單點值的更優方法。

有如下結論:

  • \(f(x)\) 關於 \(x\) 上凸時,令 \(g_a(k)=ka+\displaystyle\max_{x}(f(x)-kx)\),那麼有:\(g_a(k)\) 關於 \(k\) 下凸且 \(f(a)=\displaystyle\min_kg_a(k)\)
  • \(f(x)\) 關於 \(x\) 下凸時,令 \(g_a(k)=ka+\displaystyle\min_{x}(f(x)-kx)\),那麼有:\(g_a(k)\) 關於 \(k\) 上凸且 \(f(a)=\displaystyle\max_kg_a(k)\)

證明:不妨考慮證明其中第二條。

以下將 \(g_a(k)\) 簡寫為 \(g(k)\)。令 \(h(k)\)\(f(x)-kx\) 取到最小值的某個 \(x\)

證明 \(g(k)\) 上凸即證 \(\forall k_1,k_2,\forall \lambda\in[0,1]\),令 \(k=\lambda k_1+(1-\lambda )k_2\),有 \(\lambda g(k_1)+(1-\lambda)g(k_2)\le g(k)\)

\[\begin{aligned}&\lambda g(k_1)+(1-\lambda)g(k_2)\\=&\lambda [k_1a+\min_x(f(x)-k_1x)]+(1-\lambda)[k_2a+\min_x(f(x)-k_2x)]\\\le&\lambda [k_1a+(f(h(k))-k_1h(k))]+(1-\lambda)[k_2a+(f(h(k))-k_2h(k))]\\=&g(k)\end{aligned} \]

還需證明 \(g(k)\) 的最大值為 \(f(a)\),那麼由於 \(f(x)\) 關於 \(x\) 下凸,必定有 \(g(f'(a))=f(a)\)。而 \(g(k)\le ka+f(a)-ka=f(a)\),所以 \(f(a)=\max_k g(k)\)

有了這個結論,我們就把較對複雜的凸函式求值轉化為了對較簡單的凸函式求最值。

接下來,我們就可二分或三分求 \(g(k_1)=k_1a+\min_x(f(x,b)-k_1x)\) 的最值;而其中 \(\min_x(f(x,y)-k_1x)\) 又是關於 \(y\) 的下凸函式,再用二分或三分求 \(h(k_2)=k_2b+\min_{x,y}(f(x,y)-k_1x-k_2y)\) 的最值即可。

時間複雜度 \(O(n\log^2 v)\)

核心程式碼:

const int N=5e3+10;
int n,d,a,b,v[N];
inline ll calc(int k1,int k2){
	ll s=0;
	for(int i=1;i<=n;i++)
		s+=min({v[i],(v[i]+1)/2-k1,max(v[i]-d,0)-k2,max((v[i]+1)/2-d,0)-k1-k2});
	return s;
}
inline ll solve2(int k1){
	int L=-1e9,R=0;
	while(L<R){
		int mL=L+R>>1,mR=mL+1;
		ll v1=calc(k1,mL)+1ll*mL*b,v2=calc(k1,mR)+1ll*mR*b;
		if(v1==v2) return v1;
		if(v1<v2) L=mL+1;
		else R=mR-1;
	}
	return calc(k1,L)+1ll*L*b;
}
inline ll solve1(){
	int L=-1e9,R=0;
	while(L<R){
		int mL=L+R>>1,mR=mL+1;
		ll v1=solve2(mL)+1ll*mL*a,v2=solve2(mR)+1ll*mR*a;
		if(v1==v2) return v1;
		if(v1<v2) L=mL+1;
		else R=mR-1;
	}
	return solve2(L)+1ll*L*a;
}