學習Pytorch+Python之MNIST手寫字型識別
GoodMai。com
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import numpy as np
import torchvision.utils
from torchvision import datasets, transforms
from torch.autograd import Variable
import torch.utils.data
#判斷是否能用GPU,如果能就用GPU,不能就用CPU
use_gpu = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#資料轉換,Pytorch的底層是tensor(張量),所有用來訓練的影像均需要轉換成tensor
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
#下載資料集
data_train = datasets.MNIST(root="./data/", transform=transform, train=True, download=True)
data_test = datasets.MNIST(root="./data/", transform=transform, train=False)
#載入資料集,批次大小為64,shuffle表示亂序
data_loader_train = torch.utils.data.DataLoader(dataset=data_train, batch_size=64, shuffle=True)
data_loader_test = torch.utils.data.DataLoader(dataset=data_test, batch_size=64, shuffle=True)
#建立模型即網路架構
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
#建立二維卷積
self.conv1 = nn.Sequential(
#輸入特徵數量為1,輸出特徵數量為64,卷積核大小為3x3,步長為1,邊緣填充為1,保證了卷積後的特徵尺寸與原來一樣
nn.Conv2d(1, 64, kernel_size=3, stride=1, c=1),
nn.ReLU(),
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
#最大池化,特徵數量不變,尺寸減半[(input-kernel_size)/stride + 1]
nn.MaxPool2d(stride=2, kernel_size=2)
)
#建立全連線
self.dense = nn.Sequential(
nn.Linear(14*14*128, 1024),
nn.ReLU(),
#隨機丟棄部分結點,防止過擬合
nn.Dropout(p=0.5),
nn.Linear(1024, 10)
)
#建立好網路結構後,建立前向傳播
def forward(self, x):
#對資料進行卷積操作
x = self.conv1(x)
#改變特徵形狀
x = x.c(-1, 14*14*128)
#對特徵進行全連線
x = self.dense(x)
return x
#類例項化
model = Model()
#指定資料訓練次數
epochs = 5
#設定學習率,即梯度下降的權重,其值越大收斂越快,越小收斂越慢
learning_rate = 0.0001
#選用引數最佳化器,這裡使用Adam
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#選用損失函式,這裡使用交叉熵函式,來判定實際的輸出與期望的輸出的接近程度
criterion = nn.CrossEntropyLoss()
#判斷是否使用GPU訓練
if(use_gpu):
model = model.cuda()
loss_f = criterion.cuda()
#用for迴圈的方式完成資料的批次訓練
for epoch in range(epochs):
#定義並初始化訓練過程的損失以及正確率
running_loss = 0
running_correct = 0
for data in data_loader_train:
x_train, y_train = data
x_train, y_train =x_train.cuda(), y_train.cuda()
x_train = x_train.to(device)
y_train = y_train.to(device)
#將預處理好的資料載入到例項化好的model模型中,進行訓練得到輸出
outputs = model(x_train)
_, pred = torch.max(outputs.data, 1)
#每次迴圈中,梯度必須清零,防止梯度堆疊
optimizer.zero_grad()
#呼叫設定的損失
loss = criterion(outputs, y_train)
#反向傳播損失
loss.backward()
#引數更新
optimizer.step()
#更新損失
running_loss += loss.item()
#更新正確率
running_correct += torch.sum(pred == y_train.data)
testing_correct = 0
#檢視每輪訓練後,測試資料集中的正確率
for data in data_loader_test:
x_test, y_test = data
x_test, y_test = Variable(x_test), Variable(y_test)
x_test = x_test.to(device)
y_test = y_test.to(device)
outputs = model(x_test)
_, pred = torch.max(outputs.data, 1)
testing_correct += torch.sum(pred == y_test.data)
print("Loss is {}, Training Accuray is {}%, Test Accurray is {}".format(running_loss/len(data_train), 100*running_correct/len(data_train), 100*testing_correct/len(data_test)))
#測試訓練好的模型
#隨機載入4個手寫數字
data_loader_test = torch.utils.data.DataLoader(dataset=data_test, batch_size=4, shuffle=True)
#函式next相關
#函式iter相關
x_test,y_test = next(iter(data_loader_test))
inputs = Variable(x_test)
inputs = inputs.to(device)
pred = model(inputs)
#_為輸出的最大值,pred為最大值的索引值
_,pred = torch.max(pred, 1)
print('Predict Label is :', [i for i in pred.data])
print('Real Label is:', [ i for i in y_test] )
img = torchvision.utils.make_grid(x_test)
img = img.numpy().transpose(1, 2, 0)
std = [0.5]
mean = [0.5]
img = img*std+mean
plt.imshow(img)
plt.show()
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/70008680/viewspace-2838449/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- 深度學習例項之基於mnist的手寫數字識別深度學習
- Tensorflow2.0-mnist手寫數字識別示例
- mnist手寫數字識別——深度學習入門專案(tensorflow+keras+Sequential模型)深度學習Keras模型
- matlab練習程式(神經網路識別mnist手寫資料集)Matlab神經網路
- Pytorch搭建MyNet實現MNIST手寫數字識別PyTorch
- 用tensorflow2實現mnist手寫數字識別
- 在PaddlePaddle上實現MNIST手寫體數字識別
- 【機器學習】手寫數字識別機器學習
- 【Get】用深度學習識別手寫數字深度學習
- TensorFlow.NET機器學習入門【5】採用神經網路實現手寫數字識別(MNIST)機器學習神經網路
- TensorFlow系列專題(六):實戰專案Mnist手寫資料集識別
- keras框架下的深度學習(一)手寫體識別Keras框架深度學習
- Pytorch 手寫數字識別 深度學習基礎分享PyTorch深度學習
- 《手寫數字識別》神經網路 學習筆記神經網路筆記
- 深度學習實驗:Softmax實現手寫數字識別深度學習
- 手寫識別 b友
- 深度學習(一)之MNIST資料集分類深度學習
- iOS學習筆記06 手勢識別iOS筆記
- Promise學習筆記(知識點 + 手寫Promise)Promise筆記
- 【Python】keras使用Lenet5識別mnistPythonKeras
- 【Python】keras神經網路識別mnistPythonKeras神經網路
- 小熊飛槳練習冊-01手寫數字識別
- 機器學習之神經網路識別手寫數字(純python實現)機器學習神經網路Python
- TensorFlow 實戰Google深度學習框架(第2版)第6章之LeNet-5模型實現MNIST數字識別Go深度學習框架模型
- 【翻譯】GRAIL-手寫識別AI
- C# 手寫識別方案整理C#
- torch--minst手寫體識別
- Spring學習之——手寫Mini版Spring原始碼Spring原始碼
- 疫情防控資訊採集難?深度學習手寫體識別來幫忙深度學習
- 深度學習——性別識別深度學習
- CSS學習筆記之字型樣式CSS筆記
- 恩墨大資料系列免費課之《影像識別揭秘-Python手寫數字識別》大資料Python
- 機器學習從入門到放棄:硬train一發手寫數字識別機器學習AI
- tensorflow.js 手寫數字識別JS
- 【Python】keras卷積神經網路識別mnistPythonKeras卷積神經網路
- 雲脈文件雲識別APP:輕鬆識別潦草手寫體APP
- Deep Learning Tutorial (翻譯) 之使用邏輯迴歸分類手寫數字MNIST邏輯迴歸
- 機器學習演算法(九): 基於線性判別模型的LDA手寫數字分類識別機器學習演算法模型LDA