(四)詳解RLHF

jasonzhangxianrong發表於2024-06-27

一直都特別好奇大模型的強化學習微調是怎麼做的,網上雖然相關文章不少,但找到的文章都是淺嘗輒止說到用PPO訓練,再細緻深入的就沒有講了。。。只能自己看一看程式碼,以前搞過一點用PPO做遊戲,感覺和語言模型PPO的用法不太一樣。在遊戲場景,每個step給環境一個action之後,agent拿到的state都是會變化的,通常也會設計獎勵函式使得每個step都會有reward;但是在用強化學習微調語言模型這裡,prompt是state,只輸入一次,然後輸出一串action(回答的單詞),得到一個reward,模型並沒有在每個action之後得到新的state(感謝評論區大佬的點撥,對於answer的第二個詞,可以把prompt+answer的一個詞當作新的state,而不只是把prompt當作state,狀態轉移蘊含在transformer內部)

本篇文章並不會介紹太多PPO的原理,相關文章已經很多了,比如李宏毅介紹PPO的課程。大模型裡邊的PPO涉及到了critic model的概念,在李宏毅教程裡只提了一下並沒有細講,如果想了解可以看一下這個文章,相當於利用一個critic model預測從t時刻到最後一個時刻的累加獎勵值(強化學習裡邊的第t個時刻對標answer句子裡邊的第t個單詞),而不是透過實際累加得到從t時刻到最後一個時刻的累加獎勵值,這樣可以降低獎勵的方差。下文也結合程式碼介紹critic model輸出的具體含義。同時RLHF是什麼也會再詳細介紹,相關文章已經很多了。

本篇文章涉及的程式碼均來自微軟的deepspeed對RLHF的實現,可配合huggingface官方的部落格一起食用。本文只對演算法的一些有特點的關鍵點進行闡述,並不對整體實現進行介紹。先上一張經典的論文圖。本文重點結合程式碼講解獎勵模型訓練和強化學習訓練部分。

(四)詳解RLHF

獎勵(reward)模型訓練

首先要宣告的是,在強化學習階段,用到的reward model和critic model都使用同一個模型初始化,因此在訓練reward模型的過程中,也是在訓練critic model。其次對符號進行說明,大模型中間隱藏層的引數維度為(B,L,D),B為batch size大小,L為句子長度,D為embedding維度。在接下來的程式碼講解中,我也會標明程式碼中各個變數的維度,以更好的理解其意義。

在進行RLHF時,需要一個獎勵模型來評估語言大模型(actor model)回答的是好是壞,這個獎勵模型通常比被評估的語言大模型小一些(deepspeed的示例中,語言大模型66B,獎勵模型只有350M)。獎勵模型的輸入是prompt+answer的形式,讓模型學會對prompt+answer進行打分。獎勵模型最後一層隱藏層的輸出維度為(B,L,D),透過一個D✖️1的全連線層將維度變為(B, L),在L這個維度上,第i個位置的資料表示:從第i個位置到最後一個位置輸出所能獲得的獎勵分值的累加和(和DQN裡邊的Q值一個意義),這種形式的輸出滿足了critic model的輸出要求。對應程式碼如下:

#huggingface模型返回值是個list,第0位是模型最後輸出的hideen state
hidden_states = transformer_outputs[0]
# v_head為Dx1的全連線網路對最後一維壓縮
rewards = self.v_head(hidden_states).squeeze(-1)

對於一個獎勵模型來說,目標是給一個句子進行打分,按理說每個句子對應一個分值就行了,但是目前對於長度為L的句子,獎勵模型輸出了L個值。我們用L維度上的最後一個位置的值當作為本句話的獎勵得分。獎勵模型訓練最佳化採用pair wiss loss,即同時輸入模型關於同一個問題的兩個回答,讓模型學會這兩個句子哪個分高哪個分低。之所以如此訓練是因為,在給獎勵模型進行資料標註的過程中,給同一個問題的不同回答量化的打具體分值比較難,但是對他們進行排序相對簡單,程式碼如下:

