KM演算法小記

Sktn0089發表於2024-04-16

這個人踩死了,連 KM 都不會。

之前一直以為費用流一定優於 KM,前幾天做一道題,費用流過不了,非要用 KM 才能過,後來才知道 KM 是能做到 \(O(n^3)\) 的。

二分圖最大權完美匹配

給定一張二分圖,保證有完美匹配。每條邊有權值,求權值和最大的完美匹配。

頂標

頂標是 KM 演算法的核心工具。

我們給左右兩邊每個點分配一個數值,稱為“頂標”。不妨設 \(lx[i]\) 為左邊第 \(i\) 個點的頂標,\(ly[i]\) 則為右邊。

這裡,我們分配的頂標需要滿足 \(\forall (u,v) \in E,\space lx[u]+ly[v]\ge w(u,v)\),也就是說一條邊兩端的頂標之和大於等於該邊權值。

  • 定義 - 相等子圖:二分圖保留 \(E\) 中滿足 \(lx[u]+ly[v]=w(u,v)\) 的邊後的圖。

我們需要知道一個事實:如果當前的相等子圖中存在完美匹配,那麼這個匹配一定是原圖的最大權完美匹配。

注意到一點,因為 \(lx[u]+ly[v]\ge w(u,v)\),所以任何一個匹配的權值和 \(\le \sum\limits_u lx[u]+\sum\limits_u ly[u]\)

而相等子圖中的完美匹配正好取到上屆 \(\sum\limits_u lx[u]+\sum\limits_u ly[u]\),不存在匹配的權值之和比他大,所以一定是我們要的答案。

於是問題轉化為如何合理分配頂標,使得相等子圖存在完美匹配。

分配頂標 —— 增量+調整

我們考慮調整法,初始化頂標,可以把所有頂標賦值為無窮大,或者所連邊的權值最大值。

\(i=1...n\) 的順序為每個 \(i\) 找一個匹配。使用匈牙利演算法,在相等子圖中找一條增廣路。

當然可能找不到增廣路,此時我們需要調整頂標。設 \(slack[v]\) 表示不在搜尋樹中的右部點 \(v\) 與有連邊的在搜尋樹中的左部點 \(u\)\(lx[u]+ly[v]-w(u,v)\) 的最小值。

顯然一個合法的 \(v\)\(slack[v]\) 一定不為 \(0\)。令 \(d=\min\limits_v \{slack[v]\}\),我們令所有在搜尋樹中的左部點 \(u\) 修改 \(lx[u]\gets lx[u]-d\),令所有在搜尋樹中的右部點 \(v\) 修改 \(ly[v]\gets ly[v]-d\)。顯然,一條在搜尋樹中的點一定還在相等子圖中;一條 \(u\) 在搜尋樹、\(v\) 不在搜尋樹的邊,修改後有可能出現在相等子圖中,容易發現至少會有一條出現。

由於原圖保證了一定有完美匹配,所以這樣更新下去,一定可以找到合法的增廣路。

一次頂標的修改會使最多一個點進入搜尋樹,每次進行一個增廣,每個點都要做,時間為 \(O(n^2m)\)。當 \(m=n^2\) 時,為 \(O(n^4)\)

點選檢視程式碼
#include<bits/stdc++.h>
#define ll long long
#define pir pair<ll,ll>
#define fi first
#define se second
#define mkp make_pair
#define pb push_back
using namespace std;
const ll maxn=510, mod=998244353;
ll n,m; ll a[maxn][maxn];
vector<pir>to[maxn];
ll mch[maxn],lx[maxn],ly[maxn];
ll visx[maxn], visy[maxn], d[maxn];
ll dfs(ll u){ visx[u]=1;
	for(ll v=1;v<=n;v++)
		if(!visy[v]){
			if(lx[u]+ly[v]==a[u][v]){
				visy[v]=1;
				if(!mch[v]||dfs(mch[v])){
					mch[v]=u; return 1;
				}
			} d[v]=min(d[v],lx[u]+ly[v]-a[u][v]);
		} return 0;
}
void KM(){
	memset(lx,0xcf,sizeof lx);
	for(ll u=1;u<=n;u++)
		for(ll v=1;v<=n;v++) lx[u]=max(lx[u],a[u][v]);
	for(ll i=1;i<=n;i++){
		while(1){
			memset(visx,0,sizeof visx);
			memset(visy,0,sizeof visy);
			memset(d,0x3f,sizeof d);
			if(dfs(i)) break;
			ll dif=1e17;
			for(ll j=1;j<=n;j++)
				if(!visy[j]) dif=min(dif,d[j]);
			for(ll j=1;j<=n;j++)
				if(visx[j]) lx[j]-=dif;
			for(ll j=1;j<=n;j++)
				if(visy[j]) ly[j]+=dif;
		}
	}
	ll res=0;
	for(ll i=1;i<=n;i++) res+=a[mch[i]][i];
	printf("%lld\n",res);
	for(ll i=1;i<=n;i++) printf("%lld ",mch[i]);
}
int main(){
	scanf("%lld%lld",&n,&m);
	for(ll i=1;i<=n;i++)
		for(ll j=1;j<=n;j++) a[i][j]=-1e16;
	for(ll i=1;i<=m;i++){
		ll u,v,w; scanf("%lld%lld%lld",&u,&v,&w);
		a[u][v]=max(a[u][v],w);
	}
	KM();
	return 0;
}

