上一期中,我們說明了GAN訓練中的幾個問題,例如由於把判別器訓練得太好而引起的梯度消失的問題、通過取樣估算距離而造成偏差的問題、minmax問題不清晰以及模式崩潰、優化選擇在引數空間而非函式空間的問題等,今天這篇小文將從博弈論的角度出發來審視一下GAN訓練時的問題,說明訓練GAN其實是在尋找納什均衡,然後說明達到納什均衡或者說損失函式收斂是很難的,並最後給出了3個穩定訓練的小技巧。
作者 | 小米粥
編輯 | 言有三
1 博弈論與GAN
大家對GAN的基本模型想必已經非常熟悉了,我們先從博弈論的角度來重新描述GAN模型。遊戲中有兩個玩家:D(判別器)和G(生成器),D試圖在判別器的引數空間上尋找最好的解使得它的損失函式最小:
G也試圖在生成器的引數空間上尋找最好的解使得它的損失函式最小:
需要說明,D和G並不是彼此獨立的,對於GAN,整個博弈是“交替進行決策”的。例如先確定生成器G的引數,則D會在給定的G的引數的條件下更新判別器的引數以此最小化D的損失函式,如下面中藍線過程(提升D的辨別能力);接著G會在給定的D的引數的條件下更新判別器的引數以此來最小化G的損失函式,如下面中綠線過程(提升G的生成能力)......直到達到一個穩定的狀態:納什均衡。
在納什均衡點,兩者的引數到達一種“制衡”狀態。在給定G的引數情況下,D當前的引數便對應了D損失函式的最小值,同樣在給定D的引數情況下,G當前的引數便對應了G損失函式的最小值,也就是說在交替更新過程中,D和G均不可能單獨做出任何改變。
解空間中可能存在多個納什均衡點,而且納什均衡點並不意味著全域性最優解,但是是一種經過多次博弈後的穩定狀態,所以說GAN的任務是並非尋找全域性最優解,而是尋找一個納什均衡狀態,損失函式收斂即可。在損失函式非凸、引數連續、引數空間維度很高的情況下,不可能通過嚴格的數學計算去更新引數從而找到納什均衡,在GAN中,每次引數更新(對應藍線、綠線表示的過程)使用的是梯度下降法;另外,每次D或者G對自身引數更新都會減少自身的損失函式同時加大對方的損失函式,這導致了尋找GAN的納什均衡是比較困難的。
這裡有一個比GAN簡單多的例子表明很多時候納什均衡的狀態難以達到:
使用梯度下降法發現x,y在引數空間中並不會收斂到納什均衡點(0,0),損失函式的表現為:不收斂。
針對GAN訓練的收斂性問題,我們接下來將介紹幾種啟發式的訓練技巧。
2 特徵匹配
在GAN中,判別器D輸出一個0到1之間的標量表示接受的樣本來源於真實資料集的概率,而生成器的訓練目標就是努力使得該標量值最大。如果從特徵匹配(feature matching)的角度來看,整個判別器D(x)由兩部分功能組成,先通過前半部分f(x)提取到樣本的抽象特徵,後半部分的神經網路根據抽象特徵進行判定分類,即
f(x)表示判別器中截止到中間某層神經元啟用函式的輸出。在訓練判別器時,我們試圖找到一種能夠區分兩類樣本的特徵提取方式f(x),而在訓練生成器的時候,我們可以不再關注D(x)的概率輸出,我們可以關注:從生成器生成樣本中用f(x)提取的抽象特徵是否與在真實樣本中用f(x)提取的抽象特徵相匹配,另外,為了匹配這兩個抽象特徵的分佈,考慮其一階統計特徵:均值,即可將生成器的目標函式改寫為:
採用這樣的方式,我們可以讓生成器不過度訓練,讓訓練過程相對穩定一些。
3 歷史均值
歷史均值(historical averaging)是一個非常簡單方法,就是在生成器或者判別器的損失函式中新增一項:
這樣做使得判別器或者生成器的引數不會突然產生較大的波動,直覺上看,在快要達到納什均衡點時,引數會在納什均衡點附近不斷調整而不容易跑出去。這個技巧在處理低維問題時確實有助於進入納什均衡狀態從而使損失函式收斂,但是GAN中面臨的是高維問題,助力可能有限。
4 單側標籤平滑
標籤平滑(label smoothing)方法最開始在1980s就提出過,它在分類問題上具有非常廣泛的應用,主要是為了解決過擬合問題。一般的,我們的分類器最後一層使用softmax層輸出分類概率(Sigmoid只是softmax的特殊情況),我們用二分類softmax函式來說明一下標籤平滑的效果。
對於給定的樣本x,其類別為1,則標籤為[1,0],如果不用標籤平滑,只使用“硬”標籤,其交叉熵損失函式為:
這時候通過最小化交叉熵損失函式來訓練分類器,本質上是使得:
其實也就是使得:
對於給定的樣本x,使z1的值無限大(當然這在實際中是不可能的)而使z2趨於0,無休止擬合該標籤1,便產生了過擬合、降低了分類器的泛化能力。如果使用標籤平滑手段,對給定的樣本x,其類別為1,例如平滑標籤為[1-ε ,ε],交叉損失函式為:
當損失函式達到最小值時,有:
選擇合適的引數,理論上的最優解z1與z2存在固定的常數差值(此差值由ε決定),便不會出現z1無限大,遠大於z2的情況了。如果將此技巧用在GAN的判別器中,即對生成器生成的樣本輸出概率值0變為β ,則生成器生成的單樣本交叉熵損失函式為:
而對資料集中的樣本打標籤由1降為α,則資料集中的單樣本交叉熵損失函式為:
總交叉損失函式為:
求導容易得其最優解D(x)為:
實際訓練中,有大量這樣的x:其在訓練資料集中概率分佈為0,而在生成器生成的概率分佈不為0,他們經過判別器後輸出為β。為了能迅速“識破”該樣本,最好將β降為0,這就是所謂的單側標籤平滑。
訓練GAN時,我們對它的要求並不是找到全域性最優解,能進入一個納什均衡狀態、損失函式收斂就可以了。(雖然這個納什均衡狀態可能非常糟糕)最近的幾篇文章將著重於討論GAN訓練的收斂問題。
[1] Müller, Rafael, S. Kornblith , and G. Hinton . "When Does Label Smoothing Help?." 2019
[2] Salimans T , Goodfellow I , Zaremba W , et al. Improved Techniques for Training GANs[J]. 2016.
總結
這篇文章闡述了GAN的訓練其實是一個尋找納什均衡狀態的過程,然而想採用梯度下降達到收斂是比較難的,最後給出了幾條啟發式的方法幫助訓練收斂。