更多幹貨內容請關注微信公眾號“AI 前線”,(ID:ai-front)
本文主要講解 TFGAN 如何應用於原生 GAN、CGAN、InfoGAN、WGAN 等場景,如下所示:
其中,原生 GAN 生成的 Mnist 影像不可控:CGAN 可按照數字標籤生成相應標籤的數字影像;InfoGAN 可認為是無監督的 CGAN,前兩行表示用分類潛變數控制數字的生成類別,中間兩行表示用連續型潛變數控制數字的粗細,最後兩行表示用連續型潛變數控制數字的傾斜方向;ImageToImage 是 CGAN 的一種,實現影像的風格轉換。
GAN 由 Goodfellow 首先提出,主要由兩部分構成:Generator(生成器),簡稱 G;Discriminator(判別器), 簡稱 D。生成器主要用噪聲 z 生成一個類似真實資料的樣本,樣本越逼真越好;判別器用於估計一個樣本來自於真實資料還是生成資料,判定越準確越好。如下圖所示:
上圖中,對於真實的取樣資料,通過判別網路後,生成 D(x)。D(x) 的輸出是 0-1 範圍內的一個實數,用來判斷這個圖片是一個真實圖片的概率是多大。這樣對於真實資料,D(x) 越接近 1 越好。對於隨機噪聲 z,通過生成網路 G 後,G 將這個隨機噪聲轉化為生成資料 x。如果是圖片生成問題,G 網路的輸出就是一張生成的假圖片,用 G(z) 表示。判別模型 D 要使得 D(G(z)) 接近與 0,即能夠判斷生成的圖片是假的;生成模型 G 要使得 D(G(z)) 接近於 1,即要能夠要欺騙判別模型,使得 D 認為 G(z) 生成的假資料是真的。這樣通過判別模型 D 和生成模型 G 的博弈,使得 D 無法判斷一張圖片是生成出來的還是真實的而結束。
假設 P_r 和 P_g 分別代表真實資料的分佈與生成資料的分佈,這樣判別模型的目標函式可以表示為:
而生成模型的是讓判別模型 D 無法區別真實資料與生成資料,這樣優化目標函式為:
TFGAN 庫的地址為 https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan, 主要包含以下幾個元件:
-
核心架構,主要包括建立 TFGAN 模型,新增 Loss 值,建立訓練 operation,執行訓練 operation。
-
常用操作,主要提供了梯度修剪操作,歸一化操作及條件化操作等。
-
損失函式,主要提供了 GAN 中常用的損失和懲罰函式,如 Wasserstein 損失、梯度懲罰、互資訊懲罰等。
-
模型評估,提供了 Inception Score 和 Frechet Distance 指標,用於評估無條件生成模型。
-
示例,谷歌同時開源了常用的 GAN 網路示例程式碼,包括 unconditional GAN,conditional GAN, InfoGAN,WGAN 等。相關用例可從 https://github.com/tensorflow/models/tree/master/research/gan/ 地址下載。
使用 TFGAN 庫訓練 GAN 網路主要包含如下幾個步驟:
-
確定 GAN 網路的輸入,如下所示:
-
設定 GANModel 中的生成模型和判別模型,如下所示:
-
設定 GANLoss 中的損失方程,如下所示:
-
設定 GANTrainOps 中的訓練操作,如下所示:
-
執行模型訓練,如下所示:
CGAN(Conditional Generative Adversarial Nets),針對 GAN 本身不可控的缺點,加入監督資訊,訓練從無監督變成有監督,指導 GAN 網路進行生成。例如輸入分類的標籤,可生成相應標籤的影像。這樣 CGAN 的目標方程可以轉換為:
其中,y 是加入的監督資訊,D(x|y) 表示在 y 的條件下判定真實資料 x,D(G(z|y)) 表示在 y 的條件下判定生成資料 G(z|y)。例如,MNIST 資料集可根據數字 label 資訊,生成相應標籤的圖片;人臉生成資料集,可根據性別、是否微笑、年齡等資訊,生成相應的人臉圖片。CGAN 的架構如下圖所示:
在 TFGAN 中提供了,基於 one_hot_labels 變數和輸入 tensor 生成 condition tensor 的 API,如下所示:
tfgan.features.condition_tensor_from_onehot
(tensor, one_hot_labels, embedding_size)複製程式碼
其中,tensor 為輸入資料,one_hot_labels 為 onehot 標籤,shape 為 [batch_size, num_classes],embedding_size 為每個 label 對應的 embedding 大小,返回值為 condition tensor。
Phillip Isola 等提出了基於 CGAN 的圖片生成圖片的對抗神經網路《Image-to-Image Translation with Conditional Adversarial Networks》。網路設計的基本思想如下所示:
其中,x 為輸入的線條圖,G(x) 為生成圖片,y 為線條圖 x 對應渲染後的真圖片,生成模型 G 用於生成圖片,判斷模型 D 用於判定生成圖片的真假。判別網路能夠最大化判斷 (x,y) 的資料為真,判斷 (x,G(x)) 資料為假。而生成網路使得判別網路判斷 (x,G(x)) 資料為真,從而進行生成模型和判別模型的相互博弈。為了使生成模型不僅能夠欺騙判別模型,還要使得生成影像要像真實圖片,這樣在目標函式中加入了真實影像和生成影像的 L1 距離,如下所示:
TFGAN 庫,提供了 ImageToImage 生成對抗網路的相關損失方程 API 使用示例,如下所示:
# 定義真實資料與生成資料的 L1 損失
# gan_loss 為目標函式損失
gan_loss = tfgan.losses.combine_adversarial_loss
(gan_loss, gan_model, l1_pixel_loss, weight_factor=FLAGS.
weight_factor)複製程式碼
在 GAN 中,生成器用噪聲 z 生成資料時,沒有加任何的條件限制,很難用 z 的任何一個維度資訊表示相關的語義特徵。所以在資料生成過程中,無法控制什麼樣的噪聲 z 可以生成什麼樣的資料,在很大程度上限制了 GAN 的使用。InfoGAN 可以認為是無監督的 CGAN,在噪聲 z 上增加潛變數 c,使得生成模型生成的資料與淺變數 c 具有較高的互資訊,其中 Info 就是代表互資訊的含義。互資訊定義為兩個熵的差值,H(x) 是先驗分佈的熵,H(x|y) 代表後驗分佈的熵。如果 x,y 是相互獨立的變數,那麼互資訊的值為 0,表示 x,y 沒有關係;如果 x,y 有相關性,那麼互資訊大於 0。這樣在已知 y 的情況下,可以推斷出那些 x 的值出現高。這樣 InfoGAN 的目標方程為:
InfoGAN 的網路結構如下所示:
上圖中 InfoGAN 與 GAN 的區別在於,對應判別網路的輸出 D(x),生成變分分佈 Q(c|x),從而能用 Q(c|x) 來逼近 P(c|x),從而增大生成資料與潛變數 c 的互資訊。TFGAN 中提供了 InfoGan 相關 API,如下所示:
# 通過 tfgan.infogan_model,定義 infogan 模型
# 通過 tfgan.gan_loss,生成 infogan 模型的 loss 值:
# InfoGan 的 Loss 值為在 GAN 的 loss 值上,加上互資訊 I(c;G(z,c)),TFGAN 中提供了互資訊計算的 API,如下所示。其中 structured_generator_inputs 為潛變數的噪音資訊,predicted_distributions 為變分分佈 Q(c|x)。
Martin Arjovsky 等提出了 WGAN(Wasserstein GAN),解決了傳統 GAN 訓練困難、生成器和判別器的 loss 很難指示訓練程式、生成樣本缺乏多樣性等問題,主要有以下優點:
-
能夠平衡生成器和判別器的訓練程度,使得 GAN 的模型訓練穩定。
-
能夠保證生產樣本的多樣性。
-
提出使用 Wasserstein 距離來衡量模型訓練的程度,數值越小表示訓練得越好,成器生成的影像質量越高。
WGAN 的演算法與原始 GAN 演算法的差異主要體現在:
-
去掉判別模型最後一層的 sigmoid 操作。
-
生成模型和判別模型的 loss 值不取 log 操作。
-
每次更新判別模型的引數之後把模型引數的絕對值截斷到不超過固定常數 c。
-
使用 RMSProp 演算法,不用基於動量的優化演算法,例如 momentum 和 Adam。
WGAN 的演算法結構如下所示:
TFGAN 中提供了 WGan 相關 API,如下所示:
#生成網路損失方程
generator_loss_fn=tfgan_losses.wasserstein_generator_loss複製程式碼
#判別網路損失方程
discriminator_loss_fn=tfgan_losses.wasserstein_discriminator
_loss複製程式碼
本文首先介紹了生成對抗網路和 TFGAN,生成對抗網路模型用於影像生成、超解析度圖片生成、影像壓縮、影像風格轉換、資料增強、文字生成等場景;TFGAN 是 TensorFlow 庫,用於快速實踐各種 GAN 模型。然後講解了 CGAN、ImageToImage、InfoGAN、WGAN 模型的主要思想,並對關鍵技術進行了分析,主要包括目標函式、網路架構、損失方程及相應的 TFGAN API。使用者可基於 TFGAN 快速實踐生成對抗網路模型,並應用到工業領域中的相關場景。
參考文獻
[1] Generative Adversarial Networks.
[2] Conditional Generative Adversarial Nets.
[3] InfoGAN: Interpretable Representation Learning by Information MaximizingGenerative Adversarial Nets.
[4] Wasserstein GAN.
[5] Image-to-Image Translation with Conditional Adversarial Networks.
[6]https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/gan.
[7] https://github.com/tensorflow/models/tree/master/research/gan.
更多幹貨內容請關注微信公眾號“AI 前線”,(ID:ai-front)