強化學習筆記之【論文精讀】【ACE:一種基於熵規整和因果關係的離線SAC演算法】

这可就有点麻烦了發表於2024-10-17

強化學習筆記之【論文精讀】【ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization】


目錄
  • 強化學習筆記之【論文精讀】【ACE:Off-PolicyActor-CriticwithCausality-AwareEntropyRegularization】
    • 前言:
    • 論文一覽
    • 論文摘要
    • 論文主要貢獻:
    • 論文程式碼框架
      • 1. 初始化模組
      • 2. 因果發現模組
      • 3. 策略最佳化模組
      • 4. 重置機制模組
    • 論文原始碼主幹
      • 程式碼流程解釋
    • 論文模組程式碼及實現
      • 因果發現模組
      • 策略最佳化模組
        • 1. 取樣經驗資料
        • 2. 計算目標 Q 值
        • 3. 更新 Q 網路
        • 4. 策略網路更新
        • 5. 自適應熵調節
        • 6. 返回值
      • 重置機制模組
        • 重置邏輯
        • 重置機制模組的原理
          • 1. 計算梯度主導度 ( $\beta_\gamma $)
          • 2. 軟重置策略和 Q 網路
          • 3. 策略和 Q 最佳化器的重置
          • 4. 重置機制模組的應用
            • a. 重置間隔達成時
            • b. 主導梯度或因果效應差異滿足條件時
            • c. 總結
        • 擾動因子的計算
          • 擾動因子(factor)
          • 組合擾動因子的公式
      • 評估程式碼
        • 1. 定期評估條件
        • 2. 初始化評估列表
        • 3. 進行評估
          • 3.1 回合初始化
          • 3.2 執行智慧體動作
          • 3.3 儲存回合獎勵
        • 4. 計算平均獎勵
        • 5. 儲存最佳模型
    • 論文復現結果

前言:

強化學習筆記之【論文精讀】【ACE:一種基於熵規整和因果關係的離線SAC演算法】
至少先點個贊吧,寫的很累的

該論文是清華專案組組內博士師兄寫的文章,專案主頁為ACE (ace-rl.github.io),於2024年7月發表在ICML期刊

強化學習筆記之【論文精讀】【ACE:一種基於熵規整和因果關係的離線SAC演算法】
強化學習筆記之【論文精讀】【ACE:一種基於熵規整和因果關係的離線SAC演算法】

因為最近組內(其實只有我)需要從零開始做一個相關專案,前面的幾篇文章都是鋪墊

本文章為強化學習筆記第5篇

本文初編輯於2024.10.5,好像是這個時間,忘記了,前後寫了兩個多星期

CSDN主頁:https://blog.csdn.net/rvdgdsva

部落格園主頁:https://www.cnblogs.com/hassle

部落格園本文連結:


論文一覽

這篇強化學習論文主要介紹了一個名為 ACE 的演算法,完整名稱為 Off-Policy Actor-Critic with Causality-Aware Entropy Regularization,它透過引入因果關係分析和因果熵正則化來解決現有模型在不同動作維度上的不平等探索問題,旨在改進強化學習【註釋1】中探索效率和樣本效率的問題,特別是在高維度連續控制任務中的表現。

【註釋1】:強化學習入門這一篇就夠了


論文摘要

在policy【註釋2】學習過程中,不同原始行為的不同意義被先前的model-free RL 演算法所忽視。利用這一見解,我們探索了不同行動維度和獎勵之間的因果關係,以評估訓練過程中各種原始行為的重要性。我們引入了一個因果關係感知熵【註釋3】項(causality-aware entropy term),它可以有效地識別並優先考慮具有高潛在影響的行為,以實現高效的探索。此外,為了防止過度關注特定的原始行為,我們分析了梯度休眠現象(gradientdormancyphenomenon),並引入了休眠引導的重置機制,以進一步增強我們方法的有效性。與無模型RL基線相比,我們提出的演算法 ACE:Off-policyActor-criticwith Causality-awareEntropyregularization。在跨越7個域的29種不同連續控制任務中顯示出實質性的效能優勢,這強調了我們方法的有效性、多功能性和高效的樣本效率。 基準測試結果和影片可在https://ace-rl.github.io/上獲得。

【註釋2】:強化學習演算法中on-policy和off-policy

【註釋3】:最大熵 RL:從Soft Q-Learning到SAC - 知乎


論文主要貢獻:

【1】因果關係分析:透過引入因果政策-獎勵結構模型,評估不同動作維度(即原始行為)對獎勵的影響大小(稱為“因果權重”)。這些權重反映了每個動作維度在不同學習階段的相對重要性。

