變分自編碼器(七):球面上的VAE(vMF-VAE)

jasonzhangxianrong發表於2024-07-08

《變分自編碼器(五):VAE + BN = 更好的VAE》中,我們講到了NLP中訓練VAE時常見的KL散度消失現象,並且提到了透過BN來使得KL散度項有一個正的下界,從而保證KL散度項不會消失。事實上,早在2018年的時候,就有類似思想的工作就被提出了,它們是透過在VAE中改用新的先驗分佈和後驗分佈,來使得KL散度項有一個正的下界。

該思路出現在2018年的兩篇相近的論文中,分別是《Hyperspherical Variational Auto-Encoders》《Spherical Latent Spaces for Stable Variational Autoencoders》,它們都是用定義在超球面的von Mises–Fisher(vMF)分佈來構建先後驗分佈。某種程度上來說,該分佈比我們常用的高斯分佈還更簡單和有趣~

KL散度消失 #

我們知道,VAE的訓練目標是
\begin{equation}\mathcal{L} = \mathbb{E}_{x\sim \tilde{p}(x)} \Big[\mathbb{E}_{z\sim p(z|x)}\big[-\log q(x|z)\big]+KL\big(p(z|x)\big\Vert q(z)\big)\Big]
\end{equation}
其中第一項是重構項,第二項是KL散度項,在《變分自編碼器(一):原來是這麼一回事》中我們就說過,這兩項某種意義上是“對抗”的,KL散度項的存在,會加大解碼器利用編碼資訊的難度,如果KL散度項為0,那麼說明解碼器完全沒有利用到編碼器的資訊。

在NLP中,輸入和重構的物件是句子,為了保證效果,解碼器一般用自迴歸模型。然而,自迴歸模型是非常強大的模型,強大到哪怕沒有輸入,也能完成訓練(退化為無條件語言模型),而剛才我們說了,KL散度項會加大解碼器利用編碼資訊的難度,所以解碼器乾脆棄之不用,這就出現了KL散度消失現象。

早期比較常見的應對方案是逐漸增加KL項的權重,以引導解碼器去利用編碼資訊。現在比較流行的方案就是透過某些改動,直接讓KL散度項有一個正的下界。將先後驗分佈換為vMF分佈,就是這種方案的經典例子之一。

vMF分佈 #

vMF分佈是定義在$d-1$維超球面的分佈,其樣本空間為$S^{d-1}=\{x|x\in\mathbb{R}^d, \Vert x\Vert=1\}$,機率密度函式則為
\begin{equation}p(x) = \frac{e^{\langle\xi,x\rangle}}{Z_{d, \Vert\xi\Vert}},\quad Z_{d, \Vert\xi\Vert}=\int_{S^{d-1}}e^{\langle\xi,x\rangle} dS^{d-1}\end{equation}
其中$\xi\in\mathbb{R}^d$是預先給定的引數向量。不難想象,這是$S^{d-1}$上一個以$\xi$為中心的分佈,歸一化因子寫成$Z_{d, \Vert\xi\Vert}$的形式,意味著它只依賴於$\xi$的模長,這是由於各向同性導致的。由於這個特性,vMF分佈更常見的記法是設$\mu=\xi/\Vert\xi\Vert, \kappa=\Vert\xi\Vert, C_{d,\kappa}=1/Z_{d, \Vert\xi\Vert}$,從而
\begin{equation}p(x) = C_{d,\kappa} e^{\kappa\langle\mu,x\rangle}\end{equation}
這時候$\langle\mu,x\rangle$就是$\mu,x$的夾角餘弦,所以說,vMF分佈實際上就是以餘弦相似度為度量的一種分佈。由於我們經常用餘弦值來度量兩個向量的相似度,因此基於vMF分佈做出來的模型,通常更能滿足我們的這個需求。當$\kappa=0$的時候,vMF分佈是球面上的均勻分佈。

從歸一化因子$Z_{d, \Vert\xi\Vert}$的積分形式來看,它實際上也是vMF的母函式,從而vMF的各階矩也可以透過$Z_{d, \Vert\xi\Vert}$來表達,比如一階矩為
\begin{equation}\mathbb{E}_{x\sim p(x)} [x] = \nabla_{\xi} \log Z_{d, \Vert\xi\Vert}=\frac{d \log Z_{d,\Vert\xi\Vert}}{d\Vert\xi\Vert}\frac{\xi}{\Vert\xi\Vert}\end{equation}
可以看到$\mathbb{E}_{x\sim p(x)} [x]$在方向上跟$\xi$一致。$Z_{d, \Vert\xi\Vert}$的精確形式可以算出來,但比較複雜,而且很多時候我們也不需要精確知道這個歸一化因子,所以這裡我們就不算了。

