[深度學習]生成對抗網路的實踐例子
系列文章目錄
深度學習GAN(一)之簡單介紹
深度學習GAN(二)之DCGAN基於CIFAR10資料集的例子
深度學習GAN(三)之DCGAN基於手寫體Mnist資料集的例子
深度學習GAN(四)之cGAN (Conditional GAN)的例子
深度學習GAN(五)之PIX2PIX GAN的例子
深度學習GAN(六)之CycleGAN的例子
Pix2Pix GAN的例子
1. Pix2Pix介紹
Pix2Pix是一個對抗神經網路(GAN)模型,設計一般用於影像到影像轉換。
該方法由Phillip Isola等提出。在其2016年題為“使用條件對抗網路的影像到影像翻譯”的論文中,該論文於2017年在CVPR上發表。
GAN架構由用於輸出新的合理合成影像的生成器模型和將影像分類為真實(來自資料集)或偽影像(生成)的鑑別器模型組成。鑑別器模型直接更新,而生成器模型通過鑑別器模型更新。這樣,在對抗過程中同時訓練兩個模型,其中生成器試圖更好地欺騙鑑別器,而鑑別器試圖更好地識別偽造影像。
Pix2Pix模型是一種條件GAN或cGAN,其中輸出影像的生成取決於輸入(在這種情況下為源影像)。鑑別器既提供源影像又提供目標影像,並且必須確定目標是否是源影像的合理變換。
通過對抗損失訓練生成器,這鼓勵了生成器在目標域中生成合理的影像。還通過在生成的影像和預期的輸出影像之間測量的L1損耗來更新生成器。這種額外的損失鼓勵生成器模型建立源影像的合理翻譯。
Pix2Pix GAN已在一系列影像到影像轉換任務中得到了證明,例如將地圖轉換為衛星照片,將黑白照片轉換為顏色,將產品草圖轉換為產品照片。
現在我們已經熟悉了Pix2Pix GAN,下面我們準備一個可用於影像到影像轉換的資料集。
2. 下載衛星地圖資料集
這個資料集由紐約的衛星影像及其相應的Google地圖組成。 影像的轉換問題涉及將衛星照片轉換為Google地圖格式,或者將Google地圖影像轉換為衛星照片。
資料集在pix2pix網站上提供,可以作為255 MB的zip檔案下載。
Download Maps Dataset (maps.tar.gz)
下載後解壓後目錄結構如下:
進入任意一個目錄,開啟其中一個圖片,
3. 資料預處理(Data Reprocessing)
為了讓圖片在訓練的時候載入的快一點,我們把下載的所有的圖片都用Numpy儲存在maps_256.npz.
from os import listdir
from numpy import asarray
from numpy import vstack
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img
from numpy import savez_compressed
# load all images in a directory into memory
def load_images(path, size=(256,512)):
src_list, tar_list = list(), list()
# enumerate filenames in directory, assume all are images
for filename in listdir(path):
# load and resize the image
pixels = load_img(path + filename, target_size=size)
# convert to numpy array
pixels = img_to_array(pixels)
# split into satellite and map
sat_img, map_img = pixels[:, :256], pixels[:, 256:]
src_list.append(sat_img)
tar_list.append(map_img)
return [asarray(src_list), asarray(tar_list)]
# dataset path
path = 'D:/ML/datasets/maps/train/'
# load dataset
[src_images, tar_images] = load_images(path)
print('Loaded: ', src_images.shape, tar_images.shape)
# save as compressed numpy array
filename = 'maps_256.npz'
savez_compressed(filename, src_images, tar_images)
print('Saved dataset: ', filename)
結果是
Loaded: (1096, 256, 256, 3) (1096, 256, 256, 3)
Saved dataset: maps_256.npz
然後執行下面程式碼驗證一下是否正確的可以顯示圖片。
# load the prepared dataset
from numpy import load
from matplotlib import pyplot
# load the dataset
data = load('maps_256.npz')
src_images, tar_images = data['arr_0'], data['arr_1']
print('Loaded: ', src_images.shape, tar_images.shape)
# plot source images
n_samples = 3
for i in range(n_samples):
pyplot.subplot(2, n_samples, 1 + i)
pyplot.axis('off')
pyplot.imshow(src_images[i].astype('uint8'))
# plot target image
for i in range(n_samples):
pyplot.subplot(2, n_samples, 1 + n_samples + i)
pyplot.axis('off')
pyplot.imshow(tar_images[i].astype('uint8'))
pyplot.show()
4. 定義判別器
這個判別器是基於PatchGAN discriminator model實現的。
注意這裡的輸入是兩個圖片,in_src_image是衛星影像, in_target_image是谷歌地圖。
同過Concatenate方法,合併為6個通道,每天圖片是的3個通道(RGB).
啟用函式用LeakyReLU, 除了第一層與最後一層,其它都用BatchNormalization.
輸出層輸出是(16,16,1)
# define the discriminator model
def define_discriminator(image_shape):
# weight initialization
init = RandomNormal(stddev=0.02)
# source image input
in_src_image = Input(shape=image_shape)
# target image input
in_target_image = Input(shape=image_shape)
# concatenate images channel-wise
merged = Concatenate()([in_src_image, in_target_image])
# C64
d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
d = LeakyReLU(alpha=0.2)(d)
# C128
d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# C256
d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# C512
d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# second last output layer
d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# patch output
d = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
patch_out = Activation('sigmoid')(d)
# define model
model = Model([in_src_image, in_target_image], patch_out)
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
return model
if __name__ == '__main__':
d_model = define_discriminator((256,256,3))
print(d_model.summary())
它的結構是
5. 定義生成器
生成器是使用U-Net架構的encoder-decoder模型。 該模型獲取源影像(例如衛星照片)並生成目標影像(例如Google地圖影像)。 它首先通過對輸入影像進行下采樣或編碼到瓶頸層(bottleneck layer),然後對瓶頸(bottleneck layer)表示進行上取樣或解碼到輸出影像的大小來做到這一點。 U-Net體系結構意味著在編碼層和相應的解碼層之間新增跳過連線(skip-connections),從而形成U形。
下圖清楚地顯示了跳過連線(skip-connections),顯示了編碼器的第一層如何連線到解碼器的最後一層,依此類推。
生成器的encoder和decoder由convolutional, batch normalization, dropout, and activation layers組成。 這種標準化意味著我們可以開發輔助函式來建立每個圖層塊,並反覆呼叫它以建立模型的encoder和decoder部分。
下面的define_generator()函式實現了U-Net編碼器-解碼器生成器模型。 它使用define_encoder_block()幫助函式建立用於編碼器的層塊,並使用coder_block()函式建立用於解碼器的層塊。 tanh啟用函式在輸出層中使用,這意味著生成的影像中的畫素值將在[-1,1]範圍內。
輸入是一個文星圖片,經過Encoder-Decoder這個網路結構,最後生成一個谷歌地圖
(256,256,3) ->Encoder-> (1,1,512) -> Decoder -> (256,256,3)
# define an encoder block
def define_encoder_block(layer_in, n_filters, batchnorm=True):
# weight initialization
init = RandomNormal(stddev=0.02)
# add downsampling layer
g = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
# conditionally add batch normalization
if batchnorm:
g = BatchNormalization()(g, training=True)
# leaky relu activation
g = LeakyReLU(alpha=0.2)(g)
return g
# define a decoder block
def decoder_block(layer_in, skip_in, n_filters, dropout=True):
# weight initialization
init = RandomNormal(stddev=0.02)
# add upsampling layer
g = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
# add batch normalization
g = BatchNormalization()(g, training=True)
# conditionally add dropout
if dropout:
g = Dropout(0.5)(g, training=True)
# merge with skip connection
g = Concatenate()([g, skip_in])
# relu activation
g = Activation('relu')(g)
return g
# define the standalone generator model
def define_generator(image_shape=(256,256,3)):
# weight initialization
init = RandomNormal(stddev=0.02)
# image input
in_image = Input(shape=image_shape)
# encoder model
e1 = define_encoder_block(in_image, 64, batchnorm=False)
e2 = define_encoder_block(e1, 128)
e3 = define_encoder_block(e2, 256)
e4 = define_encoder_block(e3, 512)
e5 = define_encoder_block(e4, 512)
e6 = define_encoder_block(e5, 512)
e7 = define_encoder_block(e6, 512)
# bottleneck, no batch norm and relu
b = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)
b = Activation('relu')(b)
# decoder model
d1 = decoder_block(b, e7, 512)
d2 = decoder_block(d1, e6, 512)
d3 = decoder_block(d2, e5, 512)
d4 = decoder_block(d3, e4, 512, dropout=False)
d5 = decoder_block(d4, e3, 256, dropout=False)
d6 = decoder_block(d5, e2, 128, dropout=False)
d7 = decoder_block(d6, e1, 64, dropout=False)
# output
g = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)
out_image = Activation('tanh')(g)
# define model
model = Model(in_image, out_image)
return model
if __name__ == '__main__':
g_model = define_generator((256,256,3))
print(g_model.summary())
它的結構是
6. 定義GAN模型
GAN的模型主要是訓練生成器,所以判別器不訓練(d_model.trainable = False)。
輸入層是衛星圖片(256,256,3),
輸出層是 dis_out=(16,16,1)
gen_out = (256,256,3)
# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model, image_shape):
# make weights in the discriminator not trainable
d_model.trainable = False
# define the source image
in_src = Input(shape=image_shape)
# connect the source image to the generator input
gen_out = g_model(in_src)
# connect the source input and generator output to the discriminator input
dis_out = d_model([in_src, gen_out])
# src image as input, generated image and classification output
model = Model(in_src, [dis_out, gen_out])
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])
return model
if __name__ == '__main__':
d_model = define_discriminator((256,256,3))
g_model = define_generator((256,256,3))
gan_model = define_gan(g_model, d_model, (256,256,3))
print(g_model.summary())
Model: "model_1"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_3 (InputLayer) [(None, 256, 256, 3) 0
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 128, 128, 64) 3136 input_3[0][0]
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU) (None, 128, 128, 64) 0 conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 64, 64, 128) 131200 leaky_re_lu_5[0][0]
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 64, 64, 128) 512 conv2d_7[0][0]
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU) (None, 64, 64, 128) 0 batch_normalization_4[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 32, 32, 256) 524544 leaky_re_lu_6[0][0]
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 32, 32, 256) 1024 conv2d_8[0][0]
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU) (None, 32, 32, 256) 0 batch_normalization_5[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 16, 16, 512) 2097664 leaky_re_lu_7[0][0]
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 16, 16, 512) 2048 conv2d_9[0][0]
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU) (None, 16, 16, 512) 0 batch_normalization_6[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 8, 8, 512) 4194816 leaky_re_lu_8[0][0]
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 8, 8, 512) 2048 conv2d_10[0][0]
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU) (None, 8, 8, 512) 0 batch_normalization_7[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 4, 4, 512) 4194816 leaky_re_lu_9[0][0]
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 4, 4, 512) 2048 conv2d_11[0][0]
__________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU) (None, 4, 4, 512) 0 batch_normalization_8[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 2, 2, 512) 4194816 leaky_re_lu_10[0][0]
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 2, 2, 512) 2048 conv2d_12[0][0]
__________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU) (None, 2, 2, 512) 0 batch_normalization_9[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 1, 1, 512) 4194816 leaky_re_lu_11[0][0]
__________________________________________________________________________________________________
activation_1 (Activation) (None, 1, 1, 512) 0 conv2d_13[0][0]
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 2, 2, 512) 4194816 activation_1[0][0]
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 2, 2, 512) 2048 conv2d_transpose[0][0]
__________________________________________________________________________________________________
dropout (Dropout) (None, 2, 2, 512) 0 batch_normalization_10[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 2, 2, 1024) 0 dropout[0][0]
leaky_re_lu_11[0][0]
__________________________________________________________________________________________________
activation_2 (Activation) (None, 2, 2, 1024) 0 concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 4, 4, 512) 8389120 activation_2[0][0]
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 4, 4, 512) 2048 conv2d_transpose_1[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 4, 4, 512) 0 batch_normalization_11[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 4, 4, 1024) 0 dropout_1[0][0]
leaky_re_lu_10[0][0]
__________________________________________________________________________________________________
activation_3 (Activation) (None, 4, 4, 1024) 0 concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 8, 8, 512) 8389120 activation_3[0][0]
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 8, 8, 512) 2048 conv2d_transpose_2[0][0]
__________________________________________________________________________________________________
dropout_2 (Dropout) (None, 8, 8, 512) 0 batch_normalization_12[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 8, 8, 1024) 0 dropout_2[0][0]
leaky_re_lu_9[0][0]
__________________________________________________________________________________________________
activation_4 (Activation) (None, 8, 8, 1024) 0 concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 16, 16, 512) 8389120 activation_4[0][0]
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 16, 16, 512) 2048 conv2d_transpose_3[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 16, 16, 1024) 0 batch_normalization_13[0][0]
leaky_re_lu_8[0][0]
__________________________________________________________________________________________________
activation_5 (Activation) (None, 16, 16, 1024) 0 concatenate_4[0][0]
__________________________________________________________________________________________________
conv2d_transpose_4 (Conv2DTrans (None, 32, 32, 256) 4194560 activation_5[0][0]
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 32, 32, 256) 1024 conv2d_transpose_4[0][0]
__________________________________________________________________________________________________
concatenate_5 (Concatenate) (None, 32, 32, 512) 0 batch_normalization_14[0][0]
leaky_re_lu_7[0][0]
__________________________________________________________________________________________________
activation_6 (Activation) (None, 32, 32, 512) 0 concatenate_5[0][0]
__________________________________________________________________________________________________
conv2d_transpose_5 (Conv2DTrans (None, 64, 64, 128) 1048704 activation_6[0][0]
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 64, 64, 128) 512 conv2d_transpose_5[0][0]
__________________________________________________________________________________________________
concatenate_6 (Concatenate) (None, 64, 64, 256) 0 batch_normalization_15[0][0]
leaky_re_lu_6[0][0]
__________________________________________________________________________________________________
activation_7 (Activation) (None, 64, 64, 256) 0 concatenate_6[0][0]
__________________________________________________________________________________________________
conv2d_transpose_6 (Conv2DTrans (None, 128, 128, 64) 262208 activation_7[0][0]
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 128, 128, 64) 256 conv2d_transpose_6[0][0]
__________________________________________________________________________________________________
concatenate_7 (Concatenate) (None, 128, 128, 128 0 batch_normalization_16[0][0]
leaky_re_lu_5[0][0]
__________________________________________________________________________________________________
activation_8 (Activation) (None, 128, 128, 128 0 concatenate_7[0][0]
__________________________________________________________________________________________________
conv2d_transpose_7 (Conv2DTrans (None, 256, 256, 3) 6147 activation_8[0][0]
__________________________________________________________________________________________________
activation_9 (Activation) (None, 256, 256, 3) 0 conv2d_transpose_7[0][0]
==================================================================================================
Total params: 54,429,315
Trainable params: 54,419,459
Non-trainable params: 9,856
7. 載入真實圖片以及生成假的圖片
load_real_samples方法是載入真實圖片。
generate_real_samples 方法是生成真實圖片。每個陣列標籤都是1, shape是(16,16,1)
generate_fake_samples方法是生成假的圖片。每個陣列標籤都是0,shape是(16,16,1)
標籤這裡不一樣,一般是數字,但是這裡是shape為(16,16,1)三維陣列。
# load and prepare training images
def load_real_samples(filename):
# load compressed arrays
data = load(filename)
# unpack arrays
X1, X2 = data['arr_0'], data['arr_1']
# scale from [0,255] to [-1,1]
X1 = (X1 - 127.5) / 127.5
X2 = (X2 - 127.5) / 127.5
return [X1, X2]
# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):
# unpack dataset
trainA, trainB = dataset
# choose random instances
ix = randint(0, trainA.shape[0], n_samples)
# retrieve selected images
X1, X2 = trainA[ix], trainB[ix]
# generate 'real' class labels (1)
y = ones((n_samples, patch_shape, patch_shape, 1))
return [X1, X2], y
# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, samples, patch_shape):
# generate fake instance
X = g_model.predict(samples)
# create 'fake' class labels (0)
y = zeros((len(X), patch_shape, patch_shape, 1))
return X, y
8. 用生成器每個幾個Epoch生成一些假的圖片。看看效果
# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, dataset, n_samples=3):
# select a sample of input images
[X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)
# generate a batch of fake samples
X_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)
# scale all pixels from [-1,1] to [0,1]
X_realA = (X_realA + 1) / 2.0
X_realB = (X_realB + 1) / 2.0
X_fakeB = (X_fakeB + 1) / 2.0
# plot real source images
for i in range(n_samples):
pyplot.subplot(3, n_samples, 1 + i)
pyplot.axis('off')
pyplot.imshow(X_realA[i])
# plot generated target image
for i in range(n_samples):
pyplot.subplot(3, n_samples, 1 + n_samples + i)
pyplot.axis('off')
pyplot.imshow(X_fakeB[i])
# plot real target image
for i in range(n_samples):
pyplot.subplot(3, n_samples, 1 + n_samples*2 + i)
pyplot.axis('off')
pyplot.imshow(X_realB[i])
# save plot to file
filename1 = 'pix2pix_plot_%06d.png' % (step+1)
pyplot.savefig(filename1)
pyplot.close()
# save the generator model
filename2 = 'pix2pix_model_%06d.h5' % (step+1)
g_model.save(filename2)
print('>Saved: %s and %s' % (filename1, filename2))
10. 訓練過程
# train pix2pix models
def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1):
# determine the output square shape of the discriminator
n_patch = d_model.output_shape[1]
# unpack dataset
trainA, trainB = dataset
# calculate the number of batches per training epoch
bat_per_epo = int(len(trainA) / n_batch)
# calculate the number of training iterations
n_steps = bat_per_epo * n_epochs
# manually enumerate epochs
for i in range(n_steps):
# select a batch of real samples
[X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)
# generate a batch of fake samples
X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
# update discriminator for real samples
d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
# update discriminator for generated samples
d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
# update the generator
g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])
# summarize performance
print('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss))
# summarize model performance
if (i+1) % (bat_per_epo * 10) == 0:
summarize_performance(i, g_model, dataset)
11. 訓練後效果
在前10個時間段之後,儘管街道的線條並不完全筆直且影像中有些模糊,但仍會生成看起來合理的地圖影像。 但是,大型結構在正確的位置帶有大多數正確的顏色。
經過約50個訓練時期後生成的影像開始看起來非常逼真,至少意味著,並且在其餘訓練過程中質量似乎仍然保持良好。
請注意下面第一個生成的影像示例(右列,中間行),該示例包含比真實Google地圖影像更有用的細節。
12.完整的程式碼
# example of pix2pix gan for satellite to map image-to-image translation
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt
from numpy import load
from numpy import zeros
from numpy import ones
from numpy.random import randint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import LeakyReLU
from matplotlib import pyplot
# define the discriminator model
def define_discriminator(image_shape):
# weight initialization
init = RandomNormal(stddev=0.02)
# source image input
in_src_image = Input(shape=image_shape)
# target image input
in_target_image = Input(shape=image_shape)
# concatenate images channel-wise
merged = Concatenate()([in_src_image, in_target_image])
# C64
d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
d = LeakyReLU(alpha=0.2)(d)
# C128
d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# C256
d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# C512
d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# second last output layer
d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
d = BatchNormalization()(d)
d = LeakyReLU(alpha=0.2)(d)
# patch output
d = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
patch_out = Activation('sigmoid')(d)
# define model
model = Model([in_src_image, in_target_image], patch_out)
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
return model
# define an encoder block
def define_encoder_block(layer_in, n_filters, batchnorm=True):
# weight initialization
init = RandomNormal(stddev=0.02)
# add downsampling layer
g = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
# conditionally add batch normalization
if batchnorm:
g = BatchNormalization()(g, training=True)
# leaky relu activation
g = LeakyReLU(alpha=0.2)(g)
return g
# define a decoder block
def decoder_block(layer_in, skip_in, n_filters, dropout=True):
# weight initialization
init = RandomNormal(stddev=0.02)
# add upsampling layer
g = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
# add batch normalization
g = BatchNormalization()(g, training=True)
# conditionally add dropout
if dropout:
g = Dropout(0.5)(g, training=True)
# merge with skip connection
g = Concatenate()([g, skip_in])
# relu activation
g = Activation('relu')(g)
return g
# define the standalone generator model
def define_generator(image_shape=(256,256,3)):
# weight initialization
init = RandomNormal(stddev=0.02)
# image input
in_image = Input(shape=image_shape)
# encoder model
e1 = define_encoder_block(in_image, 64, batchnorm=False)
e2 = define_encoder_block(e1, 128)
e3 = define_encoder_block(e2, 256)
e4 = define_encoder_block(e3, 512)
e5 = define_encoder_block(e4, 512)
e6 = define_encoder_block(e5, 512)
e7 = define_encoder_block(e6, 512)
# bottleneck, no batch norm and relu
b = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)
b = Activation('relu')(b)
# decoder model
d1 = decoder_block(b, e7, 512)
d2 = decoder_block(d1, e6, 512)
d3 = decoder_block(d2, e5, 512)
d4 = decoder_block(d3, e4, 512, dropout=False)
d5 = decoder_block(d4, e3, 256, dropout=False)
d6 = decoder_block(d5, e2, 128, dropout=False)
d7 = decoder_block(d6, e1, 64, dropout=False)
# output
g = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)
out_image = Activation('tanh')(g)
# define model
model = Model(in_image, out_image)
return model
# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model, image_shape):
# make weights in the discriminator not trainable
d_model.trainable = False
# define the source image
in_src = Input(shape=image_shape)
# connect the source image to the generator input
gen_out = g_model(in_src)
# connect the source input and generator output to the discriminator input
dis_out = d_model([in_src, gen_out])
print(dis_out)
# src image as input, generated image and classification output
model = Model(in_src, [dis_out, gen_out])
# compile model
opt = Adam(lr=0.0002, beta_1=0.5)
model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])
return model
# load and prepare training images
def load_real_samples(filename):
# load compressed arrays
data = load(filename)
# unpack arrays
X1, X2 = data['arr_0'], data['arr_1']
# scale from [0,255] to [-1,1]
X1 = (X1 - 127.5) / 127.5
X2 = (X2 - 127.5) / 127.5
return [X1, X2]
# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):
# unpack dataset
trainA, trainB = dataset
# choose random instances
ix = randint(0, trainA.shape[0], n_samples)
# retrieve selected images
X1, X2 = trainA[ix], trainB[ix]
# generate 'real' class labels (1)
y = ones((n_samples, patch_shape, patch_shape, 1))
return [X1, X2], y
# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, samples, patch_shape):
# generate fake instance
X = g_model.predict(samples)
# create 'fake' class labels (0)
y = zeros((len(X), patch_shape, patch_shape, 1))
return X, y
# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, dataset, n_samples=3):
# select a sample of input images
[X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)
# generate a batch of fake samples
X_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)
# scale all pixels from [-1,1] to [0,1]
X_realA = (X_realA + 1) / 2.0
X_realB = (X_realB + 1) / 2.0
X_fakeB = (X_fakeB + 1) / 2.0
# plot real source images
for i in range(n_samples):
pyplot.subplot(3, n_samples, 1 + i)
pyplot.axis('off')
pyplot.imshow(X_realA[i])
# plot generated target image
for i in range(n_samples):
pyplot.subplot(3, n_samples, 1 + n_samples + i)
pyplot.axis('off')
pyplot.imshow(X_fakeB[i])
# plot real target image
for i in range(n_samples):
pyplot.subplot(3, n_samples, 1 + n_samples*2 + i)
pyplot.axis('off')
pyplot.imshow(X_realB[i])
# save plot to file
filename1 = 'pix2pix_plot_%06d.png' % (step+1)
pyplot.savefig(filename1)
pyplot.close()
# save the generator model
filename2 = 'pix2pix_model_%06d.h5' % (step+1)
g_model.save(filename2)
print('>Saved: %s and %s' % (filename1, filename2))
# train pix2pix models
def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1):
# determine the output square shape of the discriminator
n_patch = d_model.output_shape[1]
# unpack dataset
trainA, trainB = dataset
# calculate the number of batches per training epoch
bat_per_epo = int(len(trainA) / n_batch)
# calculate the number of training iterations
n_steps = bat_per_epo * n_epochs
# manually enumerate epochs
for i in range(n_steps):
# select a batch of real samples
[X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)
# generate a batch of fake samples
X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
# update discriminator for real samples
d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
# update discriminator for generated samples
d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
# update the generator
g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])
# summarize performance
print('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss))
# summarize model performance
if (i+1) % (bat_per_epo * 10) == 0:
summarize_performance(i, g_model, dataset)
def start_train():
# load image data
dataset = load_real_samples('maps_256.npz')
print('Loaded', dataset[0].shape, dataset[1].shape)
# define input shape based on the loaded dataset
image_shape = dataset[0].shape[1:]
# define the models
d_model = define_discriminator(image_shape)
print(image_shape)
print(d_model.summary())
g_model = define_generator(image_shape)
# define the composite model
gan_model = define_gan(g_model, d_model, image_shape)
# train model
train(d_model, g_model, gan_model, dataset)
if __name__ == '__main__':
#d_model = define_discriminator((256,256,3))
#print(d_model.summary())
#g_model = define_generator((256,256,3))
#print(g_model.summary())
#gan_model = define_gan(g_model, d_model, (256,256,3))
#print(g_model.summary())
start_train()
相關文章
- 【深度學習理論】通俗理解生成對抗網路GAN深度學習
- GAN生成對抗網路-DCGAN原理與基本實現-深度卷積生成對抗網路03卷積
- 如何應用TFGAN快速實踐生成對抗網路?
- 對抗網路學習記錄
- 【機器學習】李宏毅——生成式對抗網路GAN機器學習
- 實戰生成對抗網路[1]:簡介
- 實戰生成對抗網路[2]:生成手寫數字
- 基於深度對抗學習的智慧模糊資料生成方法
- 第六週:生成式對抗網路
- 音視訊學習 -- 弱網對抗技術相關實踐
- GAN實戰筆記——第四章深度卷積生成對抗網路(DCGAN)筆記卷積
- LSGAN:最小二乘生成對抗網路
- 0901-生成對抗網路GAN的原理簡介
- 生成對抗網路的進步多大,請看此文
- 卷積生成對抗網路(DCGAN)---生成手寫數字卷積
- 萬字綜述之生成對抗網路(GAN)
- 【生成對抗網路學習 其三】BiGAN論文閱讀筆記及其原理理解筆記
- 網路對抗 實驗一 逆向及Bof基礎實踐說明
- 解讀生成對抗網路(GAN) 之U-GAN-IT
- 聲網 2020 實時大會後的弱網對抗實踐
- 訓練生成對抗網路的一些技巧和陷阱
- 帶自注意力機制的生成對抗網路,實現效果怎樣?
- 【生成對抗網路學習 其一】經典GAN與其存在的問題和相關改進
- 深度學習技術實踐與圖神經網路新技術深度學習神經網路
- 海量案例!生成對抗網路(GAN)的18個絕妙應用
- 深度學習的應用與實踐深度學習
- 深度學習一:深度前饋網路深度學習
- 網易易盾深度學習模型工程化實踐深度學習模型
- 深度學習之Transformer網路深度學習ORM
- 生成對抗網路綜述:從架構到訓練技巧架構
- GAN實戰筆記——第七章半監督生成對抗網路(SGAN)筆記
- 實踐 | 如何使用深度學習為照片自動生成文字描述?深度學習
- 極端影像壓縮的生成對抗網路,可生成低位元速率的高質量影像
- 極端影象壓縮的生成對抗網路,可生成低位元速率的高質量影象
- 生成對抗網路,AI將圖片轉成漫畫風格AI
- 深度學習核心技術實踐與圖神經網路新技術應用深度學習神經網路
- 深度學習之殘差網路深度學習
- 深度學習(五)之原型網路深度學習原型