dp 套 dp(dp of dp)小記

KingPowers發表於2024-08-08

其實並不是什麼很高大上的東西,就是把內層 dp 的結果壓到外層 dp 的狀態裡。

通常解決的是“限制某種值為 \(x\) 的方案數”之類的問題,而限制的值通常是一個經典的 dp 問題。

沒有啥好直接介紹的,就寫三道做過的題。

BZOJ3864 Hero meet devil

題目連結

算是一道入門題目。

我們先回憶一下一個經典問題:給定兩個串 \(s\)\(t\),求它們的 \(\text{LCS}\)

考慮 dp:設 \(g_{i,j}\) 表示串 \(t\) 的前 \(i\) 位和串 \(s\) 的前 \(j\) 位的 \(\text{LCS}\),轉移是比較簡單的:

\[g_{i,j}=\begin{cases}g_{i-1,j-1}+1,&t_i=s_j\\\max(g_{i-1,j},g_{i,j-1}),&t_i\not=s_j\end{cases} \]

現在我們就是要對每個 \(0\le i\le|s|\) 統計有多少個串 \(t\) 滿足 \(g_{|t|,|s|}=i\)

要統計 dp 值為某個值的串有幾個,不妨直接將 \(g\) 作為內層 dp 記錄到狀態裡:設 \(f_{i,S}\) 表示已經填了 \(i\) 個字元,\(g\) 陣列的第 \(i\) 行結果為 \(S\) 的串有多少個。

直接這麼記錄狀態數肯定是爆炸了,畢竟你要用一個數表示整個陣列。但是我們冷靜分析下發現,固定 \(g\) 的第 \(i\) 行時,相鄰的 \(g_{i,j}\) 之間的差值不會超過 \(1\)(這是顯然的,多加一個字元 \(\text{LCS}\) 的長度至多增加 \(1\)),換句話說就是 \(g\) 的每行的差分陣列只有 \(0/1\) 兩種數。所以我們可以直接把 \(S\) 記錄成 \(g\)\(i\) 行的差分陣列,這樣我們就把第二維的狀態數壓到了 \(2^{|s|}\) 種。

因此我們提前預處理出 \(nxt_{S,c}\) 表示狀態為 \(S\) 時接上一個字元 \(c\) 會轉移到哪種狀態,轉移就是直接列舉新加入的字元 \(c\)\(f_{i,S}\to f_{i+1,nxt_{S,c}}\) 即可。

這次直接把程式碼放上來。

#include<bits/stdc++.h>
#define int long long
#define For(i, a, b) for(int i = (a); i <= (b); i++)
#define Rof(i, a, b) for(int i = (a); i >= (b); i--)
#define deb(x) cerr << #x"=" << x << '\n';
using namespace std;
const int mod = 1e9 + 7;
int n, m, nxt[1 << 15][5];
int f[1005][1 << 15], g[20], h[20], a[20], ans[20];
string s;
int get_nxt(int S, int c){
	For(i, 1, n) g[i] = g[i - 1] + ((S >> i - 1) & 1);
	For(i, 1, n){
		if(a[i] == c) h[i] = g[i - 1] + 1;
		else h[i] = max(h[i - 1], g[i]);
	}
	int T = 0;
	For(i, 1, n) T |= (h[i] - h[i - 1]) << i - 1;
	return T;
}
void Solve(){
	cin >> s;
	n = s.size(); s = ' ' + s; 
	For(i, 1, n){
		if(s[i] == 'A') a[i] = 1;
		else if(s[i] == 'G') a[i] = 2;
		else if(s[i] == 'T') a[i] = 3;
		else a[i] = 4;
	}
	For(i, 0, (1 << n) - 1){
		nxt[i][1] = get_nxt(i, 1);
		nxt[i][2] = get_nxt(i, 2);
		nxt[i][3] = get_nxt(i, 3);
		nxt[i][4] = get_nxt(i, 4);
	}
	cin >> m; f[0][0] = 1;
	For(i, 0, m - 1) For(S, 0, (1 << n) - 1){
		if(!f[i][S]) continue;
		(f[i + 1][nxt[S][1]] += f[i][S]) %= mod;
		(f[i + 1][nxt[S][2]] += f[i][S]) %= mod;
		(f[i + 1][nxt[S][3]] += f[i][S]) %= mod;
		(f[i + 1][nxt[S][4]] += f[i][S]) %= mod;
	}
	For(S, 0, (1 << n) - 1) (ans[__builtin_popcount(S)] += f[m][S]) %= mod;
	For(i, 0, n) cout << ans[i] << '\n', ans[i] = 0;
	For(i, 0, m) For(S, 0, (1 << n) - 1) f[i][S] = 0;
}
signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0); cout.tie(0);
	int T = 1; cin >> T;
	while(T--) Solve();
	return 0;
}

