回溯演算法

zhhfan發表於2021-03-26

解決⼀個回溯問題,實際上就是⼀個決策樹的遍歷過程。你只需要思考 3 個問題:

  1. 路徑:也就是已經做出的選擇
  2. 選擇列表:也就是你當前可以做的選擇
  3. 結束條件:也就是到達決策樹底層,⽆法再做選擇的條件

回溯演算法的框架:

result = [] 
def backtrack(路徑, 選擇列表): 
    if 滿⾜結束條件: 
        result.add(路徑) 
        return 
    for 選擇 in 選擇列表: 
        做選擇 
        backtrack(路徑, 選擇列表) 
        撤銷選擇

【舉例 1】
全排列問題:
給定一個沒有重複數字的序列,返回其所有可能的全排列。

leetcode連結

⽐⽅說給三個數[1,2,3],如下圖,⽐如說你站在下圖的紅⾊節點上,則 [2] 就是「路徑」,記錄你已經做過的選擇; [1,3] 就是「選擇列表」,表⽰你當前可以做出的選擇;「結束條件」就是遍歷到樹的底層,在這⾥就是選擇列表為空的時候。

如此,回溯演算法的核心框架可以表示為:

for 選擇 in 選擇列表:
    # 做選擇 
    將該選擇從選擇列表移除 
    路徑.add(選擇) 
    backtrack(路徑, 選擇列表) 
    # 撤銷選擇 
    路徑.remove(選擇) 
    將該選擇再加⼊選擇列表

我們只要在遞迴之前做出選擇,在遞迴之後撤銷剛才的選擇(如樹的遍歷),就能正確得到每個節點的選擇列表和路徑,則全排列的詳細程式碼為:

class Solution {
public:
    vector<vector<int>> permute(vector<int>& nums) {
        vector<vector<int>> res;
        vector<int> trace;
        traceback(nums, trace, res);
        return res;
    }

    void traceback(vector<int> &nums, vector<int> trace, vector<vector<int>>& res){
        if(trace.size() == nums.size()){
            res.push_back(trace);
            return;
        }
        for(int item: nums){
            if(find(trace.begin(), trace.end(), item) == trace.end()){
                trace.push_back(item);
                traceback(nums, trace, res);
                trace.erase(trace.end()-1);
            }
        }
    }
};

【舉例 2】
N皇后問題:
n 皇后問題研究的是如何將 n 個皇后放置在 n×n 的棋盤上,並且使皇后彼此之間不能相互攻擊。給你一個整數 n ,返回所有不同的 n 皇后問題的解決方案。每一種解法包含一個不同的 n 皇后問題的棋子放置方案,該方案中 'Q' 和 '.' 分別代表了皇后和空位。
注:皇后彼此不能相互攻擊,也就是說:任何兩個皇后都不能處於同一條橫行、縱行或斜線上。

leetcode連結

class Solution {
public:
    vector<vector<string>> solveNQueens(int n) {
        vector<vector<string>> vvs;
        vector<string> vs(n, string(n, '.'));
        traceback(0, vs, vvs);
        return vvs;
    }

    void traceback(int row, vector<string>& vs, vector<vector<string>>& vvs){
        if(row == vs.size()){
            vvs.push_back(vs);
            return;
        }
        for(int i = 0;i < vs.size();i++){
            if(!isValid(row, i, vs)) continue;
            vs[row][i] = 'Q';
            traceback(row+1, vs, vvs);
            vs[row][i] = '.';
        }

    }

    bool isValid(int row, int n, vector<string>& vs){
        // 同一列
        for(int i = 0;i < row;i++)
            if(vs[i][n] == 'Q') return false;
        // 左上斜線
        for(int i = row-1, j = n-1; i >= 0 && j >= 0;i--,j--)
            if(vs[i][j] == 'Q') return false;
        // 右上斜線
        for(int i = row-1, j = n+1; i >= 0 && j < vs.size();i--,j++)
            if(vs[i][j] == 'Q') return false;
        return true;
    }
};

相關文章