從GAN到WGAN的來龍去脈

ZhiboZhao發表於2021-07-16

一、原始GAN的理論分析

1.1 數學描述

其實GAN的原理很好理解,網路結構主要包含生成器 (generator) 和鑑別器 (discriminator) ,資料主要包括目標樣本 \(x_r \sim P_{r}\), 隨機輸入樣本 \(z \sim P_{z}\) 。生成器的目的就是根據 \(z\) 生成 \(G(z) \sim P_{r}\) ,而鑑別器則儘量區分出來 \(G(z)\)\(x_{r}\) 的不同。生成器和鑑別器採用生成對抗的方式不斷優化,最終能通過生成器得到期望輸出(比如風格轉換,人臉生成等)。聯想到電影《無雙》的情節,生成器就是造假幣的機器,而鑑別器可以類似為鑑別假幣的手段。在初始情況下,假幣制造機只能生成不是很逼真的假幣,此時鑑別器很輕鬆就能鑑別出來,於是便優化流程和材料,鑑別器鑑別錯誤之後再改進判別手段......如此往復,最終我們可以得到足以以假亂真的假幣。

鑑別器 $D(input;\theta_{d})$ 的目標是對輸入的資料做出準確的判斷,因此目標函式為: $$ \mathop{max}\limits_{D}[E_{x\sim P_{r}}logD(x;\theta_{d})]+E_{z\sim P_{z}}[log(1-D(G(z);\theta_{d}))] (公式1) $$ 生成器 $G(input;\theta_{g})$ 的目標就是輸出的資料儘可能與目標樣本接近騙過鑑別器 $D$,因此: $$ \mathop{min}\limits_{G}E_{z\sim P_{z}}[log(1-D(G(z);\theta_{d}))] (公式2) $$ 因此總的目標函式可以寫為: $$ \mathop{min}\limits_{G}\mathop{max}\limits_{D}[E_{x\sim P_{r}}logD(x;\theta_{d})]+E_{z\sim P_{z}}[log(1-D(G(z);\theta_{d}))] (公式3) $$ 借用原論文的符號,我們將生成器輸出的概率分佈記為 $P_{g}$,於是公式 (3) 可以記為: $$ \mathop{min}\limits_{G}\mathop{max}\limits_{D}[E_{x\sim P_{r}}logD(x;\theta_{d})]+E_{x\sim P_{g}}[log(1-D(x;\theta_{d}))] (公式4) $$

1.2 求出全域性最優解

當固定 \(G\) 的引數時,優化 \(D\) 的引數:

