P1637 三元上升子序列
簡要題意,在一個序列中尋找長度為三的上升子序列
思路
有兩種思路
直接法
一種是對於一個樹,算一個數左邊比他小的數,算右邊比他大的數,然後相乘即是該該點處值
算比他大的數,和比他小的數,用樹狀陣列或線段樹即皆可
CODE
#include<bits/stdc++.h>
using namespace std;
#define x first
#define y second
#define ll long long
int n;
const int maxn=1e5+10;
pair<int,int>m[maxn];
int t[maxn<<2];
int num[maxn];
int j[maxn<<2];
int sml[maxn];//比它小的數
int smx[maxn];//比它大的數
void push_up(int p){
t[p]=t[p<<1]+t[p<<1|1];
}
void update(int p,int l,int r,int nl,int nr){
if(l==r && nl==l){
++t[p];
return ;
}
int mid=(l+r)>>1;
if(nl<=mid) update(p<<1,l,mid,nl,nr);
if(nr>mid) update(p<<1|1,mid+1,r,nl,nr);
push_up(p);
}
ll query(int p,int l,int r,int nl,int nr){
ll res=0;
if(nl<=l && r<=nr){
return t[p];
}
int mid=(l+r)>>1;
if(nl<=mid) res+=query(p<<1,l,mid,nl,nr);
if(nr>mid) res+=query(p<<1|1,mid+1,r,nl,nr);
push_up(p);
return res;
}
int main(){
cin>>n;
for(int i=1;i<=n;++i){
int x;cin>>m[i].x;
m[i].y=i;
}
sort(m+1,m+1+n);
int cnt=0;//cnt是離散後的大小
for(int i=1;i<=n;++i){
if(m[i].x>m[i-1].x) ++cnt;
num[m[i].y]=cnt;//離散化
}
for(int i=1;i<=n;++i){
if(num[i]>1) sml[i]=query(1,1,n,1,num[i]-1);
update(1,1,n,num[i],num[i]);
}
memset(t,0,sizeof(t));
for(int i=n;i>=1;--i){
if(num[i]<n) smx[i]=query(1,1,n,num[i]+1,n);
update(1,1,n,num[i],num[i]);
}
ll ans=0;
for(int i=1;i<=n;++i) ans+=(smx[i]*sml[i]);
cout<<ans<<endl;
return 0;
}
DP
上升子序列,其實可以讓我們很容易相到,最長上升子序列的求法,只需稍加修改即可
令f[i][j]是以a[j]為結尾長度為i的上升子序列
\(f[i][j]=\sum_{k<j,a[k]<a[j]}f[i-1][k]\)
利用桶排序的思想,儲存f[i][j],在第a[j]個點,這樣在轉移只需要求
小於a[j]的和即可
如何高效的統計這個和,則會用到樹狀陣列或線段樹
遍歷i時
樹狀陣列存的應是f[i-1][k]的和(因此遍歷i之後需要清空樹狀陣列)
遍歷第j+1個序列時,前j個序列 第a[j]點的位置加上f[i-1][j]
尋找滿足狀態轉移方程則 f[i][j]+=sum(a[j]-1)
#include<bits/stdc++.h>
using namespace std;
#define x first
#define y second
#define ll long long
int n;
const int maxn=1e5+10;
pair<int,int> m[maxn];
int a[maxn];
int f[5][maxn];
int lowbit(int x){
return x&(-x);
}
int t[maxn];
void add(int x,int k){
while(x<=n){
t[x]+=k;
x+=lowbit(x);
}
}
ll sum(int x){
ll sum=0;
while(x){
sum+=t[x];
x-=lowbit(x);
}
return sum;
}
int main(){
cin>>n;
for(int i=1;i<=n;++i) {
cin>>m[i].x;
m[i].y=i;
}
int num=0;
sort(m+1,m+1+n);
for(int i=1;i<=n;++i){
if(m[i].x>m[i-1].x) ++num;
a[m[i].y]=num;
}
for(int i=1;i<=n;++i) f[1][i]=1;
for(int i=2;i<=3;++i){
memset(t,0,sizeof(t));
for(int j=1;j<=n;++j){
f[i][j]=sum(a[j]-1);
add(a[j],f[i-1][j]);
}
}
ll ans=0;
for(int i=1;i<=n;++i) ans+=f[3][i];
cout<<ans<<endl;
return 0;
}