【深度學習理論】通俗理解生成對抗網路GAN

AI科技大本營發表於2018-03-02

640?wx_fmt=jpeg&wxfrom=5&wx_lazy=1


作者 | 陳誠

來源 | 機器學習演算法與自然語言處理


1. 引言


自2014年Ian Goodfellow提出了GAN(Generative Adversarial Network)以來,對GAN的研究可謂如火如荼。各種GAN的變體不斷湧現,下圖是GAN相關論文的發表情況:

640?wxfrom=5&wx_lazy=1

大牛Yann LeCun甚至評價GAN為 “adversarial training is the coolest thing since sliced bread”。


那麼到底什麼是GAN呢?它又好在哪裡?下面我們開始進行介紹。


2. GAN的基本思想


GAN全稱對抗生成網路,顧名思義是生成模型的一種,而他的訓練則是處於一種對抗博弈狀態中的。下面舉例來解釋一下GAN的基本思想。


640?


假如你是一名籃球運動員,你想在下次比賽中得到上場機會。


於是在每一次訓練賽之後你跟教練進行溝通:


你:教練,我想打球 

教練:(評估你的訓練賽表現之後)... 算了吧 (你通過跟其他人比較,發現自己的運球很差,於是你苦練了一段時間) 


你:教練,我想打球 

教練:... 嗯 還不行 (你發現大家投籃都很準,於是你苦練了一段時間的投籃) 


你:教練,我想打球 

教練: ... 嗯 還有所欠缺 

(你發現你的身體不夠壯,被人一碰就倒,於是你去泡健身房)


...... 


通過這樣不斷的努力和被拒絕,你最終在某一次訓練賽之後得到教練的讚賞,獲得了上場的機會。


值得一提的是在這個過程中,所有的候選球員都在不斷地進步和提升。因而教練也要不斷地通過對比場上球員和候補球員來學習分辨哪些球員是真正可以上場的,並且要“觀察”得比球員更頻繁。隨著大家的成長教練也會會變得越來越嚴格。


現在大家對於GAN的思想應該有了感性的認識了,下面開始進一步窺探GAN的結構和思想。


3. GAN淺析


3.1 GAN的基本結構


GAN的主要結構包括一個生成器G(Generator)和一個判別器D(Discriminator)。


在上面的例子中的球員就相當於生成器,我們需要他在球場上能有好的表現。而球員一開始都是初學者,這個時候就需要一個教練員來指導他們訓練,告訴他們訓練得怎麼樣,直到真的能夠達到上場的標準。而這個教練就相當於判別器。


下面我們舉另外一個手寫字的例子來進行進一步窺探GAN的結構。


640?


我們現在擁有大量的手寫數字的資料集,我們希望通過GAN生成一些能夠以假亂真的手寫字圖片。主要由如下兩個部分組成:


  • 定義一個模型來作為生成器(圖三中藍色部分Generator),能夠輸出一個向量,輸出手寫數字大小的畫素影像。

  • 定義一個分類器來作為判別器(圖三中紅色部分Discriminator)用來判別圖片是真的還是假的(或者說是來自資料集中的還是生成器中生成的),輸入為手寫圖片,輸出為判別圖片的標籤。


4. GAN的訓練方式


前面已經定義了一個生成器(Generator)來生成手寫數字,一個判別器(Discrimnator)來判別手寫數字是否是真實的,和一些真實的手寫數字資料集。那麼我們怎樣來進行訓練呢?


4.1 關於生成器


對於生成器,輸入需要一個n維度向量,輸出為圖片畫素大小的圖片。因而首先我們需要得到輸入的向量。


Tips: 這裡的生成器可以是任意可以輸出圖片的模型,比如最簡單的全連線神經網路,又或者是反摺積網路等。這裡大家明白就好。


這裡輸入的向量我們將其視為攜帶輸出的某些資訊,比如說手寫數字為數字幾,手寫的潦草程度等等。由於這裡我們對於輸出數字的具體資訊不做要求,只要求其能夠最大程度與真實手寫數字相似(能騙過判別器)即可。所以我們使用隨機生成的向量來作為輸入即可,這裡面的隨機輸入最好是滿足常見的比如均值分佈,高斯分佈等。


Tips: 假如我們後面需要獲得具體的輸出數字等資訊的時候,我們可以對輸入向量產生的輸出進行分析,獲取到哪些維度是用於控制數字編號等資訊的即可以得到具體的輸出。而在訓練之前往往不會去規定它。


4.2 關於判別器


對於判別器不用多說,往往是常見的判別器,輸入為圖片,輸出為圖片的真偽便籤。


Tips: 同理,判別器與生成器一樣,可以是任意的判別器模型,比如全連線,或者是包含卷積的神經網路。


4.3 如何訓練


上面進一步說明了生成器和判別器,接下來說明如何進行訓練。


先用基本流程如下:


640?

  • 迴圈k次更新判別器之後,使用較小的學習率來更新一次生成器的引數,訓練生成器使其儘可能能夠減小生成樣本與真實樣本之間的差距,也相當於儘量使得判別器判別錯誤。

  • 多次更新迭代之後,最終理想情況是使得判別器判別不出樣本來自於生成器的輸出還是真實的輸出。亦即最終樣本判別概率均為0.5。


Tips: 之所以要訓練k次判別器,再訓練生成器,是因為要先擁有一個好的判別器,使得能夠教好地區分出真實樣本和生成樣本之後,才好更為準確地對生成器進行更新。


更直觀的理解可以參考下圖:


640?


