矩陣樹定理

weixin_45429627發表於2020-10-21

一些定義如下:

  • 度數矩陣:一個 n n n 階無向圖 G G G 的度數矩陣 D D D 的大小為 n × n n \times n n×n,且 D D D 僅在 D i , i D_{i,i} Di,i( i = 1 , 2 , ⋯ n i=1,2,\cdots n i=1,2,n) 處有值,其中 D i , i D_{i,i} Di,i 的值為點 i i i 的度數。
  • (出 / 入) 度矩陣:一個 n n n 階有向圖 G G G 的 (出 / 入) 度矩陣 D D D 的大小為 n × n n\times n n×n,且 D D D 僅在 D i , i D_{i,i} Di,i( i = 1 , 2 , ⋯ n i=1,2,\cdots n i=1,2,n) 處有值,其中 D i , i D_{i,i} Di,i 的值為點 i i i 的 (出 / 入) 度。
  • 鄰接矩陣 —— 無向圖:一個 n n n 階無向圖 G G G 的鄰接矩陣 A A A 的大小為 n × n n \times n n×n A A A A i , j A_{i,j} Ai,j( i ≠ j i\neq j i=j) 處有值,其中 A i , j A_{i,j} Ai,j 的值為點 i , j i,j i,j 間的連邊個數。
  • 鄰接矩陣 —— 有向圖:一個 n n n 階有向圖 G G G 的鄰接矩陣 A A A 的大小為 n × n n \times n n×n A A A A i , j A_{i,j} Ai,j( i ≠ j i\neq j i=j) 處有值,其中 A i , j A_{i,j} Ai,j 的值為從點 i i i 連向點 j j j 的有向邊個數。

矩陣樹定理

由於此定理的證明極其複雜,需要大量高等數學知識,這裡只給出結論。

矩陣樹定理是用於求解圖上生成樹計數問題的重要定理,內容如下:

G G G 為一 n n n 階無向圖,定義 G G G 的基爾霍夫(Kirchhoff) 矩陣 K K K 為其度數矩陣與其鄰接矩陣之差,則 G G G 的無根生成樹的個數為 K K K 的任意一個 n − 1 n - 1 n1 階主子式對應行列式的絕對值。

定理內容十分簡潔,沒什麼好講的。程式碼實現上基爾霍夫矩陣可以直接按照定義計算,對行列式求值用高斯消元即可,時間複雜度 O ( n 3 ) \Omicron(n^3) O(n3)

下面給出高斯消元求行列式的程式碼:

double det(int n){//求矩陣a[1...n][1...n]的行列式
	int f = 1;
	double d = 1;
	for(int i = 1; i <= n; i ++){
		if(fabs(a[i][i]) < eps){//分母不能為0
			int flag = 0;
			for(int j = i + 1; j <= n; j ++){
				if(fabs(a[j][i]) > eps){
					f *= -1,flag = 1;//每次交換要變號
					for(int k = i; k <= n; k ++) swap(a[i][k],a[j][k]);
					break;
				}
			}
			if(!flag) return 0;//若整列均為0,行列式值為0
		}
		for(int j = i + 1; j <= n; j ++){
			double t = a[j][i] / a[i][i];
			for(int k = i; k <= n; k ++) a[j][k] -= t * a[i][k];
		}
		d *= a[i][i];//最終值為消去後主對角線元素之積,這裡由於值不會再改變直接寫在迴圈裡面
	}
	return f * d;
}

有向圖上的推廣

G G G 為一 n n n 階有向圖,定義 G G G 的基爾霍夫(Kirchhoff) 矩陣 K K K 為其 (出 / 入) 度矩陣與其鄰接矩陣之差,則 G G G 中以任一點 i i i 為根的 (內 / 外) 向生成樹的個數為 K i , i K_{i,i} Ki,i 的餘子式對應行列式的絕對值。

變元矩陣樹定理

將上述兩個定理中的度數、邊數改為相應邊邊權和,則得到的值從生成樹個數變為 對應的生成樹所有邊權之積 的和。

容易發現矩陣樹定理是此定理每條邊邊權取一時的特殊情況。

例題

[SDOI2014]重建

分析

