深度學習例項之基於mnist的手寫數字識別

笨拙的石頭發表於2018-05-25

本文主要是介紹基於mnist資料集的手寫數字識別.

一 資料集

    mnist 資料集:包含 7 萬張黑底白字手寫數字圖片, 其中 55000 張為訓練集,5000 張為驗證集, 10000 張為測試集。每張圖片大小為 28*28 畫素,圖片中純黑色畫素值為 0, 純白色畫素值為 1。資料集的標籤是長度為 10 的一維陣列,陣列中每個元素索引號表示對應數字出現的概率。在將 mnist 資料集作為輸入喂入神經網路時,需先將資料集中每張圖片變為長度784 一維陣列,將該陣列作為神經網路輸入特徵喂入神經網路。

    1. 使用tensorflow提供的資料集mnist,具體的載入方法為:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data(data_path,one_hot=True)

    2. 資料集分為train,validation,test三個資料集.

        ① 返回資料集train樣本數   mnist.train.num_examples

        ② 返回資料集validation樣本數 mnist.validation.num_examples

        ③ 返回資料集test樣本數  mnist.test.num_examples

    3. 使用mnist.train.images返回train資料集中的所有圖片的畫素值

    4. 使用mnist.train.labels返回train資料集中的所有圖片的標籤

    5. 使用mnist.train.next_batch()將資料輸入神經網路

二 前向計算(得到預測值)

    廢話不說了,直接看程式碼.(mnist_forward.py)

# _*_coding:utf-8_*_

import tensorflow as tf

input_node = 784
output_node = 10
layer1_node = 500


def get_weight(shape, regularizer):
    # 表示要求產生的資料服從正態分佈, 並且每個值與均值之間的差值均小於兩倍的標準差
    w = tf.Variable(tf.truncated_normal(shape, stddev=0.1))
    # 判斷是否進行正則化操作
    if regularizer is not None:
        tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
    return w


def get_bias(shape):
    b = tf.Variable(tf.zeros(shape))
    return b


def forward(x, regulaizer):
    # 輸入層到Layer1層
    w1 = get_weight(shape=[input_node, layer1_node], regularizer=regulaizer)
    b1 = get_bias(shape=[layer1_node])
    y1 = tf.nn.relu(tf.add(tf.matmul(x, w1), b1))
    # 從Layer1層到輸出層
    w2 = get_weight(shape=[layer1_node, output_node], regularizer=regulaizer)
    b2 = get_bias(shape=[output_node])
    y = tf.add(tf.matmul(y1, w2), b2)
    return y

三 反向計算(引數更新)

    (mnist_backward.py)
# _*_coding:utf-8_*_

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os

batch_size = 200  # 每輪訓練的圖片數量
learning_rate_base = 0.1  # 初始學習率
learning_rate_decay = 0.99  # 學習率衰減率
regularizer = 0.0001  # 正則化係數
total_steps = 50000  # 訓練輪數
moving_average_decay = 0.99
model_save_path = './model/'
model_name = 'mnist_model'


def backward(mnist):
    # 設定x,y
    x = tf.placeholder(tf.float32, [None, mnist_forward.input_node])
    y = tf.placeholder(tf.float32, [None, mnist_forward.output_node])
    y_hat = mnist_forward.forward(x, regularizer)  # 獲取forward的返回值
    global_step = tf.Variable(0, trainable=False)  # 當前輪數值初始化

    # 損失函式(softmax和交叉熵共同組成的loss,在加上正則化損失的總和)
    ce = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y, 1), logits=y_hat)
    cem = tf.reduce_mean(ce)  # 求所有元素的均值
    loss = cem + tf.add_n(tf.get_collection('losses'))  # 得到包含所有引數損失的損失函式

    # 學習率梯度衰減的模型
    '''
    decayed_learning_rate = learning_rate *decay_rate ^ (global_step / decay_steps)
    '''
    learning_rate = tf.train.exponential_decay(learning_rate_base, global_step, mnist.train.num_examples / batch_size,
                                               learning_rate_decay, staircase=True)
    # 定義引數優化方法
    train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step)

    # 定義引數的滑動平均
    ema = tf.train.ExponentialMovingAverage(moving_average_decay, global_step)
    ema_op = ema.apply(tf.trainable_variables())

    # 定義引數的控制依賴
    with tf.control_dependencies([train_step, ema_op]):
        train_op = tf.no_op(name='train')

    # 建立模型儲存的例項化物件
    saver = tf.train.Saver()

    with tf.Session() as sess:
        # 引數初始化
        sess.run(tf.global_variables_initializer())

        ckpt = tf.train.get_checkpoint_state(model_save_path)

        # 判斷是否有ckpt模型,有則恢復模型(這種方法簡單便捷,不進行重複的訓練)
        if ckpt and ckpt.model_checkpoint_path:
            # 恢復會話,繼續訓練模型
            saver.restore(sess, ckpt.model_checkpoint_path)

        # 模型訓練,迭代
        for i in range(total_steps):
            xs, ys = mnist.train.next_batch(batch_size)
            _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y: ys})
            if i % 100 == 0:
                print('After {} train step(s), loss on training batch is {}'.format(step, loss_value))
                # 將當前會話載入到指定路徑
                saver.save(sess, os.path.join(model_save_path, model_name), global_step=global_step)


def main():
    mnist = input_data.read_data_sets('./mnist', one_hot=True)
    backward(mnist)


if __name__ == '__main__':
    main()

四 測試程式碼

        (mnist_test.py)
# _*_coding:utf-8_*_

import tensorflow as tf
import time
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backward

sleep_time = 5


def mnist_test(mnist):
    with tf.Graph().as_default() as g:
        x = tf.placeholder(tf.float32, shape=[None, mnist_forward.input_node], name='x')
        y = tf.placeholder(tf.float32, shape=[None, mnist_forward.output_node], name='y')
        y_hat = mnist_forward.forward(x, None)

        ema = tf.train.ExponentialMovingAverage(mnist_backward.moving_average_decay)
        ema_restore = ema.variables_to_restore()
        # 建立模型儲存的例項化物件
        saver = tf.train.Saver(ema_restore)

        # 模型的準確率計算
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_hat, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        # 迴圈計算
        while True:
            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state(mnist_backward.model_save_path)
                # 判斷是否有ckpt模型,有則恢復模型
                if ckpt and ckpt.model_checkpoint_path:
                    # 恢復會話
                    saver.restore(sess, ckpt.model_checkpoint_path)
                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
                    print('After {} training steps, test accuracy is {}'.format(global_step, accuracy_score))
                else:
                    print('No checkpoint file found')
                    return
            time.sleep(sleep_time)


def main():
    mnist = input_data.read_data_sets('./mnist', one_hot=True)
    mnist_test(mnist)


if __name__ == '__main__':
    main()

五 小結

    筆者覺得文章程式碼中的一個亮點(較為實用,不需要每次執行程式碼都從頭開始,節約時間!!!):

# 判斷是否有ckpt模型,有則恢復模型(這種方法簡單便捷,不進行重複的訓練)
        if ckpt and ckpt.model_checkpoint_path:
            # 恢復會話,繼續訓練模型
            saver.restore(sess, ckpt.model_checkpoint_path)
    說明: 本文主要參考了'Tenorflow筆記'這門課程的內容,在這課程中還有講解資料集的製作等沒有在這裡寫出, 如有問題或者建議可以歡迎給作者留言,謝謝!!

相關文章