題意
定義一個長度為 \(n\) 的序列為 k 好序列當且僅當該序列存在一個長度為 \(k\) 的連續子序列構成 \(1\sim k\) 的排列。
定義一個 k 好序列的權值為特殊序列序列 \(\{b_i\}\) 在該序列中的出現次數。
序列值域為 \([1,k]\),求所有 k 好序列的權值之和。
\(n\le 2.5\times10^4,k\le 400\)
分析
原問題做起來有點棘手,不妨考慮正難則反,將原問題轉化為求 所有序列的權值 減去 非 k 好序列的權值。拆貢獻易得前半部分答案為 \((n-m+1)\cdot k^{n-m}\),現在只需要求後半部分就行了。
分類討論一下:
1. 特殊序列本身存在 k 好序列
此時後半部分答案為 0。
2. 特殊序列存在重複出現的元素
由於有重複元素,所以不可能存在長為 \(k\) 的排列橫跨特殊序列,這樣序列兩端的情況就相互獨立了,只需要保證兩端都沒有長為 \(k\) 的排列即可。設 \(f_{i,j}\) 表示左側擴充套件了 \(i\) 位,極長無重複元素的序列長為 \(j\) 的方案數。\(g_{i,j}\) 表示右側。轉移考慮新接一個未出現的數,\(f_{i+1,j+1}\leftarrow f_{i,j}\cdot (k-j)\),或者接一個已經出現的數,\(f_{i+1,p}\leftarrow f_{i,j},p\le j\)。該 DP 可以用字尾和最佳化做到 \(O(nk)\)。
然後列舉一下左側擴充套件多少即可。合併是一個卷積形式。
3. 特殊序列不存在重複元素
若繼續按照上述思路做複雜度會退化到 \(O(nk^2)\),考慮一個新的做法,我們考慮計算所有非 k 好序列的“m 好序列”連續子序列數量之和,由於不同的顏色地位相同,所以在這些 m 好序列上出現特殊序列與出現其他序列的情況數是相等的,最終把求出來的答案除以 \(A_{k-1}^{m-1}\)(因為 DP 陣列定義中最後一個位置的顏色沒有被欽定,我們只需要欽定最後一個位置為特殊序列最後一位即可,故不需要除以 \(A_k^m\))。
所有非 k 好序列的“m 好序列”連續子序列數量之和也可以透過情況 2 的類似 DP 來求解,區別在於要記兩個 DP 陣列分別表示方案數和 m 好序列數,轉移也要比情況 2 多一條。同樣需要字尾和最佳化,時間複雜度 \(O(nk)\)。
綜上,時間複雜度 \(O(nk)\)。
點選檢視程式碼
#include<iostream>
#include<cstdio>
#include<cstring>
#include<string>
#include<algorithm>
#include<cmath>
#include<map>
#include<unordered_map>
#include<vector>
#include<queue>
#include<stack>
#include<bitset>
#include<set>
#include<ctime>
#include<random>
#include<cassert>
#define IOS ios::sync_with_stdio(false)
#define PY puts("Yes")
#define PN puts("No")
#define PW puts("-1")
#define P0 puts("0")
#define P__ puts("")
#define PU puts("--------------------")
#define mp make_pair
#define fi first
#define se second
#define pc putchar
#define pb emplace_back
#define un using namespace
#define il inline
#define popc __builtin_popcountll
#define all(x) x.begin(),x.end()
#define rep(a,b,c) for(int a=(b);a<=(c);++a)
#define per(a,b,c) for(int a=(b);a>=(c);--a)
#define reprange(a,b,c,d) for(int a=(b);a<=(c);a+=(d))
#define perrange(a,b,c,d) for(int a=(b);a>=(c);a-=(d))
#define graph(i,j,k,l) for(int i=k[j];i;i=l[i].nxt)
#define lowbit(x) (x&-x)
#define lson(x) (x<<1)
#define rson(x) (x<<1|1)
#define mem(x,y) memset(x,y,sizeof x)
//#define double long double
//#define int long long
//#define int __int128
using namespace std;
using i64=long long;
using u64=unsigned long long;
using pii=pair<int,int>;
inline int rd(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+ch-48;ch=getchar();}return x*f;
}
template<typename T>
inline void write(T x,char ch='\0'){
if(x<0){x=-x;putchar('-');}
int y=0;char z[40];
while(x||!y){z[y++]=x%10+48;x/=10;}
while(y--)putchar(z[y]);if(ch!='\0')putchar(ch);
}
bool Mbg;
const int maxn=2.5e4+5,maxk=405,inf=0x3f3f3f3f,mod=1e9+7;
const long long llinf=0x3f3f3f3f3f3f3f3f;
int n,k,m,ans;
int a[maxn];
int ksm(int x,int y){
int res=1;
for(;y;y>>=1,x=1ll*x*x%mod)if(y&1)res=1ll*res*x%mod;
return res;
}
inline void add(int &x,int y){x+=y,x=x>=mod?x-mod:x;}
bool vis[maxn];
bool ck(){
rep(s,1,m-k+1){
rep(i,1,k)vis[i]=0;
bool ok=1;
rep(i,0,k-1){
if(vis[a[s+i]]){ok=0;break;}
vis[a[s+i]]=1;
}
if(ok)return 1;
}
return 0;
}
int sf[maxk],sg[maxk];
int f[maxn][maxk],g[maxn][maxk];
bool visf[maxk],visg[maxk];
void solve1(int Fx,int Gx){
f[0][Fx]=g[0][Gx]=1;
rep(i,1,Fx)sf[i]=1;
rep(i,1,Gx)sg[i]=1;
rep(i,1,n){
rep(j,1,k-1){
f[i][j]=1ll*f[i-1][j-1]*(k-j+1)%mod;
add(f[i][j],sf[j]);
g[i][j]=1ll*g[i-1][j-1]*(k-j+1)%mod;
add(g[i][j],sg[j]);
}
sf[k]=sg[k]=0;
per(j,k-1,1)sf[j]=(sf[j+1]+f[i][j])%mod,sg[j]=(sg[j+1]+g[i][j])%mod;
}
int res=0;
rep(i,0,n-m){
int sumf=0,sumg=0;
rep(j,1,k-1)add(sumf,f[i][j]),add(sumg,g[n-m-i][j]);
add(res,1ll*sumf*sumg%mod);
}
write((ans+mod-res)%mod);
}
int fac[maxn],inv[maxn];
int A(int x,int y){
if(x<y)return 0;
return 1ll*fac[x]*inv[x-y]%mod;
}
void solve2(){
fac[0]=1;rep(i,1,k)fac[i]=1ll*fac[i-1]*i%mod;
inv[k]=ksm(fac[k],mod-2);per(i,k-1,0)inv[i]=1ll*inv[i+1]*(i+1)%mod;
f[1][1]=sf[1]=1;
if(m==1)g[1][1]=sg[1]=1;
rep(i,2,n){
rep(j,1,k-1){
f[i][j]=1ll*f[i-1][j-1]*(k-j+1)%mod;
add(f[i][j],sf[j]);
g[i][j]=1ll*g[i-1][j-1]*(k-j+1)%mod;
add(g[i][j],sg[j]);
}
sf[k]=sg[k]=0;
per(j,k-1,1)sf[j]=(sf[j+1]+f[i][j])%mod;
rep(j,m,k-1)add(g[i][j],f[i][j]);
per(j,k-1,1)sg[j]=(sg[j+1]+g[i][j])%mod;
}
// write(sg[1],32);
int res=1ll*sg[1]*ksm(A(k-1,m-1),mod-2)%mod;
write((ans-res+mod)%mod);
}
inline void solve_the_problem(){
n=rd(),k=rd(),m=rd();
rep(i,1,m)a[i]=rd();
ans=1ll*(n-m+1)*ksm(k,n-m)%mod;
if(ck())return write(ans);
int fst=-1,lst=-1;
rep(i,1,m){
if(visf[a[i]]){fst=i-1;break;}
visf[a[i]]=1;
}
per(i,m,1){
if(visg[a[i]]){lst=m-i;break;}
visg[a[i]]=1;
}
if(fst==-1&&lst==-1)return solve2();
return solve1(fst,lst);
}
bool Med;
signed main(){
// freopen(".in","r",stdin);freopen(".out","w",stdout);
// fprintf(stderr,"%.3lfMB\n",(&Mbg-&Med)/1048576.0);
int _=1;
while(_--)solve_the_problem();
}
/*
*/