思路:
首先求出 \(a\) 的字首和陣列 \(s\)。
考慮動態規劃,令 \(dp_{i,j}\) 表示以 \(i\) 結尾,末尾有 \(j\) 個為一組的最小答案,則狀態轉移方程為:
樸素直接轉移是 \(O(N^3)\) 的,可以得到 36pts 的好成績程式碼就懶的給了。
考慮最佳化,對於求出最小的一個 \(k\),使得 \(s_{i-j}-s_{i-j-k} > s_i - s_{i-j}\),那麼狀態轉移方程為:
後面的一串可以提前字首預處理好,現在的複雜度在求 \(k\) 上,注意到 \(s_{i,j} - s_{i-j-k}\) 是單調的,那麼直接二分即可。
時間複雜度最佳化至 \(O(N^2 \log N)\)。
$O(N^2 \log N)$ 程式碼
#include<bits/stdc++.h>
#define Add(x,y) (x+y>=mod)?(x+y-mod):(x+y)
#define lowbit(x) x&(-x)
#define pi pair<ll,ll>
#define pii pair<ll,pair<ll,ll>>
#define iip pair<pair<ll,ll>,ll>
#define ppii pair<pair<ll,ll>,pair<ll,ll>>
#define fi first
#define se second
#define full(l,r,x) for(auto it=l;it!=r;it++) (*it)=x
#define Full(a) memset(a,0,sizeof(a))
#define open(s1,s2) freopen(s1,"r",stdin),freopen(s2,"w",stdout);
using namespace std;
typedef double db;
typedef unsigned long long ull;
typedef long long ll;
bool Begin;
const ll N=5050,INF=4e18;
inline ll read(){
ll x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){
if(c=='-')
f=-1;
c=getchar();
}
while(c>='0'&&c<='9'){
x=(x<<1)+(x<<3)+(c^48);
c=getchar();
}
return x*f;
}
inline void write(ll x){
if(x<0){
putchar('-');
x=-x;
}
if(x>9)
write(x/10);
putchar(x%10+'0');
}
bool op;
ll n,l,r,h,t,ans=INF;
ll s[N],dp[N][N],f[N][N];
ll get(ll l,ll r){
if(l>r)
return 0;
if(!l)
return s[r];
return s[r]-s[l-1];
}
bool End;
int main(){
n=read(),op=read();
for(int i=1;i<=n;i++)
s[i]=s[i-1]+read();
dp[1][0]=f[1][0]=INF;
dp[1][1]=f[1][1]=s[1]*s[1];
for(int i=2;i<=n;i++){
f[i][0]=dp[i][0]=INF;
for(int j=1;j<=i;j++){
l=1,r=i-j,t=0,h=get(i-j+1,i);
if(s[i-j]<=h)
t=i-j+1;
else{
while(l<=r){
ll mid=(l+r)>>1;
if(get(i-j-mid+1,i-j)>h){
t=mid;
r=mid-1;
}
else
l=mid+1;
}
}
dp[i][j]=f[i-j][t-1]+h*h;
f[i][j]=min(f[i][j-1],dp[i][j]);
}
}
for(int i=1;i<=n;i++)
ans=min(ans,dp[n][i]);
write(ans);
cerr<<'\n'<<abs(&Begin-&End)/1048576<<"MB";
return 0;
}
之後我們可以發現,若 \(j\) 的單增的,則 \(i-j-k+1\) 是單降的,那麼我們直接對 \(k\) 進行走指標即可,時間複雜度最佳化至 \(O(N^2)\),可以拿到 64pts 的好成績。
$O(N^2)$ 程式碼
#include<bits/stdc++.h>
#define Add(x,y) (x+y>=mod)?(x+y-mod):(x+y)
#define lowbit(x) x&(-x)
#define pi pair<ll,ll>
#define pii pair<ll,pair<ll,ll>>
#define iip pair<pair<ll,ll>,ll>
#define ppii pair<pair<ll,ll>,pair<ll,ll>>
#define fi first
#define se second
#define full(l,r,x) for(auto it=l;it!=r;it++) (*it)=x
#define Full(a) memset(a,0,sizeof(a))
#define open(s1,s2) freopen(s1,"r",stdin),freopen(s2,"w",stdout);
using namespace std;
typedef double db;
typedef unsigned long long ull;
typedef long long ll;
bool Begin;
const ll N=5050,INF=4e18;
inline ll read(){
ll x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){
if(c=='-')
f=-1;
c=getchar();
}
while(c>='0'&&c<='9'){
x=(x<<1)+(x<<3)+(c^48);
c=getchar();
}
return x*f;
}
inline void write(ll x){
if(x<0){
putchar('-');
x=-x;
}
if(x>9)
write(x/10);
putchar(x%10+'0');
}
bool op;
ll n,t,h,sum,ans=INF;
ll s[N],a[N],dp[N][N],f[N][N];
ll get(ll l,ll r){
if(l>r)
return 0;
if(!l)
return s[r];
return s[r]-s[l-1];
}
bool End;
int main(){
n=read(),op=read();
for(int i=1;i<=n;i++){
a[i]=read();
s[i]=s[i-1]+a[i];
}
dp[1][0]=f[1][0]=INF;
dp[1][1]=f[1][1]=s[1]*s[1];
for(int i=2;i<=n;i++){
f[i][0]=dp[i][0]=INF;
t=i-1,sum=a[i-1];
for(int j=1;j<=i;j++){
ll h=get(i-j+1,i);
while(sum<=h&&t){
t--;
sum+=a[t];
}
dp[i][j]=f[i-j][i-j-t]+h*h;
f[i][j]=min(f[i][j-1],dp[i][j]);
sum-=a[i-j];
}
}
for(int i=1;i<=n;i++)
ans=min(ans,dp[n][i]);
write(ans);
cerr<<'\n'<<abs(&Begin-&End)/1048576<<"MB";
return 0;
}
因為我們這種 dp 的狀態數都已經達到了 \(N^2\),於是考慮找一些性質。
容易打表發現在合法情況下,滿足 \(dp_{i,j} \le dp_{i,j+1}\)。
那麼我們可以找到每個位置 \(i\),記錄一下 \(f_i\) 表示 \(\min dp_{i,j}\),且最後一段為 \([g_i,i]\),則狀態轉移方程為:
此時我們就將狀態時將至 \(O(N)\) 級別,現在考慮來最佳化狀態轉移方程。
$O(N)$ 狀態程式碼
#include<bits/stdc++.h>
#define Add(x,y) (x+y>=mod)?(x+y-mod):(x+y)
#define lowbit(x) x&(-x)
#define pi pair<ll,ll>
#define pii pair<ll,pair<ll,ll>>
#define iip pair<pair<ll,ll>,ll>
#define ppii pair<pair<ll,ll>,pair<ll,ll>>
#define fi first
#define se second
#define full(l,r,x) for(auto it=l;it!=r;it++) (*it)=x
#define Full(a) memset(a,0,sizeof(a))
#define open(s1,s2) freopen(s1,"r",stdin),freopen(s2,"w",stdout);
using namespace std;
typedef double db;
typedef unsigned long long ull;
typedef long long ll;
bool Begin;
const ll N=5e5+10,INF=4e18;
inline ll read(){
ll x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){
if(c=='-')
f=-1;
c=getchar();
}
while(c>='0'&&c<='9'){
x=(x<<1)+(x<<3)+(c^48);
c=getchar();
}
return x*f;
}
inline void write(ll x){
if(x<0){
putchar('-');
x=-x;
}
if(x>9)
write(x/10);
putchar(x%10+'0');
}
bool op;
ll n,t,h,ans=INF;
ll s[N],a[N],f[N],g[N];
ll get(ll l,ll r){
if(l>r)
return 0;
if(l<0)
return s[r];
return s[r]-s[l-1];
}
bool End;
int main(){
n=read(),op=read();
for(int i=1;i<=n;i++){
a[i]=read();
s[i]=s[i-1]+a[i];
f[i]=INF;
}
g[1]=1;
f[0]=g[0]=0;
f[1]=s[1]*s[1];
for(int i=2;i<=n;i++){
for(int j=0;j<i;j++){
h=get(g[j],j),t=get(j+1,i);
if(h>t)
continue;
if(f[j]+t*t<f[i]){
f[i]=f[j]+t*t;
g[i]=j+1;
}
}
}
write(f[n]);
cerr<<'\n'<<abs(&Begin-&End)/1048576<<"MB";
return 0;
}
容易發現,當 \(j\) 最大時,這個式子的值最小,所以我們需要求出一個最大的 \(j\) 滿足 \(s_j-s_{g_j-1} \le s_i - s_j\),即:
注意到 \(s_i\) 單增,我們可以維護一個 \(2s_j - s_{g_j-1}\) 單增的單調佇列,然後找到這個佇列最後一個滿足條件的 \(j\),那麼 \(j\) 以前的數對答案無法造成貢獻,將其彈出。
這樣每個數至多被彈出一次,時間複雜度為 \(O(N)\)。
完整程式碼:
#include<bits/stdc++.h>
#define Add(x,y) (x+y>=mod)?(x+y-mod):(x+y)
#define lowbit(x) x&(-x)
#define pi pair<ll,ll>
#define pii pair<ll,pair<ll,ll>>
#define iip pair<pair<ll,ll>,ll>
#define ppii pair<pair<ll,ll>,pair<ll,ll>>
#define fi first
#define se second
#define full(l,r,x) for(auto it=l;it!=r;it++) (*it)=x
#define Full(a) memset(a,0,sizeof(a))
#define open(s1,s2) freopen(s1,"r",stdin),freopen(s2,"w",stdout);
using namespace std;
typedef double db;
typedef unsigned long long ull;
typedef long long ll;
typedef __int128 __;
bool Begin;
const ll N=4e7+5,mod=1ll<<30;
inline ll read(){
ll x=0,f=1;
char c=getchar();
while(c<'0'||c>'9'){
if(c=='-')
f=-1;
c=getchar();
}
while(c>='0'&&c<='9'){
x=(x<<1)+(x<<3)+(c^48);
c=getchar();
}
return x*f;
}
inline void write(__ x){
if(x<0){
putchar('-');
x=-x;
}
if(x>9)
write(x/10);
putchar(x%10+'0');
}
__ t,ans;
bool op;
int n,head=1,tail=0;
ll s[N],g[N];
int Q[N];
inline void Read(){
ll x,y,z,m,p,l,r,pre=0;
x=read(),y=read(),z=read(),s[1]=read(),s[2]=read(),m=read();
for(int i=3;i<=n;i++)
s[i]=(x*s[i-1]+y*s[i-2]+z)%mod;
for(int i=1;i<=m;i++){
p=read(),l=read(),r=read();
for(int j=pre+1;j<=p;j++)
s[j]=(s[j]%(r-l+1))+l;
pre=p;
}
}
inline ll get(int l,int r){
if(l>r)
return 0;
if(l<1)
return s[r];
return s[r]-s[l-1];
}
inline ll date(ll x){
return 2ll*s[x]-s[g[x]-1];
}
bool End;
int main(){
n=read(),op=read();
if(op==1)
Read();
else{
for(int i=1;i<=n;i++)
s[i]=read();
}
for(int i=1;i<=n;i++)
s[i]+=s[i-1];
g[1]=1,g[0]=0;
Q[++tail]=0,Q[++tail]=1;
for(int i=2;i<=n;i++){
while(date(Q[head+1])<=s[i]&&head+1<=tail)
head++;
g[i]=Q[head]+1;
t=get(g[i],i);
while(date(i)<=date(Q[tail])&&tail>=head)
tail--;
Q[++tail]=i;
}
for(int i=n;i>=1;i=g[i]-1)
ans+=(__)(s[i]-s[g[i]-1])*(s[i]-s[g[i]-1]);
write(ans);
cerr<<'\n'<<abs(&Begin-&End)/1048576<<"MB";
return 0;
}