變分自編碼器(五):VAE + BN = 更好的VAE

jasonzhangxianrong發表於2024-07-08

本文我們繼續之前的變分自編碼器系列,分析一下如何防止NLP中的VAE模型出現“KL散度消失(KL Vanishing)”現象。本文受到參考文獻是ACL 2020的論文《A Batch Normalized Inference Network Keeps the KL Vanishing Away》的啟發,並自行做了進一步的完善。

值得一提的是,本文最後得到的方案還是頗為簡潔的——只需往編碼輸出加入BN(Batch Normalization),然後加個簡單的scale——但確實很有效,因此值得正在研究相關問題的讀者一試。同時,相關結論也適用於一般的VAE模型(包括CV的),如果按照筆者的看法,它甚至可以作為VAE模型的“標配”。

最後,要提醒讀者這算是一篇VAE的進階論文,所以請讀者對VAE有一定了解後再來閱讀本文。

VAE簡單回顧 #

這裡我們簡單回顧一下VAE模型,並且討論一下VAE在NLP中所遇到的困難。關於VAE的更詳細介紹,請讀者參考筆者的舊作《變分自編碼器(一):原來是這麼一回事》《變分自編碼器(二):從貝葉斯觀點出發》等。

VAE的訓練流程 #

VAE的訓練流程大概可以圖示為

VAE訓練流程圖示

VAE訓練流程圖示

寫成公式就是
$$\begin{equation}\mathcal{L} = \mathbb{E}_{x\sim \tilde{p}(x)} \Big[\mathbb{E}_{z\sim p(z|x)}\big[-\log q(x|z)\big]+KL\big(p(z|x)\big\Vert q(z)\big)\Big]
\end{equation}$$
其中第一項就是重構項,$\mathbb{E}_{z\sim p(z|x)}$是透過重引數來實現;第二項則稱為KL散度項,這是它跟普通自編碼器的顯式差別,如果沒有這一項,那麼基本上退化為常規的AE。更詳細的符號含義可以參考《變分自編碼器(二):從貝葉斯觀點出發》

NLP中的VAE #

在NLP中,句子被編碼為離散的整數ID,所以$q(x|z)$是一個離散型分佈,可以用萬能的“條件語言模型”來實現,因此理論上$q(x|z)$可以精確地擬合生成分佈,問題就出在$q(x|z)$太強了,訓練時重引數操作會來噪聲,噪聲一大,$z$的利用就變得困難起來,所以它乾脆不要$z$了,退化為無條件語言模型(依然很強),$KL(p(z|x)\Vert q(z))$則隨之下降到0,這就出現了KL散度消失現象

這種情況下的VAE模型並沒有什麼價值:KL散度為0說明編碼器輸出的是常數向量,而解碼器則是一個普通的語言模型。而我們使用VAE通常來說是看中了它無監督構建編碼向量的能力,所以要應用VAE的話還是得解決KL散度消失問題。事實上從2016開始,有不少工作在做這個問題,相應地也提出了很多方案,比如退火策略、更換先驗分佈等,讀者Google一下“KL Vanishing”就可以找到很多文獻了,這裡不一一溯源。

BN的巧與妙 #

本文的方案則是直接針對KL散度項入手,簡單有效而且沒什麼超引數。其思想很簡單:

KL散度消失不就是KL散度項變成0嗎?我調整一下編碼器輸出,讓KL散度有一個大於零的下界,這樣它不就肯定不會消失了嗎?

這個簡單的思想的直接結果就是:在$\mu$後面加入BN層,如圖

往VAE里加入BN

往VAE里加入BN

推導過程簡述 #

