GAN實戰筆記——第三章第一個GAN模型:生成手寫數字

墨戈發表於2022-02-21

第一個GAN模型—生成手寫數字

一、GAN的基礎:對抗訓練

形式上,生成器和判別器由可微函式表示如神經網路,他們都有自己的代價函式。這兩個網路是利用判別器的損失記性反向傳播訓練。判別器努力使真實樣本輸入和偽樣本輸入帶來的損失最小化,而生成器努力使它生成的為樣本造成的判別器損失最大化。

image

訓練資料集決定了生成器要學習模擬的樣本型別,例如,目標是生成貓的逼真影像,我們就會給GAN提供一組貓的影像。

用更專業的術語來說,生成器的目標是生成符合訓練資料集資料分佈的樣本。對計算機來說,影像只是矩陣:灰度圖是二維的,彩色圖是三維的。當在螢幕上呈現時,這些矩陣中的畫素值將顯示為影像線條、邊緣、輪廓等的所有視覺元素。這些值在資料集中的每個影像上遵循複雜的分佈,如果沒有分佈規律,影像將不過是些隨機噪聲。目標識別模型學習影像中的模式以識別影像的內容,生成器所做的可以認為是相反的過程:它學習合成這些模式,而不是識別這些模式。

1. 代價函式

遵循標準的表示形式,用\(\text{J}^{(G)}\)表示生成器的代價函式,用\(\text{J}^{(D)}\)表示判別器的代價函式。兩個網路的訓練引數(權重和偏置)用希臘字母表示:\(\theta^{(G)}\)表示生成器,\(\theta^{(D)}\)表示判別器。

GAN在兩個關鍵方面不同於傳統的神經網路。第一,代價函式\(J\),傳統神經網路的代價函式僅根據其自身可訓練的引數定義,數學表示為\(\text{J}^{(\theta)}\)。相比之下,GAN由兩個網路組成,其代價函式依賴於兩個網路的引數。也就是說,生成器的代價函式是\(\text{J}^{(G)}({\theta}^{(G)}, {\theta}^{(D)})\),而判別器的代價函式是\(\text{J}^{(D)}({\theta}^{(G)}, {\theta}^{(D)})\)​。

第二,在訓練過程中,傳統的神經網路可以調整它的所有引數\(\theta\)。在GAN中,每個網路只能調整自己的權重和偏置。也就是說,在訓練過程中,生成器只能調整\({\theta}^{(G)}\),判別器只能調整\({\theta}^{(D)}\)​​。因此,每個網路只控制了決定損失的部分參量。

為了使上述內容不那麼抽象,考慮下面這個類比。想象一下我們正在選擇下班開車回家的路線,如果交通不堵塞,最快的選擇是高速公路,但在交通高峰期,優選是走一條小路。儘管小路更長更曲折,但當高速公路上交通堵塞時,走小路可能會更快地回家。

讓我們把它當作一道數學題——\(J\)作為代價函式,並定義為回家所需的時間。我們的目標是儘量減小\(J\)。為簡單起見,假設離開辦公室的時間是固定的,既不能提前離開,也不能為了避開高峰時間而晚走。所以唯一能改變的引數是路線\(\theta\)

如果我們所擁有的是路上唯一的車,代價將類似於一個常規的神經網路:它將只取決於路線,且優化\(\text{J}{(\theta)}\)​完全在我們的能力範圍內。然而,一旦將其他駕駛員引入方程式,情況就會變得更加複雜。突然之間,我們回家的時間不僅取決於自己的決定,還取決於其他駕駛員的行路方案,即\(\text{J}({\theta}^{\text{我們}}, {\theta}^{\text{其他駕駛員}})\)。就像生成器網路和判別器網路一樣,“代價函式”將取決於各種因素的相互作用,其中一些因素在我們的掌控之下,而另一些因素則不在。

2. 訓練過程

上面所描述的兩個差異對GAN的訓練過程有著深遠的影響。傳統神經網路的訓練是個優化問題,通過尋找一組引數來最小化代價函式,移動到引數空間中的任何相鄰點都會增加代價。這可能是引數空間中的區域性或全域性最小值,由尋求最小化使用的代價函式所決定。

