python三層全連線層實現手寫字母識別方式
先用最簡單的三層全連線神經網路,然後新增啟用層檢視實驗結果,最後加上批標準化驗證是否有效
首先根據已有的模板定義網路結構SimpleNet,命名為net.py
import torch from torch.autograd import Variable import numpy as np import matplotlib.pyplot as plt from torch import nn,optim from torch.utils.data import DataLoader from torchvision import datasets,transforms #定義三層全連線神經網路 class simpleNet(nn.Module): def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):#輸入維度,第一層的神經元個數、第二層的神經元個數,以及第三層的神經元個數 super(simpleNet,self).__init__() self.layer1=nn.Linear(in_dim,n_hidden_1) self.layer2=nn.Linear(n_hidden_1,n_hidden_2) self.layer3=nn.Linear(n_hidden_2,out_dim) def forward(self,x): x=self.layer1(x) x=self.layer2(x) x=self.layer3(x) return x #新增啟用函式 class Activation_Net(nn.Module): def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim): super(NeutalNetwork,self).__init__() self.layer1=nn.Sequential(#Sequential組合結構 nn.Linear(in_dim,n_hidden_1),nn.ReLU(True)) self.layer2=nn.Sequential( nn.Linear(n_hidden_1,n_hidden_2),nn.ReLU(True)) self.layer3=nn.Sequential( nn.Linear(n_hidden_2,out_dim)) def forward(self,x): x=self.layer1(x) x=self.layer2(x) x=self.layer3(x) return x #新增批標準化處理模組,皮標準化放在全連線的後面,非線性的前面 class Batch_Net(nn.Module): def _init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim): super(Batch_net,self).__init__() self.layer1=nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.BatchNormld(n_hidden_1),nn.ReLU(True)) self.layer2=nn.Sequential(nn.Linear(n_hidden_1,n_hidden_2),nn.BatchNormld(n_hidden_2),nn.ReLU(True)) self.layer3=nn.Sequential(nn.Linear(n_hidden_2,out_dim)) def forword(self,x): x=self.layer1(x) x=self.layer2(x) x=self.layer3(x) return x
訓練網路,
import torch from torch.autograd import Variable import numpy as np import matplotlib.pyplot as plt %matplotlib inline from torch import nn,optim from torch.utils.data import DataLoader from torchvision import datasets,transforms #定義一些超引數 import net batch_size=64 learning_rate=1e-2 num_epoches=20 #預處理 data_tf=transforms.Compose( [transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])#將影像轉化成tensor,然後繼續標準化,就是減均值,除以方差 #讀取資料集 train_dataset=datasets.MNIST(root='./data',train=True,transform=data_tf,download=True) test_dataset=datasets.MNIST(root='./data',train=False,transform=data_tf) #使用內建的函式匯入資料集 train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True) test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False) #匯入網路,定義損失函式和最佳化方法 model=net.simpleNet(28*28,300,100,10) if torch.cuda.is_available():#是否使用cuda加速 model=model.cuda() criterion=nn.CrossEntropyLoss() optimizer=optim.SGD(model.parameters(),lr=learning_rate) import net n_epochs=5 for epoch in range(n_epochs): running_loss=0.0 running_correct=0 print("epoch {}/{}".format(epoch,n_epochs)) print("-"*10) for data in train_loader: img,label=data img=img.view(img.size(0),-1) if torch.cuda.is_available(): img=img.cuda() label=label.cuda() else: img=Variable(img) label=Variable(label) out=model(img)#得到前向傳播的結果 loss=criterion(out,label)#得到損失函式 print_loss=loss.data.item() optimizer.zero_grad()#歸0梯度 loss.backward()#反向傳播 optimizer.step()#最佳化 running_loss+=loss.item() epoch+=1 if epoch%50==0: print('epoch:{},loss:{:.4f}'.format(epoch,loss.data.item()))
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69993754/viewspace-2752788/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- 手寫數字圖片識別-全連線網路
- 全連線層的作用解析
- 基於PyTorch框架的多層全連線神經網路實現MNIST手寫數字分類PyTorch框架神經網路
- 卷積層和全連線層之間的關係卷積
- Python底層實現KNNPythonKNN
- 三層(似曾相識)
- Tensorflow實現RNN(LSTM)手寫數字識別RNN
- Vue 高德地圖 API Loca 如何使用 連線線圖層、脈衝連線圖層Vue地圖API
- Pytorch搭建MyNet實現MNIST手寫數字識別PyTorch
- OpenCV + sklearnSVM 實現手寫數字分割和識別OpenCV
- 字母排列(python實現)Python
- 層級聚類和Python實現的初學者指南(附連結)聚類Python
- 深度學習實驗:Softmax實現手寫數字識別深度學習
- 全連線神經網路的原理及Python實現神經網路Python
- Redis(三)--- Redis的五大資料型別的底層實現Redis大資料資料型別
- 機器學習之神經網路識別手寫數字(純python實現)機器學習神經網路Python
- 用tensorflow2實現mnist手寫數字識別
- 在PaddlePaddle上實現MNIST手寫體數字識別
- 二層交換機和三層交換機的區別
- Minya 分層框架實現的思考(三):問題框架
- 自己動手寫一個持久層框架框架
- Python垃圾回收(GC)三層心法,你瞭解到第幾層?PythonGC
- python 三種方式實現截圖Python
- Python實現MySQL連線池PythonMySql
- 【opencv3】 svm實現手寫體與人臉識別OpenCV
- 基於滴滴雲 GPU 實現簡單 MINIST 手寫識別GPU
- 【TensorFlow篇】--Tensorflow框架實現SoftMax模型識別手寫數字集框架模型
- Python實現AI影像識別-身份證識別PythonAI
- opencv python 基於KNN的手寫體識別OpenCVPythonKNN
- opencv python 基於SVM的手寫體識別OpenCVPython
- NSDictionary底層實現原理
- AutoreleasePool底層實現原理
- HashMap底層實現原理HashMap
- mysql索引底層實現MySql索引
- LinkedList的底層實現
- JS實現載入層JS
- 從手寫三層迴圈到標準實現,矩陣相乘執行效率提高三萬六千倍之路矩陣
- Spring原始碼系列:初探底層,手寫SpringSpring原始碼