[深度學習]生成對抗網路的實踐例子

茫茫人海一粒沙發表於2020-10-26

系列文章目錄

深度學習GAN(一)之簡單介紹
深度學習GAN(二)之DCGAN基於CIFAR10資料集的例子
深度學習GAN(三)之DCGAN基於手寫體Mnist資料集的例子
深度學習GAN(四)之cGAN (Conditional GAN)的例子
深度學習GAN(五)之PIX2PIX GAN的例子
深度學習GAN(六)之CycleGAN的例子


1. Pix2Pix介紹

Pix2Pix是一個對抗神經網路(GAN)模型,設計一般用於影像到影像轉換。

該方法由Phillip Isola等提出。在其2016年題為“使用條件對抗網路的影像到影像翻譯”的論文中,該論文於2017年在CVPR上發表。

GAN架構由用於輸出新的合理合成影像的生成器模型和將影像分類為真實(來自資料集)或偽影像(生成)的鑑別器模型組成。鑑別器模型直接更新,而生成器模型通過鑑別器模型更新。這樣,在對抗過程中同時訓練兩個模型,其中生成器試圖更好地欺騙鑑別器,而鑑別器試圖更好地識別偽造影像。

Pix2Pix模型是一種條件GAN或cGAN,其中輸出影像的生成取決於輸入(在這種情況下為源影像)。鑑別器既提供源影像又提供目標影像,並且必須確定目標是否是源影像的合理變換。

通過對抗損失訓練生成器,這鼓勵了生成器在目標域中生成合理的影像。還通過在生成的影像和預期的輸出影像之間測量的L1損耗來更新生成器。這種額外的損失鼓勵生成器模型建立源影像的合理翻譯。

Pix2Pix GAN已在一系列影像到影像轉換任務中得到了證明,例如將地圖轉換為衛星照片,將黑白照片轉換為顏色,將產品草圖轉換為產品照片。

現在我們已經熟悉了Pix2Pix GAN,下面我們準備一個可用於影像到影像轉換的資料集。

2. 下載衛星地圖資料集

這個資料集由紐約的衛星影像及其相應的Google地圖組成。 影像的轉換問題涉及將衛星照片轉換為Google地圖格式,或者將Google地圖影像轉換為衛星照片。

資料集在pix2pix網站上提供,可以作為255 MB的zip檔案下載。
Download Maps Dataset (maps.tar.gz)

下載後解壓後目錄結構如下:
在這裡插入圖片描述
進入任意一個目錄,開啟其中一個圖片,

在這裡插入圖片描述

3. 資料預處理(Data Reprocessing)

為了讓圖片在訓練的時候載入的快一點,我們把下載的所有的圖片都用Numpy儲存在maps_256.npz.

from os import listdir
from numpy import asarray
from numpy import vstack
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img
from numpy import savez_compressed

# load all images in a directory into memory
def load_images(path, size=(256,512)):
	src_list, tar_list = list(), list()
	# enumerate filenames in directory, assume all are images
	for filename in listdir(path):
		# load and resize the image
		pixels = load_img(path + filename, target_size=size)
		# convert to numpy array
		pixels = img_to_array(pixels)
		# split into satellite and map
		sat_img, map_img = pixels[:, :256], pixels[:, 256:]
		src_list.append(sat_img)
		tar_list.append(map_img)
	return [asarray(src_list), asarray(tar_list)]

# dataset path
path = 'D:/ML/datasets/maps/train/'
# load dataset
[src_images, tar_images] = load_images(path)
print('Loaded: ', src_images.shape, tar_images.shape)
# save as compressed numpy array
filename = 'maps_256.npz'
savez_compressed(filename, src_images, tar_images)
print('Saved dataset: ', filename)

結果是

Loaded:  (1096, 256, 256, 3) (1096, 256, 256, 3)
Saved dataset:  maps_256.npz

然後執行下面程式碼驗證一下是否正確的可以顯示圖片。

# load the prepared dataset
from numpy import load
from matplotlib import pyplot
# load the dataset
data = load('maps_256.npz')
src_images, tar_images = data['arr_0'], data['arr_1']
print('Loaded: ', src_images.shape, tar_images.shape)
# plot source images
n_samples = 3
for i in range(n_samples):
	pyplot.subplot(2, n_samples, 1 + i)
	pyplot.axis('off')
	pyplot.imshow(src_images[i].astype('uint8'))