下面我們用 e e e 表示一條邊, p e p_e pe 表示 e e e 連通的概率。
容易發現題目讓我們求的是這個式子(其中 T T T 表示生成樹,下同)
∑ T ∏ e ∈ T p e ∏ e ∉ T ( 1 − p e ) \sum_T\prod_{e\in T}p_e\prod_{e\not\in T}(1-p_e) TeTpeeT(1pe)

對比一下變元矩陣樹定理的式子
∑ T ∏ e ∈ T w e \sum_T\prod_{e\in T}w_e TeTwe

發現我們要求的式子的邊權似乎不那麼統一,考慮能不能轉化一下。
D = ∏ e ( 1 − p e ) D=\prod_e(1-p_e) D=e(1pe),那麼考慮將 D D D 從式子中提出來,式子變為
D ∑ T ∏ e ∈ T p e 1 − p e D\sum_{T}\prod_{e\in T}\frac{p_e}{1-p_e} DTeT1pepe

於是只要設 e e e 的邊權為 p e 1 − p e \frac{p_e}{1-p_e} 1pepe 就是模板題了,這題可以說是相當裸的了。
要注意當 p e p_e pe 小於 e p s \rm eps eps 時直接取 e p s \rm eps eps,否則計算時會出現無窮,大於 1 − e p s 1 - \rm eps 1eps 同理。

程式碼
#include <iostream>
#include <cstdio>
#include <cmath>
using namespace std;
const int maxn = 55;
const double eps = 1e-8;
int n;
double t,D = 1,a[maxn][maxn];
int read(){
	int x = 0;
	char c = getchar();
	while(c < '0' || c > '9') c = getchar();
	while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
	return x;
}
double det(int n){
	int f = 1;
	double d = 1;
	for(int i = 1; i <= n; i ++){
		if(fabs(a[i][i]) < eps){
			for(int j = i + 1; j <= n; j ++){
				if(fabs(a[j][i]) > eps){
					f *= -1;
					for(int k = i; k <= n; k ++) swap(a[i][k],a[j][k]);
					break;
				}
			}
		}
		for(int j = i + 1; j <= n; j ++){
			double t = a[j][i] / a[i][i];
			for(int k = i; k <= n; k ++) a[j][k] -= t * a[i][k];
		}
		d *= a[i][i];
	}
	return f * d;
}
int main(){
	n = read();
	for(int i = 1; i <= n; i ++)
		for(int j = 1; j <= n; j ++){
			scanf("%lf",&t);
			if(t < eps) t = eps;
			if(t > 1 - eps) t = 1 - eps;
			if(i > j) D *= 1 - t;
			if(i != j) a[i][j] = t / (t - 1);
		}
	for(int i = 1; i <= n; i ++)
		for(int j = 1; j <= n; j ++)
			if(i != j) a[i][i] -= a[i][j];
	printf("%f\n",det(n - 1) * D);//記得乘 D
	return 0;
}

[SHOI2016]黑暗前的幻想鄉

分析

直接容斥,然後對每一種情況重新建立矩陣,時間複雜度為 O ( 2 n n 3 ) \Omicron(2^nn^3) O(2nn3)

