深度學習例項之基於mnist的手寫數字識別
本文主要是介紹基於mnist資料集的手寫數字識別.
一 資料集
mnist 資料集:包含 7 萬張黑底白字手寫數字圖片, 其中 55000 張為訓練集,5000 張為驗證集, 10000 張為測試集。每張圖片大小為 28*28 畫素,圖片中純黑色畫素值為 0, 純白色畫素值為 1。資料集的標籤是長度為 10 的一維陣列,陣列中每個元素索引號表示對應數字出現的概率。在將 mnist 資料集作為輸入喂入神經網路時,需先將資料集中每張圖片變為長度784 一維陣列,將該陣列作為神經網路輸入特徵喂入神經網路。
1. 使用tensorflow提供的資料集mnist,具體的載入方法為:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data(data_path,one_hot=True)
2. 資料集分為train,validation,test三個資料集.
① 返回資料集train樣本數 mnist.train.num_examples
② 返回資料集validation樣本數 mnist.validation.num_examples
③ 返回資料集test樣本數 mnist.test.num_examples
3. 使用mnist.train.images返回train資料集中的所有圖片的畫素值
4. 使用mnist.train.labels返回train資料集中的所有圖片的標籤
5. 使用mnist.train.next_batch()將資料輸入神經網路
二 前向計算(得到預測值)
廢話不說了,直接看程式碼.(mnist_forward.py)
# _*_coding:utf-8_*_
import tensorflow as tf
input_node = 784
output_node = 10
layer1_node = 500
def get_weight(shape, regularizer):
# 表示要求產生的資料服從正態分佈, 並且每個值與均值之間的差值均小於兩倍的標準差
w = tf.Variable(tf.truncated_normal(shape, stddev=0.1))
# 判斷是否進行正則化操作
if regularizer is not None:
tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
return w
def get_bias(shape):
b = tf.Variable(tf.zeros(shape))
return b
def forward(x, regulaizer):
# 輸入層到Layer1層
w1 = get_weight(shape=[input_node, layer1_node], regularizer=regulaizer)
b1 = get_bias(shape=[layer1_node])
y1 = tf.nn.relu(tf.add(tf.matmul(x, w1), b1))
# 從Layer1層到輸出層
w2 = get_weight(shape=[layer1_node, output_node], regularizer=regulaizer)
b2 = get_bias(shape=[output_node])
y = tf.add(tf.matmul(y1, w2), b2)
return y
三 反向計算(引數更新)
(mnist_backward.py)# _*_coding:utf-8_*_
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import os
batch_size = 200 # 每輪訓練的圖片數量
learning_rate_base = 0.1 # 初始學習率
learning_rate_decay = 0.99 # 學習率衰減率
regularizer = 0.0001 # 正則化係數
total_steps = 50000 # 訓練輪數
moving_average_decay = 0.99
model_save_path = './model/'
model_name = 'mnist_model'
def backward(mnist):
# 設定x,y
x = tf.placeholder(tf.float32, [None, mnist_forward.input_node])
y = tf.placeholder(tf.float32, [None, mnist_forward.output_node])
y_hat = mnist_forward.forward(x, regularizer) # 獲取forward的返回值
global_step = tf.Variable(0, trainable=False) # 當前輪數值初始化
# 損失函式(softmax和交叉熵共同組成的loss,在加上正則化損失的總和)
ce = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y, 1), logits=y_hat)
cem = tf.reduce_mean(ce) # 求所有元素的均值
loss = cem + tf.add_n(tf.get_collection('losses')) # 得到包含所有引數損失的損失函式
# 學習率梯度衰減的模型
'''
decayed_learning_rate = learning_rate *decay_rate ^ (global_step / decay_steps)
'''
learning_rate = tf.train.exponential_decay(learning_rate_base, global_step, mnist.train.num_examples / batch_size,
learning_rate_decay, staircase=True)
# 定義引數優化方法
train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step)
# 定義引數的滑動平均
ema = tf.train.ExponentialMovingAverage(moving_average_decay, global_step)
ema_op = ema.apply(tf.trainable_variables())
# 定義引數的控制依賴
with tf.control_dependencies([train_step, ema_op]):
train_op = tf.no_op(name='train')
# 建立模型儲存的例項化物件
saver = tf.train.Saver()
with tf.Session() as sess:
# 引數初始化
sess.run(tf.global_variables_initializer())
ckpt = tf.train.get_checkpoint_state(model_save_path)
# 判斷是否有ckpt模型,有則恢復模型(這種方法簡單便捷,不進行重複的訓練)
if ckpt and ckpt.model_checkpoint_path:
# 恢復會話,繼續訓練模型
saver.restore(sess, ckpt.model_checkpoint_path)
# 模型訓練,迭代
for i in range(total_steps):
xs, ys = mnist.train.next_batch(batch_size)
_, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y: ys})
if i % 100 == 0:
print('After {} train step(s), loss on training batch is {}'.format(step, loss_value))
# 將當前會話載入到指定路徑
saver.save(sess, os.path.join(model_save_path, model_name), global_step=global_step)
def main():
mnist = input_data.read_data_sets('./mnist', one_hot=True)
backward(mnist)
if __name__ == '__main__':
main()
四 測試程式碼
(mnist_test.py)# _*_coding:utf-8_*_
import tensorflow as tf
import time
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backward
sleep_time = 5
def mnist_test(mnist):
with tf.Graph().as_default() as g:
x = tf.placeholder(tf.float32, shape=[None, mnist_forward.input_node], name='x')
y = tf.placeholder(tf.float32, shape=[None, mnist_forward.output_node], name='y')
y_hat = mnist_forward.forward(x, None)
ema = tf.train.ExponentialMovingAverage(mnist_backward.moving_average_decay)
ema_restore = ema.variables_to_restore()
# 建立模型儲存的例項化物件
saver = tf.train.Saver(ema_restore)
# 模型的準確率計算
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_hat, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 迴圈計算
while True:
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(mnist_backward.model_save_path)
# 判斷是否有ckpt模型,有則恢復模型
if ckpt and ckpt.model_checkpoint_path:
# 恢復會話
saver.restore(sess, ckpt.model_checkpoint_path)
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
accuracy_score = sess.run(accuracy, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print('After {} training steps, test accuracy is {}'.format(global_step, accuracy_score))
else:
print('No checkpoint file found')
return
time.sleep(sleep_time)
def main():
mnist = input_data.read_data_sets('./mnist', one_hot=True)
mnist_test(mnist)
if __name__ == '__main__':
main()
五 小結
筆者覺得文章程式碼中的一個亮點(較為實用,不需要每次執行程式碼都從頭開始,節約時間!!!):
# 判斷是否有ckpt模型,有則恢復模型(這種方法簡單便捷,不進行重複的訓練)
if ckpt and ckpt.model_checkpoint_path:
# 恢復會話,繼續訓練模型
saver.restore(sess, ckpt.model_checkpoint_path)
說明: 本文主要參考了'Tenorflow筆記'這門課程的內容,在這課程中還有講解資料集的製作等沒有在這裡寫出, 如有問題或者建議可以歡迎給作者留言,謝謝!! 相關文章
- 【Get】用深度學習識別手寫數字深度學習
- mnist手寫數字識別——深度學習入門專案(tensorflow+keras+Sequential模型)深度學習Keras模型
- 深度學習實驗:Softmax實現手寫數字識別深度學習
- Tensorflow2.0-mnist手寫數字識別示例
- 深度學習基礎 - 基於Theano-MLP的字元識別實驗(MNIST)深度學習字元
- 學習Pytorch+Python之MNIST手寫字型識別PyTorchPython
- opencv 學習之 基於K近鄰的數字識別OpenCV
- Spark學習筆記——手寫數字識別Spark筆記
- 用tensorflow2實現mnist手寫數字識別
- Pytorch搭建MyNet實現MNIST手寫數字識別PyTorch
- 在PaddlePaddle上實現MNIST手寫體數字識別
- 【機器學習】手寫數字識別機器學習
- 《手寫數字識別》神經網路 學習筆記神經網路筆記
- keras框架下的深度學習(一)手寫體識別Keras框架深度學習
- TensorFlow 實戰Google深度學習框架(第2版)第6章之LeNet-5模型實現MNIST數字識別Go深度學習框架模型
- 動手學深度學習需要這些數學基礎知識深度學習
- 瀏覽器中的手寫數字識別瀏覽器
- Action Recognition——基於深度學習的動作識別綜述深度學習
- 基於深度學習的手勢識別系統(Python程式碼,UI介面版)深度學習PythonUI
- 機器學習演算法(九): 基於線性判別模型的LDA手寫數字分類識別機器學習演算法模型LDA
- 基於 keras-js 快速實現瀏覽器內的 CNN 手寫數字識別KerasJS瀏覽器CNN
- tensorflow.js 手寫數字識別JS
- 深度學習(一)之MNIST資料集分類深度學習
- matlab練習程式(神經網路識別mnist手寫資料集)Matlab神經網路
- 使用神經網路識別手寫數字神經網路
- 程式碼實現(機器學習識別手寫數字)機器學習
- TensorFlow.NET機器學習入門【5】採用神經網路實現手寫數字識別(MNIST)機器學習神經網路
- 基於深度學習的機器人目標識別和跟蹤深度學習機器人
- 深度學習--基於卷積神經網路的歌唱嗓音識別深度學習卷積神經網路
- opencv python 基於KNN的手寫體識別OpenCVPythonKNN
- opencv python 基於SVM的手寫體識別OpenCVPython
- 深度學習——性別識別深度學習
- 基於PyTorch框架的多層全連線神經網路實現MNIST手寫數字分類PyTorch框架神經網路
- 使用深度學習進行基於AI的面部識別的不同方法深度學習AI
- m基於深度學習網路的手勢識別系統matlab模擬,包含GUI介面深度學習MatlabGUI
- Tensorflow實現RNN(LSTM)手寫數字識別RNN
- 機器學習:scikit-learn實現手寫數字識別機器學習
- OpenCV + sklearnSVM 實現手寫數字分割和識別OpenCV