注:圖中的黑色虛線表示真實的樣本的分佈情況,藍色虛線表示判別器判別概率的分佈情況,綠色實線表示生成樣本的分佈。 Z 表示噪聲,Z 到X表示通過生成器之後的分佈的對映情況。

我們的目標是使用生成樣本分佈(綠色實線)去擬合真實的樣本分佈(黑色虛線),來達到生成以假亂真樣本的目的。

可以看到在(a)狀態處於最初始的狀態的時候,生成器分佈和真實分佈區別較大,並且判別器判別出樣本的概率不是很穩定,因此會先訓練判別器來更好地分辨樣本。


通過多次訓練判別器來達到(b)樣本狀態,此時判別樣本區分得非常顯著和良好。然後再對生成器進行訓練。


訓練生成器之後達到(c)樣本狀態,此時生成器分佈相比之前,逼近了真實樣本分佈。
經過多次反覆訓練迭代之後,最終希望能夠達到(d)狀態,生成樣本分佈擬合於真實樣本分佈,並且判別器分辨不出樣本是生成的還是真實的(判別概率均為0.5)。也就是說我們這個時候就可以生成出非常真實的樣本啦,目的達到。


5. 訓練相關理論基礎


前面用了大白話來說明了訓練的大致流程,下面會從交叉熵開始說起,一步步說明損失函式的相關理論,尤其是論文中包含min,max的公式如下圖5形式:


640?判別器在這裡是一種分類器,用於區分樣本的真偽,因此我們常常使用交叉熵(cross entropy)來進行判別分佈的相似性,交叉熵公式如下圖6所示:


640?

Tips: 公式中pi 和qi為真實的樣本分佈和生成器的生成分佈。由於交叉熵是非常常見的損失函式,這裡預設大家都較為熟悉,就不進行贅述了。


在當前模型的情況下,判別器為一個二分類問題,因此可以對基本交叉熵進行更具體地展開如下圖7所示:


640?


Tips: 其中,假定 y1為正確樣本分佈,那麼對應的( 1-y1 )就是生成樣本的分佈。 D表示判別器,則 D(x1) 表示判別樣本為正確的概率,(1-D(x1))則對應著判別為錯誤樣本的概率。這裡僅僅是對當前情況下的交叉熵損失的具體化。相信大家也還是比較熟悉。


將上式推廣到N個樣本後,將N個樣本相加得到對應的公式如下:


640?


將上式推廣到N個樣本後,將N個樣本相加得到對應的公式如下:


640?


OK,到目前為止還是基本的二分類,下面加入GAN中特殊的地方。


對於GAN中的樣本點 xi ,對應於兩個出處,要麼來自於真實樣本,要麼來自於生成器生成的樣本  ~ G(z) ( 這裡的z是服從於投到生成器中噪聲的分佈)。

其中,對於來自於真實的樣本,我們要判別為正確的分佈 yi 。來自於生成的樣本我們要判別其為錯誤分佈(1-yi)。


將上面式子進一步使用概率分佈的期望形式寫出(為了表達無限的樣本情況,相當於無限樣本求和情況),並且讓 yi為 1/2 且使用 G(z) 表示生成樣本可以得到如下圖8的公式:


640?


640?


6. 總結


本文大致介紹了GAN的整體情況。但是對於GAN實際上還有更多更完善的理論相關描述,進一步瞭解可以看相關的論文。


並且在GAN一開始提出來的時候,實際上針對於不同的情況也有存在著一些不足,後面也陸續提出了不同的GAN的變體來完善GAN。


通過一個判別器而不是直接使用損失函式來進行逼近,更能夠自頂向下地把握全域性的資訊。比如在圖片中,雖然都是相差幾畫素點,但是這個畫素點的位置如果在不同地方,那麼他們之間的差別可能就非常之大。


640?


比如上圖10中的兩組生成樣本,對應的目標為字型2,但是圖中上面的兩個樣本雖然只相差一個畫素點,但是這個畫素點對於全域性的影響是比較大的,但是單純地去使用使用損失函式來判斷都是相差一個畫素點,而下面的兩個雖然相差了六個畫素點的差距(粉色部分的畫素點為誤差),但是實際上對於整體的判斷來說,是沒有太大影響的。


但是使用損失函式的話,卻會得到6個畫素點的差距,比上面的兩幅圖差別更大。而如果使用判別器,則可以更好地判別出這種情況(不會拘束於具體畫素的差距)。


總之GAN是一個非常有意思的東西,現在也有很多相關的利用GAN的應用,比如利用GAN來生成人物頭像,用GAN來進行文字的圖片說明等等。後面我也會使用GAN來做一些簡單的實驗來幫助進一步理解GAN。


最後附上論文中的GAN演算法流程,通過上面的介紹,這裡應該非常好理解了。


640?


招聘

新一年,AI科技大本營的目標更加明確,有更多的想法需要落地,不過目前對於營長來說是“現實跟不上靈魂的腳步”,因為缺人~~


所以,AI科技大本營要壯大隊伍了,現招聘AI記者和資深編譯,有意者請將簡歷投至:gulei@csdn.net,期待你的加入!


如果你暫時不能加入營長的隊伍,也歡迎與營長分享你的精彩文章,投稿郵箱:suiling@csdn.net


AI科技大本營讀者群(計算機視覺、機器學習、深度學習、NLP、Python、AI硬體、AI+金融方向)正在招募中,後臺回覆:讀者群,聯絡營長,新增營長請備註姓名,研究方向。


640?wx_fmt=gif


640?wx_fmt=png

640?wx_fmt=png

640?wx_fmt=png


☟☟☟點選 | 閱讀原文 | 檢視更多精彩內容

相關文章