# plot target image
for i in range(n_samples):
	pyplot.subplot(2, n_samples, 1 + n_samples + i)
	pyplot.axis('off')
	pyplot.imshow(tar_images[i].astype('uint8'))
pyplot.show()

在這裡插入圖片描述

4. 定義判別器

這個判別器是基於PatchGAN discriminator model實現的。

注意這裡的輸入是兩個圖片,in_src_image是衛星影像, in_target_image是谷歌地圖。
同過Concatenate方法,合併為6個通道,每天圖片是的3個通道(RGB).
啟用函式用LeakyReLU, 除了第一層與最後一層,其它都用BatchNormalization.
輸出層輸出是(16,16,1)

# define the discriminator model
def define_discriminator(image_shape):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# source image input
	in_src_image = Input(shape=image_shape)
	# target image input
	in_target_image = Input(shape=image_shape)
	# concatenate images channel-wise
	merged = Concatenate()([in_src_image, in_target_image])
	# C64
	d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
	d = LeakyReLU(alpha=0.2)(d)
	# C128
	d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C256
	d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C512
	d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# second last output layer
	d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# patch output
	d = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
	patch_out = Activation('sigmoid')(d)
	# define model
	model = Model([in_src_image, in_target_image], patch_out)
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
	return model
if __name__ == '__main__':
	d_model = define_discriminator((256,256,3))
	print(d_model.summary())

它的結構是
在這裡插入圖片描述

5. 定義生成器

生成器是使用U-Net架構的encoder-decoder模型。 該模型獲取源影像(例如衛星照片)並生成目標影像(例如Google地圖影像)。 它首先通過對輸入影像進行下采樣或編碼到瓶頸層(bottleneck layer),然後對瓶頸(bottleneck layer)表示進行上取樣或解碼到輸出影像的大小來做到這一點。 U-Net體系結構意味著在編碼層和相應的解碼層之間新增跳過連線(skip-connections),從而形成U形。

下圖清楚地顯示了跳過連線(skip-connections),顯示了編碼器的第一層如何連線到解碼器的最後一層,依此類推。

在這裡插入圖片描述

生成器的encoder和decoder由convolutional, batch normalization, dropout, and activation layers組成。 這種標準化意味著我們可以開發輔助函式來建立每個圖層塊,並反覆呼叫它以建立模型的encoder和decoder部分。

下面的define_generator()函式實現了U-Net編碼器-解碼器生成器模型。 它使用define_encoder_block()幫助函式建立用於編碼器的層塊,並使用coder_block()函式建立用於解碼器的層塊。 tanh啟用函式在輸出層中使用,這意味著生成的影像中的畫素值將在[-1,1]範圍內。

輸入是一個文星圖片,經過Encoder-Decoder這個網路結構,最後生成一個谷歌地圖
(256,256,3) ->Encoder-> (1,1,512) -> Decoder -> (256,256,3)

# define an encoder block
def define_encoder_block(layer_in, n_filters, batchnorm=True):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# add downsampling layer
	g = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
	# conditionally add batch normalization
	if batchnorm:
		g = BatchNormalization()(g, training=True)
	# leaky relu activation
	g = LeakyReLU(alpha=0.2)(g)
	return g

# define a decoder block
def decoder_block(layer_in, skip_in, n_filters, dropout=True):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# add upsampling layer
	g = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
	# add batch normalization
	g = BatchNormalization()(g, training=True)
	# conditionally add dropout
	if dropout:
		g = Dropout(0.5)(g, training=True)
	# merge with skip connection
	g = Concatenate()([g, skip_in])
	# relu activation
	g = Activation('relu')(g)
	return g

# define the standalone generator model
def define_generator(image_shape=(256,256,3)):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image input
	in_image = Input(shape=image_shape)
	# encoder model
	e1 = define_encoder_block(in_image, 64, batchnorm=False)
	e2 = define_encoder_block(e1, 128)
	e3 = define_encoder_block(e2, 256)
	e4 = define_encoder_block(e3, 512)
	e5 = define_encoder_block(e4, 512)
	e6 = define_encoder_block(e5, 512)
	e7 = define_encoder_block(e6, 512)
	# bottleneck, no batch norm and relu
	b = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)
	b = Activation('relu')(b)
	# decoder model
	d1 = decoder_block(b, e7, 512)
	d2 = decoder_block(d1, e6, 512)
	d3 = decoder_block(d2, e5, 512)
	d4 = decoder_block(d3, e4, 512, dropout=False)
	d5 = decoder_block(d4, e3, 256, dropout=False)
	d6 = decoder_block(d5, e2, 128, dropout=False)
	d7 = decoder_block(d6, e1, 64, dropout=False)
	# output
	g = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)
	out_image = Activation('tanh')(g)
	# define model
	model = Model(in_image, out_image)
	return model