作出上述改進的原因是:考慮一個簡單的例子,一個機械手最初應該學習放下手臂並抓住物體,然後將注意力轉移到學習手臂朝著最終目標的運動方向上。因此,在策略學習的不同階段強調對最重要的原始行為的探索是 至關重要的。在探索過程中刻意關注各種原始行為,可以加速智慧體在每個階段對基本原始行為的學習,從而提高掌握完整運動任務的效率。

此處可供學習的資料:

【2】因果熵正則化:在最大熵強化學習框架的基礎上(如SAC演算法),加入了因果加權的熵正則化項。與傳統熵正則化不同,這一項根據各個原始行為的因果權重動態調整,強化對重要行為的探索,減少對不重要行為的探索。

作出上述改進的原因是:論文引入了一個因果策略-獎勵結構模型來計算行動空間上的因果權重(causal weights),因果權重會引導agent進行更有效的探索, 鼓勵對因果權重較大的動作維度進行探索,表明對獎勵的重要性更大,並減少對因果權重較小的行為維度的探 索。一般的最大熵目標缺乏對不同學習階段原始行為之間區別的重要性的認識,可能導致低效的探索。為了解決這一限制,論文引入了一個由因果權重加權的策略熵作為因果關係感知的熵最大化目標,有效地加強了對重要原始行為的探索,並導致了更有效的探索。

此處可供學習的資料:

【3】梯度“休眠”現象(Gradient Dormancy):論文觀察到,模型訓練時有些梯度會在某些階段不活躍(即“休眠”)。為了防止模型過度關注某些原始行為,論文引入了梯度休眠導向的重置機制。該機制透過週期性地對模型進行擾動(reset),避免模型陷入區域性最優,促進更廣泛的探索。

作出上述改進的原因是:該機制透過一個由梯度休眠程度決定的因素間歇性地干擾智慧體的神經網路。將因果關係感知探索與這種新穎的重置機制相結合,旨在促進更高效、更有效的探索,最終提高智慧體的整體效能。

透過在多個連續控制任務中的實驗,ACE 展示出了顯著優於主流強化學習演算法(如SAC、TD3)的表現:

  • 29個不同的連續控制任務:包括 Meta-World(12個任務)、DMControl(5個任務)、Dexterous Hand(3個任務)和其他稀疏獎勵任務(6個任務)。
  • 實驗結果表明,ACE 在所有任務中都達到了更好的樣本效率和更高的最終效能。例如,在複雜的稀疏獎勵場景中,ACE 憑藉其因果權重引導的探索策略,顯著超越了 SAC 和 TD3 等現有演算法。

論文中的對比實驗圖表顯示了 ACE 在多種任務下的顯著優勢,尤其是在稀疏獎勵和高維度任務中,ACE 憑藉其探索效率的提升,能更快達到最優策略。


論文程式碼框架

在ACE原論文的第21頁,這玩意兒應該寫在正篇的,害的我看了好久的程式碼去排流程

不過說實話這虛擬碼有夠簡潔的,程式碼多少有點糊成一坨了

強化學習筆記之【論文精讀】【ACE:一種基於熵規整和因果關係的離線SAC演算法】

這是一個強化學習(RL)演算法的框架,具體是一個結合因果推斷(Causal Discovery)的離策略(Off-policy)Actor-Critic方法。下面是對每個模組及其引數的說明:

1. 初始化模組

  • Q網路 ( \(Q_\phi\) ):用於估計動作價值,(\phi) 是權重引數。
  • 策略網路 ( $\pi_\theta $):用於生成動作策略,(\theta) 是其權重。
  • 重放緩衝區 ($ D$ ):儲存環境互動的資料,以便進行取樣。
  • 區域性緩衝區 ( $D_c $):儲存因果發現所需的區域性資料。
  • 因果權重矩陣 ($ B_{a \rightarrow r|s} $):用於捕捉動作與獎勵之間的因果關係。
  • 擾動因子 ( \(f\) ):用於對策略進行微小擾動,增加探索。

2. 因果發現模組

  • 每 ( $$I$$ ) 步更新
    • 樣本取樣:從區域性緩衝區 ( \(D_c\) ) 中抽樣 ( \(N_c\) ) 條轉移。
    • 更新因果權重矩陣:調整 ($ B_{a \rightarrow r|s}$ ),用於反映當前策略和獎勵之間的因果關係。

