人人都可以學的人工智慧:TensorFlow 入門例子

Bob-Chen發表於2017-02-28

這是用 TensorFlow 來識別手寫數字的官方經典入門例子,資料都是已經處理過準備好了的,但是隻到計算準確度概率那就停了,缺少拿實際圖片運用的例子,初學者看完之後難免發矇。於是,本文第二段用一些實際圖片來驗證我們的模型。文中例子基於 TensorFlow 1.0.0,看過官方文件的直接跳到後面吧。

第一部分,介紹了一些處理資料的基本知識,然後採用一個簡單的模型,用一堆準確的資料去訓練它,訓練完之後拿另一堆資料去評估一下這個模型的準確率(這也是官方例子的內容)。搞清楚這個很重要,不然看完官方例子只會覺得很厲害,但是又不知道哪厲害。

第二部分,我們會拿幾個圖片,告訴模型我們認為這個圖片是幾(當然,這個是隨便說的),然後模型告訴我們它覺得的和我們認為的是否一致。

有很多名詞和數學演算法不懂沒關係,慢慢查,先跑個例子感受一下。

文中部分圖片來自官方文件。

識別手寫圖片

因為這個例子是 TensorFlow 官方的例子,不會說的太詳細,會加入了一點個人的理解,英文文件 是最新的,中文文件 是用 0.5 版本的 TensorFlow,在 1.0 版本跑不起來,建議中文文件和英文文件交叉著看,有助於理解。

準備資料

這裡用來識別的手寫圖片大致是這樣的,為了降低複雜度,每個圖片是 28*28 大小。

人人都可以學的人工智慧:TensorFlow 入門例子

但是直接丟圖片給我們的模型,模型是不認識的,所以必須要對圖片進行一些處理。

如果瞭解線性代數,大概知道圖片的每個畫素點其實可以表示為一個二維的矩陣,對圖片做各種變換,比如翻轉啊什麼的就是對這個矩陣進行運算,於是我們的手寫圖片大概可以看成是這樣的:

人人都可以學的人工智慧:TensorFlow 入門例子

這個矩陣展開成一個向量,長度是 28*28=784。我們還需要另一個東西用來告訴模型我們認為這個圖片是幾,也就是給圖片打個 label。這個 label 也不是隨便打的,這裡用一個類似有 10 個元素的陣列,其中只有一個是 1,其它都是 0,哪位為 1 表示對應的圖片是幾,例如表示數字 8 的標籤值就是 ([0,0,0,0,0,0,0,0,1,0])。

這些就是單張圖片的資料處理,實際上為了高效的訓練模型,會把圖片資料和 label 資料分別打包到一起,也就是 MNIST 資料集了。

MNIST資料集

MNIST 資料集是一個入門級的計算機視覺資料集,官網是Yann LeCun's website。 我們不需要手動去下載這個資料集, 1.0 的 TensorFlow 會自動下載。

這個訓練資料集有 55000 個圖片資料,用張量的方式組織的,形狀如 [55000,784],如下圖:

人人都可以學的人工智慧:TensorFlow 入門例子

還記得為啥是 784 嗎,因為 28*28 的圖片。
label 也是如此,[55000,10]:

人人都可以學的人工智慧:TensorFlow 入門例子

這個資料集裡面除了有訓練用的資料之外,還有 10000 個測試模型準確度的資料。

整個資料集大概是這樣的:

人人都可以學的人工智慧:TensorFlow 入門例子

現在資料有了,來看下我們的模型。

Softmax 迴歸模型

Softmax 中文名叫歸一化指數函式,這個模型可以用來給不同的物件分配概率。比如判斷

人人都可以學的人工智慧:TensorFlow 入門例子

的時候可能認為有 80% 是 9,有 5% 認為可能是 8,因為上面都有個圈。

我們對圖片畫素值進行加權求和。比如某個畫素具有很強的證據說明這個圖不是 1,則這個畫素相應的權值為負數,相反,如果這個畫素特別有利,則權值為正數。

如下圖,紅色區域代表負數權值,藍色代表正數權值。

人人都可以學的人工智慧:TensorFlow 入門例子

同時,還有一個偏置量(bias) 用來減小一些無關的干擾量。

Softmax 迴歸模型的原理大概就是這樣,更多的推導過程,可以查閱一下官方文件,有比較詳細的內容。

說了那麼久,終於可以上程式碼了。

訓練模型

具體引入的一些包這裡就不一一列出來,主要是兩個,一個是 tensorflow 本身,另一個是官方例子裡面用來輸入資料用的方法。

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf複製程式碼

之後就可以建立我們的模型。

  # Create the model
  x = tf.placeholder(tf.float32, [None, 784])
  W = tf.Variable(tf.zeros([784, 10]))
  b = tf.Variable(tf.zeros([10]))
  y = tf.matmul(x, W) + b複製程式碼

