2023 6月 dp做題記錄

Fire_Raku發表於2024-04-20

目錄
  • 6月 dp做題記錄
      • P5664 [CSP-S2019] Emiya 家今天的飯
      • P8867 [NOIP2022] 建造軍營
      • [ARC115E] LEQ and NEQ
      • P3800 Power收集
      • P3594 [POI2015] WIL

6月 dp做題記錄

P5664 [CSP-S2019] Emiya 家今天的飯

分析條件,我們要選出來的菜的集合需要滿足的限制,集合不為空和烹飪方法互不相同都好處理,這樣保證每種烹飪方法是獨立不受影響的,並且至多選一種,所以每種烹飪方法 \(i\) 選菜的方案為 \(sum_i=\sum\limits_{j=1}^m a_{i,j}\),總方案就為 \(\sum\limits_{i=1}^n sum_i-1\),減一為集合為空的情況。

在第三種限制裡,集合中每種食材的使用次數不超過 \(\left\lfloor\dfrac{k}{2}\right\rfloor\) 次,若是直接順著計算,肯定不好求,因為不超過的方案對比超過太複雜了。正難則反,考慮容斥,我們前面求出了不考慮第三種限制的方案數,只要我們求出了不符合第三種限制的方案數,兩個相減即可。

考慮用動態規劃,我們在考慮不符合第三種限制是,同時也要滿足前兩種限制,這樣保證求出來的方案一定在總方案數中。最樸素的,設狀態 \(dp_{i,j,k}\) 為前 \(i\) 種烹飪方法中選了 \(j\) 道菜,其中 \(k\) 道菜是第 \(g\) 種食材做的。這裡需要列舉 \(i\)\(j\)\(k\)\(g\) 四個量,轉移很好想

\(dp_{i,j,k}=dp_{i-1,j,k}+dp_{i-1,j-1,k}\times(sum_i-a_{i,g})+dp_{i-1,j-1,k-1}\times a_{i,g}\)

分為不選第 \(i\) 種烹飪方法,選了但不是第 \(g\) 種,選了是第 \(g\) 種。先列舉 \(g\),每次累加,這樣的不合法方案數為 \(\sum\limits_{k}dp_{n,j,k}\),這裡的 \(k>\left\lfloor\dfrac{j}{2}\right\rfloor\) 。這樣的複雜度是 \(O(n^3m)\),透過不了此題。

但我們再思考一下,在一種不合法的方案中不合法的食材有且僅有一種,因為假設有兩種,一定超過總的選菜數量,即\(2\times(\left\lfloor\dfrac{j}{2}\right\rfloor+1)>j\)。所以等價於 \(k>\dfrac{j}{2}\),化簡得到 \(k-(j-k)>0\),感性理解就是當前不合法食材數減去合法食材數大於 \(0\),這樣不滿足第三種限制的方案只需要不滿足這個限制就行了。(這裡也可以考慮感性理解,超過一半的食材肯定不會有兩個,並且一次只會有一種不合法食材,就有了如果不合法的食材比合法食材還多,那麼就是不滿足限制的)

放在狀態裡面就是不關心 \(j\)\(k\),只關心不合法與合法之間的差值,這樣狀態就可以簡化成 \(dp_{i,j}\) 表示前 \(i\) 種烹飪方法,差值為 \(j\) 的方案數,這裡的 \(j\) 可能是負數,所以加一個 \(n\) 來保證是正的(差值最大為 \(n\))。狀態轉移為:

\(dp_{i,j}=dp_{i-1,j}+dp_{i-1,j-1}\times a_{i,g}+dp_{i-1,j+1}\times(sum_i-a_{i,g})\)

每次列舉一個 \(g\),計算當前情況下的不合法數,即 \(\sum\limits_{j>0}dp_{n,j}\),不同 \(g\) 之間不會重合,所以每處理一次就讓總方案數減去它即可。