3. 策略最佳化模組

  • 每個梯度步驟
    • 樣本取樣:從重放緩衝區 ( \(D\) ) 中抽樣 ($ N$ ) 條轉移。
    • 計算因果意識熵 ( \(H_c(\pi(\cdot|s))\) ):衡量在給定狀態下策略的隨機性和確定性,用於修改策略。
    • 目標 Q 值計算:更新目標 Q 值,用於訓練 Q 網路。
    • 更新 Q 網路:減少預測的 Q 值與目標 Q 值之間的誤差。
    • 更新策略網路:最大化當前狀態下的 Q 值,以提高收益。

4. 重置機制模組

  • 每個重置間隔
    • 計算梯度主導度 ( $\beta_\gamma $):用來量化策略更新的影響程度。
    • 初始化隨機網路:為新的策略更新準備初始權重 ( $\phi_i $)。
    • 軟重置策略和 Q 網路:根據因果權重進行平滑更新,幫助實現更穩定的最佳化。
    • 重置策略和 Q 最佳化器:在重置時清空狀態,以便進行新的學習過程。

論文原始碼主幹

原始碼上千行呢,這裡只是貼上main_casual裡面的部分程式碼,並且刪掉了很大一部分程式碼以便理清程式脈絡

def train_loop(config, msg = "default"):
    # Agent
    agent = ACE_agent(env.observation_space.shape[0], env.action_space, config)

    memory = ReplayMemory(config.replay_size, config.seed)
    local_buffer = ReplayMemory(config.causal_sample_size, config.seed)

    for i_episode in itertools.count(1):
        done = False

        state = env.reset()
        while not done:
            if config.start_steps > total_numsteps:
                action = env.action_space.sample()  # Sample random action
            else:
                action = agent.select_action(state)  # Sample action from policy

            if len(memory) > config.batch_size:
                for i in range(config.updates_per_step):
                    #* Update parameters of causal weight
                    if (total_numsteps % config.causal_sample_interval == 0) and (len(local_buffer)>=config.causal_sample_size):
                        causal_weight, causal_computing_time = get_sa2r_weight(env, local_buffer, agent, sample_size=config.causal_sample_size, causal_method='DirectLiNGAM')
                        print("Current Causal Weight is: ",causal_weight)
                        
                    dormant_metrics = {}
                    # Update parameters of all the networks
                    critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha, q_sac, dormant_metrics = agent.update_parameters(memory, causal_weight,config.batch_size, updates)

                    updates += 1
            next_state, reward, done, info = env.step(action) # Step
            total_numsteps += 1
            episode_steps += 1
            episode_reward += reward

            #* Ignore the "done" signal if it comes from hitting the time horizon.
            if '_max_episode_steps' in dir(env):  
                mask = 1 if episode_steps == env._max_episode_steps else float(not done)
            elif 'max_path_length' in dir(env):
                mask = 1 if episode_steps == env.max_path_length else float(not done)
            else: 
                mask = 1 if episode_steps == 1000 else float(not done)

            memory.push(state, action, reward, next_state, mask) # Append transition to memory
            local_buffer.push(state, action, reward, next_state, mask) # Append transition to local_buffer
            state = next_state

        if total_numsteps > config.num_steps:
            break

        # test agent
        if i_episode % config.eval_interval == 0 and config.eval is True:
            eval_reward_list = []
            for _  in range(config.eval_episodes):
                state = env.reset()
                episode_reward = []
                done = False
                while not done:
                    action = agent.select_action(state, evaluate=True)
                    next_state, reward, done, info = env.step(action)
                    state = next_state
                    episode_reward.append(reward)
                eval_reward_list.append(sum(episode_reward))

            avg_reward = np.average(eval_reward_list)
          
    env.close() 

程式碼流程解釋

  1. 初始化:

    • 透過配置檔案config設定環境和隨機種子。
    • 使用ACE_agent初始化強化學習智慧體,該智慧體會在後續過程中學習如何在環境中行動。
    • 建立儲存結果和檢查點的目錄,確保訓練過程中的配置和因果權重會被記錄下來。
    • 初始化了兩個重放緩衝區:memory用於儲存所有的歷史資料,local_buffer則用於因果權重的更新。
  2. 主訓練迴圈:

    • 取樣動作:如果總步數較小,則從環境中隨機取樣動作,否則從策略中選擇動作。透過這種方式,確保早期探索和後期利用。

    • 更新因果權重:在特定間隔內,從區域性緩衝區中取樣資料,透過get_sa2r_weight函式使用DirectLiNGAM演算法計算從動作到獎勵的因果權重。這個權重會作為額外資訊,幫助智慧體最佳化策略。

    • 更新網路引數:當memory中的資料足夠多時,開始透過取樣更新Q網路和策略網路,使用計算出的因果權重來修正損失函式。

    • 記錄與儲存模型:每隔一定的步數,演算法會測試當前策略的效能,記錄並比較獎勵是否超過歷史最佳值,如果是,則儲存模型的檢查點。

    • 使用wandb記錄訓練過程中的指標,例如損失函式、獎勵和因果權重的計算時間,這些資訊可以幫助除錯和分析訓練過程。