SDOI2022 小 N 的獨立集

題目連結

稍微加了點難度。

給定點權求最大獨立集是一個經典的 dp 問題:設 \(f_{u,0/1}\) 表示 \(u\) 的子樹內 \(u\) 不選/選的最大獨立集。

我們注意到本題擁有著極小的值域(\(k\le 5\)),所以啟發我們直接把 dp 結果扔到狀態裡。設 \(dp_{u,x,y}\) 表示 \(u\) 子樹內 \(f_{u,0/1}\) 的值分別為 \(x\)\(y\) 的方案數,轉移考慮類似樹上揹包,將 \(u\) 的兒子 \(v\) 合併過來,具體地:

\[dp_{u,x,y}\times dp_{v,p,q}\to dp'_{u,x+\max(p,q),v+p} \]

下標的更新方式就是我們內層 dp 原先的轉移方式。

但是我們發現這個做法的狀態數達到了 \(O(n^3k^2)\),難以透過。

我們嘗試減少內層 dp 的狀態數。我們可以發現強制 \(u\) 選的答案比強制 \(u\) 不選的答案不會優太多,稍加分析就可以觀察到這麼一個性質:\(0\le\max(f_{u,0},f_{u,1})-f_{u,0}\le k\)。這是因為對於強制選了 \(u\) 的方案來說,把 \(u\) 去掉最多隻會減少 \(k\) 的權值且會變成一種不選的方案。這啟發我們更改下內層 dp 的定義:設 \(f_{u,0/1}\) 表示 \(u\) 子樹內不強制/強制 \(u\) 不選的方案數,這樣以來就有 \(0\le f_{u,0}-f_{u,1}\le k\)

那麼我們把外層的 dp 狀態也相應地更改為:\(dp_{u,x,y}\) 表示 \(u\) 子樹內 \(f_{u,1}=x\)\(f_{u,0}=x+y\) 的方案數,轉移也不難:

\[dp_{u,x,y}\times dp_{v,p,q}\to dp_{u,x+p+q,\max(x+y+p,x+p+q)-(x+p+q)} \]

此時的狀態數是 \(O(n^2k^2)\),套用樹上揹包可以分析出複雜度的上界為 \(O(n^2k^4)\),這個上界極其寬鬆所以可以透過。

#include<bits/stdc++.h>
#define For(i, a, b) for(int i = (a); i <= (b); i++)
#define Rof(i, a, b) for(int i = (a); i >= (b); i--)
#define deb(x) cerr << #x"=" << x << '\n';
using namespace std;
const int mod = 1e9 + 7;
int n, k, siz[1005], f[1005][5005][6], g[5005][6], ans[5005];
vector<int> e[1005];
void Add(int &x, int y){if((x = x + y) >= mod) x -= mod;}
void dfs(int now, int fa){
	siz[now] = 1;
	For(i, 1, k) f[now][0][i] = 1;
	for(int to : e[now]){
		if(to == fa) continue;
		dfs(to, now);
		memset(g, 0, sizeof g);
		For(x, 0, k * siz[now]) For(y, 0, k) if(f[now][x][y])
			For(p, 0, k * siz[to]) For(q, 0, k) if(f[to][p][q])
				Add(g[x + p + q][max(x + y + p, x + p + q) - (x + p + q)], 1ll * f[now][x][y] * f[to][p][q] % mod);
		memcpy(f[now], g, sizeof g);
		siz[now] += siz[to];
	}
}
void Solve(){
	cin >> n >> k;
	For(i, 1, n - 1){
		int u, v; cin >> u >> v;
		e[u].emplace_back(v);
		e[v].emplace_back(u);
	}
	dfs(1, 0);
	For(i, 1, n * k){
		int ans = 0;
		For(j, 0, min(i, k)) Add(ans, f[1][i - j][j]);
		cout << ans << '\n';
	}
}
signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0); cout.tie(0);
	int T = 1; //cin >> T;
	while(T--) Solve();
	return 0;
}

CF924F Minimal Subset Difference

題目連結

比較困難的一題。

考慮一個數的答案如何計算,這個顯然等價於子集和,只能做類似揹包的東西。我們設 \(f_{i,j}\) 表示用了 \(i\) 位數當前的差是否可以為 \(j\),轉移時如果加入了一個 \(c\) 可以轉移到 \(f_{i+1,j+c}\)\(f_{i+1,|j-c|}\)。事實上第一維可以扔去,只保留 \(f_i\) 表示差值為 \(i\) 的可行性。