為什麼會跟BN聯絡起來呢?我們來看KL散度項的形式:
\begin{equation}\mathbb{E}_{x\sim\tilde{p}(x)}\left[KL\big(p(z|x)\big\Vert q(z)\big)\right] = \frac{1}{b} \sum_{i=1}^b \sum_{j=1}^d \frac{1}{2}\Big(\mu_{i,j}^2 + \sigma_{i,j}^2 - \log \sigma_{i,j}^2 - 1\Big)\end{equation}
上式是取樣了$b$個樣本進行計算的結果,而編碼向量的維度則是$d$維。由於我們總是有$e^x \geq x + 1$,所以$\sigma_{i,j}^2 - \log \sigma_{i,j}^2 - 1 \geq 0$,因此
\begin{equation}\mathbb{E}_{x\sim\tilde{p}(x)}\left[KL\big(p(z|x)\big\Vert q(z)\big)\right] \geq \frac{1}{b} \sum_{i=1}^b \sum_{j=1}^d \frac{1}{2}\mu_{i,j}^2 = \frac{1}{2}\sum_{j=1}^d \left(\frac{1}{b} \sum_{i=1}^b \mu_{i,j}^2\right)\label{eq:kl}\end{equation}
留意到括號裡邊的量,其實它就是$\mu$在batch內的二階矩,如果我們往$\mu$加入BN層,那麼大體上可以保證$\mu$的均值為$\beta$,方差為$\gamma^2$($\beta,\gamma$是BN裡邊的可訓練引數),這時候
\begin{equation}\mathbb{E}_{x\sim\tilde{p}(x)}\left[KL\big(p(z|x)\big\Vert q(z)\big)\right] \geq \frac{d}{2}\left(\beta^2 + \gamma^2\right)\label{eq:kl-lb}\end{equation}
所以只要控制好$\beta,\gamma$(主要是固定$\gamma$為某個常數),就可以讓KL散度項有個正的下界,因此就不會出現KL散度消失現象了。這樣一來,KL散度消失現象跟BN就被巧妙地聯絡起來了,透過BN來“杜絕”了KL散度消失的可能性。

為什麼不是LN? #

善於推導的讀者可能會想到,按照上述思路,如果只是為了讓KL散度項有個正的下界,其實LN(Layer Normalization)也可以,也就是在式$\eqref{eq:kl}$中按$j$那一維歸一化。

那為什麼用BN而不是LN呢?

這個問題的答案也是BN的巧妙之處。直觀來理解,KL散度消失是因為$z\sim p(z|x)$的噪聲比較大,解碼器無法很好地辨別出$z$中的非噪聲成分,所以乾脆棄之不用;而當給$\mu(x)$加上BN後,相當於適當地拉開了不同樣本的$z$的距離,使得哪怕$z$帶了噪聲,區分起來也容易一些,所以這時候解碼器樂意用$z$的資訊,因此能緩解這個問題;相比之下,LN是在樣本內進的行歸一化,沒有拉開樣本間差距的作用,所以LN的效果不會有BN那麼好。

進一步的結果 #

事實上,原論文的推導到上面基本上就結束了,剩下的都是實驗部分,包括透過實驗來確定$\gamma$的值。然而,筆者認為目前為止的結論還有一些美中不足的地方,比如沒有提供關於加入BN的更深刻理解,倒更像是一個工程的技巧,又比如只是$\mu(x)$加上了BN,$\sigma(x)$沒有加上,未免有些不對稱之感。

經過筆者的推導,發現上面的結論可以進一步完善。

聯絡到先驗分佈 #

對於VAE來說,它希望訓練好後的模型的隱變數分佈為先驗分佈$q(z)=\mathcal{N}(z;0,1)$,而後驗分佈則是$p(z|x)=\mathcal{N}(z; \mu(x),\sigma^2(x))$,所以VAE希望下式成立:
\begin{equation}q(z) = \int \tilde{p}(x)p(z|x)dx=\int \tilde{p}(x)\mathcal{N}(z; \mu(x),\sigma^2(x))dx\end{equation}
兩邊乘以$z$,並對$z$積分,得到
\begin{equation}0 = \int \tilde{p}(x)\mu(x)dx=\mathbb{E}_{x\sim \tilde{p}(x)}[\mu(x)]\end{equation}
兩邊乘以$z^2$,並對$z$積分,得到
\begin{equation}1 = \int \tilde{p}(x)\left[\mu^2(x) + \sigma^2(x)\right]dx = \mathbb{E}_{x\sim \tilde{p}(x)}\left[\mu^2(x)\right] + \mathbb{E}_{x\sim \tilde{p}(x)}\left[\sigma^2(x)\right]\end{equation}
如果往$\mu(x),\sigma(x)$都加入BN,那麼我們就有
\begin{equation}\begin{aligned}
&0 = \mathbb{E}_{x\sim \tilde{p}(x)}[\mu(x)] = \beta_{\mu}\\
&1 = \mathbb{E}_{x\sim \tilde{p}(x)}\left[\mu^2(x)\right] + \mathbb{E}_{x\sim \tilde{p}(x)}\left[\sigma^2(x)\right] = \beta_{\mu}^2 + \gamma_{\mu}^2 + \beta_{\sigma}^2 + \gamma_{\sigma}^2
\end{aligned}\end{equation}
所以現在我們知道$\beta_{\mu}$一定是0,而如果我們也固定$\beta_{\sigma}=0$,那麼我們就有約束關係:
\begin{equation}1 = \gamma_{\mu}^2 + \gamma_{\sigma}^2\label{eq:gamma2}\end{equation}

