Contrastive Learning 對比學習 | RL 學 representation 時的對比學習

MoonOut發表於2024-11-26

記錄一下讀的三篇相關文章。

01. Representation Learning with Contrastive Predictive Coding

  • arxiv:https://arxiv.org/abs/1807.03748 ,2018 年的文章。
  • 參考部落格:知乎 | 理解 Contrastive Predictive Coding 和 NCE Loss
  • (發現 lilian weng 也寫過 對比學習的部落格

1.1 文章解讀

這篇文章的主要思想是,我們維護一個 discriminator,負責判斷兩個東西是否是一致的(也可認為是一個判斷相似性的函式);比如,我的 encoding 和我下一時刻的 encoding(這篇文章所做的),兩個相同類別的樣本,兩個正樣本,我的 encoding 和我資料增強後的 encoding 等等。

在這篇文章(CPC)裡,我們定義 discriminator 是 \(f_k(x_{x+k},c_t)=\exp(z_{x+k}^TW_kc_t)\),這個函式大概計算了 z 和 c 的內積。其中,\(z_{x+k}\)\(x_{x+k}\) 真實值的 encoding,而 \(c_t\) 是序列預測模型(比如說 RNN 或 LSTM)最後一步的 hidden 值,我們一般用這個值來預測。

這篇文章的 loss function 是

\[L_N = - E\left[\log\frac{f_k(x_{x+k},c_t)}{\sum _{x_j\in X} f_k(x_j,c_t)}\right] \]

這是一種 maximize [exp / Σ exp] 的形式。(照搬原部落格)怎麼理解這個 loss function 呢,\(p(x_{t+k}|c_t)\) 指的是,我們選正在用的那個聲音訊號的 \(x_{t+k}\) ,而 \(p(x_{j})\) 指的是我們可以隨便從其他的聲音訊號裡選擇一個片段。

回憶一下,我們剛才說過, \(f_k()\) 其實是在計算 \(c_t\) 的預測和 \(x_{t+k}\) (未來值)符不符合。那麼對於隨便從其他聲音訊號裡選出的 \(x_j\)\(f_k(x_j,c_t)\) 應是相對較小的。

在具體實踐時,大家常常在對一個 batch 進行訓練時,把當前 sample 的 \((x_{t+k}^i,c_t^i)\)(這裡上標表示 sample 的 id)當作 positive pair,把 batch 裡其他 samples 和當前 sample 的預測值配對 \((x_{t+k}^j,c_t^i)\) 作為 negative pair (注意上標)。

1.2 個人理解

這篇文章主要在說 InfoNCE loss。InfoNCE loss 大概就是 maximize [exp / Σ exp] 的形式,公式:

\[L_\text{InfoNCE} = - E\left[\log\frac{\exp(z^T_{x+k}Wc_t)}{\sum _{x_j\in X} \exp(z^T_{j}Wc_t)}\right] \]

這貌似是比較現代的對比學習 loss function。還有一些比較古早的 loss function 形式,比如 Contrastive loss(Chopra et al. 2005),它希望最小化同類樣本(\(y_i=y_j\))的 embedding 之間的距離,而最大化不同類樣本的 embedding 距離:

\[L(x_i,x_j) = \mathbb 1[y_i=y_j] \big\|f(x_i)-f(x_j)\big\| + \mathbb 1[y_i\neq y_j] \max\big(0,\epsilon- \|f(x_i)-f(x_j)\| \big) \]

第一項代表,如果是同類別樣本,則希望最小化它們 embedding 之間的距離;第二項代表,如果是不同類樣本,則希望最大化 embedding 距離,但不要超過 ε,ε 是超引數,表示不同類之間的距離下限。

Triplet Loss 三元組損失(FaceNet ,Schroff et al. 2015) :

\[L_\text{triplet}(x,x^+,x^-) = \sum_{x\in X} \max\big( 0, \|f(x)-f(x^+)\| - \|f(x)-f(x^+)\| + \epsilon \big) \]

其中,x 是 anchor,x+ 是正樣本,x- 是負樣本。我們希望 x 靠近 x+、遠離 x-。可以理解為,我們希望最大化 \(\|f(x)-f(x^+)\| - \|f(x)-f(x^+)\| - \epsilon\) ,即,anchor 離負樣本的距離應該大於 anchor 離正樣本的距離,距離差超過一個超引數 margin ε。

02. CURL: Contrastive Unsupervised Representations for Reinforcement Learning

  • arxiv:https://arxiv.org/pdf/2004.04136 ,ICML 2020。
  • GitHub:https://www.github.com/MishaLaskin/curl

curl 也應用了這種 maximize [exp / Σ exp] 的形式,它的 loss function 是:

\[L_q=\log\frac{\exp⁡(q^TWk_+)}{\exp⁡(q^TWk_+) + \sum_{i=0}^{K−1}\exp⁡(q^TWk_i)} \]

其中,q 是 query,貌似也可理解為 anchor,k 是 key,k+ 是正樣本,ki 是負樣本。anchor 和正樣本 貌似都是影像裁剪得到的。

Refer to caption

key encoder 的引數是 query encoder 的引數的 moving average,\(\theta_k=m\theta_k+(1-m)\theta_q\)

HIM 中,curl 是一個 baseline,HIM curl 的正樣本是 adding gaussian perturbation ∼ N (µ = 0.0, σ = 0.1) 得到的。

03. Representation Matters: Offline Pretraining for Sequential Decision Making

做了很多 RL 相關的 representation learning 的 review 和技術比較,比較了各種實現在 imitation learning、offline RL 和 offline 2 online RL 上的效果。

arxiv:https://arxiv.org/pdf/2102.05815



相關文章