cuda 沒有提供自動求導機制,因此我們需要完成以下兩步,實現反向傳播。
一、計算所有 trainable 引數的偏微分
判斷哪些引數是 trainable 的?
本例中,輸入 f 的座標是固定的,所以 uvw 的值也是固定的,因此只需要求 feats_interp
對各個頂點的特徵量 \(f_i\) 的偏微分即可。
如何進行反向傳播?
思路:先計算正向傳播的 Loss 值,然後對各個頂點的特徵量 \(f_i\) 求偏微分。
也就是利用了“鏈式法則”進行計算,但是必須要知道 Loss 對於 feats_interp
的偏導結果,因此我們自己實現反向傳播函式需要傳入這個引數。
二、程式碼實現
實現反向傳播函式的注意事項?
(1)求偏導得到的維度和輸入維度保持一致,因此 dL_dfeats
的維度是 [N,8,F]
,dL_dfeats_interp
的維度是 [N,F]
。
(2)需要把前向傳播和反向傳播函式都包裹在新類中,它是 torch.autograd.Function
的子類,如下面程式碼所示:
class Trilinear_interpolation_cuda(torch.autograd.Function):
@staticmethod
def forward(ctx, feats, points):# ctx 即 context 儲存了傳播過程中的狀態量
feat_interp = cppcuda_tutorial.trilinear_interpolation_fw(feats, points)
ctx.save_for_backward(feats, points)
return feat_interp
@staticmethod
def backward(ctx, dL_dfeat_interp):
feats, points = ctx.saved_tensors
dL_dfeats = cppcuda_tutorial.trilinear_interpolation_bw(dL_dfeat_interp.contiguous(), feats, points)
# forward 輸入有幾個引數,這裡就要回傳幾個引數,如果沒有則寫 None
# 我們這裡求的是對 feats 的偏導,所以寫在第一個位置
return dL_dfeats, None
最後的返回值和前向傳播的引數保持一致,如果不需要求偏導則需要對應寫為 None。
(3)傳入的引數都要新增求導支援。
N = 65536; F = 256
rand = torch.rand(N, 8, F, device='cuda')
feats = rand.clone().requires_grad_()
(4)呼叫:前向傳播需要使用 apply
方法,後向傳播直接呼叫:
# 1. 前向傳播
out_cuda = Trilinear_interpolation_cuda.apply(feats, points)
# 2. 計算 Loss(只是簡單的加起來作為損失)
loss = out_cuda.sum()
# 2. Pytorch 會自動計算 dL_dfeat_interp,也就是 Loss 關於 out_cuda 的梯度,傳遞給 backward 函式作為引數
loss.backward()