上一節我們學習了Pytorch優化網路的基本方法,本節我們將以MNIST資料集為例,通過搭建一個完整的神經網路,來加深對Pytorch的理解。
一、資料集
MNIST是一個非常經典的資料集,下載連結:http://yann.lecun.com/exdb/mnist/
下載下來的檔案如下:
該手寫數字資料庫具有60,000個示例的訓練集和10,000個示例的測試集。它是NIST提供的更大集合的子集。數字已經過尺寸標準化,並以固定尺寸的影像為中心。
手寫數字識別是一個比較簡單的任務,它是一個10分類問題,(0-9),之所以選這個資料集,是因為識別難度低,計算量小,資料容易獲得。
二、模型搭建
1、網路節點的確定
對於不同的目的,網路的選擇也是不一樣的。一般來說,網路容量和資料集大小是對應的。一個小型資料集也只需要一個小型的網路。
這裡有一個經驗值:
1)model_size=sqrt(in_size*out_size)
2)model_size=log(in_size)
3) model_size=sqrt(in_size*out_size)
model_size:網路的節點量
in_size:輸入的節點量
out_size輸出的節點量
2、匯入pytorch包
import torch import torchvision import trochvision import datasets import trochvision import transforms from torch.autograd import Variable
3、獲取訓練集和測試集
#root用於指定資料集下載後的存放路徑
#transform用於指定匯入資料集需要對資料進行變換操作
#train指定在資料集下載後需要載入哪部分資料,true為訓練集,false為測試集
data_train=datasets.MNIST(root="./data/",transform=transform,train=True,download=True) data_test=datasets.MNIST(root='./data/',transform=transform,train=False)
4、資料預覽和裝載
#資料裝載,可以理解為對圖片的處理 #處理完成後,將圖片送給模型訓練,裝載就是打包的過程 #dataset 用於指定載入的資料集名稱 #batch_size設定了每個包的圖片資料資料個數 #shuffle 裝載過程將資料隨機打亂並打包 data_loader_train=torch.utils.data.DataLoader(dataset=data_train,batch_size=64,shuffle=True) data_loader_test=torch.utils.data.DataLoader(dataset=data_test,batch_size=64,shuffle=True)