快速瞭解 變分自編碼器 VAE

倒地發表於2024-03-25

概述

變分自編碼器(Variational Auto-Encoders,VAE)是自編碼器 AE 的變體,由 Kingma 等人於 2014 年提出的生成式網路結構。以機率的方式描述潛在空間,在資料生成方面潛力巨大。

自編碼器 AE

自編碼器(Auto-Encoder,AE),是一種無監督式學習模型。它可以將輸入 \(X\) 對映為資料量小得多的潛在表示 \(h\),並能透過 \(h\) 嘗試還原輸入 \(X'\)

AE 包含兩個部分:

  • Encoder 編碼器,將輸入 \(X\) 編碼為潛在表示 \(h\)
  • Decoder 解碼器,利用 \(h\) 重構輸入 \(X'\)

AE 有著諸多好處。潛在表示 \(h\) 可以視為輸入的重要特徵,可以進行資料降維與壓縮;在解碼器重建資料時可以去除資料噪聲,提高模型對噪聲輸入的魯棒性;比起 CNN,訓練不需要使用帶標籤的影像(無監督訓練)。

不過要尤其注意過擬合問題。例如,AE 完全可以只用 \(h\) 中的一個數字 “死記硬背” 訓練集中的每張圖片 \(X\),這顯然不是我們所期望的結果。

變分自編碼器 VAE

VAE 對 AE 做了兩個改動。

VAE 讓編碼器能夠輸出均值和方差,在推理階段則從這樣的正態分佈裡取樣一個資料,作為解碼器的輸入。直觀上看,這一改動就是在 AE 的基礎上,讓編碼器多輸出了一個方差,使原 AE 編碼器的輸出發生了一點隨機擾動。

AE 的訓練目標是,解碼器輸出儘可能與編碼器輸入相似。VAE 在此基礎上增加了一項訓練目標:讓編碼器輸出儘可能貼近標準正態分佈。

作為結果,VAE 的解碼器被強迫從標準正態分佈重建 \(X'\),有效解決了過擬合問題。並且由於標準正態分佈可以不依靠編碼器隨機生成,VAE 還適合用於憑空生成新影像。

除了 VAE ,DDPM(Denoising Diffusion Probabilistic Model)也能處理過擬合問題,並且效果更好。

VAE 的損失函式

VAE 的 loss 函式包含兩個部分:重構損失(Reconstruct Loss)和 KL 散度。

\[\text{loss}=\text{MSE}(X,X')+\text{KL}(N(\mu,\sigma^2),N(0,1)) \]

Reconstruct Loss 是解碼器輸出 \(X'\)編碼器輸入 \(X\) 之間的 MSE 損失,反映了 VAE 網路生成結果與輸入資料的差異。

KL 散度意圖獲知編碼器輸出的變數分佈與標準正態分佈的差距,網路訓練時期望這個差距越來越小。

KL 散度項的推導

KL 散度(Kullback-Leibler Divergence)是用來度量兩個機率分佈相似度的指標。

針對離散的隨機變數 \(x\),假設有兩種機率分佈 \(P\)\(Q\),則 \(P\)\(Q\) 的 KL 散度為:

\[D_{KL}(P||Q)=\sum_{i}p(x_i)\ln \frac{p(x_i)}{q(x_i)} \]

可見,若兩種分佈完全一致,KL 散度達到最小值 0。\(P\)\(Q\) 差距越大,KL 散度也就越大。

針對 VAE 損失函式中的 \(\text{KL}(N(\mu,\sigma^2),N(0,1))\) 項,\(N(\mu,\sigma^2)\)\(N(0,1)\) 的機率密度函式分別為

\[p(x)=\frac{1}{\sqrt{2\pi \sigma^2}}e^{-\frac{(x-\mu)^2}{2\sigma^2}} \]

\[q(x)=\frac{1}{\sqrt{2\pi}}e^{-\frac{x^2}{2}} \]

帶入到 KL 散度計算公式,可化簡得到

\[\text{KL}(N(\mu,\sigma^2),N(0,1))=\frac{1}{2}(\mu^2+\sigma^2-2\ln (\sigma)-1) \]

VAE 的實現細節

重引數化技巧

VAE 的 encoder 從正態分佈中取樣資料,這個過程是不可微的。這導致梯度會在此不可傳遞,網路無法訓練。

重引數化技巧(reparameterization trick)使得我們可以從帶可變引數 \(\theta\) 的分佈 \(p_\theta(x)\) 中取樣,保留梯度資訊。

具體來說,我們不直接從 \(N(\mu,\sigma^2)\) 中取樣,而是先從 \(N(0,1)\) 取樣,再用 \(\mu\)\(\sigma\) 對取樣結果進行線性變換。這不就相當於也是 \(N(\mu,\sigma^2)\) 的取樣結果。

對數方差

VAE 的 encoder 輸出一組均值和方差,以供取樣。

方差必須為非負數,而網路的輸出可正可負。將網路輸出視為對數方差 \(\ln \sigma\) 會更方便。

避免後驗塌縮

後驗坍塌(Posterior Collapse)問題,可以說是 VAE 獨有的煩惱。

簡單來說,若 decoder 足夠強,強到能從純噪聲中生成理想結果,encoder 就失效了。具體體現是損失函式中的 KL 散度項幾乎為 0,整體 loss 降不下去。

相關各種解決方法可以參考 這個文章

參考來源

  • VoidOc​,“【深度學習】 自編碼器(AutoEncoder)”,https://zhuanlan.zhihu.com/p/133207206
  • 周弈帆,“Stable Diffusion 解讀(一):回顧早期工作”,https://zhuanlan.zhihu.com/p/676705162
  • 哇哦,“VAE模型解析(loss函式,調參...)”,https://zhuanlan.zhihu.com/p/578619659
  • 撿到一束光,“關於KL散度(Kullback-Leibler Divergence)的筆記”,https://zhuanlan.zhihu.com/p/438129018
  • 瑪卡巴卡Bu漆糖,“KL散度 (Kullback-Leibler divergence)”,https://zhuanlan.zhihu.com/p/521804938
  • 馬東什麼,“vae和重引數化技巧”,https://zhuanlan.zhihu.com/p/570269617
  • 半畝糖,“如何避免VAE後驗坍塌?(總)”,https://zhuanlan.zhihu.com/p/389295612

相關文章