至於引數$\kappa$的含義,或許設$\tau=1/\kappa$我們更好理解,此時$p(x)\sim e^{\langle\mu,x\rangle/\tau}$,熟悉能量模型的同學都知道,這裡的$\tau$就是溫度引數,如果$\tau$越小($\kappa$越大),那麼分佈就越集中在$\mu$附近,反之則越分散(越接近球面上的均勻分佈)。因此,$\kappa$也被形象地稱為“凝聚度(concentration)”引數。

從vMF取樣 #

對於vMF分佈來說,需要解決的第一個難題是如何實現從它裡邊取樣出具體的樣本來。尤其是如果我們要將它應用到VAE中,那麼這一步是至關重要的。

均勻分佈 #

最簡單是$\kappa=0$的情形,也就是$d-1$維球面上的均勻分佈,因為標準正態分佈本來就是各向同性的,其機率密度正比於$e^{-\Vert x\Vert^2/2}$只依賴於模長,所以我們只需要從$d$為標準正態分佈中取樣一個$z$,然後讓$x=z/\Vert z\Vert$就得到了球面上的均勻取樣結果。

特殊方向 #

接著,對於$\kappa > 0$的情形,我們記$x=[x_1,x_2,\cdots,x_d]$,首先考慮一種特殊的情況:$\mu = [1, 0, \cdots, 0]$。事實上,由於各向同性的原因,很多時候我們都只需要考慮這個特殊情況,然後就可以平行地推廣到一般情形。