#include<bits/stdc++.h>
using namespace std;
int n,m,a[120][2020];
long long s[120],dp[120][220],ans=1;
const int mod=998244353; 
int main(){
	cin >> n >> m;
	for(int i = 1; i <= n; i++){
		for(int j = 1; j <= m; j++){
			cin >> a[i][j];
			s[i] += a[i][j];
			s[i] %= mod;
		}
	}
	for(int i = 1; i <= n; i++){
		ans *= (s[i] + 1);
		ans %= mod;
	}
	ans = (ans - 1 + mod) % mod;
	for(int k = 1; k <= m; k++){
		memset(dp, 0, sizeof(dp));
		dp[0][100] = 1;
		for(int i = 1; i <= n; i++){
			long long now = s[i] - a[i][k];
			for(int j = 0; j <= n + 100; j++){
				dp[i][j] = dp[i - 1][j];
				if(j) dp[i][j] += dp[i - 1][j - 1] * a[i][k];
				dp[i][j] += dp[i - 1][j + 1] * now;
				dp[i][j] %= mod;
			}
		}
		for(int i = 101; i <= n + 100; i++){
			ans -= dp[n][i];
			ans = (ans + mod) % mod;
		}
	}
	cout << ans << endl;
	return 0;
}

P8867 [NOIP2022] 建造軍營

計算合法的建造軍營和看守道路方案數,合法即為去掉一條沒人看守的邊後軍營之間依然連通,因為是一條,所以容易發現在圖中,強連通分量的邊被割去一條是一定不會影響軍營連通的,即強連通分量的邊想看守就看守,不作為決定性因素。只有割邊與方案的合法性有關。

所以我們考慮縮點,在無向圖縮點後,原圖會變成一個樹,這點方便我們做樹形 dp。

每個強連通分量內的方案數是可以預處理的,處理出點數為 \(v_i\),邊數為 \(e_i\)。那麼在這個強連通分量中,不選軍營的方案數是 \(2^{e_i}\),選至少一個軍營的方案數為 \(2^{v_i+e_i}-2^{e_i}\)

考慮題目,經過上面分析,題目簡化成,在一顆樹上選出若干個點,選出的點在去掉一條邊後依然連通的方案數。意思即選出的點之間相連的唯一路徑上的邊一定要看守,其他隨意。問題縮小到子樹上,限制只在子樹裡有軍營,由於列舉子樹時,唯一不同的就是根節點,它是我們區分不同方案的關鍵,所以我們限制當前的子樹的根節點一定為相連路徑上的一點,這點方便轉移,因為這樣子樹中的軍營就可以透過根節點相連。

在限制下,考慮 \(u\) 節點和它的兒子 \(v\),它們之間相不相連取決於 \(v\) 的子樹中有無軍營。順著這個可以設出狀態 \(dp_{u,0/1}\) 表示在以 \(u\) 節點為根的子樹中的沒有/有軍營的方案數。根據列舉兒子節點順序會有前 \(i\) 個兒子的隱藏狀態,類似揹包,分為當前子節點選不選節點,轉移可以寫成:

\(\begin{cases}dp_{u,1}=dp_{u,0}\times dp_{v,1}+dp_{u,1}\times(2\times dp_{v,0}+dp_{v,1})\\dp_{u,0}=dp_{u,0}\times (2\times dp_{v,0})\end{cases}\)

\(2\) 的地方是因為這裡的邊 \((u,v)\) 由於 \(v\) 中沒有軍營可以選和不選。考慮了子樹內的方案數,並且前面為了統計答案,限制了軍營只在子樹內,所以子樹外的邊可以隨便選。這裡要注意的就是在統計答案上如何保證不重不漏,計算出 \(u\) 子樹的方案後,我們回到了 \(fa_u\) 子樹,為了不重複,\((fa_u,u)\) 這條邊就會選入,這是和 \(u\) 子樹方案根本的區別,所以在計入子樹 \(u\) 的答案時,先預處理出 \(sz_i\) 表示 \(i\) 子樹中的邊數(包括強連通分量的邊),原本為 \(dp_{u,1}\times 2^{sz_1-sz_u}\),但這其中會多計算一次和 \(fa_u\) 相連的方案,需要改成 \(dp_{u,1}\times 2^{sz_1-sz_u-1}\) 保證不重。

答案即為:

\(\begin{cases}ans\leftarrow dp_{u,1}&u=1\\ans\leftarrow dp_{u,1}\times 2^{sz_1-sz_u-1}&u\ne 1\end{cases}\)

要理解不漏也很容易,我們將每個節點作為中轉點,實際上所有的選點方案都一定至少會在一棵子樹上被統計。

在這題中,我們經過了縮點,將題目轉化為樹形 dp,統計方案數,在樹上選點可以依照題意給出可以轉移的狀態,並且為了統計方案,可以給狀態一些隱藏的限制,一是便於轉移,二是可以使狀態的意義更加明晰,特指某一種情況下的方案,使得小的方案之間沒有並集,便於統計答案。最後的複雜度為 \(O(n+m)\)

