【演算法框架套路】回溯演算法(暴力窮舉的藝術)

雪山飛豬發表於2021-08-24

回溯演算法介紹

回溯演算法可以搜尋一個問題的所有解,本質是用遞迴代替N層for迴圈來“暴力窮舉”

原理如下:

  1. 從根節點出發深度搜尋解空間樹
  2. 搜尋到有解的分支時,繼續向下搜尋
  3. 搜尋到無解的分支時,回退到上一步,顧名思義“回溯”

框架套路

talk is cheap,show you the 套路,框架如下

結果集=[]
function dfs(選擇列表,已選擇的陣列)
    if 結束條件
        結果集追加
        return
    for 選擇 in 選擇列表
        做選擇
        dfs(選擇列表, 已選擇陣列) 進入下一次選擇
        取消選擇
dfs(選擇列表,[])
return 結果集

思路來自labuladong的演算法小抄,自己改成了個人覺得更通用的版本,預設收集所有的解,便於跟蹤除錯。

重點:

  1. 選擇列表。當前可以做出的選擇
  2. 已選擇路徑。已經做出的選擇
  3. 結束條件。無法再做出選擇的條件

有了這框架,以後遇到需要窮舉的演算法,把3個重點想通,直接套用,簡直不要太嗨~

演算法示例

以下演算法全用python實現,需要注意的是python的陣列預設是傳遞引用,引入了copy包來複制陣列

全組合

全組合是窮舉的代表了吧,給定指定不重複的字串,比如給定["a","b"],返回所有的組合結果應該是

aa
ab
ba
bb

我們來套用框架實現一下,程式碼如下

import copy

# 全組合
def combination(str_list):
    res = []

    max_len = len(str_list)

    def dfs(str_list, track_list):
        if len(track_list) == max_len:  # 滿足條件,加入結果集
            res.append(track_list)
            return
        for c in str_list:
            track_list.append(c)  # 選擇
            dfs(str_list, copy.copy(track_list))  # 進入下一次選擇
            track_list.pop()  # 取消選擇

    dfs(str_list, [])
    return res

三個重點:

  1. 選擇列表。可以選擇的字串,比如['a','b','c'],對應變數str_list。
  2. 已選擇路徑。已經做出的選擇,比如已經選擇了['a'],對應變數track_list。
  3. 結束條件。無法再做出選擇的條件,已選擇的陣列長度等於最大長度,對應len(track_list) == max_len

我們來測試一下

for v in combination(['a', 'b']):
    print(v)

執行輸出

全排列

全排列和全組合差不多,唯一的區別是已經選擇過的字串,不讓選擇了。
我們只需要在全組合程式碼的基礎上加上限制即可,程式碼如下

import copy


# 全排列
def permute(str_list):
    res = []

    max_len = len(str_list)

    def dfs(str_list, track_list):

        if len(track_list) == max_len:  # 滿足條件,加入結果集
            res.append(track_list)
            return
        for c in str_list:
            if c in track_list:  # 已經存在的不再新增
                continue
            track_list.append(c)  # 選擇
            dfs(str_list, copy.copy(track_list))  # 進入下一次選擇
            track_list.pop()  # 取消選擇

    dfs(str_list, [])
    return res

我們只是改了一下這裡

我們用chenqionghe的簡稱['c','q','h']來測試一下

for v in permute(['c', 'q', 'h']):
    print(v)

執行輸出

湊零錢

給定數量N種面值的硬幣, 再給定一個金額,返回硬幣湊出這個金額的最少數量。
比如,給定硬幣1, 2, 5,總額為10,最少需要2枚硬幣5+5=10

程式碼實現如下

def coin_change(coins, amount):
    res_list = []

    def dfs(n, track_list):
        if n == 0:
            res_list.append(track_list)  # 滿足條件
            return 0

        if n < 0:
            return -1

        for coin in coins:
            track_list.append(coin)  # 做選擇
            dfs(n - coin, copy.copy(track_list))  # 選擇一個硬幣,目標金額就會減少,解變為1+sub
            track_list.pop()  # 取消選擇

    dfs(amount, [])
    return res_list

三個重點:

  1. 選擇列表。可以選擇的硬幣,對應coins陣列。
  2. 已選擇路徑。已經做出的選擇,對應track_list陣列。
  3. 結束條件。無法再做出選擇的條件,金額為0和負的時候。

需要注意的是:df函式代表的是:目標金額是n,需要dfs[n]個硬幣,比如給定金額10,這次選擇了2,這次選擇能達到的金額數量是1+dfs(10 - 2),也就是1+dfs(8)

