前言
深度學習作為人工智慧的重要手段,迎來了爆發,在NLP、CV、物聯網、無人機等多個領域都發揮了非常重要的作用。最近幾年,各種深度學習演算法層出不窮, Generative Adverarial Network(GAN)自2014年提出以來,引起廣泛關注,身為深度學習三巨頭之一的Yan Lecun對GAN的評價頗高,認為GAN是近年來在深度學習上最大的突破,是近十年來機器學習上最有意思的工作。圍繞GAN的論文數量也迅速增多,各種版本的GAN出現,主要在CV領域帶來了一些貢獻,如下圖所示。
我們可以利用GAN生成一些我們需要的影像或者文字,比如二次元頭像。
GAN簡介
GAN主要的應用是自動生成一些東西,包括影像和文字等,比如隨機給一個向量作為輸入,通過GAN的Generator生成一張圖片,或者生成一串語句。Conditional GAN的應用更多一些,比如資料集是一段文字和影像的資料對,通過訓練,GAN可以通過給定一段文字生成對應的影像。
GAN主要可以分為Generator(生成器)和Discriminator(判別器)兩個部分,其中Generator其實就是一個神經網路,輸入一個向量,可以輸出一張影像(即一個高維的向量表示),如下圖示。
Discriminator也是一個神經網路,輸入為一張影像,輸出為一個數值,輸出的數值用於判斷輸入的影像是否是真的,數值越大,說明影像是真的,數值越小,說明影像為假的,如下圖示。
Generator負責生成影像,Discriminator負責對Generator生成的影像和真實影像去進行對比,區別出真假,Generator需要不斷優化來欺騙Discriminator,以假亂真;而Discriminator也不斷優化,來提高識別能力,能夠識別出Generator的把戲。二者的這種關係可以形象地通過下圖展示。
Generator和Discriminator連線起來,形成一個比較大的深層網路,即為GAN網路。
場景描述
深度學習的各種演算法在PAI上可以通過PAI-DSW進行實現,在PAI-DSW上進行訓練資料,利用GAN自動生成二次元頭像。
資料準備
首先需要準備真實的二次元頭像作為資料集,這裡從網上找到一些共享的資源,儲存在了釘釘釘盤中,釘盤地址 ,提取密碼: c2pz,資料集如下圖示,約5萬多張:
演算法實踐
利用PAI-DSW進行GAN演算法實踐,首先需要安裝準備好環境。
首先進入到Notebook建模,建立新例項,之後開啟例項,進入Terminal,在Terminal下使用者可以像在自己本地一樣安裝相應的依賴包,進行操作。
準備好環境之後,我們可以通過如下圖示方法,將基於Tensorflow的DCGAN程式碼和資料集上傳上去。
用於訓練的DCGAN程式碼地址:github.com/carpedm20/D…,關於DCGAN的網路框架圖如下,詳細介紹可以參考論文:arxiv.org/abs/1511.06…,這裡我們不做詳述。
資料集和程式碼上傳成功,如下圖示。
其中,data目錄下的faces即為資料集,該資料夾下為對應的5萬多張真實二次元頭像。DCGAN-tensorflow為整個程式碼路徑,其中最主要的兩個程式碼檔案是main.py和model.py,其中最主要的核心程式碼如下。
def main(_):
pp.pprint(flags.FLAGS.__flags)
if FLAGS.input_width is None:
FLAGS.input_width = FLAGS.input_height
if FLAGS.output_width is None:
FLAGS.output_width = FLAGS.output_height
if not os.path.exists(FLAGS.checkpoint_dir):
os.makedirs(FLAGS.checkpoint_dir)
if not os.path.exists(FLAGS.sample_dir):
os.makedirs(FLAGS.sample_dir)
#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
run_config = tf.ConfigProto()
run_config.gpu_options.allow_growth=True
with tf.Session(config=run_config) as sess:
if FLAGS.dataset == 'mnist':
dcgan = DCGAN(
sess,
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.batch_size,
y_dim=10,
z_dim=FLAGS.generate_test_images,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
data_dir=FLAGS.data_dir)
else:
dcgan = DCGAN(
sess,
input_width=FLAGS.input_width,
input_height=FLAGS.input_height,
output_width=FLAGS.output_width,
output_height=FLAGS.output_height,
batch_size=FLAGS.batch_size,
sample_num=FLAGS.batch_size,
z_dim=FLAGS.generate_test_images,
dataset_name=FLAGS.dataset,
input_fname_pattern=FLAGS.input_fname_pattern,
crop=FLAGS.crop,
checkpoint_dir=FLAGS.checkpoint_dir,
sample_dir=FLAGS.sample_dir,
data_dir=FLAGS.data_dir)
show_all_variables()
if FLAGS.train:
dcgan.train(FLAGS)
複製程式碼
else:
# Update D network
_, summary_str = self.sess.run([d_optim, self.d_sum],
feed_dict={ self.inputs: batch_images, self.z: batch_z })
self.writer.add_summary(summary_str, counter)
# Update G network
_, summary_str = self.sess.run([g_optim, self.g_sum],
feed_dict={ self.z: batch_z })
self.writer.add_summary(summary_str, counter)
# Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
_, summary_str = self.sess.run([g_optim, self.g_sum],
feed_dict={ self.z: batch_z })
self.writer.add_summary(summary_str, counter)
errD_fake = self.d_loss_fake.eval({ self.z: batch_z })
errD_real = self.d_loss_real.eval({ self.inputs: batch_images })
errG = self.g_loss.eval({self.z: batch_z})
複製程式碼
一切就緒之後,我們執行命令進行訓練,呼叫命令如下:
python main.py --input_height 96 --input_width 96 --output_height 48 --output_width 48 --dataset faces --crop --train --epoch 300 --input_fname_pattern "*.jpg"
其中,引數dateset指定資料集的目錄,epoch指定迴圈迭代的次數,input_height、input_width用於指定輸入檔案的大小,輸出檔案的大小同樣也需要引數設定,程式碼執行過程如下圖示:
我們來看下執行結果,分別看一下epoch為1,30,100的時候生成的二次元頭像效果圖。
epoch=1
epoch=30
epoch=100
我們發現,隨著不斷迭代,生成的二次元頭像也越來越逼真。
總結
通過上面的實踐,我們領略到了GAN的魅力,GAN的變種有很多,除此之外我們還可以利用GAN做非常多的有意思的事情,比如通過文字生成影像,通過簡單文字生成宣傳海報等。PAI-DSW像是一個練武場,為我們準備好了深度學習所需要的環境和條件,讓我們可以盡情享受大資料和深度學習的樂趣,除了GAN,像比較火熱的Bert等模型,我們也都可以試一試。
本文作者:不等_趙振才
本文為雲棲社群原創內容,未經允許不得轉載。