P1637 三元上升子序列

归游發表於2024-11-09

P1637 三元上升子序列

簡要題意,在一個序列中尋找長度為三的上升子序列

思路

有兩種思路

直接法

一種是對於一個樹,算一個數左邊比他小的數,算右邊比他大的數,然後相乘即是該該點處值

算比他大的數,和比他小的數,用樹狀陣列或線段樹即皆可

CODE

#include<bits/stdc++.h>
using namespace std;
#define x first
#define y second
#define ll long long 

int n;
const int maxn=1e5+10;
pair<int,int>m[maxn];
int t[maxn<<2];
int num[maxn];
int j[maxn<<2];
int sml[maxn];//比它小的數
int smx[maxn];//比它大的數
void push_up(int p){
	t[p]=t[p<<1]+t[p<<1|1]; 
}
void update(int p,int l,int r,int nl,int nr){
	if(l==r  && nl==l){
		++t[p];
		return ;
	} 
	int mid=(l+r)>>1;
	if(nl<=mid) update(p<<1,l,mid,nl,nr);
	if(nr>mid) update(p<<1|1,mid+1,r,nl,nr);
	push_up(p);
}
ll query(int p,int l,int r,int nl,int nr){
	ll res=0;
	if(nl<=l && r<=nr){
		return t[p];	
	}
	int mid=(l+r)>>1;
	if(nl<=mid) res+=query(p<<1,l,mid,nl,nr);
	if(nr>mid) res+=query(p<<1|1,mid+1,r,nl,nr);
	push_up(p);
	return res;
}
int main(){
	cin>>n;
	for(int i=1;i<=n;++i){
		int x;cin>>m[i].x;
		m[i].y=i;
	}
	sort(m+1,m+1+n);
	int cnt=0;//cnt是離散後的大小 
	for(int i=1;i<=n;++i){
		if(m[i].x>m[i-1].x) ++cnt;
		num[m[i].y]=cnt;//離散化
	}
	for(int i=1;i<=n;++i){
		if(num[i]>1) sml[i]=query(1,1,n,1,num[i]-1);
		update(1,1,n,num[i],num[i]); 
	}
	memset(t,0,sizeof(t));
	for(int i=n;i>=1;--i){
		if(num[i]<n)  smx[i]=query(1,1,n,num[i]+1,n);
		update(1,1,n,num[i],num[i]);
	}
	ll ans=0;
	for(int i=1;i<=n;++i) ans+=(smx[i]*sml[i]);
	cout<<ans<<endl;
	return 0;
}

DP

上升子序列,其實可以讓我們很容易相到,最長上升子序列的求法,只需稍加修改即可

令f[i][j]是以a[j]為結尾長度為i的上升子序列

\(f[i][j]=\sum_{k<j,a[k]<a[j]}f[i-1][k]\)

利用桶排序的思想,儲存f[i][j],在第a[j]個點,這樣在轉移只需要求
小於a[j]的和即可

如何高效的統計這個和,則會用到樹狀陣列或線段樹

遍歷i時

樹狀陣列存的應是f[i-1][k]的和(因此遍歷i之後需要清空樹狀陣列)

遍歷第j+1個序列時,前j個序列 第a[j]點的位置加上f[i-1][j]

尋找滿足狀態轉移方程則 f[i][j]+=sum(a[j]-1)

#include<bits/stdc++.h>
using namespace std;
#define x first 
#define y second
#define ll long long
int n;
const int maxn=1e5+10;
pair<int,int> m[maxn];
int a[maxn];
int f[5][maxn];
int lowbit(int x){
	return x&(-x);
}
int t[maxn];
void add(int x,int k){
	while(x<=n){
		t[x]+=k;
		x+=lowbit(x);
	}
}
ll sum(int x){
	ll sum=0;
	while(x){
		sum+=t[x];
		x-=lowbit(x);
	}
	return sum;
}
int main(){
	cin>>n;
	for(int i=1;i<=n;++i) {
		cin>>m[i].x;
		m[i].y=i;
	}
	int num=0;
	sort(m+1,m+1+n);
	for(int i=1;i<=n;++i){
		if(m[i].x>m[i-1].x) ++num;
		a[m[i].y]=num;
	}
	for(int i=1;i<=n;++i) f[1][i]=1;
	for(int i=2;i<=3;++i){
		memset(t,0,sizeof(t));
		for(int j=1;j<=n;++j){
			f[i][j]=sum(a[j]-1);
			add(a[j],f[i-1][j]);
		}
	}
	ll ans=0;
	for(int i=1;i<=n;++i) ans+=f[3][i];
	cout<<ans<<endl;
	return 0; 
} 

相關文章