【PR #12】劃分序列 / Yet Another Mex Problem 題解

AFewSuns發表於2024-03-15

題目連結

題目大意

給定一個長度為 \(n\) 的序列 \(a\),定義一段區間的價值為該區間的 \(\operatorname{mex}\) 乘上區間元素總和。

你需要將序列劃分成若干個長度 \(\leq k\) 的區間。一個劃分方案的價值為劃分出來的每個區間價值之和,求所有劃分方案的價值最大值。

\(1 \leq k \leq n \leq 2\times 10^5,0 \leq a_i \leq n\)

題目分析

第一眼:這不直接維護 \(\operatorname{mex}\) 連續段之後對每段維護凸包,然後用 deque+啟發式合併就行了嗎?

實際上不行,因為從左往右 dp 的時候連續段會分裂。

\(s\)\(a\) 的字首和,先把 dp 的轉移寫出來:\(f_i=\max\limits_{i-k+1\leq j \leq i}{(f_{j-1}+\operatorname{mex}(a_j,\ldots,a_i)\times(s_i-s_{j-1}))}\)

然後從左往右維護以 \(i\) 為結尾的 \(\operatorname{mex}\) 連續段,對 \(\operatorname{mex}\) 相同的一段區間 \([l,r]\) 放一起考慮。

將轉移拆成 \((f_{j-1}-\operatorname{mex}\times s_{j-1})+(\operatorname{mex}\times s_i)\),然後分成兩個部分。對於前面的式子,只需要將 \([l,r]\) 內的所有點 \((s_{j-1},f_{j-1})\) 建個凸包,然後查斜率為 \(\operatorname{mex}\) 的最大值即可,設其為 \(v_{\operatorname{mex}}\)。將所有連續段的貢獻求出來後,後面的式子相當於求 \(\max\limits_{\operatorname{mex}}{(v_{\operatorname{mex}}+\operatorname{mex}\times s_i)}\),也可以看成將所有點 \((-\operatorname{mex},v_{\operatorname{mex}})\) 建個凸包,然後查斜率為 \(s_i\) 的最大值。

眾所周知,整個過程只會有 \(\mathcal O(n)\)\(\operatorname{mex}\) 連續段,將其預處理出來,每個連續段 \((l,r,x,m)\) 表示所有左端點在 \([l,r]\) 內,右端點 \(\geq x\) 的區間 \(\operatorname{mex}\)\(\geq m\)(由於是最大值,所以算小了肯定不優)。

然後考慮 \(k\) 的限制,這相當於能貢獻到 \(i\) 的為若干整連續段+某一段的字尾。將字尾單獨考慮,然後就變成了若干個整段可以給 \(i\) 貢獻,用剛才預處理出來的 \((l,r,x,m)\) 表示即為 \([l,r]\) 可以給 \([x,l+k]\) 貢獻,\(\operatorname{mex}\)\(m\)

第一部分是簡單的,直接用線段樹維護凸包。對於線段樹節點 \([L,R]\),等到 \(f_L\sim f_R\) 求出來後在該節點上建凸包,由於凸包上的點本身是按照橫座標排序的,所以建凸包的時間複雜度是 \(\mathcal O(n\log n)\)。查詢的時候由於隨著時間的推移,以每個位置為左端點的 \(\operatorname{mex}\) 肯定單調不降,所以直接對線段樹上每個凸包維護一個指標即可,查詢的總時間複雜度 \(\mathcal O(n\log n)\)

然後是第二部分,這相當於對每個 \((l,r,x,m)\) 求出 \([l,r]\) 第一部分的答案 \(v_m\) 後,將 \((-m,v_m)\) 這個點貢獻到 \([x,l+k]\) 這個區間,然後列舉到 \(i\) 的時候查詢 \(i\) 這個位置上凸包在斜率為 \(s_i\) 時的最大值。比較簡單的做法是把 \((-m,v_m)\) 看成直線,然後用線段樹+李超線段樹維護,不過時間複雜度是 \(\mathcal O(n\log^2 n)\) 的,不太行。

還是考慮用線段樹維護凸包,由於查詢時 \(s_i\) 單調不降,所以如果能建出凸包,那麼查詢的時候就只需要維護指標,還是單 \(\log\);關鍵是將 \((-m,v_m)\) 塞到線段樹節點的時候並不能保證 \(m\) 單增,這導致插入凸包時多一個 \(\log\)

