python三層全連線層實現手寫字母識別方式
先用最簡單的三層全連線神經網路,然後新增啟用層檢視實驗結果,最後加上批標準化驗證是否有效
首先根據已有的模板定義網路結構SimpleNet,命名為net.py
import torch from torch.autograd import Variable import numpy as np import matplotlib.pyplot as plt from torch import nn,optim from torch.utils.data import DataLoader from torchvision import datasets,transforms #定義三層全連線神經網路 class simpleNet(nn.Module): def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim):#輸入維度,第一層的神經元個數、第二層的神經元個數,以及第三層的神經元個數 super(simpleNet,self).__init__() self.layer1=nn.Linear(in_dim,n_hidden_1) self.layer2=nn.Linear(n_hidden_1,n_hidden_2) self.layer3=nn.Linear(n_hidden_2,out_dim) def forward(self,x): x=self.layer1(x) x=self.layer2(x) x=self.layer3(x) return x #新增啟用函式 class Activation_Net(nn.Module): def __init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim): super(NeutalNetwork,self).__init__() self.layer1=nn.Sequential(#Sequential組合結構 nn.Linear(in_dim,n_hidden_1),nn.ReLU(True)) self.layer2=nn.Sequential( nn.Linear(n_hidden_1,n_hidden_2),nn.ReLU(True)) self.layer3=nn.Sequential( nn.Linear(n_hidden_2,out_dim)) def forward(self,x): x=self.layer1(x) x=self.layer2(x) x=self.layer3(x) return x #新增批標準化處理模組,皮標準化放在全連線的後面,非線性的前面 class Batch_Net(nn.Module): def _init__(self,in_dim,n_hidden_1,n_hidden_2,out_dim): super(Batch_net,self).__init__() self.layer1=nn.Sequential(nn.Linear(in_dim,n_hidden_1),nn.BatchNormld(n_hidden_1),nn.ReLU(True)) self.layer2=nn.Sequential(nn.Linear(n_hidden_1,n_hidden_2),nn.BatchNormld(n_hidden_2),nn.ReLU(True)) self.layer3=nn.Sequential(nn.Linear(n_hidden_2,out_dim)) def forword(self,x): x=self.layer1(x) x=self.layer2(x) x=self.layer3(x) return x
訓練網路,
import torch from torch.autograd import Variable import numpy as np import matplotlib.pyplot as plt %matplotlib inline from torch import nn,optim from torch.utils.data import DataLoader from torchvision import datasets,transforms #定義一些超引數 import net batch_size=64 learning_rate=1e-2 num_epoches=20 #預處理 data_tf=transforms.Compose( [transforms.ToTensor(),transforms.Normalize([0.5],[0.5])])#將影像轉化成tensor,然後繼續標準化,就是減均值,除以方差 #讀取資料集 train_dataset=datasets.MNIST(root='./data',train=True,transform=data_tf,download=True) test_dataset=datasets.MNIST(root='./data',train=False,transform=data_tf) #使用內建的函式匯入資料集 train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True) test_loader=DataLoader(test_dataset,batch_size=batch_size,shuffle=False) #匯入網路,定義損失函式和最佳化方法 model=net.simpleNet(28*28,300,100,10) if torch.cuda.is_available():#是否使用cuda加速 model=model.cuda() criterion=nn.CrossEntropyLoss() optimizer=optim.SGD(model.parameters(),lr=learning_rate) import net n_epochs=5 for epoch in range(n_epochs): running_loss=0.0 running_correct=0 print("epoch {}/{}".format(epoch,n_epochs)) print("-"*10) for data in train_loader: img,label=data img=img.view(img.size(0),-1) if torch.cuda.is_available(): img=img.cuda() label=label.cuda() else: img=Variable(img) label=Variable(label) out=model(img)#得到前向傳播的結果 loss=criterion(out,label)#得到損失函式 print_loss=loss.data.item() optimizer.zero_grad()#歸0梯度 loss.backward()#反向傳播 optimizer.step()#最佳化 running_loss+=loss.item() epoch+=1 if epoch%50==0: print('epoch:{},loss:{:.4f}'.format(epoch,loss.data.item()))
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69993754/viewspace-2752788/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- 全連線層的作用解析
- 手寫數字圖片識別-全連線網路
- 基於PyTorch框架的多層全連線神經網路實現MNIST手寫數字分類PyTorch框架神經網路
- 卷積層和全連線層之間的關係卷積
- 一個單層的基礎神經網路實現手寫數字識別神經網路
- Python底層實現KNNPythonKNN
- 三層交換機怎麼設定路由連線路由
- ETL連線企業三層資訊系統
- session會話的底層實現方式Session會話
- 使用TensorFlow實現手寫識別(Softmax)
- Vue 高德地圖 API Loca 如何使用 連線線圖層、脈衝連線圖層Vue地圖API
- MVC專案實踐,在三層架構下實現SportsStore-02,DbSession層、BLL層MVC架構Session
- 二層交換機 三層交換機 四層交換機的區別
- 二層、三層交換機和四層交換機的區別(轉)
- 程式碼實現(機器學習識別手寫數字)機器學習
- 機器學習之神經網路識別手寫數字(純python實現)機器學習神經網路Python
- 二層交換機和三層交換機的區別
- MVC專案實踐,在三層架構下實現SportsStore-10,連線字串的加密和解密MVC架構字串加密解密
- Tensorflow實現RNN(LSTM)手寫數字識別RNN
- 機器學習:scikit-learn實現手寫數字識別機器學習
- OpenCV + sklearnSVM 實現手寫數字分割和識別OpenCV
- MVC專案實踐,在三層架構下實現SportsStore,從類圖看三層架構MVC架構
- 深度學習實驗:Softmax實現手寫數字識別深度學習
- 三層網路結構(核心層、匯聚層 、接入層)
- Redis(三)--- Redis的五大資料型別的底層實現Redis大資料資料型別
- 全連線神經網路的原理及Python實現神經網路Python
- python建立多層目錄的方式Python
- 自己動手寫一個持久層框架框架
- MVC與三層架構區別MVC架構
- 三層架構及分層架構
- Minya 分層框架實現的思考(三):問題框架
- 關於三層架構中各層次的關係與實現模型 (轉)架構模型
- 字母排列(python實現)Python
- 【opencv3】 svm實現手寫體與人臉識別OpenCV
- 用tensorflow2實現mnist手寫數字識別
- Pytorch搭建MyNet實現MNIST手寫數字識別PyTorch
- 網路知識梳理--OSI七層網路與TCP/IP五層網路架構及二層/三層網路TCP架構
- Python垃圾回收(GC)三層心法,你瞭解到第幾層?PythonGC