P10342 [THUSC 2019] 數列 題解

Southern_Dynasty發表於2024-04-17

形式化題面:

\[\sum_{l=1}^{n}\sum_{r=l}^{n}\max_{i=l}^{r}(i-l+1)\times f(i,r) \]

其中 \(f(l,r)\)\(a_l,...,a_r\) 中有多少個不同的數字。

注意到,除了 Sub2,其餘資料點都有 \(\max f\le 800\),這啟發我們考慮 \(O(nm)\) 的演算法。

套路地,掃描線列舉右端點,則現在只需要考慮其對所有左端點的貢獻。

\(pre_i\) 表示 \(a_i\) 上一次出現的位置,維護一個 ODT 狀物,即所有 \(f\) 的連續段。每次 \(r-1\to r\) 就相當於在 \(pre_{r}+1\) 處 split 一下,後面位置的 \(f\)\(+1\),然後新 push 進去一個 \(([r,r],1)\) 的段,最後合併一些段。注意到只會有 \(O(m)\) 段,所以可以直接用陣列維護,每次直接重構都是可以接受的。

然後注意到,只有每一段的右端點才有可能貢獻到答案。記這些點為“關鍵點”。

考慮關鍵點 \(i\) 對答案的貢獻:\((i-l+1)\times f\)。其中 \(f\) 是定值。拆項可得 \(-f\cdot l+(i+1)f\)。不難發現這是一個一次函式的形式。

考慮從左到右加入關鍵點,那麼可以注意到每次加入的一次函式斜率遞增,那麼其一定會更新一段字尾的答案。考慮將每個位置的最優點描出來,不難發現其構成了一個下凸殼。於是插入線段也是簡單的,先 pop 掉那些被完全覆蓋的線段,然後 \(O(1)\) 求出兩條線段的交點即可。最後再掃描凸殼計算答案即可。這形如若干等差數列求和,容易 \(O(1)\) 計算。

注意到每條線段只會被 pop 一次,且求交點複雜度為 \(O(1)\),所以總的時間複雜度為 \(O(nm)\)

至於 Sub2,根據基礎不等式知識不難發現 \(i\) 取區間中點最優,因此直接列舉區間長度計算即可,複雜度 \(O(n)\)

程式碼:

#include<bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/tree_policy.hpp>
#include<ext/pb_ds/hash_policy.hpp>
#define gt getchar
#define pt putchar
#define fst first
#define scd second
#define SZ(s) ((int)s.size())
#define all(s) s.begin(),s.end()
#define pb push_back
#define eb emplace_back
typedef long long ll;
typedef double db;
typedef long double ld;
typedef unsigned long long ull;
typedef unsigned int uint;
const int N=1e5+5;
const int mod=998244353;
using namespace std;
using namespace __gnu_pbds;
typedef pair<int,int> pii;
template<class T,class I> inline void chkmax(T &a,I b){a=max(a,(T)b);}
template<class T,class I> inline void chkmin(T &a,I b){a=min(a,(T)b);}
inline bool __(char ch){return ch>=48&&ch<=57;}
template<class T> inline void read(T &x){
	x=0;bool sgn=0;static char ch=gt();
	while(!__(ch)&&ch!=EOF) sgn|=(ch=='-'),ch=gt();
	while(__(ch)) x=(x<<1)+(x<<3)+(ch&15),ch=gt();
	if(sgn) x=-x;
}
template<class T,class ...I> inline void read(T &x,I &...x1){
	read(x);
	read(x1...);
}
template<class T> inline void print(T x){
	static char stk[70];short top=0;
	if(x<0) pt('-');
	do{stk[++top]=x>=0?(x%10+48):(-(x%10)+48),x/=10;}while(x);
	while(top) pt(stk[top--]);
}
template<class T> inline void printsp(T x){
	print(x);
	putchar(' ');
}
template<class T> inline void println(T x){
	print(x);
	putchar('\n');
}
int n,a[N],pre[N],pos[N],siz;
struct Seg{
	int l,r,w;
	Seg(int _l=0,int _r=0,int _w=0)
		:l(_l),r(_r),w(_w)
	{}
}odt[N];
inline bool in(int x,Seg seg){
	return seg.l<=x&&x<=seg.r;
}
inline void split(int x){
	auto upd=[&](int i){
		if(odt[i+1].w==odt[i].w){
			odt[i].r=odt[i+1].r;
			for(int j=i+1;j<siz;++j) odt[j]=odt[j+1];
			siz--;
		}
	};
	for(int i=1;i<=siz;++i){
		if(in(x,odt[i])){
			if(x==odt[i].l){
				for(int j=i;j<=siz;++j) odt[j].w++;
				upd(i-1);
			}else{
				odt[++siz]=Seg(x,odt[i].r,odt[i].w);
				odt[i].r=x-1;
				for(int j=i+1;j<=siz;++j) odt[j].w++;
				for(int j=siz;j>=i+2;--j) swap(odt[j],odt[j-1]);
				upd(i);
			}
			return;
		}
	}
}
struct func{
	int k,b;
	func(int _k=0,int _b=0):k(_k),b(_b){}
	inline int get(int x){
		return k*x+b;
	}
};
inline int cross(func a,func b){
	// find min x, so that a.get(x) > b.get(x).
	if(a.b>b.b) return 1;
	int now=(b.b-a.b)/(a.k-b.k);
	while(a.get(now)<=b.get(now)) now++; 
	while(a.get(now-1)>b.get(now-1)) now--;
	return now;
}
inline func gen(int i,int w){
	// (i-l+1)*w.
	// (i+1)*w - l*w.
	return func(-w,(i+1)*w);
}
inline int s(int x){
	return (1ll*x*(x+1)/2)%mod;
}
struct Node{
	func w;
	int l,r;
	Node(func _w=func(),int _l=0,int _r=0)
		:w(_w),l(_l),r(_r)
	{}
	inline int val(){
		return (((1ll*w.b*(r-l+1)%mod+1ll*w.k*(s(r)-s(l-1)+mod)%mod)%mod)+mod)%mod;
	}
}conv[N];
int top;
inline int find(func w){
	// find min l, so that w.get(l) > others.
	while(top&&conv[top].w.get(conv[top].l)<=w.get(conv[top].l)) top--;
	if(!top) return 1;
	int x=cross(w,conv[top].w);
	conv[top].r=x-1;
	return x;
}
inline void add(int &a,int b){
	a+=b;
	if(a>=mod) a-=mod;
}
namespace corner_case{
	bool vis[N];
	inline bool check(){
		int cnt=0;
		for(int i=1;i<=n;++i){
			if(!vis[a[i]]) cnt++;
			vis[a[i]]=1;
		}
		return cnt>800;
	}
	inline void solve(){
		int ans=0;
		for(int i=1;i<=n;++i){
			int len1=(i+1)/2,len2=i/2+1;
			if(i&1) add(ans,1ll*len1*len1%mod*(n-i+1)%mod);
			else add(ans,1ll*len1*len2%mod*(n-i+1)%mod);
		}
		println(ans);
	}
}
signed main(){
	read(n);
	for(int i=1;i<=n;++i){
		read(a[i]);
		pre[i]=pos[a[i]];
		pos[a[i]]=i;
	}
	if(corner_case::check()) return corner_case::solve(),0; 
	int ans=0;
	for(int r=1;r<=n;++r){
		split(pre[r]+1);
		odt[++siz]=Seg(r,r,1);
		if(odt[siz-1].w==1) odt[siz-1].r=r,siz--;
		conv[top=1]=Node(gen(odt[1].r,odt[1].w),odt[1].l,odt[1].r);
		for(int i=2;i<=siz;++i){
			func qwq=gen(odt[i].r,odt[i].w); 
			Node now(qwq,odt[i].l,odt[i].r);
			now.l=find(qwq);
			conv[++top]=now;
		}
		for(int i=1;i<=top;++i) add(ans,conv[i].val());
	}
	println(ans);
	return 0;
}

相關文章