插頭DP學習筆記

liuchanglc發表於2021-04-05

插頭DP學習筆記

用途

有些 狀壓 \(DP\) 問題要求我們記錄狀態的連通性資訊,這類問題一般被形象的稱為插頭 \(DP\) 或連通性狀態壓縮 \(DP\)

例如格點圖的哈密頓路徑計數,求棋盤的黑白染色方案滿足相同顏色之間形成一個連通塊的方案數,以及特定圖的生成樹計數等等。

這些問題通常需要我們對狀態的連通性進行編碼,討論狀態轉移過程中連通性的變化。

例題

洛谷P5056 【模板】插頭dp

首先要明確兩個概念:

輪廓線:已決策狀態和未決策狀態的分界線。

插頭:一個格子某個方向的插頭存在,表示這個格子在這個方向與相鄰格子相連。

我們要狀壓的就是輪廓線上插頭的狀態。

具體來說,可以把路徑的合併看作括號的匹配。

一般的狀壓 \(dp\) 只有 \(0,1\) 兩個狀態,分別表示有和沒有,但是這道題需要用三進位制的狀態壓縮,\(0,1,2\) 分別表示沒有括號,有左括號,有右括號。

之所以要分左右括號是因為要區分下面這兩種情況:

第一種情況中間的兩個括號是不能合併的,因為要恰好形成一條迴路,第二種情況則能夠合併。

一條輪廓線會由 \(n+1\) 條線段組成,其中 \(n\) 條是左右方向的,另外 \(1\) 條是上下方向的。

為了方便解壓狀態,我們用四進位制來表示,同時減少列舉的狀態,要把所有的狀態存到雜湊表裡。

轉移的時候大力分類討論:

\(1\)、當前的位置不能有路徑經過

如果沒有向右的插頭或者向下的插頭,直接繼承上一個格子的答案,

否則不存在合法的方案。

if(s[i][j]=='*'){
	if(!r && !dow) f[now].ad(nzt,nval);
} 

\(2\)、當前的位置必須經過並且沒有向右的插頭或者向下的插頭。

需要在當前的格子新開一個向右的插頭和向下的插頭,並且把向右的插頭標記為右括號,把向下的插頭標記為左括號。

我在轉移狀態之前就去判斷這個狀態是否合法,這樣會比較好寫。

else if(!r && !dow){
	if(s[i][j+1]=='.' && s[i+1][j]=='.') f[now].ad(nzt|2|(1<<j*2),nval);
} 

\(3\)、當前的位置必須經過並且只有向右的插頭或者向下的插頭。

可以繼續沿著之前的方向或者改變插頭的方向,左右括號不變。

 else if(r && !dow){
		if(s[i][j+1]=='.') f[now].ad(nzt,nval);
		if(s[i+1][j]=='.') f[now].ad(nzt^r|(r<<j*2),nval);
} else if(dow && !r){
		if(s[i+1][j]=='.') f[now].ad(nzt,nval);
		if(s[i][j+1]=='.') f[now].ad(nzt^(dow<<j*2)|dow,nval);
} 

\(4\)、當前的位置必須經過並且有一個代表左括號的右插頭和一個代表右括號的下插頭。

如果當前的點是圖中右下角的終止節點並且不存在其它匹配的括號更新答案。

else if(r==1 && dow==2){
		if(i==edx && j==edy && (nzt^r^(dow<<j*2))==0) ans+=nval;
} 

\(5\)、當前的位置必須經過並且有一個代表左括號的下插頭和一個代表右括號的右插頭。

將這兩個括號匹配。

else if(r==2 && dow==1){
		f[now].ad(nzt^r^(dow<<j*2),nval);
}

\(6\)、當前的位置必須經過並且有一個下插頭和一個右插頭,並且這兩個插頭都代表左括號。

一直向右找,找到第一個左括號和右括號恰好匹配的位置,把這個位置的右括號改為左括號,之前的兩個左括號直接匹配。

