0902-用GAN生成動漫頭像

二十三歲的有德發表於2021-05-13

0902-用GAN生成動漫頭像

pytorch完整教程目錄:https://www.cnblogs.com/nickchen121/p/14662511.html

一、概述

本節將通過 GAN 實現一個生成動漫人物頭像的例子。

在日本的技術部落格網站上有個博主,利用 DCGAN 從 20 萬張動漫頭像中學習,最終能夠利用程式自動生成動漫頭像。源程式是利用 Chainer 框架實現的,在這裡我們將嘗試利用 Pytorch 實現。

原始的圖片是從網站中採集的,並利用 OpenCV 擷取頭像,處理起來非常麻煩。因此我們在這裡通過之乎使用者 何之源 爬取並經過處理的 5 萬張圖片,想要圖片的百度網盤連結的可以加我微信:chenyoudea。需要注意的是,這裡圖片的解析度是 3×96×96,而不是論文中的 3×64×64,因此需要相應地調整網路結構,使生成影像的尺寸為 96。

二、程式碼結構

下面我們首先來看下我們未來的一個程式碼結構。

checkpoints/  # 無程式碼,用來儲存模型
imgs/  # 無程式碼,用來儲存生成的圖片
data/  # 無程式碼,用來儲存訓練所需要的圖片
main.py  # 訓練和生成
model.py  # 模型定義
visualize.py  # 視覺化工具 visdom 的開發
requirement.txt  # 程式中用到的第三方庫
README.MD  # 說明

三、model.py

model.py 主要是用來定義生成器和判別器的。

3.1 生成器

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by https://www.cnblogs.com/nickchen121/
# Datatime:2021/5/10 10:37
# Filename:model.py
# Toolby: PyCharm
from torch import nn