if __name__ == '__main__':
	g_model = define_generator((256,256,3))
	print(g_model.summary())

它的結構是

在這裡插入圖片描述

6. 定義GAN模型

GAN的模型主要是訓練生成器,所以判別器不訓練(d_model.trainable = False)。
輸入層是衛星圖片(256,256,3),
輸出層是 dis_out=(16,16,1)
gen_out = (256,256,3)

# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model, image_shape):
	# make weights in the discriminator not trainable
	d_model.trainable = False
	# define the source image
	in_src = Input(shape=image_shape)
	# connect the source image to the generator input
	gen_out = g_model(in_src)
	# connect the source input and generator output to the discriminator input
	dis_out = d_model([in_src, gen_out])
	# src image as input, generated image and classification output
	model = Model(in_src, [dis_out, gen_out])
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])
	return model
if __name__ == '__main__':
	d_model = define_discriminator((256,256,3))
	g_model = define_generator((256,256,3))
	gan_model = define_gan(g_model, d_model, (256,256,3))
	print(g_model.summary())
Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_3 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
conv2d_6 (Conv2D)               (None, 128, 128, 64) 3136        input_3[0][0]                    
__________________________________________________________________________________________________
leaky_re_lu_5 (LeakyReLU)       (None, 128, 128, 64) 0           conv2d_6[0][0]                   
__________________________________________________________________________________________________
conv2d_7 (Conv2D)               (None, 64, 64, 128)  131200      leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
batch_normalization_4 (BatchNor (None, 64, 64, 128)  512         conv2d_7[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_6 (LeakyReLU)       (None, 64, 64, 128)  0           batch_normalization_4[0][0]      
__________________________________________________________________________________________________
conv2d_8 (Conv2D)               (None, 32, 32, 256)  524544      leaky_re_lu_6[0][0]              
__________________________________________________________________________________________________
batch_normalization_5 (BatchNor (None, 32, 32, 256)  1024        conv2d_8[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_7 (LeakyReLU)       (None, 32, 32, 256)  0           batch_normalization_5[0][0]      
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 16, 16, 512)  2097664     leaky_re_lu_7[0][0]              
__________________________________________________________________________________________________
batch_normalization_6 (BatchNor (None, 16, 16, 512)  2048        conv2d_9[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_8 (LeakyReLU)       (None, 16, 16, 512)  0           batch_normalization_6[0][0]      
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 8, 8, 512)    4194816     leaky_re_lu_8[0][0]              
__________________________________________________________________________________________________
batch_normalization_7 (BatchNor (None, 8, 8, 512)    2048        conv2d_10[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_9 (LeakyReLU)       (None, 8, 8, 512)    0           batch_normalization_7[0][0]      
__________________________________________________________________________________________________
conv2d_11 (Conv2D)              (None, 4, 4, 512)    4194816     leaky_re_lu_9[0][0]              
__________________________________________________________________________________________________
batch_normalization_8 (BatchNor (None, 4, 4, 512)    2048        conv2d_11[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_10 (LeakyReLU)      (None, 4, 4, 512)    0           batch_normalization_8[0][0]      
__________________________________________________________________________________________________
conv2d_12 (Conv2D)              (None, 2, 2, 512)    4194816     leaky_re_lu_10[0][0]             
__________________________________________________________________________________________________
batch_normalization_9 (BatchNor (None, 2, 2, 512)    2048        conv2d_12[0][0]                  
__________________________________________________________________________________________________
leaky_re_lu_11 (LeakyReLU)      (None, 2, 2, 512)    0           batch_normalization_9[0][0]      
__________________________________________________________________________________________________
conv2d_13 (Conv2D)              (None, 1, 1, 512)    4194816     leaky_re_lu_11[0][0]             
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 1, 1, 512)    0           conv2d_13[0][0]                  
__________________________________________________________________________________________________
conv2d_transpose (Conv2DTranspo (None, 2, 2, 512)    4194816     activation_1[0][0]               
__________________________________________________________________________________________________
batch_normalization_10 (BatchNo (None, 2, 2, 512)    2048        conv2d_transpose[0][0]           
__________________________________________________________________________________________________
dropout (Dropout)               (None, 2, 2, 512)    0           batch_normalization_10[0][0]     
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 2, 2, 1024)   0           dropout[0][0]                    
                                                                 leaky_re_lu_11[0][0]             
__________________________________________________________________________________________________
activation_2 (Activation)       (None, 2, 2, 1024)   0           concatenate_1[0][0]              
__________________________________________________________________________________________________
conv2d_transpose_1 (Conv2DTrans (None, 4, 4, 512)    8389120     activation_2[0][0]               
__________________________________________________________________________________________________
batch_normalization_11 (BatchNo (None, 4, 4, 512)    2048        conv2d_transpose_1[0][0]         
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 4, 4, 512)    0           batch_normalization_11[0][0]     
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 4, 4, 1024)   0           dropout_1[0][0]                  
                                                                 leaky_re_lu_10[0][0]             
__________________________________________________________________________________________________
activation_3 (Activation)       (None, 4, 4, 1024)   0           concatenate_2[0][0]              
__________________________________________________________________________________________________
conv2d_transpose_2 (Conv2DTrans (None, 8, 8, 512)    8389120     activation_3[0][0]               
__________________________________________________________________________________________________
batch_normalization_12 (BatchNo (None, 8, 8, 512)    2048        conv2d_transpose_2[0][0]         
__________________________________________________________________________________________________
dropout_2 (Dropout)             (None, 8, 8, 512)    0           batch_normalization_12[0][0]     
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 8, 8, 1024)   0           dropout_2[0][0]                  
                                                                 leaky_re_lu_9[0][0]              
__________________________________________________________________________________________________
activation_4 (Activation)       (None, 8, 8, 1024)   0           concatenate_3[0][0]              
__________________________________________________________________________________________________
conv2d_transpose_3 (Conv2DTrans (None, 16, 16, 512)  8389120     activation_4[0][0]               
__________________________________________________________________________________________________
batch_normalization_13 (BatchNo (None, 16, 16, 512)  2048        conv2d_transpose_3[0][0]         
__________________________________________________________________________________________________
concatenate_4 (Concatenate)     (None, 16, 16, 1024) 0           batch_normalization_13[0][0]     
                                                                 leaky_re_lu_8[0][0]              
__________________________________________________________________________________________________
activation_5 (Activation)       (None, 16, 16, 1024) 0           concatenate_4[0][0]              
__________________________________________________________________________________________________
conv2d_transpose_4 (Conv2DTrans (None, 32, 32, 256)  4194560     activation_5[0][0]               
__________________________________________________________________________________________________
batch_normalization_14 (BatchNo (None, 32, 32, 256)  1024        conv2d_transpose_4[0][0]         
__________________________________________________________________________________________________
concatenate_5 (Concatenate)     (None, 32, 32, 512)  0           batch_normalization_14[0][0]     
                                                                 leaky_re_lu_7[0][0]              
__________________________________________________________________________________________________
activation_6 (Activation)       (None, 32, 32, 512)  0           concatenate_5[0][0]              
__________________________________________________________________________________________________
conv2d_transpose_5 (Conv2DTrans (None, 64, 64, 128)  1048704     activation_6[0][0]               
__________________________________________________________________________________________________
batch_normalization_15 (BatchNo (None, 64, 64, 128)  512         conv2d_transpose_5[0][0]         
__________________________________________________________________________________________________
concatenate_6 (Concatenate)     (None, 64, 64, 256)  0           batch_normalization_15[0][0]     
                                                                 leaky_re_lu_6[0][0]              
__________________________________________________________________________________________________
activation_7 (Activation)       (None, 64, 64, 256)  0           concatenate_6[0][0]              
__________________________________________________________________________________________________
conv2d_transpose_6 (Conv2DTrans (None, 128, 128, 64) 262208      activation_7[0][0]               
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 128, 128, 64) 256         conv2d_transpose_6[0][0]         
__________________________________________________________________________________________________
concatenate_7 (Concatenate)     (None, 128, 128, 128 0           batch_normalization_16[0][0]     
                                                                 leaky_re_lu_5[0][0]              
__________________________________________________________________________________________________
activation_8 (Activation)       (None, 128, 128, 128 0           concatenate_7[0][0]              
__________________________________________________________________________________________________
conv2d_transpose_7 (Conv2DTrans (None, 256, 256, 3)  6147        activation_8[0][0]               
__________________________________________________________________________________________________
activation_9 (Activation)       (None, 256, 256, 3)  0           conv2d_transpose_7[0][0]         
==================================================================================================
Total params: 54,429,315
Trainable params: 54,419,459
Non-trainable params: 9,856

7. 載入真實圖片以及生成假的圖片

load_real_samples方法是載入真實圖片。
generate_real_samples 方法是生成真實圖片。每個陣列標籤都是1, shape是(16,16,1)
generate_fake_samples方法是生成假的圖片。每個陣列標籤都是0,shape是(16,16,1)
標籤這裡不一樣,一般是數字,但是這裡是shape為(16,16,1)三維陣列。

# load and prepare training images
def load_real_samples(filename):
	# load compressed arrays
	data = load(filename)
	# unpack arrays
	X1, X2 = data['arr_0'], data['arr_1']
	# scale from [0,255] to [-1,1]
	X1 = (X1 - 127.5) / 127.5
	X2 = (X2 - 127.5) / 127.5
	return [X1, X2]

# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):
	# unpack dataset
	trainA, trainB = dataset
	# choose random instances
	ix = randint(0, trainA.shape[0], n_samples)
	# retrieve selected images
	X1, X2 = trainA[ix], trainB[ix]
	# generate 'real' class labels (1)
	y = ones((n_samples, patch_shape, patch_shape, 1))
	return [X1, X2], y

# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, samples, patch_shape):
	# generate fake instance
	X = g_model.predict(samples)
	# create 'fake' class labels (0)
	y = zeros((len(X), patch_shape, patch_shape, 1))
	return X, y

8. 用生成器每個幾個Epoch生成一些假的圖片。看看效果

# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, dataset, n_samples=3):
	# select a sample of input images
	[X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)
	# generate a batch of fake samples
	X_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)
	# scale all pixels from [-1,1] to [0,1]
	X_realA = (X_realA + 1) / 2.0
	X_realB = (X_realB + 1) / 2.0
	X_fakeB = (X_fakeB + 1) / 2.0
	# plot real source images
	for i in range(n_samples):
		pyplot.subplot(3, n_samples, 1 + i)
		pyplot.axis('off')
		pyplot.imshow(X_realA[i])
	# plot generated target image
	for i in range(n_samples):
		pyplot.subplot(3, n_samples, 1 + n_samples + i)
		pyplot.axis('off')
		pyplot.imshow(X_fakeB[i])
	# plot real target image
	for i in range(n_samples):
		pyplot.subplot(3, n_samples, 1 + n_samples*2 + i)
		pyplot.axis('off')
		pyplot.imshow(X_realB[i])
	# save plot to file
	filename1 = 'pix2pix_plot_%06d.png' % (step+1)
	pyplot.savefig(filename1)
	pyplot.close()
	# save the generator model
	filename2 = 'pix2pix_model_%06d.h5' % (step+1)
	g_model.save(filename2)
	print('>Saved: %s and %s' % (filename1, filename2))

10. 訓練過程

# train pix2pix models
def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1):
	# determine the output square shape of the discriminator
	n_patch = d_model.output_shape[1]
	# unpack dataset
	trainA, trainB = dataset
	# calculate the number of batches per training epoch
	bat_per_epo = int(len(trainA) / n_batch)
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# manually enumerate epochs
	for i in range(n_steps):
		# select a batch of real samples
		[X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)
		# generate a batch of fake samples
		X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
		# update discriminator for real samples
		d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
		# update discriminator for generated samples
		d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
		# update the generator
		g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])
		# summarize performance
		print('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss))
		# summarize model performance
		if (i+1) % (bat_per_epo * 10) == 0:
			summarize_performance(i, g_model, dataset)

11. 訓練後效果

在前10個時間段之後,儘管街道的線條並不完全筆直且影像中有些模糊,但仍會生成看起來合理的地圖影像。 但是,大型結構在正確的位置帶有大多數正確的顏色。

在這裡插入圖片描述
經過約50個訓練時期後生成的影像開始看起來非常逼真,至少意味著,並且在其餘訓練過程中質量似乎仍然保持良好。

請注意下面第一個生成的影像示例(右列,中間行),該示例包含比真實Google地圖影像更有用的細節。
在這裡插入圖片描述

12.完整的程式碼

# example of pix2pix gan for satellite to map image-to-image translation
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt
from numpy import load
from numpy import zeros
from numpy import ones
from numpy.random import randint
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose

from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Concatenate
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import LeakyReLU
from matplotlib import pyplot

# define the discriminator model
def define_discriminator(image_shape):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# source image input
	in_src_image = Input(shape=image_shape)
	# target image input
	in_target_image = Input(shape=image_shape)
	# concatenate images channel-wise
	merged = Concatenate()([in_src_image, in_target_image])
	# C64
	d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(merged)
	d = LeakyReLU(alpha=0.2)(d)
	# C128
	d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C256
	d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C512
	d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# second last output layer
	d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
	d = BatchNormalization()(d)
	d = LeakyReLU(alpha=0.2)(d)
	# patch output
	d = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
	patch_out = Activation('sigmoid')(d)
	# define model
	model = Model([in_src_image, in_target_image], patch_out)
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss='binary_crossentropy', optimizer=opt, loss_weights=[0.5])
	return model

# define an encoder block
def define_encoder_block(layer_in, n_filters, batchnorm=True):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# add downsampling layer
	g = Conv2D(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
	# conditionally add batch normalization
	if batchnorm:
		g = BatchNormalization()(g, training=True)
	# leaky relu activation
	g = LeakyReLU(alpha=0.2)(g)
	return g

# define a decoder block
def decoder_block(layer_in, skip_in, n_filters, dropout=True):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# add upsampling layer
	g = Conv2DTranspose(n_filters, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(layer_in)
	# add batch normalization
	g = BatchNormalization()(g, training=True)
	# conditionally add dropout
	if dropout:
		g = Dropout(0.5)(g, training=True)
	# merge with skip connection
	g = Concatenate()([g, skip_in])
	# relu activation
	g = Activation('relu')(g)
	return g

# define the standalone generator model
def define_generator(image_shape=(256,256,3)):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image input
	in_image = Input(shape=image_shape)
	# encoder model
	e1 = define_encoder_block(in_image, 64, batchnorm=False)
	e2 = define_encoder_block(e1, 128)
	e3 = define_encoder_block(e2, 256)
	e4 = define_encoder_block(e3, 512)
	e5 = define_encoder_block(e4, 512)
	e6 = define_encoder_block(e5, 512)
	e7 = define_encoder_block(e6, 512)
	# bottleneck, no batch norm and relu
	b = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(e7)
	b = Activation('relu')(b)
	# decoder model
	d1 = decoder_block(b, e7, 512)
	d2 = decoder_block(d1, e6, 512)
	d3 = decoder_block(d2, e5, 512)
	d4 = decoder_block(d3, e4, 512, dropout=False)
	d5 = decoder_block(d4, e3, 256, dropout=False)
	d6 = decoder_block(d5, e2, 128, dropout=False)
	d7 = decoder_block(d6, e1, 64, dropout=False)
	# output
	g = Conv2DTranspose(3, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d7)
	out_image = Activation('tanh')(g)
	# define model
	model = Model(in_image, out_image)
	return model

# define the combined generator and discriminator model, for updating the generator
def define_gan(g_model, d_model, image_shape):
	# make weights in the discriminator not trainable
	d_model.trainable = False
	# define the source image
	in_src = Input(shape=image_shape)
	# connect the source image to the generator input
	gen_out = g_model(in_src)
	# connect the source input and generator output to the discriminator input
	dis_out = d_model([in_src, gen_out])
	print(dis_out)
	# src image as input, generated image and classification output
	model = Model(in_src, [dis_out, gen_out])
	# compile model
	opt = Adam(lr=0.0002, beta_1=0.5)
	model.compile(loss=['binary_crossentropy', 'mae'], optimizer=opt, loss_weights=[1,100])
	return model

# load and prepare training images
def load_real_samples(filename):
	# load compressed arrays
	data = load(filename)
	# unpack arrays
	X1, X2 = data['arr_0'], data['arr_1']
	# scale from [0,255] to [-1,1]
	X1 = (X1 - 127.5) / 127.5
	X2 = (X2 - 127.5) / 127.5
	return [X1, X2]

# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples, patch_shape):
	# unpack dataset
	trainA, trainB = dataset
	# choose random instances
	ix = randint(0, trainA.shape[0], n_samples)
	# retrieve selected images
	X1, X2 = trainA[ix], trainB[ix]
	# generate 'real' class labels (1)
	y = ones((n_samples, patch_shape, patch_shape, 1))
	return [X1, X2], y

# generate a batch of images, returns images and targets
def generate_fake_samples(g_model, samples, patch_shape):
	# generate fake instance
	X = g_model.predict(samples)
	# create 'fake' class labels (0)
	y = zeros((len(X), patch_shape, patch_shape, 1))
	return X, y

# generate samples and save as a plot and save the model
def summarize_performance(step, g_model, dataset, n_samples=3):
	# select a sample of input images
	[X_realA, X_realB], _ = generate_real_samples(dataset, n_samples, 1)
	# generate a batch of fake samples
	X_fakeB, _ = generate_fake_samples(g_model, X_realA, 1)
	# scale all pixels from [-1,1] to [0,1]
	X_realA = (X_realA + 1) / 2.0
	X_realB = (X_realB + 1) / 2.0
	X_fakeB = (X_fakeB + 1) / 2.0
	# plot real source images
	for i in range(n_samples):
		pyplot.subplot(3, n_samples, 1 + i)
		pyplot.axis('off')
		pyplot.imshow(X_realA[i])
	# plot generated target image
	for i in range(n_samples):
		pyplot.subplot(3, n_samples, 1 + n_samples + i)
		pyplot.axis('off')
		pyplot.imshow(X_fakeB[i])
	# plot real target image
	for i in range(n_samples):
		pyplot.subplot(3, n_samples, 1 + n_samples*2 + i)
		pyplot.axis('off')
		pyplot.imshow(X_realB[i])
	# save plot to file
	filename1 = 'pix2pix_plot_%06d.png' % (step+1)
	pyplot.savefig(filename1)
	pyplot.close()
	# save the generator model
	filename2 = 'pix2pix_model_%06d.h5' % (step+1)
	g_model.save(filename2)
	print('>Saved: %s and %s' % (filename1, filename2))

# train pix2pix models
def train(d_model, g_model, gan_model, dataset, n_epochs=100, n_batch=1):
	# determine the output square shape of the discriminator
	n_patch = d_model.output_shape[1]
	# unpack dataset
	trainA, trainB = dataset
	# calculate the number of batches per training epoch
	bat_per_epo = int(len(trainA) / n_batch)
	# calculate the number of training iterations
	n_steps = bat_per_epo * n_epochs
	# manually enumerate epochs
	for i in range(n_steps):
		# select a batch of real samples
		[X_realA, X_realB], y_real = generate_real_samples(dataset, n_batch, n_patch)
		# generate a batch of fake samples
		X_fakeB, y_fake = generate_fake_samples(g_model, X_realA, n_patch)
		# update discriminator for real samples
		d_loss1 = d_model.train_on_batch([X_realA, X_realB], y_real)
		# update discriminator for generated samples
		d_loss2 = d_model.train_on_batch([X_realA, X_fakeB], y_fake)
		# update the generator
		g_loss, _, _ = gan_model.train_on_batch(X_realA, [y_real, X_realB])
		# summarize performance
		print('>%d, d1[%.3f] d2[%.3f] g[%.3f]' % (i+1, d_loss1, d_loss2, g_loss))
		# summarize model performance
		if (i+1) % (bat_per_epo * 10) == 0:
			summarize_performance(i, g_model, dataset)

def start_train():
	# load image data
	dataset = load_real_samples('maps_256.npz')
	print('Loaded', dataset[0].shape, dataset[1].shape)
	# define input shape based on the loaded dataset
	image_shape = dataset[0].shape[1:]
	# define the models
	d_model = define_discriminator(image_shape)
	print(image_shape)
	print(d_model.summary())
	g_model = define_generator(image_shape)
	# define the composite model
	gan_model = define_gan(g_model, d_model, image_shape)
	# train model
	train(d_model, g_model, gan_model, dataset)

if __name__ == '__main__':
	#d_model = define_discriminator((256,256,3))
	#print(d_model.summary())
	#g_model = define_generator((256,256,3))
	#print(g_model.summary())
	#gan_model = define_gan(g_model, d_model, (256,256,3))
	#print(g_model.summary())
	start_train()

相關文章