帶自注意力機制的生成對抗網路,實現效果怎樣?

機器之心發表於2018-06-06

在前一段時間,Han Zhang 和 Goodfellow 等研究者提出新增了自注意力機制生成對抗網路,這種網路可使用全域性特徵線索來生成高解析度細節。本文介紹了自注意力生成對抗網路的 PyTorch 實現,讀者也可以嘗試這一新型生成對抗網路

專案地址:https://github.com/heykeetae/Self-Attention-GAN

這個資源庫提供了一個使用 PyTorch 實現的 SAGAN。其中作者準備了 wgan-gp 和 wgan-hinge 損失函式,但注意 wgan-gp 有時與譜歸一化(spectral normalization)是不匹配的;因此,作者會移除模型所有的譜歸一化來適應 wgan-gp。

在這個實現中,自注意機制會應用到生成器和鑑別器的兩個網路層。畫素級的自注意力會增加 GPU 資源的排程成本,且每個畫素有不同的注意力掩碼。Titan X GPU 大概可選擇的批量大小為 8,你可能需要減少自注意力模組的數量來減少記憶體消耗。

帶自注意力機制的生成對抗網路,實現效果怎樣?

目前更新狀態:

  • 注意力視覺化 (LSUN Church-outdoor)

  • 無監督設定(現未使用標籤)

  • 已應用:Spectral Normalization(程式碼來自 https://github.com/christiancosgrove/pytorch-spectral-normalization-gan)

  • 已實現:自注意力模組(self-attention module)、兩時間尺度更新規則(TTUR)、wgan-hinge 損失函式和 wgan-gp 損失函式

結果

下圖展示了 LSUN 中的注意力結果 (epoch #8):

帶自注意力機制的生成對抗網路,實現效果怎樣?

SAGAN 在 LSUN church-outdoor 資料集上的逐畫素注意力結果。這表示自注意力模組的無監督訓練依然有效,即使注意力圖本身並不具有可解釋性。更好的圖片生成結果以後會新增,上面這些是在生成器第層 3 和層 4 中的自注意力的視覺化,它們的尺寸依次是 16 x 16 和 32 x 32,每一張都包含 64 張注意力圖的視覺化。要視覺化逐畫素注意力機制,我們只能如左右兩邊的數字顯示選擇一部分畫素。

CelebA 資料集 (epoch on the left, 還在訓練中):

帶自注意力機制的生成對抗網路,實現效果怎樣?

LSUN church-outdoor 資料集 (epoch on the left, 還在訓練中):

帶自注意力機制的生成對抗網路,實現效果怎樣?

訓練環境:

  • Python 3.5+ (https://www.continuum.io/downloads)

  • PyTorch 0.3.0 (http://pytorch.org/)

用法

1. 克隆版本庫

$ git clone https://github.com/heykeetae/Self-Attention-GAN.git

$ cd Self-Attention-GAN

2. 下載資料集 (CelebA 或 LSUN)

$ bash download.sh CelebA
or
$ bash download.sh LSUN

3. 訓練

$ python python main.py --batch_size 6 --imsize 64 --dataset celeb --adv_loss hinge --version sagan_celeb
or
$ python python main.py --batch_size 6 --imsize 64 --dataset lsun --adv_loss hinge --version sagan_lsun

4. 享受結果吧~

$ cd samples/sagan_celeb
or
$ cd samples/sagan_lsun

每 100 次迭代生成一次樣本,抽樣率可根據引數 --sample_step (ex,—sample_step 100) 控制。

論文:Self-Attention Generative Adversarial Networks


帶自注意力機制的生成對抗網路,實現效果怎樣?


論文地址:https://arxiv.org/abs/1805.08318


在此論文中,我們提出了自注意生成式對抗網路(SAGAN),能夠為影象生成任務實現注意力驅動的、長範圍的依存關係建模。傳統的卷積 GAN 只根據低分辨特徵圖中的空間區域性點生成高解析度細節(detail)。在 SAGAN 中,可使用所有特徵點的線索來生成高解析度細節,而且鑑別器能檢查圖片相距較遠部分的細微細節特徵是否彼此一致。不僅如此,近期研究表明鑑別器調節可影響 GAN 的表現。根據這個觀點,我們在 GAN 生成器中加入了譜歸一化(spectral normalization),並發現這樣可以提高訓練動力學。我們所提出的 SAGAN 達到了當前最優水平,在極具挑戰性的 ImageNet 資料集中將最好的 inception 分數記錄從 36.8 提高到 52.52,並將 Frechet Inception 距離從 27.62 減少到 18.65。注意力層的視覺化展現了生成器可利用其附近環境對物體形狀做出反應,而不是直接使用固定形狀的區域性區域。

相關文章