更快更穩定:這就是Wasserstein GAN

機器之心分析師發表於2018-11-20

Courant 數學科學研究所與 Facebook 人工智慧研究所提出的 Wasserstein GAN 在標準 GAN 的基礎上實現了顯著的改進。機器之心技術分析師對該研究進行了解讀。

更快更穩定:這就是Wasserstein GAN

論文地址:https://arxiv.org/abs/1701.07875

專案地址:https://github.com/martinarjovsky/WassersteinGAN

論文討論:https://www.reddit.com/r/MachineLearning/comments/5qxoaz/r_170107875_wasserstein_gan/

引言

這篇論文介紹了一種名叫 Wasserstein GAN(WGAN)的全新演算法,這是一種可替代標準生成對抗網路(GAN)的訓練方法。這項研究沒有應用傳統 GAN 所用的那種 minimax 形式,而是基於一種名為“Wasserstein 距離”的新型距離指標做了某些修改。

更快更穩定:這就是Wasserstein GAN

這是基於 MLP 生成器的 WGAN(左上圖)和 GAN(右上圖)生成的樣本,很顯然,這裡 WGAN 的影像質量優於標準 GAN。

簡單來說,WGAN 有兩個改變。第一個是取出了判別器中的 sigmoid,這是用於計算輸出均值之間的差異的。第二個改變是判別器(這篇論文稱之為 Critic),這就只是一個函式,其目標是讓假資料有較低的預期值,讓真實資料有較高的預期值。注意這些輸出不再是對數機率,這樣這些損失現在就與二元交叉熵無關了。

Wasserstein GAN

近期一些 GAN 論文提出了一些不同的生成對抗訓練架構。但是,這些架構的一個共同點是 f-距離(包括 KL-距離、總變差散度(total variation divergence))。f-距離是真實資料分佈和生成資料分佈之間的密度比 P_r(x)/P_θ(x) 的函式,非常類似於 Jenson-Shannon(JS)距離。

更快更穩定:這就是Wasserstein GAN

上式是標準 GAN 的目標。在 GAN 的訓練過程中,判別器的目標是最大化上述目標(最大值為 0,最小值為負無窮)。GAN 的估計可對應於 JS 距離度量。我們再看看 f-距離。如果兩個分佈沒有顯著的重疊,我們又能做什麼?如果不能,那麼其機率密度比將為零或無窮,而且其對整體機率估計(比如由 (0, z) 點組成的真實資料,其中 z ~ (0,1))會有巨大的負面影響,於是樣本就會從 y=0 到 y=1 沿垂直軸 x=0 均勻分佈。但如果該模型生成樣本 (θ, z),則其分佈根本不會重疊。在這種情況下,會發生梯度消失問題,會使標準 GAN 崩潰。

所以基於這一事實,這篇論文的作者提出使用 Wasserstein 距離,而不是 JS 距離。Wasserstein 距離定義為:

更快更穩定:這就是Wasserstein GAN

我們可以這樣解讀這一等式:首先,所有可能的配置都會被選取,假設是 P_r(x) 和 P_g(x)。然後這些點會根據這兩個分佈來配對。在那之後,它會計算每組配置中配對的平均距離。這裡的 inf 可以被視為最小值,這樣最後它將從所有可能的配對配置中選擇出最小的平均距離。這篇論文提出使用這一距離度量來替代 f-距離,這樣它就不再是密度比的函式的。透過這種方式,即使兩個分佈沒有重疊,Wasserstein 距離也仍然可以描述它們相距多遠,並且透過這種方式能從根本上解決梯度消失問題。

由於初始的 Wasserstein 距離定義具有難以解決的計算複雜性,所以研究者使用了一種替代定義:

更快更穩定:這就是Wasserstein GAN

這會導致 Kantorovich-Rubinstein二元性。

值得注意的是,當且僅當 f(x) 的梯度的幅度由 K 在該空間的所有部分設定了上界時,f(x) 是 K-Lipschitz。這篇論文透過將權重限制在一定範圍內,使用網路來近似建模 K-Lipschitz。這裡的上界可以被視為是一個最大值(二元表示式)。理論上,其目標是尋找到一個 critic 函式,以最大化真實樣本均值和偽造樣本均值之間餘量。

WGAN 演算法

更快更穩定:這就是Wasserstein GAN

上面描述了 Wasserstein 生成對抗網路(WGAN)演算法。經過前面的知識介紹之後,這個演算法看起來就更簡單一些了。總結如下:

  • 更新 Critic n 次迭代,之後更新生成器;
  • 對於 Critic 的每次迭代,基於 Wasserstein 距離更新梯度,然後剪下權重;
  • 使用 RMSProp;
  • 像普通 GAN 那樣更新生成器。

下面給出了實現 WGAN 演算法的程式碼示例:

   # (1) update Critic Network
             for p in netD.parameters():
                 p.requires_grad = True
             netD.zero_grad()
    # train with real
             real_cpu, _ = data
             netD.zero_grad()
             batch_size = real_cpu.size(0)
             input.data.resize_(real_cpu.size()).copy_(real_cpu)
             errD_real = netD(input)
             errD_real.backward(one)
    # train with fake
             noise.data.resize_(batch_size, nz, 1, 1)
             noise.data.normal_(0, 1)
             fake = netG(noise)
             input.data.copy_(fake.data)
             errD_fake = netD(input)
             errD_fake.backward(mone)
             errD = errD_real - errD_fake
             optimizerD.step()
 
      # (2) Update G network
             for p in netD.parameters():
                 p.requires_grad = False # to avoid computation
             netG.zero_grad()  
             noise.data.resize_(opt.batchSize, nz, 1, 1)
             noise.data.normal_(0, 1)
             fake = netG(noise)
             errG = netD(fake)
             errG.backward(one)
             optimizerG.step()

實證實驗

研究者使用 Wasserstein GAN 進行了一些定量實驗,並且表明相比於標準 GAN,使用 WGAN有顯著的實際好處。

他們提到了兩個優勢:

  • WGAN 的損失表現出了收斂的特性。

更快更穩定:這就是Wasserstein GAN

更快更穩定:這就是Wasserstein GAN

如上所示,上圖為 WGAN,下圖為標準 GAN。對於 WGAN,隨著損失快速下降,樣本質量也會增長。相比於 WGAN,標準 GAN 演算法的誤差曲線是不穩定的,甚至會增大。

  • 最佳化過程的穩定性提升。

更快更穩定:這就是Wasserstein GAN

上圖是使用無批歸一化的該演算法得到的生成器的結果。左上基於 WGAN 演算法,右上基於標準 GAN 演算法。標準 GAN 不能學習的地方,WGAN 依然能穩定地生成合理的樣本。

分析師簡評

這篇論文提出了一種名為 Wasserstein GAN 的新型生成對抗網路。它從理論上向我們說明了已有的 GAN 模型失敗的原因以及 WGAN 有效的原因。相比於 DCGAN 等標準 GAN,這篇論文表明即使沒有批歸一化,WGAN 也能穩定地訓練。但也仍然存在一些值得關注的地方。首先,在更新生成器之前他們更新了 critic n 次迭代,這意味著 critic 的迭代次數仍是人工調節的。是否存在最佳化兩者的更好方法呢?第二,WGAN 在非常深度的網路上的泛化情況如何,比如 152 層的殘差網路?第三,他們限制了權重的範圍以確保 Lipschitz 連續性,但是否存在建模這種情況的方法?最後,生成對抗訓練能否用於詞預測等 NLP 任務,同時還能保持穩定性?

相關文章