最小化代價函式的優化過程如下圖所示:

image

碗形網格表示引數空間\(\theta_1\)\(\theta_2\)中的損失\(J\)。黑色點線表示通過優化使引數空間中的損失最小化

因為生成器和判別器智慧調整自己的引數而不能相互調整對方的引數,所以GAN訓練可以用一個博弈過程來更好的描述,而非優化。該博弈中的對手是GAN所包含的兩個網路。

當兩個網路達到納什均衡時GAN訓練結束,在納什均衡點上,雙方都不能通過改變策略來改善自己的情況。從數學角度來說,發上在這樣的情況下——生成器的可訓練引數\(\theta^{(G)}\)對應的生成器的代價函式\(\text{J}^{(G)}({\theta}^{(G)}, {\theta}^{(D)})\)最小化;同時,對應該網路引數\(\theta^{(D)}\)下的判別器的代價函式\(\text{J}^{(D)}({\theta}^{(G)}, {\theta}^{(D)})\)也得到最小化。下圖說明了二者零和博弈的建立和達到納什均衡的過程。

image

玩家1(左)試圖通過調整\(\theta_1\)來最小化V。玩家2(中間)試圖通過調整\(\theta_2\)來最大化V(最小化-V)。鞍形網路(右)顯示了引數空間\(V({\theta_1}, {\theta_2})\)中的組合損失。虛線表示在鞍形網路中心收斂達到納什均衡

回到我們的類比,對於我們和可能在路上遇到的所有其他駕駛員來說,當每一條回家的路線所花費的時間都完全相同時,納什均衡將會發生。任何更快的路線都會被交通擁堵量的成比例增長所抵消,從而減緩了每個人的速度。而這種狀態在現實生活中幾乎是無法實現的,即使使用像谷歌地圖這樣提供實時流量更新的工具,也不可能完美地評估出回家的最佳路徑。

這同樣適用於訓練GAN網路時的高維、非凸情況。即使是像 MNIST資料集中的那些小到只有28×28畫素的灰度影像,也有28×28=784維。如果它們被著色(RGB),它們的維數將增加到3倍變成2352。在訓練資料集中的所有影像上捕獲這種分佈非常困難,特別是當最好的學習方法是從對手(判別器)那裡學習時。

成功地訓練GAN需要反覆試驗,儘管有最優方法,但它是一門科學的同時也是一門藝術。

二、生成器和判別器

現在通過引入更多的表示概括所學的內容。生成器(G)接收隨機噪聲向量z並生成為樣本\(x^*\)。數學上來說,\(G(z) = x^*\)。判別器(D)的輸入要麼是真實樣本x,要麼是偽樣本x*;對於每個輸入,它輸出一個介於0和1之間的值,表示輸入是真實樣本的概率。下圖用剛才介紹的術語和符號描述了GAN架構。

image

生成器網路G將隨機向量z轉換為偽樣本\(x^*: G(z) = x^*\)。判別器網路D對輸入樣本是各真實進行分類並輸出。對於真實樣本x,判別器力求輸出儘可能接近1的值;對於偽樣本x,判別器力求輸出儘可能接近0的值。相反,生成器希望D(x*)儘可能接近1,這表明判別器被欺騙,將偽樣本分類為真實樣本

1. 對抗的目標

判別器的目標是儘可能精確。對於真實樣本x,D(x)力求儘可能接近1(正的標籤);對於偽樣本x*,\(D\left( x^{\ast }\right)\)​力求儘可能接近0(負的標籤)。

生成器的目標正好相反,它試圖通過生成與訓練資料集中的真實資料別無二致的偽樣本x*來欺騙判別器。從數學角度來講,即生成器試圖生成假樣本\(D\left( x^{\ast }\right)\),使得\(D\left( x^{\ast }\right)\)儘可能接近1。

2. 混淆矩陣

