CUDA教學(2):反向傳播

7hu95b發表於2024-05-28

cuda 沒有提供自動求導機制,因此我們需要完成以下兩步,實現反向傳播。

image

一、計算所有 trainable 引數的偏微分

判斷哪些引數是 trainable 的?

本例中,輸入 f 的座標是固定的,所以 uvw 的值也是固定的,因此只需要求 feats_interp對各個頂點的特徵量 \(f_i\) 的偏微分即可。

如何進行反向傳播?

思路:先計算正向傳播的 Loss 值,然後對各個頂點的特徵量 \(f_i\) 求偏微分。

image

也就是利用了“鏈式法則”進行計算,但是必須要知道 Loss 對於 feats_interp 的偏導結果,因此我們自己實現反向傳播函式需要傳入這個引數。

二、程式碼實現

實現反向傳播函式的注意事項?

(1)求偏導得到的維度和輸入維度保持一致,因此 dL_dfeats 的維度是 [N,8,F]dL_dfeats_interp 的維度是 [N,F]

(2)需要把前向傳播和反向傳播函式都包裹在新類中,它是 torch.autograd.Function 的子類,如下面程式碼所示:

class Trilinear_interpolation_cuda(torch.autograd.Function):
    @staticmethod
    def forward(ctx, feats, points):# ctx 即 context 儲存了傳播過程中的狀態量
        feat_interp = cppcuda_tutorial.trilinear_interpolation_fw(feats, points)

        ctx.save_for_backward(feats, points)

        return feat_interp

    @staticmethod
    def backward(ctx, dL_dfeat_interp):
        feats, points = ctx.saved_tensors

        dL_dfeats = cppcuda_tutorial.trilinear_interpolation_bw(dL_dfeat_interp.contiguous(), feats, points)

        # forward 輸入有幾個引數,這裡就要回傳幾個引數,如果沒有則寫 None
        # 我們這裡求的是對 feats 的偏導,所以寫在第一個位置
        return dL_dfeats, None

最後的返回值和前向傳播的引數保持一致,如果不需要求偏導則需要對應寫為 None。

(3)傳入的引數都要新增求導支援。

N = 65536; F = 256
rand = torch.rand(N, 8, F, device='cuda')
feats = rand.clone().requires_grad_()

(4)呼叫:前向傳播需要使用 apply 方法,後向傳播直接呼叫:

# 1. 前向傳播
out_cuda = Trilinear_interpolation_cuda.apply(feats, points)

# 2. 計算 Loss(只是簡單的加起來作為損失)
loss = out_cuda.sum()

# 2. Pytorch 會自動計算 dL_dfeat_interp,也就是 Loss 關於 out_cuda 的梯度,傳遞給 backward 函式作為引數
loss.backward()

相關文章