深度學習界明星:生成對抗網路與Improving GAN
2014年,深度學習三巨頭之一IanGoodfellow提出了生成對抗網路(Generative Adversarial Networks, GANs)這一概念,剛開始並沒有引起轟動,直到2016年,學界、業界對它的興趣如“井噴”一樣爆發,多篇重磅文章陸續發表。2016年12月NIPS大會上,Goodfellow做了關於GANs的專題報告,使得GANs成為了當今最熱門的研究領域之一,本文將介紹如今深度學習界的明星——生成對抗網路。
1何為生成對抗網路
生成對抗網路,根據它的名字,可以推斷這個網路由兩部分組成:第一部分是生成,第二部分是對抗。這個網路的第一部分是生成模型,就像之前介紹的自動編碼器的解碼部分;第二部分是對抗模型,嚴格來說它是一個判斷真假圖片的判別器。生成對抗網路最大的創新在此,這也是生成對抗網路與自動編碼器最大的區別。簡單來說,生成對抗網路就是讓兩個網路相互競爭,通過生成網路來生成假的資料,對抗網路通過判別器判別真偽,最後希望生成網路生成的資料能夠以假亂真騙過判別器。過程如圖1所示。
圖1 生成對抗網路生成資料過程
下面依次介紹生成模型和對抗模型。
1. 生成模型
首先看看生成模型,前一節自動編碼器其實已經給出了一般的生成模型。
在生成對抗網路中,不再是將圖片輸入編碼器得到隱含向量然後生成圖片,而是隨機初始化一個隱含向量,根據變分自動編碼器的特點,初始化一個正態分佈的隱含向量,通過類似解碼的過程,將它對映到一個更高的維度,最後生成一個與輸入資料相似的資料,這就是假的圖片。這時自動編碼器是通過對比兩張圖片之間每個畫素點的差異計算損失函式的,而生成對抗網路會通過對抗過程來計算出這個損失函式,如圖2所示。
圖2 生成模型
2. 對抗模型
重點來介紹對抗過程,這個過程是生成對抗網路相對於之前的生成模型如自動編碼器等最大的創新。
對抗過程簡單來說就是一個判斷真假的判別器,相當於一個二分類問題,輸入一張真的圖片希望判別器輸出的結果是1,輸入一張假的圖片希望判別器輸出的結果是0。
這跟原圖片的label 沒有關係,不管原圖片到底是一個多少類別的圖片,它們都統一稱為真的圖片,輸出的label 是1,則表示是真實的;而生成圖片的label 是0,則表示是假的。
在訓練的時候,先訓練判別器,將假的資料和真的資料都輸入給判別模型,這個時候優化這個判別模型,希望它能夠正確地判斷出真的資料和假的資料,這樣就能夠得到一個比較好的判別器。
然後開始訓練生成器,希望它生成的假的資料能夠騙過現在這個比較好的判別器。
具體做法就是將判別器的引數固定,通過反向傳播優化生成器的引數,希望生成器得到的資料在經過判別器之後得到的結果能儘可能地接近1,這時只需要調整一下損失函式就可以了,之前在優化判別器的時候損失函式是讓假的資料儘可能接近0,而現在訓練生成器的損失函式是讓假的資料儘可能接近1。
這其實就是一個簡單的二分類問題,這個問題可以用前面介紹過的很多方法去處理,比如Logistic 迴歸、多層感知器、卷積神經網路、迴圈神經網路等。
上面是生成對抗網路的簡單解釋,可以通過程式碼更清晰地展示整個過程。
跟自動編碼器一樣,先使用簡單的多層感知器來實現:
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.dis = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.dis(x)
return x
上面是判別器的結構,中間使用了斜率設為0.2 的LeakyReLU 啟用函式,最後需要使用nn.Sigmoid() 將結果對映到0 s 1 之間概率進行真假的二分類。這裡之所以用LeakyReLU 啟用函式而不使用ReLU 啟用函式,是因為經過實驗,LeakyReLU 的表現更好。
class generator(nn.Module):
def __init__(self, input_size):
super(generator, self).__init__()
self.gen = nn.Sequential(
nn.Linear(input_size, 256),
nn.ReLU(True),
nn.Linear(256, 256),
nn.ReLU(True),
nn.Linear(256, 784),
nn.Tanh()
)
def forward(self, x):
x = self.gen(x)
return x
這就是生成器的結構,跟自動編碼器中的解碼器是類似的,最後需要使用nn.Tanh(),將資料分佈到-1 ~1 之間,這是因為輸入的圖片會規範化到-1 ~1之間。
接著需要定義損失函式和優化函式:
criterion = nn.BCELoss() # Binary Cross Entropy
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0003)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0003)
這裡使用二分類的損失函式nn.BCELoss(),使用Adam 優化函式,學習率設定為0.0003。
接著是最為重要的訓練過程,這個過程分為兩個部分:一個是判別器的訓練,一個是生成器的訓練。
首先來看看判別器的訓練。
img = img.view(num_img, -1)
real_img = Variable(img).cuda()
real_label = Variable(torch.ones(num_img)).cuda()
fake_label = Variable(torch.zeros(num_img)).cuda()
# compute loss of real_img
real_out = D(real_img)
d_loss_real = criterion(real_out, real_label)
real_scores = real_out
# compute loss of fake_img
z = Variable(torch.randn(num_img, z_dimension)).cuda()
fake_img = G(z)
fake_out = D(fake_img)
d_loss_fake = criterion(fake_out, fake_label)
fake_scores = fake_out
# bp and optimize
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
開始需要自己建立label,真實的資料是1,生成的假的資料是0,然後將真實的資料輸入判別器得到loss,將假的資料輸入判別器得到loss,將這兩個loss 加起來得到總的loss,然後反向傳播去更新引數就能夠得到一個優化好的判別器。
接下來是生成模型的訓練:
# compute loss of fake_img
z = Variable(torch.randn(num_img, z_dimension)).cuda() # 得到隨機噪聲
fake_img = G(z) # 生成假的圖片
output = D(fake_img) # 經過判別器得到結果
g_loss = criterion(output, real_label) # 得到假的圖片與真實圖片label的loss
# bp and optimize
g_optimizer.zero_grad() # 歸0梯度
g_loss.backward() # 反向傳播
g_optimizer.step() # 更新生成網路的引數
一個隨機隱含向量通過生成網路得到了一個假的資料,然後希望假的資料經過判別模型後儘可能和真實label 接近,通過g_loss = criterion(output, real_label)實現,然後反向傳播去優化生成器的引數,在這個過程中,判別器的引數不再發生變化,否則生成器永遠無法騙過優化的判別器。
除了使用簡單的多層感知器外,也可以在生成模型和對抗模型中使用更加複雜的卷積神經網路,定義十分簡單。
class discriminator(nn.Module):
def __init__(self):
super(discriminator, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(1, 32, 5, padding=2), # batch, 32, 28, 28
nn.LeakyReLU(0.2, True),
nn.AvgPool2d(2, stride=2), # batch, 32, 14, 14
)
self.conv2 = nn.Sequential(
nn.Conv2d(32, 64, 5, padding=2), # batch, 64, 14, 14
nn.LeakyReLU(0.2, True),
nn.AvgPool2d(2, stride=2) # batch, 64, 7, 7
)
self.fc = nn.Sequential(
nn.Linear(64*7*7, 1024),
nn.LeakyReLU(0.2, True),
nn.Linear(1024, 1),
nn.Sigmoid()
)
def forward(self, x):
'''
x: batch, width, height, channel=1
'''
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
class generator(nn.Module):
def __init__(self, input_size, num_feature):
super(generator, self).__init__()
self.fc = nn.Linear(input_size, num_feature) # batch, 3136=1x56x56
self.br = nn.Sequential(
nn.BatchNorm2d(1),
nn.ReLU(True)
)
self.downsample1 = nn.Sequential(
nn.Conv2d(1, 50, 3, stride=1, padding=1), # batch, 50, 56, 56
nn.BatchNorm2d(50),
nn.ReLU(True)
)
self.downsample2 = nn.Sequential(
nn.Conv2d(50, 25, 3, stride=1, padding=1), # batch, 25, 56, 56
nn.BatchNorm2d(25),
nn.ReLU(True)
)
self.downsample3 = nn.Sequential(
nn.Conv2d(25, 1, 2, stride=2), # batch, 1, 28, 28
nn.Tanh()
)
def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), 1, 56, 56)
x = self.br(x)
x = self.downsample1(x)
x = self.downsample2(x)
x = self.downsample3(x)
return x
圖3 左邊是多層感知器的生成對抗網路,右邊是卷積生成對抗網路,右邊的圖片比左邊的圖片噪宣告顯更少。在卷積神經網路裡引入了批標準化(Batchnormalization)來穩定訓練,同時使用了LeakyReLU 和平均池化來進行訓練。生成對抗網路的訓練其實是很困難的,因為這是兩個對偶網路在相互學習,所以需要增加一些訓練技巧才能使訓練更加穩定。
圖3生成對抗網路對比結果
以上介紹了生成對抗網路的簡單原理和訓練流程,但是對生成對抗網路而言,它其實並沒有真正地學習到它要表示的物體,通過對抗的過程,它只是生成了一張儘可能真的圖片,這就意味著沒辦法決定用哪種噪聲能夠生成想要的圖片,除非把初始分佈都試一遍。所以在生成對抗網路提出之後,有很多基於標準生成對抗網路的變式來解決各種各樣的問題。
2 Improving GAN
這一節將介紹改善的生成對抗網路,因為生成對抗網路存在很多問題,所以人們研究能否通過改善網路結構或者損害函式來解決這些問題。
1 Wasserstein GAN
Wasserstein GAN 是GAN 的一種變式,我們知道GAN 的訓練是非常麻煩的,需要很多訓練技巧,而且在不同的資料集上,由於資料的分佈會發生變化,也需要重新調整引數,不僅需要小心地平衡生成器和判別器的訓練程式,同時生成的樣本還缺乏多樣性。除此之外最大的問題是沒辦法衡量這個生成器到底好不好,因為沒辦法通過判別器的loss 去判斷這個事情。雖然DC GAN 依靠對生成器和判別器的結構進行列舉,最終找到了一個比較好的網路設定,但還是沒有從根本上解決訓練的問題。
WGAN 的出現,徹底解決了下面這些難點:
(1)徹底解決了訓練不穩定的問題,不再需要設計引數去平衡判別器和生成器;
(2)基本解決了collapse mode 的問題,確保了生成樣本的多樣性;
(3)訓練中有一個向交叉熵、準確率的數值指標來衡量訓練的程式,數值越小代表GAN 訓練得越好,同時也就代表著生成的圖片質量越高;
(4)不需要精心設計網路結構,用簡單的多層感知器就能夠取得比較好的效果。
下面先介紹為什麼GAN 會有這些缺點,然後解釋WGAN是通過什麼辦法解決這些問題的。
① GAN 的侷限性
根據之前介紹的,有下面的式子(1):
從式(1)我們知道原始的GAN 是通過最優判別器下的JS Divergence 來衡量兩種分佈之間的差異的,而且最優判別器下JS Divergence 越小,就說明兩種分佈越接近,但是JS Divergence 有一個嚴重的問題,那就是如果兩種分佈完全沒有重疊部分,或者說重疊部分可忽略,那麼JS Divergence 將恆等於常數log2。換句話說,就算兩種分佈很接近,但是隻要它們沒有重疊,那麼JS Divergence 就是一個常數,這就使得網路沒辦法通過這個損失函式去學習,因為它沒辦法知道它是否做得好,這就會導致梯度消失,同時這也使得我們沒有辦法衡量這兩種分佈到底有多靠近。
而真實分佈與生成的分佈沒有重疊部分的概率有多大呢?其實是非常大的,直觀來講,真實分佈是一個高維分佈,而生成的分佈來自於一個低維分佈,所以其實很有可能生成分佈和真實分佈之間就沒有重疊的部分。除此之外,不可能真正去計算兩個分佈,只能近似取樣,所以也導致了兩種分佈沒有重疊部分。如果判別器訓練得太好,那麼生成的分佈和原來分佈基本沒有重疊部分,這就導致了梯度消失;如果判別器訓練得不好,這樣生成器的梯度又不準,就會出現錯誤的優化方向。如果要使得GAN 能夠完美地收斂,那麼需要判別器的訓練不好也不壞,而這個度是很難把握的,況且這還依賴資料的分佈等條件,所以GAN 才這麼難訓練。
②Wasserstein 距離
既然GAN 存在的問題都是由於JS Divergence 引起的,那麼能不能換一種度量方式去衡量兩種分佈之間的差異,而不使用JS Divergence?答案是肯定的,這就是WGAN中提出的解決辦法。
首先介紹一種新的度量方式去度量兩種分佈之間的差異——Wasserstein 距離,也稱為Earth Mover 距離,定義如下:
看上去可能比較複雜,數學解釋如下:對於兩種分佈Pr 和Pg,它們的聯合分佈是II(Pr,Pg),換句話說II(Pr,Pg) 中每一個聯合分佈的邊緣分佈就是Pr 或者Pg。那麼對每一個聯合分佈而言,從裡面取樣x 和y,並計算x 和y 的距離,然後取遍所有的x 和y 計算一下期望,接著取這些期望裡面最小的作為W 距離的定義。
如果上面的解釋不夠清楚,也可以通俗地解釋,因為它還有一個別名叫Earth mover距離,也就是推土機距離,這是什麼意思呢?可以把兩種分佈想象成兩堆土,然後想想如何用推土機將一種分佈變成另外一種分佈的樣子,會有很多種移動方案,裡面最小消耗的那種方案就是最優的方案,也就是這個距離的定義。
W 距離與JS Divergence 相比有什麼好處呢?最大的好處就是不管兩種分佈是否有重疊,它都是連續變換的而不是突變的,可以用下面這個例子來說明一下,如圖4所示。
圖4 W 距離例子
通過上面這個演示可以發現,雖然兩種分佈更接近,但JS Divergence 仍然是log2,W 距離就能夠連續而有效地衡量兩種分佈之間的差異。
③WGAN
W 距離有很好的優越性,把它拿來作為兩種分佈的度量優化生成器,但是W 距離裡面有一個是沒辦法求解的。作者Martin 在論文附錄裡面通過定理將這個問題轉變成了一個新的問題,有著如下形式:
這裡引入了一個新的概念——Lipschitz 連續。如果函式f 滿足Lipschitz 連續條件,那麼它就滿足下面的式子:
我們不希望函式的變化太快,希望函式f 變化能比較平緩。
那麼可以將上面的式子改成GAN:
也就是說構建一個神經網路D 作為判別器,希望D 輸出的變化比較平緩,在實際計算中限制D 中的引數大小不超過某個範圍,這樣就使得關於輸入的樣本,D 的輸出變化基本不會超過某個範圍,所以就能夠基本滿足Lipschitz 連續條件。
所以最後構造一個判別器D,滿足:
儘可能取到最大,同時D 還要滿足Lipschitz 連續條件,得到的L 可以近似為真實分佈和生成分佈的Wasserstein 距離。原始的GAN 做的是二分類的任務,也就是對於真假圖片進行二分類,而WGAN 做的是迴歸問題,相當於近似擬合Wasserstein 距離。
最後優化生成器的時候希望最小化L,這時候需要滿足Lipschitz 連續條件,所以需要做權重的裁剪,由於W 距離的優越性,不再需要擔心梯度消失的問題,這樣就能夠得到WGAN 的整個訓練過程。
總結一下,WGAN 與原始GAN 相比,只改了以下四點:
(1)判別器最後一層去掉sigmoid;
(2)生成器和判別器的loss 不取log;
(3)每次更新判別器的引數之後把它們的絕對值裁剪到不超過一個固定常數的數;
(4)不要用基於動量的優化演算法(比如momentuem 和Adam),推薦使用RMSProp。
前三點都是從理論分析得到的結果,第(4)點是作者從實驗中發現的。對於WGAN,論文作者做了不少實驗,得到了幾個結論:第一,WGAN 如果使用類似DCGAN 的結構,那麼和DCGAN 生成的圖片差不多,但是WGAN 的優勢就在於不用DCGAN 的結構,也能生成效果比較好的圖片,但是把DCGAN 的Batch Normalization 拿掉的話,DCGAN 就不能生成圖片了;第二,WGAN 和原始的GAN 都是用多層全連線網路的話,WGAN 生成的圖片質量會變得差一些,但是原始的GAN 不僅質量很差,還有多樣性不足的問題。
2 Improving WGAN
WGAN 的提出成功地解決了GAN 的很多問題,最後需要滿足一階Lipschitz 連續性條件,所以在訓練的時候加了一個限制——權重裁剪。
然而權重的裁剪只是一種簡單的做法,不是最好的做法,所以隨後有人提出了一些新的辦法來解決這個問題。
首先提出一個定理:一個可微函式如果滿足1 階Lipschitz 連續,等價於它的梯度範數處小於1。用式子來表示就是:
有了這個定理,就能夠近似地這樣去表達W 距離:
不需要在整個分佈上都滿足Lipschitz 條件,只需要沿著一些直線上的點滿足這些,結果就已經很好了,同時在實際中採用的策略也不是取max,因為不希望太小,所以做的是最小化
,最後改進的WGAN 就是:
改進後的WGAN 和改進前的WGAN 相比,訓練更加穩定,生成的圖片效果也更好。
以上內容節選自《深度學習入門之PyTorch》,點此連結可在博文視點官網檢視此書。
想及時獲得更多精彩文章,可在微信中搜尋“博文視點”或者掃描下方二維碼並關注。
相關文章
- 【深度學習理論】通俗理解生成對抗網路GAN深度學習
- GAN生成對抗網路-DCGAN原理與基本實現-深度卷積生成對抗網路03卷積
- 【機器學習】李宏毅——生成式對抗網路GAN機器學習
- [深度學習]生成對抗網路的實踐例子深度學習
- 解讀生成對抗網路(GAN) 之U-GAN-IT
- 萬字綜述之生成對抗網路(GAN)
- 0901-生成對抗網路GAN的原理簡介
- 【生成對抗網路學習 其一】經典GAN與其存在的問題和相關改進
- GAN實戰筆記——第四章深度卷積生成對抗網路(DCGAN)筆記卷積
- 白話生成對抗網路 GAN,50 行程式碼玩轉 GAN 模型!【附原始碼】行程模型原始碼
- 海量案例!生成對抗網路(GAN)的18個絕妙應用
- 對抗網路學習記錄
- GAN實戰筆記——第七章半監督生成對抗網路(SGAN)筆記
- 【強化學習】使用off-policy演算法機器人抓取任務基準;生成對抗網路 GAN 就是強化學習強化學習演算法機器人
- 一文入門人工智慧的掌上明珠:生成對抗網路(GAN)人工智慧
- 第六週:生成式對抗網路
- 基於深度對抗學習的智慧模糊資料生成方法
- 三大深度學習生成模型:VAE、GAN及其變種深度學習模型
- 資本與科學:抗衰明星布萊恩獲得學術界認可
- 實戰生成對抗網路[1]:簡介
- LSGAN:最小二乘生成對抗網路
- 吳恩達Deeplearning.ai國慶節上新:生成對抗網路(GAN)專項課程吳恩達AI
- GAN實戰筆記——第六章漸進式增長生成對抗網路(PGGAN)筆記
- 人工智慧-深度學習-生成模型:GAN經典模型-->InfoGAN人工智慧深度學習模型
- 人工智慧-深度學習-生成模型:GAN經典模型-->VAEGAN人工智慧深度學習模型
- 實戰生成對抗網路[2]:生成手寫數字
- 卷積生成對抗網路(DCGAN)---生成手寫數字卷積
- 再聊神經網路與深度學習神經網路深度學習
- 深度學習與圖神經網路深度學習神經網路
- 【生成對抗網路學習 其三】BiGAN論文閱讀筆記及其原理理解筆記
- 獨家 | GAN大盤點,聊聊這些年的生成對抗網路 : LSGAN, WGAN, CGAN, infoGAN, EBGAN, BEGAN, VAE
- 如何應用TFGAN快速實踐生成對抗網路?
- 生成對抗網路的進步多大,請看此文
- 深度學習與圖神經網路學習分享:CNN 經典網路之-ResNet深度學習神經網路CNN
- 深度學習一:深度前饋網路深度學習
- 【深度學習】--GAN從入門到初始深度學習
- 深度學習之神經網路(CNN/RNN/GAN)演算法原理+實戰 完整版深度學習神經網路CNNRNN演算法
- 對抗深度學習: 魚 (模型準確性) 與熊掌 (模型魯棒性) 能否兼得?深度學習模型
- 深度學習之Transformer網路深度學習ORM