class NetG(nn.Module):
    """
    生成器定義
    """

    def __init__(self, opt):
        super(NetG, self).__init__()
        ngf = opt.ngf  # 生成器 feature map 數
        self.main = nn.Sequential(
            # 輸入是 nz 維度的噪聲,可以認識它是一個 nz*1*1 的 feature map
            # H_{out} = (H_{in}-1)*stride - 2*padding + kernel_size
            # 以下面一行程式碼的ConvTranspose2d舉例(初始 H_{in}=1):H_{out} = (1-1)*1-2*0+4 = 4
            nn.ConvTranspose2d(opt.nz, ngf * 8, (4, 4), (1, 1), (0, 0), bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 上一步的輸出形狀:(ngf*8)*4*4,其中(ngf*8)是輸出通道數,4 為 H_{out} 是通過上述公式計算出來的

            # 以下面一行程式碼的ConvTranspose2d舉例(初始 H_{in}=4):H_{out} = (4-1)*2-2*1+4 =8
            nn.ConvTranspose2d(ngf * 8, ngf * 4, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 上一步的輸出形狀:(ngf*4)*8*8

            nn.ConvTranspose2d(ngf * 4, ngf * 2, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 上一步的輸出形狀是:(ngf*2)*16*16

            nn.ConvTranspose2d(ngf * 2, ngf, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 上一步的輸出形狀:(ngf)*32*32

            nn.ConvTranspose2d(ngf, 3, (5, 5), (3, 3), (1, 1), bias=False),
            nn.Tanh()
            # 輸出形狀:3*96*96
        )

    def forward(self, inp):
        return self.main(inp)

從上述生成器的程式碼可以看出生成器的構建比較簡單,直接用 nn.Sequential 把上卷積、啟用等操作拼接起來就行了。這裡稍微注意下 ConvTranspose2d 的使用,當 kernel size 為 4、stride 為 2、padding 為 1 時,根據公式 \(H_{out} = (H_{in}-1)*stride - 2*padding + kernel_size\),輸出尺寸剛好變成輸入的兩倍。

最後一層我們使用了 tanh 把輸出圖片的畫素歸一化至 -1~1,如果希望歸一化到 0~1,可以使用 sigimoid 方法。

3.2 判別器

class NetD(nn.Module):
    """
    判別器定義
    """

    def __init__(self, opt):
        super(NetD, self).__init__()
        ndf = opt.ndf
        self.main = nn.Sequential(
            # 輸入 3*96*96
            nn.Conv2d(3, ndf, (5, 5), (3, 3), (1, 1), bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # 輸出 (ndf)*32*32

            nn.Conv2d(ndf, ndf * 2, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 輸出 (ndf*2)*16*16

            nn.Conv2d(ndf * 2, ndf * 4, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 輸出 (ndf*4)*8*8

            nn.Conv2d(ndf * 4, ndf * 8, (4, 4), (2, 2), (1, 1), bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 輸出 (ndf*8)*4*4

            nn.Conv2d(ndf * 8, 1, (4, 4), (1, 1), (0, 0), bias=False),
            nn.Sigmoid()  # 輸出一個數:概率
        )

    def forward(self, inp):
        return self.main(inp).view(-1)

從上述程式碼可以看到判別器和生成器的網路結構幾乎是對稱的,從卷積核大小到 padding、stride 等設定,幾乎一模一樣。

需要注意的是,生成器的啟用函式用的是 ReLU,而判別器使用的是 LeakyReLU,兩者其實沒有太大的區別,這種選擇更多的是經驗的總結。

判別器的最終輸出是一個 0~1 的數,表示這個樣本是真圖片的概率。

四、引數配置

在開始寫訓練函式前,我們可以先配置模型引數

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by https://www.cnblogs.com/nickchen121/
# Datatime:2021/5/11 15:14
# Filename:config.py
# Toolby: PyCharm
class Config(object):
    data_path = 'data/'  # 資料集存放路徑
    num_workers = 4  # 多程式載入資料所用的程式數
    image_size = 96  # 圖片尺寸
    batch_size = 256
    max_epoch = 200
    lr1 = 2e-4  # 生成器的學習率
    lr2 = 2e-4  # 判別器的學習率
    beta1 = 0.5  # Adam 優化器的 beta1 引數
    use_gpu = False  # 是否使用 GPU
    nz = 100  # 噪聲維度
    ngf = 64  # 生成器的 feature map 數
    ndf = 64  # 判別器的 feature map 數

    save_path = 'imgs/'  # 生成圖片儲存路徑

    vis = True  # 是否使用 visdom 視覺化
    env = 'GAN'  # visdom 的 env
    plot_every = 20  # 每隔 20 個 batch,visdom 畫圖一次

    debug_file = '/tmp/debuggan'  # 存在該檔案則進入 debug 模式
    d_every = 1  # 每 1 個 batch 訓練一次判別器
    g_every = 5  # 每 5 個 batch 訓練一次生成器
    decay_everty = 10  # 每 10 個 epoch 儲存一次模型
    save_every = 10  # 每 10個epoch儲存一次模型
    netd_path = 'checkpoints/netd_211.pth'  # 預訓練模型
    netg_path = 'checkpoints/netg_211.pth'

    # 測試時用的引數
    gen_img = 'result.png'
    # 從 512 張生成的圖片路徑中儲存最好的 64 張
    gen_num = 64
    gen_search_num = 512
    gen_mean = 0  # 噪聲的均值
    gen_std = 1  # 噪聲的方差
    
opt = Config()

上述這些都只是模型的預設引數,還可以利用 Fire 等工具通過命令列傳入,覆蓋預設值。

除此之外,還可以使用 opt.atrr,還可以利用 IDE/Python 提供的自動補全功能,十分方便。

上述的超引數大多是照搬 DCGAN 論文的預設值,這些預設值都是坐著經過大量的實驗,發現這些引數能夠更快地去訓練出一個不錯的模型。

五、資料處理

當我們下載完資料之後,需要把所有圖片放在一資料夾,然後把資料夾移動到 data 目錄下(並且要確保 data 下沒有其他的資料夾)。使用這種方法是為了能夠直接使用 pytorchvision 自帶的 ImageFolder 讀取圖片,而沒有必要自己寫一個 Dataset。

資料讀取和載入的程式碼如下所示。

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by https://www.cnblogs.com/nickchen121/
# Datatime:2021/5/12 09:43
# Filename:dataset.py
# Toolby: PyCharm
import torch as t
import torchvision as tv
from torch.utils.data import DataLoader

from config import opt

# 資料處理,輸出規模為 -1~1
transforms = tv.transforms.Compose([
    tv.transforms.Scale(opt.image_size),
    tv.transforms.CenterCrop(opt.image_size),
    tv.transforms.ToTensor(),
    tv.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# 載入資料集
dataset = tv.datasets.ImageFolder(opt.data_path, transform=transforms)
dataloader = DataLoader(
    dataset,
    batch_size=opt.batch_size,
    shuffle=True,
    num_workers=opt.num_workers,
    drop_last=True
)

從上述程式碼中可以發現,用 ImageFolder 配合 DataLoader 載入圖片十分方便。

六、訓練

在訓練之前,我們還需要定義幾個變數:模型、優化器、噪聲等。

#!/usr/bin/env python
# -*- coding:utf-8 -*-
# Coding by https://www.cnblogs.com/nickchen121/
# Datatime:2021/5/10 10:37
# Filename:main.py
# Toolby: PyCharm
import os
import ipdb
import tqdm
import fire
import torch as t
import torchvision as tv
from visualize import Visualizer
from torch.autograd import Variable
from torchnet.meter import AverageValueMeter

from config import opt
from dataset import dataloader
from model import NetD, NetG



def train(**kwargs):
    # 定義模型
    netd = NetD()
    netg = NetG()
    # 定義網路
    map_location = lambda storage, loc: storage
    if opt.netd_path:
        netd.load_state_dict(t.load(opt.netd_path, map_location=map_location))
    if opt.netg_path:
        netg.load_state_dict(t.load(opt.netg_path, map_location=map_location))

    # 定義優化器和損失
    optimizer_g = t.optim.Adam(netg.parameters(), opt.lr1, betas=(opt.beta1, 0.999))
    optimizer_d = t.optim.Adam(netd.parameters(), opt.lr2, betas=(opt.beta1, 0.999))
    criterion = t.nn.BCELoss()

    # 真圖片 label 為 1,假圖片 label 為 0,noises 為生成網路的輸入噪聲
    true_labels = Variable(t.ones(opt.batch_size))
    fake_labels = Variable(t.zeros(opt.batch_size))
    fix_noises = Variable(t.randn(opt.batch_size, opt.nz, 1, 1))
    noises = vars(t.randn(opt.batch_size, opt.nz, 1, 1))

    # 如果使用 GPU 訓練,把資料轉移到 GPU 上
    if opt.use_gpu:
        netd.cuda()
        netg.cuda()
        criterion.cuda()
        true_labels, fake_labels = true_labels.cuda(), fake_labels.cuda()
        fix_noises, noises = fix_noises.cuda(), noises.cuda()

在載入預訓練模型的時候,最好指定 map_location。因為如果程式之前在 GPU 上執行,那麼模型就會被存成 torch.cuda.Tensor,這樣載入的時候會預設把資料載入到視訊記憶體上。如果執行該程式的計算機中沒有 GPU,則會報錯,因此指定 map_location 把 Tensor 預設載入到記憶體上,等有需要的時候再載入到視訊記憶體中。

下面開始訓練網路,訓練的步驟如下所示:

  1. 訓練判別器:
    • 固定生成器
    • 對於真圖片,判別器的輸出概率值儘可能接近 1
    • 對於生成器生成的圖片,判別器儘可能輸出 0
  2. 訓練生成器
    • 固定判別器
    • 生成器生成圖片,儘可能讓判別器輸出 1
  3. 返回第一步,迴圈交替訓練
    epochs = range(opt.max_epoch)
    for epoch in iter(epochs):

        for ii, (img, _) in tqdm.tqdm(enumerate(dataloader)):
            real_img = Variable(img)
            if opt.use_gpu:
                real_img = real_img.cuda()

            # 訓練判別器
            if (ii + 1) % opt.d_every == 0:
                optimizer_d.zero_grad()
                # 儘可能把真圖片判別為 1
                output = netd(real_img)
                error_d_real = criterion(output, true_labels)
                error_d_real.backward()

                # 儘可能把假圖片判別為 0
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises).detach()  # 根據照片生成假圖片
                fake_ouput = netd(fake_img)
                error_d_fake = criterion(fake_ouput, fake_labels)
                error_d_fake.backward()
                optimizer_d.step()

            # 訓練生成器
            if (ii + 1) % opt.g_every == 0:
                optimizer_g.zero_grad()
                noises.data.copy_(t.randn(opt.batch_size, opt.nz, 1, 1))
                fake_img = netg(noises)
                fake_output = netd(fake_img)
                # 儘可能讓判別器把假圖片也判別為 1
                error_g = criterion(fake_output, true_labels)
                error_g.backward()
                optimizer_g.step()

            # 視覺化

            if opt.vis and ii % opt.plot_every == opt.plot_every - 1:
                # 定義視覺化視窗
                vis = Visualizer(opt.env)

                if os.path.exists(opt.debug_file):
                    ipdb.set_trace()
                global fix_fake_imgs
                fix_fake_imgs = netg(fix_noises)
                vis.images(fix_fake_imgs.detach().cpu().numpy()[:64] * 0.5 + 0.5, win='fixfake')
                vis.images(real_img.data.cpu().numpy()[:64] * 0.5 + 0.5, win='real')
                vis.plot('errord', errord_meter.value()[0])
                vis.plot('errorg', errorg_meter.value()[0])

        if (epoch + 1) % opt.save_every == 0:
            # 儲存模型、圖片
            tv.utils.save_image(fix_fake_imgs.data[:64], '%s/%s.png' % (opt.save_path, epoch), normalize=True,
                                range=(-1, 1))
            t.save(netd.state_dict(), 'checkpoints/netd_%s.pth' % epoch)
            t.save(netg.state_dict(), 'checkpoints/netg_%s.pth' % epoch)
            errord_meter.reset()
            errorg_meter.reset()

在上述訓練程式碼中,需要注意以下幾點:

  • 訓練生成器的時候,不需要調整判別器的引數;訓練判別器的時候,也不需要調整生成器的引數
  • 在訓練判別器的時候,需要對生成器生成的圖片用 detach 操作進行計算圖截斷,避免反向傳播把梯度傳到生成器中。因為在訓練判別器的時候我們不需要訓練生成器,也就不需要生成器的梯度。
  • 在訓練分類器的時候,需要反向傳播兩次,一次是希望把真圖片判為 1,一次是希望把假圖片判為 0.也可以把這個兩者的資料放到一個 batch 中,進行一次前向傳播和一次反向傳播即可。但是人們發現,在一個 batch 中只包含真圖片或者只包含假圖片的做法最好。
  • 對於假圖片,在訓練判別器的時候,我們希望它輸出為 0;而在訓練生成器的時候,我們希望它輸出為 1.因此可以看到一堆相互矛盾的程式碼:error_d_fake = criterion(fake_output,fake_labels)error_g = criterion(fake_output, true_labels)。其實這也很好理解,判別器希望能夠把假圖片判別為 fake_label,而生成器希望能把它判別為 true_label,判別器和生成器相互對抗提升。
  • 其中的 Visualize 模組類似於上一章自己的寫的模組,可以直接複製貼上原始碼中的程式碼。

七、隨機生成圖片

除了上述所示的程式碼外,還提供了一個函式,能載入預訓練好的模型,並且利用噪聲隨機生成圖片。

@t.no_grad()
def generate():
    # 定義噪聲和網路
    netg, netd = NetG(opt), NetD(opt)
    noises = t.randn(opt.gen_search_num, opt.nz, 1, 1).normal_(opt.gen_mean, opt.gen_std)
    noises = Variable(noises)

    # 載入預訓練的模型
    netd.load_state_dict(t.load(opt.netd_path))
    netg.load_state_dict(t.load(opt.netg_path))

    # 是否使用 GPU
    if opt.use_gpu:
        netd.cuda()
        netg.cuda()
        noises = noises.cuda()

    # 生成圖片,並計算圖片在判別器的分數
    fake_img = netg(noises)
    scores = netd(fake_img).data

    # 挑選最好的某幾張
    indexs = scores.topk(opt.gen_num)[1]
    result = []
    for ii in indexs:
        result.append(fake_img.data[ii])

    # 儲存圖片
    tv.utils.save_image(t.stack(result), opt.gen_num, normalize=True, range=(-1, 1))

八、訓練模型並測試

完整的程式碼可以新增我微信:chenyoudea,其實上述程式碼已經很完整了,或者去github https://github.com/chenyuntc/pytorch-book/tree/master/chapter07-AnimeGAN下載。

這裡假設你是擁有完整的程式碼,那麼準備好資料後,可以用下面的命令開始訓練:

python main.py train --gpu=True --vis=True --batch-size=256 --max-epoch=200

如果使用了 visdom,此時開啟 http://localhost:8097 就能看到生成的影像。

訓練完成後,我們就可以利用生成網路隨機生成動漫頭像,輸入命令如下:

python main.py generate --gen-img='result.5w.png' --gen-search-num=15000

下圖是 10 個 epoch 的展示:
0902-用GAN生成動漫頭像

相關文章