【機器學習PAI實戰】—— 玩轉人工智慧之利用GAN自動生成二次元頭像

阿里云云棲號發表於2019-03-26

前言

深度學習作為人工智慧的重要手段,迎來了爆發,在NLP、CV、物聯網、無人機等多個領域都發揮了非常重要的作用。最近幾年,各種深度學習演算法層出不窮, Generative Adverarial Network(GAN)自2014年提出以來,引起廣泛關注,身為深度學習三巨頭之一的Yan Lecun對GAN的評價頗高,認為GAN是近年來在深度學習上最大的突破,是近十年來機器學習上最有意思的工作。圍繞GAN的論文數量也迅速增多,各種版本的GAN出現,主要在CV領域帶來了一些貢獻,如下圖所示。

【機器學習PAI實戰】—— 玩轉人工智慧之利用GAN自動生成二次元頭像

我們可以利用GAN生成一些我們需要的影象或者文字,比如二次元頭像。

AN簡介

GAN主要的應用是自動生成一些東西,包括影象和文字等,比如隨機給一個向量作為輸入,通過GAN的Generator生成一張圖片,或者生成一串語句。Conditional GAN的應用更多一些,比如資料集是一段文字和影象的資料對,通過訓練,GAN可以通過給定一段文字生成對應的影象。

GAN主要可以分為Generator(生成器)和Discriminator(判別器)兩個部分,其中Generator其實就是一個神經網路,輸入一個向量,可以輸出一張影象(即一個高維的向量表示),如下圖示。

【機器學習PAI實戰】—— 玩轉人工智慧之利用GAN自動生成二次元頭像
【機器學習PAI實戰】—— 玩轉人工智慧之利用GAN自動生成二次元頭像

Discriminator也是一個神經網路,輸入為一張影象,輸出為一個數值,輸出的數值用於判斷輸入的影象是否是真的,數值越大,說明影象是真的,數值越小,說明影象為假的,如下圖示。

【機器學習PAI實戰】—— 玩轉人工智慧之利用GAN自動生成二次元頭像
【機器學習PAI實戰】—— 玩轉人工智慧之利用GAN自動生成二次元頭像

Generator負責生成影象,Discriminator負責對Generator生成的影象和真實影象去進行對比,區別出真假,Generator需要不斷優化來欺騙Discriminator,以假亂真;而Discriminator也不斷優化,來提高識別能力,能夠識別出Generator的把戲。二者的這種關係可以形象地通過下圖展示

【機器學習PAI實戰】—— 玩轉人工智慧之利用GAN自動生成二次元頭像



Generator和Discriminator連線起來,形成一個比較大的深層網路,即為GAN網路。

場景描述

深度學習的各種演算法在PAI上可以通過PAI-DSW進行實現,在PAI-DSW上進行訓練資料,利用GAN自動生成二次元頭像。

資料準備

首先需要準備真實的二次元頭像作為資料集,這裡從網上找到一些共享的資源,儲存在了釘釘釘盤中,釘盤地址 ,提取密碼: c2pz,資料集如下圖示,約5萬多張:

【機器學習PAI實戰】—— 玩轉人工智慧之利用GAN自動生成二次元頭像

演算法實踐

利用PAI-DSW進行GAN演算法實踐,首先需要安裝準備好環境。

首先進入到Notebook建模,建立新例項,之後開啟例項,進入Terminal,在Terminal下使用者可以像在自己本地一樣安裝相應的依賴包,進行操作。

準備好環境之後,我們可以通過如下圖示方法,將基於Tensorflow的DCGAN程式碼和資料集上傳上去。

【機器學習PAI實戰】—— 玩轉人工智慧之利用GAN自動生成二次元頭像

用於訓練的DCGAN程式碼地址:github.com/carpedm20/D…,關於DCGAN的網路框架圖如下,詳細介紹可以參考論文:arxiv.org/abs/1511.06…,這裡我們不做詳述。

【機器學習PAI實戰】—— 玩轉人工智慧之利用GAN自動生成二次元頭像

資料集和程式碼上傳成功,如下圖示。

【機器學習PAI實戰】—— 玩轉人工智慧之利用GAN自動生成二次元頭像