參考的實現方案 #

經過這樣的推導,我們發現可以往$\mu(x),\sigma(x)$都加入BN,並且可以固定$\beta_{\mu}=\beta_{\sigma}=0$,但此時需要滿足約束$\eqref{eq:gamma2}$。要注意的是,這部分討論還僅僅是對VAE的一般分析,並沒有涉及到KL散度消失問題,哪怕這些條件都滿足了,也無法保證KL項不趨於0。結合式$\eqref{eq:kl-lb}$我們可以知道,保證KL散度不消失的關鍵是確保$\gamma_{\mu} > 0$,所以,筆者提出的最終策略是:
\begin{equation}\begin{aligned}
&\beta_{\mu}=\beta_{\sigma}=0\\
&\gamma_{\mu} = \sqrt{\tau + (1-\tau)\cdot\text{sigmoid}(\theta)}\\
&\gamma_{\sigma} = \sqrt{(1-\tau)\cdot\text{sigmoid}(-\theta)}
\end{aligned}\end{equation}
其中$\tau\in(0,1)$是一個常數,筆者在自己的實驗中取了$\tau=0.5$,而$\theta$是可訓練引數,上式利用了恆等式$\text{sigmoid}(-\theta) = 1-\text{sigmoid}(\theta)$。

關鍵程式碼參考(Keras):

class Scaler(Layer):
    """特殊的scale層
    """
    def __init__(self, tau=0.5, **kwargs):
        super(Scaler, self).__init__(**kwargs)
        self.tau = tau
def build(self, input_shape):
    super(Scaler, self).build(input_shape)
    self.scale = self.add_weight(
        name='scale', shape=(input_shape[-1],), initializer='zeros'
    )

def call(self, inputs, mode='positive'):
    if mode == 'positive':
        scale = self.tau + (1 - self.tau) * K.sigmoid(self.scale)
    else:
        scale = (1 - self.tau) * K.sigmoid(-self.scale)
    return inputs * K.sqrt(scale)

def get_config(self):
    config = {'tau': self.tau}
    base_config = super(Scaler, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

def sampling(inputs):
"""重引數取樣
"""
z_mean, z_std = inputs
noise = K.random_normal(shape=K.shape(z_mean))
return z_mean + z_std * noise

e_outputs # 假設e_outputs是編碼器的輸出向量
scaler = Scaler()
z_mean = Dense(hidden_dims)(e_outputs)
z_mean = BatchNormalization(scale=False, center=False, epsilon=1e-8)(z_mean)
z_mean = scaler(z_mean, mode='positive')
z_std = Dense(hidden_dims)(e_outputs)
z_std = BatchNormalization(scale=False, center=False, epsilon=1e-8)(z_std)
z_std = scaler(z_std, mode='negative')
z = Lambda(sampling, name='Sampling')([z_mean, z_std])

文章內容小結 #

本文簡單分析了VAE在NLP中的KL散度消失現象,並介紹了透過BN層來防止KL散度消失、穩定訓練流程的方法。這是一種簡潔有效的方案,不單單是原論文,筆者私下也做了簡單的實驗,結果確實也表明了它的有效性,值得各位讀者試用。因為其推導具有一般性,所以甚至任意場景(比如CV)中的VAE模型都可以嘗試一下。

轉載到請包括本文地址:https://spaces.ac.cn/archives/7381

更詳細的轉載事宜請參考:《科學空間FAQ》

相關文章