# 同一個batch裡邊的句子需要等長,短句後邊會被padding
# [divergence_ind:end_ind]索引了padding前一個位置的輸出分值
# chosen_reward是同一個句子pair裡分數高的句子,r_truncated_reward是句子pair裡分數低的句子
c_truncated_reward = chosen_reward[divergence_ind:end_ind]
r_truncated_reward = rejected_reward[divergence_ind:end_ind]

pair wise loss程式碼如下,如果給pair裡邊好的句子打分高(c_truncated_reward),壞的句子(r_truncated_reward)打分低,loss就會小:

loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()

在訓練強化學習的過程中,會用到reward model(critic model,再次提醒,critic model和reward model是同一個模型的兩個副本)的推理過程,透過呼叫forward_value實現,具體程式碼如下,返回的值中有兩種值,values表示每個位置i,從第i個位置到最後一個位置的獎勵累加值,供強化學習過程中critic model使用;“chosen_end_scores”指的是對每個prompt+answer的打分,供reward model使用。

def forward_value(...):
    ...
    if return_value_only:
        #(B,L)
        return values
    else:
        ...
        return {
            "values": values,
            # (B,)
            "chosen_end_scores": torch.stack(chosen_end_scores),
        }

強化學習微調

強化學習微調階段,會用到4個模型,actor model, ref_model,reward model和critic model(好費視訊記憶體啊!!!)。其中actor model和ref_model是RLHF第一個階段有監督微調模型的兩個副本,reward model和critic model是本文第一部分訓練出來的模型的兩個副本。整體流程見這篇文件,整體流程圖如下所示(沒畫出critic model):

(四)詳解RLHF

首先說明actor model的訓練模式和推理模式的區別( 後邊會用到)。訓練模式是用teacher force的方式(不明白的同學知乎搜一下),將整句話輸入到模型中,並透過mask機制在保證不洩漏未來的單詞情況下預測下一個單詞。推理模式是真正的自迴歸,預測出下一個單詞之後,當作下一步輸入再預測下下個單詞,原理如下圖所示:

(四)詳解RLHF

首先用actor model在推理模式下根據prompt生成一個answer(prompt對應強化學習裡邊的state,answer對應一些列的action),程式碼如下:

# 保證不觸發反向傳播
with torch.no_grad():
    seq = self.actor_model.module.generate(prompts,
    max_length=max_min_length,
    min_length=max_min_length)

然後利用reward model和ciric model對輸出的prompt+answer進行打分(PPO訓練時使用的獎勵值並不單單是reward model的輸出還要考慮kl散度,後文介紹):

# 獎勵模型返回的是個字典,key為chosen_end_scores位置儲存資料維度為(B,),表示對於prompt+answer的打分
reward_score = self.reward_model.forward_value(
                seq, attention_mask,
                prompt_length=self.prompt_length)['chosen_end_scores'].detach(
                )
#critic model返回的資料維度為(B,L),L維度上第i個位置代表從i位置到最後的累積獎勵
#捨去最後一個位置是因為句子“終止符”無意義 
values = self.critic_model.forward_value(
                seq, attention_mask, return_value_only=True).detach()[:, :-1]

actor model是我們想透過強化學習微調的大模型,但是強化學習過程很容易把模型訓練“壞”,因此需要另外一個不會引數更新的 ref_model來當作標的,別讓actor mode跑偏太遠。我們在訓練模式下,將prompt+answer分別輸入到actor mode和ref model,用KL散度來衡量 ref model和actor mode輸出的差別。同時將KL散度(衡量資料分佈差距大小)納入損失函式(KL散度本質是納入到獎勵值裡邊的,獎勵值被納入到了損失函式),進而來約束 ref_model和actor mode的輸出分佈別差距太大。具體程式碼如下:

# 得到兩個模型的輸出
output = self.actor_model(seq, attention_mask=attention_mask)
output_ref = self.ref_model(seq, attention_mask=attention_mask)
logits = output.logits
logits_ref = output_ref.logits
...
return {
...
# 分別得到兩個模型在真實單詞上的預測機率
'logprobs': gather_log_probs(logits[:, :-1, :], seq[:, 1:]),
'ref_logprobs': gather_log_probs(logits_ref[:, :-1, :], seq[:,1:]),
...
}
...
# 計算kl散度,log_probs裡邊存的數字經過log變化了,因此減法就對應除法
kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)

