[譯] GAN 的 Keras 實現:構建影象去模糊應用

luochen1992發表於2018-04-18

2014年,Ian Goodfellow 提出了生成對抗網路(Generative Adversarial Networks) (GAN),本文將聚焦於利用 Keras 實現基於對抗生成網路的影象去模糊模型所有的 Keras 程式碼都在 這裡.

檢視原文 scientific publication 以及 Pytorch 版本實現.


快速回顧生成對抗網路

在生成對抗網路中,兩個網路互相訓練。生成模型通過創造以假亂真的輸入誤導判別模型。判別模型則區分輸入是真實的還是偽造的

[譯] GAN 的 Keras 實現:構建影象去模糊應用

GAN 訓練流程 — Source

訓練有3個主要步驟

  • 使用生成模型創造基於噪聲的假輸入
  • 同時使用真實的和虛假的輸入訓練判別模型
  • 訓練整個模型: 該模型是由生成模型後串接判別模型所構成的。

請注意,在第三步中,判別模型的權重不再更新。

串接兩個模型網路的原因是不可能直接對生成模型輸出進行反饋。我們衡量(生成模型的輸出)的唯一標準是判別模型是否接受生成的樣本

這裡簡要回顧了 GAN 的結構。如果你覺得不容易理解,你可以參考這個 excellent introduction.


資料集

Ian Goodfellow 首先應用 GAN 模型生成 MNIST 資料。在本教程中,我們使用生成對抗網路進行影象去模糊。因此,生成模型的輸入不是噪聲而是模糊的影象。

資料集採用 GOPRO 資料集。您可以下載 精簡版 (9GB) 或 完整版 (35GB)。它包含來自多個街景的人為模糊影象。資料集在按場景分的子資料夾裡。

我們先將圖片放在資料夾 A(模糊)和 B(清晰)中。這種 A 和 B 的結構與原論文 pix2pix article 一致。我寫了一個 自定義指令碼 去執行這個任務,按照 README 使用它。


模型

訓練過程保持不變。首先,讓我們看看神經網路結構!

生成模型

生成模型旨在重現清晰的影象。該網路模型是基於 殘差網路(ResNet) 塊(block)。它持續追蹤原始模糊影象的演變。這篇文章是基於 UNet 版本的, 我還沒實現過。這兩種結構都適合用於影象去模糊。

[譯] GAN 的 Keras 實現:構建影象去模糊應用

DeblurGAN 生成模型的網路結構 — Source

核心是應用於原始影象上取樣的 9 個殘差網路塊(ResNet blocks)。讓我們看看 Keras 的實現!

from keras.layers import Input, Conv2D, Activation, BatchNormalization
from keras.layers.merge import Add
from keras.layers.core import Dropout

def res_block(input, filters, kernel_size=(3,3), strides=(1,1), use_dropout=False):
    """
    使用序貫(sequential) API 對 Keras Resnet 塊進行例項化。
    :param input: 輸入張量
    :param filters: 卷積核數目
    :param kernel_size: 卷積核大小
    :param strides: 卷積步幅大小
    :param use_dropout: 布林值,確定是否使用 dropout
    :return: Keras 模型
    """
    x = ReflectionPadding2D((1,1))(input)
    x = Conv2D(filters=filters,
               kernel_size=kernel_size,
               strides=strides,)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    if use_dropout:
        x = Dropout(0.5)(x)

    x = ReflectionPadding2D((1,1))(x)
    x = Conv2D(filters=filters,
                kernel_size=kernel_size,
                strides=strides,)(x)
    x = BatchNormalization()(x)

    # 輸入和輸出之間連線兩個卷積層
    merged = Add()([input, x])
    return merged
複製程式碼

ResNet 層基本是卷積層,新增了輸入和輸出以形成最終輸出。

from keras.layers import Input, Activation, Add
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.core import Lambda
from keras.layers.normalization import BatchNormalization
from keras.models import Model

from layer_utils import ReflectionPadding2D, res_block

ngf = 64
input_nc = 3
output_nc = 3
input_shape_generator = (256, 256, input_nc)
n_blocks_gen = 9