程式碼
#include <iostream>
#include <cstdio>
#include <vector>
using namespace std;
const int maxn = 20,mod = 1e9 + 7;
int n,m,ans,a[maxn][maxn];
vector <pair<int,int> > e[maxn];
int read(){
	int x = 0;
	char c = getchar();
	while(c < '0' || c > '9') c = getchar();
	while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
	return x;
}
inline int add(int x,int y){
	if(x + y < mod) return x + y;
	else return x + y - mod;
}
inline int dec(int x,int y){
	if(x - y >= 0) return x - y;
	else return x - y + mod;
}
int qpow(int x,int k){
	int d = 1,t = x;
	while(k){
		if(k & 1) d = 1ll * d * t % mod;
		t = 1ll * t * t % mod,k >>= 1;
	}
	return d;
}
int det(int n){
	int f = 1,d = 1;
	for(int i = 1; i <= n; i ++){
		if(a[i][i] == 0){
			for(int j = i + 1; j <= n; j ++){
				if(a[j][i]){
					f ^= 1;
					for(int k = i; k <= n; k ++) swap(a[i][k],a[j][k]);
					break;
				}
			}
		}
		int inv = qpow(a[i][i],mod - 2);
		for(int j = i + 1; j <= n; j ++){
			int t = 1ll * a[j][i] * inv % mod;
			for(int k = i; k <= n; k ++) a[j][k] = dec(a[j][k],1ll * t * a[i][k] % mod);
		}
		d = 1ll * d * a[i][i] % mod;
	}
	return f ? d : (mod - d);
}
int main(){
	n = read(),m = 1 << (n - 1);
	for(int i = 0; i < n - 1; i ++){
		int t = read();
		while(t --) e[i].push_back({read(),read()});
	}
	for(int i = 0; i < m; i ++){
		for(int j = 1; j <= n; j ++)
			for(int k = 1; k <= n; k ++)
				a[j][k] = 0;
		int cnt = 0;
		for(int j = 0; j < n - 1; j ++){
			if((i >> j & 1) == 0) continue;
			cnt ++;
			for(int k = e[j].size() - 1; k >= 0; k --){
				int u = e[j][k].first,v = e[j][k].second;
				a[u][u] ++,a[v][v] ++;
				a[u][v] --,a[v][u] --;
			}
		}
		for(int j = 1; j <= n; j ++)
			for(int k = 1; k <= n; k ++)
				if(a[j][k] < 0) a[j][k] += mod;
		if((n - cnt) % 2) ans = add(ans,det(n - 1));
		else ans = dec(ans,det(n - 1));
	}
	printf("%d\n",ans);
	return 0;
}

作業題

分析

今年省選的 Day2T3,現在回去看怎麼感覺像是 Day2 最簡單的題 ¿

先推一下這個式子
A n s = ∑ T ∑ e ∈ T w e gcd ⁡ e ∈ T ( w e ) = ∑ T ∑ e ∈ T w e ∑ d   ∣ gcd ⁡ e ∈ T ( w e ) φ ( d ) = ∑ T ∑ e ∈ T w e ∑ ∀ e ∈ T , d ∣ w e φ ( d ) = ∑ d = 1 V φ ( d ) ( ∑ T , ∀ e ∈ T , d ∣ w e ∑ e ∈ T w e ) \begin{aligned} Ans&=\sum_{T}\sum_{e\in T}w_e\gcd_{e\in T}(w_e)\\ &=\sum_{T}\sum_{e\in T}w_e\sum_{d\,|\gcd_{e\in T}(w_e)}\varphi(d)\\ &=\sum_{T}\sum_{e\in T}w_e\sum_{\forall e\in T,d|w_e}\varphi(d)\\ &=\sum_{d=1}^V\varphi(d)\left(\sum_{T,\atop \forall e\in T,d|w_e}\sum_{e\in T} w_e\right) \end{aligned} Ans=TeTweeTgcd(we)=TeTwedgcdeT(we)φ(d)=TeTweeT,dweφ(d)=d=1Vφ(d)eT,dweT,eTwe

其中, V V V 表示值域最大值,下同。

先考慮右邊的部分,如果沒有 d ∣ w e d | w_e dwe 這個限制要怎麼快速求?
我們可以把每條邊的權值設為一個一次多項式 1 + w x 1+wx 1+wx (其中 w w w 表示這條邊的邊權),那麼每個生成樹邊權的和就轉化成了生成樹邊權的積所得多項式的一次項係數,於是就可以直接上變元矩陣樹定理了。

再看回原式子,加上限制以後怎麼搞?
如果我們對每個 d d d 重新加邊,那麼只加滿足 d ∣ w e d | w_e dwe 的邊就一定可以滿足限制條件,就可以用上面的方法求了。但這樣的時間複雜度為 O ( V n 3 ) \Omicron(Vn^3) O(Vn3),不可能通過,考慮優化。

容易發現當圖不連通時對答案是一定沒有貢獻的,而如果我們加入的邊不足 n − 1 n-1 n1 條,圖一定不連通,於是我們可以只對邊數大於 n − 1 n-1 n1 d d d 求行列式。因為對每條邊來說,它最多隻會被它邊權的每個因數算一遍,而每次跑行列式都需要至少 n − 1 n-1 n1 條邊,故複雜度上界為 O ( t m n − 1 n 3 ) = O ( t n 4 ) \Omicron(\frac{tm}{n-1}n^3)=\Omicron(tn^4) O(n1tmn3)=O(tn4) m m m n 2 n^2 n2 級別的),其中 t t t 表示值域中最大的因數個數。
打表可知 t = 144 t=144 t=144,可以通過,且實際上遠跑不滿。