#include<bits/stdc++.h>
using namespace std;
int read(){
	int x = 0, f = 1;
	char c = getchar();
	while(c < '0' || c > '9'){
		if(c == '-') f = -1;
		c = getchar();
	}
	while(c >= '0' && c <= '9'){
		x = (x << 1) + (x << 3) + (c - '0');
		c = getchar();
	}
	return x * f;
}
const int mod = 1000000007;
int n, m, cnt, cnt2, top, idx, tot;
int h[500010], h2[500010];
long long g[500010], sz[500010], sum1[500010];
int low[500010], dfn[500010], bel[500010], ins[500010], st[500010];
long long dp[500010][2], ans;
struct node{
	int to, nxt;
}e[2000010];
struct node2{
	int to, nxt;
}e2[2000010];
void add(int u, int v){
	e[++cnt].to = v;
	e[cnt].nxt = h[u];
	h[u] = cnt;
}
void add2(int u, int v){
	e2[++cnt2].to = v;
	e2[cnt2].nxt = h2[u];
	h2[u] = cnt2;
}
void tarjan(int u, int fa){
	dfn[u] = low[u] = ++tot;
	st[++top] = u;
	ins[u] = 1;
	for(int i = h[u]; i; i = e[i].nxt){
		int v = e[i].to;
		if(!ins[v]){
			tarjan(v, u);
			low[u] = min(low[u], low[v]);
		}
		else if(v != fa){
			low[u] = min(low[u], dfn[v]);
		}
	}
	if(low[u] == dfn[u]){
		++idx;
		int v;
		do{
			v = st[top--];
			bel[v] = idx;
			sum1[idx]++;
			ins[v] = 0;
		}while(v != u);
	}
}
long long ksm(long long a, long long b){
	long long ans = 1;
	while(b){
		if(b & 1) ans = (ans * a) % mod;
		a = (a * a) % mod;
		b >>= 1;
	}
	return ans;
}
void init(int u, int fa){
	sz[u] = g[u];
	for(int i = h2[u]; i; i = e2[i].nxt){
		int v = e2[i].to;
		if(v == fa) continue;
		init(v, u);
		sz[u] += sz[v] + 1;
	}
}
void dfs(int u, int fa){
	dp[u][0] = ksm(2, g[u]) % mod, dp[u][1] = (ksm(2, sum1[u] + g[u]) - dp[u][0] + mod) % mod;
	for(int i = h2[u]; i; i = e2[i].nxt){
		int v = e2[i].to;
		if(v == fa) continue;
		dfs(v, u);
		dp[u][1] = (dp[u][1] * (2 * dp[v][0] % mod + dp[v][1]) % mod + dp[u][0] * dp[v][1] % mod) % mod;
		dp[u][0] = dp[u][0] * (2 * dp[v][0] % mod) % mod;
	}
	if(u == 1) ans += dp[u][1], ans %= mod;
	else ans += (dp[u][1] * ksm(2, sz[1] - sz[u] - 1) % mod) % mod, ans %= mod;
}
int main(){
	n = read(), m = read();
	for(int i = 1; i <= m; i++){
		int u = read(), v = read();
		add(u, v), add(v, u);
	}
	tarjan(1, 0);
	for(int i = 1; i <= n; i++){
		for(int j = h[i]; j; j = e[j].nxt){
			int v = e[j].to;
			if(bel[i] == bel[v]) g[bel[i]]++;
			else add2(bel[i], bel[v]);
		}
	}
	for(int i = 1; i <= idx; i++) g[i] /= 2;
	init(1, 0);
	dfs(1, 0);
	cout << ans << endl;
	return 0;
}

[ARC115E] LEQ and NEQ

如果不考慮第二個條件的話,那麼答案顯然是 \(\sum a_i\),所以我們考慮容斥掉不合法的方案。

容斥的基本條件,我們要找到一個共性,也就是能夠容斥的性質。這一題中,容斥的物件就是不符合第二個條件的兩項。所以我們可以設 \(g_i\) 為剛好有 \(i\) 組違反條件的項,\(f_i\) 為至少有 \(i\) 組違反條件的項。我們直接套用容斥的公式。

\[ans=\sum\limits_{i=0}^{n-1}(-1)^i\ f_i \]

