遷移學習:互資訊的變分上下界

orion發表於2023-09-21

1 導引

在機器學習,尤其是涉及異構資料的遷移學習/聯邦學習中,我們常常會涉及互資訊相關的最佳化項,我上半年的第一份工作也是致力於此(ArXiv連結:FedDCSR)。其思想雖然簡單,但其具體的估計與最佳化手段而言卻大有門道,我們今天來好好總結一下,也算是對我研一下學期一個收尾。

我們知道,隨機變數\(X\)\(Y\)的互資訊定義為其聯合分佈(joint)\(p(x, y)\)和其邊緣分佈(marginal)的乘積\(p(x)p(y)\)之間的KL散度(相對熵)[1]

\[\begin{aligned} I(X ; Y) &= D_{\text{KL}}\left(p(x, y) \parallel p(x)p(y)\right) \\ &=\mathbb{E}_{p(x, y)}\left[\log \frac{p(x, y)}{p(x)p(y)}\right] \end{aligned} \tag{1} \]

直觀地理解,互資訊表示一個隨機變數包含另一個隨機變數資訊量(即統計依賴性)的度量;同時,互資訊也是在給定另一隨機變數知識的條件下,原隨機變數不確定度的縮減量,即\(I(X; Y) = H(X) - H(X \mid Y) = H(Y) - H(Y\mid X)\)。當\(X\)\(Y\)一一對應時,\(I(X; Y) = H(X) = H(Y)\);當\(X\)\(Y\)相互獨立時\(I(X; Y)=0\)

在機器學習的情境下,聯合分佈\(p(x, y)\)一般是未知的,因此我們需要用貝葉斯公式將其繼續轉換為如下形式:

\[\begin{aligned} I(X ; Y) &\overset{(1)}{=}\mathbb{E}_{p(x, y)}\left[\log \frac{p(x \mid y)}{p(x)}\right] \overset{(2)}{=}\mathbb{E}_{p(x, y)}\left[\log \frac{p(y \mid x)}{p(y)}\right] \end{aligned} \tag{2} \]

那麼轉換為這種形式之後,我們是否就可以開始對其進行估計了呢?答案是否定的。我們假設現在是深度表徵學習場景,\(X\)是資料,\(Y\)是資料的隨機表徵,則對於第\((1)\)種形式來說,條件機率分佈\(p(x|y)=\frac{p (y|x)p(x)}{\int p(y|x)p(x)dx}\)是難解(intractable)的(由於\(p(x)\)未知);而對於第\((2)\)種形式而言,邊緣分佈\(p(y)\)也需要透過積分\(p(y)=\int p(y \mid x)p(x)d x\)來進行計算,而這也是難解的(由於\(p(x)\)未知)。為了解決互資訊估計的的難解性,我們的方法是不直接對互資訊進行估計,而是採用變分近似的手段,來得出互資訊的下界/上界做為近似,轉而對互資訊的下界/上界進行最大化/最小化[2]

2 互資訊的變分下界(對應最大化)

我們先來看互資訊的變分下界。我們常常透過最大化互資訊的下界來近似地對其進行最大化。具體而言,按照是否需要解碼器,我們可以將互資訊的下界分為兩類,分別對應變分資訊瓶頸(解碼項)[3][4]Deep InfoMax[5][6]這兩種方法。

2.1 資料VS表徵:變分資訊瓶頸(解碼項)

對於互資訊的第\((1)\)種表示法即\(I(X ; Y){=}\mathbb{E}_{p(x, y)}\left[\log \frac{p(x \mid y)}{p(x)}\right]\),我們已經知道條件分佈\(p(x|y)\)是難解的,那麼我們就採用變分分佈\(q(x|y)\)將其轉變為可解(tractable)的最佳化問題。這樣就可以匯出互資訊的Barber & Agakov下界(由於KL散度的非負性):

\[\begin{aligned} I(X ; Y)= & \mathbb{E}_{p(x, y)}\left[\log \frac{q(x \mid y)}{p(x)}\right]+\mathbb{E}_{p(y)}\left[D_{\text{K L}}(p(x \mid y) \mid q(x \mid y))\right] \\ \geq & \mathbb{E}_{p(x, y)}[\log q(x \mid y)]+H(X) \triangleq I_{\mathrm{BA}}, \end{aligned} \tag{3} \]