程式碼
#include <iostream>
#include <cstdio>
using namespace std;
const int maxn = 35,maxm = 500,maxv = 1.6e5,mod = 998244353;
int n,m,mx,ans,tot,u[maxm],v[maxm],w[maxm],vis[maxv],p[maxv],phi[maxv];
inline int add(int x,int y){
	if(x + y < mod) return x + y;
	else return x + y - mod;
}
inline int dec(int x,int y){
	if(x - y >= 0) return x - y;
	else return x - y + mod;
}
int qpow(int x,int k){
	int d = 1,t = x;
	while(k){
		if(k & 1) d = 1ll * d * t % mod;
		t = 1ll * t * t % mod,k >>= 1;
	}
	return d;
}
struct node{
	int x,y;
	node operator + (node p){
		return {add(x,p.x),add(y,p.y)};
	}
	node operator - (node p){
		return {dec(x,p.x),dec(y,p.y)};
	}
	node operator * (node p){
		return {1ll * x * p.x % mod,add(1ll * x * p.y % mod,1ll * y * p.x % mod)};
	}
}a[maxn][maxn];
int read(){
	int x = 0;
	char c = getchar();
	while(c < '0' || c > '9') c = getchar();
	while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
	return x;
}
inline node Inv(node p){
	//    Inv(1+ax)=1-ax (mod x^2)
	// => Inv(a+bx) = Inv(a(1+(b/a)x)) = Inv(a)Inv(1+(b/a)x) = (1/a)(1-(b/a)x) (mod x^2)
	int t = qpow(p.x,mod - 2);
	t = 1ll * t * t % mod;
	p.y = mod - 1ll * p.y * t % mod,p.x = 1ll * p.x * t % mod;
	return p;
}
int det(int n){
	int f = 1;
	node d = {1,0};
	for(int i = 1; i <= n; i ++){
		int p = 0;
		for(int j = i; j <= n; j ++){
			if(a[j][i].x){
				p = j;
				break;
			}
		}
		if(!p) return 0;
		if(p != i){
			f ^= 1;
			for(int j = i; j <= n; j ++) swap(a[p][j],a[i][j]);
		}
		node inv = Inv(a[i][i]);
		for(int j = i + 1; j <= n; j ++){
			node t = a[j][i] * inv;
			for(int k = i; k <= n; k ++) a[j][k] = a[j][k] - t * a[i][k];
		}
		d = d * a[i][i];
	}
	return f ? d.y : mod - d.y;
}
void Sieve(int n){
	phi[1] = 1;
	for(int d = 2; d <= n; d ++){
		if(!vis[d]) p[++tot] = d,phi[d] = d - 1;
		for(int i = 1; i <= tot && p[i] * d <= mx; i ++){
			int v = p[i] * d;
			vis[v] = 1;
			if(d % p[i] == 0){
				phi[v] = p[i] * phi[d];
				break;
			}
			phi[v] = phi[p[i]] * phi[d];
		}
	}
}
int main(){
	n = read(),m = read();
	for(int i = 1; i <= m; i ++) u[i] = read(),v[i] = read(),w[i] = read(),mx = max(mx,w[i]);
	Sieve(mx);
	for(int d = 1; d <= mx; d ++){
		int cnt = 0;
		for(int i = 1; i <= m; i ++) if(w[i] % d == 0) cnt ++;
		if(cnt < n - 1) continue;//很多人這裡都直接寫了並查集,我比較懶,反正都能過
		for(int i = 1; i <= n; i ++)
			for(int j = 1; j <= n; j ++)
				a[i][j] = {0,0};
		for(int i = 1; i <= m; i ++)
			if(w[i] % d == 0){
				a[u[i]][v[i]] = {mod - 1,mod - w[i]};
				a[v[i]][u[i]] = {mod - 1,mod - w[i]};
				a[u[i]][u[i]] = a[u[i]][u[i]] + (node){1,w[i]};
				a[v[i]][v[i]] = a[v[i]][v[i]] + (node){1,w[i]};
			}
		ans = add(ans,1ll * phi[d] * det(n - 1) % mod);
	}
	printf("%d\n",ans);
	return 0;
}

相關文章