[原始碼解析]深度學習利器之自動微分(3) --- 示例解讀
0x00 摘要
本文從 PyTorch 兩篇官方文件開始為大家解讀兩個示例。本文不會逐句翻譯,而是選取重點並且試圖加入自己的理解。
我們在前兩篇文章學習了自動微分的基本概念,從本文開始,我們繼續分析 PyTorch 如何實現自動微分。因為涉及內容太多太複雜,所以計劃使用 2~3篇來介紹前向傳播如何實現,用 3 ~ 4 篇來介紹後向傳播如何實現。
系列前兩篇連線如下:
0x01 概述
在訓練神經網路時,最常用的演算法是 反向傳播。在該演算法中根據損失函式相對於給定引數的梯度來對引數(模型權重)進行調整。為了計算這些梯度,PyTorch 實現了一個名為 torch.autograd
的內建反向自動微分引擎。它支援任何計算圖的梯度自動計算。
1.1 編碼歷史
從概念上講,autograd 記錄了一個計算圖。在建立張量時,如果設定 requires_grad 為Ture,那麼 Pytorch 就知道需要對該張量進行自動求導。於是PyTorch會記錄對該張量的每一步操作歷史,從而生成一個概念上的有向無環圖,該無環圖的葉子節點是模型的輸入張量,其根為模型的輸出張量。使用者不需要對圖的所有執行路徑進行編碼,因為使用者執行的就是使用者後來想微分的。通過從根到葉跟蹤此圖形,使用者可以使用鏈式求導規則來自動計算梯度。
在內部實現上看,autograd 將此圖表示為一個“Function” 或者說是"Node" 物件(真正的表示式)的圖,該圖可以使用apply方法來進行求值。
1.2 如何應用
在前向傳播計算時,autograd做如下操作:
- 執行請求的操作以計算結果張量。
- 建立一個計算梯度的DAG圖,在DAG圖中維護所有已執行操作(包括操作的梯度函式以及由此產生的新張量)的記錄 。每個tensor梯度計算的具體方法存放於tensor節點的grad_fn屬性中。
當向前傳播完成之後,我們通過在在 DAG 根上呼叫.backward()
來執行後向傳播,autograd會做如下操作:
- 利用
.grad_fn
計算每個張量的梯度,並且依據此構建出包含梯度計算方法的反向傳播計算圖。 - 將梯度累積在各自的張量
.grad
屬性中,並且使用鏈式法則,一直傳播到葉張量。 - 每次迭代都會重新建立計算圖,這使得我們可以使用Python程式碼在每次迭代中更改計算圖的形狀和大小。
需要注意是,PyTorch 中 的DAG 是動態的,每次 .backward()
呼叫後,autograd 開始填充新計算圖,該圖是從頭開始重新建立。這使得我們可以使用Python程式碼在每次迭代中更改計算圖的形狀和大小。
0x02 示例
下面我們通過兩個例子來進行解讀,之所以使用兩個例子,因為均來自於PyTorch 官方文件。
2.2 例項解讀 1
我們首先使用 https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html 來進行演示和解讀。
2.2.1 程式碼
示例程式碼如下:
import torch
a = torch.tensor(2., requires_grad=True)
b = torch.tensor(6., requires_grad=True)
O = 3*a**3
P = b**2
Q = O - P
external_grad = torch.tensor(1.)
Q.backward(gradient=external_grad)
print(a.grad)
print(b.grad)
print("=========== grad")
a = torch.tensor(2., requires_grad=True)
b = torch.tensor(6., requires_grad=True)
Q = 3*a**3 - b**2
grads = torch.autograd.grad(Q, [a, b])
print(grads[0])
print(grads[1])
print(Q.grad_fn.next_functions)
print(O.grad_fn.next_functions)
print(P.grad_fn.next_functions)
print(a.grad_fn)
print(b.grad_fn)
輸出為:
tensor(36.)
tensor(-12.)
=========== grad
tensor(36.)
tensor(-12.)
((<MulBackward0 object at 0x000001374DE6C308>, 0), (<PowBackward0 object at 0x000001374DE6C288>, 0))
((<PowBackward0 object at 0x000001374DE6C288>, 0), (None, 0))
((<AccumulateGrad object at 0x000001374DE6C6C8>, 0),)
None
None
這裡的Q運算方式如下:
因此Q對a, b 的求導如下:
2.2.2 分析
動態圖是在前向傳播的時候建立。前向傳播時候,Q是最終的輸出,但是在反向傳播的時候,Q 卻是計算的最初輸入,就是反向傳播圖的Root。
示例中,對應的張量是:
- a 是 2,b 是 6, Q 是
tensor(-12., grad_fn=<SubBackward0>)
。
對應的積分是:
- Q對於 a 的積分是:\(\frac{∂Q}{∂a} = 9a^2\) = 36。
- Q對於b的積分是 \(\frac{∂Q}{∂b} = -2b\) = -12。
當我們呼叫.backward()
時,backward()
只是通過將其引數傳遞給已經生成的反向圖來計算梯度。autograd 計算這些梯度並將它們儲存在各自的張量.grad
屬性中。
我們需要顯式地給Q.backward()
傳入一個gradient
引數,因為它是一個向量。 gradient
是與 形狀相同的張量Q
,它表示 Q 本身的梯度,即
等效地,我們也可以將 Q 聚合為一個標量並隱式地向後呼叫,例如Q.sum().backward()
。
external_grad = torch.tensor([1., 1.])
Q.backward(gradient=external_grad)
下面是我們示例中 DAG 的視覺化表示。在圖中,箭頭指向前向傳遞的方向。節點代表前向傳遞中每個操作的後向函式。藍色的葉子節點代表我們的葉子張量a
和b
。
2.3 例項解讀 2
這次以https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html為例子說明。
2.3.1 示例程式碼
考慮最簡單的一層神經網路,具有輸入x
、引數w
和b
,以及一些損失函式。它可以通過以下方式在 PyTorch 中定義:
import torch
x = torch.ones(5) # input tensor
y = torch.zeros(3) # expected output
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w)+b
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
2.3.2 張量、函式和計算圖
上述程式碼定義了以下計算圖:
圖片來源是:https://pytorch.org/tutorials/_images/comp-graph.png
在這個網路中,w
和b
是我們需要優化的引數。因此,我們需要計算關於這些變數的損失函式的梯度。為了做到這一點,我們設定了這些張量的requires_grad
屬性。
注意,您可以在建立張量時設定requires_grad
的值,也可以稍後使用x.requires_grad_(True)
方法設定。
我們應用於張量來構建計算圖的函式實際上是一個Function
類的物件。該物件知道如何在前向計算函式,以及如何在反向傳播步驟中計算其導數。對反向傳播函式的引用儲存在grad_fn
張量的屬性中。
print('Gradient function for z =', z.grad_fn)
print('Gradient function for loss =', loss.grad_fn)
輸出如下:
Gradient function for z = <AddBackward0 object at 0x7f4dbd4d3080>
Gradient function for loss = <BinaryCrossEntropyWithLogitsBackward object at 0x7f4dbd4d3080>
2.3.3 計算梯度
為了優化神經網路中引數的權重,我們需要計算損失函式關於引數的導數,即我們需要在限定一些 x
和y
時候得到 $ \frac{\partial loss}{\partial w}$ 和 $\frac{\partial loss}{\partial b} $ 。為了計算這些導數,我們呼叫 loss.backward()
,然後從w.grad
和 b.grad
之中獲得數值:
loss.backward()
print(w.grad)
print(b.grad)
得出:
tensor([[0.1881, 0.1876, 0.0229],
[0.1881, 0.1876, 0.0229],
[0.1881, 0.1876, 0.0229],
[0.1881, 0.1876, 0.0229],
[0.1881, 0.1876, 0.0229]])
tensor([0.1881, 0.1876, 0.0229])
注意
- 我們只能獲取在計算圖葉子節點的
requires_grad
屬性設定為True
時候得到該節點的grad
屬性。我們沒法得到們圖中的所有其他節點的梯度。 - 出於效能原因,我們只能在給定的計算圖之上使用
backward
執行一次梯度計算 。如果我們需要在同一個圖上多次呼叫backward
,則需要在backward
呼叫時候設定retain_graph=True
。
2.3.4 禁用梯度跟蹤
預設情況下,所有設定requires_grad=True
的張量都會跟蹤其計算曆史並支援梯度計算。但是,有些情況下我們不需要這樣做,例如,當我們已經訓練了模型並且只想將其應用於某些輸入資料時,即我們只想通過網路進行前向計算,這時候我們可以通過用torch.no_grad()
塊包圍我們的計算程式碼以停止跟蹤計算 :
z = torch.matmul(x, w)+b
print(z.requires_grad)
with torch.no_grad():
z = torch.matmul(x, w)+b
print(z.requires_grad)
輸出:
True
False
實現相同結果的另一種方法是在張量上使用detach()
方法:
z = torch.matmul(x, w)+b
z_det = z.detach()
print(z_det.requires_grad)
輸出:
False
您可能想要禁用梯度跟蹤的原因有:
- 將神經網路中的某些引數標記為凍結引數。這是微調預訓練網路的一個非常常見的場景。
- 在僅進行前向傳遞時加快計算速度,因為對不跟蹤梯度的張量進行計算會更有效。
0x03 邏輯關係
如果從計算圖角度來看前向計算的過程,就是在構建圖和執行圖。"構建圖"描述的是節點運算之間的關係。"執行圖"則是在會話中執行這個運算關係,就是張量在計算圖之中進行前向傳播的過程。
前向計算依賴一些基礎類,在具體分析前向傳播之前,我們先要看看這些基礎類之間的邏輯關係。從DAG角度來分析 PyTorch 這個系統,其具體邏輯如下。
- 圖表示計算任務。PyTorch把計算都當作是一種有向無環圖,或者說是計算圖,但這是一種虛擬的圖,程式碼中沒有真實的資料結構。
- 計算圖由節點(Node)和邊(Edge)組成。
- 節點(Node)代表了運算操作。
- 一個節點通過邊來獲得 0 個或多個
Tensor
,節點執行計算之後會產生 0 個或多個Tensor
。 - 節點的成員變數 next_functions 是一個 tuple 列表,此列表就代表本節點要輸出到哪些其他 Function。列表個數就是這個 grad_fn 的 Edge 數目,列表之中每一個 tuple 對應一條 Edge 資訊,內容就是 (Edge.function, Edge.input_nr)。
- 一個節點通過邊來獲得 0 個或多個
- 邊(Edge)就是運算操作之間的流向關係。
- Edge.function :表示此 Edge 需要輸出到哪一個其他 Function。
- Edge.input_nr :指定本 Edge 是 Function 的第幾個輸入。
- 使用張量( Tensor) 表示資料,就是在節點間流動的資料,如果沒有資料,計算圖就沒有任何意義。
具體可以參見下圖:
+---------------------+ +----------------------+
| SubBackward0 | | PowBackward0 |
| | Edge | | Edge
| next_functions +-----+--------> | next_functions +----------> ...
| | | | |
+---------------------+ | +----------------------+
|
|
| +----------------------+
| Edge | MulBackward0 |
+--------> | | Edge
| next_functions +----------> ...
| |
+----------------------+
至此,示例解析結束,我們下一篇介紹PyTorch 微分引擎相關的一些基礎類。
0xFF 參考
https://github.com/KeithYin/read-pytorch-source-code/
pytorch學習筆記(十三):backward過程的底層實現解析
How autograd encodes the history
https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html