這裡\(H(X)\)\(X\)的微分熵,BA是論文[7]兩位作者名字的縮寫。當\(q(x|y)=p(x|y)\)時,該下界是緊的,此時上式的第一項就等於條件熵\(H(X|Y)\)

上式可不可解取決於微分熵\(H(X)\)是否已知。幸運的是,限定在 \(X\)是資料,\(Y\)是表徵 的場景下,\(H(X)=\mathbb{E}_{x\sim p(x)} \log p(x)\)僅涉及資料生成過程,和模型無關。這意味著我們只需要最大化\(I_{\text{BA}}\)的第一項,而這可以理解為最小化VAE中的重構誤差(失真,distortion)。此時,\(I_{\text{BA}}\)的梯度就與“編碼器”\(p(y|x)\)和變分“解碼器”\(q(x|y)\)相關,而這是易於計算的。因此,我們就可以使用該目標函式來學習一個最大化\(I(X; Y)\)的編碼器\(p(y|x)\),這就是大名鼎鼎的變分資訊瓶頸(variational information bottleneck) 的思想(對應其中的解碼項部分)。

2.2 表徵VS表徵:Deep Infomax

我們在 2.1 中介紹的方法雖然簡單好用,但是需要構建一個易於計算的解碼器\(q(x|y)\),這在\(X\)是資料,\(Y\)是表徵的時候非常容易,然而當 \(X\)\(Y\)都是表徵 的時候就直接寄了,首先是因為解碼器\(q(x|y)\)是難以計算的,其次微分熵\(H(X)\)也是未知的。為了匯出不需要解碼器的可解下界,我們轉向去思考\(q(x|y)\)變分族的的非標準化分佈(unnormalized distributions)。

我們選擇一個基於能量的變分族,它使用一個判別函式/網路(critic)\(f(x, y): \mathcal{X} \times \mathcal{Y}\rightarrow \mathbb{R}\),並經由資料密度\(p(x)\)縮放:

\[q(x \mid y)=\frac{p(x)}{Z(y)} e^{f(x, y)}, \text { where } Z(y)=\mathbb{E}_{p(x)}\left[e^{f(x, y)}\right]\tag{4} \]

我們將該分佈代入公式\((3)\)中的\(I_{\text{BA}}\)中,就匯出了另一個互資訊的下界,我們將其稱為UBA下界(記作\(I_{\text{UBA}}\)),可視為Barber & Agakov下界的非正太分佈版本(Unnormalized version):

\[\mathbb{E}_{p(x, y)}[f(x, y)]-\mathbb{E}_{p(y)}[\log Z(y)] \triangleq I_{\mathrm{UBA}} \tag{5} \]

\(f(x, y)=\log p(y|x) + c(y)\)時,該上界是緊的,這裡\(c(y)\)僅僅是關於\(y\)的函式(而非\(x\))。注意在代入過程中難解的微分熵\(H(X)\)被消掉了,但我們仍然剩下一個難解的\(\log\)配分函式\(\log Z(y)\),它妨礙了我們計算梯度與評估。如果我們對\(\mathbb{E}_{p(y)}[\log Z(y)]\)這個整體應用Jensen不等式(\(\log\)為凹函式),我們能進一步匯出式\((5)\)的下界,即大名鼎鼎的Donsker & Varadhan下界[7]

\[I_{\mathrm{UBA}} \geq \mathbb{E}_{p(x, y)}[f(x, y)]-\log \mathbb{E}_{p(y)}[Z(y)] \triangleq I_{\mathrm{DV}} \tag{6} \]

然而,該目標函式仍然是難解的。接下來我們換個角度,我們不對\(\mathbb{E}_{p(y)}[\log Z(y)]\)這個整體應用Jensen不等式,而考慮對裡面的\(\log Z(y)\)應用Jensen不等式即\(\log Z(y)=\log \mathbb{E}_{p(x)}\left[e^{f(x, y)}\right]\geq\mathbb{E}_{p(x)}\left[\log e^{f(x, y)}\right]=\mathbb{E}_{p(x)}\left[f(x, y)\right]\),那麼我們就可以匯出式\((5)\)的上界來對其進行近似:

