C++記憶化搜尋

薛儒浩發表於2024-08-24

前言(一些小廢話)

C++中的記憶化搜尋(Memoization)是一種最佳化技術,用於減少重複計算的開銷。它常用於動態規劃和遞迴問題中。
記憶化搜尋和動態規劃從根本上來講就是一個東西,任何一個DP方程都能轉為記憶化搜尋 ,反之亦然。
我寫這篇文章,是因為自己的DP基礎薄弱,不易推出狀態轉移方程。實際上,DP的程式碼量更少,也更方便(在能推出狀態轉移方程的基礎上)。

記憶化搜尋的優缺點

優點

  1. 記憶化搜尋可以避免搜到無用狀態, 特別是在有狀態壓縮時
  2. 不需要注意轉移順序(這裡的"轉移順序"指正常DP中for迴圈的巢狀順序以及迴圈變數是遞增還是遞減)
  3. 邊界情況非常好處理, 且能有效防止陣列訪問越界
  4. 對我這種蒟蒻來說寫起來簡單易懂
  5. 有些DP(如區間DP)用記憶化搜尋寫很簡單但正常DP很難

缺點

  1. 不能滾動陣列(雖然我也不大會),要滾動陣列的話還是老老實實寫DP吧
  2. 有些最佳化比較難加
  3. 由於遞迴, 有時效率較低但不至於 TLE (狀壓dp除外)
  4. 程式碼有點長

如何寫記憶化搜尋

不考慮DP
由暴搜開始思考

  1. 寫出這道題的暴搜程式
  2. 將這個DFS改成"無需外部變數"的DFS
  3. 新增記憶化陣列

例題:[NOIP2005 普及組] 採藥

假設我P也不會,只會暴搜,你會得到:

#include<bits/stdc++.h>
using namespace std;
#define N 105
int n,t;
int T[N],W[N];
int ans;
void dfs( int x , int time , int tans ) {
	if(time < 0)
		return;
	if(x == n+1) {
		ans = max(ans,tans);
		return;
	}
	dfs(x+1,time,tans);
	dfs(x+1,time-T[x],tans+W[x]);
}
int main() {
	cin >> t >> n;
	for(int i = 1; i <= n; i++)
		cin >> T[i] >> W[i];
	dfs(1,t,0);
	cout << ans << endl;
	return 0;
}

以及冰冷的30分
開始第二步,將這個DFS改成"無需外部變數"的DFS
引入問題:如何記錄答案?
答:可以將DFS引入返回值
於是我們得到了:

#include<bits/stdc++.h>
using namespace std; 
#define N 105
int n,t;
int T[N],W[N];
int ans;
int dfs( int x , int time) {
	if(x == n+1) {
		return 0;
	}
    int dfs1 ,dfs2 = -1;
	dfs1 = dfs(x+1,time);
	if(time >= T[x]){
		dfs2 = dfs(x+1,time-T[x]) + W[x];
	}
	return max(dfs1,dfs2);
}
int main() {
	cin >> t >> n;
	for(int i = 1; i <= n; i++)
		cin >> T[i] >> W[i];
	cout<<dfs(1,t)<<endl;
	return 0;
}

別急著提交,因為我們只是去掉了"外部變數",卻沒有最佳化複雜度
接下來,我們新增記憶化陣列
多測試幾組樣例發現,其實對於相同的x和time,DFS的返回值總是相同的(廢話)
接下來引入記憶化陣列"mem",用它來記錄下DFS每一個返回值,每次DFS判斷一下mem是否有值,若有值,直接返回mem;若無值,繼續搜尋(類似於剪枝)
於是我們得到了:

#include<bits/stdc++.h>
using namespace std; 
#define N 105
int n,t;
int T[N],W[N];
int mem[N][10*N];
int ans;
int dfs( int x , int time) {
	if(mem[x][time] != 0){
		return mem[x][time];
	}
	if(x == n+1) {
		return 0;
	}
    int dfs1 ,dfs2 = -1;
	dfs1 = dfs(x+1,time);
	if(time >= T[x]){
		dfs2 = dfs(x+1,time-T[x]) + W[x];
	}
	mem[x][time] = max(dfs1,dfs2);
	return mem[x][time];
}
int main() {
	cin >> t >> n;
	for(int i = 1; i <= n; i++)
		cin >> T[i] >> W[i];
	cout<<dfs(1,t)<<endl;
	return 0;
}

注意:mem陣列千萬不要開小了(作者親自踩坑)

鳴謝

參考:https://www.luogu.com.cn/article/qay8mori

相關文章