PPO訓練時候的獎勵值綜合考慮KL散度和reward模型的輸出,只考慮answer部分的KL散度,將reward model的輸出加到KL散度L維度的最後一個位置上,得到最終的獎勵值,程式碼如下:

rewards = kl_divergence_estimate
# 只考慮answer部分的獎勵,不考慮prompt
start = prompts.shape[1] - 1
# 不考慮padding部分
ends = start + action_mask[:, start:].sum(1)
reward_clip = torch.clamp(reward_score, -self.clip_reward_value,
                         self.clip_reward_value)
batch_size = log_probs.shape[0]
# 在L維度上,每個位置都有KL散度,但是隻在最後一個位置加上獎勵值
for j in range(batch_size):
    rewards[j, start:ends[j]][-1] += reward_clip[j]

接下來的內容就是PPO的訓練過程的比較核心的內容了,目標是計算PPO更新公示裡邊的advantage,具體公式如下,V就是critic model的輸出。如果原理不懂建議先到這個連結看看。我直接在程式碼中給註釋了。

(四)詳解RLHF
圖片出處:https://huggingface.co/blog/deep-rl-a2c
def get_advantages_and_returns(self, values, rewards, start):
    # values(B,L) critic model輸出
    # rewards(B,L)reward 包含kl散度
    # start answer開始的位置
    # Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134
    lastgaelam = 0
    advantages_reversed = []
    length = rewards.size()[-1]
    # 計算每個時刻(序列位置)的critic model預測誤差
    for t in reversed(range(start, length)):
        nextvalues = values[:, t + 1] if t < length - 1 else 0.0
        # critic model預測的是t到到最後一個時刻的獎勵和,所以變化量delta可以用如下公式表示
        delta = (rewards[:, t] + self.gamma * nextvalues) - values[:, t]
        # self.gamma=1,self.lam=0.95是衰減因子,表示之前計算的delta對現在影響越來越小
        lastgaelam = delta + self.gamma * self.lam * lastgaelam
        advantages_reversed.append(lastgaelam)
    advantages = torch.stack(advantages_reversed[::-1], dim=1)
    # 後續用來更新critic model用
    returns = advantages + values[:, start:]
    return advantages.detach(), returns

以上過程,我們已經拿到了PPO訓練所需要的advantage以及actor model的輸出,我先現在可以對actor model進行訓練啦。具體程式碼如下。logprobs和old_logprobs這兩個引數分別是“老actor(n個epoch才會更新一次)”和新actor(每個batch都會更新它)”在正確單詞上出處的機率,這塊時PPO import sampling相關的知識,就不在這重複介紹了,不明白的同學補習一下哈。借用一下李宏毅老師的PPO公式:

(四)詳解RLHF
def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):
    ## policy gradient loss
    #logprobs, old_logprobs都是經過log變化的單詞機率,這裡帶著log做減法就相當於在做機率除法
    log_ratio = (logprobs - old_logprobs) * mask
    # 指數操作去掉log
    ratio = torch.exp(log_ratio)
    pg_loss1 = -advantages * ratio
    pg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,
                                            1.0 + self.cliprange)
    pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()
    return pg_loss

同樣的,我們也要對critic model進行訓練,更新,loss就是mse loss。

def critic_loss_fn(self, values, old_values, returns, mask):
    ## value loss
    # 用“老critic model”的輸出約束“新critic model”不要步子太大,裁剪一下
    values_clipped = torch.clamp(
        values,
        old_values - self.cliprange_value,
        old_values + self.cliprange_value,
    )
    vf_loss1 = (values - returns)**2
    vf_loss2 = (values_clipped - returns)**2
    vf_loss = 0.5 * torch.sum(
        torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()
    return vf_loss

至此,我們的RLHF訓練流程就結束了。第二部分開頭我們說過,共涉及actor model, ref_model,reward model和critic model這四個模型,其實更新引數的模型只有actor model和critic model。

轉自:詳解大模型RLHF過程(配程式碼解讀) - 知乎 (zhihu.com)

相關文章