BFS 的最佳化

考慮每次修改完頂標後,搜尋樹中的邊本質不變,唯一變的是會向外擴充一些新的點。我們每次都做一遍增廣,浪費了大量時間。

考慮 DFS 轉 BFS,我們每次修改頂標後只需要檢查沒搜過的所有右部點,其是否可以加入搜尋樹。

這樣一個點只會被搜一次,BFS 部分總時間為 \(O(nm)\)。同時,最多修改 \(n\) 次頂標,每次是 \(O(n)\),這部分是 \(O(n^3)\)。所以總時間為 \(O(nm+n^3)\),當 \(m=n^2\) 時,為 \(O(n^3)\)

點選檢視程式碼
#include<bits/stdc++.h>
#define ll long long
#define pir pair<ll,ll>
#define fi first
#define se second
#define mkp make_pair
#define pb push_back
using namespace std;
const ll maxn=510, mod=998244353;
ll n,m; ll a[maxn][maxn];
vector<pir>to[maxn];
ll mchx[maxn],mchy[maxn], lx[maxn],ly[maxn];
ll visx[maxn],visy[maxn], sla[maxn], pre[maxn]; ll q[maxn], l, r;
void aug(ll u){
	while(u){
		mchy[u]=pre[u];
		swap(u,mchx[pre[u]]);
	}
}
void bfs(ll s){
	memset(visx,0,sizeof visx);
	memset(visy,0,sizeof visy);
	memset(pre,0,sizeof pre);
	memset(sla,0x3f,sizeof sla);
	q[l=r=1]=s;
	while(1){
		while(l<=r){
			ll u=q[l++]; visx[u]=1;
			for(ll v=1;v<=n;v++)
				if(!visy[v]&&sla[v]>lx[u]+ly[v]-a[u][v]){
					sla[v]=lx[u]+ly[v]-a[u][v], pre[v]=u;
					if(!sla[v]){ visy[v]=1;
						if(!mchy[v]){
							aug(v); return;
						} else q[++r]=mchy[v];
					}
				}
		}
		ll dif=1e17;
		for(ll i=1;i<=n;i++)
			if(!visy[i]) dif=min(dif,sla[i]);
		for(ll i=1;i<=n;i++)
			if(visx[i]) lx[i]-=dif;
		for(ll i=1;i<=n;i++)
			if(visy[i]) ly[i]+=dif;
			else sla[i]-=dif;
		for(ll i=1;i<=n;i++)
			if(!visy[i]&&!sla[i]){
				visy[i]=1;
				if(!mchy[i]){
					aug(i); return;
				} q[++r]=mchy[i];
			}
	}
}
void KM(){
	for(ll u=1;u<=n;u++)
		for(ll v=1;v<=n;v++) lx[u]=max(lx[u],a[u][v]);
	for(ll i=1;i<=n;i++){
		bfs(i);
	}
	ll res=0;
	for(ll i=1;i<=n;i++) res+=a[mchy[i]][i];
	printf("%lld\n",res);
	for(ll i=1;i<=n;i++) printf("%lld ",mchy[i]);
}
int main(){
	scanf("%lld%lld",&n,&m);
	for(ll i=1;i<=n;i++)
		for(ll j=1;j<=n;j++) a[i][j]=-1e16;
	for(ll i=1;i<=m;i++){
		ll u,v,w; scanf("%lld%lld%lld",&u,&v,&w);
		a[u][v]=max(a[u][v],w);
	}
	KM();
	return 0;
}

相關文章