TensorFlow 入門(MNIST資料集)
TensorFlow 中的layers 模組提供了高階API,使我們能輕易建立出神經網路。包括全連線層,卷積層,啟用函式,dropout regularization。在本文中,將構建CNN網路來進行手寫數字識別。
MNIST資料集包括60000個訓練樣本,10000個測試樣本,每個樣本為28*28畫素的圖片。
本例中使用的網路結構如下:
卷積層1:32個5*5 filter,啟用函式為ReLU
池化層1:2*2,步長為2
卷積層2: 64個5*5 filter,啟用函式為ReLU
池化層2: 2*2,步長為2
Dense Layer 1(全連線層):1024個neurons,dropout regularization rate 為 0.4
API中的三個函式(這三個函式的輸入輸出均為tensor):tf.layers.conv2d(), tf.layers.max_pooling2d(), tf.layers.dense()
首先給出定義模型的程式碼:
下面對程式碼細節進行解釋:
Input Layer:
對於一個二維的影像,首先將資料轉化為以下形式[batch_size,image_width,image_height,channels]
其中batch_size = -1,表示生成自適應大小的tensor。
Convolutional Layer:
此處值得注意的是,padding,padding有兩個可選值 valid和same,預設為valid,當padding值為same值時,卷積是會在邊緣處補0,使得輸出tensor的width*height與輸入相同。
Dense Layer:
這裡的全連線層包含1024個神經元,在連線至全連線層之前,我們首先將原有的tensor轉換。由pool2輸出的tensor([batch_size,image_width,image_height,channels])轉換為二維的tensor([batch_size,features])。Features需要自己計算出來,因為pool2輸出tensor的格式為[batch_size,7,7,64]。
為了防止過擬合,需要使用drop regularization,
這裡的引數training是一個Boolean值,只有在訓練階段,才需要使用dropout。當training = true時,這一層才會執行。輸出為[batch_size, 1024]
Logits Layer:
這裡使用了預設的啟用函式linear activation
Generate Predictions:
Logits Layer輸出一個tensor[batch_size,10](類似一個batch_size 行,10列的一個矩陣,對應每一張圖片,都會產生一個向量,每個向量有10個值),在此,我們可以生成兩種型別的返回值,(1)返回0,1,2,3,4,5,6,7,8,9中的一個數字(2)返回對應每個數字的機率,例如0的機率是0.2,1的機率是0.7等等。
在本例中,我們返回向量(10個值)中,最大值對應的類別。使用tf.argmax()獲得對應的索引。
可以使用tf.nn.softmax()得到每個類別對應的機率:
我們使用name引數來明確命名這個操作softmax_tensor,所以我們可以在稍後引用它.
將預測結果放入一個字典中,返回一個EstimatorSpec物件。
Calculate Loss:
對於多分類問題,多采用cross_entropy作為loss function。
首先,傳入的labels尺寸為[batch_size,1],將其轉化為與logits相同的尺寸[batch_size,10]。
tf.cast(x,dtype): Casts a tensor to a new type.
tf.estimator.EstimatorSpec 類
Ops and objects returned from a model_fn and passed to an Estimator.(從定義模型的函式中返回,並傳遞給Estimator的一箇中間產物)
建構函式:
Training Op:
Evaluation Op:
程式碼:
【本文轉載自:知乎,作者:motto,原文連結:】
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/31542119/viewspace-2200313/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- 使用tensorflow操作MNIST資料
- MNIST資料集介紹
- TensorFlow系列專題(六):實戰專案Mnist手寫資料集識別
- Python深度學習入門之mnist-inception(Tensorflow2.0實現)Python深度學習
- tensorflow2.3 mnist
- 如何在Tensorflow.js中處理MNIST影象資料JS
- tensorflow入門
- 深度學習(一)之MNIST資料集分類深度學習
- mnist手寫數字識別——深度學習入門專案(tensorflow+keras+Sequential模型)深度學習Keras模型
- TensorFlow入門文章
- keras 手動搭建alexnet並訓練mnist資料集Keras
- MNIST資料集詳解及視覺化處理(pytorch)視覺化PyTorch
- TensorFlow.NET機器學習入門【6】採用神經網路處理Fashion-MNIST機器學習神經網路
- TensorFlow 入門 上(自用)
- TensorFlow 入門 下(自用)
- ML.NET呼叫Tensorflow模型示例——MNIST模型
- python 將Mnist資料集轉為jpg,並按比例/標籤拆分為多個子資料集Python
- 前饋神經網路進行MNIST資料集分類神經網路
- TensorFlow入門 - 變數(Variables)變數
- [譯] 使用 PyTorch 在 MNIST 資料集上進行邏輯迴歸PyTorch邏輯迴歸
- TensorFlow.NET機器學習入門【7】採用卷積神經網路(CNN)處理Fashion-MNIST機器學習卷積神經網路CNN
- tensorflow載入資料的三種方式
- Tensorflow 2.x入門教程
- 【原創】python實現BP神經網路識別Mnist資料集Python神經網路
- Pytorch筆記之 多層感知機實現MNIST資料集分類PyTorch筆記
- matlab練習程式(神經網路識別mnist手寫資料集)Matlab神經網路
- TensorFlow.NET機器學習入門【5】採用神經網路實現手寫數字識別(MNIST)機器學習神經網路
- Tensorflow2.0-mnist手寫數字識別示例
- 大資料工程師入門系列—常用資料採集工具(Flume、Logstash 和 Fluentd)大資料工程師
- 深度學習:TensorFlow入門實戰深度學習
- 大資料入門大資料
- NLP入門資料
- 尋找手寫資料集MNIST程式的最佳引數(learning_rate、nodes、epoch)
- Flask入門資料庫的查詢集與過濾器(十一)Flask資料庫過濾器
- 如何入門Pytorch之四:搭建神經網路訓練MNISTPyTorch神經網路
- tensorflow資料清洗
- 資料分析 | 零基礎入門資料分析(一):從入門到摔門?
- 面向初學者的快速入門tensorflow