def generator_model():
    """構建生成模型"""
    # Current version : ResNet block
    inputs = Input(shape=image_shape)

    x = ReflectionPadding2D((3, 3))(inputs)
    x = Conv2D(filters=ngf, kernel_size=(7,7), padding='valid')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # Increase filter number
    n_downsampling = 2
    for i in range(n_downsampling):
        mult = 2**i
        x = Conv2D(filters=ngf*mult*2, kernel_size=(3,3), strides=2, padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

    # 應用 9 ResNet blocks
    mult = 2**n_downsampling
    for i in range(n_blocks_gen):
        x = res_block(x, ngf*mult, use_dropout=True)

    # 減少卷積核到3個 (RGB)
    for i in range(n_downsampling):
        mult = 2**(n_downsampling - i)
        x = Conv2DTranspose(filters=int(ngf * mult / 2), kernel_size=(3,3), strides=2, padding='same')(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)

    x = ReflectionPadding2D((3,3))(x)
    x = Conv2D(filters=output_nc, kernel_size=(7,7), padding='valid')(x)
    x = Activation('tanh')(x)

    # Add direct connection from input to output and recenter to [-1, 1]
    outputs = Add()([x, inputs])
    outputs = Lambda(lambda z: z/2)(outputs)

    model = Model(inputs=inputs, outputs=outputs, name='Generator')
    return model
複製程式碼

Keras 實現生成模型

按計劃,9 個 ResNet 塊應用於輸入的上取樣版本。我們新增從輸入端到輸出端的連線併除以 2 以保持標準化的輸出。

這就是生成模型,讓我們看看判別模型。

判別模型

判別模型的目標是確定輸入影象是否是人造的。因此,判別模型的結構是卷積的,並且輸出是單一值

from keras.layers import Input
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import Conv2D
from keras.layers.core import Dense, Flatten
from keras.layers.normalization import BatchNormalization
from keras.models import Model

ndf = 64
output_nc = 3
input_shape_discriminator = (256, 256, output_nc)


def discriminator_model():
    """構建判別模型."""
    n_layers, use_sigmoid = 3, False
    inputs = Input(shape=input_shape_discriminator)

    x = Conv2D(filters=ndf, kernel_size=(4,4), strides=2, padding='same')(inputs)
    x = LeakyReLU(0.2)(x)

    nf_mult, nf_mult_prev = 1, 1
    for n in range(n_layers):
        nf_mult_prev, nf_mult = nf_mult, min(2**n, 8)
        x = Conv2D(filters=ndf*nf_mult, kernel_size=(4,4), strides=2, padding='same')(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(0.2)(x)

    nf_mult_prev, nf_mult = nf_mult, min(2**n_layers, 8)
    x = Conv2D(filters=ndf*nf_mult, kernel_size=(4,4), strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)

    x = Conv2D(filters=1, kernel_size=(4,4), strides=1, padding='same')(x)
    if use_sigmoid:
        x = Activation('sigmoid')(x)

    x = Flatten()(x)
    x = Dense(1024, activation='tanh')(x)
    x = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=inputs, outputs=x, name='Discriminator')
    return model
複製程式碼

Keras實現判別模型

最後一步是構建完整模型。這個 GAN 的 特殊性在於輸入是真實影象而不是噪聲。因此,我們能獲得生成模型輸出的直接反饋。

from keras.layers import Input
from keras.models import Model

def generator_containing_discriminator_multiple_outputs(generator, discriminator):
    inputs = Input(shape=image_shape)
    generated_images = generator(inputs)
    outputs = discriminator(generated_images)
    model = Model(inputs=inputs, outputs=[generated_images, outputs])
    return model
複製程式碼

讓我們看看如何通過使用兩個損失函式來充分利用這種特殊性。


訓練

損失函式

我們在兩個層級抽取損失值,一個是在生成模型的末端,另一個在整個模型的末端。

首先是直接根據生成模型的輸出計算感知損失(perceptual loss)。該損失值確保了 GAN 模型是面向去模糊任務的。它比較了VGG的 第一個卷積輸出。

import keras.backend as K
from keras.applications.vgg16 import VGG16
from keras.models import Model

image_shape = (256, 256, 3)

def perceptual_loss(y_true, y_pred):
    vgg = VGG16(include_top=False, weights='imagenet', input_shape=image_shape)
    loss_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
    loss_model.trainable = False
    return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))
複製程式碼

第二個損失值是計算整個模型的輸出 Wasserstein loss。它是 兩張影象之間的平均差異。它以改善對抗生成網路收斂性而聞名.