不過由於插入的點的橫座標 \(m\) 是提前知道的,所以可以離線。先將所有 \((l,r,x,m)\) 按照 \(m\) 排序,然後按順序依次插入線段樹上 \([x,l+k]\) 這段區間,這樣就可以把每個線段樹節點將要插入的點提前按照橫座標排序,同時處理出每個 \((l,r,x,m)\) 在每個插入的線段樹節點上的排名。

由於對於所有的 \((l,r,x,m)\),都有 \(r\leq x\),所以遍歷到每個線段樹節點 \([L,R]\) 時,所有可能貢獻給當前節點的連續段都已經計算完畢,這時再用已經排好序的點建凸包即可。時間複雜度 \(\mathcal O(n\log n)\)

吐槽一下:為什麼單 \(\log\) 跑得比雙 \(\log\) 慢這麼多啊/ll/ll/ll

程式碼

#include<bits/stdc++.h>
using namespace std;
using namespace my_std;
#define LC x<<1
#define RC x<<1|1
set<pair<int,pair<int,int> > > s;
set<pair<pair<int,int>,int> > ss;
vector<int> vec[200020],qry[800080];
vector<pair<int,int> > nd[600060];
int n,k,a[200020],pre[200020],mp[200020],cnt=0;
ll f[200020],sum[200020],pos[800080];
struct node{
	int l,r,x,mex;
}b[600060],c[200020];
struct point{
	ll x,y;
};
vector<point> tree[800080];
il bl operator<(const node &x,const node &y){
	return x.mex<y.mex;
}
il point operator-(const point &x,const point &y){
	return (point){x.x-y.x,x.y-y.y};
}
il ll operator*(const point &x,const point &y){
	return x.x*y.y-x.y*y.x;
}
struct seg{
	vector<point> vec[800080];
	ll pos[800080];
	il void pushup(ll x){
		vec[x]=vec[LC];
		fr(i,0,(ll)vec[RC].size()-1){
			while((ll)vec[x].size()>1&&(vec[RC][i]-vec[x][(ll)vec[x].size()-1])*(vec[x][(ll)vec[x].size()-1]-vec[x][(ll)vec[x].size()-2])<=0) vec[x].pop_back();
			vec[x].push_back(vec[RC][i]);
		}
		pos[x]=(ll)vec[x].size()-1;
	}
	void mdf(ll x,ll l,ll r,ll v){
		if(l==r){
			vec[x].push_back((point){sum[l],f[l]});
			return;
		}
		ll mid=(l+r)>>1;
		if(v<=mid) mdf(LC,l,mid,v);
		else mdf(RC,mid+1,r,v);
		if(v==r) pushup(x);
	}
	ll query(ll x,ll l,ll r,ll ql,ll qr,ll v){
		if(ql<=l&&r<=qr){
			while(pos[x]&&(vec[x][pos[x]]-vec[x][pos[x]-1])*(point){1,v}>=0) pos[x]--;
			if(pos[x]<(ll)vec[x].size()) return vec[x][pos[x]].y-v*vec[x][pos[x]].x;
			else return -inf;
		}
		ll mid=(l+r)>>1,res=-inf;
		if(ql<=mid) res=max(res,query(LC,l,mid,ql,qr,v));
		if(mid<qr) res=max(res,query(RC,mid+1,r,ql,qr,v));
		return res;
	}
}T;
il void pushup(ll x){
	ll top=0;
	fr(i,0,(ll)tree[x].size()-1){
		while(top>1){
			if(tree[x][i].x==tree[x][top-1].x){
				if(tree[x][i].y>tree[x][top-1].y) top--;
				else break;
			}
			else if((tree[x][i]-tree[x][top-1])*(tree[x][top-1]-tree[x][top-2])<=0) top--;
			else break;
		}
		tree[x][top++]=tree[x][i];
	}
	tree[x].resize(top);
	pos[x]=max(0ll,top-1);
}
void ins(ll x,ll l,ll r,ll ql,ll qr,ll v){
	if(ql>qr) return;
	if(ql<=l&&r<=qr){
		qry[x].push_back(v);
		tree[x].push_back((point){0,0});
		nd[v].push_back(MP(x,(ll)qry[x].size()-1));
		return;
	}
	ll mid=(l+r)>>1;
	if(ql<=mid) ins(LC,l,mid,ql,qr,v);
	if(mid<qr) ins(RC,mid+1,r,ql,qr,v);
}
ll query(ll x,ll l,ll r,ll v,ll w){
	while(pos[x]&&(tree[x][pos[x]]-tree[x][pos[x]-1])*(point){1,w}>=0) pos[x]--;
	ll res=-inf;
	if(pos[x]<(ll)tree[x].size()) res=tree[x][pos[x]].y-w*tree[x][pos[x]].x;
	if(l==r) return res;
	ll mid=(l+r)>>1;
	if(v<=mid) res=max(res,query(LC,l,mid,v,w));
	else res=max(res,query(RC,mid+1,r,v,w));
	return res;
}
void solve(ll x,ll l,ll r){
	pushup(x);
	if(l==r){
		f[l]=max(f[l],query(1,0,n,l,sum[l]));
		if(c[l].l){
			ll tmp=T.query(1,0,n,c[l].l-1,c[l].r-1,c[l].mex);
			f[l]=max(f[l],tmp+c[l].mex*sum[l]);
		}
		T.mdf(1,0,n,l);
		fr(i,0,(ll)vec[l+1].size()-1){
			ll id=vec[l+1][i],tmp=T.query(1,0,n,b[id].l-1,b[id].r-1,b[id].mex);
			fr(j,0,(ll)nd[id].size()-1) tree[nd[id][j].fir][nd[id][j].sec]=(point){-b[id].mex,tmp};
		}
	}
	else{
		ll mid=(l+r)>>1;
		solve(LC,l,mid);
		solve(RC,mid+1,r);
	}
}
int main(){
	n=read();
	k=read();
	fr(i,1,n){
		a[i]=read();
		pre[i]=mp[a[i]];
		mp[a[i]]=i;
	}
	ll lst=n;
	fr(i,0,n){
		if(mp[i]<lst){
			s.insert(MP(i,MP(mp[i]+1,lst)));
			ss.insert(MP(MP(mp[i]+1,lst),i));
			lst=mp[i];
		}
	}
	pfr(i,n,1){
		set<pair<pair<int,int>,int> >::iterator jt=ss.upper_bound(MP(MP(i-k+1,n+1),n+1));
		if(jt!=ss.begin()){
			jt--;
			if((*jt).fir.sec>=(i-k+1)) c[i]=(node){i-k+1,(*jt).fir.sec,i,(*jt).sec};
		}
		set<pair<int,pair<int,int> > >::iterator it=s.begin();
		pair<int,pair<int,int> > now=*it;
		b[++cnt]=(node){now.sec.fir,i,i,now.fir};
		s.erase(now);
		ss.erase(MP(now.sec,now.fir));
		if(now.sec.fir<i){
			s.insert(MP(now.fir,MP(now.sec.fir,i-1)));
			ss.insert(MP(MP(now.sec.fir,i-1),now.fir));
		}
		mp[a[i]]=pre[i];
		it=s.lower_bound(MP(a[i]+1,MP(0,0)));
		ll tmp=0;
		if(it!=s.end()) tmp=(*it).sec.sec;
		while(it!=s.end()){
			now=*it;
			if(now.sec.sec<=mp[a[i]]) break;
			s.erase(now);
			ss.erase(MP(now.sec,now.fir));
			b[++cnt]=(node){now.sec.fir,now.sec.sec,i,now.fir};
			if(now.sec.fir<=mp[a[i]]){
				s.insert(MP(now.fir,MP(now.sec.fir,mp[a[i]])));
				ss.insert(MP(MP(now.sec.fir,mp[a[i]]),now.fir));
				break;
			}
			it=s.lower_bound(MP(a[i]+1,MP(0,0)));
		}
		if(mp[a[i]]<tmp){
			s.insert(MP(a[i],MP(mp[a[i]]+1,tmp)));
			ss.insert(MP(MP(mp[a[i]]+1,tmp),a[i]));
		}
	}
	fr(i,1,n) sum[i]=sum[i-1]+a[i];
	sort(b+1,b+cnt+1);
	fr(i,1,cnt) vec[b[i].x].push_back(i);
	pfr(i,cnt,1) ins(1,0,n,b[i].x,min(n,b[i].l+k-1),i);
	solve(1,0,n);
	write(f[n]);
}

相關文章