如何入門Pytorch之四:搭建神經網路訓練MNIST

jimchen1218發表於2020-09-13

       上一節我們學習了Pytorch優化網路的基本方法,本節我們將以MNIST資料集為例,通過搭建一個完整的神經網路,來加深對Pytorch的理解。

一、資料集

       MNIST是一個非常經典的資料集,下載連結:http://yann.lecun.com/exdb/mnist/

      下載下來的檔案如下:

 

該手寫數字資料庫具有60,000個示例的訓練集和10,000個示例的測試集。它是NIST提供的更大集合的子集。數字已經過尺寸標準化,並以固定尺寸的影像為中心。

手寫數字識別是一個比較簡單的任務,它是一個10分類問題,(0-9),之所以選這個資料集,是因為識別難度低,計算量小,資料容易獲得。

二、模型搭建

    1、網路節點的確定

    對於不同的目的,網路的選擇也是不一樣的。一般來說,網路容量和資料集大小是對應的。一個小型資料集也只需要一個小型的網路。

這裡有一個經驗值:

      1)model_size=sqrt(in_size*out_size)

      2)model_size=log(in_size)

      3)  model_size=sqrt(in_size*out_size)

      model_size:網路的節點量

      in_size:輸入的節點量

      out_size輸出的節點量

     2、匯入pytorch包

import torch
import torchvision
import trochvision import datasets
import trochvision import transforms
from torch.autograd import Variable

    3、獲取訓練集和測試集

#root用於指定資料集下載後的存放路徑
#transform用於指定匯入資料集需要對資料進行變換操作
#train指定在資料集下載後需要載入哪部分資料,true為訓練集,false為測試集
data_train=datasets.MNIST(root="./data/",transform=transform,train=True,download=True) data_test=datasets.MNIST(root='./data/',transform=transform,train=False)

    4、資料預覽和裝載

#資料裝載,可以理解為對圖片的處理
#處理完成後,將圖片送給模型訓練,裝載就是打包的過程
#dataset 用於指定載入的資料集名稱
#batch_size設定了每個包的圖片資料資料個數
#shuffle 裝載過程將資料隨機打亂並打包
data_loader_train=torch.utils.data.DataLoader(dataset=data_train,batch_size=64,shuffle=True)
data_loader_test=torch.utils.data.DataLoader(dataset=data_test,batch_size=64,shuffle=True)

 

  

 

   

 

相關文章