本文大約 8000 字,閱讀大約需要 12 分鐘
第一次翻譯,限於英語水平,可能不少地方翻譯不準確,請見諒!
最近谷歌開源了一個基於 TensorFlow 的庫–TFGAN,方便開發者快速上手 GAN 的訓練,其 Github 地址如下:
原文網址:Generative Adversarial Networks: Google open sources TensorFlow-GAN (TFGAN)
如果你玩過波斯王子,那你應該知道你需要保護自己不被”影子“所殺掉,但這也是一個矛盾:如果你殺死“影子”,那遊戲就結束了;但你不做任何事情,那麼遊戲也會輸掉。
儘管生成對抗網路(GAN)有不少優點,但它也面臨著相似的區分問題。大部分支援 GAN 的深度學習專業也是非常謹慎的支援它,並指出它確實存在穩定性的問題。
GAN 的這個問題也可以稱做整體收斂性問題。儘管判別器 D 和 生成器 D 相互競爭博弈,但同時也相互依賴對方來達到有效的訓練。如果其中一方訓練得很差,那整個系統也會很差(這也是之前提到的梯度消失或者模式奔潰問題)。並且你也需要確保他們不會訓練太過度,造成另一方無法訓練了。因此,波斯王子是一個很有趣的概念。
首先,神經網路的提出就是為了模仿人類的大腦(儘管是人為的)。它們也已經在物體識別和自然語言處理方面取得成功。但是,想要在思考和行為上與人類一致,這還有非常大的差距。
那麼是什麼讓 GANs 成為機器學習領域一個熱門話題呢?因為它不僅只是一個相對新的結構,它更加是一個比之前其他模型都能更加準確的對真實資料建模,可以說是深度學習的一個革命性的變化。
最後,它是一個同時訓練兩個獨立的網路的新模型,這兩個網路分別是判別器和生成器。這樣一個非監督神經網路卻能比其他傳統網路得到更好效能的結果。
但目前事實是我們對 GANs 的研究還只是非常淺層,仍然有著很多挑戰需要解決。GANs 目前也存在不少問題,比如無法區分在某個位置應該有多少特定的物體,不能應用到 3D 物體,以及也不能理解真實世界的整體結構。當然現在有大量研究正在研究如何解決上述問題,新的模型也取得更好的效能。
而最近谷歌為了讓 GANs 更容易實現,設計開發並開源了一個基於 TensorFlow 的輕量級庫–TFGAN。
根據谷歌的介紹,TFGAN 提供了一個基礎結構來減少訓練一個 GAN 模型的難度,同時提供非常好測試的損失函式和評估標準,以及給出容易上手的例子,這些例子強調了 TFGAN 的靈活性和易於表現的優點。
此外,還提供了一個教程,包含一個高階的 API 可以快速使用自己的資料集訓練一個模型。
上圖是展示了對抗損失在影像壓縮方面的效果。最上方第一行圖片是來自 ImageNet 資料集的圖片,也是原始輸入圖片,中間第二行展示了採用傳統損失函式訓練得到的影像壓縮神經網路的壓縮和解壓縮效果,最底下一行則是結合傳統損失函式和對抗損失函式訓練的網路的結果,可以看到儘管基於對抗損失的圖片並不像原始圖片,但是它比第二行的網路得到更加清晰和細節更好的圖片。
TFGAN 既提供了幾行程式碼就可以實現的簡答函式來呼叫大部分 GAN 的使用例子,也是建立在包含複雜 GAN 設計的模式化方式。這就是說,我們可以採用自己需要的模組,比如損失函式、評估策略、特徵以及訓練等等,這些都是獨立的模組。TFGAN 這樣的設計方式其實就滿足了不同使用者的需求,對於入門新手可以快速訓練一個模型來看看效果,對於需要修改其中任何一個模組的使用者也能修改對應模組,而不會牽一髮而動全身。
最重要的是,谷歌也保證了這個程式碼是經過測試的,不需要擔心一般的 GAN 庫造成的數字或者統計失誤。
開始使用
首先新增以下程式碼來匯入 tensorflow 和 宣告一個 TFGAN 的例項:
import tensorflow as tf
tfgan = tf.contrib.gan
複製程式碼
為何使用 TFGAN
- 採用良好測試並且很靈活的呼叫介面實現快速訓練生成器和判別器網路,此外,還可以混合 TFGAN、原生 TensorFlow以及其他自定義框架程式碼;
- 使用實現好的GAN 的損失函式和懲罰策略 (比如 Wasserstein loss、梯度懲罰等)
- 訓練階段對 GAN 進行監控和視覺化操作,以及評估生成結果
- 使用實現好的技巧來穩定和提高效能
- 基於常規的 GAN 訓練例子來開發
- 採用GANEstimator介面裡快速訓練一個 GAN 模型
- TFGAN 的結構改進也會自動提升你的 TFGAN 專案的效能
- TFGAN 會不斷新增最新研究的演算法成果
TFGAN 的部件有哪些呢?
TFGAN 是由多個設計為獨立的部件組成的,分別是:
- core:提供了一個主要的訓練 GAN 模型的結構。訓練過程分為四個階段,每個階段都可以採用自定義程式碼或者 呼叫 TFGAN 庫介面來完成;
- features:包含許多常見的 GAN 運算和正則化技術,比如例項正則化(instance normalization)
- losses:包含常見的 GAN 的損失函式和懲罰機制,比如 Wasserstein loss、梯度懲罰、相互資訊懲罰等
- evaulation:使用一個預訓練好的 Inception 網路來利用
Inception Score
或者Frechet Distance
評估標準來評估非條件生成模型。當然也支援利用自己訓練的分類器或者其他方法對有條件生成模型的評估 - examples and tutorial:使用 TFGAN 訓練 GAN 模型的例子和教程。包含了使用非條件和條件式的 GANs 模型,比如 InfoGANs 等。
訓練一個 GAN 模型
典型的 GAN 模型訓練步驟如下:
- 為你的網路指定輸入,比如隨機噪聲,或者是輸入圖片(一般是應用在圖片轉換的應用,比如 pix2pixGAN 模型)
- 採用
GANModel
介面定義生成器和判別器網路 - 採用
GANLoss
指定使用的損失函式 - 採用
GANTrainOps
設定訓練運算操作,即優化器 - 開始訓練
當然,GAN 的設定有多種形式。比如,你可以在非條件下訓練生成器生成圖片,或者可以給定一些條件,比如類別標籤等輸入到生成器中來訓練。無論是哪種設定,TFGAN 都有相應的實現。下面將結合程式碼例子來進一步介紹。
例項
非條件 MNIST 圖片生成
第一個例子是訓練一個生成器來生成手寫數字圖片,即 MNIST 資料集。生成器的輸入是從多變數均勻分佈取樣得到的隨機噪聲,目標輸出是 MNIST 的數字圖片。具體檢視論文“Generative Adversarial Networks”。程式碼如下:
# 配置輸入
# 真實資料來自 MNIST 資料集
images = mnist_data_provider.provide_data(FLAGS.batch_size)
# 生成器的輸入,從多變數均勻分佈取樣得到的隨機噪聲
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])
# 呼叫 tfgan.gan_model() 函式定義生成器和判別器網路模型
gan_model = tfgan.gan_model(
generator_fn=mnist.unconditional_generator,
discriminator_fn=mnist.unconditional_discriminator,
real_data=images,
generator_inputs=noise)
# 呼叫 tfgan.gan_loss() 定義損失函式
gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss)
# 呼叫 tfgan.gan_train_ops() 指定生成器和判別器的優化器
train_ops = tfgan.gan_train_ops(
gan_model,
gan_loss,
generator_optimizer=tf.train.AdamOptimizer(gen_lr, 0.5),
discriminator_optimizer=tf.train.AdamOptimizer(dis_lr, 0.5))
# tfgan.gan_train() 開始訓練,並指定訓練迭代次數 num_steps
tfgan.gan_train(
train_ops,
hooks=[tf.train.StopAtStepHook(num_steps=FLAGS.max_number_of_steps)],
logdir=FLAGS.train_log_dir)
複製程式碼
條件式 MNIST 圖片生成
第二個例子同樣還是生成 MNIST 圖片,但是這次輸入到生成器的不僅僅是隨機噪聲,還會給類別標籤,這種 GAN 模型也被稱作條件 GAN,其目的也是為了讓 GAN 訓練不會太過自由。具體可以看論文“Conditional Generative Adversarial Nets”。
程式碼方面,僅僅需要修改輸入和建立生成器與判別器模型部分,如下所示:
# 配置輸入
# 真實資料來自 MNIST 資料集,這裡增加了類別標籤--one_hot_labels
images, one_hot_labels = mnist_data_provider.provide_data(FLAGS.batch_size)
# 生成器的輸入,從多變數均勻分佈取樣得到的隨機噪聲
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])
# 呼叫 tfgan.gan_model() 函式定義生成器和判別器網路模型
gan_model = tfgan.gan_model(
generator_fn=mnist.conditional_generator,
discriminator_fn=mnist.conditional_discriminator,
real_data=images,
generator_inputs=(noise, one_hot_labels)) # 生成器的輸入增加了類別標籤
# 剩餘的程式碼保持一致
...
複製程式碼
對抗損失
第三個例子結合了 L1 pixel loss 和對抗損失來學習自動編碼圖片。瓶頸層可以用來傳輸圖片的壓縮表示。如果僅僅使用 pixel-wise loss,網路只回傾向於生成模糊的圖片,但 GAN 可以用來讓這個圖片重建過程更加逼真。具體可以看論文“Full Resolution Image Compression with Recurrent Neural Networks”來了解如何用 GAN 來實現影像壓縮,以及論文“Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network”瞭解如何用 GANs 來增強生成的圖片質量。
程式碼如下:
# 配置輸入
images = image_provider.provide_data(FLAGS.batch_size)
# 配置生成器和判別器網路
gan_model = tfgan.gan_model(
generator_fn=nets.autoencoder, # 自定義的 autoencoder
discriminator_fn=nets.discriminator, # 自定義的 discriminator
real_data=images,
generator_inputs=images)
# 建立 GAN loss 和 pixel loss
gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,
gradient_penalty=1.0)
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)
# 結合兩個 loss
gan_loss = tfgan.losses.combine_adversarial_loss(
gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)
# 剩下程式碼保持一致
...
複製程式碼
影像轉換
第四個例子是影像轉換,它是將一個領域的圖片轉變成另一個領域的同樣大小的圖片。比如將語義分割圖變成街景圖,或者是灰度圖變成彩色圖。具體細節看論文“Image-to-Image Translation with Conditional Adversarial Networks”。
程式碼如下:
# 配置輸入,注意增加了 target_image
input_image, target_image = data_provider.provide_data(FLAGS.batch_size)
# 配置生成器和判別器網路
gan_model = tfgan.gan_model(
generator_fn=nets.generator,
discriminator_fn=nets.discriminator,
real_data=target_image,
generator_inputs=input_image)
# 建立 GAN loss 和 pixel loss
gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan_losses.least_squares_generator_loss,
discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss)
l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1)
# 結合兩個 loss
gan_loss = tfgan.losses.combine_adversarial_loss(
gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.weight_factor)
# 剩下程式碼保持一致
...
複製程式碼
InfoGAN
最後一個例子是採用 InfoGAN 模型來生成 MNIST 圖片,但是可以不需要任何標籤來控制生成的數字型別。具體細節可以看論文“InfoGAN: Interpretable Representation Learning by Information Maximizing Generative Adversarial Nets”。
程式碼如下:
# 配置輸入
images = mnist_data_provider.provide_data(FLAGS.batch_size)
# 配置生成器和判別器網路
gan_model = tfgan.infogan_model(
generator_fn=mnist.infogan_generator,
discriminator_fn=mnist.infogran_discriminator,
real_data=images,
unstructured_generator_inputs=unstructured_inputs, # 自定義輸入
structured_generator_inputs=structured_inputs) # 自定義
# 配置 GAN loss 以及相互資訊懲罰
gan_loss = tfgan.gan_loss(
gan_model,
generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,
gradient_penalty=1.0,
mutual_information_penalty_weight=1.0)
# 剩下程式碼保持一致
...
複製程式碼
自定義模型的建立
最後同樣是非條件 GAN 生成 MNIST 圖片,但利用GANModel
函式來配置更多引數從而更加精確控制模型的建立。
程式碼如下:
# 配置輸入
images = mnist_data_provider.provide_data(FLAGS.batch_size)
noise = tf.random_normal([FLAGS.batch_size, FLAGS.noise_dims])
# 手動定義生成器和判別器模型
with tf.variable_scope(`Generator`) as gen_scope:
generated_images = generator_fn(noise)
with tf.variable_scope(`Discriminator`) as dis_scope:
discriminator_gen_outputs = discriminator_fn(generated_images)
with variable_scope.variable_scope(dis_scope, reuse=True):
discriminator_real_outputs = discriminator_fn(images)
generator_variables = variables_lib.get_trainable_variables(gen_scope)
discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
# 依賴於你需要使用的 TFGAN 特徵,你並不需要指定 `GANModel`函式的每個引數,不過
# 最少也需要指定判別器的輸出和變數
gan_model = tfgan.GANModel(
generator_inputs,
generated_data,
generator_variables,
gen_scope,
generator_fn,
real_data,
discriminator_real_outputs,
discriminator_gen_outputs,
discriminator_variables,
dis_scope,
discriminator_fn)
# 剩下程式碼和第一個例子一樣
...
複製程式碼
最後,再次給出 TFGAN 的 Github 地址如下:
如果有翻譯不當的地方或者有任何建議和看法,歡迎留言交流;也歡迎關注我的微信公眾號–機器學習與計算機視覺或者掃描下方的二維碼,和我分享你的建議和看法,指正文章中可能存在的錯誤,大家一起交流,學習和進步!