尋找手寫資料集MNIST程式的最佳引數(learning_rate、nodes、epoch)
尋找手寫資料集MNIST程式的最佳引數(learning_rate、nodes、epoch)
在利用神經網路進行MNIST手寫資料集進行識別時,部分引數是需要人為進行設定的,具有不確定性,本文按照英國人工智慧領域碩士塔裡克·拉希德(Tariq Rashid)先生的方法對於識別MNIST資料集的BP神經網路演算法程式進行了多項引數測試,利用基本的控制變數法,試圖尋找針對MNIST資料集識別的學習率、隱藏層節點數以及訓練迴圈次數的最佳值,並驗證塔裡克·拉希德(Tariq Rashid)在書中得到的結論的準確性
利用神經網路演算法識別手寫資料集MNIST的程式原始碼連線如下:
https://blog.csdn.net/weixin_46076729/article/details/108936565
讀者可根據此程式碼進行測試。
注意:本文利用的是jupyter Notebook進行程式執行,需將建立神經網路的類的程式與訓練神經網路的程式放在同一個Input框中,然後將驗證神經網路的程式放在另一個框中進行執行,點選Jupyter的執行按鈕兩次,如果將程式放在同一個框中執行可能系統會出現報錯。
由於程式一開始進行前向傳播時,需要系統生成隨機矩陣,所以為了避免由於系統產生隨機數帶來的偶然因素影響,每一次試驗進行了三次執行,並最終取平均值進行分析。本文用MNIST資料集的6000條資料進行訓練,用測試集的10000條資料進行驗證,利用識別資料的準確率來表徵演算法識別的質量好壞。
下列表格中,score1、score2、score3表示三次平行測試的結果,average_value是三次測試結果的平均值。
一、對各項引數的分析
1、學習率(learing_rate)
控制變數:迴圈次數(epoch)=5;隱藏層節點數(nodes)=100
取其平均值用折線圖表示如圖所示:
顯然,當learning rate=0.1的時候能夠取得最高的準確率。
2、隱藏層節點數(nodes)
控制變數:學習率(learning rate)=0.3;迴圈次數(epoch)=2
取其平均值用折線圖表示如圖所示:
此外,本文還記錄了執行程式所需要的時間,如下圖所示:
根據上數兩個圖可以看出, 當節點數達到一定數量時,演算法的準去率會趨於一個穩定值,在這個穩定值上下波動。然而,從第二張圖可以看出,隨著節點數的增加,程式的計算量也進一步增加,節點數越多,程式所需執行的時間就越長,所以綜合因素考慮,當節點數在150到200之間的時候基本為最佳值,在此範圍內,可以獲得較高的準確率,並且程式執行時間也較短,能夠有效節約時間成本。
3、訓練迴圈次數(epoch)
控制變數:學習率(learning rate)=0.3;隱藏層節點數(nodes)=100
取其平均值用折線圖表示如圖所示:
很明顯,改變訓練次數並不能提高神經網路的識別準確率,一開始認為資料出現了問題或者程式出現了問題,在經過一番思考後突然意識到,每次迴圈的時候,在前向傳播階段,系統都要先對輸入層與隱藏層之間的權重矩陣進行隨機賦值,數值符合正態分佈,均值為0,標準差為隱藏層節點數的負二分之一次方,正因為計算機在每次迴圈的開頭對權重矩陣的隨機賦值,使得偶然性大大增加。每次迴圈都是重新開始,並不能對上一次的權重矩陣進行反向傳播計算,所以才導致了上圖所示的試驗結果。
二、總結
通過控制變數對三項引數進行定量分析,從而找到利用BP神經網路對MNIST手寫資料集進行識別的程式不確定性引數的最佳值,與塔裡克·拉希德(Tariq Rashid)的結論基本一致,得到最佳學習率(learning rate)為0.1,最佳隱藏層節點數(nodes)為150到200之間。
[1] 《Python神經網路程式設計》 [英]塔裡克·拉希德(Tariq Rashid) 人民郵電出版社
(作為一個初學者來說,第一次學習神經網路與python,若文中有錯誤,歡迎大佬指正。)
相關文章
- matlab練習程式(神經網路識別mnist手寫資料集)Matlab神經網路
- 尋找寫程式碼感覺(七)之封裝請求引數和返回引數封裝
- 尋找寫程式碼感覺(十二)之 封裝分頁請求引數和返回引數封裝
- 尋找寫程式碼感覺(十六)之 整合Validation做引數校驗
- MNIST資料集介紹
- TensorFlow系列專題(六):實戰專案Mnist手寫資料集識別
- keras 手動搭建alexnet並訓練mnist資料集Keras
- TensorFlow 入門(MNIST資料集)
- Tensorflow2.0-mnist手寫數字識別示例
- jquery尋找最佳路徑效果程式碼例項jQuery
- ACM 尋找最大數ACM
- 用tensorflow2實現mnist手寫數字識別
- Pytorch搭建MyNet實現MNIST手寫數字識別PyTorch
- 深度學習例項之基於mnist的手寫數字識別深度學習
- 深度學習(一)之MNIST資料集分類深度學習
- 如何高效尋找素數
- 在PaddlePaddle上實現MNIST手寫體數字識別
- 尋找寫程式碼感覺(十四)之 新增功能的開發
- 大資料叢集核心引數調優大資料
- 如何尋找優質的資料標註公司?
- 尋找風險投資的十大最佳實踐
- 尋找寫程式碼感覺(十三)之 編輯功能的開發
- 尋找寫程式碼感覺(十五)之 刪除功能的開發
- PHP 採集程式中日常的引數PHP
- SAP CRM產品主資料搜尋功能的With individual object搜尋引數Object
- 尋找寫程式碼感覺(三)之使用 Spring Boot 編寫介面Spring Boot
- 目標檢測(2):我用 PyTorch 復現了 LeNet-5 神經網路(MNIST 手寫資料集篇)!PyTorch神經網路
- 尋找頭緒:編寫可維護的 JavaScriptJavaScript
- 尋找產品的需求-試分析手機殼
- LeetCode:尋找丟失的數字LeetCode
- leetcode 287 尋找重複的數LeetCode
- 2837 尋找水仙花數
- 檢視引數(parameter)的字典與資料庫字符集資料庫
- 尋找海量資料集用於大資料開發實戰(維基百科網站統計資料)大資料網站
- 尋找鎖定資料庫使用者的真兇資料庫
- 前饋神經網路進行MNIST資料集分類神經網路
- MNIST資料集詳解及視覺化處理(pytorch)視覺化PyTorch
- window7下caffe安裝與mnist資料集測試