判別器的分類可以使用混淆矩陣來表示,混淆矩陣是二元分類中所有可能結果的表格表示。如下表:

image

判別器的分類結果如下:

  1. 真陽性(true positive)——真實樣本正確分類為真D(x)≈1;
  2. 假陰性( false negative)——真實樣本錯誤分類為假D(x)≈0;
  3. 真陰性( true negative)——偽樣本正確分類為假D(x*)≈0;
  4. 假陽性( false positive)——偽樣本錯誤分類為真D(x*)≈1。

使用混淆矩陣的術語,判別器試圖最大化真陽性和真陰性分類,這等同於最小化假陽性和假陰性分類。相反,生成器的目標是最大化判別器的假陽性分類,這樣生成器才能成功地欺騙判別器,使其相信偽樣本是真的。生成器不關心判別器對真實樣本的分類效果如何,只關心對偽樣本的分類。

三、GAN訓練演算法

這裡介紹的演算法使用小批量(mini-batch)而不是一次使用一個樣本。

GAN訓練演算法
對於每次訓練迭代,執行
(1)訓練判別器
a. 取隨機的小批量的真實樣本x
b. 取隨機的小批量的隨機噪聲z,並生成一個小批量偽樣本:G(z) = x*
c. 計算D(x)和D(x*)的分類損失,並反向傳播總誤差以更新\({\theta}^{(D)}\)來最小化分類損失

​ (2)訓練生成器

​ a. 取隨機的小批量的隨機噪聲z生成一小批量偽樣本:G(z) = x*

​ b. 用判別器網路對x*進行分類

​ c. 計算D(x*)的分類損失,並反向傳播總誤差以更新\({\theta}^{(G)}\)來最大化分類損失。

結束

在步驟1中訓練判別器時,生成器的引數保持不變;同樣,在步驟2中,在訓練生成器時保持判別器的引數不變。之所以只允許更新被訓練網路的權重和偏置,是因為要將所有更改隔離到僅受該網路控制的引數中。

四、生成手寫數字

本節將實現一個GAN,它將學習生成外觀逼真的手寫數字,用的是帶有TensorFlow後端的Python神經網路庫Keras。

1. 匯入模組並指定模型輸入維度

#import statements
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Flatten, GlobalAveragePooling2D, Reshape
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from keras.layers.advanced_activations import LeakyReLU 

#模型輸入維度
img_rows = 28
img_cols = 28
channels = 1

img_shape = (img_rows, img_cols, channels)#輸入圖片的維度

z_dim = 100#噪聲向量的大小用作生成器的輸入
img_shape

2. 構造生成器

簡而言之,生成器是一個只有一個隱藏層的神經網路。生成器以z為輸入,生成28×28×1的影像。在隱藏層中使用 LeakyReLU啟用函式,與將任何負輸入對映
到0的常ReLU函式不同, LeakyReLU函式允許存在一個小的正梯度,這樣可以防止梯度在訓練過程中消失,從而產生更好的訓練效果。

在輸出層使用tanh啟用函式,它將輸出值縮放到範圍[-1, 1]。之所以使用tanh(與sigmoid同, sigmoid會輸出更為典型的0到1範圍內的值),是因為它有助於生成更清斷的影像。

def build_generator(img_shape, z_dim):
    model = Sequential([
        Dense(128, input_dim=z_dim),#全連線層
        LeakyReLU(alpha=0.01),
        Dense(28*28*1, activation='tanh'),
        Reshape(img_shape)#生成器的輸出改變為影像尺寸
    ])
    return model

build_generator(img_shape, z_dim).summary()

3. 構造判別器

判別器接收28×28×1的影像,並輸出表示輸入是否被視為真而不是假的概率。判別器由一個兩層神經網路表示,其隱藏層有128個隱藏單元及啟用函式為 LeakyReLU。

為簡單起見,我們構造的判別器網路看起來與生成器幾乎相同,但並非必須如此。實際上,在大多數GAN的實現中,生成器和判別器網路體系結構的大小和複雜性都相差很大。