論文模組程式碼及實現

因果發現模組

因果發現模組主要透過 get_sa2r_weight 函式實現,並且與 DirectLiNGAM 模型結合,負責計算因果權重。具體程式碼在訓練迴圈中如下:

causal_weight, causal_computing_time = get_sa2r_weight(env, local_buffer, agent, sample_size=config.causal_sample_size, causal_method='DirectLiNGAM')

在這個程式碼段,get_sa2r_weight 函式會基於當前環境、樣本資料(local_buffer)和因果模型(這裡使用的是 DirectLiNGAM),計算與行動相關的因果權重(causal_weight)。這些權重會影響後續的策略最佳化和引數更新。關鍵邏輯包括:

  1. 取樣間隔:因果發現是在 total_numsteps % config.causal_sample_interval == 0 時觸發,確保只在指定的步數間隔內計算因果權重,避免每一步都進行因果計算,減輕計算負擔。
  2. 區域性緩衝區local_buffer 中儲存了足夠的樣本(config.causal_sample_size),這些樣本用於因果關係的發現。
  3. 因果方法DirectLiNGAM 是選擇的因果模型,用於從狀態、行動和獎勵之間推匯出因果關係。

因果權重計算完成後,程式會將這些權重應用到策略最佳化中,並且記錄權重及計算時間等資訊。

def get_sa2r_weight(env, memory, agent, sample_size=5000, causal_method='DirectLiNGAM'):
    ······
    return weight, model._running_time

這個程式碼的核心是利用DirectLiNGAM模型計算給定狀態、動作和獎勵之間的因果權重。接下來,用LaTeX公式詳細表述計算因果權重的過程:

  1. 資料預處理
    將從memory中取樣的states(狀態)、actions(動作)和rewards(獎勵)進行拼接,構建輸入資料矩陣 \(X_{\text{ori}}\)

    \[X_{\text{ori}} = [S, A, R] \]

    其中,\(S\) 代表狀態,\(A\) 代表動作,\(R\) 代表獎勵。接著,構建資料框 \(X\) 來進行因果分析。

  2. 因果模型擬合

    X_ori 轉換為 X 是為了利用 pandas 資料框的便利性和靈活性

    使用 DirectLiNGAM 模型對矩陣 \(X\) 進行擬合,得到因果關係的鄰接矩陣 \(A_{\text{model}}\)

    \[A_{\text{model}} = \text{DirectLiNGAM}(X) \]

    該鄰接矩陣表示狀態、動作、獎勵之間的因果結構,特別是從動作到獎勵的影響關係。

  3. 提取動作對獎勵的因果權重
    透過鄰接矩陣提取動作對獎勵的因果權重 \(w_{\text{r}}\),該權重從鄰接矩陣的最後一行中選擇與動作對應的元素:

    \[w_{\text{r}} = A_{\text{model}}[-1, \, d_s:(d_s + d_a)] \]

    其中,\(d_s\) 是狀態的維度,\(d_a\) 是動作的維度。

  4. 因果權重的歸一化
    對因果權重 \(w_{\text{r}}\) 進行Softmax歸一化,確保它們的總和為1:

    \[w = \frac{e^{w_{\text{r},i}}}{\sum_{i} e^{w_{\text{r},i}}} \]

  5. 調整權重的尺度
    最後,因果權重根據動作的數量進行縮放:

    \[w = w \times d_a \]

最終輸出的權重 \(w\) 表示每個動作對獎勵的因果影響,經過歸一化和縮放處理,可以用於進一步的策略調整或分析。

策略最佳化模組

以下是對函式工作原理的逐步解釋:

策略最佳化模組主要由 agent.update_parameters 函式實現。agent.update_parameters 這個函式的主要目的是在強化學習中更新策略 (policy) 和價值網路(critic)的引數,以提升智慧體的效能。這個函式實現了一個基於軟演員評論家(SAC, Soft Actor-Critic)的更新機制,並且加入了因果權重與"休眠"神經元(dormant neurons)的處理,以提高模型的魯棒性和穩定性。

