Description
給定一個 \(h\times w\) 的 \(01\) 矩陣和非負整數序列 \(\{b_n\}\),接下來每秒會在 \(h\times w\) 個格子中均勻隨機地選取一個將其取反。
問期望第幾秒會第一次滿足:對於任意的 \(i\) 有第 \(i\) 行恰有 \(b_i\) 個 \(1\)。
\(1\le h,w\le 50\)。
Solution
閒話:
我終於會 PGF 了!!!!
這真的只是銅牌題嗎???我只能說,too hard for me。
由於筆者是 PGF 初學者,所以本篇文章會講得非常詳細以來加深筆者自己地理解和幫助其他初學者。
下面進入正文:
前置知識:
- 基本組合數學知識(如二項式定理等)。
- 基本微積分知識(如求導,泰勒展開等)。
閱讀此文你並不需要有關於機率生成函式(PGF)的知識,因為本文可能會從頭到尾地介紹一遍。
你甚至可能不需要普通生成函式(OGF)和指數生成函式(EGF)的知識(雖然理論上是這樣,但是我覺得想要做這題還是得了解生成函式)。
符號約定:
- \(F(x)\) 表示序列 \(f\) 的普通生成函式(OGF):\(F(x)=\sum\limits_{i\ge 0} f_ix^i\)。
- \(\hat{F}(x)\) 表示序列 \(f\) 的指數生成函式(EGF):\(\hat{F}(x)=\sum\limits_{i\ge 0} \frac{f_i}{i!}x^i=\sum\limits_{i\ge 0} \frac{[x^i]F(x)}{i!}x^i\)。
- \([x^i]F(x)\) 表示多項式 \(F(x)\) 其第 \(i\) 次項的係數。
- \(F*G\) 表示的是兩個多項式卷積得到的多項式,具體的,若 \(H=F*G\),則我們有:
- \([x^n]H(x)=\sum\limits_{m=0}^n [x^m]F(x)\times [x^{n-m}]G(x)\)。
- 這個東西也可以寫作 \(H(x)=F(x)G(x)\)。
- \(F'(x)\) 表示的是對多項式 \(F(x)\) 求導後得到的多項式。
- \(\Pr(A)\) 表示的是 \(A\) 事件成立的機率。
- \(\operatorname{E}(X)\) 表示的是隨機變數 \(X\) 的期望取值。
Pre - 機率生成函式(PGF)
如果我們現在有一個取值是非負整數的隨機變數 \(X\),則我們記其機率生成函式(PGF)為:
下面簡記為 \(F(x)\)。
或者從另一個角度理解:即序列 \(\{\Pr(X=1),\Pr(X=2),\cdots\}\) 的普通生成函式(OGF)。
更一般的,我們的 PGF 可能是關於一系列事件而非一個隨機變數的(特別的,我們可以將 \(X=1,2,\cdots\) 看作是一系列不同的事件)。
在這題中,我們的隨機變數 \(X\) 也就是停時時間(即第一次符合條件的時間)。
則我們的答案(期望停時時間)就是:
我們發現:如果對 \(F(x)\) 求導,我們就可以得到:
如果我們此時帶入 \(x=1\),就可以驚訝地發現:
也就是說,\(F'(1)\) 就是我們想要的答案!
Pre - 解決「第一次」問題
我們注意到本題中想要的是「第一次符合條件的時間」,這是很煩的,我們考慮這樣子處理:
我們記 \(S,T\) 分別為題目中的初始和最終局面,事件 \(A(U\to V,i)\) 表示的是從初始是 \(U\) 局面經過恰好 \(i\) 時刻變成了 \(V\) 局面,則我們列出關於 \(A(S\to T,*)\) 的 PGF \(G(x)\) 和關於 \(A(T\to T,*)\) 的 PGF \(H(x)\):
則我們有:\(G=F*H\)。
- 證明:我們可以認為是,當我們在 \(i\) 時刻第一次從初始狀態到達了最終狀態,然後又經過了 \(j\) 時刻的一系列操作從最終狀態又回到了最終狀態;又因為這兩個事件是獨立的,則我們有 \([x^i]F(x)\times [x^j]H(x)\to [x^{i+j}]G(x)\),即 \(G=F*H\)。
注意到我們關心的是 \(F'(1)\) 的值,則根據 \(G=F*H\),我們有 \(F=\frac{G}{H}\),根據商的求導法則,我們有:
則我們現在將問題轉換為了求 \(G(1),G'(1),H(1),H'(1)\) 的值,而這樣子就消除了「第一次」的限制。
下面我們預設只討論 \(G(1),G'(1)\) 的求解,對於 \(H(1),H'(1)\) 則是幾乎一模一樣的,在實現上只需要改一下進行 dp 時傳進去的引數即可。
Part 1 - 一個 dp
注意到我們一個格子翻了奇數次相當於翻了一次,翻了偶數次則相當於沒翻。
我們不妨假定:每個格子最多隻會被翻轉一次。考慮計算從 \(S\to T\) 的方案總數。
注意到我們只關心每一行 \(1\) 的個數,所以考慮按行 dp。
記錄 \(f_{i,j}\) 表示考慮前 \(i\) 行翻轉恰好 \(j\) 次使得滿足條件的的方案總數。
我們記第 \(i\) 行在初始狀態有 \(a_i\) 個 \(1\),而最終狀態就是題目給定的有 \(b_i\) 個 \(1\)。
則考慮列舉將多少個 \(1\) change 成了 \(0\)。
如果有 \(k\) 個 \(0\to 1\),則我們應有 \(b_i-(a_i-k)\) 個 \(1\to 0\),則得到轉移方程:
記 \(n=h\times w\),記格子總數量,則這個 dp 是 \(\mathcal{O}(n^2)\) 的,可以接受。
我們設 \(\{c_n\}\) 為最終的答案,即 \(c_i=f_{h,i}\)。
Part 2 - 構造生成函式
現在我們對於每一個格子給出其被翻轉偶數 or 奇數次的 PGF:
這時候我們看起來只需要列舉實際上有多少個格子被翻轉(即被翻轉奇數次),然後將對應的生成函式捲起來即可!
但是實際上並不是,因為不同格子之間的翻轉是區分順序的,所以我們要使用 EGF 進行區分(這是因為兩個 EGF 的卷積表述成最終的序列是會乘上一個組合數的係數),即得到:
嘗試推導 \(\hat{P_0}\) 和 \(\hat{P_1}\)。
我們給出:
可以得到:
最後一步是考慮 \(e^x\) 的泰勒展開。
同理我們可以得到的是:\(\hat{Q}(x)=e^{-\frac{x}{n}}\)。
我們注意到有:
則可以得到的是:
也就是對等式兩邊同時取 EGF,因為我們注意到 \([x^n]\) 的相對關係沒有發生變化,所以等號仍然成立。
即
同理有:
則我們得到了:
Part 2 - 一個柿子的推
考慮化簡上面 EGF:
其中第三個等號處用二項式定理將其暴力展開,倒數第二個等號處改為列舉 \(j+k\),最後一個等號處將 \(\sum\limits_{j=0}^i \binom{i}{j}\binom{n-i}{k-j}(-1)^{n-i-(k-j)}\) 改為其 OGF 寫法(考慮二項式定理將其展開即可獲證)。
Part 3 - 計算答案
我們先暴力計算出 \(\sum\limits_{i=1}^n c_i(x+1)^{i}(x-1)^{n-i}\) 的多項式 \(C(x)\)。
則柿子化為
這時候我們去掉 EGF,還原會 OGF:
這一部分即為先將 \(e^{(\frac{2k-n}{n})x}\) 泰勒展開得到:
然後再化為 OGF 的封閉形式:
此時我們代入 \(x=1\) 即可得到 \(G(1)\) 的值,然而有一個問題是當 \(k=n\) 時我們會得到 \(\frac{1}{0}\),這無法處理。
我們考慮令 \(A(x)=(1-x)G(x),B(x)=(1-x)H(x)\),此時我們仍有 \(F=\frac{A}{B}\),所以不影響答案。
而這時當 \(k=n\) 時我們外層的 \(1-x\) 和內層的 \(1-x\) 相互抵消,所以對 \(A(1)\) 恰好貢獻為 \(1\times [x^n]C(x)\)。
而對於 \(k\ne n\) 時,我們記 \(t=\frac{2t-n}{n},v=[x^k]C(x)\),則所得多項式為 \(\frac{v(x-1)}{1-tx}\)。
- 這時 \(x-1=0\),不對 \(A(1)\) 產生貢獻。
- 此時導函式為 \(\frac{(t-1)v}{(1-tx)^2}\),對 \(A'(1)\) 貢獻為 \(\frac{v}{t-1}\)。
注意別忘記除掉 \(\frac{1}{2^n}\)。
同理我們可以計算出 \(B(1)\) 和 \(B'(1)\) 的值,則可以得到:
這就是答案。
暴力實現複雜度是 \(\mathcal{O}(n^2)\) 的,已經可以透過;使用一些多項式知識可以做到 \(\mathcal{O}(n\operatorname{polylog}(n))\)。
下面給出 \(\mathcal{O}(n^2)\) 的實現。
Code
#include<bits/stdc++.h>
//#pragma GCC optimize(3,"Ofast","inline")
//#define int long long
#define i128 __int128
#define ll long long
#define ull unsigned long long
#define uint unsigned int
#define ld double
#define PII pair<int,int>
#define INF 0x3f3f3f3f
#define INFLL 0x3f3f3f3f3f3f3f3f
#define chkmax(a,b) a=max(a,b)
#define chkmin(a,b) a=min(a,b)
#define rep(k,l,r) for(int k=l;k<=r;++k)
#define per(k,r,l) for(int k=r;k>=l;--k)
#define cl(f,x) memset(f,x,sizeof(f))
#define pcnt(x) __builtin_popcount(x)
#define lg(x) (31-__builtin_clz(x))
using namespace std;
void file_IO() {
// system("fc .out .ans");
freopen(".in","r",stdin);
freopen(".out","w",stdout);
}
bool M1;
template<int p>
struct mint {
int x;
mint() {
x=0;
}
mint(int _x) {
x=_x;
}
friend mint operator + (mint a,mint b) {
return a.x+b.x>=p? a.x+b.x-p:a.x+b.x;
}
friend mint operator - (mint a,mint b) {
return a.x<b.x? a.x-b.x+p:a.x-b.x;
}
friend mint operator * (mint a,mint b) {
return 1ll*a.x*b.x%p;
}
friend mint operator ^ (mint a,ll b) {
mint res=1,base=a;
while(b) {
if(b&1)
res*=base;
base*=base; b>>=1;
}
return res;
}
friend mint operator ~ (mint a) {
return a^(p-2);
}
friend mint operator / (mint a,mint b) {
return a*(~b);
}
friend mint & operator += (mint& a,mint b) {
return a=a+b;
}
friend mint & operator -= (mint& a,mint b) {
return a=a-b;
}
friend mint & operator *= (mint& a,mint b) {
return a=a*b;
}
friend mint & operator /= (mint& a,mint b) {
return a=a/b;
}
friend mint operator ++ (mint& a) {
return a+=1;
}
friend mint operator -- (mint& a) {
return a-=1;
}
};
const int MOD=998244353;
#define mint mint<MOD>
const int N=5e3+5;
mint jc[N],inv_jc[N];
void init(int n=5000) {
jc[0]=1;
rep(i,1,n)
jc[i]=jc[i-1]*i;
inv_jc[n]=~jc[n];
per(i,n-1,0)
inv_jc[i]=inv_jc[i+1]*(i+1);
}
mint C(int n,int m) {
if(m<0||n<m)
return 0;
return jc[n]*inv_jc[n-m]*inv_jc[m];
}
int n,h,w;
mint tmp[N][N],t[N];
void calc(int a[],int b[],mint c[]) {
rep(i,0,h) {
rep(j,0,n)
tmp[i][j]=0;
}
tmp[0][0]=1;
rep(i,1,h) {
rep(j,0,w)
t[j]=0;
rep(j,0,a[i]) {
int d=b[i]-a[i]+j;
if(j+d>=0)
t[j+d]+=C(a[i],j)*C(w-a[i],d);
}
rep(j,0,w) {
rep(k,0,n-j)
tmp[i][j+k]+=tmp[i-1][k]*t[j];
}
}
rep(i,0,n)
c[i]=tmp[h][i];
}
mint p[N];
void calc(mint c[],mint f[]) {
rep(i,0,n)
p[i]=C(n,i);
rep(i,0,n) {
if(i) {
rep(j,1,n)
p[j]-=p[j-1];
per(j,n,1)
p[j]=p[j-1]-p[j];
p[0]=MOD-p[0];
}
rep(j,0,n)
f[j]+=p[j]*c[i];
}
}
int a[N],b[N];
char s[N];
mint c1[N],c2[N],f[N],g[N];
void solve() {
scanf("%d%d",&h,&w);
rep(i,1,h) {
scanf("%s",s+1);
int cnt=0;
rep(j,1,w)
cnt+=s[j]=='1';
a[i]=cnt;
}
rep(i,1,h)
scanf("%d",&b[i]);
n=h*w;
calc(a,b,c1);
calc(c1,f);
calc(b,b,c2);
calc(c2,g);
mint sf=0,sg=0,sdf=0,sdg=0;
rep(i,0,n-1) {
mint val=(mint(2*i)-mint(n))/mint(n);
sdf+=f[i]/(val-mint(1));
sdg+=g[i]/(val-mint(1));
}
sf+=f[n];
sg+=g[n];
mint inv=(~mint(2))^n;
sf*=inv; sg*=inv;
sdf*=inv; sdg*=inv;
printf("%d\n",((sdf*sg-sf*sdg)/(sg*sg)).x);
}
bool M2;
signed main() {
//file_IO();
int testcase=1;
init();
//scanf("%d",&testcase);
while(testcase--)
solve();
fprintf(stderr,"used time = %ldms\n",1000*clock()/CLOCKS_PER_SEC);
fprintf(stderr,"used memory = %lldMB\n",(&M2-&M1)/1024/1024);
return 0;
}