【TensorFlow2.0】LeNet進行手寫體數字識別
LeNet簡介
LeNet-5出自論文Gradient-Based Learning Applied to Document Recognition,是一種用於手寫體字元識別的非常高效的卷積神經網路。
卷積神經網路
卷積神經網路能夠很好的利用影像的結構資訊。LeNet-5是一個較簡單的卷積神經網路。下圖顯示了其結構:輸入的二維影像,先經過兩次卷積層到池化層,再經過全連線層,最後使用softmax分類作為輸出層。下面我們主要介紹卷積層和池化層。
LeNet
1、INPUT層-輸入層
2、C1層-卷積層
3、S2層-池化層(下采樣層)
4、C3層-卷積層
5、S4層-池化層(下采樣層)
6、C5層-卷積層
7、F6層-全連線層
8、Output層-全連線層
#handwrite_Lenet_Tensorflow_train.py
#coding=utf-8
"""
參考:https://blog.csdn.net/suyunzzz/article/details/104195872
參考:https://cuijiahua.com/blog/2018/01/dl_3.html
"""
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import datetime
print ("start")
def train_model():
mnist=tf.keras.datasets.mnist
#獲取資料,訓練集,測試集 60k訓練,10K測試
(x_train,y_train),(x_test,y_test)=mnist.load_data()
#print (x_train.shape,y_train.shape)
#print (x_test.shape,y_test.shape)
#首先是資料 INPUT 層,輸入影像的尺寸統一歸一化為32*32。
#LeNet Input 為32*32
x_train= np.pad(x_train,((0,0),(2,2),(2,2)),'constant',constant_values=0) #28*28-》32*32
x_test= np.pad(x_test,((0,0),(2,2),(2,2)),'constant',constant_values=0) #28*28-》32*32
#print(x_train.shape,x_test.shape)
#資料集格式轉換
x_train=x_train.astype('float32')
x_train=x_train.astype('float32')
#歸一化,就是為了限定你的輸入向量的最大值跟最小值不超過你的隱層跟輸出層函式的限定範圍。
x_train=x_train/255#歸一化
x_test=x_test/255#歸一化
x_train=x_train.reshape(x_train.shape[0],32,32,1)
x_test=x_test.reshape(x_test.shape[0],32,32,1)
print(x_train.shape,x_test.shape)
#模型例項化,根據LeNet 的七層結構
model=tf.keras.models.Sequential([
tf.keras.layers.Conv2D(filters=6,kernel_size=(5,5),padding='valid',activation=tf.nn.relu,input_shape=(32,32,1)),
tf.keras.layers.AveragePooling2D(pool_size=(2,2),strides=(2,2),padding='same'),
tf.keras.layers.Conv2D(filters=16,kernel_size=(5,5),padding='valid',activation=tf.nn.relu,input_shape=(32,32,1)),
tf.keras.layers.AveragePooling2D(pool_size=(2,2),strides=(2,2),padding='same'),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=120,activation=tf.nn.relu),
tf.keras.layers.Dense(units=84,activation=tf.nn.relu),
tf.keras.layers.Dense(units=10,activation=tf.nn.softmax),
])
model.summary()
#模型訓練
num_epochs=1#訓練次數
batch_size=64#每個批次喂多少張圖片
lr=0.001#學習率
#優化器
adam_optimizer=tf.keras.optimizers.Adam(lr)
model.compile(
optimizer=adam_optimizer,
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy']
)
start_time=datetime.datetime.now()
model.fit(
x=x_train,
y=y_train,
batch_size=batch_size,
epochs=num_epochs
)
end_time=datetime.datetime.now()
time_cost=end_time-start_time
print('time_cost: ',time_cost)
model.save('leNet_model.h5')
print(model.evaluate(x_test,y_test))
image_index=3
# 預測
pred=model.predict(x_test[image_index].reshape(1,32,32,1))
print("predict result:",pred.argmax())
# 顯示
plt.imshow(x_test[image_index].reshape(32,32))
plt.savefig("predict_num.jpg")
plt.show()
train_model()
print ("end")
#handwrite_Lenet_Tensorflow_load.py
#coding=utf-8
import tensorflow as tf
mnist=tf.keras.datasets.mnist
import matplotlib.pyplot as plt
import matplotlib as m
import numpy as np
import cv2
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
# np.set_printoptions(threshold=np.inf)
#載入模型
def digit_predict():
model=tf.keras.models.load_model('leNet_model.h5')
#圖片預處理
img=cv2.imread('0.jpg')
print(img.shape)
plt.imshow(img)
plt.show()
#灰度圖
img=cv2.cvtColor(img,cv2.COLOR_BGR2GRAY)
print(img.shape)
plt.imshow(img,cmap='Greys')
plt.show()
#取反
img=cv2.bitwise_not(img)
plt.imshow(img,cmap='Greys')
plt.show()
print('二值化前:',img.shape)
print('二值化前:',img)
#純黑 純白 二值化
img[img<=100]=0
img[img>=140]=255
plt.imshow(img,cmap='Greys')
plt.show()
print('二值化後:',img.shape)
print('二值化後:',img)
#尺寸
img=cv2.resize(img,(32,32))
print('尺寸:',img.shape)
print('尺寸',img)
#歸一化
img=img/255
print('歸一化:',img.shape)
print('歸一化:',img)#0和1組成
#預測
pred=model.predict(img.reshape(1,32,32,1))
print('prediction Number: ',pred.argmax())
#列印圖片資訊
plt.imshow(img)
plt.show()
digit_predict()
相關文章
- CNN實現手寫數字識別並改變引數進行分析CNN
- tensorflow.js 手寫數字識別JS
- 在PaddlePaddle上實現MNIST手寫體數字識別
- 【機器學習】手寫數字識別機器學習
- 瀏覽器中的手寫數字識別瀏覽器
- Tensorflow2.0-mnist手寫數字識別示例
- 【Get】用深度學習識別手寫數字深度學習
- Tensorflow實現RNN(LSTM)手寫數字識別RNN
- Pytorch搭建MyNet實現MNIST手寫數字識別PyTorch
- OpenCV + sklearnSVM 實現手寫數字分割和識別OpenCV
- Pytorch 手寫數字識別 深度學習基礎分享PyTorch深度學習
- 小熊飛槳練習冊-01手寫數字識別
- 用tensorflow2實現mnist手寫數字識別
- 手寫數字圖片識別-全連線網路
- 計算機視覺—CNN識別手寫數字(11)計算機視覺CNN
- 計算機視覺—kNN識別手寫數字(10)計算機視覺KNN
- 深度學習實驗:Softmax實現手寫數字識別深度學習
- 手寫數字圖片識別-卷積神經網路卷積神經網路
- 《手寫數字識別》神經網路 學習筆記神經網路筆記
- KNN 演算法-實戰篇-如何識別手寫數字KNN演算法
- torch--minst手寫體識別
- 【百度飛槳】手寫數字識別模型部署Paddle Inference模型
- 恩墨大資料系列免費課之《影像識別揭秘-Python手寫數字識別》大資料Python
- 使用人工神經網路訓練手寫數字識別模型神經網路模型
- 深度學習例項之基於mnist的手寫數字識別深度學習
- 【Keras篇】---Keras初始,兩種模型構造方法,利用keras實現手寫數字體識別Keras模型構造方法
- TensorFlow筆記(5)——優化手寫數字識別模型之優化器筆記優化模型
- 雲脈文件雲識別APP:輕鬆識別潦草手寫體APP
- TensorFlow2.0 + CNN + keras + 人臉識別CNNKeras
- 機器學習之神經網路識別手寫數字(純python實現)機器學習神經網路Python
- 手寫識別 b友
- opencv python 基於KNN的手寫體識別OpenCVPythonKNN
- opencv python 基於SVM的手寫體識別OpenCVPython
- 【Python】keras使用Lenet5識別mnistPythonKeras
- 手把手教你使用LabVIEW OpenCV DNN實現手寫數字識別(含原始碼)ViewOpenCVDNN原始碼
- mnist手寫數字識別——深度學習入門專案(tensorflow+keras+Sequential模型)深度學習Keras模型
- 全棧AI工程師指南,DIY一個識別手寫數字的web應用全棧AI工程師Web
- python呼叫hanlp進行命名實體識別PythonHanLP