critic_1_loss, critic_2_loss, policy_loss, ent_loss, alpha, q_sac, dormant_metrics = agent.update_parameters(memory, causal_weight, config.batch_size, updates)

透過 agent.update_parameters 函式,程式會更新以下幾個部分:

  1. Critic網路(價值網路)critic_1_losscritic_2_loss 分別是兩個 Critic 網路的損失,用於評估當前策略的價值。
  2. Policy網路(策略網路)policy_loss 表示策略網路的損失,用於最佳化 agent 的行動選擇。
  3. Entropy損失ent_loss 用來調節策略的隨機性,幫助 agent 在探索和利用之間找到平衡。
  4. Alpha:表示自適應的熵係數,用於調整探索與利用之間的權衡。

這些引數的更新在每次訓練迴圈中被呼叫,並使用 wandb.log 記錄損失和其他相關的訓練資料。

update_parametersACE_agent 類中的一個關鍵函式,用於根據經驗回放緩衝區中的樣本資料來更新模型的引數。下面是對其工作原理的詳細解釋:

1. 取樣經驗資料

首先,函式從 memory 中取樣一批樣本(state_batchaction_batchreward_batchnext_state_batchmask_batch),其中包括狀態、動作、獎勵、下一個狀態以及掩碼,用於表示是否為終止狀態。

state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)
  • state_batch:當前的狀態。
  • action_batch:在當前狀態下執行的動作。
  • reward_batch:執行該動作後獲得的獎勵。
  • next_state_batch:執行動作後到達的下一個狀態。
  • mask_batch:掩碼,用於表示是否為終止狀態(1 表示非終止,0 表示終止)。

2. 計算目標 Q 值

利用當前策略(policy)網路,取樣下一個狀態的動作 next_state_action 和其對應的機率分佈對數 next_state_log_pi。然後利用目標 Q 網路 critic_target 估計下一時刻的最小 Q 值,並結合獎勵和折扣因子 \(\gamma\) 計算下一個 Q 值:

\[{min\_qf\_next\_target} = \min(Q_1^{\text{target}}(s', a'), Q_2^{{target}}(s', a')) - \alpha \cdot \log \pi(a'|s') \\ {next\_q\_value} = r + \gamma \cdot \text{mask\_batch} \cdot {min\_qf\_next\_target} \]

with torch.no_grad():
    next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch, causal_weight)
    qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
    min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
    next_q_value = reward_batch + mask_batch * self.gamma * (min_qf_next_target)
  • 透過策略網路 self.policy 為下一個狀態 next_state_batch 取樣動作 next_state_action 和相應的策略熵 next_state_log_pi

  • 使用目標 Q 網路計算 qf1_next_targetqf2_next_target,並取兩者的最小值來減少估計偏差。

  • 最終使用貝爾曼方程計算 next_q_value,即當前的獎勵加上折扣因子 \(\gamma\) 乘以下一個狀態的 Q 值。

  • 這裡,\(\alpha\) 是熵項的權重,用於平衡探索和利用的權衡,而 mask_batch 是為了處理終止狀態的情況。

    使用無偏估計來計算目標 Q 值。透過目標網路 (critic_target) 計算出下一個狀態和動作的 Q 值,並使用獎勵和掩碼更新當前 Q 值

3. 更新 Q 網路

接著,使用當前 Q 網路 critic 估計當前狀態和動作下的 Q 值 \(Q_1\)\(Q_2\),並計算它們與目標 Q 值的均方誤差損失:

\[\text{qf1_loss} = \text{MSE}(Q_1(s, a), \text{next\_q\_value}) \\ \text{qf2\_loss} = \text{MSE}(Q_2(s, a), \text{next\_q\_value}) \]

最終 Q 網路的總損失是兩個 Q 網路損失之和:

\[\text{qf\_loss} = \text{qf1\_loss} + \text{qf2\_loss} \]

然後,透過反向傳播 qf_loss 來更新 Q 網路的引數。

qf1, qf2 = self.critic(state_batch, action_batch)
qf1_loss = F.mse_loss(qf1, next_q_value)
qf2_loss = F.mse_loss(qf2, next_q_value)
qf_loss = qf1_loss + qf2_loss

self.critic_optim.zero_grad()
qf_loss.backward()
self.critic_optim.step()
  • qf1qf2 是兩個 Q 網路的輸出,用於減少正向估計偏差。
  • 損失函式是 Q 值的均方誤差(MSE),qf1_lossqf2_loss 分別計算兩個 Q 網路的誤差,最後將兩者相加為總的 Q 損失 qf_loss
  • 透過 self.critic_optim 最佳化器對損失進行反向傳播和引數更新。

