ARC100D/F Colorful Sequences

dcytrl發表於2024-11-16

題意

定義一個長度為 \(n\) 的序列為 k 好序列當且僅當該序列存在一個長度為 \(k\) 的連續子序列構成 \(1\sim k\) 的排列。

定義一個 k 好序列的權值為特殊序列序列 \(\{b_i\}\) 在該序列中的出現次數。

序列值域為 \([1,k]\),求所有 k 好序列的權值之和。

\(n\le 2.5\times10^4,k\le 400\)

分析

原問題做起來有點棘手,不妨考慮正難則反,將原問題轉化為求 所有序列的權值 減去 非 k 好序列的權值。拆貢獻易得前半部分答案為 \((n-m+1)\cdot k^{n-m}\),現在只需要求後半部分就行了。

分類討論一下:

1. 特殊序列本身存在 k 好序列

此時後半部分答案為 0。

2. 特殊序列存在重複出現的元素

由於有重複元素,所以不可能存在長為 \(k\) 的排列橫跨特殊序列,這樣序列兩端的情況就相互獨立了,只需要保證兩端都沒有長為 \(k\) 的排列即可。設 \(f_{i,j}\) 表示左側擴充套件了 \(i\) 位,極長無重複元素的序列長為 \(j\) 的方案數。\(g_{i,j}\) 表示右側。轉移考慮新接一個未出現的數,\(f_{i+1,j+1}\leftarrow f_{i,j}\cdot (k-j)\),或者接一個已經出現的數,\(f_{i+1,p}\leftarrow f_{i,j},p\le j\)。該 DP 可以用字尾和最佳化做到 \(O(nk)\)

然後列舉一下左側擴充套件多少即可。合併是一個卷積形式。

3. 特殊序列不存在重複元素

若繼續按照上述思路做複雜度會退化到 \(O(nk^2)\),考慮一個新的做法,我們考慮計算所有非 k 好序列的“m 好序列”連續子序列數量之和,由於不同的顏色地位相同,所以在這些 m 好序列上出現特殊序列與出現其他序列的情況數是相等的,最終把求出來的答案除以 \(A_{k-1}^{m-1}\)(因為 DP 陣列定義中最後一個位置的顏色沒有被欽定,我們只需要欽定最後一個位置為特殊序列最後一位即可,故不需要除以 \(A_k^m\))。

所有非 k 好序列的“m 好序列”連續子序列數量之和也可以透過情況 2 的類似 DP 來求解,區別在於要記兩個 DP 陣列分別表示方案數和 m 好序列數,轉移也要比情況 2 多一條。同樣需要字尾和最佳化,時間複雜度 \(O(nk)\)

綜上,時間複雜度 \(O(nk)\)

