寫給程式設計師的機器學習入門 (十四) - 對抗生成網路 如何造假臉

q303248153發表於2021-04-21

這篇文章將會教你怎樣用機器學習來偽造假資料,題材還是人臉,以下六張人臉裡面,有兩張是假的,猜猜是哪兩張??

生成假人臉使用的網路是對抗生成網路 (GAN - Generative adversarial network),這個網路與之前介紹的比起來相當特殊,雖然看起來不算複雜,但訓練起來極其困難,以下將從基礎原理開始一直講到具體程式碼,還會引入一些之前沒有講過的元件和訓練方法?。

對抗生成網路 (GAN) 的原理

所謂生成網路就是用於生成文章,音訊,圖片,甚至程式碼等資料的機器學習模型,例如我們可以給出一個需求讓網路生成一份程式碼,如果網路足夠強大,生成的程式碼質量足夠好並且能滿足需求,那碼農們就要面臨失業了?。當然,目前機器學習模型可以生成的資料比較有限並且質量都很一般,碼農們的飯碗還是能保住一段時間的。

生成網路和普通的模型一樣,要求有輸入和輸出,假設我們可以傳入一些條件讓網路生成符合條件的圖片:

看起來非常好用,但訓練這樣的模型需要一個龐大的資料集,並且得一張張圖片去標記它們的屬性,實現起來會累死人。這篇文章介紹的對抗生成網路屬於無監督學習,可以完全不需要給資料打標籤,你只需要給模型認識一些真實資料,就可以讓模型輸出類似真實資料的假資料。對抗生成網路分為兩部分,第一部分是生成器 (Generator),第二部分是識別器 (Discriminator),生成器負責根據隨機條件生成資料,識別器負責識別資料是否為真。

訓練對抗生成網路有兩大目標,這兩大目標是矛盾的,這就是為什麼我們叫對抗生成網路:

  • 生成器需要生成騙過識別器 (輸出為真) 的資料
  • 識別器需要不被生成器騙過去 (針對生成器生成的資料輸出為假,針對真實資料輸出為真)

對抗生成網路的訓練流程大致如下,需要迴圈訓練生成器和識別器:

簡單通俗一點我們可以用造假皮包為例來理解,好理解了吧?:

和現實造假皮包一樣,生成器會生成越來越接近真實資料的假資料,最後會生成和真實資料一模一樣的資料,但這樣反而就遠離我們構建生成網路的目的了(不如直接用真實資料)。使用生成網路通常是為了達到以下的目的:

  • 要求大量看上去是真的,但稍微不一樣的資料
  • 要求沒有版權保護的資料 (假資料沒來的版權?)
  • 生成想要但是現實沒有的資料 (需要更進一步的工作)

看以上的流程你可能會發現,因為對抗生成網路是無監督學習,不需要標籤,我們只能給模型傳入隨機的條件來讓它生成資料,模型生成出來的資料看起來可能像真的但不一定是我們想要的。如果我們想要指定具體的條件,則需要在訓練完成以後分析隨機條件對生成結果的影響,例如隨機生成的第二個數字代表性別,第六個數字代表年齡,第八個數字代表頭髮的數量,這樣我們就可以調整這些條件來讓模型生成想要的圖片。

還記得上一篇人臉識別的模型不?人臉識別的模型會把圖片轉換為某個長度的向量,訓練完成以後這個向量的值會代表人物的屬性,而這一篇是反過來,把某個長度的向量轉換回圖片,訓練成功以後這個向量同樣會代表人物的各個屬性。當然,兩種的向量表現形式是不同的,把人臉識別輸出的向量交給對抗生成網路,生成的圖片和原有的圖片可能會相差很遠,把人臉識別輸出的向量還原回去的方法後面再研究吧?。

對抗生成網路的實現

反摺積層 (ConvTranspose2d)

第八篇介紹 CNN 的文章中,我們瞭解過卷積層運算 (Conv2d) 的實現原理,CNN 模型會利用卷積層來把圖片的長寬逐漸縮小,通道數逐漸擴大,最後扁平化輸出一個代表圖片特徵的向量:

而在對抗生成網路的生成器中,我們需要實現反向的操作,即把向量當作一個 (向量長度, 1, 1) 的圖片,然後把長寬逐漸擴大,通道數 (最開始是向量長度) 逐漸縮小,最後變為 (3, 圖片長度, 圖片寬度) 的圖片 (3 代表 RGB)。

實現反向操作需要反摺積層 (ConvTranspose2d),反摺積層簡單的來說就是在引數數量相同的情況下,把輸出大小的資料還原為輸入大小的資料:

要理解反摺積層的具體運算方式,我們可以把卷積層拆解為簡單的矩陣乘法:

可以看到卷積層計算的時候可以根據核心引數和輸入大小生成一個矩陣,然後計算輸入與這個矩陣的乘積來得到輸出結果。

而反摺積層則會計算輸入與轉置 (Transpose) 後的矩陣的乘積得到輸出結果:

可以看到卷積層與反摺積層的區別只在於是否轉置計算使用的矩陣。此外,通道數量轉換的計算方式也是一樣的。

測試反摺積層的程式碼如下:

>>> import torch

# 生成測試用的矩陣
# 第一個維度代表批次,第二個維度代表通道數量,第三個維度代表長度,第四個維度代表寬度
>>> a = torch.arange(1, 5).float().reshape(1, 1, 2, 2)
>>> a
tensor([[[[1., 2.],
          [3., 4.]]]])

# 建立反摺積層
>>> convtranspose2d = torch.nn.ConvTranspose2d(1, 1, kernel_size=2, stride=2, bias=False)