此時機率密度正比於$e^{\kappa x_1}$,然後我們轉換到球座標系:
\begin{equation}
\left\{\begin{aligned}
x_1 &= \cos\varphi_1\\
x_2 &= \sin\varphi_1 \cos\varphi_2 \\
x_3 &= \sin\varphi_1 \sin\varphi_2 \cos\varphi_3 \\
&\,\,\vdots \\
x_{d-1} &= \sin\varphi_1 \cdots \sin\varphi_{d-2} \cos\varphi_{d-1}\\
x_d &= \sin\varphi_1 \cdots \sin\varphi_{d-2} \sin\varphi_{d-1}
\end{aligned}\right.
\end{equation}
那麼(超球座標的積分變換,請直接參考“維基百科”)
\begin{equation}\begin{aligned}
e^{\kappa x_1}dS^{d-1} =& e^{\kappa\cos\varphi_1}\sin^{d-2}\varphi_1 \sin^{d-3}\varphi_2 \cdots \sin\varphi_{d-2} d\varphi_1 d\varphi_2 \cdots d\varphi_{d-1} \\
=& \left(e^{\kappa\cos\varphi_1}\sin^{d-2}\varphi_1 d\varphi_1\right)\left(\sin^{d-3}\varphi_2 \cdots \sin\varphi_{d-2} d\varphi_2 \cdots d\varphi_{d-1}\right) \\
=& \left(e^{\kappa\cos\varphi_1}\sin^{d-2}\varphi_1 d\varphi_1\right)dS^{d-2} \\
\end{aligned}\end{equation}
這個分解表明,從該vMF分佈中取樣,等價於先從機率密度正比於$e^{\kappa\cos\varphi_1}\sin^{d-2}\varphi_1$的分佈取樣一個$\varphi_1$,然後從$d-2$維超球面上均勻取樣一個$d-1$維向量$\varepsilon = [\varepsilon_2,\varepsilon_3,\cdots,\varepsilon_d]$,透過如下方式組合成最終取樣結果
\begin{equation}x = [\cos\varphi_1, \varepsilon_2\sin\varphi_1, \varepsilon_3\sin\varphi_1, \cdots, \varepsilon_d\sin\varphi_1]\end{equation}
設$w=\cos\phi_1\in[-1,1]$,那麼
\begin{equation}\left|e^{\kappa\cos\varphi_1}\sin^{d-2}\varphi_1 d\varphi_1\right| = \left|e^{\kappa w} (1-w^2)^{(d-3)/2}dw\right|\end{equation}
所以我們主要研究從機率密度正比於$e^{\kappa w} (1-w^2)^{(d-3)/2}$的分佈中取樣。

然而,筆者所不理解的是,大多數涉及到vMF分佈的論文,都採用了1994年的論文《Simulation of the von mises fisher distribution》提出的基於beta分佈的拒絕取樣方案,整個取樣流程還是頗為複雜的。但現在都2021年了,對於一維分佈的取樣,居然還需要拒絕取樣這麼低效的方案?

事實上,對於任意一維分佈$p(w)$,設它的累積機率函式為$\Phi(w)$,那麼$w=\Phi^{-1}(\varepsilon),\varepsilon\sim U[0,1]$就是一個最方便通用的取樣方案。可能有讀者抗議說“累積機率函式不好算呀”、“它的逆函式更不好算呀”,但是在用程式碼實現取樣的時候,我們壓根就不需要知道$\Phi(w)$長啥樣,只要直接數值計算就行了,參考實現如下:

import numpy as np

def sample_from_pw(size, kappa, dims, epsilon=1e-7):
x = np.arange(-1 + epsilon, 1, epsilon)
y = kappa * x + np.log(1 - x**2) * (dims - 3) / 2
y = np.cumsum(np.exp(y - y.max()))
y = y / y[-1]
return np.interp(np.random.random(size), y, x)

這裡的實現中,計算量最大的是變數y的計算,而一旦計算好之後,可以快取下來,之後只需要執行最後一步來完成取樣,其速度是非常快的。這樣再怎麼看,也比從beta分佈中拒絕取樣要簡單方便吧。順便說,實現上這裡還用到了一個技巧,即先計算對數值,然後減去最大值,最後才算指數,這樣可以防止溢位,哪怕$\kappa$成千上萬,也可以成功計算。

一般情形 #

現在我們已經實現了從$\mu=[1,0,\cdots,0]$的vMF分佈中取樣了,我們可以將取樣結果分解為
\begin{equation}x = w\times\underbrace{[1,0,\cdots,0]}_{\text{引數向量}\mu} + \sqrt{1-w^2}\times\underbrace{[0,\varepsilon_2,\cdots,\varepsilon_d]}_{\begin{array}{c}\text{與}\mu\text{正交的}d-2\text{維}\\ \text{超球面均勻取樣}\end{array}}\end{equation}
同樣由於各向同性的原因,對於一般的$\mu$,取樣結果依然具有同樣的形式:
\begin{equation}\begin{aligned}
&x = w\mu + \sqrt{1-w^2}\nu\\
&w\sim e^{\kappa w} (1-w^2)^{(d-3)/2}\\
&\nu\sim \text{與}\mu\text{正交的}d-2\text{維超球面均勻分佈}
\end{aligned}\end{equation}
對於$\nu$的取樣,關鍵之處是與$\mu$正交,這也不難實現,先從標準正態分佈中取樣一個$d$維向量$z$,然後保留與$\mu$正交的分量並歸一化即可:
\begin{equation}\nu = \frac{\varepsilon - \langle \varepsilon,\mu\rangle \mu}{\Vert \varepsilon - \langle \varepsilon,\mu\rangle \mu\Vert},\quad \varepsilon\sim\mathcal{N}(0,1_d)\end{equation}

vMF-VAE #

至此,我們可謂是已經完成了本篇文章最艱難的部分,剩下的構建vMF-VAE可謂是水到渠成了。vMF-VAE選用球面上的均勻分佈($\kappa=0$)作為先驗分佈$q(z)$,並將後驗分佈選取為vMF分佈:
\begin{equation}p(z|x) = C_{d,\kappa} e^{\kappa\langle\mu(x),z\rangle}\end{equation}
簡單起見,我們將$\kappa$設為超引數(也可以理解為透過人工而不是梯度下降來更新這個引數),這樣一來,$p(z|x)$的唯一引數來源就是$\mu(x)$了。此時我們可以計算KL散度項
\begin{equation}\begin{aligned}
\int p(z|x) \log\frac{p(z|x)}{q(z)} dz =&\, \int C_{d,\kappa} e^{\kappa\langle\mu(x),z\rangle}\left(\kappa\langle\mu(x),z\rangle + \log C_{d,\kappa} - \log C_{d,0}\right)dz\\
=&\,\kappa\left\langle\mu(x),\mathbb{E}_{z\sim p(z|x)}[z]\right\rangle + \log C_{d,\kappa} - \log C_{d,0}
\end{aligned}\end{equation}
前面我們已經討論過,vMF分佈的均值方向跟$\mu(x)$一致,模長則只依賴於$d$和$\kappa$,所以代入上式後我們可以知道KL散度項只依賴於$d$和$\kappa$,當這兩個引數被選定之後,那麼它就是一個常數(根據KL散度的性質,當$\kappa\neq 0$時,它必然大於0),絕對不會出現KL散度消失現象了。

那麼現在就剩下重構項了,我們需要用“重引數(Reparameterization)”來完成取樣並保留梯度,在前面我們已經研究了vMF的取樣過程,所以也不難實現,綜合的流程為:
\begin{equation}\begin{aligned}
&\mathcal{L} = \Vert x - g(z)\Vert^2\\
&z = w\mu(x) + \sqrt{1-w^2}\nu\\
&w\sim e^{\kappa w} (1-w^2)^{(d-3)/2}\\
&\nu=\frac{\varepsilon - \langle \varepsilon,\mu\rangle \mu}{\Vert \varepsilon - \langle \varepsilon,\mu\rangle \mu\Vert}\\
&\varepsilon\sim\mathcal{N}(0,1_d)
\end{aligned}\end{equation}
這裡的重構loss以MSE為例,如果是句子重構,那麼換用交叉熵就好。其中$\mu(x)$就是編碼器,而$g(z)$就是解碼器,由於KL散度項為常數,對最佳化沒影響,所以vMF-VAE相比於普通的自編碼器,只是多了一項稍微有點複雜的重引數操作(以及人工調整$\kappa$)而已,相比基於高斯分佈的標準VAE可謂簡化了不少了。

此外,從該流程我們也可以看出,除了“簡單起見”之外,不將$\kappa$設為可訓練還有一個主要原因,那就是$\kappa$關係到$w$的取樣,而在$w$的取樣過程中要保留$\kappa$的梯度是比較困難的。

參考實現 #

vMF-VAE的實現難度主要是重引數部分,也就還是從vMF分佈中取樣,而關鍵之處就是$w$的取樣。前面我們已經給出了$w$的取樣的numpy實現,但是在tf中未見類似np.interp的函式,因此不容易轉換為純tf的實現。當然,如果是torch或者tf2這種動態圖框架,直接跟numpy的程式碼混合使用也無妨,但這裡還是想構造一種比較通用的方案。

其實也不難,由於$w$只是一個一維變數,每步訓練只需要用到batch_size個取樣結果,所以我們完全可以事先用numpy函式取樣好足夠多(幾十萬)個$w$存好,然後訓練的時候直接從這批取樣好的結果隨機抽就行了,參考實現如下:

def sampling(mu):
    """vMF分佈重引數操作
    """
    dims = K.int_shape(mu)[-1]
    # 預先計算一批w
    epsilon = 1e-7
    x = np.arange(-1 + epsilon, 1, epsilon)
    y = kappa * x + np.log(1 - x**2) * (dims - 3) / 2
    y = np.cumsum(np.exp(y - y.max()))
    y = y / y[-1]
    W = K.constant(np.interp(np.random.random(10**6), y, x))
    # 實時取樣w
    idxs = K.random_uniform(K.shape(mu[:, :1]), 0, 10**6, dtype='int32')
    w = K.gather(W, idxs)
    # 實時取樣z
    eps = K.random_normal(K.shape(mu))
    nu = eps - K.sum(eps * mu, axis=1, keepdims=True) * mu
    nu = K.l2_normalize(nu, axis=-1)
    return w * mu + (1 - w**2)**0.5 * nu

一個基於MNIST的完整例子可見:

https://github.com/bojone/vae/blob/master/vae_vmf_keras.py

至於vMF-VAE用於NLP的例子,我們日後有機會再分享。本文主要還是以理論介紹和簡單演示為主~

文章小結 #

本文介紹了基於vMF分佈的VAE實現,其主要難度在於vMF分佈的取樣。總的來說,vMF分佈建立在餘弦相似度度量之上,在某些方面的性質更符合我們的直觀認知,將其用於VAE中,能夠使得KL散度項為一個常數,從而防止了KL散度消失現象,並且簡化了VAE結構。

轉載到請包括本文地址:https://spaces.ac.cn/archives/8404

更詳細的轉載事宜請參考:《科學空間FAQ》

相關文章