涵蓋18+ SOTA GAN實現,這個影像生成領域的PyTorch庫火了

機器之心發表於2021-03-01

GAN 自從被提出後,便迅速受到廣泛關注。我們可以將 GAN 分為兩類,一類是無條件下的生成;另一類是基於條件資訊的生成。近日,來自韓國浦項科技大學的碩士生在 GitHub 上開源了一個專案,提供了條件 / 無條件影像生成的代表性生成對抗網路(GAN)的實現。

涵蓋18+ SOTA GAN實現,這個影像生成領域的PyTorch庫火了

近日,機器之心在 GitHub 上看到了一個非常有意義的專案 PyTorch-StudioGAN,它是一個 PyTorch 庫,提供了條件 / 無條件影像生成的代表性生成對抗網路(GAN)的實現。據主頁介紹,該專案旨在提供一個統一的現代 GAN 平臺,這樣機器學習領域的研究者可以快速地比較和分析新思路和新方法等。

該專案的作者為韓國浦項科技大學的碩士生,他的研究興趣主要包括深度學習、機器學習和計算機視覺。

涵蓋18+ SOTA GAN實現,這個影像生成領域的PyTorch庫火了

專案地址:https://github.com/POSTECH-CVLab/PyTorch-StudioGAN

具體而言,該專案具有以下幾個顯著特徵:
  • 提供了大量 PyTorch 框架的 GAN 實現;

  • 基於 CIFAR 10、Tiny ImageNet 和 ImageNet 資料集的 GAN 基準;

  • 相較原始實現的更好的效能和更低的記憶體消耗;

  • 提供完全最新 PyTorch 環境的預訓練模型;

  • 支援多 GPU(DP、DDP 和多節點 DDP)、混合精度、同步批歸一化、LARS、Tensorboard 視覺化和其他分析方法。

對於這個 PyTorch GAN 庫,有網友表示:「看上去很不錯!如果可以提供 top-k 等現代訓練實踐以及各種增強方法就更棒了。」對此,專案作者稱其會在 NeurIPS 論文提交截止日期之後,新增一些改進的方法,如 Sinha 等人的 Tok-K 訓練以及 Langevin 取樣和 SimCLR 增強。

涵蓋18+ SOTA GAN實現,這個影像生成領域的PyTorch庫火了

此外,有網友詢問是否可以將該專案用於影像之外的其他領域。作者表示可以,即使無法使用一些穩定器(如 diffaug、ada 等),依然可以透過調整 dataLoader 來訓練自己的模型。

涵蓋18+ SOTA GAN實現,這個影像生成領域的PyTorch庫火了

18+ SOTA GAN 實現

如下圖所示,專案作者提供了 18 + 個 SOTA GAN 的實現,包括 DCGAN、LSGAN、GGAN、WGAN-WC、WGAN-GP、WGAN-DRA、ACGAN、ProjGAN、SNGAN、SAGAN、BigGAN、BigGAN-Deep、CRGAN、ICRGAN、LOGAN、DiffAugGAN、ADAGAN、ContraGAN 和 FreezeD。

涵蓋18+ SOTA GAN實現,這個影像生成領域的PyTorch庫火了

cBN:條件批歸一化;AC:輔助分類器;PD:Projection 判別器;CL:對比學習。

其中,需要注意以下幾點:
  • G/D_type 表示將標籤資訊注入生成器或判別式的方式;

  • EMA 表示生成器中應用更新後的指數移動平均線;

  • Tiny ImageNet 資料集上的實驗使用的是 ResNet 架構而不是 CNN。

下圖中 StyleGAN2 為即將實現的 GAN 網路,其中 AdaIN 表示自適應例項歸一化(Adaptive Instance Normalization)。

涵蓋18+ SOTA GAN實現,這個影像生成領域的PyTorch庫火了

環境要求
  • Anaconda

  • Python >= 3.6

  • 6.0.0 <= Pillow <= 7.0.0

  • scipy == 1.1.0

  • sklearn

  • seaborn

  • h5py

  • tqdm

  • torch >= 1.6.0 

  • torchvision >= 0.7.0

  • tensorboard

  • 5.4.0 <= gcc <= 7.4.0

  • torchlars 

使用者可以採用以下方法安裝推薦的環境:

conda env create -f environment.yml -n studiogan

在 docker 中還可以採用以下方式:

docker pull mgkang/studiogan:latest

以下是建立名字為「studioGAN」容器的命令,同樣也可以使用埠號為 6006 來連線 tensoreboard。

docker run -it --gpus all --shm-size 128g -p 6006:6006 --name studioGAN -v /home/USER:/root/code --workdir /root/code mgkang/studiogan:latest /bin/bash

使用方法

使用 GPU 0 的情況下,在 CONFIG_PATH 中對於模型的訓練「-t」和評估「-e」進行了定義:

CUDA_VISIBLE_DEVICES=0 python3 src/main.py -t -e -c CONFIG_PATH

在使用 GPU (0, 1, 2, 3) 和 DataParallel 情況下,在 CONFIG_PATH 中對於模型的訓練「-t」和評估「-e」進行了定義:

CUDA_VISIBLE_DEVICES=0,1,2,3 python3 src/main.py -t -e -c CONFIG_PATH

在 python3 src/main.py 程式中檢視可用選項,透過 Tensorboard 可以監控 IS、FID、F_beta、Authenticity Accuracies 以及最大奇異值:

~ PyTorch-StudioGAN/logs/RUN_NAME>>> tensorboard --logdir=./ --port PORT

視覺化以及分析生成影像

StudioGAN 支援影像視覺化、k 最近鄰分析、線性差值以及頻率分析。所有的結果儲存在「./figures/RUN_NAME/*.png」中。

影像視覺化的程式碼和示例如下:

CUDA_VISIBLE_DEVICES=0,...,N python3 src/main.py -iv -std_stat --standing_step STANDING_STEP -c CONFIG_PATH --checkpoint_folder CHECKPOINT_FOLDER --log_output_path LOG_OUTPUT_PATH

涵蓋18+ SOTA GAN實現,這個影像生成領域的PyTorch庫火了

k 最近鄰分析,這裡固定 K=7,第一列中是生成的影像:

CUDA_VISIBLE_DEVICES=0,...,N python3 src/main.py -knn -std_stat --standing_step STANDING_STEP -c CONFIG_PATH --checkpoint_folder CHECKPOINT_FOLDER --log_output_path LOG_OUTPUT_PATH

涵蓋18+ SOTA GAN實現,這個影像生成領域的PyTorch庫火了

線性插值(僅適用於有條件的 Big ResNet 模型 )的程式碼和示例如下:

CUDA_VISIBLE_DEVICES=0,...,N python3 src/main.py -itp -std_stat --standing_step STANDING_STEP -c CONFIG_PATH --checkpoint_folder CHECKPOINT_FOLDER --log_output_path LOG_OUTPUT_PATH

涵蓋18+ SOTA GAN實現,這個影像生成領域的PyTorch庫火了


參考連結:https://www.reddit.com/r/MachineLearning/comments/lu9gen/p_pytorch_gan_library_that_provides/

相關文章