我們來執行一下:

for v in coin_change([2, 3, 5], 10):
    print(v)

輸出如下

給出了所有的方案,如果要最小的硬幣只需要統計長度最小的即可。

N皇后

最典型的是八皇后:

在8×8格的國際象棋上擺放8個皇后,使其不能互相攻擊,即任意兩個皇后都不能處於同一行、同一列或同一斜線上,問有多少種擺法。

以4皇后為例,給定數字4,應該給出兩種方案如下

第一種方案
. Q . .
. . . Q
Q . . .
. . Q .
第二種方案
. . Q .
Q . . .
. . . Q
. Q . .

套用框架實現如下

# N皇后問題
def solve_n_queens(n):
    res = []

    def dfs(board, row):
        if row == n:  # 到達最後一行,追加結果集
            res.append(board)
        for col in range(n):
            # 排除不合法的選擇
            if not is_valid(board, row, col, n):
                continue
            board[row][col] = 'Q'  # 選擇第row行第col列放Q

            dfs(copy.deepcopy(board), row + 1)

            board[row][col] = '.'  # 撤銷選擇
        return False

    board = [['.'] * n for _ in range(n)]  # 初始化二維陣列
    dfs(board, 0)  # 從第0行開始做選擇
    return res

# 判斷是否能在board[row][col]放置Q
def is_valid(board, row, col, n):
    # 垂直方向是否有Q
    for v in range(row):
        if board[v][col] == 'Q':
            return False
    # 左上方是否有Q
    i, j = row - 1, col - 1
    while i >= 0 and j >= 0:
        if board[i][j] == 'Q':
            return False
        i = i - 1
        j = j - 1
    # 右上方是否有Q
    i, j = row - 1, col + 1
    while i >= 0 and j <= n - 1:
        if board[i][j] == 'Q':
            return False
        i = i - 1
        j = j + 1
    return True

N皇后的解法是,在每行做選擇,選擇為N列,做出選擇後,進入下一行繼續做選擇
三個重點:

  1. 選擇列表。可以選擇的列,對應的是0-n的任意一列。
  2. 已選擇路徑。已經做出的選擇,對應board二維陣列。
  3. 結束條件。無法再做出選擇的條件,也就是已經到達最後一行的時候。

注意:is_valid的函式,主要是判斷檢測當前位置是否能放“皇后”,也就是檢查垂直、左上方向和右上方是不是都沒有“皇后”

我們來測試一下

res = solve_n_queens(8)
for data in res:
    print('-' * 20)
    for v in data:
        print(" ".join(v))

執行輸出如下

優化思路

新增備忘錄避免重複計算

以湊零錢為例,裡邊其實會出現很多相同子問題的遞迴
以10舉個例子,當我們選擇了選擇了[2, 3]和[5]的時候,都需要再計算dfs(5)的值。資料越大,重複的遞迴越多,效能越差。

我們可以引入一個map,記錄已經計算出的值,下次遇到相同問題直接返回結果

def coin_change_optimization(coins, amount):
    memo = {}
    def dfs(n):
        if n in memo:
            return memo[n]
        if n == 0:
            return 0
        if n < 0:
            return -1

        min_res = float('INF')
        for coin in coins:
            sub = dfs(n - coin)  # 選擇一個硬幣,目標金額就會減少,解變為1+sub
            if sub == -1:
                continue
            if min_res > 1 + sub:  # 更新最小值
                min_res = 1 + sub

        memo[n] = min_res if min_res != float('INF') else -1
        return memo[n]

    return dfs(amount)

得到解向上返回阻斷遞迴

以N皇后為例,我們只需要在得到解的時候return,並在上層接收即可,程式碼如下

# N皇后問題
def solve_n_queens(n):
    res = []

    def dfs(board, row):
        if row == n:  # 到達最後一行,追加結果集
            res.append(board)
            return True
        for col in range(n):
            # 排除不合法的選擇
            if not is_valid(board, row, col, n):
                continue
            board[row][col] = 'Q'  # 選擇第row行第col列放Q

            if dfs(copy.deepcopy(board), row + 1):
                return True

            board[row][col] = '.'  # 撤銷選擇
        return False

    board = [['.'] * n for _ in range(n)]  # 初始化二維陣列
    dfs(board, 0)  # 從第0行開始做選擇
    return res

以上只是在這裡做了改動

看到沒有,這就是回溯暴力窮舉的藝術,最簡單的框架,解決最難的問題~

相關文章