Pytorch學習(七)---- 儲存提取
莫煩python視訊學習筆記 視訊連結https://www.bilibili.com/video/BV1Vx411j7kT?from=search&seid=3065687802317837578
import torch
from torch.autograd import Variable
import matplotlib.pyplot as plt
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
# import os 這裡是為了防止報錯加的
torch.manual_seed(1) # reproducible
# fake data
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = x.pow(2) + 0.2*torch.rand(x.size())
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False)
# 神經網路的儲存
def save():
# save net1
net1 = torch.nn.Sequential( # Sequential的功能在這個括號裡逐層搭建神經層
torch.nn.Linear(1, 10),
torch.nn.ReLU(), # 激勵函式
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.25) # 傳入引數,lr是學習效率
loss_func = torch.nn.MSELoss()
for t in range(100):
prediction = net1(x)
loss = loss_func(prediction, y) # 預測值與真實值對比
optimizer.zero_grad() # 將梯度降為零
loss.backward()
optimizer.step() # 以學習效率0.5優化梯度
# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(131)
plt.title('Net1')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
torch.save(net1, 'net.pkl') # entire all
torch.save(net1.state_dict(), 'net_params.pkl') # 儲存引數
# 神經網路的提取
# method 1
def restore_net():
net2 = torch.load('net.pkl')
prediction = net2(x)
# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(132)
plt.title('Net2')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
# method 2
def restore_params():
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)
# plot result
plt.figure(1, figsize=(10, 3))
plt.subplot(133)
plt.title('Net3')
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.show()
# save net1
save()
# restore entire net
restore_net()
# restore params
restore_params()
一開始將訓練的學習效率設定為0.5,即:
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) # 傳入引數,lr是學習效率
結果影像如下,很明顯學習效果是失敗的。
重新設定學習效率,發現在0.25~0.3學習效果最好
相關文章
- [PyTorch 學習筆記] 7.1 模型儲存與載入PyTorch筆記模型
- 儲存學習
- TensorFlow模型儲存和提取方法模型
- 零基礎學習人工智慧—Python—Pytorch學習(七)人工智慧PythonPyTorch
- 七牛雲物件儲存物件
- 【小白學PyTorch】19 TF2模型的儲存與載入PyTorchTF2模型
- 七牛儲存使用筆記筆記
- ThinkPHP之七牛雲儲存PHP
- 學習筆記14:模型儲存筆記模型
- Pytorch | Tutorial-07 儲存和載入模型PyTorch模型
- 快速使用七牛雲物件儲存物件
- Mysql 5.7儲存過程的學習MySql儲存過程
- Spark 儲存模組原始碼學習Spark原始碼
- GlusterFS分散式儲存學習筆記分散式筆記
- 機器學習7-模型儲存&無監督學習機器學習模型
- 學習OpenCV:骨架提取OpenCV
- 圖片儲存-從七牛到 GithubGithub
- 日常學習儲存--陣列和指標陣列指標
- 樹的學習——樹的儲存結構
- 重新學習Mysql資料庫3:Mysql儲存引擎與資料儲存原理MySql資料庫儲存引擎
- 【小白學PyTorch】6 模型的構建訪問遍歷儲存(附程式碼)PyTorch模型
- pytorch學習筆記PyTorch筆記
- PyTorch 學習筆記PyTorch筆記
- Pytorch_第七篇_深度學習 (DeepLearning) 基礎 [3]---梯度下降PyTorch深度學習梯度
- 深度學習框架Pytorch學習筆記深度學習框架PyTorch筆記
- Flutter學習指南:檔案、儲存和網路Flutter
- docker學習系列2儲存對容器的修改Docker
- etcd學習(9)-etcd中的儲存實現
- 七、函式-儲存過程-觸發器函式儲存過程觸發器
- 機器學習-特徵提取機器學習特徵
- 批次提取畫素差異並儲存二進位制
- 《PyTorch》Part5 PyTorch之遷移學習PyTorch遷移學習
- (pytorch-深度學習系列)pytorch資料操作PyTorch深度學習
- 【Pytorch教程】迅速入門Pytorch深度學習框架PyTorch深度學習框架
- 通過示例學習PYTORCHPyTorch
- Pytorch模型檔案`*.pt`與`*.pth` 的儲存與載入PyTorch模型
- 全面解析Pytorch框架下模型儲存,載入以及凍結PyTorch框架模型
- pytorch入門(七):unsqueezePyTorch