點選檢視程式碼
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<map>
#include<unordered_map>
#include<vector>
#include<queue>
#include<stack>
#include<bitset>
#include<set>
#include<ctime>
#include<random>
#include<cassert>
#define IOS ios::sync_with_stdio(false)
#define PY puts("Yes")
#define PN puts("No")
#define PW puts("-1")
#define P0 puts("0")
#define P__ puts("")
#define PU puts("--------------------")
#define mp make_pair
#define fi first
#define se second
#define pc putchar
#define pb emplace_back
#define un using namespace
#define il inline
#define popc __builtin_popcountll
#define all(x) x.begin(),x.end()
#define rep(a,b,c) for(int a=(b);a<=(c);++a)
#define per(a,b,c) for(int a=(b);a>=(c);--a)
#define reprange(a,b,c,d) for(int a=(b);a<=(c);a+=(d))
#define perrange(a,b,c,d) for(int a=(b);a>=(c);a-=(d))
#define graph(i,j,k,l) for(int i=k[j];i;i=l[i].nxt)
#define lowbit(x) (x&-x)
#define lson(x) (x<<1)
#define rson(x) (x<<1|1)
#define mem(x,y) memset(x,y,sizeof x)
//#define double long double
//#define int long long
//#define int __int128
using namespace std;
using i64=long long;
using u64=unsigned long long;
using pii=pair<int,int>;
inline int rd(){
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-48;ch=getchar();}return x*f;
}
template<typename T>
inline void write(T x,char ch='\0'){
	if(x<0){x=-x;putchar('-');}
	int y=0;char z[40];
	while(x||!y){z[y++]=x%10+48;x/=10;}
	while(y--)putchar(z[y]);if(ch!='\0')putchar(ch);
}
bool Mbg;
const int maxn=2.5e4+5,maxk=405,inf=0x3f3f3f3f,mod=1e9+7;
const long long llinf=0x3f3f3f3f3f3f3f3f;
int n,k,m,ans;
int a[maxn];
int ksm(int x,int y){
	int res=1;
	for(;y;y>>=1,x=1ll*x*x%mod)if(y&1)res=1ll*res*x%mod;
	return res;
}
inline void add(int &x,int y){x+=y,x=x>=mod?x-mod:x;}
bool vis[maxn];
bool ck(){
	rep(s,1,m-k+1){
		rep(i,1,k)vis[i]=0;
		bool ok=1;
		rep(i,0,k-1){
			if(vis[a[s+i]]){ok=0;break;}
			vis[a[s+i]]=1;
		}
		if(ok)return 1;
	}
	return 0;
}
int sf[maxk],sg[maxk];
int f[maxn][maxk],g[maxn][maxk];
bool visf[maxk],visg[maxk];
void solve1(int Fx,int Gx){
	f[0][Fx]=g[0][Gx]=1;
	rep(i,1,Fx)sf[i]=1;
	rep(i,1,Gx)sg[i]=1;
	rep(i,1,n){
		rep(j,1,k-1){
			f[i][j]=1ll*f[i-1][j-1]*(k-j+1)%mod;
			add(f[i][j],sf[j]);
			g[i][j]=1ll*g[i-1][j-1]*(k-j+1)%mod;
			add(g[i][j],sg[j]);
		}
		sf[k]=sg[k]=0;
		per(j,k-1,1)sf[j]=(sf[j+1]+f[i][j])%mod,sg[j]=(sg[j+1]+g[i][j])%mod;
	}
	int res=0;
	rep(i,0,n-m){
		int sumf=0,sumg=0;
		rep(j,1,k-1)add(sumf,f[i][j]),add(sumg,g[n-m-i][j]);
		add(res,1ll*sumf*sumg%mod);
	}
	write((ans+mod-res)%mod);
}
int fac[maxn],inv[maxn];
int A(int x,int y){
	if(x<y)return 0;
	return 1ll*fac[x]*inv[x-y]%mod;
}
void solve2(){
	fac[0]=1;rep(i,1,k)fac[i]=1ll*fac[i-1]*i%mod;
	inv[k]=ksm(fac[k],mod-2);per(i,k-1,0)inv[i]=1ll*inv[i+1]*(i+1)%mod;
	f[1][1]=sf[1]=1;
	if(m==1)g[1][1]=sg[1]=1;
	rep(i,2,n){
		rep(j,1,k-1){
			f[i][j]=1ll*f[i-1][j-1]*(k-j+1)%mod;
			add(f[i][j],sf[j]);
			g[i][j]=1ll*g[i-1][j-1]*(k-j+1)%mod;
			add(g[i][j],sg[j]);
		}
		sf[k]=sg[k]=0;
		per(j,k-1,1)sf[j]=(sf[j+1]+f[i][j])%mod;
		rep(j,m,k-1)add(g[i][j],f[i][j]);
		per(j,k-1,1)sg[j]=(sg[j+1]+g[i][j])%mod;
	}
//	write(sg[1],32);
	int res=1ll*sg[1]*ksm(A(k-1,m-1),mod-2)%mod;
	write((ans-res+mod)%mod);
}
inline void solve_the_problem(){
	n=rd(),k=rd(),m=rd();
	rep(i,1,m)a[i]=rd();
	ans=1ll*(n-m+1)*ksm(k,n-m)%mod;
	if(ck())return write(ans);
	int fst=-1,lst=-1;
	rep(i,1,m){
		if(visf[a[i]]){fst=i-1;break;}
		visf[a[i]]=1;
	}
	per(i,m,1){
		if(visg[a[i]]){lst=m-i;break;}
		visg[a[i]]=1;
	}
	if(fst==-1&&lst==-1)return solve2();
	return solve1(fst,lst);
}
bool Med;
signed main(){
//	freopen(".in","r",stdin);freopen(".out","w",stdout);
//	fprintf(stderr,"%.3lfMB\n",(&Mbg-&Med)/1048576.0);
	int _=1;
	while(_--)solve_the_problem();
}
/*

*/

相關文章