訓練生成對抗網路的一些技巧和陷阱
作者:Utkarsh Desai
編譯:ronghuaiyang
導讀
生成對抗網路是個好東西,不過訓練比較麻煩,這裡有一些技巧和陷阱,分享給大家。
生成對抗網路(GANs)是當前深度學習研究的熱點之一。過去一段時間,GANs上發表的論文數量有了巨大的增長。GANs已經應用於各種各樣的問題。
我讀了很多關於GANs的書,但是我自己從來沒有玩過。因此,在閱讀了一些論文和github repos之後,我決定親自動手訓練一個簡單的GAN,但很快就遇到了問題。
本文的目標讀者是剛開始學習GANs的深度學習愛好者。除非你非常幸運,否則第一次獨自訓練GAN可能是一個令人沮喪的過程,可能需要花費數小時才能正確。當然,隨著時間的推移,隨著經驗的積累,你會很好地訓練GANs,但是對於初學者來說,可能會出現一些錯誤,你甚至不知道從哪裡開始除錯。我想分享我的觀察和經驗教訓,從零開始訓練GANs,希望它可能會節省一些人開始除錯幾個小時的時間。
生成對抗網路
除非你已經斷網一年左右了,否則所有參與深度學習的人——甚至一些沒有參與深度學習的人——都聽說過並談論過GANs。GANs是一種深度神經網路,是資料的生成模型。這意味著,給定一組訓練資料,GANs可以學會估計資料的潛在機率分佈。這是非常有用的,因為除了其他事情,我們現在可以從學習到的機率分佈中生成樣本,這些樣本可能不會出現在原始訓練集中。
該領域的專家已經提供了一些很棒的資源來解釋GANs及其工作原理,所以我不會試圖複製他們的工作。但是為了完整起見,這裡有一個快速的概述。
生成對抗網路實際上是兩個相互競爭的深層網路。給定一個訓練集
X(比如說幾千張貓的圖片),生成網路
G(X),使用一個隨機向量作為輸入,並試圖產生與訓練集類似的影像。鑑別器網路,
D(X),是一個二元分類器,試圖區分真正的訓練集
X中的貓圖片和生成器生成的假貓圖片。因此,生成器網路的工作就是學習資料在
X中的分佈情況,從而生成真實的貓影像,並確保識別器不能區分訓練集中的貓影像和生成器生成的貓影像。鑑別器需要學習跟上生成器的步伐,因為生成器一直嘗試新的技巧來生成假的貓的影像並欺騙鑑別器。
最終,如果一切順利,生成器(或多或少)就會學習訓練資料的真實分佈,並變得非常擅長生成真實的貓影像。識別器不再能夠區分訓練集的貓影像和生成的貓影像。
從這個意義上說,這兩個網路不斷地試圖確保另一個不能很好地完成他們的任務。那麼,這到底是怎麼回事呢?
另一種觀察GAN設定的方法是,鑑別器試圖透過告訴生成程式真實的貓影像是什麼樣子來引導生成器。最終,這臺機器發現了這一點,並開始生成真實的貓的影像。GANs的訓練方法類似於博弈論中的極大極小演算法,這兩個網路試圖達到所謂的納什均衡。
GAN訓練中的挑戰
回到GANs的訓練。首先,我使用Keras和Tensorflow後端,在MNIST資料集上訓練了一個GAN(準確地說,是DC-GAN)。這並不難,經過對生成器和鑑別器網路的一些小的調整,GAN能夠生成MNIST數字的清晰影像。
黑色和白色的數字沒那麼有趣。物體和人的彩色影像是所有酷傢伙玩的東西。這就是事情開始變得棘手的地方。MNIST之後,顯然下一步是生成CIFAR-10影像。在日復一日地調整超引數、更改網路架構、新增和刪除層之後,我終於能夠生成類似CIFAR-10的外觀不錯的影像。
我從一個相當深的網路開始,最終得到了一個實際有效的、簡單得多的網路。當我開始調整網路和訓練過程時,15個epochs後生成的影像從現在的樣子:
到這樣:
最後是這樣:
下面是我認識到自己犯過的一些錯誤,以及我從中學到的一些東西。所以,如果你是GANs的新手,並沒有看到在訓練方面取得很大的成功,也許看看以下幾個方面會有所幫助:
1. 大卷積核和更多的濾波器
更大的卷積核覆蓋了前一層影像中的更多畫素,因此可以檢視更多資訊。5x5的核與CIFAR-10配合良好,在鑑別器中使用3x3核使鑑別器損耗迅速趨近於0。對於生成器,你希望在頂層的卷積層有更大的核,以保持某種平滑。在較低的層,我沒有看到改變核心大小的任何主要影響。
濾波器的數量可以大量增加引數的數量,但通常需要更多的濾波器。我在幾乎所有的卷積層中都使用了128個濾波器。使用較少的濾波器,特別是在生成器中,使得最終生成的影像過於模糊。因此,看起來更多的濾波器可以幫助捕獲額外的資訊,最終為生成的影像增加清晰度。
2. 標籤翻轉(Generated=True, Real=False)
雖然一開始看起來很傻,但對我有用的一個主要技巧是更改標籤分配。
如果你使用的是Real Images = 1,而生成的影像= 0,則使用另一種方法會有所幫助。正如我們將在後面看到的,這有助於在早期迭代中使用梯度流,並幫助使梯度流動。
3. 使用有噪聲的標籤和軟標籤
這在訓練鑑別器時是非常重要的。硬標籤(1或0)幾乎扼殺了早期的所有學習,導致識別器非常快地接近0損失。最後,我使用0到0.1之間的隨機數表示0標籤(真實影像),使用0.9到1.0之間的隨機數表示1標籤(生成的影像)。在訓練生成器時不需要這樣做。
此外,增加一些噪音的訓練標籤也是有幫助的。對於輸入識別器的5%的影像,標籤被隨機翻轉。比如真實的被標記為生成的,生成的被標記為真實的。
4. 使用批歸一化是有用的,但是需要有其他的東西也是合適的
批歸一化無疑有助於最終的結果。新增批歸一化後,生成的影像明顯更清晰。但是,如果你錯誤地設定了卷積核或濾波器,或者識別器的損失很快達到0,新增批歸一化可能並不能真正幫助恢復。
5. 每次一個類別
為了更容易地訓練GANs,確保輸入資料具有相似的特徵是很有用的。例如,與其在CIFAR-10的所有10個類中都訓練GAN,不如選擇一個類(例如,汽車或青蛙)並訓練GANs從該類生成影像。DC-GAN的其他變體在學習生成多個類的影像方面做得更好。例如,以類標籤為輸入,生成基於類標籤的影像。但是,如果你從一個普通的DC-GAN開始,最好保持事情簡單。
6. 檢視梯度
如果可能的話,試著監控梯度以及網路中的損失。這些可以幫助你更好地瞭解訓練的進展,甚至可以幫助你在工作不順利的情況下進行除錯。
理想情況下,生成器應該在訓練的早期獲得較大的梯度,因為它需要學習如何生成真實的資料。另一方面,鑑別器並不總是在早期獲得較大的梯度,因為它可以很容易地區分真假影像。一旦生成器得到足夠的訓練,鑑別器就很難分辨真假影像。它會不斷出錯,並得到大的梯度。
我在CIFAR-10汽車上的最初幾個GAN版本,有許多卷積和批次規範層,沒有標籤翻轉。除了這個趨勢之外,監測梯度的規模也很重要。如果生成器層上的梯度太小,學習可能會很慢,或者根本不會發生。這在GAN的這個版本中是可見的。
在生成器的最下層梯度的規模太小,任何學習都無法進行。鑑別器的梯度始終是一致的,這表明鑑別器並沒有真正學到任何東西。現在,讓我們將其與GAN的梯度進行比較,GAN具有上面描述的所有變化,並生成良好的真實影像:
梯度到達生成器底層的比例明顯高於前一個版本。此外,隨著訓練的進展,梯度流與預期一樣,隨著發生器在早期獲得較大的梯度,一旦訓練足夠,鑑別器在頂層獲得一致的高梯度。
7. 不要提前停止
我犯了一個愚蠢的錯誤——可能是由於我的不耐煩——當我看到損失沒有任何明顯的進展,或者生成的樣本仍然有噪聲時,在進行了幾百次小批次訓練之後,我就終止了訓練。比起等到訓練結束後才意識到網路什麼都沒學到,重啟工作並節省時間是很誘人的。GANs的訓練時間較長,初始損失和生成的樣本值很少,幾乎從未顯示出任何趨勢或進展的跡象。在結束訓練過程並調整設定之前,等待一段時間是很重要的。
這個規則的一個例外是,如果你看到鑑別器損失迅速接近0。如果發生這種情況,幾乎沒有恢復的機會,最好重新開始訓練,最好對網路或訓練過程做一些修改。
最後的GAN是這樣工作的:
英文原文:
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/29829936/viewspace-2662488/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- 生成對抗網路綜述:從架構到訓練技巧架構
- GAN生成對抗網路-DCGAN原理與基本實現-深度卷積生成對抗網路03卷積
- 第六週:生成式對抗網路
- 實戰生成對抗網路[1]:簡介
- LSGAN:最小二乘生成對抗網路
- 0901-生成對抗網路GAN的原理簡介
- 生成對抗網路的進步多大,請看此文
- [深度學習]生成對抗網路的實踐例子深度學習
- 卷積生成對抗網路(DCGAN)---生成手寫數字卷積
- 實戰生成對抗網路[2]:生成手寫數字
- 萬字綜述之生成對抗網路(GAN)
- 谷歌獲批GAN專利,一整套對抗訓練網路被收入囊中谷歌
- 解讀生成對抗網路(GAN) 之U-GAN-IT
- 如何應用TFGAN快速實踐生成對抗網路?
- 深度神經網路的分散式訓練概述:常用方法和技巧全面總結神經網路分散式
- 如何應對訓練的神經網路不工作?神經網路
- 【深度學習理論】通俗理解生成對抗網路GAN深度學習
- 海量案例!生成對抗網路(GAN)的18個絕妙應用
- 【機器學習】李宏毅——生成式對抗網路GAN機器學習
- 3.3 神經網路的訓練神經網路
- 極端影像壓縮的生成對抗網路,可生成低位元速率的高質量影像
- 極端影象壓縮的生成對抗網路,可生成低位元速率的高質量影象
- 生成對抗網路,AI將圖片轉成漫畫風格AI
- Python中的一些陷阱與技巧小結Python
- AI和ML如何幫助對抗網路攻擊?AI
- 關於訓練神經網路的諸多技巧Tricks(完全總結版)神經網路
- 深度學習與CV教程(6) | 神經網路訓練技巧 (上)深度學習神經網路
- 對抗網路學習記錄
- 前人總結出的一些學Python中的陷阱和技巧,非常受用!Python
- 帶自注意力機制的生成對抗網路,實現效果怎樣?
- 送書 | AI插畫師:如何用基於PyTorch的生成對抗網路生成動漫頭像?AIPyTorch
- 【生成對抗網路學習 其一】經典GAN與其存在的問題和相關改進
- 醫療領域:合成資料、生成對抗網路、數字孿生的應用
- 一文入門人工智慧的掌上明珠:生成對抗網路(GAN)人工智慧
- dl4j-gans: Deeplearning4j生成對抗網路GNA的示例原始碼原始碼
- 理解生成對抗網路,一步一步推理得到GANs(二)
- 理解生成對抗網路,一步一步推理得到GANs(一)
- 訓練自己的Android TensorFlow神經網路Android神經網路