處理 \(f_i\) 的過程需要用到動態規劃。我們發現不同的數字代表一段區間,彼此之間相鄰的違反條件的點也正好對應一段區間,壞點越多序列段數越少,壞點的增多會導致段數的減少,所以反過來,我們可以用段數的多少來滿足”至少“這個條件(比如 \(5\) 個分成 \(3\) 段,說明至少有 \(2\) 個壞點),且分段問題更好解決。於是我們設狀態 \(dp_{i,j}\) 為前 \(i\) 個正好分了 \(j\) 段的方案數。可能會覺得這不就和 \(f_i\)至少衝突了嗎?我們這裡的分段並不是嚴格意義上的分段,我們只規定了段內一定相同,而相鄰的雖然不是一段但也可以相同。

因為我們轉移需要一整段\(a_i\) 的大小關係,又因為不是嚴格分段,不需要考慮段之間是否一定不同,所以轉移為

\(dp_{i,j}=\sum\limits_{k=0}^{i-1}dp_{k,j-1}\times \min\limits_{o=k+1}^ia_o\)

統計答案也就變成

\(ans=\sum\limits_{i=0}^{n-1}(-1)^i\ dp_{n,n-i}\)

發現 \(j\) 位置只需要 \(j\)\(j-1\),即與當前奇偶性有關,在奇偶之間轉移,所以可以用滾動陣列降維。降維之後統計答案也方便,因為統計答案時同樣只需要關心奇偶性。

\(dp_{i,0/1}=\sum\limits_{k=0}^{i-1}dp_{k,1/0}\times \min\limits_{o=k+1}^ia_o\)

這樣轉移是 \(O(n^2)\) 的,需要最佳化轉移。每列舉一個 \(i\),就要多考慮一個 \(a_i\),並且對於連續區間的最小值,它的轉移的貢獻也是連續的,往前列舉 \(j\) 的過程中,會有一個時刻,\(a_i\) 永遠不會是之後的最小值,這個時候即為左邊第一個小於 \(a_i\) 的數。以這個時刻為分隔點 \(k\)\(k\) 及它右邊的貢獻乘的都是 \(a_i\),而 \([k,i-1]\)\(dp\) 值可以透過字首和統計;左邊的貢獻由於已經不受影響新列舉的 \(a_i\),可以發現總貢獻之前已經處理過了,即為 \(dp_{k,0/1}\)

關於找到左邊第一個小於 \(a_i\) 的位置,可以用單調棧實現。複雜度就降到 \(O(n)\)

#include <bits/stdc++.h>
using namespace std;
long long read(){
    int x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){
        if(c == '-') f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        x = (x << 1) + (x << 3) + (c - '0');
        c = getchar();
    }
    return x * f;
}
long long n, ans, mod = 998244353;
long long a[500010], s[500010][2], dp[500010][2];
long long st[500010], top;
int main(){
    n = read();
    for(int i = 1; i <= n; i++){
        a[i] = read();
    }
    dp[0][0] = s[0][0] = 1;
    for(int i = 1; i <= n; i++){
        while(top > 0 && a[st[top]] >= a[i]) top--;
        if(!top){
            for(int j = 0; j <= 1; j++) dp[i][j] = 1ll * (dp[i][j] + s[i - 1][j ^ 1] * a[i]) % mod;
        }
        else{
            for(int j = 0; j <= 1; j++) dp[i][j] = 1ll * (dp[st[top]][j] + (s[i - 1][j ^ 1] - s[st[top] - 1][j ^ 1] + mod) * a[i]) % mod;
        }
        s[i][0] = (s[i - 1][0] + dp[i][0]) % mod;
        s[i][1] = (s[i - 1][1] + dp[i][1]) % mod;
        st[++top] = i;
    }
    if(n % 2 == 1) ans = (dp[n][1] - dp[n][0] + mod) % mod;
    else ans = (dp[n][0] - dp[n][1] + mod) % mod;
    cout << ans;
    return 0;
}

P3800 Power收集

這題的狀態很明確,因為完全可以把每一個網格看成狀態,網格之間相互轉移,所以設狀態 \(dp_{i,j}\) 為走到第 \(i\) 行第 \(j\) 列時的最大值。依照題意,一層的轉移只與上一層有關,所以轉移為

\(dp_{i,j}=\max\limits_{j-t\le k\le j+t}(dp_{i-1,k})+a_{i,j}\)