注意,與生成器不同的是,判別器的輸出層應用了 sigmoid啟用函式。這確保了輸出值將介於0和1之間,可以將其解釋為生成器將輸入認定為真的概率。

def build_discrimination(img_shape):
    model = Sequential([
        Flatten(input_shape=img_shape),#輸入影像展平
        Dense(128),
        LeakyReLU(alpha=0.01),
        Dense(1, activation='sigmoid')
    ])
    return model

build_discrimination(img_shape).summary()

4. 搭建整個模型

構建並編譯先前實現的生成器模型和判別器模型。注意:在用於訓練生成器的組合模型中,通過將discriminator.trainable設定為False來固定判別器引數。還要注意的是,組合模型(其中判別器設定為不可訓練)僅用於訓練生成器。判別器將用單獨編譯的模型訓練。(當回顧訓練迴圈時,這一點會變得很明顯。)

使用二元交又熵作為在訓練中尋求最小化的損失函式。二元交叉熵( binary cross-entropy)用於度量二分類預測計算的概率和實際概率之間的差異;交叉損失越大,預測離真值就越遠。

優化每個網路使用的是Adam優化演算法。該演算法名字源於adaptive moment estimation。這是一種先進的基於梯度下降的優化演算法,Adam憑藉其通常優異的效能已經成為大多數GAN的首選優化器。

def build_gan(generator, discriminator):
    model = Sequential()
    
    #生成器模型和判別器模型結合到一起
    model.add(generator)
    model.add(discriminator)
    
    return model

discriminator = build_discrimination(img_shape)#構建並編譯判別器
discriminator.compile (
    loss='binary_crossentropy',
    optimizer=Adam(),
    metrics=['accuracy']
)
generator = build_generator(img_shape, z_dim)#構建生成器
discriminator.trainable = False#訓練生成器時保持判別器的引數固定

#構建並編譯判別器固定的GAN模型,以生成訓練器
gan = build_gan(generator, discriminator)
gan.compile(
    loss='binary_crossentropy',
    optimizer=Adam()
)

5. 訓練

首先,取隨機小批量的MNIST影像為真實樣本,從隨機噪聲向量z中生成小批量偽樣本,然後在保持生成器引數不變的情況下,利用這些偽樣本訓練判別器網路。其次,生成一小批偽樣本,使用這些影像訓練生成器網路,同時保持判別器的引數不變。演算法在每次送代中都重複這個過程。

我們使用獨熱編碼(one-hot-encoded)標籤:1代表真實影像,0代表偽影像。z從標準正態分佈(平均值為0、標準差為1的鐘形曲線)中取樣得到。訓練判別器使得假標籤分配給偽影像,真標籤分配給真影像。對生成器進行訓練時,生成器要使判別器能將真實的標籤分配給它生成的偽樣本。

注意:訓練資料集中的真實影像被重新縮放到了-1到1。如前例所示,生成器在輸出層使用tanh啟用函式,因此偽樣本同樣將在範圍(-1,1)內。相應地,就得將判別器的所有輸入重新縮放到同一範圍。

losses = []
accuracies = []
iteration_checkpoints = []