else if(r==1 && dow==1){
	cs1=nzt^r^(dow<<j*2);
	for(rg int o=j+1,p=1;o<=m;o++){
		cs2=cs1>>o*2&3;
		p+=(cs2==1)-(cs2==2);
		if(!p){
			cs1^=3<<o*2;
			break;
		}
	}
	f[now].ad(cs1,nval);
} 

\(7\)、當前的位置必須經過並且有一個下插頭和一個右插頭,並且這兩個插頭都代表右括號。

和上面的情況一樣,但是需要改成向左找。

else if(r==2 && dow==2){
	cs1=nzt^r^(dow<<j*2);
	for(rg int o=j-1,p=1;o;o--){
		cs2=cs1>>o*2&3;
		p+=(cs2==2)-(cs2==1);
		if(!p){
			cs1^=3<<o*2;
			break;
		}
	}
	f[now].ad(cs1,nval);
}

程式碼

#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
#include<algorithm>
#define rg register
const int maxn=14,mod=1e5+3,maxm=1e5+5;
int n,m,edx,edy;
char s[maxn][maxn];
long long ans=0;
struct has{
	struct asd{
		int nxt,zt;
		long long val;
	}b[maxm];
	has(){
		memset(h,-1,sizeof(h));
		tot=1;
	}
	int tot,h[maxm];
	void cls(){
		memset(h,-1,sizeof(h));
		tot=1;
	}
	void ad(rg int zt,rg long long val){
		rg int now=zt%mod;
		for(rg int i=h[now];i!=-1;i=b[i].nxt){
			if(b[i].zt==zt){
				b[i].val+=val;
				return;
			}
		}
		b[tot].val=val;
		b[tot].zt=zt;
		b[tot].nxt=h[now];
		h[now]=tot++;
	}
}f[2];
int main(){
	scanf("%d%d",&n,&m);
	for(rg int i=1;i<=n;i++){
		scanf("%s",s[i]+1);
		for(rg int j=1;j<=m;j++){
			if(s[i][j]=='.') edx=i,edy=j;
		}
	}
	rg int now=0,nzt,r,dow,cs1,cs2;
	rg long long nval;
	f[0].ad(0,1);
	for(rg int i=1;i<=n;i++){
		for(rg int j=1;j<=m;j++){
			now^=1;
			f[now].cls();
			for(rg int k=1;k<f[now^1].tot;k++){
				nzt=f[now^1].b[k].zt,nval=f[now^1].b[k].val;
				r=nzt&3,dow=nzt>>j*2&3;
				if(s[i][j]=='*'){
					if(!r && !dow) f[now].ad(nzt,nval);
				} else if(!r && !dow){
					if(s[i][j+1]=='.' && s[i+1][j]=='.') f[now].ad(nzt|2|(1<<j*2),nval);
				} else if(r && !dow){
					if(s[i][j+1]=='.') f[now].ad(nzt,nval);
					if(s[i+1][j]=='.') f[now].ad(nzt^r|(r<<j*2),nval);
				} else if(dow && !r){
					if(s[i+1][j]=='.') f[now].ad(nzt,nval);
					if(s[i][j+1]=='.') f[now].ad(nzt^(dow<<j*2)|dow,nval);
				} else if(r==1 && dow==2){
					if(i==edx && j==edy && (nzt^r^(dow<<j*2))==0) ans+=nval;
				} else if(r==2 && dow==1){
					f[now].ad(nzt^r^(dow<<j*2),nval);
				} else if(r==1 && dow==1){
					cs1=nzt^r^(dow<<j*2);
					for(rg int o=j+1,p=1;o<=m;o++){
						cs2=cs1>>o*2&3;
						p+=(cs2==1)-(cs2==2);
						if(!p){
							cs1^=3<<o*2;
							break;
						}
					}
					f[now].ad(cs1,nval);
				} else if(r==2 && dow==2){
					cs1=nzt^r^(dow<<j*2);
					for(rg int o=j-1,p=1;o;o--){
						cs2=cs1>>o*2&3;
						p+=(cs2==2)-(cs2==1);
						if(!p){
							cs1^=3<<o*2;
							break;
						}
					}
					f[now].ad(cs1,nval);
				}
			}
		}
	}
	printf("%lld\n",ans);
	return 0;
}

相關文章