典型的單調佇列形式,我們需要的只有一段區間中的最大值,並且區間移動是連續的。瓶頸在於我們列舉 \(j\) 的時候,我們只能知道 \([j-t,j]\) 的最大值。解決方法很簡單,最大值有結合律,即 \(\max(a,b,c)=\max(a,(b,c))\),所以我們正著和反著都做一遍,把 \([j-t,j]\)\([j,j+t]\) 的最大值分別求出來,跑兩遍單調佇列即可。

其他還有可以最佳化的地方,比如由於一層的轉移只與上一層有關,所以可以滾掉 \(i\) 這一維。

P3594 [POI2015] WIL

這題中,可以一次將任意長度小於等於 \(d\) 的區間變為 \(0\),求修改完之後區間和小於 \(p\) 的最長區間長度。

對於區間和,我們可以用字首和 \(sum_x=\sum\limits_{i=1}^xa_i\),來 \(O(1)\) 求出。

可以容易想到,“任意長度小於等於 \(d\) 的區間” 在實際操作中一定是貪心地取長度為 \(d\) 的區間,因為多取一定不劣。所以一個暴力做法是,我們列舉區間的左右斷點 \(l\)\(r\),再列舉一個 在 \([i,j]\) 之間的 \(k\) 為修改的區間右端點,判斷減去一段區間後的區間和是否小於 \(q\) 來更新答案。複雜度 \(O(n^3)\)

考慮最佳化,對於一個左端點,我們一定是找它滿足條件的最遠右端點;同樣,對於一個右端點,我們一定是找它滿足條件的最遠左端點。這個性質可以用上雙指標,只需要列舉 \(r\)\(k\)\(l\) 只需要根據區間和單調向右移動即可。複雜度 \(O(n^2)\)

現在的瓶頸是,因為我們判斷一個區間能否滿足條件,一定要找到它的最大修改區間才能一次做出決定,所以能否最佳化掉尋找最大修改區間的時間呢?可以用到單調佇列,每列舉一個 \(i\),就多一個區間 \([i-d,i]\),所以我們維護當前滿足條件的最大修改區間,判斷時直接取出即可。這裡的條件指當前的 \(l\) 指標是否已經超過當前隊首修改區間的左端點

如果此時的 \([l,r]\) 用上最大修改區間還是大於 \(p\) 的話,那麼 \(l\) 只能往右走,並同時刪去由於 \(l\) 向右走而導致不合法的修改區間,保證下一次取出的最大修改區間是合法的,由於單調性,我們不用擔心區間會不會被多刪。

由於一個數最多進出佇列一次,並且 \(l\) 單調向右移動,所以單調佇列和雙指標都是線性的,複雜度降到 \(O(n)\)

#include <bits/stdc++.h>
using namespace std;
int read(){
    int x = 0, f = 1;
    char c = getchar();
    while(c < '0' || c > '9'){
        if(c == '-') f = -1;
        c = getchar();
    }
    while(c >= '0' && c <= '9'){
        x = (x << 1) + (x << 3) + (c - '0');
        c = getchar();
    }
    return x * f;
}
int n, m, k, t, ans;
int dp[4010][4010], a[4010][4010], q[4010];
int main(){
    n = read(), m = read(), k = read(), t = read();
    for(int i = 1; i <= k; i++){
        int x = read(), y = read(), v = read();
        a[x][y] = v;
    }
    for(int i = 1; i <= n; i++){
        int head = 1, tail = 0;
        for(int j = 1; j <= m; j++){
            while(head <= tail && dp[i - 1][q[tail]] <= dp[i - 1][j]) tail--;
            while(head <= tail && q[head] + t < j) head++;
            q[++tail] = j;
            dp[i][j] = max(dp[i][j], dp[i - 1][q[head]] + a[i][j]); 
        }
        head = 1, tail = 0;
        for(int j = m; j >= 1; j--){
            while(head <= tail && dp[i - 1][q[tail]] <= dp[i - 1][j]) tail--;
            while(head <= tail && q[head] - t > j) head++;
            q[++tail] = j;
            dp[i][j] = max(dp[i][j], dp[i - 1][q[head]] + a[i][j]); 
        }
    }
    for(int i = 1; i <= m; i++) ans = max(ans, dp[n][i]);
    cout << ans << endl;
    return 0;
}

相關文章