這裡的程式碼都是類似佔位符,要填了資料才有用。

  • x 是從圖片資料檔案裡面讀來的,理解為一個常量,一個輸入值,因為是 28*28 的圖片,所以這裡是 784;
  • W 代表權重,因為有 784 個點,然後有 10 個數字的權重,所以是 [784, 10],模型運算過程中會不斷調整這個值,可以理解為一個變數;
  • b 代表偏置量,每個數字的偏置量都不同,所以這裡是 10,模型運算過程中也會不斷調整這個值,也是一個變數;
  • y 基於前面的資料矩陣乘積計算。

tf.zeros 表示初始化為 0。

我們會需要一個東西來接受正確的輸入,也就是放訓練時準確的 label。

  # Define loss and optimizer
  y_ = tf.placeholder(tf.float32, [None, 10])複製程式碼

我們會用一個叫交叉熵的東西來衡量我們的預測的「驚訝」程度。

關於交叉熵,舉個例子,我們平常寫程式碼的時候,一按編譯,一切順利,程式跑起來了,我們就沒那麼「驚訝」,因為我們的程式碼是那麼的優秀;而如果一按編譯,整個就 Crash 了,我們很「驚訝」,一臉蒙逼的想,這怎麼可能。

交叉熵感性的認識就是表達這個的,當輸出的值和我們的期望是一致的時候,我們就「驚訝」值就比較低,當輸出值不是我們期望的時候,「驚訝」值就比較高。

  cross_entropy = tf.reduce_mean(
      tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))複製程式碼

這裡就用了 TensorFlow 實現的 softmax 模型來計算交叉熵。
交叉熵,就是我們想要儘量優化的值,讓結果符合預期,不要讓我們太「驚訝」。

TensorFlow 會自動使用反向傳播演算法(backpropagation algorithm) 來有效的確定變數是如何影響你想最小化的交叉熵。然後 TensorFlow 會用你選擇的優化演算法來不斷地修改變數以降低交叉熵。

  train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)複製程式碼

這裡用了梯度下降演算法(gradient descent algorithm)來優化交叉熵,這裡是以 0.5 的速度來一點點的優化交叉熵。

之後就是初始化變數,以及啟動 Session

  sess = tf.InteractiveSession()
  tf.global_variables_initializer().run()複製程式碼

啟動之後,開始訓練!

  # Train
  for _ in range(1000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})複製程式碼

這裡訓練 1000 次,每次隨機找 100 個資料來練習,這裡的 feed_dict={x: batch_xs, y_: batch_ys},就是我們前面那設定的兩個留著佔位的輸入值。

到這裡基本訓練就完成了。

評估模型

訓練完之後,我們來評估一下模型的準確度。

  # Test trained model
  correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  print(sess.run(accuracy, feed_dict={x: mnist.test.images,
                                      y_: mnist.test.labels}))複製程式碼

tf.argmax 給出某個tensor物件在某一維上的其資料最大值所在的索引值。因為我們的 label 只有一個 1,所以 tf.argmax(_y, 1) 就是 label 的索引,也就是表示圖片是幾。把計算值和預測值 equal 一下就可以得出模型算的是否準確。
下面的 accuracy 計算的是一個整體的精確度。

這裡填入的資料不是訓練資料,是測試資料和測試 label。

最終結果,我的是 0.9151,91.51% 的準確度。官方說這個不太好,如果用一些更好的模型,比如多層卷積網路等,這個識別率可以到 99% 以上。

官方的例子到這裡就結束了,雖然說識別了幾萬張圖片,但是我一張像樣的圖片都沒看到,於是我決定想辦法拿這個模型真正找幾個圖片測試一下。

用模型測試

看下上面的例子,重點就是放測試資料進去這裡,如果我們要拿圖片測,需要先把圖片變成相應格式的資料。

sess.run(accuracy, feed_dict={x: mnist.test.images,
                                      y_: mnist.test.labels})複製程式碼

看下這裡 mnist 是從 tensorflow.examples.tutorials.mnist 中的 input_data 的 read_data_sets 方法中來的。

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)複製程式碼

Python 就是好,有啥不懂看下原始碼。原始碼的線上地址在這裡

開啟找 read_data_sets 方法,發現:

from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets複製程式碼

這個檔案裡面。

def read_data_sets(train_dir,
                   fake_data=False,
                   one_hot=False,
                   dtype=dtypes.float32,
                   reshape=True,
                   validation_size=5000):
                   ...
                   ...
                   ...
    train = DataSet(train_images, train_labels, dtype=dtype, reshape=reshape)
    validation = DataSet(validation_images,
                        validation_labels,
                        dtype=dtype,
                        reshape=reshape)
    test = DataSet(test_images, test_labels, dtype=dtype, reshape=reshape)

    return base.Datasets(train=train, validation=validation, test=test)複製程式碼

可以看到,最後返回的都是是一個物件,而我們用的 feeddict={x: mnist.test.images, y: mnist.test.labels} 就是從這來的,是一個 DataSet 物件。這個物件也在這個檔案裡面。