# 手動指定權重 (讓計算更好理解)
>>> convtranspose2d.weight = torch.nn.Parameter(torch.tensor([0.1, 0.2, 0.5, 0.8]).reshape(1, 1, 2, 2))
>>> convtranspose2d.weight
Parameter containing:
tensor([[[[0.1000, 0.2000],
          [0.5000, 0.8000]]]], requires_grad=True)

# 測試反摺積層
>>> convtranspose2d(a)
tensor([[[[0.1000, 0.2000, 0.2000, 0.4000],
          [0.5000, 0.8000, 1.0000, 1.6000],
          [0.3000, 0.6000, 0.4000, 0.8000],
          [1.5000, 2.4000, 2.0000, 3.2000]]]],
       grad_fn=<SlowConvTranspose2DBackward>)

需要注意的是,不一定存在一個反摺積層可以把卷積層的輸出還原到輸入,這是因為卷積層的計算是不可逆的,即使存在一個可以把輸出還原到輸入的矩陣,這個矩陣也不一定有一個等效的反摺積層的核心引數。

生成器的實現 (Generator)

接下來我們看一下生成器的定義,原始介紹 GAN 的論文給出了生成 64x64 圖片的網路,而這裡給出的是生成 80x80 圖片的網路,其實區別只在於一開始的輸出通道數量 (論文是 4, 這裡是 5)