其中,data目錄下的faces即為資料集,該資料夾下為對應的5萬多張真實二次元頭像。DCGAN-tensorflow為整個程式碼路徑,其中最主要的兩個程式碼檔案是main.py和model.py,其中最主要的核心程式碼如下。

def main(_):
  pp.pprint(flags.FLAGS.__flags)

  if FLAGS.input_width is None:
    FLAGS.input_width = FLAGS.input_height
  if FLAGS.output_width is None:
    FLAGS.output_width = FLAGS.output_height

  if not os.path.exists(FLAGS.checkpoint_dir):
    os.makedirs(FLAGS.checkpoint_dir)
  if not os.path.exists(FLAGS.sample_dir):
    os.makedirs(FLAGS.sample_dir)

  #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
  run_config = tf.ConfigProto()
  run_config.gpu_options.allow_growth=True

  with tf.Session(config=run_config) as sess:
    if FLAGS.dataset == 'mnist':
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          y_dim=10,
          z_dim=FLAGS.generate_test_images,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir,
          data_dir=FLAGS.data_dir)
    else:
      dcgan = DCGAN(
          sess,
          input_width=FLAGS.input_width,
          input_height=FLAGS.input_height,
          output_width=FLAGS.output_width,
          output_height=FLAGS.output_height,
          batch_size=FLAGS.batch_size,
          sample_num=FLAGS.batch_size,
          z_dim=FLAGS.generate_test_images,
          dataset_name=FLAGS.dataset,
          input_fname_pattern=FLAGS.input_fname_pattern,
          crop=FLAGS.crop,
          checkpoint_dir=FLAGS.checkpoint_dir,
          sample_dir=FLAGS.sample_dir,
          data_dir=FLAGS.data_dir)

    show_all_variables()

    if FLAGS.train:
      dcgan.train(FLAGS)

else:
          # Update D network
          _, summary_str = self.sess.run([d_optim, self.d_sum],
            feed_dict={ self.inputs: batch_images, self.z: batch_z })
          self.writer.add_summary(summary_str, counter)

          # Update G network
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z })
          self.writer.add_summary(summary_str, counter)

          # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
          _, summary_str = self.sess.run([g_optim, self.g_sum],
            feed_dict={ self.z: batch_z })
          self.writer.add_summary(summary_str, counter)
          
          errD_fake = self.d_loss_fake.eval({ self.z: batch_z })
          errD_real = self.d_loss_real.eval({ self.inputs: batch_images })
          errG = self.g_loss.eval({self.z: batch_z})
複製程式碼

一切就緒之後,我們執行命令進行訓練,呼叫命令如下:

​python main.py --input_height 96 --input_width 96 --output_height 48 --output_width 48 --dataset faces --crop --train --epoch 300 --input_fname_pattern "*.jpg"

其中,引數dateset指定資料集的目錄,epoch指定迴圈迭代的次數,input_height、input_width用於指定輸入檔案的大小,輸出檔案的大小同樣也需要引數設定,程式碼執行過程如下圖示:​

【機器學習PAI實戰】—— 玩轉人工智慧之利用GAN自動生成二次元頭像

我們來看下執行結果,分別看一下epoch為1,30,100的時候生成的二次元頭像效果圖。

epoch=1

【機器學習PAI實戰】—— 玩轉人工智慧之利用GAN自動生成二次元頭像

epoch=30

【機器學習PAI實戰】—— 玩轉人工智慧之利用GAN自動生成二次元頭像

epoch=100​

【機器學習PAI實戰】—— 玩轉人工智慧之利用GAN自動生成二次元頭像

我們發現,隨著不斷迭代,生成的二次元頭像也越來越逼真。

總結

通過上面的實踐,我們領略到了GAN的魅力,GAN的變種有很多,除此之外我們還可以利用GAN做非常多的有意思的事情,比如通過文字生成影象,通過簡單文字生成宣傳海報等。PAI-DSW像是一個練武場,為我們準備好了深度學習所需要的環境和條件,讓我們可以盡情享受大資料和深度學習的樂趣,除了GAN,像比較火熱的Bert等模型,我們也都可以試一試。


原文連結

本文為雲棲社群原創內容,未經允許不得轉載。


相關文章