4. 策略網路更新

每隔若干步(透過 target_update_interval 控制),開始更新策略網路 policy。首先,重新取樣當前狀態下的策略 \(\pi(a|s)\),並計算 Q 值和熵權重下的策略損失:

\[\text{policy\_loss} = \mathbb{E}\left[ \alpha \cdot \log \pi(a|s) - \min(Q_1(s, a), Q_2(s, a)) \right] \]

這個損失透過反向傳播更新策略網路。

if updates % self.target_update_interval == 0:
    pi, log_pi, _ = self.policy.sample(state_batch, causal_weight)
    qf1_pi, qf2_pi = self.critic(state_batch, pi)
    min_qf_pi = torch.min(qf1_pi, qf2_pi)
    policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

    self.policy_optim.zero_grad()
    policy_loss.backward()
    self.policy_optim.step()
  • 透過策略網路對當前狀態 state_batch 進行取樣,得到動作 pi 及其對應的策略熵 log_pi
  • 計算策略損失 policy_loss,即 \(\alpha\) 倍的策略熵減去最小的 Q 值。
  • 透過 self.policy_optim 最佳化器對策略損失進行反向傳播和引數更新。

5. 自適應熵調節

如果開啟了自動熵項調整(automatic_entropy_tuning),則會進一步更新熵項 \(\alpha\) 的損失:

\[\alpha_{\text{loss}} = -\mathbb{E}\left[\log \alpha \cdot (\log \pi(a|s) + \text{target\_entropy}) \right] \]

並透過梯度下降更新 \(\alpha\)

如果 automatic_entropy_tuning 為真,則會更新熵項:

if self.automatic_entropy_tuning:
    alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()
    self.alpha_optim.zero_grad()
    alpha_loss.backward()
    self.alpha_optim.step()
    self.alpha = self.log_alpha.exp()
    alpha_tlogs = self.alpha.clone()
else:
    alpha_loss = torch.tensor(0.).to(self.device)
    alpha_tlogs = torch.tensor(self.alpha) # For TensorboardX logs
  • 透過計算 alpha_loss 更新 self.alpha,調整策略的探索-利用平衡。

6. 返回值

  • qf1_loss, qf2_loss: 兩個 Q 網路的損失
  • policy_loss: 策略網路的損失
  • alpha_loss: 熵權重的損失
  • alpha_tlogs: 用於日誌記錄的熵權重
  • next_q_value: 平均下一個 Q 值
  • dormant_metrics: 休眠神經元的相關度量

重置機制模組

重置機制模組在程式碼中主要體現在 update_parameters 函式中,並透過梯度主導度 (dominant metrics) 和擾動函式 (perturbation functions) 實現對策略網路和 Q 網路的重置。

重置邏輯

函式根據設定的 reset_interval 判斷是否需要對策略網路和 Q 網路進行擾動和重置。這裡使用了"休眠"神經元的概念,即一些在梯度更新中影響較小的神經元,可能會被調整或重置。

函式計算了休眠度量 dormant_metrics 和因果權重差異 causal_diff,透過擾動因子 perturb_factor 來決定是否對網路進行部分或全部的擾動與重置。

重置機制模組的原理

重置機制主要由以下部分組成:

1. 計算梯度主導度 ( $\beta_\gamma $)

在更新策略時,計算主導梯度,即某些特定神經元或引數在更新中主導作用的比率。程式碼中透過呼叫 cal_dormant_grad(self.policy, type='policy', percentage=0.05) 實現這一計算,代表提取出 5% 的主導梯度來作為判斷因子。

dormant_metrics = cal_dormant_grad(self.policy, type='policy', percentage=0.05)

根據主導度 ($ \beta_\gamma$ ) 和權重 ($ w$ ),可以得到因果效應的差異。程式碼裡用 causal_diff 來表示因果差異:

\[\text{causal\_diff} = \max(w) - \min(w) \]

2. 軟重置策略和 Q 網路

軟重置機制透過平滑更新策略網路和 Q 網路,避免過大的權重更新導致的網路不穩定。這在程式碼中由 soft_update 實現:

soft_update(self.critic_target, self.critic, self.tau)

具體來說,軟更新的公式為:

\[\theta_{\text{target}} = \tau \theta_{\text{source}} + (1 - \tau) \theta_{\text{target}} \]

