Pytorch中backward()的思考記錄
在pytorch中,只能對標量使用backward。如果對向量進行backward,則會報錯:
import torch
x=torch.tensor([2,3,4],dtype=torch.float,requires_grad=True)
print(x)
y=x*2
print(y)
# z=y.mean()
# z.backward()
y.backward()
print(x.requires_grad)
print(x.grad)
以上程式碼執行:RuntimeError: grad can be implicitly created only for scalar outputs
但如果我們將y.backward()
改為y.backward(torch.Tensor([1,1,1]))
,則可以正確輸出x的梯度。
結合官方文件和知乎部落格https://zhuanlan.zhihu.com/p/27808095,進行以下直觀解釋:
我們有計算序列 m = ( x 1 = 2 , x 2 = 3 ) , k = ( x 1 2 + 3 x 2 , x 2 2 + 2 x 1 ) m=(x_1=2,x_2=3),k=(x_1^2+3x_2,x_2^2+2x_1) m=(x1=2,x2=3),k=(x12+3x2,x22+2x1)。為了方便表示,我們令 y 1 = x 1 2 + 3 x 2 , y 2 = x 2 2 + 2 x 1 y_1=x_1^2+3x_2,y_2=x_2^2+2x_1 y1=x12+3x2,y2=x22+2x1則 k = ( y 1 , y 2 ) k=(y_1,y_2) k=(y1,y2)。我們在進行神經網路計算時,通常想要得到的梯度是 ∂ y 1 ∂ x 1 , ∂ y 1 ∂ x 2 , ∂ y 2 ∂ x 1 , ∂ y 2 ∂ x 2 \frac{\partial y_1}{\partial x_1},\frac{\partial y_1}{\partial x_2},\frac{\partial y_2}{\partial x_1},\frac{\partial y_2}{\partial x_2} ∂x1∂y1,∂x2∂y1,∂x1∂y2,∂x2∂y2,並以此進行反向傳播並優化。
但此時,我們的輸出
k
k
k是一個向量,我們直接對
k
k
k進行反向傳播:
∂
k
∂
x
1
=
∂
k
∂
y
1
⋅
∂
y
1
∂
x
1
+
∂
k
∂
y
2
⋅
∂
y
2
∂
x
1
∂
k
∂
x
2
=
∂
k
∂
y
1
⋅
∂
y
1
∂
x
2
+
∂
k
∂
y
2
⋅
∂
y
2
∂
x
2
\frac{\partial k}{\partial x_1}=\frac{\partial k}{\partial y_1}\cdot \frac{\partial y_1}{\partial x_1}+\frac{\partial k}{\partial y_2} \cdot \frac{\partial y_2}{\partial x_1}\\ \frac{\partial k}{\partial x_2}=\frac{\partial k}{\partial y_1}\cdot \frac{\partial y_1}{\partial x_2}+\frac{\partial k}{\partial y_2} \cdot \frac{\partial y_2}{\partial x_2}
∂x1∂k=∂y1∂k⋅∂x1∂y1+∂y2∂k⋅∂x1∂y2∂x2∂k=∂y1∂k⋅∂x2∂y1+∂y2∂k⋅∂x2∂y2
我們可以看到,上述直接對
k
k
k進行反向傳播得到的結果並不是我們想要的。
寫成矩陣運算形式:
[
∂
k
∂
x
1
∂
k
∂
x
2
]
=
[
∂
y
1
∂
x
1
∂
y
2
∂
x
1
∂
y
1
∂
x
2
∂
y
2
∂
x
2
]
[
∂
k
∂
y
1
∂
k
∂
y
2
]
\left[\begin{matrix}\frac{\partial k}{\partial x_1}\\\frac{\partial k}{\partial x_2} \end{matrix}\right]=\left[ \begin{matrix}\frac{\partial y_1}{\partial x_1}& \frac{\partial y_2}{\partial x_1}\\\frac{\partial y_1}{\partial x_2}& \frac{\partial y_2}{\partial x_2} \end{matrix} \right] \left[\begin{matrix}\frac{\partial k}{\partial y_1}\\\frac{\partial k}{\partial y_2} \end{matrix}\right]
[∂x1∂k∂x2∂k]=[∂x1∂y1∂x2∂y1∂x1∂y2∂x2∂y2][∂y1∂k∂y2∂k]
可以看到,矩陣運算形式中前面
2
×
2
2\times2
2×2的矩陣中的結果正是我們想要的。我們稱之為雅可比(Jacobian)矩陣(嚴格地說,是雅可比矩陣的轉置)。
那麼,當我們令 [ ∂ k ∂ y 1 ∂ k ∂ y 2 ] = [ 1 0 ] \left[\begin{matrix}\frac{\partial k}{\partial y_1}\\\frac{\partial k}{\partial y_2} \end{matrix}\right]=\left[\begin{matrix}1\\0 \end{matrix}\right] [∂y1∂k∂y2∂k]=[10]時,即可得到雅可比矩陣中的第一列,令 [ ∂ k ∂ y 1 ∂ k ∂ y 2 ] = [ 0 1 ] \left[\begin{matrix}\frac{\partial k}{\partial y_1}\\\frac{\partial k}{\partial y_2} \end{matrix}\right]=\left[\begin{matrix}0\\1 \end{matrix}\right] [∂y1∂k∂y2∂k]=[01]時,即可得到雅可比矩陣的第二列。以此我們就可以得到想要的梯度值。
這正是pytorch官方文件中所說的思想。當我們對向量進行反向傳播時,通過在backward()
中新增一個向量
v
v
v,就可以分別得到原向量中每一項乘向量
v
v
v中係數後對應的梯度值。
一些需要注意的點:
- 在backward()中新增的向量 v v v的size要與進行反向傳播的向量的size相同。
- 當我們需要分別求反向傳播向量中每一項的梯度時,可能需要分多次分別進行求導。此時我們要注意
backward()
裡面另外的一個引數retain_variables=True
,這個引數預設是False,也就是反向傳播之後這個計算圖的記憶體會被釋放,這樣就沒辦法進行第二次反向傳播了,所以我們需要設定為True,因為這裡我們需要進行兩次反向傳播求得Jacobian矩陣。
相關文章
- 對pytroch中torch.autograd.backward的思考
- win10中pyTorch的GPU模式安裝記錄Win10PyTorchGPU模式
- MVVM的學習記錄和思考MVVM
- 對查詢資料庫中第M到N條記錄的思考資料庫
- Windows10安裝Pytorch步驟記錄WindowsPyTorch
- Pytorch:使用Tensorboard記錄訓練狀態PyTorchORB
- Anaconda Pytorch 深度學習入門記錄PyTorch深度學習
- 再聊對架構決策記錄的一些思考架構
- 輕量化模型訓練加速的思考(Pytorch實現)模型PyTorch
- DNS中MX記錄的理解DNS
- JS中的Promise 物件記錄JSPromise物件
- pytorch中forward的理解PyTorchForward
- pytorch中中的模型剪枝方法PyTorch模型
- 一次遷移思考的記錄--bulk_collect的limit用法MIT
- SQL中的單記錄函式SQL函式
- DNS中MX記錄的理解(轉)DNS
- pytorch 方法筆記PyTorch筆記
- Pytorch筆記(一)PyTorch筆記
- 機器學習第一天--pytorch環境配置踩坑記錄(一)機器學習PyTorch
- Vue中的點點滴滴在此記錄Vue
- 找回Oracle中Delete刪除的記錄Oracledelete
- 客戶主記錄中的資料
- 零基礎入門深度學習-dive in to pytorch 的程式碼報錯記錄深度學習PyTorch
- Pytorch中的損失函式PyTorch函式
- 轉:Pytorch中的register_buffer()PyTorch
- 域名解析的記錄型別:A記錄、CNAME、MX記錄、NS記錄型別
- pytorch學習筆記PyTorch筆記
- PyTorch 學習筆記PyTorch筆記
- 記錄下學習筆記(Laravel 中的事件監聽)筆記Laravel事件
- awk 中的欄位、記錄和變數變數
- 如何優雅的在flask中記錄logFlask
- Oracle遊標遍歷%rowtype中的記錄Oracle
- 查mysql欄位中的數字記錄MySql
- MySQL 查詢所有表中的記錄數MySql
- 如何刪除oracle庫中相同的記錄Oracle
- dns中soa和ns記錄的作用(轉)DNS
- Oracle中取固定記錄數的方法薦Oracle
- 記錄一個Excel中特殊的VLOOKUP方法Excel