\[I_{\mathrm{UBA}} \leq \mathbb{E}_{p(x, y)}[f(x, y)]-\mathbb{E}_{p(x)p(y)}\left[f(x, y)\right]\triangleq I_{\mathrm{MINE}} \tag{7} \]

然而式\((5)\)本身做為互資訊的下界而存在,因此\(I_{\text{MINE}}\)嚴格意義上講既不是互資訊的上界也不是互資訊的下界。不過這種方法可視為採用期望的蒙特卡洛近似來評估\(I_{\text{DV}}\),也就是作為互資訊下界的無偏估計。已經有工作證明了這種巢狀蒙特卡洛估計器的收斂性和漸進一致性,但並沒有給出在有限樣本下的成立的界[8][9]

\(I_{\text{MINE}}\)思想的基礎之上,論文Deep Infomax[6]又向前推進了一步,認為我們無需死抱著資訊的KL散度形式不放,可以大膽採用非KL散度的形式。事實上,我們主要感興趣的是最大化互資訊,而不關心它的精確值,於是採用非KL散度形式可以為我們提供有利的trade-off。比如我們就可以基於\(p(x, y)\)\(p(x)p(y)\)Jensen-Shannon散度(JSD),來定義如下的JS互資訊估計器:

\[ I_{\text{JSD}} \triangleq \mathbb{E}_{p(x, y)}\left[-\operatorname{sp}\left(-f\left(x, y\right)\right)\right]-\mathbb{E}_{p(x^{\prime})p(y)}\left[\operatorname{sp}\left(f\left(x^{\prime}, y\right)\right)\right], \tag{8} \]

這裡\(x\)是輸入樣本,\(x\prime\)是採自\(p(x^{\prime}) = p(x)\)的負樣本,\(\text{sp}(z) = \log (1+e^x)\)\(\text{softplus}\)函式。這裡判別網路\(f\)被最佳化來能夠區分來自聯合分佈的樣本對(正樣本對)和來自邊緣乘積分佈的樣本對(負樣本對)。

此外,噪聲對比估計(NCE)[10]做為最先被採用的互資訊下界(被稱為“InfoNCE”),也可以用於互資訊最大化:

\[I(X, Y)\geq \mathbb{E}_{p(x, y)}\left[f\left(x, y\right)-\mathbb{E}_{p(x^{\prime})}\left[\log \sum_{x^{\prime}} e^{f\left(x^{\prime}, y\right)}\right]\right]\triangleq I_{\text{InfoNCE}} \tag{9} \]

對於Deep Infomax而言,\(I_{\text{JSD}}\)\(I_{\text{InfoNCE}}\)形式的之間差別在於負樣本分佈\(p(x^{\prime})\)的期望是套在正樣本分佈\(p(x, y)\)期望的裡面還是外面,而這個差別就意味著對於\(\text{DV}\)\(\text{JSD}\)而言一個正樣本只需要一個負樣本,但對於\(\text{InfoNCE}\)而言就是一個正樣本就需要\(N\)個負樣本(\(N\)為batch size)。此外,也有論文[6]分析證明了\(I_{\text{JSD}}\)對負樣本的數量不敏感,而\(I_{\text{InfoNCE}}\)的表現會隨著負樣本的減少而下降。

3 互資訊的變分上界(對應最小化)

我們接下來來看互資訊的變分上界。我們常常透過最小化互資訊的上界來近似地對互資訊進行最小化。具體而言,按照是否需要編碼器,我們可以將互資訊的下界分為兩類,而這兩個類別分別就對應了變分資訊瓶頸的編碼項[4]解耦表徵學習[11]

3.1 資料VS表徵:變分資訊瓶頸(編碼項)

對於互資訊的第\((2)\)種表示法即\(I(X ; Y){=}\mathbb{E}_{p(x, y)}\left[\log \frac{p(y \mid x)}{p(y)}\right]\),我們已經知道邊緣分佈\(p(y)=\int p(y \mid x)p(x)d x\)是難解的。但是限定在 \(X\)是資料,\(Y\)是表徵 的場景下,我們能夠透過引入一個變分近似\(q(y)\)來構建一個可解的變分下界:

\[\begin{aligned} I(X ; Y) & \equiv \mathbb{E}_{p(x, y)}\left[\log \frac{p(y \mid x)}{p(y)}\right] \\ & \overset{(1)}{=}\mathbb{E}_{p(x, y)}\left[\log \frac{p(y \mid x) q(y)}{q(y) p(y)}\right] \\ & \overset{(2)}{=}\mathbb{E}_{p(x, y)}\left[\log \frac{p(y \mid x)}{q(y)}\right]-D_{\text{K L}}(p(y) \| q(y)) \\ & \overset{(3)}{\leq} \mathbb{E}_{p(x)}\left[D_{\text{K L}}(p(y \mid x) \| q(y))\right] \triangleq R, \end{aligned} \tag{10} \]

注意上面的\((1)\)是分子分母同時乘以\(q(y)\)\((2)\)是單獨配湊出KL散度;\((3)\)是利用KL散度的非負性(證明變分上下界的常用技巧)。最後得到的這個上界我們在生成模型在常常被稱為Rate[12](也就是率失真理論裡的那個率),故這裡記為\(R\)。當\(q(y)=p(y)\)時該上界是緊的,且該上界要求\(\log q(y)\)是易於計算的。該變分上界經常在深度生成模型(如VAE)[13][14] 被用來限制隨機表徵的容量。在變分資訊瓶頸[4]這篇論文中,該上界被用於防止表徵攜帶更多與輸入有關,但卻和下游分類任務無關的資訊(即對應其中的編碼項部分)。

3.2 表徵VS表徵:解耦表徵學習

上面介紹的方法需要構建一個易於計算的編碼器\(p(y|x)\),但應用場景也僅限於在\(X\)是資料,\(Y\)是表徵的情況下,當 \(X\)\(Y\)都是表徵 的時候(即對應解耦表徵學習的場景)也會遇到我們在2.2中所面臨的問題,從而不能夠使用了。那麼我們能不能效仿2.2中的做法,對匯出的\(I_{\text{JSD}}\)\(I_{\text{InfoNCE}}\)加個負號,從而將互資訊最大化轉換為互資訊最小化呢?當然可以但是效果不會太好。因為對於兩個分佈而言,拉近它們距離的結果是確定可控的,但直接推遠它們距離的結果就是不可控的了——我們無法掌控這兩個分佈推遠之後的具體形態,導致任務的整體表現受到負面影響。那麼有沒有更好的辦法呢?