class GenerationModel(nn.Module):
    """生成虛假資料的模型"""
    # 編碼長度
    EmbeddedSize = 128

    def __init__(self):
        super().__init__()
        self.generator = nn.Sequential(
            # 128,1,1 => 512,5,5
            nn.ConvTranspose2d(128, 512, kernel_size=5, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            # => 256,10,10
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            # => 128,20,20
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            # => 64,40,40
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            # => 3,80,80
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
            # 限制輸出在 -1 ~ 1,不使用 Hardtanh 是為了讓超過範圍的值可以傳播給上層
            nn.Tanh())

    def forward(self, x):
        y = self.generator(x.view(x.shape[0], x.shape[1], 1, 1))
        return y

表現如下:

其中批次正規化 (BatchNorm) 用於控制引數值範圍,防止層數過多 (後面會結合識別器訓練) 導致梯度爆炸問題。

還有一個要點是生成器輸出的範圍會在 -1 ~ 1,也就是使用 -1 ~ 1 代表 0 ~ 255 的顏色值,這跟我們之前處理圖片的時候把值除以 255 使得範圍在 0 ~ 1 不一樣。使用 -1 ~ 1 可以提升輸出顏色的精度 (減少浮點數的精度損失)。

識別器的實現 (Discriminator)

我們再看以下識別器的定義,基本上就是前面生成器的相反流程:

class DiscriminationModel(nn.Module):
    """識別資料是否真實的模型"""

    def __init__(self):
        super().__init__()
        self.discriminator = nn.Sequential(
            # 3,80,80 => 64,40,40
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # => 128,20,20
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # => 256,10,10
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # => 512,5,5
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # => 1,1,1
            nn.Conv2d(512, 1, kernel_size=5, stride=1, padding=0, bias=False),
            # 扁平化
            nn.Flatten(),
            # 輸出是否真實資料 (0 or 1)
            nn.Sigmoid())

    def forward(self, x):
        y = self.discriminator(x)
        return y

表現如下:

看到這裡你可能會有幾個疑問:

  • 為什麼用 LeakyReLU: 這是為了防止層數疊加次數過多導致的梯度消失問題,參考第三篇,LeakyReLU 對於負數輸入不會返回 0,而是返回 輸入 * slope,這裡的 slope 指定為 0.2
  • 為什麼第一層不加批次正規化 (BatchNorm): 原有論文中提到實際測試中,如果在所有層新增批次正規化會讓模型訓練結果不穩定,生成器的最後一層和識別器的第一層拿掉以後效果會好一些
  • 為什麼不加池化層: 新增池化層以後可逆性將會降低,例如識別器針對假資料返回接近 0 的數值時,判斷哪些部分導致這個輸出的依據會減少

訓練生成器和識別器的方法

接下來就是訓練生成器和識別器,生成器和識別器需要分別訓練,訓練識別器的時候不能動生成器的引數,訓練生成器的時候不能動識別器的引數,使用的程式碼大致如下:

# 建立模型例項
generation_model = GenerationModel().to(device)
discrimination_model = DiscriminationModel().to(device)

# 建立引數調整器
# 根據生成器和識別器分別建立
optimizer_g = torch.optim.Adam(generation_model.parameters())
optimizer_d = torch.optim.Adam(discrimination_model.parameters())

# 隨機生成編碼
def generate_vectors(batch_size):
    vectors = torch.randn((batch_size, GenerationModel.EmbeddedSize), device=device)
    return vectors

# 開始訓練過程
for epoch in range(0, 10000):
    # 列舉真實資料
    for index, batch_x in enumerate(read_batches()):
        # 生成隨機編碼
        training_vectors = generate_vectors(minibatch_size)
        # 生成虛假資料
        generated = generation_model(training_vectors)
        # 獲取真實資料
        real = batch_x
 
        # 訓練識別器 (只調整識別器的引數)
        predicted_t = discrimination_model(real)
        predicted_f = discrimination_model(generated)
        loss_d = (
            nn.functional.binary_cross_entropy(
                predicted_t, torch.ones(predicted_t.shape, device=device)) +
            nn.functional.binary_cross_entropy(
                predicted_f, torch.zeros(predicted_f.shape, device=device)))
        loss_d.backward() # 根據損失自動微分
        optimizer_d.step() # 調整識別器的引數
        optimizer_g.zero_grad() # 清空生成器引數記錄的導函式值
        optimizer_d.zero_grad() # 清空識別器引數記錄的導函式值

        # 訓練生成器 (只調整生成器的引數)
        predicted_f = discrimination_model(generated)
        loss_g = nn.functional.binary_cross_entropy(
            predicted_f, torch.ones(predicted_f.shape, device=device))
        loss_g.backward() # 根據損失自動微分
        optimizer_g.step() # 調整生成器的引數
        optimizer_g.zero_grad() # 清空生成器引數記錄的導函式值
        optimizer_d.zero_grad() # 清空識別器引數記錄的導函式值

上述例子應該可以幫助你理解大致的訓練流程和只訓練識別器或生成器的方法,但是直接這麼做效果會很差?,接下來我們會看看對抗生成網路的問題,並且給出優化方案,後面的完整程式碼會跟上述例子有一些不同。

如果對原始論文有興趣可以參考這裡,原始的對抗生成網路又稱 DCGAN (Deep Convolutional GAN)。

對抗生成網路的問題

看完以上的內容你可能會覺得,嘿嘿,還是挺簡單的。不?,雖然原理看上去挺好理解,模型本身也不復雜,但對抗生成網路是目前介紹過的模型裡面訓練難度最高的,這是因為對抗生成網路建立在矛盾上,沒有一個明確的目標 (之前的模型目標都是針對未學習過的資料預測正確率儘可能接近 100%)。如果生成器生成 100% 可以騙過識別器的資料,那可能代表識別器根本沒正常工作,或者生成器生成的資料跟真實資料 100% 相同,沒實用價值;而如果識別器 100% 可以識別生成器生成的資料,那代表生成器生成的資料太垃圾,一個都騙不過。本篇介紹的例子使用了最蠢最簡單的方法,把每一輪學習後生成器生成的資料輸出到硬碟,然後人工鑑定生成的效果怎樣?,同時還會每 100 輪訓練記錄一次模型狀態,供訓練完以後回滾使用 (最後一個模型狀態效果不會是最好的,後面會說明)。

另一個問題是識別器和生成器不能同時訓練,怎樣安排訓練過程對訓練結果的影響非常大?,理想的過程是:識別器稍微領先生成器,生成器跟著識別器慢慢的生成越來越精準的資料。舉例來說,識別器首先會識別膚色佔比較多的圖片為人臉,接下來生成器會生成全部都是膚色的圖片,然後識別器會識別有兩個看上去是眼睛的圖片為人臉,接下來生成器會加上兩個看上去是眼睛的形狀到圖片,之後識別器會識別帶有五官的圖片為人臉,接下來生成器會加上剩餘的五官到圖片,最後識別器會識別五官和臉形狀比較正常的人為人臉,生成器會盡量調整五官和人臉形狀接近正常水平。而不理想的過程是識別器大幅領先生成器,例如識別器很早就達到了接近 100% 的正確率,而生成器因為找不到學習的方向正確率會一直原地踏步;另一個不理想的過程是生成器領先識別器,這時會出現識別器找不到學習的方向,生成器也找不到學習的方向而原地轉的情況。實現識別器稍微領先生成器,可以增加識別器的訓練次數,常見的方法是每訓練 n 次識別器就訓練 1 次生成器,而本文後面會介紹根據正確率動態調整識別器和生成器學習次數的方法,參考後面的程式碼吧。

對抗生成網路最大的問題是模式崩潰 (Mode Collapse) 問題,這個問題所有訓練對抗生成網路的人都會面對,並且目前沒有 100% 的方法避免?。簡單的來說就是生成器學會偷懶作弊,只會輸出一到幾個與真實資料幾乎一模一樣的虛假資料,因為生成的資料同質化非常嚴重,即使可以騙過識別器也沒什麼實用價值。發生模式崩潰以後的輸出例子如下,可以看到很多人臉都非常接近:

為了儘量避免模式崩潰問題,以下幾個改進的模型被發明瞭出來,這就是人民群眾的智慧啊?。

改進對抗生成網路 (WGAN)

模式崩潰問題的原因之一就是部分模型引數會隨著訓練固化 (達到本地最優),因為原始的對抗生成網路會讓識別器輸出儘可能接近 1 或者 0 的值,如果值已經是 0 或者 1 那麼引數就不會被調整。WGAN (Wasserstein GAN) 的解決方式是不限制識別器輸出的值範圍,只要求識別器針對真實資料輸出的值大於虛假資料輸出的值,和要求生成器生成可以讓識別器輸出更大的值的資料

第一個修改是拿掉識別器最後的 Sigmoid,這樣識別器輸出的值就不會限制在 0 ~ 1 的範圍內。

第二個修改是修改計算損失的方式:

# 計算識別器的損失,修改前
loss_d = (
    nn.functional.binary_cross_entropy(
        predicted_t, torch.ones(predicted_t.shape, device=device)) +
    nn.functional.binary_cross_entropy(
        predicted_f, torch.zeros(predicted_f.shape, device=device)))

# 計算識別器的損失,修改後
loss_d = predicted_f.mean() - predicted_t.mean()
# 計算生成器的損失,修改前
loss_g = nn.functional.binary_cross_entropy(
            predicted_f, torch.ones(predicted_f.shape, device=device))

# 計算生成器的損失,修改後
loss_g = -predicted_f.mean()

這麼修改以後會出現一個問題,識別器輸出的值範圍會隨著訓練越來越大 (生成器提高虛假資料的輸出值,接下來識別器提高真實資料的輸出值,迴圈下去輸出值就會越來越大?),從而導致梯度爆炸問題。為了解決這個問題 WGAN 對識別器引數的可取範圍做出了限制,也就是在調整完引數以後裁剪引數,第三個修改如下:

# 讓識別器引數必須在 -0.1 ~ 0.1 之間
for p in discrimination_model.parameters():
    p.data.clamp_(-0.1, 0.1)

如果有興趣可以參考 WGAN 的原始論文,裡面一大堆數學公式可以把人嚇壞?,但主要的部分只有上面提到的三點。

改進對抗生成網路 (WGAN-GP)

WGAN 為了防止梯度爆炸問題對識別器引數的可取範圍做出了限制,但這個做法比較粗暴,WGAN-GP (Wasserstein GAN Gradient Penalty) 提出了一個更優雅的方法,即限制導函式值的範圍,如果導函式值偏移某個指定的值則通過損失給與模型懲罰。

具體實現如下,看起來比較複雜但做的事情只是計算識別器輸入資料的導函式值,然後判斷所有通道合計的導函式值的 L2 合計與常量 1 相差多少,相差越大就返回越高的損失,這樣識別器模型引數自然會控制在某個水平。

def gradient_penalty(discrimination_model, real, generated):
    """控制導函式值的範圍,用於防止模型引數失控 (https://arxiv.org/pdf/1704.00028.pdf)"""

    # 給批次中的每個樣本分別生成不同的隨機值,範圍在 0 ~ 1
    batch_size = real.shape[0]
    rate = torch.randn(batch_size, 1, 1, 1)
    rate = rate.expand(batch_size, real.shape[1], real.shape[2], real.shape[3]).to(device)

    # 按隨機值比例混合真樣本和假樣本
    mixed = (rate * real + (1 - rate) * generated)

    # 識別混合樣本
    predicted_m = discrimination_model(mixed)

    # 計算 mixed 對 predicted_m 的影響,也就是 mixed => predicted_m 的微分
    # 與以下程式碼計算結果相同,但不會影響途中 (即模型引數) 的 grad 值
    # mixed = torch.tensor(mixed, requires_grad=True)
    # predicted_m.sum().backward()
    # grad = mixed.grad
    grad = torch.autograd.grad(
        outputs = predicted_m,
        inputs = mixed,
        grad_outputs = torch.ones(predicted_m.shape).to(device),
        create_graph=True,
        retain_graph=True)[0]

    # 讓導函式值的 L2 norm (所有通道合計) 在 1 左右,如果偏離 1 則使用損失給與懲罰
    grad_penalty = ((grad.norm(2, dim=1) - 1) ** 2).mean() * 10
    return grad_penalty

然後再修改計算識別器損失的方法:

# 計算識別器的損失,修改前
loss_d = predicted_f.mean() - predicted_t.mean()

# 計算識別器的損失,修改後
loss_d = (predicted_f.mean() - predicted_t.mean() +
    gradient_penalty(discrimination_model, real, generated))

最後把識別器中的批次正規化 (BatchNorm) 刪掉或者改為例項正規化 (InstanceNorm) 就完了。InstanceNorm 和 BatchNorm 的區別在於計算平均值和標準差的時候不會根據整個批次計算,而是隻根據各個樣本自身計算,關於 BatchNorm 的計算方式可以參考第四篇

如果有興趣可以參考 WGAN-GP 的原始論文

完整程式碼

又到完整程式碼的時間了?,這份程式碼同時包含了原始的 GAN 模型 (DCGAN),WGAN 和 WGAN-GP 的實現,後面還會比較它們之間的效果相差多少。

使用的資料集連結如下,前一篇的人臉識別文章也用到了這個資料集:

https://www.kaggle.com/atulanandjha/lfwpeople

需要注意的是人臉圖片數量越多就越容易出現模式崩潰問題,這也是對抗生成網路訓練的難點之一?,這份程式碼只會隨機選取 2000 張圖片用於訓練。

這份程式碼還會根據正確率動態調整生成器和識別器的訓練比例,如果識別器比生成器更強則訓練 1 次生成器,如果生成器比識別器更強則訓練 5 次識別器,這麼做可以省去手動調整訓練比例的麻煩,經實驗效果也不錯?。

import os
import sys
import torch
import gzip
import itertools
import random
import numpy
import math
import json
from PIL import Image
from torch import nn
from matplotlib import pyplot
from functools import lru_cache

# 生成或識別圖片的大小
IMAGE_SIZE = (80, 80)
# 訓練使用的資料集路徑
DATASET_DIR = "./dataset/lfwpeople/lfw_funneled"
# 模型類別, 支援 DCGAN, WGAN, WGAN-GP
MODEL_TYPE = "WGAN-GP"

# 用於啟用 GPU 支援
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class GenerationModel(nn.Module):
    """生成虛假資料的模型"""
    # 編碼長度
    EmbeddedSize = 128

    def __init__(self):
        super().__init__()
        self.generator = nn.Sequential(
            # 128,1,1 => 512,5,5
            nn.ConvTranspose2d(128, 512, kernel_size=5, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            # => 256,10,10
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            # => 128,20,20
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            # => 64,40,40
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            # => 3,80,80
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
            # 限制輸出在 -1 ~ 1,不使用 Hardtanh 是為了讓超過範圍的值可以傳播給上層
            nn.Tanh())

    def forward(self, x):
        y = self.generator(x.view(x.shape[0], x.shape[1], 1, 1))
        return y

    @staticmethod
    def calc_accuracy(predicted_f):
        """正確率計算器"""
        # 返回騙過識別器的虛假資料比例
        if MODEL_TYPE == "DCGAN":
            threshold = 0.5
        elif MODEL_TYPE in ("WGAN", "WGAN-GP"):
            threshold = DiscriminationModel.LastTrueSamplePredictedMean
        else:
            raise ValueError("unknown model type")
        return (predicted_f >= threshold).float().mean().item()

class DiscriminationModel(nn.Module):
    """識別資料是否真實的模型"""
    # 最終識別真實樣本的輸出平均值,WGAN 會使用這個值判斷騙過識別器的虛假資料比例
    LastTrueSamplePredictedMean = 0.5

    def __init__(self):
        super().__init__()
        # 標準化函式
        def norm2d(features):
            if MODEL_TYPE == "WGAN-GP":
                # WGAN-GP 本來不需要 BatchNorm,但可以額外的加 InstanceNorm 改善效果
                # InstanceNorm 不一樣的是平均值和標準差會針對批次中的各個樣本分別計算
                # affine = True 表示調整量可學習 (BatchNorm2d 預設為 True)
                return nn.InstanceNorm2d(features, affine=True)
            return nn.BatchNorm2d(features)
        self.discriminator = nn.Sequential(
            # 3,80,80 => 64,40,40
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # => 128,20,20
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            norm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            # => 256,10,10
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            norm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            # => 512,5,5
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            norm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            # => 1,1,1
            nn.Conv2d(512, 1, kernel_size=5, stride=1, padding=0, bias=False),
            # 扁平化
            nn.Flatten())
        if MODEL_TYPE == "DCGAN":
            # 輸出是否真實資料 (0 or 1)
            # WGAN 不限制輸出值範圍在 0 ~ 1 之間
            self.discriminator.add_module("sigmoid", nn.Sigmoid())

    def forward(self, x):
        y = self.discriminator(x)
        return y

    @staticmethod
    def calc_accuracy(predicted_f, predicted_t):
        """正確率計算器"""
        # 返回正確識別的資料比例
        if MODEL_TYPE == "DCGAN":
            return (((predicted_f <= 0.5).float().mean() + (predicted_t > 0.5).float().mean()) / 2).item()
        elif MODEL_TYPE in ("WGAN", "WGAN-GP"):
            DiscriminationModel.LastTrueSamplePredictedMean = predicted_t.mean()
            return (predicted_t > predicted_f).float().mean().item()
        else:
            raise ValueError("unknown model type")

    def gradient_penalty(self, real, generated):
        """控制導函式值的範圍,用於防止模型引數失控 (https://arxiv.org/pdf/1704.00028.pdf)"""
        # 給批次中的每個樣本分別生成不同的隨機值,範圍在 0 ~ 1
        batch_size = real.shape[0]
        rate = torch.randn(batch_size, 1, 1, 1)
        rate = rate.expand(batch_size, real.shape[1], real.shape[2], real.shape[3]).to(device)
        # 按隨機值比例混合真樣本和假樣本
        mixed = (rate * real + (1 - rate) * generated)
        # 識別混合樣本
        predicted_m = self.forward(mixed)
        # 計算 mixed 對 predicted_m 的影響,也就是 mixed => predicted_m 的微分
        # 與以下程式碼計算結果相同,但不會影響途中 (即模型引數) 的 grad 值
        # mixed = torch.tensor(mixed, requires_grad=True)
        # predicted_m.sum().backward()
        # grad = mixed.grad
        grad = torch.autograd.grad(
            outputs = predicted_m,
            inputs = mixed,
            grad_outputs = torch.ones(predicted_m.shape).to(device),
            create_graph=True,
            retain_graph=True)[0]
        # 讓導函式值的 L2 norm (所有通道合計) 在 1 左右,如果偏離 1 則使用損失給與懲罰
        grad_penalty = ((grad.norm(2, dim=1) - 1) ** 2).mean() * 10
        return grad_penalty

def save_tensor(tensor, path):
    """儲存 tensor 物件到檔案"""
    torch.save(tensor, gzip.GzipFile(path, "wb"))

# 為了減少讀取時間這裡快取了讀取的 tensor 物件
# 如果記憶體不夠應該適當減少 maxsize
@lru_cache(maxsize=200)
def load_tensor(path):
    """從檔案讀取 tensor 物件"""
    return torch.load(gzip.GzipFile(path, "rb"))

def image_to_tensor(img):
    """縮放並轉換圖片物件到 tensor 物件"""
    img = img.resize(IMAGE_SIZE) # 縮放圖片,比例不一致時拉伸
    arr = numpy.asarray(img)
    t = torch.from_numpy(arr)
    t = t.transpose(0, 2) # 轉換維度 H,W,C 到 C,W,H
    t = (t / 255.0) * 2 - 1 # 正規化數值使得範圍在 -1 ~ 1
    return t

def tensor_to_image(t):
    """轉換 tensor 物件到圖片"""
    t = (t + 1) / 2 * 255.0 # 轉換顏色回 0 ~ 255
    t = t.transpose(0, 2) # 轉換維度 C,W,H 到 H,W,C
    t = t.int() # 轉換數值到整數
    img = Image.fromarray(t.numpy().astype("uint8"), "RGB")
    return img

def prepare():
    """準備訓練"""
    # 資料集轉換到 tensor 以後會儲存在 data 資料夾下
    if not os.path.isdir("data"):
        os.makedirs("data")

    # 查詢人臉圖片列表
    # 每個人最多使用 2 張圖片
    image_paths = []
    for dirname in os.listdir(DATASET_DIR):
        dirpath = os.path.join(DATASET_DIR, dirname)
        if not os.path.isdir(dirpath):
            continue
        for filename in os.listdir(dirpath)[:2]:
            image_paths.append(os.path.join(DATASET_DIR, dirname, filename))
    print(f"found {len(image_paths)} images")

    # 隨機打亂人臉圖片列表
    random.shuffle(image_paths)

    # 限制人臉數量
    # 如果數量太多,識別器難以記住人臉的具體特徵,會需要更長時間訓練或直接陷入模式崩潰問題
    image_paths = image_paths[:2000]
    print(f"only use {len(image_paths)} images")

    # 儲存人臉圖片資料
    for batch, index in enumerate(range(0, len(image_paths), 200)):
        paths = image_paths[index:index+200]
        images = []
        for path in paths:
            img = Image.open(path)
            # 擴大人臉佔比
            w, h = img.size
            img = img.crop((int(w*0.25), int(h*0.25), int(w*0.75), int(h*0.75)))
            images.append(img)
        tensors = [ image_to_tensor(img) for img in images ]
        tensor = torch.stack(tensors) # 維度: (圖片數量, 3, 寬度, 高度)
        save_tensor(tensor, os.path.join("data", f"{batch}.pt"))
        print(f"saved batch {batch}")

    print("done")

def train():
    """開始訓練模型"""
    # 建立模型例項
    generation_model = GenerationModel().to(device)
    discrimination_model = DiscriminationModel().to(device)

    # 建立損失計算器
    ones_map = {}
    zeros_map = {}
    def loss_function_t(predicted):
        """損失計算器 (訓練識別結果為 1)"""
        count = predicted.shape[0]
        ones = ones_map.get(count)
        if ones is None:
            ones = torch.ones((count, 1), device=device)
            ones_map[count] = ones
        return nn.functional.binary_cross_entropy(predicted, ones)
    def loss_function_f(predicted):
        """損失計算器 (訓練識別結果為 0)"""
        count = predicted.shape[0]
        zeros = zeros_map.get(count)
        if zeros is None:
            zeros = torch.zeros((count, 1), device=device)
            zeros_map[count] = zeros
        return nn.functional.binary_cross_entropy(predicted, zeros)

    # 建立引數調整器
    # 學習率和 betas 跟各個論文給出的一樣,可以一定程度提升學習效果,但不是決定性的
    if MODEL_TYPE == "DCGAN":
        optimizer_g = torch.optim.Adam(generation_model.parameters(), lr=0.0002, betas=(0.5, 0.999))
        optimizer_d = torch.optim.Adam(discrimination_model.parameters(), lr=0.0002, betas=(0.5, 0.999))
    elif MODEL_TYPE == "WGAN":
        optimizer_g = torch.optim.RMSprop(generation_model.parameters(), lr=0.00005)
        optimizer_d = torch.optim.RMSprop(discrimination_model.parameters(), lr=0.00005)
    elif MODEL_TYPE == "WGAN-GP":
        optimizer_g = torch.optim.Adam(generation_model.parameters(), lr=0.0001, betas=(0.0, 0.999))
        optimizer_d = torch.optim.Adam(discrimination_model.parameters(), lr=0.0001, betas=(0.0, 0.999))
    else:
        raise ValueError("unknown model type")

    # 記錄訓練集和驗證集的正確率變化
    training_accuracy_g_history = []
    training_accuracy_d_history = []

    # 計算正確率的工具函式
    calc_accuracy_g = generation_model.calc_accuracy
    calc_accuracy_d = discrimination_model.calc_accuracy

    # 隨機生成編碼
    def generate_vectors(batch_size):
        vectors = torch.randn((batch_size, GenerationModel.EmbeddedSize), device=device)
        return vectors

    # 輸出生成的圖片樣本
    def output_generated_samples(epoch, samples):
        dir_path = f"./generated_samples/{epoch}"
        if not os.path.isdir(dir_path):
            os.makedirs(dir_path)
        for index, sample in enumerate(samples):
            path = os.path.join(dir_path, f"{index}.png")
            tensor_to_image(sample.cpu()).save(path)

    # 讀取批次的工具函式
    def read_batches():
        for batch in itertools.count():
            path = f"data/{batch}.pt"
            if not os.path.isfile(path):
                break
            x = load_tensor(path)
            yield x.to(device)

    # 開始訓練過程
    validating_vectors = generate_vectors(100)
    for epoch in range(0, 10000):
        print(f"epoch: {epoch}")

        # 根據訓練集訓練並修改引數
        # 切換模型到訓練模式
        generation_model.train()
        discrimination_model.train()
        training_accuracy_g_list = []
        training_accuracy_d_list = []
        last_accuracy_g = 0
        last_accuracy_d = 0
        minibatch_size = 20
        train_discriminator_count = 0
        for index, batch_x in enumerate(read_batches()):
            # 使用小批次訓練
            training_batch_accuracy_g = 0.0
            training_batch_accuracy_d = 0.0
            minibatch_count = 0
            for begin in range(0, batch_x.shape[0], minibatch_size):
                # 測試目前生成器和識別器哪邊佔劣勢,訓練佔劣勢的一方
                # 最終的平衡狀態是: 生成器正確率 = 1.0, 識別器正確率 = 0.5
                # 代表生成器生成的圖片和真實圖片基本完全一樣,但不應該訓練到這個程度
                training_vectors = generate_vectors(minibatch_size) # 隨機向量
                generated = generation_model(training_vectors) # 根據隨機向量生成的虛假資料
                real = batch_x[begin:begin+minibatch_size] # 真實資料
                predicted_t = discrimination_model(real)
                predicted_f = discrimination_model(generated)
                accuracy_g = calc_accuracy_g(predicted_f)
                accuracy_d = calc_accuracy_d(predicted_f, predicted_t)
                train_discriminator = (accuracy_g / 2) >= accuracy_d
                if train_discriminator or train_discriminator_count > 0:
                    # 訓練識別器
                    if MODEL_TYPE == "DCGAN":
                        loss_d = loss_function_f(predicted_f) + loss_function_t(predicted_t)
                    elif MODEL_TYPE == "WGAN":
                        loss_d = predicted_f.mean() - predicted_t.mean()
                    elif MODEL_TYPE == "WGAN-GP":
                        loss_d = (predicted_f.mean() - predicted_t.mean() +
                            discrimination_model.gradient_penalty(real, generated))
                    else:
                        raise ValueError("unknown model type")
                    loss_d.backward()
                    optimizer_d.step()
                    # 限制識別器引數範圍以防止模型引數失控 (WGAN-GP 有更好的方法)
                    # 這裡的限制值比論文的值 (0.01) 更大是因為模型層數和引數量更多
                    if MODEL_TYPE == "WGAN":
                        for p in discrimination_model.parameters():
                            p.data.clamp_(-0.1, 0.1)
                    # 讓識別器訓練次數多於生成器
                    if train_discriminator and train_discriminator_count == 0:
                        train_discriminator_count = 5
                    train_discriminator_count -= 1
                else:
                    # 訓練生成器
                    if MODEL_TYPE == "DCGAN":
                        loss_g = loss_function_t(predicted_f)
                    elif MODEL_TYPE in ("WGAN", "WGAN-GP"):
                        loss_g = -predicted_f.mean()
                    else:
                        raise ValueError("unknown model type")
                    loss_g.backward()
                    optimizer_g.step()
                optimizer_g.zero_grad()
                optimizer_d.zero_grad()
                training_batch_accuracy_g += accuracy_g
                training_batch_accuracy_d += accuracy_d
                minibatch_count += 1
            training_batch_accuracy_g /= minibatch_count
            training_batch_accuracy_d /= minibatch_count
            # 輸出批次正確率
            training_accuracy_g_list.append(training_batch_accuracy_g)
            training_accuracy_d_list.append(training_batch_accuracy_d)
            print(f"epoch: {epoch}, batch: {index},",
                f"accuracy_g: {training_batch_accuracy_g}, accuracy_d: {training_batch_accuracy_d}")
        training_accuracy_g = sum(training_accuracy_g_list) / len(training_accuracy_g_list)
        training_accuracy_d = sum(training_accuracy_d_list) / len(training_accuracy_d_list)
        training_accuracy_g_history.append(training_accuracy_g)
        training_accuracy_d_history.append(training_accuracy_d)
        print(f"training accuracy_g: {training_accuracy_g}, accuracy_d: {training_accuracy_d}")

        # 儲存虛假資料用於評價訓練效果
        output_generated_samples(epoch, generation_model(validating_vectors))

        # 儲存模型狀態
        if (epoch + 1) % 10 == 0:
            save_tensor(generation_model.state_dict(), "model.generation.pt")
            save_tensor(discrimination_model.state_dict(), "model.discrimination.pt")
            if (epoch + 1) % 100 == 0:
                save_tensor(generation_model.state_dict(), f"model.generation.epoch_{epoch}.pt")
                save_tensor(discrimination_model.state_dict(), f"model.discrimination.epoch_{epoch}.pt")
            print("model saved")

    print("training finished")

    # 顯示訓練集的正確率變化
    pyplot.plot(training_accuracy_g_history, label="training_accuracy_g")
    pyplot.plot(training_accuracy_d_history, label="training_accuracy_d")
    pyplot.ylim(0, 1)
    pyplot.legend()
    pyplot.show()

from http.server import HTTPServer, BaseHTTPRequestHandler
from urllib.parse import parse_qs
from io import BytesIO
class RequestHandler(BaseHTTPRequestHandler):
    """用於測試生成圖片的簡單伺服器"""
    # 模型狀態的路徑,這裡使用看起來效果最好的記錄
    MODEL_STATE_PATH = "model.generation.epoch_2999.pt"
    Model = None

    @staticmethod
    def get_model():
        if RequestHandler.Model is None:
            # 建立模型例項,載入訓練好的狀態,然後切換到驗證模式
            model = GenerationModel().to(device)
            model.load_state_dict(load_tensor(RequestHandler.MODEL_STATE_PATH))
            model.eval()
            RequestHandler.Model = model
        return RequestHandler.Model

    def do_GET(self):
        parts = self.path.partition("?")
        if parts[0] == "/":
            self.send_response(200)
            self.send_header("Content-type", "text/html")
            self.end_headers()
            with open("gan_eval.html", "rb") as f:
                self.wfile.write(f.read())
        elif parts[0] == "/generate":
            # 根據傳入的引數生成圖片
            params = parse_qs(parts[-1])
            vector = (torch.tensor([float(x) for x in params["values"][0].split(",")])
                .reshape(1, GenerationModel.EmbeddedSize)
                .to(device))
            generated = RequestHandler.get_model()(vector)[0]
            img = tensor_to_image(generated.cpu())
            bytes_io = BytesIO()
            img.save(bytes_io, format="PNG")
            # 返回圖片
            self.send_response(200)
            self.send_header("Content-type", "image/png")
            self.end_headers()
            self.wfile.write(bytes_io.getvalue())
        else:
            self.send_response(404)
            self.end_headers()
            self.wfile.write(b"Not Found")

def eval_model():
    """使用訓練好的模型生成圖片"""
    server = HTTPServer(("localhost", 8666), RequestHandler)
    print("Please access http://localhost:8666")
    try:
        server.serve_forever()
    except KeyboardInterrupt:
        pass
    server.server_close()
    exit()

def main():
    """主函式"""
    if len(sys.argv) < 2:
        print(f"Please run: {sys.argv[0]} prepare|train|eval")
        exit()

    # 給隨機數生成器分配一個初始值,使得每次執行都可以生成相同的隨機數
    # 這是為了讓過程可重現,你也可以選擇不這樣做
    random.seed(0)
    torch.random.manual_seed(0)

    # 根據命令列引數選擇操作
    operation = sys.argv[1]
    if operation == "prepare":
        prepare()
    elif operation == "train":
        train()
    elif operation == "eval":
        eval_model()
    else:
        raise ValueError(f"Unsupported operation: {operation}")

if __name__ == "__main__":
    main()

儲存程式碼到 gan.py,然後執行以下命令即可開始訓練:

python3 gan.py prepare
python3 gan.py train

同樣訓練 2000 輪以後,DCGAN, WGAN, WGAN-GP 輸出的樣本如下:

DCGAN

WGAN

WGAN-GP

可以看到 WGAN-GP 受模式崩潰問題影響最少,並且效果也更好?。

WGAN-GP 訓練到 3000 次以後輸出的樣本如下:

WGAN-GP 訓練到 10000 次以後輸出的樣本如下:

隨著訓練次數增多,WGAN-GP 一樣無法避免模式崩潰問題,這就是為什麼以上程式碼會記錄每一輪訓練後輸出的樣本,並在每 100 輪訓練以後儲存單獨的模型狀態,這樣訓練結束以後我們可以通過評價輸出的樣本找到效果最好的批次,然後使用該批次的模型狀態。

上述的例子效果最好的狀態是訓練 3000 次以後的狀態。

你可能發現輸出的樣本中夾雜了一些畸形?,這是因為生成器沒有覆蓋到輸入的向量空間,最主要的原因是隨機輸入中包含了很多接近 0 的值,避免這個問題簡單的做法是生成隨機輸入時限制值必須小於或大於某個值。原則上給反摺積層設定 Bias 也可以避免這個問題,但會更容易陷入模式崩潰問題。

使用訓練好的模型生成人臉就比較簡單了:

generation_model = GenerationModel().to(device)
model.load_state_dict(load_tensor("model.generation.epoch_2999.pt"))
model.eval()

# 隨機生成 100 張人臉
vector = torch.randn((100, GenerationModel.EmbeddedSize), device=device)
samples = model(vector)
for index, sample in enumerate(samples):
    img = tensor_to_image(sample.cpu())
    img.save(f"{index}.png")

額外的,我做了一個可以動態調整引數捏臉的網頁,html 程式碼如下:

<!DOCTYPE html>
<html lang="cn">
  <head>
    <meta charset="utf-8">
    <title>測試人臉生成</title>
    <style>
        html, body {
            width: 100%;
            height: 100%;
            margin: 0px;
        }
        .left-pane {
            width: 50%;
            height: 100%;
            border-right: 1px solid #000;
        }
        .right-pane {
            position: fixed;
            left: 70%;
            top: 35%;
            width: 25%;
        }
        .sliders {
            padding: 8px;
        }
        .slider-container {
            display: inline-block;
            min-width: 25%;
        }
        #image {
            left: 25%;
            top: 25%;
            width: 50%;
            height: 50%;
        }
    </style>
  </head>
  <body>
    <div class="left-pane">
        <div class="sliders">
        </div>
    </div>
    <div class="right-pane">
       <p><img id="target" src="data:image/png;base64," alt="image" /></p>
       <p><button class="set-random">隨機生成</button></p>
    </div>
  </body>
  <script>
    (function() {
        // 滑動條改變後的處理
        var onChanged = function() {
            var sliderInputs = document.querySelectorAll(".slider");
            var values = [];
            sliderInputs.forEach(function(s) {
                values.push(s.value);
            });
            var image = document.querySelector("#target");
            image.setAttribute("src", "/generate?values=" + values.join(","));
        };

        // 點選隨機生成時的處理
        var setRandomButton = document.querySelector(".set-random");
        setRandomButton.onclick = function() {
            var sliderInputs = document.querySelectorAll(".slider");
            sliderInputs.forEach(function(s) { s.value = Math.random() * 2 - 1; });
            onChanged();
        };

        // 新增滑動條
        var sliders = document.querySelector(".sliders");
        for (var n = 0; n < 128; ++n) {
            var container = document.createElement("div");
            container.setAttribute("class", "slider-container");
            var span = document.createElement("span");
            span.innerText = n;
            container.appendChild(span);
            var slider = document.createElement("input");
            slider.setAttribute("type", "range")
            slider.setAttribute("class", "slider");
            slider.setAttribute("min", "-1");
            slider.setAttribute("max", "1");
            slider.setAttribute("step", "0.01");
            slider.value = 0;
            slider.onchange = onChanged;
            slider.oninput = onChanged;
            container.appendChild(slider);
            sliders.appendChild(container);
        }
    })();
  </script>
</html>

儲存到 gan_eval.html 以後執行以下命令即可啟動伺服器:

python3 gan.py eval

瀏覽器開啟 http://localhost:8666 以後會顯示以下介面,點選隨機生成按鈕可以隨機生成人臉,拉動左邊的引數條可以動態調整引數:

一些捏臉的網站會分析各個引數的含義,看看哪些引數代表膚色,那些引數代表表情,哪些引數代表脫髮程度,我比較懶就只給出各個引數的序號了?。

寫在最後

又摸完一個新的模型了,跟到這篇的人也越來越少了,估計這個系列再寫一兩篇就會結束 (VAE, 強化學習)。

前一篇論文我提到了可能會開一個新的系列介紹 .NET 的機器學習,但我決定不開了。經過試驗發現沒有達到可用的水平,文件基本等於沒有,社群氣氛也不行 (大會 PPT 倒是做的挺好的)。畢竟語言只是個工具,不是老祖宗,還是看開一點吧。學 python 再做機器學習會輕鬆很多,就像長遠來說學一點基礎英語再程式設計比完全只用中文程式設計 (先把基礎框架類庫系統介面的英文全部翻譯成中文,再用中文寫) 簡單很多,對叭?。

相關文章