題目連結
題目大意
給定一個長度為 \(n\) 的序列 \(a\),定義一段區間的價值為該區間的 \(\operatorname{mex}\) 乘上區間元素總和。
你需要將序列劃分成若干個長度 \(\leq k\) 的區間。一個劃分方案的價值為劃分出來的每個區間價值之和,求所有劃分方案的價值最大值。
\(1 \leq k \leq n \leq 2\times 10^5,0 \leq a_i \leq n\)。
題目分析
第一眼:這不直接維護 \(\operatorname{mex}\) 連續段之後對每段維護凸包,然後用 deque+啟發式合併就行了嗎?
實際上不行,因為從左往右 dp 的時候連續段會分裂。
記 \(s\) 為 \(a\) 的字首和,先把 dp 的轉移寫出來:\(f_i=\max\limits_{i-k+1\leq j \leq i}{(f_{j-1}+\operatorname{mex}(a_j,\ldots,a_i)\times(s_i-s_{j-1}))}\)。
然後從左往右維護以 \(i\) 為結尾的 \(\operatorname{mex}\) 連續段,對 \(\operatorname{mex}\) 相同的一段區間 \([l,r]\) 放一起考慮。
將轉移拆成 \((f_{j-1}-\operatorname{mex}\times s_{j-1})+(\operatorname{mex}\times s_i)\),然後分成兩個部分。對於前面的式子,只需要將 \([l,r]\) 內的所有點 \((s_{j-1},f_{j-1})\) 建個凸包,然後查斜率為 \(\operatorname{mex}\) 的最大值即可,設其為 \(v_{\operatorname{mex}}\)。將所有連續段的貢獻求出來後,後面的式子相當於求 \(\max\limits_{\operatorname{mex}}{(v_{\operatorname{mex}}+\operatorname{mex}\times s_i)}\),也可以看成將所有點 \((-\operatorname{mex},v_{\operatorname{mex}})\) 建個凸包,然後查斜率為 \(s_i\) 的最大值。
眾所周知,整個過程只會有 \(\mathcal O(n)\) 個 \(\operatorname{mex}\) 連續段,將其預處理出來,每個連續段 \((l,r,x,m)\) 表示所有左端點在 \([l,r]\) 內,右端點 \(\geq x\) 的區間 \(\operatorname{mex}\) 值 \(\geq m\)(由於是最大值,所以算小了肯定不優)。
然後考慮 \(k\) 的限制,這相當於能貢獻到 \(i\) 的為若干整連續段+某一段的字尾。將字尾單獨考慮,然後就變成了若干個整段可以給 \(i\) 貢獻,用剛才預處理出來的 \((l,r,x,m)\) 表示即為 \([l,r]\) 可以給 \([x,l+k]\) 貢獻,\(\operatorname{mex}\) 是 \(m\)。
第一部分是簡單的,直接用線段樹維護凸包。對於線段樹節點 \([L,R]\),等到 \(f_L\sim f_R\) 求出來後在該節點上建凸包,由於凸包上的點本身是按照橫座標排序的,所以建凸包的時間複雜度是 \(\mathcal O(n\log n)\)。查詢的時候由於隨著時間的推移,以每個位置為左端點的 \(\operatorname{mex}\) 肯定單調不降,所以直接對線段樹上每個凸包維護一個指標即可,查詢的總時間複雜度 \(\mathcal O(n\log n)\)。
然後是第二部分,這相當於對每個 \((l,r,x,m)\) 求出 \([l,r]\) 第一部分的答案 \(v_m\) 後,將 \((-m,v_m)\) 這個點貢獻到 \([x,l+k]\) 這個區間,然後列舉到 \(i\) 的時候查詢 \(i\) 這個位置上凸包在斜率為 \(s_i\) 時的最大值。比較簡單的做法是把 \((-m,v_m)\) 看成直線,然後用線段樹+李超線段樹維護,不過時間複雜度是 \(\mathcal O(n\log^2 n)\) 的,不太行。
還是考慮用線段樹維護凸包,由於查詢時 \(s_i\) 單調不降,所以如果能建出凸包,那麼查詢的時候就只需要維護指標,還是單 \(\log\);關鍵是將 \((-m,v_m)\) 塞到線段樹節點的時候並不能保證 \(m\) 單增,這導致插入凸包時多一個 \(\log\)。
不過由於插入的點的橫座標 \(m\) 是提前知道的,所以可以離線。先將所有 \((l,r,x,m)\) 按照 \(m\) 排序,然後按順序依次插入線段樹上 \([x,l+k]\) 這段區間,這樣就可以把每個線段樹節點將要插入的點提前按照橫座標排序,同時處理出每個 \((l,r,x,m)\) 在每個插入的線段樹節點上的排名。
由於對於所有的 \((l,r,x,m)\),都有 \(r\leq x\),所以遍歷到每個線段樹節點 \([L,R]\) 時,所有可能貢獻給當前節點的連續段都已經計算完畢,這時再用已經排好序的點建凸包即可。時間複雜度 \(\mathcal O(n\log n)\)。
吐槽一下:為什麼單 \(\log\) 跑得比雙 \(\log\) 慢這麼多啊/ll/ll/ll
程式碼
#include<bits/stdc++.h>
using namespace std;
using namespace my_std;
#define LC x<<1
#define RC x<<1|1
set<pair<int,pair<int,int> > > s;
set<pair<pair<int,int>,int> > ss;
vector<int> vec[200020],qry[800080];
vector<pair<int,int> > nd[600060];
int n,k,a[200020],pre[200020],mp[200020],cnt=0;
ll f[200020],sum[200020],pos[800080];
struct node{
int l,r,x,mex;
}b[600060],c[200020];
struct point{
ll x,y;
};
vector<point> tree[800080];
il bl operator<(const node &x,const node &y){
return x.mex<y.mex;
}
il point operator-(const point &x,const point &y){
return (point){x.x-y.x,x.y-y.y};
}
il ll operator*(const point &x,const point &y){
return x.x*y.y-x.y*y.x;
}
struct seg{
vector<point> vec[800080];
ll pos[800080];
il void pushup(ll x){
vec[x]=vec[LC];
fr(i,0,(ll)vec[RC].size()-1){
while((ll)vec[x].size()>1&&(vec[RC][i]-vec[x][(ll)vec[x].size()-1])*(vec[x][(ll)vec[x].size()-1]-vec[x][(ll)vec[x].size()-2])<=0) vec[x].pop_back();
vec[x].push_back(vec[RC][i]);
}
pos[x]=(ll)vec[x].size()-1;
}
void mdf(ll x,ll l,ll r,ll v){
if(l==r){
vec[x].push_back((point){sum[l],f[l]});
return;
}
ll mid=(l+r)>>1;
if(v<=mid) mdf(LC,l,mid,v);
else mdf(RC,mid+1,r,v);
if(v==r) pushup(x);
}
ll query(ll x,ll l,ll r,ll ql,ll qr,ll v){
if(ql<=l&&r<=qr){
while(pos[x]&&(vec[x][pos[x]]-vec[x][pos[x]-1])*(point){1,v}>=0) pos[x]--;
if(pos[x]<(ll)vec[x].size()) return vec[x][pos[x]].y-v*vec[x][pos[x]].x;
else return -inf;
}
ll mid=(l+r)>>1,res=-inf;
if(ql<=mid) res=max(res,query(LC,l,mid,ql,qr,v));
if(mid<qr) res=max(res,query(RC,mid+1,r,ql,qr,v));
return res;
}
}T;
il void pushup(ll x){
ll top=0;
fr(i,0,(ll)tree[x].size()-1){
while(top>1){
if(tree[x][i].x==tree[x][top-1].x){
if(tree[x][i].y>tree[x][top-1].y) top--;
else break;
}
else if((tree[x][i]-tree[x][top-1])*(tree[x][top-1]-tree[x][top-2])<=0) top--;
else break;
}
tree[x][top++]=tree[x][i];
}
tree[x].resize(top);
pos[x]=max(0ll,top-1);
}
void ins(ll x,ll l,ll r,ll ql,ll qr,ll v){
if(ql>qr) return;
if(ql<=l&&r<=qr){
qry[x].push_back(v);
tree[x].push_back((point){0,0});
nd[v].push_back(MP(x,(ll)qry[x].size()-1));
return;
}
ll mid=(l+r)>>1;
if(ql<=mid) ins(LC,l,mid,ql,qr,v);
if(mid<qr) ins(RC,mid+1,r,ql,qr,v);
}
ll query(ll x,ll l,ll r,ll v,ll w){
while(pos[x]&&(tree[x][pos[x]]-tree[x][pos[x]-1])*(point){1,w}>=0) pos[x]--;
ll res=-inf;
if(pos[x]<(ll)tree[x].size()) res=tree[x][pos[x]].y-w*tree[x][pos[x]].x;
if(l==r) return res;
ll mid=(l+r)>>1;
if(v<=mid) res=max(res,query(LC,l,mid,v,w));
else res=max(res,query(RC,mid+1,r,v,w));
return res;
}
void solve(ll x,ll l,ll r){
pushup(x);
if(l==r){
f[l]=max(f[l],query(1,0,n,l,sum[l]));
if(c[l].l){
ll tmp=T.query(1,0,n,c[l].l-1,c[l].r-1,c[l].mex);
f[l]=max(f[l],tmp+c[l].mex*sum[l]);
}
T.mdf(1,0,n,l);
fr(i,0,(ll)vec[l+1].size()-1){
ll id=vec[l+1][i],tmp=T.query(1,0,n,b[id].l-1,b[id].r-1,b[id].mex);
fr(j,0,(ll)nd[id].size()-1) tree[nd[id][j].fir][nd[id][j].sec]=(point){-b[id].mex,tmp};
}
}
else{
ll mid=(l+r)>>1;
solve(LC,l,mid);
solve(RC,mid+1,r);
}
}
int main(){
n=read();
k=read();
fr(i,1,n){
a[i]=read();
pre[i]=mp[a[i]];
mp[a[i]]=i;
}
ll lst=n;
fr(i,0,n){
if(mp[i]<lst){
s.insert(MP(i,MP(mp[i]+1,lst)));
ss.insert(MP(MP(mp[i]+1,lst),i));
lst=mp[i];
}
}
pfr(i,n,1){
set<pair<pair<int,int>,int> >::iterator jt=ss.upper_bound(MP(MP(i-k+1,n+1),n+1));
if(jt!=ss.begin()){
jt--;
if((*jt).fir.sec>=(i-k+1)) c[i]=(node){i-k+1,(*jt).fir.sec,i,(*jt).sec};
}
set<pair<int,pair<int,int> > >::iterator it=s.begin();
pair<int,pair<int,int> > now=*it;
b[++cnt]=(node){now.sec.fir,i,i,now.fir};
s.erase(now);
ss.erase(MP(now.sec,now.fir));
if(now.sec.fir<i){
s.insert(MP(now.fir,MP(now.sec.fir,i-1)));
ss.insert(MP(MP(now.sec.fir,i-1),now.fir));
}
mp[a[i]]=pre[i];
it=s.lower_bound(MP(a[i]+1,MP(0,0)));
ll tmp=0;
if(it!=s.end()) tmp=(*it).sec.sec;
while(it!=s.end()){
now=*it;
if(now.sec.sec<=mp[a[i]]) break;
s.erase(now);
ss.erase(MP(now.sec,now.fir));
b[++cnt]=(node){now.sec.fir,now.sec.sec,i,now.fir};
if(now.sec.fir<=mp[a[i]]){
s.insert(MP(now.fir,MP(now.sec.fir,mp[a[i]])));
ss.insert(MP(MP(now.sec.fir,mp[a[i]]),now.fir));
break;
}
it=s.lower_bound(MP(a[i]+1,MP(0,0)));
}
if(mp[a[i]]<tmp){
s.insert(MP(a[i],MP(mp[a[i]]+1,tmp)));
ss.insert(MP(MP(mp[a[i]]+1,tmp),a[i]));
}
}
fr(i,1,n) sum[i]=sum[i-1]+a[i];
sort(b+1,b+cnt+1);
fr(i,1,cnt) vec[b[i].x].push_back(i);
pfr(i,cnt,1) ins(1,0,n,b[i].x,min(n,b[i].l+k-1),i);
solve(1,0,n);
write(f[n]);
}