我們退一步思考:最小化互資訊\(I(X, Y)\)的難點在於\(X\)\(Y\)都是隨機表徵,那麼我們可以嘗試引入資料隨機變數\(D\),使得互資訊\(I(X, Y)\)可以進一步拆分為\(D\)\(X\)\(Y\)之間的互資訊(如\(I(D; X)\)以及\(I(D; Y)\)。已知三個隨機變數的互資訊(稱之為Interation information[1])的定義如下:

\[\begin{aligned} I(X ; Y ; D)&\overset{(1)}{=}I(X ; Y)-I(X ; Y \mid D)\\ &\overset{(2)}{=}I(X ; D)-I(X ; D \mid Y)\\ &\overset{(3)}{=}I(Y ; D)-I(Y ; D \mid X) \end{aligned} \tag{11} \]

聯立上述的等式\((1)\)和等式\((2)\),我們有:

\[ I(X ; Y) = I(X; D) - I (X ; D \mid Y) + I(X; Y \mid D) \tag{12} \]

在解耦表徵學習中,由於關於表徵後驗分佈\(q\)滿足結構化假設\(q\left(X \mid D\right)=q\left(X \mid D, Y\right)\),因此上述等式的最後一項就消失了:

\[\begin{aligned} I\left(X ; Y \mid D\right) &= H\left(X \mid D\right)-H\left(X \mid D, Y\right)\\ &=H\left(X \mid D\right)-H\left(X \mid D\right)=0 \end{aligned} \tag{13} \]

這樣我們就有:

\[\begin{aligned} I\left(X ; Y\right) &\overset{(1)}{=}I\left(D; X\right)-I\left(D ; X \mid Y\right) \\ & \overset{(2)}{=}I\left(D ; X\right)+I\left(D ; Y\right)-I\left(D ; X, Y\right) \end{aligned} \tag{14} \]

上述的\((1)\)是由於\(I(X; Y \mid D)=0\)\((2)\)是由於互資訊的鏈式法則即\(I(D; X, Y)=I(D; Y) + I(D; X \mid Y)\)

\(I(X, Y)\)等價變換至此,真相已經逐漸浮出水面:我們可以可以透過最小化\(I\left(D ; X\right)\)\(I\left(D ; Y\right)\),最大化\(I\left(D ; X, Y\right)\)來完成對\(I(X, Y)\)的最小化。其直觀的物理意義也就是懲罰表徵\(X\)\(Y\)中涵蓋的總資訊,並使得\(X\)\(Y\)共同和資料\(D\)相關聯。

基於我們在\(3.1\)\(2.1\)中所推導的\(I(D; X)\)\(I(D, Y)\)的變分上界與\(I(D; X, Y)\)的變分下界,我們就得到了\(I(X, Y)\)的變分上界:

\[\begin{aligned} I\left(X ; Y\right) &\leq \mathbb{E}_{p(D)}\left[D_{\text{K L}}(q(x \mid D) \| p(x)) + D_{\text{K L}}(q(y \mid D) \| p(y))\right] \\ &+ \mathbb{E}_{p(D)}\left[\mathbb{E}_{q(x | D)q(y|D)}[\log p(D \mid x, y)]\right]+H(D) \end{aligned} \tag{15} \]

直觀地看,上式地物理意義為使後驗\(q(x\mid D)\)\(q(y\mid D)\)都趨近於各自的先驗分佈(一般取高斯分佈),並減小\(X\)\(Y\)\(D\)的重構誤差,直覺上確實符合表徵解耦的目標。

4 總結

總結起來,互資訊的所有上下界可以表示為下圖[2](包括我們前面提到的\(I_{\text{BA}}\)\(I_{\text{UBA}}\)\(I_{\text{DV}}\)\(I_{\text{MINE}}\)\(I_{\text{InfoNCE}}\)等):

遷移學習:互資訊的變分上下界

圖中節點的代表了它們估計與最佳化的易處理性:綠色的界表示易估計也易於最佳化,黃色的界表示易於最佳化但不易於估計,紅色的界表示既不易於最佳化也不易於估計。孩子節點透過引入新的近似或假設來從父親節點匯出。

參考

  • [1] Cover T M. Elements of information theory[M]. John Wiley & Sons, 1999.
  • [2] Poole B, Ozair S, Van Den Oord A, et al. On variational bounds of mutual information[C]//International Conference on Machine Learning. PMLR, 2019: 5171-5180.
  • [3] Tishby N, Pereira F C, Bialek W. The information bottleneck method[J]. arXiv preprint physics/0004057, 2000.
  • [4] Alemi A A, Fischer I, Dillon J V, et al. Deep variational information bottleneck[J]. arXiv preprint arXiv:1612.00410, 2016.
  • [5] Belghazi M I, Baratin A, Rajeshwar S, et al. Mutual information neural estimation[C]//International conference on machine learning. PMLR, 2018: 531-540.
  • [6] Hjelm R D, Fedorov A, Lavoie-Marchildon S, et al. Learning deep representations by mutual information estimation and maximization[J]. arXiv preprint arXiv:1808.06670, 2018.
  • [7] Barber D, Agakov F. The im algorithm: a variational approach to information maximization[J]. Advances in neural information processing systems, 2004, 16(320): 201.
  • [8] Rainforth T, Cornish R, Yang H, et al. On nesting monte carlo estimators[C]//International Conference on Machine Learning. PMLR, 2018: 4267-4276.
  • [9] Mathieu E, Rainforth T, Siddharth N, et al. Disentangling disentanglement in variational autoencoders[C]//International conference on machine learning. PMLR, 2019: 4402-4412.
  • [10] Oord A, Li Y, Vinyals O. Representation learning with contrastive predictive coding[J]. arXiv preprint arXiv:1807.03748, 2018.
  • [11] Variational Interaction Information Maximization for Cross-domain Disentanglement
  • [12] Alemi A, Poole B, Fischer I, et al. Fixing a broken ELBO[C]//International conference on machine learning. PMLR, 2018: 159-168.
  • [13] Rezende D J, Mohamed S, Wierstra D. Stochastic backpropagation and approximate inference in deep generative models[C]//International conference on machine learning. PMLR, 2014: 1278-1286.
  • [14] Kingma D P, Welling M. Auto-encoding variational bayes[J]. arXiv preprint arXiv:1312.6114, 2013.

相關文章