其中,( \(\tau\) ) 是一個較小的常數,通常介於 ( [0, 1] ) 之間,確保目標網路的更新是緩慢的,以提高學習的穩定性。

3. 策略和 Q 最佳化器的重置
4. 重置機制模組的應用

每當經過一定的重置間隔時,判斷是否需要擾動策略和 Q 網路。透過呼叫 perturb()dormant_perturb() 實現對網路的擾動(perturbation)。擾動因子由梯度主導度和因果差異共同決定。

策略與 Q 網路的擾動會在以下兩種情況下發生:

a. 重置間隔達成時

程式碼中每當更新次數 updates 達到設定的重置間隔 self.reset_interval,並且 updates > 5000 時,才會觸發策略與 Q 網路的重置邏輯。這是為了確保擾動不是頻繁發生,而是在經過一段較長的訓練時間後進行。

具體判斷條件:

if updates % self.reset_interval == 0 and updates > 5000:
b. 主導梯度或因果效應差異滿足條件時

在達到了重置間隔後,首先會計算梯度主導度因果效應的差異。這可以透過計算因果差異 causal_diff 或梯度主導度 dormant_metrics['policy_grad_dormant_ratio'] 來決定是否需要擾動。

  • 梯度主導度計算方式透過 cal_dormant_grad() 函式實現,如果梯度主導度較低,意味著網路中的某些神經元更新幅度過小,則需要對網路進行擾動。

  • 因果效應差異透過計算 causal_diff = np.max(causal_weight) - np.min(causal_weight) 得到,如果差異過大,則可能需要重置。

然後根據這些值透過擾動因子 factor 進行判斷:

factor = perturb_factor(dormant_metrics['policy_grad_dormant_ratio'])

如果擾動因子 ( \(\text{factor} < 1\) ),網路會進行擾動:

if factor < 1:
    if self.reset == 'reset' or self.reset == 'causal_reset':
        perturb(self.policy, self.policy_optim, factor)
        perturb(self.critic, self.critic_optim, factor)
        perturb(self.critic_target, self.critic_optim, factor)
c. 總結
  • 更新次數達到設定的重置間隔,且經過了一定時間的訓練(updates > 5000)。
  • 梯度主導度較低或因果效應差異過大,導致計算出的擾動因子 ( $\text{factor} < 1 $)。

這兩種條件同時滿足時,策略和 Q 網路將被擾動或重置。

擾動因子的計算

在這段程式碼中,factor 是基於網路中梯度主導度或者因果效應差異計算出來的擾動因子。擾動因子透過函式 perturb_factor() 進行計算,該函式會根據神經元的梯度主導度(dormant_ratio)或因果效應差異(causal_diff)來調整 factor 的大小。

擾動因子(factor)

擾動因子 factor 的計算公式如下:

\[\text{factor} = \min\left(\max\left(\text{min\_perturb\_factor}, 1 - \text{dormant\_ratio}\right), \text{max\_perturb\_factor}\right) \]

其中:

  • (\(\text{dormant\_ratio}\)) 是網路中梯度主導度,即表示有多少神經元的梯度變化較小(或者接近零),處於“休眠”狀態。

  • (\(\text{min\_perturb\_factor}\)) 是最小擾動因子值,程式碼中設定為 0.2

  • (\(\text{max\_perturb\_factor}\)) 是最大擾動因子值,程式碼中設定為 0.9

  • dormant_ratio:

    • 表示網路中處於“休眠狀態”的梯度比例。這個比例通常透過計算神經網路中梯度幅度小於某個閾值的神經元數量來獲得。dormant_ratio 越大,表示越多神經元的梯度變化很小,說明網路更新不充分,需要擾動。
  • max_perturb_factor:

    • 最大擾動因子值,用來限制擾動因子的上限,程式碼中設定為 0.9,意味著最大擾動幅度不會超過 90%。
  • min_perturb_factor:

    • 最小擾動因子值,用來限制擾動因子的下限,程式碼中設定為 0.2,意味著即使休眠神經元比例很低,擾動幅度也不會小於 20%。

在計算因果效應的部分,擾動因子 factor 還會根據因果效應差異 causal_diff 來調整。causal_diff 是透過計算因果效應的最大值與最小值的差異來獲得的:

\[\text{causal\_diff} = \max(\text{causal\_weight}) - \min(\text{causal\_weight}) \]

計算出的 causal_diff 會影響 causal_factor,並進一步對 factor 進行調整:

\[\text{causal\_factor} = \exp(-8 \cdot \text{causal\_diff}) - 0.5 \]

