形式化題面:
求
其中 \(f(l,r)\) 為 \(a_l,...,a_r\) 中有多少個不同的數字。
注意到,除了 Sub2,其餘資料點都有 \(\max f\le 800\),這啟發我們考慮 \(O(nm)\) 的演算法。
套路地,掃描線列舉右端點,則現在只需要考慮其對所有左端點的貢獻。
設 \(pre_i\) 表示 \(a_i\) 上一次出現的位置,維護一個 ODT 狀物,即所有 \(f\) 的連續段。每次 \(r-1\to r\) 就相當於在 \(pre_{r}+1\) 處 split 一下,後面位置的 \(f\) 值 \(+1\),然後新 push 進去一個 \(([r,r],1)\) 的段,最後合併一些段。注意到只會有 \(O(m)\) 段,所以可以直接用陣列維護,每次直接重構都是可以接受的。
然後注意到,只有每一段的右端點才有可能貢獻到答案。記這些點為“關鍵點”。
考慮關鍵點 \(i\) 對答案的貢獻:\((i-l+1)\times f\)。其中 \(f\) 是定值。拆項可得 \(-f\cdot l+(i+1)f\)。不難發現這是一個一次函式的形式。
考慮從左到右加入關鍵點,那麼可以注意到每次加入的一次函式斜率遞增,那麼其一定會更新一段字尾的答案。考慮將每個位置的最優點描出來,不難發現其構成了一個下凸殼。於是插入線段也是簡單的,先 pop 掉那些被完全覆蓋的線段,然後 \(O(1)\) 求出兩條線段的交點即可。最後再掃描凸殼計算答案即可。這形如若干等差數列求和,容易 \(O(1)\) 計算。
注意到每條線段只會被 pop 一次,且求交點複雜度為 \(O(1)\),所以總的時間複雜度為 \(O(nm)\)。
至於 Sub2,根據基礎不等式知識不難發現 \(i\) 取區間中點最優,因此直接列舉區間長度計算即可,複雜度 \(O(n)\)。
程式碼:
#include<bits/stdc++.h>
#include<ext/pb_ds/assoc_container.hpp>
#include<ext/pb_ds/tree_policy.hpp>
#include<ext/pb_ds/hash_policy.hpp>
#define gt getchar
#define pt putchar
#define fst first
#define scd second
#define SZ(s) ((int)s.size())
#define all(s) s.begin(),s.end()
#define pb push_back
#define eb emplace_back
typedef long long ll;
typedef double db;
typedef long double ld;
typedef unsigned long long ull;
typedef unsigned int uint;
const int N=1e5+5;
const int mod=998244353;
using namespace std;
using namespace __gnu_pbds;
typedef pair<int,int> pii;
template<class T,class I> inline void chkmax(T &a,I b){a=max(a,(T)b);}
template<class T,class I> inline void chkmin(T &a,I b){a=min(a,(T)b);}
inline bool __(char ch){return ch>=48&&ch<=57;}
template<class T> inline void read(T &x){
x=0;bool sgn=0;static char ch=gt();
while(!__(ch)&&ch!=EOF) sgn|=(ch=='-'),ch=gt();
while(__(ch)) x=(x<<1)+(x<<3)+(ch&15),ch=gt();
if(sgn) x=-x;
}
template<class T,class ...I> inline void read(T &x,I &...x1){
read(x);
read(x1...);
}
template<class T> inline void print(T x){
static char stk[70];short top=0;
if(x<0) pt('-');
do{stk[++top]=x>=0?(x%10+48):(-(x%10)+48),x/=10;}while(x);
while(top) pt(stk[top--]);
}
template<class T> inline void printsp(T x){
print(x);
putchar(' ');
}
template<class T> inline void println(T x){
print(x);
putchar('\n');
}
int n,a[N],pre[N],pos[N],siz;
struct Seg{
int l,r,w;
Seg(int _l=0,int _r=0,int _w=0)
:l(_l),r(_r),w(_w)
{}
}odt[N];
inline bool in(int x,Seg seg){
return seg.l<=x&&x<=seg.r;
}
inline void split(int x){
auto upd=[&](int i){
if(odt[i+1].w==odt[i].w){
odt[i].r=odt[i+1].r;
for(int j=i+1;j<siz;++j) odt[j]=odt[j+1];
siz--;
}
};
for(int i=1;i<=siz;++i){
if(in(x,odt[i])){
if(x==odt[i].l){
for(int j=i;j<=siz;++j) odt[j].w++;
upd(i-1);
}else{
odt[++siz]=Seg(x,odt[i].r,odt[i].w);
odt[i].r=x-1;
for(int j=i+1;j<=siz;++j) odt[j].w++;
for(int j=siz;j>=i+2;--j) swap(odt[j],odt[j-1]);
upd(i);
}
return;
}
}
}
struct func{
int k,b;
func(int _k=0,int _b=0):k(_k),b(_b){}
inline int get(int x){
return k*x+b;
}
};
inline int cross(func a,func b){
// find min x, so that a.get(x) > b.get(x).
if(a.b>b.b) return 1;
int now=(b.b-a.b)/(a.k-b.k);
while(a.get(now)<=b.get(now)) now++;
while(a.get(now-1)>b.get(now-1)) now--;
return now;
}
inline func gen(int i,int w){
// (i-l+1)*w.
// (i+1)*w - l*w.
return func(-w,(i+1)*w);
}
inline int s(int x){
return (1ll*x*(x+1)/2)%mod;
}
struct Node{
func w;
int l,r;
Node(func _w=func(),int _l=0,int _r=0)
:w(_w),l(_l),r(_r)
{}
inline int val(){
return (((1ll*w.b*(r-l+1)%mod+1ll*w.k*(s(r)-s(l-1)+mod)%mod)%mod)+mod)%mod;
}
}conv[N];
int top;
inline int find(func w){
// find min l, so that w.get(l) > others.
while(top&&conv[top].w.get(conv[top].l)<=w.get(conv[top].l)) top--;
if(!top) return 1;
int x=cross(w,conv[top].w);
conv[top].r=x-1;
return x;
}
inline void add(int &a,int b){
a+=b;
if(a>=mod) a-=mod;
}
namespace corner_case{
bool vis[N];
inline bool check(){
int cnt=0;
for(int i=1;i<=n;++i){
if(!vis[a[i]]) cnt++;
vis[a[i]]=1;
}
return cnt>800;
}
inline void solve(){
int ans=0;
for(int i=1;i<=n;++i){
int len1=(i+1)/2,len2=i/2+1;
if(i&1) add(ans,1ll*len1*len1%mod*(n-i+1)%mod);
else add(ans,1ll*len1*len2%mod*(n-i+1)%mod);
}
println(ans);
}
}
signed main(){
read(n);
for(int i=1;i<=n;++i){
read(a[i]);
pre[i]=pos[a[i]];
pos[a[i]]=i;
}
if(corner_case::check()) return corner_case::solve(),0;
int ans=0;
for(int r=1;r<=n;++r){
split(pre[r]+1);
odt[++siz]=Seg(r,r,1);
if(odt[siz-1].w==1) odt[siz-1].r=r,siz--;
conv[top=1]=Node(gen(odt[1].r,odt[1].w),odt[1].l,odt[1].r);
for(int i=2;i<=siz;++i){
func qwq=gen(odt[i].r,odt[i].w);
Node now(qwq,odt[i].l,odt[i].r);
now.l=find(qwq);
conv[++top]=now;
}
for(int i=1;i<=top;++i) add(ans,conv[i].val());
}
println(ans);
return 0;
}