簡單使用PyTorch搭建GAN模型

格物鈦Graviti發表於2021-08-25

作者|Ta-Ying Cheng,牛津大學博士研究生,Medium技術博主,多篇文章均被平臺官方刊物Towards Data Science收錄

翻譯|頌賢

以往人們普遍認為生成影像是不可能完成的任務,因為按照傳統的機器學習思路,我們根本沒有真值(ground truth)可以拿來檢驗生成的影像是否合格。

2014年,Goodfellow等人則提出生成對抗網路(Generative Adversarial Network, GAN),能夠讓我們完全依靠機器學習來生成極為逼真的圖片。GAN的橫空出世使得整個人工智慧行業都為之震動,計算機視覺和影像生成領域發生了鉅變。

本文將帶大家瞭解GAN的工作原理,並介紹如何透過PyTorch簡單上手GAN

GAN的原理

按照傳統的方法,模型的預測結果可以直接與已有的真值進行比較。然而,我們卻很難定義和衡量到底怎樣才算作是“正確的”生成影像。

Goodfellow等人則提出了一個有趣的解決辦法:我們可以先訓練好一個分類工具,來自動區分生成影像和真實影像。這樣一來,我們就可以用這個分類工具來訓練一個生成網路,直到它能夠輸出完全以假亂真的影像,連分類工具自己都沒有辦法評判真假。 圖 1. GAN的運作流程. 圖源作者. 按照這一思路,我們便有了GAN:也就是一個生成器(generator)和一個判別器(discriminator)。生成器負責根據給定的資料集生成影像,判別器則負責區分影像是真是假。GAN的運作流程如上圖所示。

損失函式

在GAN的運作流程中,我們可以發現一個明顯的矛盾:同時最佳化生成器和判別器是很困難的。可以想象,這兩個模型有著完全相反的目標:生成器想要儘可能偽造出真實的東西,而判別器則必須要識破生成器生成的影像。

為了說明這一點,我們設D(x)為判別器的輸出,即x是真實影像的機率,並設G(z)為生成器的輸出。判別器類似於一種二進位制的分類器,所以其目標是使該函式的結果最大化:請新增圖片描述 這一函式本質上是非負的二元交叉熵損失函式。另一方面,生成器的目標是最小化判別器做出正確判斷的機率,因此它的目標是使上述函式的結果最小化。

因此,最終的損失函式將會是兩個分類器之間的極小極大博弈,表示如下: 請新增圖片描述 理論上來說,博弈的最終結果將是讓判別器判斷成功的機率收斂到0.5。然而在實踐中,極大極小博弈通常會導致網路不收斂,因此仔細調整模型訓練的引數非常重要。

在訓練GAN時,我們尤其要注意學習率等超引數,學習率比較小時能讓GAN在輸入噪音較多的情況下也能有較為統一的輸出。

計算環境

本文將指導大家透過PyTorch搭建整個程式(包括torchvision)。同時,我們將會使用Matplotlib來讓GAN的生成結果視覺化。以下程式碼能夠匯入上述所有庫:

"""
Import necessary libraries to create a generative adversarial network
The code is mainly developed using the PyTorch library
"""
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms
from model import discriminator, generator
import numpy as np
import matplotlib.pyplot as plt

資料集

資料集對於訓練GAN來說非常重要,尤其考慮到我們在GAN中處理的通常是非結構化資料(一般是圖片、影片等),任意一class都可以有資料的分佈。這種資料分佈恰恰是GAN生成輸出的基礎。

為了更好地演示GAN的搭建流程,本文將帶大家使用最簡單的MNIST資料集,其中含有6萬張手寫阿拉伯數字的圖片。

MNIST這樣高質量的非結構化資料集都可以在格物鈦公開資料集網站上找到。事實上,格物鈦Open Datasets平臺涵蓋了很多優質的公開資料集,同時也可以實現資料集託管及一站式搜尋的功能,這對AI開發者來說,是相當實用的社群平臺。 請新增圖片描述

硬體需求