def train(iterations, batch_size, sample_interval):
    (x_train, _), (_, _) = mnist.load_data()#載入mnist資料集
    x_train = x_train/127.5 - 1.0#灰度畫素值[0,255]縮放到[-1,1]
    x_train = np.expand_dims(x_train, axis=3)
    real = np.ones((batch_size, 1))#真實影像的標籤都是1
    fake = np.zeros((batch_size, 1))#偽影像的標籤都是0
    for iteration in range(iterations):
        idx = np.random.randint(0, x_train.shape[0], batch_size)#隨機噪聲取樣
        imgs = x_train[idx]
        
        z = np.random.normal(0, 1, (batch_size, 100))#獲取隨機的一批真實影像
        gen_imgs = generator.predict(z)
        
        #影像畫素縮放到[0,1]
        d_loss_real = discriminator.train_on_batch(imgs, real)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        
        d_loss, accuracy = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        z = np.random.normal(0, 1, (batch_size, 100))#生成一批偽影像
        gen_imgs = generator.predict(z)
        
        g_loss = gan.train_on_batch(z, real)#訓練判別器
        
        if(iteration + 1) % sample_interval == 0:
            losses.append((d_loss, g_loss))
            accuracies.append(100.0 * accuracy)
            iteration_checkpoints.append(iteration + 1)
            print("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (iteration + 1, d_loss, 100.0 * accuracy, g_loss))#輸出訓練過程
            
            sample_images(generator)#輸出生成影像的取樣

6. 輸出樣本影像

在生成器訓練程式碼中,你可能注意到呼叫了 sample_images()函式。該函式在每次sample_ interval選代中呼叫,並輸出由生成器在給定迭代中合成的含有4x4幅合成影像的網格。執行模型後,你可以使用這些影像檢查臨時和最終的輸出情況。

def sample_images(generator, image_grid_rows=4, image_grid_columns=4):
    z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, z_dim))#樣本隨機噪聲
    
    gen_imgs = generator.predict(z)#從隨機噪聲生成影像
    
    gen_imgs = 0.5 * gen_imgs + 0.5#將影像畫素重置縮放至[0, 1]內
    
    #設定影像網格
    fig, axs = plt.subplots(
        image_grid_rows,
        image_grid_columns,
        figsize=(4, 4),
        sharex=True, 
        sharey=True
    )
    
    cnt = 0
    for i in range(image_grid_rows):
        for j in range(image_grid_columns):
            axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')#輸出一個影像網格
            axs[i, j].axis('off')
            cnt += 1
        

7. 執行模型

這是最後一步,設定訓練超引數——迭代次數和批量大小,然後訓練模型。目前沒有一種行之有效的方法來確定正確的迭代次數或正確的批量大小,只能觀察訓
練進度,通過反覆試驗來確定。

也就是說,對這些數有一些重要的實際限制:每個小批量必須足夠小,以適合記憶體器處理(典型使用的批量大小是2的冪:32、64、128、256和512)。迭代次數也有一個實際的限制:擁有的迭代次數越多,訓練過程花費的時間就越長。像GAN這樣複雜的深度學習模型,即使有了強大的計算能力,訓練時長也很容易變得難以控制。

為了確定合適的迭代次數,你需要監控訓練損失,並在損失達到平穩狀態(這意味著我們從進一步的訓練中得到的改進增量很少,甚至沒有)的次數附近設定送代次數。(因為這是個生成模型,像有監督的學習演算法一樣,也需要擔心過擬合問題。)

#設定訓練超引數
iterations = 20000
batch_size = 128
sample_interval = 1000

train(iterations, batch_size, sample_interval)

8. 檢查結果

經過訓練迭代後由生成器生成的樣本影像,按照時間先後排列,如下圖所示。可以看到,生成器起初只會產生隨機噪聲。在訓練迭代的過程中,它越來越擅長模擬訓練資料的特性,每次判別器判斷生成的影像為假或判斷生成的影像為真時,生成器都會稍有改進。

image

生成器經過充分訓練後可以合成的影像樣本,如下圖所示。雖遠非完美,但簡單的雙層生成器生成了逼真的數字,如數字1和7。

image

為了進行比較,我們給出從MNIST資料集中隨機選擇的真實影像樣本。如下圖所示。

image

五、小結

  1. GAN是由兩個網路組成的:生成器(G)和判別器(D)。它們各自有自己的損失函式:\(\text{J}^{(G)}({\theta}^{(G)}, {\theta}^{(D)})\)\(\text{J}^{(D)}({\theta}^{(G)}, {\theta}^{(D)})\)
  2. 在訓練過程中,生成器和判別器只能調整自己的引數,即\(\theta^{(G)}\)\(\theta^{(D)}\)
  3. 兩個網路通過一個類似博弈的動態過程同時訓練:生成器試圖最大化判別器的假陽性分類(將生成的影像分類為真影像),而判別器試圖最小化它的假陽性和假陰性分類。

相關文章