\[V_{D} = [E_{x\sim P_{r}}logD(x;\theta_{d})]+E_{x\sim P_{g}}[log(1-D(x;\theta_{d}))] \\ V_{D} = \int P_{r}logD(x)dx+\int P_{g}log(1-D(x))dx = \int (P_{r}logD(x)+P_{g}log(1-D(x))dx \\ (公式5) \]

因此,最大值為:

\[\dfrac{\partial{V_{D}}}{\partial{D}} = \dfrac{\partial{}}{\partial{D}}\int (P_{r}logD(x)+P_{g}log(1-D(x))dx \\ \dfrac{\partial{V_{D}}}{\partial{D}} = \int \dfrac{P_{r}}{D(x)}-\dfrac{P_{g}}{1-D(x)}dx = 0\\ (公式6) \]

解得:

\[D^{*}(x) = \dfrac{P_{r}}{P_{r}+P_{g}} (公式7) \]

於是,將 \(D^{*}(x)\) 帶入到公式 (4) 中,得到:

\[V_{G} = [E_{x\sim P_{r}}log\dfrac{P_{r}}{P_{r}+P_{g}}]+E_{x\sim P_{g}}[log(1-\dfrac{P_{r}}{P_{r}+P_{g}})] \\ (公式8) \]

即:

\[V_{G} = [E_{x\sim P_{r}}log\dfrac{P_{r}}{P_{r}+P_{g}}]+E_{x\sim P_{g}}[log(\dfrac{P_{g}}{P_{r}+P_{g}})] (公式9) \]

由於\(P_{r}+P{g} \in [0,2]\),因此公式 (10) 可以寫為:

\[V_{G} = [E_{x\sim P_{r}}log\dfrac{P_{r}}{(P_{r}+P_{g})/2}\times \dfrac{1}{2}]+E_{x\sim P_{g}}[log(\dfrac{P_{g}}{(P_{r}+P_{g})/2}\times \dfrac{1}{2})] \\ V_{G} = KL(P_{r}|| \dfrac{P_{r}+P_{g}}{2})+log \dfrac{1}{2}+KL(P_{g}|| \dfrac{P_{r}+P_{g}}{2})+log \dfrac{1}{2} \\(公式10) \]

最終:

\[V_{G} = KL(P_{r}|| \dfrac{P_{r}+P_{g}}{2})+KL(P_{g}|| \dfrac{P_{r}+P_{g}}{2})-2log2 (公式11) \]

因此,當 \(P_{r} = \dfrac{P_{r}+P_{g}}{2} = P_{g}\) 時,存在唯一極小值 \(P_{r} = P_{g}\),此時 \(D^{*}(x) = \dfrac{1}{2}\)。即公式 (4) 存在全域性最優解,在全域性最優解的情況下,生成器生成的概率分佈與目標樣本概率分佈一樣,此時鑑別器無法準確判斷生成樣本與目標樣本的差異,判斷正確和錯誤的概率各為0.5,類似於瞎猜。

1.3 原始GAN到底出了什麼問題?

GAN的訓練是依靠生成器和鑑別器的相互對抗來完成的,那麼直觀地思考一下:如果鑑別器過於差勁,給不到生成器任何有用的資訊,那麼生成器的更新就會沒有方向;如果鑑別器太好,那麼類似於造假幣的機器極其差,而鑑別器直接就是驗鈔機,那麼直觀上也無法給生成器提供足夠的資訊去更新。因此,原始的GAN理論上可行,而實際上卻受到鑑別器和生成器狀態的影響,不一定能找到最優解,且訓練不穩定。

從數學角度上來描述:我們在 1.2節 求全域性最優解的過程中,先求出了鑑別器 \(D\) 的最優解,然後得到了公式 (11) ,在這種情況下相當於我用已經訓練好的鑑別器來指導生成器的學習,將概率分佈從 \(P_{z}\) 拉向 \(P_{r}\)。乍一看沒什麼問題,但是如果兩個分佈 \(P_{r}\)\(P_{z}\) 完全沒有重疊的部分,或者它們重疊的部分可忽略,會發生什麼情況呢?答案是無論換句話說,無論 \(P_{r}\)\(P_{g}\)是遠在天邊,還是近在眼前,只要它們倆沒有一點重疊或者重疊部分可忽略,公式 (11) 散度就固定是常數 \(log2\),而這對於梯度下降方法意味著——梯度為0!此時對於最優判別器來說,生成器肯定是得不到一丁點梯度資訊的;即使對於接近最優的判別器來說,生成器也有很大機會面臨梯度消失的問題。與我們直觀上的感覺一致。

那麼問題就變成了\(P_{r}\)\(P_{z}\) 沒有重疊的部分的概率大嗎?答案是非常大。首先,\(P_{r}\) 是一個複雜分佈,而 \(P_{z}\) 則是一個簡單分佈,所以在空間上二者不重疊的概率很大。更重要的一個原因是,輸入 \(z \sim P_{r}\) 一般是 100 維,而生成的目標往往是一張圖片,比如 \(64 \times 64\) 就是 \(4096\) 維,低維與高維相重合本來就很少,因此更加證明了原始GAN不容易訓練。總結下來:

原始GAN存在梯度不穩定的問題,即判別器訓練得太好,生成器梯度消失,生成器loss降不下去;判別器訓練得不好,生成器梯度不準,四處亂跑。只有判別器訓練得不好不壞才行,但是這個火候又很難把握,甚至在同一輪訓練的前後不同階段這個火候都可能不一樣,所以GAN才那麼難訓練。 此外,GAN還存在模式崩塌(collapse mode)的問題,即生成樣本多樣性不足。

二、WGAN的前世今生

為了解決原始GAN梯度不穩定的問題,一個過渡的解決方案是強行對生成樣本和真實樣本加噪聲,使得原本兩個分佈彌散到整個高維空間,增加重疊部分。當二者出現重疊部分時,再把噪聲拿掉,這樣也能夠繼續收斂。這只是一個折中的方案,並沒有從本質上解決問題。

2.1 Wasserstein 距離

Wasserstein 距離又叫 Earth-Mover ( EM ) 距離,定義如下:

\[W(P_{r},P_{g}) = \mathop{inf}\limits_{\gamma \sim \prod (P_{r}, P_{g})}E_{(x,y)\sim \gamma}[||x-y||] (公式12) \]

其中:\(\prod (P_{r}, P_{g})\) 表示從概率 \(P_{g}\)\(P_{r}\) 的所有可能分佈,而 \(W(P_{r},P_{g})\) 代表所有可能的分佈中, \(||x-y||\) 的最小期望值距離。舉個例子:如下圖所示,假如將左側的方塊運送到右側的位置,那麼方案有很多種,其中最小的那一種移動所花的消耗即為Wasserstein距離。

**因此,Wasserstein的好處就是無論兩個分佈是否有重疊部分,Wasserstein距離都是連續的,能夠反映兩個分佈的遠近,而JS散度和KL散度既不能反映遠近,也提供不了梯度。**所以,EM距離更適合用作GAN的loss function。

2.2 從EM距離到WGAN

由於在Wasserstein中,\(\mathop{inf}\limits_{\gamma \sim \prod (P_{r}, P_{g})}\) 沒辦法直接求解,因此WGAN的作者通過已有的定理將其轉換成如下形式:

\[W(P_{r},P_{g}) = \dfrac{1}{K} \mathop{sup}\limits_{||f||_{L}<K} E_{x \sim P_{r}}[f(x)]-E_{x \sim P_{g}} [f(x)] (公式13) \]

式子的證明過程對我來說確實難以理解,因此這裡就不作解釋了,有興趣的可以參考WGAN的原論文。最後,WGAN的loss function變成了下面的形式:

\[W(P_{r},P_{g}) = \dfrac{1}{K} \mathop{max}\limits_{||f_{w}||_{L}<K} E_{x \sim P_{r}}[f_{w}(x)]-E_{x \sim P_{g}} [f_{w}(x)] (公式14) \]

於是,可以把函式 \(f\) 用一個引數為 \(w\) 的神經網路來表示。最後,為了滿足 \(||f_{w}||_{L}<K\) 的限制,將神經網路的所有引數 \(w\) 都拉伸到 \([-c,c]\) 中,所以一定滿足Lipschitz連續條件。

因此,我們可以構造一個含引數 \(w\)、最後一層不是非線性啟用層的判別器網路 \(f_{w}\),在限制! \(w\) 不超過某個範圍的條件下,使得:

\[L = E_{x \sim P_{r}}[f_{w}(x)]-E_{x \sim P_{g}} [f_{w}(x)] (公式15) \]

儘可能取到最大,此時的 \(L\) 就可以近似為真實分佈 \(P_{r}\) 與生成分佈 \(P_{g}\) 之間的Wasserstein距離。注意:原始GAN的判別器做的時二分類任務,所以最後一層採用 \(sigmoid\) 函式,而WGAN中的判別器做的是擬合 Wasserstein 距離,屬於迴歸任務,因此把最後一層的 \(sigmoid\) 去掉。

因此判別器的loss function為:

\[E_{x \sim P_{g}}[f_{w}(x)]-E_{x \sim P_{r}} [f_{w}(x)] (公式16) \]

生成器的loss function為:

\[-E_{x \sim P_{g}}[f_{w}(x)] (公式17) \]

所以,不管理論再複雜, WGAN在原始的GAN上只做了三點改進:

  • 判別器最後一層去掉sigmoid
  • 生成器和判別器的loss不取log
  • 每次更新判別器的引數之後把它們的絕對值截斷到不超過一個固定常數c

最後,作者通過經驗發現,不要使用Adam優化演算法,推薦RMSProp或者SGD。

2.3 模型崩塌(collapse mode)問題的解決方法

上述解決了GAN在訓練過程中梯度不穩定的問題,那麼模型崩塌(collapse mode)問題的解決方法如下:

2.3.1 在loss function 層面

通常先更新幾輪生成器,之後再更新一輪鑑別器。因為GAN的訓練是 \(min max\) 的策略,即先更新鑑別器,然後再更新生成器。往往在迭代的過程中,生成器和鑑別器交替優化,容易將問題變成 \(maxmin\) 的問題,這樣一來就變成了:生成器先生成一個輸出,然後鑑別器對這個輸出進行判斷,那麼生成器最後學習到的往往是最保險的,導致模型崩塌(collapse mode),生成樣本多樣性不足。

2.3.2 在網路結構方面

1、採用多個生成器和一個鑑別器,類似於曠視“先發散再收斂”的學習策略,通過正則化約束生成器之間的比重,生成多樣性的樣本。

2、將真實樣本通過一個編碼器 (Encoder) 後再使用生成器進行重構,如下圖所示:

那麼 \(D_{M}\)\(R\) 用來指導生成對應的樣本,而 \(D_{D}\) 則對 \(G(z)\)\(G(E(x))\) 進行判別,顯然二者都是生成的樣本,差別越大那麼表明生成樣本的多樣性越高。

3、Mini-batch discrimination在判別器的中間層建立一個mini-batch layer用於計算基於 \(L_{1}\) 距離的樣本統計量,通過建立該統計量去判別一個batch內某個樣本與其他樣本有多接近。這個資訊可以被判別器利用到,從而甄別出哪些缺乏多樣性的樣本。對生成器而言,則要試圖生成具有多樣性的樣本。

2.4 WGAN 部分程式碼分析

self.G_sample = self.generator(self.z)

self.D_real, _ = self.discriminator(self.X)
self.D_fake, _ = self.discriminator(self.G_sample, reuse = True)

# loss
self.D_loss = - tf.reduce_mean(self.D_real) + tf.reduce_mean(self.D_fake)
self.G_loss = - tf.reduce_mean(self.D_fake)

self.D_solver = tf.train.RMSPropOptimizer(learning_rate=1e-4).minimize(self.D_loss, var_list=self.discriminator.vars)
self.G_solver = tf.train.RMSPropOptimizer(learning_rate=1e-4).minimize(self.G_loss, var_list=self.generator.vars)

# clip
self.clip_D = [var.assign(tf.clip_by_value(var, -0.01, 0.01)) for var in self.discriminator.vars]

然後按照正常的GAN訓練即可。

相關文章