一般來說,雖然可以使用CPU來訓練神經網路,但最佳選擇其實是GPU,因為這樣可以大幅提升訓練速度。我們可以用下面的程式碼來測試自己的機器能否用GPU來訓練:

"""
Determine if any GPUs are available
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

實現

網路結構

由於數字是非常簡單的資訊,我們可以將判別器和生成器這兩層結構都組建成全連線層(fully connected layers)。

我們可以用以下程式碼在PyTorch中搭建判別器和生成器: 

"""
Network Architectures
The following are the discriminator and generator architectures
"""

class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 1)
        self.activation = nn.LeakyReLU(0.1)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return nn.Sigmoid()(x)


class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(128, 1024)
        self.fc2 = nn.Linear(1024, 2048)
        self.fc3 = nn.Linear(2048, 784)
        self.activation = nn.ReLU()

def forward(self, x):
    x = self.activation(self.fc1(x))
    x = self.activation(self.fc2(x))
    x = self.fc3(x)
    x = x.view(-1, 1, 28, 28)
    return nn.Tanh()(x)

訓練

在訓練GAN的時候,我們需要一邊最佳化判別器,一邊改進生成器,因此每次迭代我們都需要同時最佳化兩個互相矛盾的損失函式。

對於生成器,我們將輸入一些隨機噪音,讓生成器來根據噪音的微小改變輸出的影像:

"""
Network training procedure
Every step both the loss for disciminator and generator is updated
Discriminator aims to classify reals and fakes
Generator aims to generate images as realistic as possible
"""
for epoch in range(epochs):
    for idx, (imgs, _) in enumerate(train_loader):
        idx += 1

        # Training the discriminator
        # Real inputs are actual images of the MNIST dataset
        # Fake inputs are from the generator
        # Real inputs should be classified as 1 and fake as 0
        real_inputs = imgs.to(device)
        real_outputs = D(real_inputs)
        real_label = torch.ones(real_inputs.shape[0], 1).to(device)

        noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
        noise = noise.to(device)
        fake_inputs = G(noise)
        fake_outputs = D(fake_inputs)
        fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)

        outputs = torch.cat((real_outputs, fake_outputs), 0)
        targets = torch.cat((real_label, fake_label), 0)

        D_loss = loss(outputs, targets)
        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()

        # Training the generator
        # For generator, goal is to make the discriminator believe everything is 1
        noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
        noise = noise.to(device)

        fake_inputs = G(noise)
        fake_outputs = D(fake_inputs)
        fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)
        G_loss = loss(fake_outputs, fake_targets)
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        if idx % 100 == 0 or idx == len(train_loader):
            print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))

    if (epoch+1) % 10 == 0:
        torch.save(G, 'Generator_epoch_{}.pth'.format(epoch))
        print('Model saved.')

結果

經過100個訓練時期之後,我們就可以對資料集進行視覺化處理,直接看到模型從隨機噪音生成的數字: 請新增圖片描述 我們可以看到,生成的結果和真實的資料非常相像。考慮到我們在這裡只是搭建了一個非常簡單的模型,實際的應用效果會有非常大的上升空間。

不僅是有樣學樣

GAN和以往機器視覺專家提出的想法都不一樣,而利用GAN進行的具體場景應用更是讓許多人讚歎深度網路的無限潛力。下面我們來看一下兩個最為出名的GAN延申應用。

CycleGAN

朱儁彥等人2017年發表的CycleGAN能夠在沒有配對圖片的情況下將一張圖片從X域直接轉換到Y域,比如把馬變成斑馬、將熱夏變成隆冬、把莫奈的畫變成梵高的畫等等。這些看似天方夜譚的轉換CycleGAN都能輕鬆做到,並且結果非常準確。 請新增圖片描述

GauGAN

英偉達則透過GAN讓人們能夠只需要寥寥數筆勾勒出自己的想法,便能得到一張極為逼真的真實場景圖片。雖然這種應用需要的計算成本極為高昂,但是GauGAN憑藉它的轉換能力探索出了前所未有的研究和應用領域。

請新增圖片描述

結語

相信看到這裡,你已經知道了GAN的大致工作原理,並且能夠自己動手簡單搭建一個GAN了。

相關文章