import keras.backend as K

def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true*y_pred)
複製程式碼

訓練過程

第一步是載入資料以及初始化模型。我們使用自定義函式載入資料集以及為模型新增 Adam 優化器。我們通過設定 Keras 可訓練選項以防止判別模型進行訓練。

# 載入資料集
data = load_images('./images/train', n_images)
y_train, x_train = data['B'], data['A']

# 初始化模型
g = generator_model()
d = discriminator_model()
d_on_g = generator_containing_discriminator_multiple_outputs(g, d)

# 初始化優化器
g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
d_on_g_opt = Adam(lr=1E-4, beta_1=0.9, beta_2=0.999, epsilon=1e-08)

# 編譯模型
d.trainable = True
d.compile(optimizer=d_opt, loss=wasserstein_loss)
d.trainable = False
loss = [perceptual_loss, wasserstein_loss]
loss_weights = [100, 1]
d_on_g.compile(optimizer=d_on_g_opt, loss=loss, loss_weights=loss_weights)
d.trainable = True
複製程式碼

然後,我們啟動迭代,同時將資料集按批量劃分。

for epoch in range(epoch_num):
  print('epoch: {}/{}'.format(epoch, epoch_num))
  print('batches: {}'.format(x_train.shape[0] / batch_size))

  # 將影象隨機劃入不同批次
  permutated_indexes = np.random.permutation(x_train.shape[0])

  for index in range(int(x_train.shape[0] / batch_size)):
      batch_indexes = permutated_indexes[index*batch_size:(index+1)*batch_size]
      image_blur_batch = x_train[batch_indexes]
      image_full_batch = y_train[batch_indexes]
複製程式碼

最後,我們根據兩種損失先後訓練生成模型和判別模型。我們用生成模型產生假輸入。我們訓練判別模型來區分虛假和真實輸入,然後我們訓練整個模型。

for epoch in range(epoch_num):
  for index in range(batches):
    # [Batch Preparation]

    # 生成假輸入
    generated_images = g.predict(x=image_blur_batch, batch_size=batch_size)
    
    # 在真假輸入上訓練多次判別模型
    for _ in range(critic_updates):
        d_loss_real = d.train_on_batch(image_full_batch, output_true_batch)
        d_loss_fake = d.train_on_batch(generated_images, output_false_batch)
        d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

    d.trainable = False
    # Train generator only on discriminator's decision and generated images
    d_on_g_loss = d_on_g.train_on_batch(image_blur_batch, [image_full_batch, output_true_batch])

    d.trainable = True
複製程式碼

你可以參考 Github 看整個迴圈!

一些材料

我在 Deep Learning AMI (version 3.0) 中使用了 AWS Instance (p2.xlarge) 在 GOPRO 資料集 精簡版下,訓練時間約為5小時(50 次迭代)。

影象去模糊結果

[譯] GAN 的 Keras 實現:構建影象去模糊應用

從左到右: 原始影象、模糊影象、GAN 輸出

上面的輸出是我們 Keras Deblur GAN 的結果。即使在嚴重模糊的情況下,網路也能夠減少並形成更令人信服的影象。車燈更清晰,樹枝更清晰。

[譯] GAN 的 Keras 實現:構建影象去模糊應用

左: GOPRO 測試影象, 右: GAN 輸出.

一個限制是影象上的誘導模式,這可能是由於使用 VGG 作為損失而引起的。

[譯] GAN 的 Keras 實現:構建影象去模糊應用

左: GOPRO 測試影象, 右: GAN 輸出.

我希望你喜歡這篇關於利用生成對抗模型進行影象去模糊的文章。歡迎發表評論,關注我們或 與我聯絡.

如果您對計算機視覺感興趣,可以看看我們以前寫的一篇文章 Keras 實現基於內容的影象檢索。以下是生成對抗網路的資源列表。

[譯] GAN 的 Keras 實現:構建影象去模糊應用

左:GOPRO 測試影象,右:GAN 輸出。

生成對抗網路的資源列表。


掘金翻譯計劃 是一個翻譯優質網際網路技術文章的社群,文章來源為 掘金 上的英文分享文章。內容覆蓋 AndroidiOS前端後端區塊鏈產品設計人工智慧等領域,想要檢視更多優質譯文請持續關注 掘金翻譯計劃官方微博知乎專欄

相關文章