1 導引
在機器學習,尤其是涉及異構資料的遷移學習/聯邦學習中,我們常常會涉及互資訊相關的最佳化項,我上半年的第一份工作也是致力於此(ArXiv連結:FedDCSR)。其思想雖然簡單,但其具體的估計與最佳化手段而言卻大有門道,我們今天來好好總結一下,也算是對我研一下學期一個收尾。
我們知道,隨機變數\(X\)和\(Y\)的互資訊定義為其聯合分佈(joint)\(p(x, y)\)和其邊緣分佈(marginal)的乘積\(p(x)p(y)\)之間的KL散度(相對熵)[1]:
直觀地理解,互資訊表示一個隨機變數包含另一個隨機變數資訊量(即統計依賴性)的度量;同時,互資訊也是在給定另一隨機變數知識的條件下,原隨機變數不確定度的縮減量,即\(I(X; Y) = H(X) - H(X \mid Y) = H(Y) - H(Y\mid X)\)。當\(X\)和\(Y\)一一對應時,\(I(X; Y) = H(X) = H(Y)\);當\(X\)和\(Y\)相互獨立時\(I(X; Y)=0\)。
在機器學習的情境下,聯合分佈\(p(x, y)\)一般是未知的,因此我們需要用貝葉斯公式將其繼續轉換為如下形式:
那麼轉換為這種形式之後,我們是否就可以開始對其進行估計了呢?答案是否定的。我們假設現在是深度表徵學習場景,\(X\)是資料,\(Y\)是資料的隨機表徵,則對於第\((1)\)種形式來說,條件機率分佈\(p(x|y)=\frac{p (y|x)p(x)}{\int p(y|x)p(x)dx}\)是難解(intractable)的(由於\(p(x)\)未知);而對於第\((2)\)種形式而言,邊緣分佈\(p(y)\)也需要透過積分\(p(y)=\int p(y \mid x)p(x)d x\)來進行計算,而這也是難解的(由於\(p(x)\)未知)。為了解決互資訊估計的的難解性,我們的方法是不直接對互資訊進行估計,而是採用變分近似的手段,來得出互資訊的下界/上界做為近似,轉而對互資訊的下界/上界進行最大化/最小化[2]。
2 互資訊的變分下界(對應最大化)
我們先來看互資訊的變分下界。我們常常透過最大化互資訊的下界來近似地對其進行最大化。具體而言,按照是否需要解碼器,我們可以將互資訊的下界分為兩類,分別對應變分資訊瓶頸(解碼項)[3][4]和Deep InfoMax[5][6]這兩種方法。
2.1 資料VS表徵:變分資訊瓶頸(解碼項)
對於互資訊的第\((1)\)種表示法即\(I(X ; Y){=}\mathbb{E}_{p(x, y)}\left[\log \frac{p(x \mid y)}{p(x)}\right]\),我們已經知道條件分佈\(p(x|y)\)是難解的,那麼我們就採用變分分佈\(q(x|y)\)將其轉變為可解(tractable)的最佳化問題。這樣就可以匯出互資訊的Barber & Agakov下界(由於KL散度的非負性):
這裡\(H(X)\)是\(X\)的微分熵,BA是論文[7]兩位作者名字的縮寫。當\(q(x|y)=p(x|y)\)時,該下界是緊的,此時上式的第一項就等於條件熵\(H(X|Y)\)。
上式可不可解取決於微分熵\(H(X)\)是否已知。幸運的是,限定在 \(X\)是資料,\(Y\)是表徵 的場景下,\(H(X)=\mathbb{E}_{x\sim p(x)} \log p(x)\)僅涉及資料生成過程,和模型無關。這意味著我們只需要最大化\(I_{\text{BA}}\)的第一項,而這可以理解為最小化VAE中的重構誤差(失真,distortion)。此時,\(I_{\text{BA}}\)的梯度就與“編碼器”\(p(y|x)\)和變分“解碼器”\(q(x|y)\)相關,而這是易於計算的。因此,我們就可以使用該目標函式來學習一個最大化\(I(X; Y)\)的編碼器\(p(y|x)\),這就是大名鼎鼎的變分資訊瓶頸(variational information bottleneck) 的思想(對應其中的解碼項部分)。
2.2 表徵VS表徵:Deep Infomax
我們在 2.1 中介紹的方法雖然簡單好用,但是需要構建一個易於計算的解碼器\(q(x|y)\),這在\(X\)是資料,\(Y\)是表徵的時候非常容易,然而當 \(X\)和\(Y\)都是表徵 的時候就直接寄了,首先是因為解碼器\(q(x|y)\)是難以計算的,其次微分熵\(H(X)\)也是未知的。為了匯出不需要解碼器的可解下界,我們轉向去思考\(q(x|y)\)變分族的的非標準化分佈(unnormalized distributions)。
我們選擇一個基於能量的變分族,它使用一個判別函式/網路(critic)\(f(x, y): \mathcal{X} \times \mathcal{Y}\rightarrow \mathbb{R}\),並經由資料密度\(p(x)\)縮放:
我們將該分佈代入公式\((3)\)中的\(I_{\text{BA}}\)中,就匯出了另一個互資訊的下界,我們將其稱為UBA下界(記作\(I_{\text{UBA}}\)),可視為Barber & Agakov下界的非正太分佈版本(Unnormalized version):
當\(f(x, y)=\log p(y|x) + c(y)\)時,該上界是緊的,這裡\(c(y)\)僅僅是關於\(y\)的函式(而非\(x\))。注意在代入過程中難解的微分熵\(H(X)\)被消掉了,但我們仍然剩下一個難解的\(\log\)配分函式\(\log Z(y)\),它妨礙了我們計算梯度與評估。如果我們對\(\mathbb{E}_{p(y)}[\log Z(y)]\)這個整體應用Jensen不等式(\(\log\)為凹函式),我們能進一步匯出式\((5)\)的下界,即大名鼎鼎的Donsker & Varadhan下界[7]:
然而,該目標函式仍然是難解的。接下來我們換個角度,我們不對\(\mathbb{E}_{p(y)}[\log Z(y)]\)這個整體應用Jensen不等式,而考慮對裡面的\(\log Z(y)\)應用Jensen不等式即\(\log Z(y)=\log \mathbb{E}_{p(x)}\left[e^{f(x, y)}\right]\geq\mathbb{E}_{p(x)}\left[\log e^{f(x, y)}\right]=\mathbb{E}_{p(x)}\left[f(x, y)\right]\),那麼我們就可以匯出式\((5)\)的上界來對其進行近似:
然而式\((5)\)本身做為互資訊的下界而存在,因此\(I_{\text{MINE}}\)嚴格意義上講既不是互資訊的上界也不是互資訊的下界。不過這種方法可視為採用期望的蒙特卡洛近似來評估\(I_{\text{DV}}\),也就是作為互資訊下界的無偏估計。已經有工作證明了這種巢狀蒙特卡洛估計器的收斂性和漸進一致性,但並沒有給出在有限樣本下的成立的界[8][9]。
在\(I_{\text{MINE}}\)思想的基礎之上,論文Deep Infomax[6]又向前推進了一步,認為我們無需死抱著資訊的KL散度形式不放,可以大膽採用非KL散度的形式。事實上,我們主要感興趣的是最大化互資訊,而不關心它的精確值,於是採用非KL散度形式可以為我們提供有利的trade-off。比如我們就可以基於\(p(x, y)\)與\(p(x)p(y)\)的Jensen-Shannon散度(JSD),來定義如下的JS互資訊估計器:
這裡\(x\)是輸入樣本,\(x\prime\)是採自\(p(x^{\prime}) = p(x)\)的負樣本,\(\text{sp}(z) = \log (1+e^x)\)是\(\text{softplus}\)函式。這裡判別網路\(f\)被最佳化來能夠區分來自聯合分佈的樣本對(正樣本對)和來自邊緣乘積分佈的樣本對(負樣本對)。
此外,噪聲對比估計(NCE)[10]做為最先被採用的互資訊下界(被稱為“InfoNCE”),也可以用於互資訊最大化:
對於Deep Infomax而言,\(I_{\text{JSD}}\)和\(I_{\text{InfoNCE}}\)形式的之間差別在於負樣本分佈\(p(x^{\prime})\)的期望是套在正樣本分佈\(p(x, y)\)期望的裡面還是外面,而這個差別就意味著對於\(\text{DV}\)和\(\text{JSD}\)而言一個正樣本只需要一個負樣本,但對於\(\text{InfoNCE}\)而言就是一個正樣本就需要\(N\)個負樣本(\(N\)為batch size)。此外,也有論文[6]分析證明了\(I_{\text{JSD}}\)對負樣本的數量不敏感,而\(I_{\text{InfoNCE}}\)的表現會隨著負樣本的減少而下降。
3 互資訊的變分上界(對應最小化)
我們接下來來看互資訊的變分上界。我們常常透過最小化互資訊的上界來近似地對互資訊進行最小化。具體而言,按照是否需要編碼器,我們可以將互資訊的下界分為兩類,而這兩個類別分別就對應了變分資訊瓶頸的編碼項[4]和解耦表徵學習[11]。
3.1 資料VS表徵:變分資訊瓶頸(編碼項)
對於互資訊的第\((2)\)種表示法即\(I(X ; Y){=}\mathbb{E}_{p(x, y)}\left[\log \frac{p(y \mid x)}{p(y)}\right]\),我們已經知道邊緣分佈\(p(y)=\int p(y \mid x)p(x)d x\)是難解的。但是限定在 \(X\)是資料,\(Y\)是表徵 的場景下,我們能夠透過引入一個變分近似\(q(y)\)來構建一個可解的變分下界:
注意上面的\((1)\)是分子分母同時乘以\(q(y)\);\((2)\)是單獨配湊出KL散度;\((3)\)是利用KL散度的非負性(證明變分上下界的常用技巧)。最後得到的這個上界我們在生成模型在常常被稱為Rate[12](也就是率失真理論裡的那個率),故這裡記為\(R\)。當\(q(y)=p(y)\)時該上界是緊的,且該上界要求\(\log q(y)\)是易於計算的。該變分上界經常在深度生成模型(如VAE)[13][14] 被用來限制隨機表徵的容量。在變分資訊瓶頸[4]這篇論文中,該上界被用於防止表徵攜帶更多與輸入有關,但卻和下游分類任務無關的資訊(即對應其中的編碼項部分)。
3.2 表徵VS表徵:解耦表徵學習
上面介紹的方法需要構建一個易於計算的編碼器\(p(y|x)\),但應用場景也僅限於在\(X\)是資料,\(Y\)是表徵的情況下,當 \(X\)和\(Y\)都是表徵 的時候(即對應解耦表徵學習的場景)也會遇到我們在2.2中所面臨的問題,從而不能夠使用了。那麼我們能不能效仿2.2中的做法,對匯出的\(I_{\text{JSD}}\)和\(I_{\text{InfoNCE}}\)加個負號,從而將互資訊最大化轉換為互資訊最小化呢?當然可以但是效果不會太好。因為對於兩個分佈而言,拉近它們距離的結果是確定可控的,但直接推遠它們距離的結果就是不可控的了——我們無法掌控這兩個分佈推遠之後的具體形態,導致任務的整體表現受到負面影響。那麼有沒有更好的辦法呢?
我們退一步思考:最小化互資訊\(I(X, Y)\)的難點在於\(X\)和\(Y\)都是隨機表徵,那麼我們可以嘗試引入資料隨機變數\(D\),使得互資訊\(I(X, Y)\)可以進一步拆分為\(D\)和\(X\)、\(Y\)之間的互資訊(如\(I(D; X)\)以及\(I(D; Y)\)。已知三個隨機變數的互資訊(稱之為Interation information[1])的定義如下:
聯立上述的等式\((1)\)和等式\((2)\),我們有:
在解耦表徵學習中,由於關於表徵後驗分佈\(q\)滿足結構化假設\(q\left(X \mid D\right)=q\left(X \mid D, Y\right)\),因此上述等式的最後一項就消失了:
這樣我們就有:
上述的\((1)\)是由於\(I(X; Y \mid D)=0\),\((2)\)是由於互資訊的鏈式法則即\(I(D; X, Y)=I(D; Y) + I(D; X \mid Y)\)。
對\(I(X, Y)\)等價變換至此,真相已經逐漸浮出水面:我們可以可以透過最小化\(I\left(D ; X\right)\)、\(I\left(D ; Y\right)\),最大化\(I\left(D ; X, Y\right)\)來完成對\(I(X, Y)\)的最小化。其直觀的物理意義也就是懲罰表徵\(X\)和\(Y\)中涵蓋的總資訊,並使得\(X\)和\(Y\)共同和資料\(D\)相關聯。
基於我們在\(3.1\)、\(2.1\)中所推導的\(I(D; X)\)、\(I(D, Y)\)的變分上界與\(I(D; X, Y)\)的變分下界,我們就得到了\(I(X, Y)\)的變分上界:
直觀地看,上式地物理意義為使後驗\(q(x\mid D)\)、\(q(y\mid D)\)都趨近於各自的先驗分佈(一般取高斯分佈),並減小\(X\)和\(Y\)對\(D\)的重構誤差,直覺上確實符合表徵解耦的目標。
4 總結
總結起來,互資訊的所有上下界可以表示為下圖[2](包括我們前面提到的\(I_{\text{BA}}\)、\(I_{\text{UBA}}\)、\(I_{\text{DV}}\)、\(I_{\text{MINE}}\)、\(I_{\text{InfoNCE}}\)等):
圖中節點的代表了它們估計與最佳化的易處理性:綠色的界表示易估計也易於最佳化,黃色的界表示易於最佳化但不易於估計,紅色的界表示既不易於最佳化也不易於估計。孩子節點透過引入新的近似或假設來從父親節點匯出。
參考
- [1] Cover T M. Elements of information theory[M]. John Wiley & Sons, 1999.
- [2] Poole B, Ozair S, Van Den Oord A, et al. On variational bounds of mutual information[C]//International Conference on Machine Learning. PMLR, 2019: 5171-5180.
- [3] Tishby N, Pereira F C, Bialek W. The information bottleneck method[J]. arXiv preprint physics/0004057, 2000.
- [4] Alemi A A, Fischer I, Dillon J V, et al. Deep variational information bottleneck[J]. arXiv preprint arXiv:1612.00410, 2016.
- [5] Belghazi M I, Baratin A, Rajeshwar S, et al. Mutual information neural estimation[C]//International conference on machine learning. PMLR, 2018: 531-540.
- [6] Hjelm R D, Fedorov A, Lavoie-Marchildon S, et al. Learning deep representations by mutual information estimation and maximization[J]. arXiv preprint arXiv:1808.06670, 2018.
- [7] Barber D, Agakov F. The im algorithm: a variational approach to information maximization[J]. Advances in neural information processing systems, 2004, 16(320): 201.
- [8] Rainforth T, Cornish R, Yang H, et al. On nesting monte carlo estimators[C]//International Conference on Machine Learning. PMLR, 2018: 4267-4276.
- [9] Mathieu E, Rainforth T, Siddharth N, et al. Disentangling disentanglement in variational autoencoders[C]//International conference on machine learning. PMLR, 2019: 4402-4412.
- [10] Oord A, Li Y, Vinyals O. Representation learning with contrastive predictive coding[J]. arXiv preprint arXiv:1807.03748, 2018.
- [11] Variational Interaction Information Maximization for Cross-domain Disentanglement
- [12] Alemi A, Poole B, Fischer I, et al. Fixing a broken ELBO[C]//International conference on machine learning. PMLR, 2018: 159-168.
- [13] Rezende D J, Mohamed S, Wierstra D. Stochastic backpropagation and approximate inference in deep generative models[C]//International conference on machine learning. PMLR, 2014: 1278-1286.
- [14] Kingma D P, Welling M. Auto-encoding variational bayes[J]. arXiv preprint arXiv:1312.6114, 2013.