組合擾動因子的公式

最後,如果選擇了因果重置(causal_reset),擾動因子將使用因果差異計算出的 causal_factor 進行二次調整:

\[\text{factor} = \text{perturb\_factor}(\text{causal\_factor}) \]

綜上所述,factor 的最終值是由梯度主導度或因果效應差異來控制的,當休眠神經元比例較大或因果效應差異較大時,factor 會減小,導致網路進行擾動。

評估程式碼

這段程式碼主要實現了在強化學習(RL)訓練過程中,定期評估智慧體(agent)的效能,並在某些條件下儲存最佳模型的檢查點。我們可以分段解釋該程式碼:

1. 定期評估條件

if i_episode % config.eval_interval == 0 and config.eval is True:

這部分程式碼用於判斷是否應該執行智慧體的評估。條件為:

  • i_episode % config.eval_interval == 0:表示每隔 config.eval_interval 個訓練回合(i_episode 是當前回合數)進行一次評估。
  • config.eval is True:確保 eval 設定為 True,也就是說,評估功能開啟。

如果滿足這兩個條件,程式碼將開始執行評估操作。

2. 初始化評估列表

eval_reward_list = []

用於儲存每個評估回合(episode)的累計獎勵,以便之後計算平均獎勵。

3. 進行評估

for _ in range(config.eval_episodes):

評估階段將執行多個回合(由 config.eval_episodes 指定的回合數),以獲得智慧體的表現。

3.1 回合初始化
state = env.reset()
episode_reward = []
done = False
  • env.reset():重置環境,獲得初始狀態 state
  • episode_reward:初始化一個列表,用於儲存當前回合中智慧體獲得的所有獎勵。
  • done = False:用 done 來跟蹤當前回合是否結束。
3.2 執行智慧體動作
while not done:
    action = agent.select_action(state, evaluate=True)
    next_state, reward, done, info = env.step(action)
    state = next_state
    episode_reward.append(reward)
  • 動作選擇agent.select_action(state, evaluate=True) 在評估模式下根據當前狀態 state 選擇動作。evaluate=True 表示該選擇是在評估模式下,通常意味著探索行為被關閉(即不進行隨機探索,而是選擇最優動作)。

  • 環境反饋next_state, reward, done, info = env.step(action) 透過執行動作 action,環境返回下一個狀態 next_state,當前獎勵 reward,回合是否結束的標誌 done,以及附加資訊 info

  • 狀態更新:當前狀態被更新為 next_state,並將獲得的獎勵 reward 儲存在 episode_reward 列表中。

迴圈持續,直到回合結束(即 done == True)。

3.3 儲存回合獎勵
eval_reward_list.append(sum(episode_reward))

當前回合結束後,累計獎勵(sum(episode_reward))被新增到 eval_reward_list,用於後續計算平均獎勵。

4. 計算平均獎勵

avg_reward = np.average(eval_reward_list)

在所有評估回合結束後,計算 eval_reward_list 的平均值 avg_reward。這是當前評估階段智慧體的表現指標。

5. 儲存最佳模型

if config.save_checkpoint:
    if avg_reward >= best_reward:
        best_reward = avg_reward
        agent.save_checkpoint(checkpoint_path, 'best')
  • 如果 config.save_checkpointTrue,則表示需要檢查是否儲存模型。
  • 透過判斷 avg_reward 是否超過了之前的最佳獎勵 best_reward,如果是,則更新 best_reward,並儲存當前模型的檢查點。
agent.save_checkpoint(checkpoint_path, 'best')

這行程式碼會將智慧體的狀態儲存到指定的路徑 checkpoint_path,並標記為 "best",表示這是效能最佳的模型。

論文復現結果

強化學習筆記之【論文精讀】【ACE:一種基於熵規整和因果關係的離線SAC演算法】

咳咳,可以發現程式只記錄了 0~1000 的資料,從 1001 開始的每一個資料都顯示報錯所以被捨棄掉了。

後面重新下載了github程式碼包,發生了同樣的報錯資訊

報錯資訊是:你在 X+1 輪次中嘗試記載 X 輪次中的資訊,所以這個資料被捨棄掉了

大概是主程式哪裡有問題吧,我自己也沒調 bug

不過這個專案結題了,主要負責這個專案的博士師兄也畢業了,也不好說些什麼(雖然我有他微信),至少論文裡面的模組挺有用的啊(手動滑稽)

強化學習筆記之【論文精讀】【ACE:一種基於熵規整和因果關係的離線SAC演算法】

相關文章