class DataSet(object):

  def __init__(self,
               images,
               labels,
               fake_data=False,
               one_hot=False,
               dtype=dtypes.float32,
               reshape=True):
    """Construct a DataSet.
    one_hot arg is used only if fake_data is true.  `dtype` can be either
    `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
    `[0, 1]`.
    """
    ...
    ...
    ...複製程式碼

這個物件很長,我就只挑重點了,主要看構造方法。一定要傳入的有 images 和 labels。其實這裡已經比較明朗了,我們只要把單張圖片弄成 mnist 格式,分別傳入到這個 DataSet 裡面,就可以得到我們要的資料。

網上查了下還真有,程式碼地址,對應的文章:www.jianshu.com/p/419557758…,文章講的有點不清楚,需要針對 TensorFlow 1.0 版本以及實際目錄情況做點修改。

直接上我修改後的程式碼:

from PIL import Image
from numpy import *

def GetImage(filelist):
    width=28
    height=28
    value=zeros([1,width,height,1])
    value[0,0,0,0]=-1
    label=zeros([1,10])
    label[0,0]=-1

    for filename in filelist:
        img=array(Image.open(filename).convert("L"))
        width,height=shape(img);
        index=0
        tmp_value=zeros([1,width,height,1])
        for i in range(width):
            for j in range(height):
                tmp_value[0,i,j,0]=img[i,j]
                index+=1

        if(value[0,0,0,0]==-1):
            value=tmp_value
        else:
            value=concatenate((value,tmp_value))

        tmp_label=zeros([1,10])
        index=int(filename.strip().split('/')[2][0])
        print "input:",index
        tmp_label[0,index]=1
        if(label[0,0]==-1):
            label=tmp_label
        else:
            label=concatenate((label,tmp_label))

    return array(value),array(label)複製程式碼

這裡讀取圖片依賴 PIL 這個庫,由於 PIL 比較少維護了,可以用它的一個分支 Pillow 來代替。另外依賴 numpy 這個科學計算庫,沒裝的要裝一下。

這裡就是把圖片讀取,並按 mnist 格式化,label 是取圖片檔名的第一個字,所以圖片要用數字開頭命名。

如果懶得 PS 畫或者手寫的畫,可以把測試資料集的資料給轉回圖片,實測成功,參考這篇文章:如何用python解析mnist圖片

新建一個資料夾叫 test_num,裡面圖片如下,這裡命名就是 label 值,可以看到 label 和圖片是對應的:

人人都可以學的人工智慧:TensorFlow 入門例子

開始測試:

  print("Start Test Images")

  dir_name = "./test_num"
  files = glob2.glob(dir_name + "/*.png")
  cnt = len(files)
  for i in range(cnt):
    print(files[i])
    test_img, test_label = GetImage([files[i]])

    testDataSet = DataSet(test_img, test_label, dtype=tf.float32)

    res = accuracy.eval({x: testDataSet.images, y_: testDataSet.labels})

    print("output: ",  res)
    print("----------")複製程式碼

這裡用了 glob2 這個庫來遍歷以及過濾檔案,需要安裝,常規的遍歷會把 Mac 上的 .DS_Store 檔案也會遍歷進去。

人人都可以學的人工智慧:TensorFlow 入門例子

可以看到我們打的 label 和模型算出來的是相符的。

然後我們可以打亂檔名,把 9 說成 8,把 0 也說成 8:

人人都可以學的人工智慧:TensorFlow 入門例子

可以看到,我們的 label 和模型算出來的是不相符的。

人人都可以學的人工智慧:TensorFlow 入門例子

恭喜,到著你就完成了一次簡單的人工智慧之旅。

總結

從這個例子中我們可以大致知道 TensorFlow 的執行模式:

人人都可以學的人工智慧:TensorFlow 入門例子

例子中是每次都要走一遍訓練流程,實際上是可以用 tf.train.Saver() 來儲存訓練好的模型的。這個入門例子完成之後能對 TensorFlow 有個感性認識。

TensorFlow 沒有那麼神祕,沒有我們想的那麼複雜,也沒有我們想的那麼簡單,並且還有很多數學知識要補充呢。

另外,這方面我也是個小白,文中若有錯誤之處,歡迎斧正。

demo 程式碼地址:
github.com/bob-chen/te…

碎碎念

記錄一些所思所想,寫寫科技與人文,寫寫生活狀態,寫寫讀書心得,主要是扯淡和感悟。
歡迎關注,交流。

微信公眾號:程式設計師的詩和遠方

公眾號ID : MonkeyCoder-Life

人人都可以學的人工智慧:TensorFlow 入門例子

參考

github.com/wlmnzf/tens…

blog.csdn.net/u014046170/…

www.jianshu.com/p/419557758…

zhuanlan.zhihu.com/p/22410917?…

stackoverflow.com/questions/3…

相關文章