實戰生成對抗網路[2]:生成手寫數字

雲水木石發表於2019-02-27

在開始本文之前,讓我們先看看一則報導:

人民網訊 據英國廣播電視公司10月25日報導,由人工智慧創作的藝術作品以432000美元(約合300萬人民幣)的高價成功拍賣。

看起來一則不起眼的新聞,其實意義深遠,它意味著人們開始認可計算機創作的藝術價值,那些沾沾自喜認為不會被人工智慧取代的藝術家也要瑟瑟發抖了。

這幅由人工智慧創作的作品長啥樣,有啥過人之處?

實戰生成對抗網路[2]:生成手寫數字

嗯,以我這種外行人士看來,實在不怎麼樣,但這不意味著人工智慧不行。要知道,AlphaGo初出道時,也只敢挑戰一下樊麾這樣的二流棋手,接下來挑戰頂級棋手李世石,人類還能勉力一戰,等進化到AlphaGo Master,零封人類棋手。然而這還沒有完,AlphaGo Zero不再學習人類棋譜,完全通過自學,碾壓AlphaGo Master,對付人類棋手,更如我們捏死一隻螞蟻那麼容易。

所以說,儘管人工智慧創作的第一副作品如同鬼畫桃符,但其潛力無可限量。

那麼,接下來我們會探討如何創作出一幅名畫?No. No.

創作一副畫並不是那麼容易。這幅名為《埃德蒙·貝拉米肖像》的畫作是由巴黎一個名為“顯而易見”(Obvious)的藝術團體創作利用人工智慧技術創作而成,這幅作品是用演算法和15000幅從14世紀到20世紀的肖像畫資料製作而成。

我們還沒有那個條件去創作一副人工智慧的畫作,但我們可以先從基本的著手,生成手寫數字。手寫數字對於機器學習的同學來說,太熟悉不過了。既然是老朋友了,那讓我們開始吧!

首先回顧一下《實戰生成對抗網路[1]:簡介》這篇文章的內容,GAN由生成器和判別器組成。簡單起見,我們選擇簡單的二層神經網路來實現生成器和判別器。

生成器

實現生成器並不難,我們採取的全連線網路拓撲結構為:100 --> 128 --> 784,最後的輸出為784是因為MNIST資料集就是由28 x 28畫素的灰度影像組成。程式碼如下:

G_W1 = tf.Variable(initializer([100, 128]), name='G_W1')
G_b1 = tf.Variable(tf.zeros(shape=[128]), name='G_b1')
G_W2 = tf.Variable(initializer([128, 784]), name='G_W2')
G_b2 = tf.Variable(tf.zeros(shape=[784]), name='G_b2')
theta_G = [G_W1, G_W2, G_b1, G_b2]

def generator(z):
  G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
  G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
  G_prob = tf.nn.sigmoid(G_log_prob)

  return G_prob
複製程式碼

判別器

判別器正好相反,以MNIST影像作為輸入並返回一個代表真實影像的概率的標量,程式碼如下:

D_W1 = tf.Variable(initializer(shape=[784, 128]), name='D_W1')
D_b1 = tf.Variable(tf.zeros(shape=[128]), name='D_b1')
D_W2 = tf.Variable(initializer(shape=[128, 1]), name='D_W2')
D_b2 = tf.Variable(tf.zeros(shape=[1]), name="D_W2")
theta_D = [D_W1, D_W2, D_b1, D_b2]

def discriminator(x):
  D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
  D_logit = tf.matmul(D_h1, D_W2) + D_b2
  D_prob = tf.nn.sigmoid(D_logit)

  return D_prob, D_logit
複製程式碼

訓練演算法

在論文arXiv: 1406.2661, 2014中給出了訓練演算法的虛擬碼:

實戰生成對抗網路[2]:生成手寫數字

TensorFlow中的優化器只能做最小化,因為為了最大化損失函式,我們在虛擬碼給出的損失函式前加上一個負號。

D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))
複製程式碼

接下來定義優化器:

# 僅更新D(X)的引數, var_list=theta_D
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
# 僅更新G(X)的引數, var_list=theta_G
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)
複製程式碼

最後進行迭代,更新引數:

for it in range(60000):
  X_mb, _ = mnist.train.next_batch(mb_size)

  _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
  _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})
複製程式碼

整個流程下來,其實和之前的深度學習演算法差不多,非常容易理解。演算法是不是有效果呢?我們可以將迭代過程中生成的手寫數字顯示出來:

實戰生成對抗網路[2]:生成手寫數字

嗯,結果雖然有點差強人意,但差不多是手寫數字的字形,而且隨著迭代,越來越接近手寫數字,可以說GAN演算法還是有效的。

小結

一個簡單的GAN網路就這麼幾行程式碼就能搞定,看樣子生成一副畫也沒有什麼難的。先不要這麼樂觀,其實,GAN網路中的坑還是不少,比如在迭代過程中,就出現過如下提示:

Iter: 9000
D loss: nan
G_loss: nan
複製程式碼

從程式碼中我們可以看出,GAN網路依然採用的梯度下降法來迭代求解引數。梯度下降的啟動會選擇一個減小所定義問題損失的方向,但是我們並沒有一個辦法來確保利用GAN網路可以進入納什均衡的狀態,這是一個高維度的非凸優化目標。網路試圖在接下來的步驟中最小化非凸優化目標,最終有可能導致進入振盪而不是收斂到底層正式目標。

另外還有模型坍塌、計數、角度以及全域性結構方面的問題,要解決這些問題,需要使用一些特殊的技巧和方法,後面我們深入各種GAN模型時將會探討。

本文完整的程式碼請參考: github.com/mogoweb/aie…

參考

  1. 首幅人工智慧畫作拍賣43.2萬美元 遠超預估價
  2. 實戰生成對抗網路[1]:簡介

image

相關文章