因為單個數計算的複雜度是不可能再低下去了,我們就只能計數有多少個數的 dp 狀態是合法的,那就只能是 dp 套 dp。看上去直接把 \(f\) 設在狀態裡狀態數直接昇天了,但是先別急,我們慢慢降。

首先一個觀察是最後的答案一定不超過 \(9\),因為考慮一個貪心:每次往和小的集合裡扔數,這樣能保證最後差不超過 \(9\)。而我們又注意到,如果某個時刻兩個集合的差大於了 \(72\),那麼即使剩下位全都是 \(9\) 扔過去答案也不會小於 \(9\),所以說我們的 \(f\) 其實只需要保留 \(f_0\)\(f_{72}\) 的狀態就好了,這樣我們可以直接用一個 int128 表示 \(f\)

現在的狀態數是 \(2^{73}\) 左右,還是很爆炸。因為填的數最多隻有 \(19\) 位,我們考慮直接爆搜,搜出所有的合法狀態,發現只有一萬多種!

所以我們可以直接把這一萬多種狀態拉出來做 dp 套 dp 了。預處理 \(g_{lim,len,S}\) 表示限制差值不超過 \(lim\),還剩下 \(len\) 位數需要填,當前的狀態為 \(S\) 時還有多少種填法,轉移是容易的。對於詢問,顯然要先差分掉,然後我們列舉 \(\text{LCP}\) 直接在 \(g\) 這個自動機上走就能統計答案。

跑得相當快,CF 上只跑了 234ms。

#include<bits/stdc++.h>
#define int long long
#define For(i, a, b) for(int i = (a); i <= (b); i++)
#define Rof(i, a, b) for(int i = (a); i >= (b); i--)
#define deb(x) cerr << #x"=" << x << '\n';
using namespace std;
using LL = __int128;
const int D = 10, L = 19, MS = 20005, W = 72;
const LL U = ((LL)1 << W + 1) - 1;
map<LL, int> ID;
int tot, ans[MS], nxt[MS][10], f[D][L][MS];
LL val[MS];
vector<int> vec[D];
LL get_nxt(LL S, int c){
	LL T = (S >> c) | (S << c);
	For(i, 0, c) if((S >> i) & 1) T |= (1 << c - i);
	return T & U;
}
int get_ans(LL S){
	For(i, 0, D - 1) if((S >> i) & 1) return i;
	assert(0); return 114514;
}
void bfs(){
	queue<pair<int, int>> q;
	tot++; ans[1] = 0; ID[1] = 1; val[1] = 1; q.push({1, 0});
	while(!q.empty()){
		auto [cur, len] = q.front(); q.pop();
		if(len == L - 1) continue;
		For(c, 1, 9){
			LL to = get_nxt(val[cur], c);
			auto it = ID.find(to);
			if(it != ID.end()) {nxt[cur][c] = it -> second; continue;}
			ID[to] = ++tot; nxt[cur][c] = tot; ans[tot] = get_ans(to); val[tot] = to;
			q.push({tot, len + 1});
		}
	}
	For(i, 1, tot) vec[ans[i]].push_back(i);
	For(i, 0, D - 1){
		For(j, 0, i) for(int k : vec[j]) f[i][0][k] = 1;
		For(j, 1, L - 1) For(k, 1, tot){
			f[i][j][k] += f[i][j - 1][k];
			For(c, 1, 9) f[i][j][k] += f[i][j - 1][nxt[k][c]];
		}
	}
}
int lim;
int query(int x){
	if(lim >= 10) return x + 1;
	int ans = 0, now = 1; x++; vector<int> st;
	while(x) st.push_back(x % 10), x /= 10;
	reverse(st.begin(), st.end());
	int len = st.size(); ans += f[lim][len - 1][1];
	For(i, 0, len - 1){
		int x = st[i];
		For(c, (i == 0), x - 1){
			int to = (c == 0) ? now : nxt[now][c];
			ans += f[lim][len - i - 1][to];
		}
		now = (x == 0) ? now : nxt[now][x];
	}
	return ans;
}
void Solve(){
	int l, r; cin >> l >> r >> lim;
	cout << query(r) - query(l - 1) << '\n';
}
signed main(){
	ios::sync_with_stdio(0);
	cin.tie(0); cout.tie(0);
	int T; cin >> T; bfs();
	while(T--) Solve();
	return 0;
}