大名鼎鼎的深度學習之父Yann LeCun曾評價GAN是“20年來機器學習領域最酷的想法”。的確,GAN向世人展示了從無到有、無中生有的神奇過程,並且GAN已經在工業界有著廣泛的應用,是一項令人非常激動的AI技術。今天我將和大家一起去了解GAN及其內部工作原理,洞開GAN的大門。
本文儘量用淺顯易懂的語言來進行表述,少用繁瑣的數學公式,並對幾個典型的GAN模型進行講解。
一、GAN(GenerativeAdversarial Networks)
GAN全名叫Generative Adversarial Networks,即生成對抗網路,是一種典型的無監督學習方法。在GAN出現之前,一般是用AE(AutoEncoder)的方法來做影像生成的,但是得到的影像比較模糊,效果始終都不理想。直到2014年,Goodfellow大神在NIPS2014會議上首次提出了GAN,使得GAN第一次進入了人們的眼簾並大放異彩,到目前為止GAN的變種已經超過400種,並且CVPR2018收錄的論文中有三分之一的主題和GAN有關,可見GAN仍然是當今一大熱門研究方向。
GAN的應用場景非常廣泛,主要有以下幾個方面:
1.影像、音訊生成。比如訓練資料的生成。
2.影像翻譯。從真實場景的影像到漫畫風格的影像、風景畫與油畫間的風格互換等等。
3.影像修復。比如影像去噪、去除影像中的馬賽克(嘿嘿…)。
4.影像超解析度重建。衛星、遙感以及醫學影像中用的比較多,大大提升後續的處理精度。
(一) GAN原理簡述
GAN的原理表現為對抗哲學,舉個例子:警察和小偷的故事,二者滿足兩個對抗條件:
1.小偷不停的更新偷盜技術以避免被抓。
2.警察不停的發現新的方法與工具來抓小偷。
小偷想要不被抓就要去學習國外的先進偷盜技術,而警察想要抓到小偷就要儘可能的去掌握小偷的偷盜習性。兩者在博弈的過程中不斷的總結經驗、吸取教訓,從而都得到穩步的提升,這就是對抗哲學的精髓所在。要注意這個過程一定是一個交替的過程,也就是說兩者是交替提升的。想象一下,如果一開始警察就很強大,把所有小偷全部抓光了,那麼在沒有了小偷之後警察也不會再去學習新的知識了,偵查能力就得不到提升。反之亦然,如果小偷剛開始就很強大,警察根本抓不到小偷,那麼小偷也沒有動力學習新的偷盜技術了,小偷的偷盜能力也得不到提升,這就好比在訓練神經網路時出現了梯度消失一樣。所以一定是一個動態博弈的過程,這也是GAN最顯著的特性之一。
在講完了警察與小偷的故事之後,我們引入今天的主人公——GAN。
(二) 模型架構圖
從上圖能夠看出GAN的整個網路架構是非常簡單明瞭的,GAN由一個生成器(Generator)和一個判別器(Discriminator)組成, 兩者的結構都是多層感知機(MLP),具體有多少層、每層多少個神經元可以根據實際情況自行設計,比較靈活。在這裡,生成器充當著“小偷”的角色,判別器就扮演“警察”的角色。為了方便講解,後面把生成器簡稱為G,判別器簡稱為D。
G:接收一個隨機噪聲向量z(比如z服從高斯分佈),G的目標就是透過這個噪聲來生成一個像真實樣本的假樣本。
D:判別一個樣本是真實樣本還是G自己造的假樣本。它接收一個樣本資料作為輸入,所以這個樣本可以是G生成的假樣本也可以是真實樣本。它輸出一個標量,標量的數值代表了輸入樣本到底是真實樣本還是G生成的假樣本的機率。如果接近1,則代表是真實樣本,接近於0則代表是生成器生成的假樣本,所以此時D最後一層的啟用函式一定為sigmoid。
網路的最終目標是在D很強大的同時,G生成的假樣本送給D後其輸出值變為0.5,說明G已經完全騙過了D,即D已經區分不出來輸入的樣本到底是還是,從而得到一個生成效果很好的G。
損失函式的設計:
從上面的式子可以看出,損失函式是兩個分佈各自期望的和,其中是真實資料的機率分佈,是生成器所生成的假樣本的機率分佈。對於D,它的目的是讓中的樣本的輸出結果儘可能的大,即變大,而讓生成的樣本x的輸出結果儘可能的小,即
變大,導致變大。對於G,它的目的是用噪聲z來生成一個假樣本x並讓D給出一個較大的值,即讓變小,導致變小。綜上,我們得出:
(三) GAN的訓練流程
假設batch_size=m,則在每一個epoch中:
先訓練判別器k(比如3)次:
1. 從噪聲分佈z(比如高斯分佈)中隨機取樣出m個噪聲向量:。
2. 從真實樣本x中隨機取樣出m個樣本:。
3. 用梯度下降法使損失函式: 與1之間的二分類交叉熵減小(因為最後判別器最後一層的啟用函式為sigmoid,所以要與0或者1做二分類交叉熵,這也是為什麼損失函式要取log的原因)。
4. 用梯度下降法使損失函式:與0之間的二分類交叉熵減小。
5. 所以判別器的總損失函式即讓d_loss越小越好。注意在訓練判別器的時候生成器中的所有引數要固定住,即不參加訓練。
再訓練生成器1次:
1. 從噪聲分佈中隨機取樣出m個噪聲向量:。
2. 用梯度下降法使損失函式:與1之間的二分類交叉熵減小。
3. 所以生成器的損失函式即讓g_loss越小越好。注意在訓生成器的時候判別器中的所有引數要固定住,即不參加訓練。
直到所有epoch執行完畢,訓練結束。
從訓練方法中可以看出,生成器和判別器是交替進行訓練的,呈現出一種動態博弈的思想,非常有意思。不過在訓練的時候還有一些注意事項:
1.在訓練G的時候D中的引數不參加訓練,即不需要梯度反傳。同理,訓練D的時候G中的引數不參加訓練。
2.為了讓D保持在一個相對較高的評判水平,從而更好的訓練G。在每一個epoch內,先對D進行k(比如k=3)次訓練,然後訓練G一次,加快網路的收斂速度。
3.在原始論文中,作者在訓練G的時候給出的公式是,然而這個公式有一些隱患,因為在訓練的初始階段,G生成的樣本和真實樣本間的差異一般會很大,此時D能很輕鬆的分辨兩種樣本,導致一直趨近於0,此時梯度消失,G也就得不到訓練,所以這裡的策略是,上面訓練過程的闡述中已經對該處的損失函式做了更正。
(四) 損失函式相關數學推導
我們先將G中的引數固定住,此時的噪聲向量透過G後所生成的樣本是一一對應的,則有如下對映:
由此將由兩個數學期望的和組成的展開:
由於和是固定的常量,另它們等於a,b。令,得到,由於是唯一極值點,則必為最值點,也能夠證明在時,其二階導小於0,那麼該最值點為全域性最大值點。
所以,當G固定住的時候,不斷的訓練D中的引數,理論上可以讓D達到最大值:
此時將帶入進中,得到:
對於兩個機率分佈和,它們之間的KL散度就是資料的原始分佈與近似分佈的機率的對數差的期望值,其公式為:
所以此時得到:
再將兩個KL散度的和合併成JS散度,得到:
從上式可以看出,如果G要讓最小,必須要讓和間的JS散度最小,而JS散度的最小值為0,此時兩個分佈完全重合,即理論上的最小值為,此時存在唯一解:
使得損失函式達到全域性最小值,即生成器完美的實現了生成真實資料的過程,完全掌握了真實資料的機率分佈。
(五) 總結
1.GAN的開山之作。
2.GAN的本質其實是利用神經網路強大的非線性擬合能力來學習從一個任意先驗的噪聲分佈到真實資料分佈的非線性對映,從而讓生成器具有能夠產生逼真樣本的能力。
3.早期GAN的訓練非常不穩定導致訓練難度大,還容易出現梯度爆炸、mode collapse等問題。mode collapse的意思就是生成的樣本大量集中於部分真實樣本,那麼就是很嚴重的mode collapse。以生成動漫頭像圖片為例,從下圖中能夠明顯的看出,紅框標記的影像重複出現了很多次,即存在一定的mode collapse。
二、DCGAN(Deep Convolutional Generative Adversarial Networks)
在GAN被提出之後,GAN的熱度曲線呈指數式增長,期間在原始GAN的結構基礎上進行改進的GAN變種層出不窮,其中最具代表性的當屬DCGAN了,我們來看看它對原始GAN有什麼創新:
1.將兩個多層感知機替換為兩個卷積神經網路。即將CNN融合進GAN中,極大的加速了GAN在影像領域中應用的步伐,此後許多新提出的GAN都一直在沿用DCGAN的網路架構。
2.創新性的將反摺積(也叫轉置卷積)操作應用於生成器中。
3.透過大量實驗,總結出一套構建網路時很有用的trick。
(一) 反摺積
常見的上取樣方式有三種:雙線性插值,反摺積(也叫轉置卷積)和反池化。鑑於篇幅所限,除反摺積以外的兩種上取樣方法就不在這裡介紹了。
常規的卷積操作一般會導致影像的尺寸越來越小,同時影像深度在逐漸增加。而反摺積則使影像尺寸越來越大,而深度在逐漸減小。所以反摺積是卷積操作的逆運算,也就是說反摺積的正向傳播是卷積的反向傳播,其反向傳播是卷積的正向傳播,本文力求用形象的過程來展現反摺積的工作原理(注:下文所闡述的反摺積工作方式為tensorflow機器學習框架反摺積的底層實現方法,其他框架的底層實現方法可能略有不同)。
若輸入為3*3大小的單通道影像:
考慮卷積核大小kernel_size=3*3,stride=2,padding=same的反摺積操作,且卷積核為:
如果stride=2,那麼就在輸入影像的每行和每列之間插入(stride-1)行(列)的零元素,另外還需要在補零後的矩陣的左邊和上邊新增額外的(stride-1)行(列)的零元素:
如果卷積核的大小kernel_size=3,且padding=same的情況下,我們知道在正常的卷積模式下是要上、下、左、右各新增(kernel_size-1)/2個行(列)元素,他們的初始值都為0,以此來保證輸出影像與輸入影像的大小是相同的,所以這裡也採取相同的padding操作。這裡簡單說明一下:如果kernel_size=4,那麼(kernel_size-1)/2=1.5,無法整除,那麼此時左方和上方新增一行(列)零元素,右方和下方新增兩行(列)零元素,總之要保證新增的總行(列)數要和kernel_size-1是相等的,這也是tensorflow機器學習框架在卷積操作中padding=same時的填補方法。所以現在輸入影像變成了這樣:
此時輸入影像的尺寸由3*3變成了8*8,我們用kernel_size=3,stride=1,padding=valid的方式對這張圖進行常規的卷積操作,則輸出尺寸變為:H=(8+0-3)/1+1=6,W=(8+0-3)/1+1=6。注意這步操作中的kernel_size是和反摺積核的kernel_size是保持一致的,stride固定為1,而且不進行padding操作,因為前面已經padding過了,得到:
我們用tensorflow做個小實驗,來驗證上面演算法的正確性。
輸出:
輸出結果和我們自己推導的完全一致!可見,反摺積也僅僅是卷積操作而已,與正常卷積使用相同大小的卷積核,只不過反摺積需要透過特定的規則對輸入tensor透過padding 0元素的方式處理一下。這樣我們最終得到的輸出影像尺寸要比原影像大,即實現了上取樣的功能。怎麼樣,是不是非常簡單。
反摺積的應用領域非常廣泛,不僅僅在GAN中,還在影像分割以及feature map的視覺化領域有著廣泛的應用。好了,簡單講完反摺積後,讓我們回到DCGAN。
(二) 網路實現上的一些tirck
1. 在生成器與判別器中,將所有池化層替換為步長大於1的卷積操作,即拋棄所有池化層,目的是讓網路去學習屬於它自己的上(下)取樣方式。想了一下,確實是非常有效的trick,因為在影像分割領域中,maxpooling操作會破會影像的邊緣與細節,導致分割結果很粗糙,所以一般都透過別的辦法來替代maxpooling,以保證分割結果的細節完好。
2. 移除全域性平均池化層,全域性平均池化在影像分類網路中有著舉足輕重的地位,作者在做實驗的過程中發現在判別器中用全域性平均池化再接全連線層雖然能夠增加模型的穩定性,但同時嚴重減緩了模型的收斂速度,所以決定移除。
3. 除了生成器的最後一層和判別器的輸入層,其餘層都做batch normalization操作。是一個非常有助於網路快速收斂的trick。作者發現如果全部層都用batch normalization,容易發生mode collapse現象,並使得模型變得不穩定。
4. 生成器最後一層的啟用函式採用tanh,其餘層為relu啟用函式。而判別器中則全部採用leaky relu啟用函式。
(三) DCGAN中生成器的網路結構
網路的整體架構和原始GAN是差不多的,不同的僅僅是生成器和判別器的內部結構,由MLP換成了CNN。從圖中來看,主要是由一個啟用函式為relu的全連線層,三個啟用函式為relu的反摺積層,以及最後的啟用函式為tanh的反摺積層,將一個長度為100滿足正態分佈(或者均勻分佈)的向量z變成一個大小為64*64的3通道影像,這也是生成器生成的最終影像。判別器在結構上與生成器是完全對稱的,類似於常規的分類網路,這裡不再贅述。
注意:由於生成器最後一層的啟用函式為tanh,因此輸出值的範圍在[-1, 1]上,所以真實圖片樣本也必須要進行縮放範圍一致的歸一化操作,即,令,將輸入樣本x上的畫素值都歸一化到[-1, 1]上,再將這個歸一化後的圖片送入判別器中,以此來保證每一個輸入進判別器的樣本分佈區間的一致性。當然也可以採用別的歸一化方法,只要能讓就好。
(四) 用DCGAN在MNIST資料集上訓練手寫數字生成
開原始碼倉庫地址: https://github.com/carpedm20/DCGAN-tensorflow
在訓練了30個epoch後,我把每個epoch生成器生成的100張圖片存下來並縮小做成動態圖:
可以看出並沒有出現mode collapse現象,生成樣本具有一定的多樣性,效果還不錯。其實主要還是資料集比較簡單,圖片比較小,複雜紋理資訊不多,比較容易生成。
生成的這些資料就可以用作手寫數字識別的訓練資料。但是這些資料是沒有標籤的,然而手寫數字識別為監督學習,難道還要對它們進行人工標註?這個問題我們留到下一個小節來解決。
(五) 總結
1.訓練方法和訓練原始GAN的方法保持一致。
2.將兩個MLP替換成為兩個CNN,生成的影像較原始GAN來說質量更高,更逼真。
3.透過大量實驗總結出一套非常有用的trick,使得DCGAN在訓練時的穩定性相比原始GAN有顯著改善,要知道原始GAN是非常難訓練的。
4.後面要講的模型中G和D的架構均和DCGAN保持一致,便不再贅述。
三、InfoGAN(InformationMaximizing Generative Adversarial Nets)
DCGAN已經能夠生成足夠逼真的影像了,但是它直接將噪聲向量z作為G的輸入,沒有為z新增任何限制,導致我們根本不知道G主要用到了z的哪個維度來生成圖片,即已經將z進行高度耦合處理,所以z的維度資訊對於真實資料來說不具有語義特徵,也就是說是不可解釋的。
拿上面的圖為例,我們發現第三個“7”中間出現了一個橫線,但是為什麼會出現這個橫線,誰也不知道,為了讓GAN具有可解釋性,比較有代表性的GAN變體——InfoGAN就出現啦,為了解決語義問題,InfoGAN的作者對損失函式進行了一些小的改進,一定程度上讓網路學習到了可解釋的特徵表示,即作者文中所說的interpretable reptesentation。
(一)原理闡述
既然要讓輸入的噪聲向量z帶有一定的語意資訊,那就人為的為它新增上一些限制,於是作者把G的輸入看成兩部分:一部分就是噪聲z,可以將它看成是不可壓縮的噪聲向量。另一部分是若干個離散的和連續的latent variables(潛變數)所拼接而成的向量c,用於代表生成資料的不同語意資訊。
以MNIST資料集為例,可以用一個離散的隨機變數(0-9,用於表示生成數字的具體數值)和兩個連續的隨機變數(假設用於表示筆劃的粗細與傾斜程度)。所以此時的c由一個離散的向量(長度為10)、兩個連續的向量(長度為1)拼接而成,即c長度為12。
上圖是作者用InfoGAN在MNIST資料集上的部分結果,透過保持離散變數不變、逐漸增大某一個連續的潛變數(論文中是從-2到2),可以看出從左到右數字的筆劃逐漸增粗,具有很強的可解釋性。所以上一小節遺留的問題就迎刃而解了,理想情況下,我們可以透過這些潛變數來生成無數個滿足我們需求的手寫數字了!也就不需要再為生成的資料人工打標籤了。
所以此時對於G的輸入來說不再是單純的噪聲z了,而是z和一個長度為12的向量c,但是僅僅有這個設定還不夠,因為生成器的學習具有很高的自由度,它很容易找到一個解,使得:
此時在生成器看來,z和c是兩個完全獨立的向量,有沒有c都一樣可以生成資料,這樣生成器就完全繞過了c,導致它起不到應有的作用。
為了解決這個問題,作者透過最佳化GAN的損失函式來讓和c強制產生聯絡,使得兩者完成建模。作者從資訊理論中得到啟發,提出基於互資訊(mutual information)的正則化項。在資訊理論中,互資訊用來衡量“已知隨機變數Y的情況下,可以獲得多少有關隨機變數X的資訊”,其計算公式為:
上式中,H表示計算熵值,所以I(X;Y)是兩個熵值的差。H(X|Y)衡量的是“給定隨機變數的情況下,隨機變數X的不確定性”。從公式中可以看出,若X和Y是獨立的,此時H(X)=H(X|Y),得到I(X;Y)=0,為最小值。若X和Y有非常強的關聯時,即已知Y時,X沒有不確定性,則H(X|Y)=0 ,I(X;Y)達到最大值。所以為了讓G(z,c)和c之間產生儘量明確的語義資訊,必須要讓它們二者的互資訊足夠的大,所以我們對GAN的損失函式新增一個正則項,就可以改寫為:
注意屬於G的損失函式的一部分,所以這裡為負號,即讓該項越大越好,使得G的損失函式變小。其中為平衡兩個損失函式的權重。但是,在計算的過程中,需要知道後驗機率分佈,而這個分佈在實際中是很難獲取的,因此作者在解決這個問題時採用了變分推理的思想,引入變分分佈來逼近,進而透過輪流迭代的方法用去逼近的下界,得到最終的網路損失函式:
(二)網路結構
從上圖可以清晰的看出,雖然在設計InfoGAN時的數學推導比較複雜,但是網路架構還是非常簡單明瞭的。G和D的網路結構和DCGAN保持一致,均由CNN構成。在此基礎上,改動的地方主要有:
1.G的輸入不僅僅是噪聲向量z了,而是z和具有語意資訊的淺變數c進行拼接後的向量輸入給G。
2.D的輸出在原先的基礎上新增了一個新的輸出分支Q,Q和D共享全部分卷積層,然後各自透過不同的全連線層輸出不同的內容:Q的輸出對應於的c的機率分佈,D則仍然判別真偽。
(三) InfoGAN的訓練流程
假設batch_size=m,資料集為MNIST,則根據作者的方法,不可壓縮噪聲向量的長度為62,離散潛變數的個數為1,取值範圍為[0, 9],代表0-9共10個數字,連續淺變數的個數為2,代表了生成數字的傾斜程度和筆劃粗細,最好服從[-2, 2]上的均勻分佈,因為這樣能夠顯式的透過改變其在[-2,2]上的數值觀察到生成資料相應的變化,便於實驗,所以此時輸入變數的長度為62+10+2=74。
則在每一個epoch中:
先訓練判別器k(比如3)次:
1. 從噪聲分佈(比如高斯分佈)中隨機取樣出m個噪聲向量:。
2.從真實樣本x中隨機取樣出m個樣本:
3. 用梯度下降法使損失函式real_loss:與1之間的二分類交叉熵減小(因為最後判別器最後一層的啟用函式為sigmoid,所以要與0或者1做二分類交叉熵,這也是為什麼損失函式要取log的原因)。
4.用梯度下降法使損失函式fake_loss:與0之間的二分類交叉熵減小。
5. 所以判別器的總損失函式d_loss:即讓d_loss減小。注意在訓練判別器的時候分類器中的所有引數要固定住,即不參加訓練。
再訓練生成器1次:
1. 從噪聲分佈中隨機取樣出m個噪聲向量:。
2. 從離散隨機分佈中隨機取樣m個長度為10、one-hot編碼格式的向量:。
3. 從兩個連續隨機分佈中各隨機取樣m個長度為1的向量:
,
4. 將上面的所有向量進行concat操作,得到長度為74的向量,共m個,並記錄每個向量所在的位置,便於計算損失函式。
5. 此時g_loss由三部分組成:一個是與1之間的二分類交叉熵、一個是Q分支輸出的離散淺變數的預測值和相應的輸入部分的交叉熵以及Q分支輸出的連續淺變數的預測值和輸入部分的互資訊,併為這三部分乘上適當的平衡因子,其中互資訊項的係數是負的。
6. 用梯度下降法使越小越好。注意在訓生成器的時候判別器中的所有引數要固定住,即不參加訓練。
直到所有epoch執行完畢,訓練結束。
(四)總結
1.G的輸入不再是一個單一的噪聲向量,而是噪聲向量與潛變數的拼接。
2.對於潛變數來說,G和D組成的大網路就好比是一個AutoEncoder,不同之處只是將資訊編碼在了影像中,而非向量,最後透過D解碼還原回。
3.D的輸出由原先的單一分支變為兩個不同的分支。
4.從資訊熵的角度對噪聲向量和潛變數的關係完成建模,並透過數學推導以及實驗的方式證明了該方法確實有效。
5.透過潛變數,使得G生成的資料具有一定的可解釋性。
四、WGAN(Wasserstein GAN)
(一) Wassersteindistance
從前面的章節我們知道,DCGAN的損失函式本質上是讓與間的JS散度儘可能的小,但是很有可能出現與兩個分佈根本就沒有重疊的地方,對於任意兩個沒有交疊、距離足夠遠的分佈,它們之間的JS散度恆定為log2,導致梯度消失,此時不可能在訓練的過程中向的方向移動,D也就得不到訓練。而WGAN就著手於從損失函式上進行最佳化,使得訓練更加穩定。
WGAN的作者用大量的數學推導來證明了基於二分類交叉熵的損失函式的缺陷與不合理性,並提出了一種新的損失函式,取名為Wasserstein distance,這個損失函式在任何位置都有著相對平滑的梯度,由於篇幅所限,我儘量直觀的向大家闡述,我們先來看一下網路結構。
(二)網路結構
乍一看怎麼和DCGAN相差無幾呢,是的,作者在網路結構上的變動僅僅是去掉了DCGAN中D最後一層的sigmoid啟用函式,使得網路最後一層的輸出變成線性的了。
(三) 原理闡述
究竟什麼是Wasserstein distance呢?Wasserstein distance用來衡量來兩個分佈間的距離,而且即使兩個分佈間沒有交疊,也會根據分佈相距的遠近程度給出一個相應的數值,即損失函式的值會隨著兩個分佈間的距離的遠近程度而動態的發生改變,在這篇論文中,作者初步給出Wasserstein distance的表示式:
從直觀上理解,損失函式是兩個期望的差值,並讓這個差值儘可能的大,即使儘可能的大,同時使儘可能的小,但僅從這個表示式是不足以讓訓練變得收斂的,我們來看下圖:
雖然能夠準確對生成樣本和真實樣本完美的區分,但是在沒有了sigmoid函式限制值域的情況下會讓D在真實樣本上的輸出值趨於無窮大,而在生成樣本上的輸出值趨於無窮小,導致永遠不會收斂,為了避免出現這個問題,作者在損失函式中新增了一個額外的限制條件:
限制是指:在樣本空間中,要求判別器函式D(x)梯度值不大於一個有限的常數k,透過權重限制的方式保證了權重引數的有界性,間接限制了其梯度資訊。
目的就是讓D的輸出曲線儘可能的平滑,不讓它趨向與無窮大或者無窮小,那麼怎麼限制呢?在作者2017年釋出的WGAN中只是對D的權重進行簡單的clipping操作:
人為的規定一個閾值c,並將D中的網路引數數值全部限制在上[-c,c],對於D中的任意一個引數w,如果 w>c, 則令w=c。如果w<-c,則令w=-c,即始終保持,該操作稱為weight clipping,使得D的輸出曲線比較平滑。是的,就是這麼簡單!實驗證明該演算法雖然簡單粗暴,但確實使得訓練過程變得更加穩定。另一方面,c的取值範圍很難確定,是一個依賴於經驗的數值。如果取的過小,網路引數都被限制在了一個比較小的範圍,導致D的擬合能力受限。如果取的過大,又可能會讓D的輸出值趨近於無窮,網路又無法收斂,所以它的取值極度依賴實驗,不過一般地,將c取為0.01是一個比較個合理的值。綜上,WGAN改動的地方主要有以下三點:
1.D最後一層去掉sigmoid啟用函式,所以它現在的輸出值不再代表二分類的機率了。
2.G和D的loss不再取log,即不再用與0或者1的二分類交叉熵作為損失函式了。
3.每次更新D的引數後,將其所有引數的絕對值截斷到不超過一個固定常數c(經驗數值,可以取為0.01),即weight clipping操作,其實本質上就是對D的引數新增了一個簡單粗暴的正則項。
(四) WGAN的訓練流程
假設batch_size=m,則在每一個epoch中:
先訓練判別器k(比如5)次:
1. 從噪聲分佈(比如高斯分佈)中隨機取樣出m個噪聲向量:.
2. 從真實樣本x中隨機取樣出m個真實樣本:
3. 用梯度下降法使損失函式越小越好(取負號的原因是一般的深度學習框架只能讓損失函式越來越小,所以這裡加個負號就和原先最大化的邏輯保持一致了)。
4. 用梯度下降法使損失函式越小越好,並儲存生成的假樣本的結果,記。
5. 所以判別器的總損失函式,即讓d_loss越小越好。注意在訓練判別器的時候分類器中的所有引數要固定住,即不參加訓練。
6. 檢查D中所有可訓練引數的值,將它們限制在一個人為規定的常數|c|內,即,令(可以將c取為0.01)。
再訓練生成器1次:
1. 從噪聲分佈中隨機取樣出m個噪聲向量:用梯度下降法使損失函式越小越好。注意在訓生成器的時候判別器中的所有引數要固定住,即不參加訓練。
直到所有epoch執行完畢,訓練結束。
(五) 總結
1.修改一直在沿用的原始GAN損失函式,提出一種新的損失函式,使得GAN的訓練變得比以前穩定。
2.提出針對判別器的weight clipping操作,並經過大量實驗證明確實能夠讓訓練變得穩定、加快模型收斂,而且程式碼實現上也非常簡單,對DCGAN程式碼的改動不超過20行就能讓它變成WGAN。
3.模型是否能夠收斂高度依賴於超引數c的取值,而該引數的選取通常依賴於實驗。如果選取得當,能夠提高網路訓練的穩定性,如果選取不當,模型反而無法收斂。
五、WGAN_GP(WassersteinGAN with Gradient Penality)
(一) gradient penality
在提出了WGAN後,作者繼續在WGAN上進行最佳化,又給出了一種新的損失函式,拋棄weight clipping,也就不再需要經驗常數c了,取而代之的是gradient penality(梯度懲罰),因此取名為WGAN_GP,也叫Improved_WGAN。
Gradient penality是指:對D的每一個輸入樣本x,使得。意思是對於任意一個輸入樣本x,用D的輸出結果D(x)對求梯度後的值的L2範數不大於1。
上面的解釋可能有些拗口,我們從一維空間中的函式f(x)來進行闡述:一維函式f(x),對於任意輸入x,該函式滿足:,即任意一點的斜率的平方不大於1,進而可以推出:,可想而知f(x)的函式曲線是比較平滑的,所以稱為梯度懲罰。那為什麼是L2範數呢而不是L1範數呢?原因很簡單,L1範數會破壞一個函式的可微性呀,所以L2範數是非常合理的!
注意前面我說的是針對於D的每一個輸入樣本,都讓它滿足,實際上這是不現實的,所以作者又想了一個辦法來解決這個問題:假設從真實資料中取樣出來的一個點稱為x(這個點是高維空間中的點),G利用取樣得到的噪聲向量所生成的假資料稱為在這兩點之間的某一個位置取樣一個點記為,即對於每一個,儘量讓。那麼最常見的滿足上述要求的取樣方法就是線性取樣方法了,即在x與所形成的超平面上任意選取一個點,換句話說就是在生成樣本和真實樣本間做一個線性插值,所以存在。
在新的損失函式閃亮登場之前,我們還有一個小小的最佳化!因為作者最後發現,其實讓是最好的方案,而不是把1作為上、下限,別問我為什麼,作者也不知道!因為是透過大量的實驗總結出來的。
那麼WGAN_GP的核心就線上性插值這了,為了不讓這部分變得太抽象,我們用pytorch來實現一下插值這部分。
所以,新的損失函式可以寫為:
說了這麼多,其實用數學公式表達出來還是非常簡單的,式子中前兩項仍然是WGAN的損失函式,只是新新增了一個正則項,便是在真實資料和生成資料之間透過線性插值得到的點,即儘量讓D對它的梯度的L2範數越接近於1,使越大越好,透過該正則項,能讓損失函式上的每一點都有較為平滑的梯度,訓練也就更加穩定,大大降低了訓練GAN的難度。超引數用於對這兩部分的損失函式進行平衡,作者透過實驗發現=10是一個比較合理的數值。
(二) WGAN_GP的訓練流程
假設batch_size=m,則在每一個epoch中:
先訓練判別器k(比如5)次:
1. 從噪聲分佈(比如高斯分佈)中隨機取樣出m個噪聲向量:。
2. 從真實樣本x中隨機取樣出m個真實樣本:。
3. 用梯度下降法使損失函式越小越好(取負號的原因是一般的深度學習框架只能讓損失函式越來越小,所以這裡加個負號就和原先最大化的邏輯保持一致了)。
4. 用梯度下降法使損失函式越小越好,並儲存生成的假樣本的結果,記為。
5.在這m個假樣本與已經得到的m個真實樣本進行線性插值,得到m個插值樣本:。將m個插值樣本送入D中得到的結果對輸入求梯度,使越小越好。
6.所以判別器的總損失函式d_loss =read_loss + fake_loss + gp,即讓d_loss越小越好。注意在訓練判別器的時候分類器中的所有引數要固定住,即不參加訓練。為平衡兩個損失函式的權重,取為10是比較合理的數值。
再訓練生成器1次:
從噪聲分佈中隨機取樣出m個噪聲向量:用梯度下降法使損失函式g_loss:越小越好。注意在訓練生成器的時候判別器中的所有引數要固定住,即不參加訓練。
直到所有epoch執行完畢,訓練結束。
(三) WGAN_GP小試牛刀
在寫這篇文章的時候,正好看到TinyMind舉辦了一個關於用GAN生成書法字型的比賽https://www.tinymind.cn/competitions/45 - ranking,當時距離比賽結束僅剩三天時間,但是為了讓文章更充實一些,還是馬不停蹄的把資料集下載到本地,不說了,GAN就完了!
比賽目的是用GAN來生成圖片大小為128*128的書法字型圖片,評判標準是上傳10000張自己生成的書法字進行系統評分,當然質量、多樣性越高越好。訓練集中共有100種字,每種字又有400張不同的字型圖片,所以一共是40000張圖片,每張圖片的高、寬都在200到400之間,並且為灰度影像,那麼我們就來用WGAN_GP來完成這個小比賽!,參考開原始碼地址:https://github.com/igul222/improved_wgan_training,實現框架為tensorflow。
先來看看資料集長什麼樣吧。
這裡我將每種字隨機抽出1個並resize到64*64進行排列展示,所以正好100個不同的字,發現有一些根本不認識!不過認不認識沒關係,對於網路來說它需要的僅僅是資料而已。另外一點就是這裡面有一些髒資料,比如大字下面還有一些小字,這肯定不是我們期望的樣本,但是我在這裡並沒有過濾掉這些髒資料,一是工作量太大,不能自動完成,需要人工檢查。二是先嚐試著訓練一下,不行的話再想辦法剔除,事實證明對結果影響不大。
原repo的程式碼只能生成64*64的圖片,所以需要對其網路結構進行相應的改進,使其能夠產生128*128的圖片,改進的方案也非常簡單:
1)將G的第一個全連線層的輸出神經元個數擴大為原先的兩倍,所以這時reshape後tensor深度變為原先的兩倍,此後卷積核的個數每層都除以2。
2)將生成器最後一層的啟用函式改為relu,接一個batch normalization,並在其後面再新增一個deconv層,啟用函式為tanh。
3)將判別器的最後一層的全連線層改為卷積層,接一個batch normalization,啟用函式為leaky relu,並重復一次,即再降取樣一次,reshape後再接一個單神經元的全連線層就可以了,注意沒有啟用函式。
4)因為是新的資料了,所以資料讀取以及組織資料的程式碼需要自己寫。損失函式、訓練程式碼不用動。可能需要在實驗中對學習率進行調整。
在訓練了40個epoch後,我把每個epoch生成器生成的100張圖片存下來並縮小做成動態圖:
可以看出生成的資料已經趨於穩定,變動不大。由於時間有限再加上工作繁忙,沒有足夠的時間對網路進行最佳化,排名沒進前10,因為前10名才有獎勵呀,重在參與嘛!將10000張生成的圖片上傳後,官網展示了部分圖片:
個人感覺效果一般吧,不過在參加了這個小比賽後讓我學到了很多知識,也認識到了自身的不足。
直接把資料的標籤資訊仍掉了,所有資料同等對待一起訓的,導致最終資料的多樣性可能不夠高,拉低評分。
既然是比賽,採取合理的小技巧來達到更高的評分也是可以的。我們知道越大的圖片越不好生成,而64*64的圖片相對來說比較容易生成,也易於訓練。可以只生成64*64的圖片,提交成績的時候再透過一些好的插值方法(比如雙三次插值)resize到128*128!賽後我知道確實有人是這樣做的。不過這種方法所生成的書法字肯定沒有直接生成128*128的圖片質量高。
其實有很多比WGAN_GP更先進、生成效果更好的網路,畢竟這篇文章的發表時間是在2017年,但是新手上路嘛,以穩為主,就選擇了一個比較經典的模型。
為以後的手寫字識別提供了不少思路,透過GAN來增加訓練資料量是非常可行的方法。
(四) 總結
1.為WGAN的損失函式提出了一種新的正則方法——gradient penality,從而更好的解決了訓練GAN的過程中梯度消失的問題。
2.比標準WGAN擁有更快的收斂速度,並能生成更高質量的樣本。
3.將resnet中的殘差塊成功應用於生成器和判別器中,使網路可以變得更深、同時能夠生成質量更高的樣本,並且訓練過程也更加穩定。
4.不需要過多的調參,成功訓練多種針對圖片的GAN結構。
六 、總結
1.本文沿著GANDCGANInfoGANWGANWGAN_GP的路線來介紹GAN,其初衷是能讓大家對GAN有一個感性的瞭解,所以大量的數學公式推導沒有列出來。當然,還有很多優秀的GAN本文沒有涉及到,畢竟以入門為主嘛!相信在讀完本文後能夠讓大家更好的理解當下比較新穎並且有意思的GAN。
2.其實GAN在最終的實現上都非常簡單,比較難的地方是涉及模型損失函式的最佳化以及相關數學推導、還有就是在現有網路上的創新,從而提出一個新穎並且生成質量高的GAN模型。
3.雖然GAN在影像生成上取得了耀眼的成績,但並沒有在NLP領域取得顯著成果。其中一個主要原因是影像資料都是實數空間上的連續資料,而NLP中大多都是離散資料,例如分詞後的片語。而對於連續型資料,就可以略微改變合成的資料,比如一個浮點型別的畫素值為0.64,將這個值改為0.65是沒有問題的。但是對於離散型資料,如果輸出了一個單詞”hello”,但接下來不能將其改為”hello+0.01”,因為根本沒有這個單詞!所以NLP中應用GAN是比較困難的。但並不代表沒有人研究這個方向,有一些學者已經能夠將GAN應用於NLP中了,大多數要與強化學習結合,感興趣的小夥伴可以讀一讀TextGAN、SeqGAN這兩篇文章。
4.由於平時對GAN的接觸比較少,再加上專業水平有限,文章中出錯之處在所難免,還望多多包涵。
七、參考文獻
[1]IanJ. Goodfellow, Jean Pouget-Abadie and Mehdi Mirza, “Gererative AdversarialNetworks,” ArXiv preprint arXiv:1406.2661, 2014.
[2]AlecRadford, Luke Metz and Soumith Chintala, “Unsupervised Representation Learningwith Deep Convolutional Generative Adversarial Networks,” ArXiv preprintaxXiv:1511.06434, 2016.
[3]Xi Chen, Yan Duan and Rein Houthooft, “InfoGAN:Interpretable Representation Learning by Information Maximizing GenerativeAdversarial Nets,” ArXiv Preprint arXiv:1606.03657, 2016.
[4]Martin Arjovsky, Soumith Chintala and Léon Bottou, “WassersteinGAN,” ArXiv preprint arXiv:1606.03657, 2016.
[5]Ishaan Gulrajani, Faruk Ahmed and Martin Arjovsky, “ImprovedTraining of Wasserstein GANs”, ArXiv preprint arXiv:1704.00028, 2017.
關於作者
馬振宇:達觀資料演算法工程師,負責達觀資料OCR方向的相關演算法研發,最佳化工作。